YOLO26 自定义损失函数 分类任务自定义损失的接口约定

YOLO26 自定义损失函数 分类任务自定义损失的接口约定

flyfish

这个约定是 分类训练循环中调用损失函数的固定调用契约,自定义损失类必须完全符合这个契约,才能被框架正常识别、调用,不会出现参数不匹配、返回值解包失败等报错。

分别约束了「调用形式、入参格式、返回值格式、挂载位置」:

1. 调用形式约定:必须实现 __call__ 方法,实例可直接调用

框架在训练迭代中,会以函数调用的方式直接使用损失函数实例,框架内部的调用逻辑(简化版)为:

python 复制代码
# 训练循环中框架的固定调用写法
loss, loss_items = self.model.criterion(preds, batch)

因此自定义损失类必须实现 __call__ 方法,让类的实例可以像函数一样被直接调用。

普通 PyTorch 损失函数继承 nn.Module,通过 forward 方法实现计算,本质也是依赖 nn.Module 自带的 __call__ 机制。YOLO 分类的损失不强制要求继承 nn.Module,只要类实现了 __call__ 即可正常工作。

2. 入参格式约定:固定接收 2 个参数,顺序不可修改

损失函数的 __call__ 方法必须固定接收两个入参,顺序为 (模型预测结果, 批次数据字典),不可调换、不可增减参数。

第一个参数:preds(模型前向输出)

分类模型前向传播的输出结果

格式兼容:不同版本 Ultralytics 中,分类模型的输出有两种形式:

  1. 直接返回形状为 [batch_size, 类别数] 的分类 logits 张量
  2. 返回元组/列表,通常结构为 [中间特征图, 最终分类logits],有效预测值在第二个位置
    对应代码中的兼容处理:
python 复制代码
# 兼容两种输出格式,提取最终的分类预测张量
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
第二个参数:batch(批次数据字典)

数据加载器(DataLoader)返回的单批次数据,固定为字典格式

分类任务固定包含键:"cls",对应形状为 [batch_size]类别索引标签 (不是 one-hot 编码)

损失计算时,必须通过 batch["cls"] 取出真实标签

示例:batch=8 的二分类任务中,batch["cls"] 是形如 tensor([0, 1, 0, 1, 1, 0, 0, 1]) 的一维张量

3. 返回值约定:必须返回二元组

必须返回 2 个值,框架会自动解包,少返回/多返回都会直接触发报错。

返回值顺序 作用 格式要求
第一个值 用于反向传播,更新模型权重 必须是带计算图的标量张量(可求导),通常是批次内所有样本损失求均值后的结果
第二个值 用于训练日志统计、进度条打印、指标文件记录 必须是分离梯度后的损失值.detach()),不参与计算图,避免显存泄漏

对应代码中的标准实现:

python 复制代码
loss = focal_loss.mean()       # 第一个值:带梯度的标量损失,用于反向传播更新参数
return loss, loss.detach()     # 第二个值:脱梯度的损失值,仅用于日志打印和统计

目标检测任务中会返回多个损失项的字典,但分类任务只有单损失,直接返回脱梯度的标量即可。

4. 挂载位置约定:必须挂载为模型的 criterion 属性

框架是通过 self.model.criterion 来定位损失函数的,因此无论你用哪种注入方式,最终都要把自定义损失的实例,赋值给模型实例的 .criterion 属性。

子类化标准法:模型初始化时自动调用 init_criterion() 生成实例并赋值给 self.criterion,属于框架原生的标准流程

示例

下面是一个最小化的、完全符合接口约定的自定义损失(包装原生交叉熵),可以直接接入YOLO分类训练:

python 复制代码
from torch.nn import CrossEntropyLoss

class SimpleCustomLoss:
    def __init__(self, label_smoothing=0.1):
        self.ce = CrossEntropyLoss(label_smoothing=label_smoothing)
    
    # 约定1:实现 __call__ 方法
    # 约定2:固定入参顺序 preds, batch
    def __call__(self, preds, batch):
        # 兼容模型输出格式
        preds = preds[1] if isinstance(preds, (list, tuple)) else preds
        # 从 batch 字典中取出分类标签
        loss = self.ce(preds, batch["cls"])
        # 约定3:返回 (带梯度损失, 脱梯度损失) 二元组
        return loss, loss.detach()
相关推荐
Together_CZ1 天前
Ultralytics YOLO26: Unified Real-Time End-to-End Vision Models——统一的实时端到端视觉模型
ultralytics·end-to-end·unified·yolo26·real-time·统一的实时端到端视觉模型·vision models
YOLO视觉与编程1 天前
jetson orin nano烧录jetpack7.2系统
人工智能·深度学习·yolo·目标检测·机器学习
stsdddd1 天前
YOLO系列目标检测数据集大全【第二十五期】
yolo·目标检测·目标跟踪
stsdddd2 天前
YOLO系列目标检测数据集大全【第二十二期】
yolo·目标检测·目标跟踪
王小王-1232 天前
基于 YOLOv8 与 Faster R-CNN 的红外图像行人检测系统设计与实现
yolo·目标检测·cnn·fasterrcnn·红外行人检测
stsdddd2 天前
YOLO系列目标检测数据集大全【第二十三期】
yolo·目标检测·目标跟踪
YOLO数据集集合2 天前
无人机航拍桥梁巡检数据集 | 桥梁结构缺陷检测 深度学习目标检测数据10338期
深度学习·yolo·目标检测·计算机视觉·无人机
前网易架构师-高司机2 天前
带标注的薄荷病叶数据集,识别率98.8%,3533张图,支持yolo,coco json,voc xml,文末有模型训练代码
yolo·数据集·缺陷·薄荷·叶子·风干·变质
爱睡懒觉的焦糖玛奇朵2 天前
【视觉检测之人员奔跑检测算法开发思路】
人工智能·python·深度学习·算法·yolo·视觉检测