【函数讲解】botorch中的函数 is_non_dominated():用于计算非支配(non-dominated)前沿

默认求最大的Pareto前沿

python 复制代码
        # 获取训练目标值,计算Pareto前沿(非支配解集合),然后从样本中提取出Pareto最优解。
        train_obj = self.samples[1]
        pareto_mask = is_non_dominated(train_obj)
        pareto_y = train_obj[pareto_mask]

源码

这里用到了一个函数 is_non_dominated(),来看下该函数的源码:

python 复制代码
from __future__ import annotations

import torch
from torch import Tensor


def is_non_dominated(Y: Tensor, deduplicate: bool = True) -> Tensor:
    r"""Computes the non-dominated front.

    Note: this assumes maximization.

    Args:
        输入:张量 Y,其维度为 (batch_shape) x n x m,这里 n 代表样本数量,m 代表每个样本的目标数量。,其中括号括住的batch_shape意思是可选,可以有这个维度或者没有
        Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
        deduplicate: A boolean indicating whether to only return unique points on the pareto frontier.

    Returns:
        返回:布尔张量,指示每个样本是否是非支配点
        A `(batch_shape) x n`-dim boolean tensor indicating whether each point is non-dominated.
    """
    Y1 = Y.unsqueeze(-3)
    Y2 = Y.unsqueeze(-2)
    dominates = (Y1 >= Y2).all(dim=-1) & (Y1 > Y2).any(dim=-1)
    nd_mask = ~(dominates.any(dim=-1))
    if deduplicate:
        # remove duplicates
        # find index of first occurrence of each unique element
        indices = (Y1 == Y2).all(dim=-1).long().argmax(dim=-1)
        keep = torch.zeros_like(nd_mask)
        keep.scatter_(dim=-1, index=indices, value=1.0)
        return nd_mask & keep
    return nd_mask

示例:

有一组解,每个解有两个目标值,找出这组解中的非支配解:

python 复制代码
import torch
from botorch.utils.multi_objective import is_non_dominated

# 假设有5个解,每个解有2个目标
Y = torch.tensor([
    [0.5, 0.7],
    [0.6, 0.6],
    [0.8, 0.3],
    [0.4, 0.9],
    [0.7, 0.5]
])

# 调用 is_non_dominated 函数
non_dominated_mask = is_non_dominated(Y)

# 打印非支配解
print("Non-dominated solutions:", Y[non_dominated_mask])

实例:

这里就是先拿到所有的目标值,然后计算哪些是Pareto(True或者False),最后再原始数据中选出所有True的数据

python 复制代码
        # 获取训练目标值,计算Pareto前沿(非支配解集合),然后从样本中提取出Pareto最优解。
        train_obj = self.samples[1]
        pareto_mask = is_non_dominated(train_obj)
        pareto_y = train_obj[pareto_mask]
相关推荐
tqs_123459 小时前
redis zset score的计算
java·算法
_Coin_-9 小时前
算法训练营DAY60 第十一章:图论part11
算法·图论
林木辛9 小时前
LeetCode热题 438.找到字符中所有字母异位词 (滑动窗口)
算法·leetcode
和鲸社区9 小时前
四大经典案例,入门AI算法应用,含分类、回归与特征工程|2025人工智能实训季初阶赛
人工智能·python·深度学习·算法·机器学习·分类·回归
IT古董10 小时前
【第五章:计算机视觉】1.计算机视觉基础-(3)卷积神经网络核心层与架构分析:卷积层、池化层、归一化层、激活层
人工智能·计算机视觉·cnn
黎燃10 小时前
AI生成音乐的创作逻辑深析:以AIVA为例
人工智能
dragoooon3410 小时前
[优选算法专题二——NO.16最小覆盖子串]
c++·算法·leetcode·学习方法
点云SLAM10 小时前
四元数 (Quaternion)在位姿(SE(3))表示下的各类导数(雅可比)知识(2)
人工智能·线性代数·算法·机器学习·slam·四元数·李群李代数
七芒星202310 小时前
ResNet(详细易懂解释):残差网络的革命性突破
人工智能·pytorch·深度学习·神经网络·学习·cnn
汉克老师10 小时前
第十四届蓝桥杯青少组C++选拔赛[2023.1.15]第二部分编程题(4 、移动石子)
c++·算法·蓝桥杯·蓝桥杯c++·c++蓝桥杯