YOLO26 自定义损失函数 重写 init_criterion 方法 损失类不继承基类
flyfish
YOLO 的 ClassificationModel 基类中,专门设计了 init_criterion 方法:模型初始化时,会自动调用该方法生成损失函数并赋值给 self.criterion 属性。
通过继承模型基类,重写 init_criterion 方法,让模型在初始化阶段就原生创建自定义损失函数。
二、实现步骤
步骤1:自定义损失类
步骤2:自定义模型类,重写损失初始化方法
继承 ClassificationModel 基类,重写 init_criterion 方法,返回自定义损失的实例。
模型初始化时会自动调用该方法,将自定义损失设为模型的原生属性。
步骤3:自定义训练器,重写模型创建方法
继承官方 ClassificationTrainer,重写 get_model 方法,返回自定义的模型类实例,让训练器全程使用自定义模型进行训练。
步骤4:配置参数,启动训练
按照官方格式配置超参数,使用自定义训练器启动训练。
代码
python
import torch
from torch.nn import CrossEntropyLoss
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.nn.tasks import ClassificationModel
class FocalClassificationLoss:
"""
自定义聚焦损失函数(Focal Loss)
完全兼容YOLO分类框架接口
"""
def __init__(self, gamma=2.0, alpha=None, label_smoothing=0.1):
"""
初始化损失函数超参数
:param gamma: 聚焦因子,调节难分样本的关注权重
:param alpha: 类别平衡权重
:param label_smoothing: 标签平滑系数
"""
self.gamma = gamma
self.alpha = torch.tensor(alpha) if alpha else None
self.label_smoothing = label_smoothing
self.ce = CrossEntropyLoss(label_smoothing=label_smoothing, reduction="none")
self.call_count = 0
def __call__(self, preds, batch):
"""
YOLO框架标准损失计算接口
:param preds: 模型预测输出
:param batch: 批次数据,包含标签
:return: 反向传播用损失、日志用损失
"""
self.call_count += 1
# 兼容多输出格式,提取有效预测张量
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
# 计算基础交叉熵损失
ce_loss = self.ce(preds, batch["cls"])
# 计算真实标签的预测概率
pt = torch.exp(-ce_loss)
# Focal Loss计算
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
# 应用类别平衡权重
if self.alpha is not None:
alpha_t = self.alpha.to(preds.device)[batch["cls"]]
focal_loss = alpha_t * focal_loss
loss = focal_loss.mean()
return loss, loss.detach()
class CustomClassificationModel(ClassificationModel):
"""
自定义分类模型
继承官方基类,重写损失函数初始化方法,原生支持自定义损失
"""
def init_criterion(self):
"""
重写官方损失初始化方法
模型初始化时会自动调用该方法创建损失函数
"""
return FocalClassificationLoss(gamma=2.0, alpha=[2.3, 0.64], label_smoothing=0.1)
class CustomModelTrainer(ClassificationTrainer):
"""
自定义训练器
重写模型创建方法,返回自定义的分类模型
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""
重写模型创建逻辑
:param cfg: 模型配置文件
:param weights: 预训练权重路径
:param verbose: 是否打印模型详情
:return: 自定义分类模型实例
"""
# 创建自定义模型实例
model = CustomClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose)
# 加载预训练权重
if weights:
model.load(weights)
return model
def run_training():
# 训练超参数配置
overrides = dict(
model="yolo26s-cls.pt",
data="./cats_and_dogs_filtered",
imgsz=224,
epochs=5,
batch=8,
workers=2,
device=0,
amp=False,
patience=0,
save=False,
project="./runs/classify/cat_dog_standard",
name="custom_model",
exist_ok=True,
pretrained=True,
)
# 启动训练
trainer = CustomModelTrainer(overrides=overrides.copy())
trainer.train()
# 验证自定义损失生效情况
if hasattr(trainer.model, 'criterion') and hasattr(trainer.model.criterion, 'call_count'):
print(f"自定义损失函数累计调用次数: {trainer.model.criterion.call_count}")
if __name__ == "__main__":
run_training()