深度学习(三):PyTorch 损失函数:按任务分类的实用指南

文章目录

PyTorch 损失函数:按任务分类的实用指南

损失函数的核心作用是 "衡量模型预测与真实结果的差异",选择的关键是匹配任务类型(回归 / 分类)和模型输出形式。以下按 "任务场景" 分类整理,清晰标注每个损失函数的适用场景、核心特点及激活函数搭配要求。

一、回归任务专用损失函数

回归任务的目标是预测 "连续数值"(如房价、温度、销量),对应的损失函数聚焦于 "数值差异的量化"。

损失函数 核心特点 适用场景 激活函数搭配
MSELoss 计算 "预测值与真实值差值的平方的平均值"(均方误差),对大误差惩罚更严厉 普通回归任务(如预测房价) 无需额外激活函数
SmoothL1Loss 结合 L1 损失(对异常值不敏感)和 MSE 损失(梯度平滑)的优点,训练更稳定,泛化能力强 回归任务(尤其数据含异常值时,如目标检测中的边界框回归) 无需额外激活函数

二、分类任务专用损失函数

分类任务的目标是预测 "离散类别"(如猫 / 狗 / 鸟、垃圾邮件 / 正常邮件),按 "类别数量" 分为 "多分类" 和 "二分类",损失函数需匹配类别输出形式。

2.1 多分类任务(类别数 ≥3,如 10 类图像分类)

损失函数 核心特点 激活函数搭配(关键!)
CrossEntropyLoss 自带 log_softmax 操作,直接接收模型输出的 "原始得分(Logits)",无需额外加激活函数;本质是 "log_softmax + NLLLoss" 模型最后一层不添加激活函数(直接输出 Logits)
NLLLoss 计算 "负对数似然损失",需先将模型输出转为 "对数概率" 模型最后一层必须加 log_softmax 激活函数

2.2 二分类任务(类别数 =2,如判断是否为垃圾邮件)

损失函数 核心特点 激活函数搭配(关键!)
BCELoss 计算 "二元交叉熵",需接收模型输出的 "0~1 概率值" 模型最后一层必须加 sigmoid 激活函数(将输出压缩到 0~1)
BCEWithLogitsLoss 自带 sigmoid 操作,直接接收模型输出的 "原始得分(Logits)",避免单独加激活函数时的数值不稳定问题;本质是 "sigmoid + BCELoss" 模型最后一层不添加激活函数(直接输出 Logits)

三、特殊任务专用损失函数

针对特定场景设计,解决非常规的匹配或对齐问题。

损失函数 核心用途 典型场景
CTCLoss 用于 "输入与输出长度不固定时的对齐",无需提前标注输入(如音频)与输出(如文本)的一一对应关系 语音识别(音频→文字)、手写体识别(图像→文字)

四、快速选择总结(避坑指南)

  1. 做回归 :数据无异常值用 MSELoss,有异常值或需稳定训练用 SmoothL1Loss
  2. 做多分类 :想省事儿直接用 CrossEntropyLoss(自带激活),若需自定义 log_softmax 再用 NLLLoss
  3. 做二分类 :模型没加 sigmoidBCEWithLogitsLoss(自带激活),加了 sigmoidBCELoss
  4. 语音 / 手写识别 :输入输出长度不固定时,直接用 CTCLoss