【函数讲解】botorch中的函数 is_non_dominated():用于计算非支配(non-dominated)前沿

默认求最大的Pareto前沿

python 复制代码
        # 获取训练目标值,计算Pareto前沿(非支配解集合),然后从样本中提取出Pareto最优解。
        train_obj = self.samples[1]
        pareto_mask = is_non_dominated(train_obj)
        pareto_y = train_obj[pareto_mask]

源码

这里用到了一个函数 is_non_dominated(),来看下该函数的源码:

python 复制代码
from __future__ import annotations

import torch
from torch import Tensor


def is_non_dominated(Y: Tensor, deduplicate: bool = True) -> Tensor:
    r"""Computes the non-dominated front.

    Note: this assumes maximization.

    Args:
        输入:张量 Y,其维度为 (batch_shape) x n x m,这里 n 代表样本数量,m 代表每个样本的目标数量。,其中括号括住的batch_shape意思是可选,可以有这个维度或者没有
        Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
        deduplicate: A boolean indicating whether to only return unique points on the pareto frontier.

    Returns:
        返回:布尔张量,指示每个样本是否是非支配点
        A `(batch_shape) x n`-dim boolean tensor indicating whether each point is non-dominated.
    """
    Y1 = Y.unsqueeze(-3)
    Y2 = Y.unsqueeze(-2)
    dominates = (Y1 >= Y2).all(dim=-1) & (Y1 > Y2).any(dim=-1)
    nd_mask = ~(dominates.any(dim=-1))
    if deduplicate:
        # remove duplicates
        # find index of first occurrence of each unique element
        indices = (Y1 == Y2).all(dim=-1).long().argmax(dim=-1)
        keep = torch.zeros_like(nd_mask)
        keep.scatter_(dim=-1, index=indices, value=1.0)
        return nd_mask & keep
    return nd_mask

示例:

有一组解,每个解有两个目标值,找出这组解中的非支配解:

python 复制代码
import torch
from botorch.utils.multi_objective import is_non_dominated

# 假设有5个解,每个解有2个目标
Y = torch.tensor([
    [0.5, 0.7],
    [0.6, 0.6],
    [0.8, 0.3],
    [0.4, 0.9],
    [0.7, 0.5]
])

# 调用 is_non_dominated 函数
non_dominated_mask = is_non_dominated(Y)

# 打印非支配解
print("Non-dominated solutions:", Y[non_dominated_mask])

实例:

这里就是先拿到所有的目标值,然后计算哪些是Pareto(True或者False),最后再原始数据中选出所有True的数据

python 复制代码
        # 获取训练目标值,计算Pareto前沿(非支配解集合),然后从样本中提取出Pareto最优解。
        train_obj = self.samples[1]
        pareto_mask = is_non_dominated(train_obj)
        pareto_y = train_obj[pareto_mask]
相关推荐
Shawn_Shawn3 小时前
mcp学习笔记(一)-mcp核心概念梳理
人工智能·llm·mcp
33三 三like5 小时前
《基于知识图谱和智能推荐的养老志愿服务系统》开发日志
人工智能·知识图谱
芝士爱知识a5 小时前
【工具推荐】2026公考App横向评测:粉笔、华图与智蛙面试App功能对比
人工智能·软件推荐·ai教育·结构化面试·公考app·智蛙面试app·公考上岸
腾讯云开发者6 小时前
港科大熊辉|AI时代的职场新坐标——为什么你应该去“数据稀疏“的地方?
人工智能
工程师老罗6 小时前
YoloV1数据集格式转换,VOC XML→YOLOv1张量
xml·人工智能·yolo
颜酱6 小时前
图结构完全解析:从基础概念到遍历实现
javascript·后端·算法
m0_736919107 小时前
C++代码风格检查工具
开发语言·c++·算法
yugi9878387 小时前
基于MATLAB强化学习的单智能体与多智能体路径规划算法
算法·matlab
Coder_Boy_7 小时前
技术让开发更轻松的底层矛盾
java·大数据·数据库·人工智能·深度学习