【pytorch损失函数(7)】损失函数的选择需结合属性类型(分类/回归)、任务粒度(单标签/多标签)以及数据特性(类别平衡性)

在属性(如颜色、类型等)的监督学习中,损失函数的选择需结合属性类型(分类/回归)、任务粒度(单标签/多标签)以及数据特性(类别平衡性)。

文章目录

      • [1. 分类属性(如颜色、类型、方向)](#1. 分类属性(如颜色、类型、方向))
        • [(1) 单标签分类(互斥类别)](#(1) 单标签分类(互斥类别))
        • [(2) 多标签分类(非互斥类别)](#(2) 多标签分类(非互斥类别))
      • [2. 回归属性(如长度、曲率、角度)](#2. 回归属性(如长度、曲率、角度))
        • [(1) 连续值预测](#(1) 连续值预测)
        • [(2) 方向或角度预测](#(2) 方向或角度预测)
      • [3. 多任务联合训练](#3. 多任务联合训练)
        • [(1) 加权多损失融合](#(1) 加权多损失融合)
        • [(2) 自适应权重(如Uncertainty Weighting)](#(2) 自适应权重(如Uncertainty Weighting))
      • 总结

1. 分类属性(如颜色、类型、方向)

(1) 单标签分类(互斥类别)
  • 场景:每个属性仅属于一个类别(如车道颜色只能是红/绿/蓝中的一种)。
  • 推荐损失:
    • 交叉熵损失(Cross-Entropy Loss)
      标准分类损失,适用于类别平衡的数据。

      python 复制代码
      loss_fn = nn.CrossEntropyLoss(weight=class_weights)  # 可选类别权重
      loss = loss_fn(preds, gts)  # preds: [B, C], gts: [B](类别索引)
    • Focal Loss
      解决类别不平衡问题(如某些颜色样本极少)。

      python 复制代码
      loss_fn = FocalLoss(alpha=0.25, gamma=2.0)  # 抑制易分类样本的权重
(2) 多标签分类(非互斥类别)
  • 场景:属性可同时属于多个类别(如车道类型="虚线"+"黄色")。
  • 推荐损失:
    • 二元交叉熵(Binary Cross-Entropy, BCE)
      每个类别独立计算概率。

      python 复制代码
      loss_fn = nn.BCEWithLogitsLoss()  # 内置Sigmoid
      loss = loss_fn(preds, gts)  # preds/gts: [B, C](C为类别数)
    • Dice Loss
      适用于类别高度不平衡的分割任务(如罕见属性)。

      python 复制代码
      def dice_loss(pred, gt, smooth=1e-6):
          pred = torch.sigmoid(pred)
          intersection = (pred  gt).sum()
          return 1 - (2.  intersection + smooth) / (pred.sum() + gt.sum() + smooth)

2. 回归属性(如长度、曲率、角度)

(1) 连续值预测
  • 场景:需预测数值型属性(如车道长度、曲率半径)。
  • 推荐损失:
    • L1 Loss(MAE)
      对异常值鲁棒,输出更稳定。

      python 复制代码
      loss_fn = nn.L1Loss()
      loss = loss_fn(preds, gts)  # preds/gts: [B, 1]
    • Smooth L1 Loss
      结合L1和L2的优点,避免梯度爆炸。

      python 复制代码
      loss_fn = nn.SmoothL1Loss(beta=0.1)  # beta控制过渡区间
(2) 方向或角度预测
  • 场景:预测车道方向(0°~360°),需处理周期性。
  • 推荐损失:
    • Huber Loss for Angles
      将角度差转换为周期性误差。

      python 复制代码
      def angle_loss(pred, gt):
          diff = torch.abs(pred - gt)
          return torch.mean(torch.min(diff, 360 - diff))  # 最小化环形距离

3. 多任务联合训练

若需同时优化多个属性(如颜色+类型+方向),可采用:

(1) 加权多损失融合
python 复制代码
loss_weights = {'color': 1.0, 'type': 0.5, 'angle': 0.2}  # 手动调参
total_loss = sum(loss_weights[k]  loss_fn[k](preds[k], gts[k]) for k in loss_weights)
(2) 自适应权重(如Uncertainty Weighting)
python 复制代码
# 学习每个任务的不确定性权重
log_vars = nn.ParameterDict({k: nn.Parameter(torch.zeros(1)) for k in ['color', 'type']})
loss = sum(0.5  torch.exp(-log_vars[k])  loss_fn[k](preds[k], gts[k]) + log_vars[k] for k in log_vars)

  1. 类别不平衡:
    • 分类任务优先选择 Focal Loss 或加权 CrossEntropy
    • 回归任务使用 HuberSmooth L1
  2. 输出归一化:
    • 角度预测需归一化到 [0, 360] 或使用 sin/cos 编码。
  3. 多任务权衡:
    • 通过网格搜索或不确定性加权调整损失比例。

总结

  • 分类属性:CrossEntropy(平衡数据)或 Focal Loss(不平衡数据)。
  • 回归属性:Smooth L1(通用)或自定义周期性损失(角度)。
  • 多任务:加权求和或自适应不确定性加权。
相关推荐
消晨消晨11 小时前
MONAI初上手——模型构建
pytorch·python·monai
keineahnung234514 小时前
PyTorch symbolic_shapes 模組的 is_contiguous 從哪來?── sizes_strides_user 安裝與實作解析
人工智能·pytorch·python·深度学习
轻口味15 小时前
HarmonyOS 6.1 全栈实战录 - 09 极光底座:ArkWeb 6.1 性能、安全与视觉插帧全特性深度实战
pytorch·安全·harmonyos
轻口味16 小时前
HarmonyOS 6.1 全栈实战录 - 13 流量增长新引擎:全场景归因与 App Linking 链接深度开发实战
pytorch·深度学习·harmonyos
数据皮皮侠AI19 小时前
基于经济学季刊方法测算的中国城市蔓延指数
大数据·人工智能·笔记·数据挖掘·回归
声声codeGrandMaster1 天前
seq2seq概念和数据集处理
人工智能·pytorch·python·算法·ai
2zcode1 天前
基于机器视觉与YOLO11的服装厂废料(边角料)分类检测系统(数据集+UI界面+训练代码+数据分析)
jvm·分类·数据分析·机器视觉·yolo11·服装厂废料
阿正的梦工坊1 天前
深入理解 PyTorch 中的 unsqueeze 操作
人工智能·pytorch·python
技术小黑1 天前
CNN算法实战系列03 | DenseNet121算法实战与解析
pytorch·深度学习·算法·cnn