YOLO26 自定义损失函数 分类任务自定义损失的接口约定
flyfish
这个约定是 分类训练循环中调用损失函数的固定调用契约,自定义损失类必须完全符合这个契约,才能被框架正常识别、调用,不会出现参数不匹配、返回值解包失败等报错。
分别约束了「调用形式、入参格式、返回值格式、挂载位置」:
1. 调用形式约定:必须实现 __call__ 方法,实例可直接调用
框架在训练迭代中,会以函数调用的方式直接使用损失函数实例,框架内部的调用逻辑(简化版)为:
python
# 训练循环中框架的固定调用写法
loss, loss_items = self.model.criterion(preds, batch)
因此自定义损失类必须实现 __call__ 方法,让类的实例可以像函数一样被直接调用。
普通 PyTorch 损失函数继承 nn.Module,通过 forward 方法实现计算,本质也是依赖 nn.Module 自带的 __call__ 机制。YOLO 分类的损失不强制要求继承 nn.Module,只要类实现了 __call__ 即可正常工作。
2. 入参格式约定:固定接收 2 个参数,顺序不可修改
损失函数的 __call__ 方法必须固定接收两个入参,顺序为 (模型预测结果, 批次数据字典),不可调换、不可增减参数。
第一个参数:preds(模型前向输出)
分类模型前向传播的输出结果
格式兼容:不同版本 Ultralytics 中,分类模型的输出有两种形式:
- 直接返回形状为
[batch_size, 类别数]的分类 logits 张量 - 返回元组/列表,通常结构为
[中间特征图, 最终分类logits],有效预测值在第二个位置
对应代码中的兼容处理:
python
# 兼容两种输出格式,提取最终的分类预测张量
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
第二个参数:batch(批次数据字典)
数据加载器(DataLoader)返回的单批次数据,固定为字典格式
分类任务固定包含键:"cls",对应形状为 [batch_size] 的类别索引标签 (不是 one-hot 编码)
损失计算时,必须通过 batch["cls"] 取出真实标签
示例:batch=8 的二分类任务中,batch["cls"] 是形如 tensor([0, 1, 0, 1, 1, 0, 0, 1]) 的一维张量
3. 返回值约定:必须返回二元组
必须返回 2 个值,框架会自动解包,少返回/多返回都会直接触发报错。
| 返回值顺序 | 作用 | 格式要求 |
|---|---|---|
| 第一个值 | 用于反向传播,更新模型权重 | 必须是带计算图的标量张量(可求导),通常是批次内所有样本损失求均值后的结果 |
| 第二个值 | 用于训练日志统计、进度条打印、指标文件记录 | 必须是分离梯度后的损失值 (.detach()),不参与计算图,避免显存泄漏 |
对应代码中的标准实现:
python
loss = focal_loss.mean() # 第一个值:带梯度的标量损失,用于反向传播更新参数
return loss, loss.detach() # 第二个值:脱梯度的损失值,仅用于日志打印和统计
目标检测任务中会返回多个损失项的字典,但分类任务只有单损失,直接返回脱梯度的标量即可。
4. 挂载位置约定:必须挂载为模型的 criterion 属性
框架是通过 self.model.criterion 来定位损失函数的,因此无论你用哪种注入方式,最终都要把自定义损失的实例,赋值给模型实例的 .criterion 属性。
子类化标准法:模型初始化时自动调用 init_criterion() 生成实例并赋值给 self.criterion,属于框架原生的标准流程
示例
下面是一个最小化的、完全符合接口约定的自定义损失(包装原生交叉熵),可以直接接入YOLO分类训练:
python
from torch.nn import CrossEntropyLoss
class SimpleCustomLoss:
def __init__(self, label_smoothing=0.1):
self.ce = CrossEntropyLoss(label_smoothing=label_smoothing)
# 约定1:实现 __call__ 方法
# 约定2:固定入参顺序 preds, batch
def __call__(self, preds, batch):
# 兼容模型输出格式
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
# 从 batch 字典中取出分类标签
loss = self.ce(preds, batch["cls"])
# 约定3:返回 (带梯度损失, 脱梯度损失) 二元组
return loss, loss.detach()