【函数讲解】botorch中的函数 is_non_dominated():用于计算非支配(non-dominated)前沿

默认求最大的Pareto前沿

python 复制代码
        # 获取训练目标值,计算Pareto前沿(非支配解集合),然后从样本中提取出Pareto最优解。
        train_obj = self.samples[1]
        pareto_mask = is_non_dominated(train_obj)
        pareto_y = train_obj[pareto_mask]

源码

这里用到了一个函数 is_non_dominated(),来看下该函数的源码:

python 复制代码
from __future__ import annotations

import torch
from torch import Tensor


def is_non_dominated(Y: Tensor, deduplicate: bool = True) -> Tensor:
    r"""Computes the non-dominated front.

    Note: this assumes maximization.

    Args:
        输入:张量 Y,其维度为 (batch_shape) x n x m,这里 n 代表样本数量,m 代表每个样本的目标数量。,其中括号括住的batch_shape意思是可选,可以有这个维度或者没有
        Y: A `(batch_shape) x n x m`-dim tensor of outcomes.
        deduplicate: A boolean indicating whether to only return unique points on the pareto frontier.

    Returns:
        返回:布尔张量,指示每个样本是否是非支配点
        A `(batch_shape) x n`-dim boolean tensor indicating whether each point is non-dominated.
    """
    Y1 = Y.unsqueeze(-3)
    Y2 = Y.unsqueeze(-2)
    dominates = (Y1 >= Y2).all(dim=-1) & (Y1 > Y2).any(dim=-1)
    nd_mask = ~(dominates.any(dim=-1))
    if deduplicate:
        # remove duplicates
        # find index of first occurrence of each unique element
        indices = (Y1 == Y2).all(dim=-1).long().argmax(dim=-1)
        keep = torch.zeros_like(nd_mask)
        keep.scatter_(dim=-1, index=indices, value=1.0)
        return nd_mask & keep
    return nd_mask

示例:

有一组解,每个解有两个目标值,找出这组解中的非支配解:

python 复制代码
import torch
from botorch.utils.multi_objective import is_non_dominated

# 假设有5个解,每个解有2个目标
Y = torch.tensor([
    [0.5, 0.7],
    [0.6, 0.6],
    [0.8, 0.3],
    [0.4, 0.9],
    [0.7, 0.5]
])

# 调用 is_non_dominated 函数
non_dominated_mask = is_non_dominated(Y)

# 打印非支配解
print("Non-dominated solutions:", Y[non_dominated_mask])

实例:

这里就是先拿到所有的目标值,然后计算哪些是Pareto(True或者False),最后再原始数据中选出所有True的数据

python 复制代码
        # 获取训练目标值,计算Pareto前沿(非支配解集合),然后从样本中提取出Pareto最优解。
        train_obj = self.samples[1]
        pareto_mask = is_non_dominated(train_obj)
        pareto_y = train_obj[pareto_mask]
相关推荐
波动几何1 小时前
CAD制图编辑器cad-editor
人工智能
耿雨飞7 小时前
第三章:LangChain Classic vs. 新版 LangChain —— 架构演进与迁移指南
人工智能·架构·langchain
BizViewStudio7 小时前
甄选 2026:AI 重构新媒体代运营行业的三大核心变革与落地路径
大数据·人工智能·新媒体运营·媒体
俊哥V7 小时前
AI一周事件 · 2026年4月8日至4月14日
人工智能·ai
GitCode官方8 小时前
G-Star Gathering Day 杭州站回顾
人工智能·开源·atomgit
宇擎智脑科技8 小时前
开源 AI Agent 架构设计对比:Python 单体 vs TypeScript 插件化
人工智能·openclaw·hermes agent
房开民8 小时前
可变参数模板
java·开发语言·算法
不知名的忻9 小时前
Morris遍历(力扣第99题)
java·算法·leetcode·morris遍历
冷色系里的一抹暖调9 小时前
OpenClaw Docker部署避坑指南:服务启动成功但网页打不开?
人工智能·docker·容器·openclaw
曹牧9 小时前
自动编程AI落地方案‌
人工智能