文章目录
- [PyTorch 损失函数:按任务分类的实用指南](#PyTorch 损失函数:按任务分类的实用指南)
-
- 一、回归任务专用损失函数
- 二、分类任务专用损失函数
-
- [2.1 多分类任务(类别数 ≥3,如 10 类图像分类)](#2.1 多分类任务(类别数 ≥3,如 10 类图像分类))
- [2.2 二分类任务(类别数 =2,如判断是否为垃圾邮件)](#2.2 二分类任务(类别数 =2,如判断是否为垃圾邮件))
- 三、特殊任务专用损失函数
- 四、快速选择总结(避坑指南)
PyTorch 损失函数:按任务分类的实用指南
损失函数的核心作用是 "衡量模型预测与真实结果的差异",选择的关键是匹配任务类型(回归 / 分类)和模型输出形式。以下按 "任务场景" 分类整理,清晰标注每个损失函数的适用场景、核心特点及激活函数搭配要求。
一、回归任务专用损失函数
回归任务的目标是预测 "连续数值"(如房价、温度、销量),对应的损失函数聚焦于 "数值差异的量化"。
损失函数 | 核心特点 | 适用场景 | 激活函数搭配 |
---|---|---|---|
MSELoss | 计算 "预测值与真实值差值的平方的平均值"(均方误差),对大误差惩罚更严厉 | 普通回归任务(如预测房价) | 无需额外激活函数 |
SmoothL1Loss | 结合 L1 损失(对异常值不敏感)和 MSE 损失(梯度平滑)的优点,训练更稳定,泛化能力强 | 回归任务(尤其数据含异常值时,如目标检测中的边界框回归) | 无需额外激活函数 |
二、分类任务专用损失函数
分类任务的目标是预测 "离散类别"(如猫 / 狗 / 鸟、垃圾邮件 / 正常邮件),按 "类别数量" 分为 "多分类" 和 "二分类",损失函数需匹配类别输出形式。
2.1 多分类任务(类别数 ≥3,如 10 类图像分类)
损失函数 | 核心特点 | 激活函数搭配(关键!) |
---|---|---|
CrossEntropyLoss | 自带 log_softmax 操作,直接接收模型输出的 "原始得分(Logits)",无需额外加激活函数;本质是 "log_softmax + NLLLoss" |
模型最后一层不添加激活函数(直接输出 Logits) |
NLLLoss | 计算 "负对数似然损失",需先将模型输出转为 "对数概率" | 模型最后一层必须加 log_softmax 激活函数 |
2.2 二分类任务(类别数 =2,如判断是否为垃圾邮件)
损失函数 | 核心特点 | 激活函数搭配(关键!) |
---|---|---|
BCELoss | 计算 "二元交叉熵",需接收模型输出的 "0~1 概率值" | 模型最后一层必须加 sigmoid 激活函数(将输出压缩到 0~1) |
BCEWithLogitsLoss | 自带 sigmoid 操作,直接接收模型输出的 "原始得分(Logits)",避免单独加激活函数时的数值不稳定问题;本质是 "sigmoid + BCELoss" |
模型最后一层不添加激活函数(直接输出 Logits) |
三、特殊任务专用损失函数
针对特定场景设计,解决非常规的匹配或对齐问题。
损失函数 | 核心用途 | 典型场景 |
---|---|---|
CTCLoss | 用于 "输入与输出长度不固定时的对齐",无需提前标注输入(如音频)与输出(如文本)的一一对应关系 | 语音识别(音频→文字)、手写体识别(图像→文字) |
四、快速选择总结(避坑指南)
- 做回归 :数据无异常值用
MSELoss
,有异常值或需稳定训练用SmoothL1Loss
; - 做多分类 :想省事儿直接用
CrossEntropyLoss
(自带激活),若需自定义log_softmax
再用NLLLoss
; - 做二分类 :模型没加
sigmoid
用BCEWithLogitsLoss
(自带激活),加了sigmoid
用BCELoss
; - 语音 / 手写识别 :输入输出长度不固定时,直接用
CTCLoss
。