YOLO26 自定义损失函数 重写 init_criterion 方法 损失类不继承基类

YOLO26 自定义损失函数 重写 init_criterion 方法 损失类不继承基类

flyfish

YOLO 的 ClassificationModel 基类中,专门设计了 init_criterion 方法:模型初始化时,会自动调用该方法生成损失函数并赋值给 self.criterion 属性。

通过继承模型基类,重写 init_criterion 方法,让模型在初始化阶段就原生创建自定义损失函数。

二、实现步骤

步骤1:自定义损失类
步骤2:自定义模型类,重写损失初始化方法

继承 ClassificationModel 基类,重写 init_criterion 方法,返回自定义损失的实例。

模型初始化时会自动调用该方法,将自定义损失设为模型的原生属性。

步骤3:自定义训练器,重写模型创建方法

继承官方 ClassificationTrainer,重写 get_model 方法,返回自定义的模型类实例,让训练器全程使用自定义模型进行训练。

步骤4:配置参数,启动训练

按照官方格式配置超参数,使用自定义训练器启动训练。

代码

python 复制代码
import torch
from torch.nn import CrossEntropyLoss
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.nn.tasks import ClassificationModel


class FocalClassificationLoss:
    """
    自定义聚焦损失函数(Focal Loss)
    完全兼容YOLO分类框架接口
    """
    def __init__(self, gamma=2.0, alpha=None, label_smoothing=0.1):
        """
        初始化损失函数超参数
        :param gamma: 聚焦因子,调节难分样本的关注权重
        :param alpha: 类别平衡权重
        :param label_smoothing: 标签平滑系数
        """
        self.gamma = gamma
        self.alpha = torch.tensor(alpha) if alpha else None
        self.label_smoothing = label_smoothing
        self.ce = CrossEntropyLoss(label_smoothing=label_smoothing, reduction="none")
        self.call_count = 0

    def __call__(self, preds, batch):
        """
        YOLO框架标准损失计算接口
        :param preds: 模型预测输出
        :param batch: 批次数据,包含标签
        :return: 反向传播用损失、日志用损失
        """
        self.call_count += 1
        
        # 兼容多输出格式,提取有效预测张量
        preds = preds[1] if isinstance(preds, (list, tuple)) else preds
        # 计算基础交叉熵损失
        ce_loss = self.ce(preds, batch["cls"])
        # 计算真实标签的预测概率
        pt = torch.exp(-ce_loss)
        # Focal Loss计算
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        # 应用类别平衡权重
        if self.alpha is not None:
            alpha_t = self.alpha.to(preds.device)[batch["cls"]]
            focal_loss = alpha_t * focal_loss
        
        loss = focal_loss.mean()
        return loss, loss.detach()


class CustomClassificationModel(ClassificationModel):
    """
    自定义分类模型
    继承官方基类,重写损失函数初始化方法,原生支持自定义损失
    """
    def init_criterion(self):
        """
        重写官方损失初始化方法
        模型初始化时会自动调用该方法创建损失函数
        """
        return FocalClassificationLoss(gamma=2.0, alpha=[2.3, 0.64], label_smoothing=0.1)


class CustomModelTrainer(ClassificationTrainer):
    """
    自定义训练器
    重写模型创建方法,返回自定义的分类模型
    """
    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        重写模型创建逻辑
        :param cfg: 模型配置文件
        :param weights: 预训练权重路径
        :param verbose: 是否打印模型详情
        :return: 自定义分类模型实例
        """
        # 创建自定义模型实例
        model = CustomClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose)
        # 加载预训练权重
        if weights:
            model.load(weights)
        return model


def run_training():
    # 训练超参数配置
    overrides = dict(
        model="yolo26s-cls.pt",
        data="./cats_and_dogs_filtered",
        imgsz=224,
        epochs=5,
        batch=8,
        workers=2,
        device=0,
        amp=False,
        patience=0,
        save=False,
        project="./runs/classify/cat_dog_standard",
        name="custom_model",
        exist_ok=True,
        pretrained=True,
    )

    # 启动训练
    trainer = CustomModelTrainer(overrides=overrides.copy())
    trainer.train()
    
    # 验证自定义损失生效情况
    if hasattr(trainer.model, 'criterion') and hasattr(trainer.model.criterion, 'call_count'):
        print(f"自定义损失函数累计调用次数: {trainer.model.criterion.call_count}")


if __name__ == "__main__":
    run_training()
相关推荐
装不满的克莱因瓶1 小时前
RLHF中的PPO算法——大语言模型对齐优化的核心引擎
人工智能·python·深度学习·算法·机器学习·语言模型·自然语言处理
c_lb72881 小时前
期货主连研究具体月实盘:KQ 连续与标的月份偏差怎么记
python·区块链
绘梨衣5471 小时前
采集基类设计遇到的描述符bug
爬虫·python·bug
TechWayfarer2 小时前
IP精准定位服务在保险行业的接入实践:区域需求洞察与精准服务
数据库·python·tcp/ip·flask
Li#2 小时前
AI编写操作使用说明书需要用到的工具和能力
python·ai编程·ai写作
红宝村村长2 小时前
torch.autograd.Function.apply()
开发语言·python
花间相见2 小时前
【LeetCode01】—— 无重复字符的最长子串:滑动窗口经典题详解
python·算法·leetcode
何以解忧,唯有..2 小时前
Python 中的继承机制:从基础到高级用法详解
java·开发语言·python
try2find3 小时前
agent环境安装spacy
python·智能体