前言
你是否遇到过这些困惑?
- 为什么用了
CrossEntropyLoss手动加Softmax,结果 loss 不降反升? BCELoss和BCEWithLogitsLoss到底有什么区别?- 多分类和多标签分类,损失函数该怎么选?
- target 到底该传索引还是 one-hot?
如果你也被这些问题困扰过,那这篇文章就是为你准备的。
本文将从原理到实践 ,用数据流图 + 数学公式 + 代码示例三位一体的方式,彻底讲清楚 PyTorch 中损失函数与激活函数的搭配规则。读完这篇文章,你将能够:
- 根据任务类型,快速选择正确的损失函数
- 理解每种损失函数的内部计算过程
- 避开常见的踩坑点
- 拿到可直接复用的代码模板
组合推荐
| 任务类型 | 推荐组合 | predict 形状 | target 形状/类型 | |
|---|---|---|---|---|
| 多分类(N选1) | Linear → CrossEntropyLoss |
[N, C] |
[N] LongTensor |
|
| 二分类 | Linear → BCEWithLogitsLoss |
[N, 1] |
[N, 1] FloatTensor |
|
| 多标签(N选M) | Linear → BCEWithLogitsLoss |
[N, L] |
[N, L] FloatTensor |
|
| 回归 | Linear → MSELoss |
[N, 1] |
[N, 1] FloatTensor |
注意:
CrossEntropyLoss内置 Softmax,不要手动加BCEWithLogitsLoss内置 Sigmoid,不要手动加CrossEntropyLoss的 target 是类别索引,不是 one-hot
一、快速查表
在深入原理之前,先给大家一个速查表。遇到具体场景时,可以直接来这里查。
按场景查组合
| 我要做什么? | 用什么损失函数? | |
|---|---|---|
| 图像分类(ImageNet 1000类) | CrossEntropyLoss | |
| 文本情感分析(正面/负面) | BCEWithLogitsLoss 或 CrossEntropyLoss | |
| 垃圾邮件检测(是/否) | BCEWithLogitsLoss | |
| 文章打标签(可多选:科技/体育/娱乐) | BCEWithLogitsLoss | |
| 预测房价(连续值) | MSELoss | |
| 目标检测框回归 | SmoothL1Loss | |
| 模型蒸馏 | KLDivLoss | |
| 正负样本极度不平衡(1:1000) | FocalLoss | |
| 人脸识别/图像检索 | TripletMarginLoss | |
| 语音识别/OCR | CTCLoss |
按损失函数查用法
| 损失函数 | 适用场景 | 是否内置激活 | predict | target |
|---|---|---|---|---|
| CrossEntropyLoss | 多分类(互斥) | Softmax | [N,C] logits |
[N] 索引 |
| BCEWithLogitsLoss | 二分类/多标签 | Sigmoid | [N,1]或[N,L] |
同形状,float |
| BCELoss | 二分类(不推荐) | 无,需手动 | [N,1] 概率 |
同形状,float |
| MSELoss | 回归 | 无 | [N,1] |
[N,1] |
| L1Loss | 回归(抗异常值) | 无 | [N,1] |
[N,1] |
| SmoothL1Loss | 目标检测bbox | 无 | [N,1] |
[N,1] |
二、多分类任务:CrossEntropyLoss
这是最常用的分类损失函数,适用于互斥的多分类场景(每个样本只属于一个类别)。
2.1 正确用法
python
import torch
import torch.nn as nn
class MultiClassModel(nn.Module):
def __init__(self, input_size, num_classes):
super().__init__() # 线性层直接输出,不加激活函数!
self.fc = nn.Linear(input_size, num_classes)
self.loss = nn.CrossEntropyLoss()
def forward(self, x, target=None):
predict = self.fc(x) # [N, C] 原始分数(logits)
if target is not None:
return self.loss(predict, target) # target: [N]
return predict
# 推理时获取预测类别
probs = torch.softmax(predict, dim=1)
pred_class = torch.argmax(probs, dim=1)
2.2 数据格式
predict: [N, C] → 每个样本对应 C 个类别的原始分数(logits)
target: [N] → 每个样本的真实类别索引 (0 ~ C-1)
注意:target 是索引,不是 one-hot 编码!
2.3 内部计算过程详解
很多初学者不理解 CrossEntropyLoss 内部到底做了什么。我们用一个具体例子来拆解:
输入数据:
predict = [[2.0, 1.0, 0.1]] # 1个样本,3分类的 logits target = [0] # 真实类别是第0类
┌─────────────────────────────────────────────────────────────┐
│ Step 1: Softmax 将 logits 转为概率分布 │├─────────────────────────────────────────────────────────────┤
│ │
│ e^2.0 = 7.39 │
│ e^1.0 = 2.72 │
│ e^0.1 = 1.11 │
│ ────────── │
│ sum = 11.22 │
│ │
│ p[0] = 7.39 / 11.22 = 0.66 │
│ p[1] = 2.72 / 11.22 = 0.24 │
│ p[2] = 1.11 / 11.22 = 0.10 │
│ │
│ 概率分布 p = [0.66, 0.24, 0.10] │└─────────────────────────────────────────────────────────────┘
↓┌─────────────────────────────────────────────────────────────┐
│ Step 2: 取真实类别的概率,计算负对数 │├─────────────────────────────────────────────────────────────┤
│ │
│ target = 0,所以取 p[0] = 0.66 ││ │
│ Loss = -log(0.66) = 0.42 │
│ │
└─────────────────────────────────────────────────────────────┘
公式总结:
CrossEntropyLoss=−log(eztarget∑jezj)\text{CrossEntropyLoss} = -\log\left(\frac{e^{z_{target}}}{\sum_{j} e^{z_j}}\right)CrossEntropyLoss=−log(∑jezjeztarget)
2.4 为什么这样设计有效?
让我们看看不同预测情况下的 loss 值:
| 预测情况 | 真实类别的概率 | Loss = -log§ | 模型更新幅度 |
|---|---|---|---|
| 预测正确,置信度高 | 0.95 | 0.05 | 几乎不更新 |
| 预测正确,置信度低 | 0.60 | 0.51 | 适度更新 |
| 预测错误 | 0.10 | 2.30 | 大幅更新 |
核心思想:损失函数的梯度会推动模型,让正确类别的分数不断变大!
2.6 为什么 target 是索引而不是 one-hot?
原始交叉熵公式是这样的:
L=−∑i=0C−1yi⋅log(pi)L = -\sum_{i=0}^{C-1} y_i \cdot \log(p_i)L=−i=0∑C−1yi⋅log(pi)
其中 yyy 是 one-hot 编码,如 [0, 1, 0]。
展开后:
L=−(0×log(p0)+1×log(p1)+0×log(p2))L = -(0 \times \log(p_0) + 1 \times \log(p_1) + 0 \times \log(p_2))L=−(0×log(p0)+1×log(p1)+0×log(p2))
发现问题没?只有一项是有效的! 因为 one-hot 中只有一个 1,其他都是 0。
所以可以简化为:
L=−log(ptarget)L = -\log(p_{target})L=−log(ptarget)
结论:只需要知道真实类别的索引即可,不需要完整的 one-hot,更省内存也更高效!
三、二分类任务:BCEWithLogitsLoss
二分类是多分类的特例,但有更高效的实现方式。
3.1 方式一:当作2类多分类
python
self.fc = nn.Linear(hidden_size, 2) # 输出2个分数
self.loss = nn.CrossEntropyLoss()
# predict: [N, 2], target: [N] (值为0或1)
这种方式可行,但输出维度多了一个,不够高效。
3.2 方式二:BCEWithLogitsLoss(推荐)
python
class BinaryClassModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1) # 只输出1个分数
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
predict = self.fc(x) # [N, 1] 原始分数(logits)
if target is not None:
return self.loss(predict, target.float()) # target 必须是 float
return torch.sigmoid(predict) # 推理时转概率
# 推理时
probs = torch.sigmoid(predict)
pred_class = (probs > 0.5).long()
3.3 数据流图示
predict: [N, 1] target: [N, 1]
↓ ↓
┌─────────┐ ┌─────────┐
│ 2.0 │ 样本1 │ 1.0 │ 正样本
├─────────┤ ├─────────┤
│ -1.5 │ 样本2 │ 0.0 │ 负样本
├─────────┤ ├─────────┤
│ 0.5 │ 样本3 │ 1.0 │ 正样本
└─────────┘ └─────────┘
↓ Sigmoid(内置,不需要手动加!)
↓
┌─────────┐
│ 0.88 │ 样本1 → L = -log(0.88) = 0.13
├─────────┤
│ 0.18 │ 样本2 → L = -log(1-0.18) = -log(0.82) = 0.20
├─────────┤
│ 0.62 │ 样本3 → L = -log(0.62) = 0.48
└─────────┘
↓ Loss = (0.13 + 0.20 + 0.48) / 3 = 0.27```
3.4 数学原理
BCEWithLogitsLoss = Sigmoid + BCELoss
步骤1:Sigmoid 将分数转为概率
p=σ(z)=11+e−zp = \sigma(z) = \frac{1}{1 + e^{-z}}p=σ(z)=1+e−z1
步骤2:二元交叉熵
L=−[y⋅log(p)+(1−y)⋅log(1−p)]L = -[y \cdot \log(p) + (1-y) \cdot \log(1-p)]L=−[y⋅log(p)+(1−y)⋅log(1−p)]
公式展开理解:
- 当 y = 1 (正样本)时:L=−log(p)L = -\log(p)L=−log(p) → 希望 p 接近 1
- 当 y = 0 (负样本)时:L=−log(1−p)L = -\log(1-p)L=−log(1−p) → 希望 p 接近 0
3.5 为什么用 BCEWithLogitsLoss 而不是 Sigmoid + BCELoss?
很多人会问:我手动加 Sigmoid 再用 BCELoss 不行吗?
可以,但不推荐。 原因有三:
-
数值稳定性:当 sigmoid(x) 接近 0 或 1 时,log(sigmoid(x)) 会产生数值问题。BCEWithLogitsLoss 内部用 log-sum-exp 技巧避免了这个问题。
-
计算效率:一次前向传播完成,不需要中间存储 sigmoid 结果。
-
梯度更稳定:避免 sigmoid 饱和区的梯度消失问题。
四、多标签分类:BCEWithLogitsLoss
4.1 什么是多标签分类?
多标签 ≠ 多分类!
| 多分类 | 多标签 | |
|---|---|---|
| 定义 | N选1 | N选M |
| 例子 | 这张图是猫还是狗? | 这张图里有猫、有狗、有人? |
| 激活函数 | Softmax(概率和=1) | Sigmoid(每个标签独立) |
| 损失函数 | CrossEntropyLoss | BCEWithLogitsLoss |
4.2 正确用法
python
class MultiLabelModel(nn.Module):
def __init__(self, input_size, num_labels):
super().__init__()
self.fc = nn.Linear(input_size, num_labels)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
predict = self.fc(x) # [N, L] 每个标签的分数
if target is not None:
return self.loss(predict, target.float()) # target: [N, L] 多热编码
return torch.sigmoid(predict)
# 推理时
probs = torch.sigmoid(predict)
pred_labels = (probs > 0.5).long() # 每个位置独立判断
4.3 数据格式
predict: [N, L] → L 个标签的原始分数
target: [N, L] → 多热编码(multi-hot),如 [1, 0, 1, 0, 1]
4.4 数据流图示
predict: [N, L] target: [N, L]
↓ ↓
┌────────────────┐ ┌────────────────┐
│ 2.0 -1.0 0.5 │ 样本1 │ 1 0 1 │ 标签0和2为正
└────────────────┘ └────────────────┘
↓ Sigmoid(对每个位置独立计算)
↓
┌────────────────┐
│ 0.88 0.27 0.62 │
└────────────────┘
↓
每个位置独立计算BCE:
位置0: target=1, p=0.88 → -log(0.88) = 0.13
位置1: target=0, p=0.27 → -log(1-0.27) = 0.31
位置2: target=1, p=0.62 → -log(0.62) = 0.48
↓
Loss = (0.13 + 0.31 + 0.48) / 3 = 0.31
五、回归任务
回归任务预测的是连续值,不需要激活函数,线性层直接输出即可。
5.1 MSELoss(均方误差)
最常用的回归损失函数。
python
class RegressionModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
self.loss = nn.MSELoss()
def forward(self, x, target=None):
predict = self.fc(x) # [N, 1] 直接输出数值
if target is not None:
return self.loss(predict, target)
return predict
公式:
L=1N∑i=1N(yi−y^i)2L = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2L=N1i=1∑N(yi−y^i)2
特点:
- ✅ 处处可导,优化稳定
- ✅ 大误差时梯度大,收敛快
- ❌ 对异常值敏感(平方会放大误差)
5.2 L1Loss(平均绝对误差)
对异常值更鲁棒。
公式:
L=1N∑i=1N∣yi−y^i∣L = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|L=N1i=1∑N∣yi−y^i∣
特点:
- ✅ 对异常值鲁棒(不会放大大误差)
- ❌ 在 0 点不可导
- ❌ 梯度恒定,收敛可能慢
5.3 SmoothL1Loss(Huber Loss)
MSE 和 L1 的结合体,目标检测中的标配。
L={0.5(y−y^)2if ∣y−y^∣<1∣y−y^∣−0.5otherwiseL = \begin{cases} 0.5(y - \hat{y})^2 & \text{if } |y - \hat{y}| < 1 \\ |y - \hat{y}| - 0.5 & \text{otherwise} \end{cases}L={0.5(y−y^)2∣y−y^∣−0.5if ∣y−y^∣<1otherwise
5.4 回归损失对比
| 损失函数 | 对异常值 | 收敛速度 | 适用场景 | |
|---|---|---|---|---|
| MSELoss | 敏感 | 快 | 误差正态分布 | |
| L1Loss | 鲁棒 | 慢 | 有异常值 | |
| SmoothL1Loss | 鲁棒 | 中等 | 目标检测 bbox |
六、特殊任务
6.1 知识蒸馏:KLDivLoss
用大模型(Teacher)的软标签指导小模型(Student)训练。
python
T = 4.0 # 温度参数
loss = nn.KLDivLoss(reduction='batchmean')(
F.log_softmax(student_output / T, dim=1), # Student: log概率
F.softmax(teacher_output / T, dim=1) # Teacher: 概率
为什么要用温度 T?
原始 logits = [3.0, 1.0, 0.1]
T=1: softmax = [0.84, 0.11, 0.05] # 很尖锐,信息少
T=4: softmax = [0.45, 0.30, 0.25] # 更平滑,保留类别间关系
6.2 不平衡分类:Focal Loss
正负样本极度不平衡时(如目标检测中背景:前景 = 1000:1),普通交叉熵会被大量简单负样本主导。
FL(pt)=−αt(1−pt)γlog(pt)FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)FL(pt)=−αt(1−pt)γlog(pt)
python
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__() self.alpha = alpha
self.gamma = gamma
def forward(self, predict, target):
ce_loss = F.cross_entropy(predict, target, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
核心思想:降低易分样本的权重,让模型专注于难分样本。
| 样本类型 | pt | (1−pt)2(1-p_t)^2(1−pt)2 | 效果 |
|---|---|---|---|
| 易分样本 | 0.95 | 0.0025 | 权重极小 |
| 难分样本 | 0.10 | 0.81 | 权重大 |
6.3 对比学习:TripletMarginLoss
学习样本间的相对距离,常用于人脸识别、图像检索。
python
loss = nn.TripletMarginLoss(margin=1.0)(anchor, positive, negative)
L=max(0,d(a,p)−d(a,n)+margin)L = \max(0, d(a, p) - d(a, n) + margin)L=max(0,d(a,p)−d(a,n)+margin)
目标:让同类样本靠近,不同类样本远离。
6.4 序列对齐:CTCLoss
输入输出长度不对齐的场景,如语音识别、OCR。
python
loss = nn.CTCLoss(blank=0)(log_probs, targets, input_lengths, target_lengths)
核心思想:允许多个输入对应同一输出,通过 blank 符号处理对齐。
七、常见错误及修正
7.1 ❌ CrossEntropyLoss + Softmax(重复激活)
python
# ❌ 错误写法
predict = F.softmax(self.fc(x), dim=1)
loss = F.cross_entropy(predict, target) # 内部又做了一次 softmax!
# ✅ 正确写法
predict = self.fc(x) # 直接用 logitsloss = F.cross_entropy(predict, target)
7.2 ❌ BCELoss 忘记 Sigmoid
python
# ❌ 错误写法
predict = self.fc(x) # logits,可能是负数!
loss = F.binary_cross_entropy(predict, target) # BCELoss 期望输入是概率
# ✅ 正确写法1:手动加 Sigmoidpredict = torch.sigmoid(self.fc(x))
loss = F.binary_cross_entropy(predict, target)
# ✅ 正确写法2:用 BCEWithLogitsLoss(推荐)
predict = self.fc(x)
loss = F.binary_cross_entropy_with_logits(predict, target)
7.3 ❌ target 类型错误
python
# CrossEntropyLoss 需要 LongTensortarget = torch.LongTensor([0, 1, 2]) # ✅
target = torch.FloatTensor([0.0, 1.0]) # ❌
# BCEWithLogitsLoss 需要 FloatTensortarget = torch.FloatTensor([0.0, 1.0]) # ✅
target = torch.LongTensor([0, 1]) # ❌
7.4 ❌ 维度不匹配
python
# CrossEntropyLoss
predict: [N, C] # ✅ 二维
target: [N] # ✅ 一维(不是 [N, 1]!)
# BCEWithLogitsLoss - 形状必须一致
predict: [N, 1] # ✅
target: [N, 1] # ✅
7.5 ❌ 多标签用了 CrossEntropyLoss
python
# ❌ 错误:CrossEntropyLoss 假设类别互斥
target = torch.FloatTensor([[1, 1, 0]]) # 同时属于类别0和1
loss = F.cross_entropy(predict, target) # 错!
# ✅ 正确:多标签用 BCEWithLogitsLossloss = F.binary_cross_entropy_with_logits(predict, target)
八、快速选择决策树
你的任务是什么?
│
├── 分类任务
│ │
│ ├── 类别是否互斥?
│ │ │
│ │ ├── 是(N选1)
│ │ │ └── 几个类别?
│ │ │ ├── 2类 → BCEWithLogitsLoss(更高效)
│ │ │ │ 或 CrossEntropyLoss│ │ │ └── >2类 → CrossEntropyLoss│ │ │
│ │ └── 否(N选M,多标签)
│ │ └── BCEWithLogitsLoss
│ │
│ └── 样本是否极度不平衡?
│ └── 是 → FocalLoss│
├── 回归任务
│ │
│ ├── 是否有异常值?
│ │ ├── 有 → L1Loss 或 SmoothL1Loss│ │ └── 无 → MSELoss│ │
│ └── 是目标检测 bbox?
│ └── 是 → SmoothL1Loss│
└── 特殊任务
├── 知识蒸馏 → KLDivLoss ├── 相似度学习 → TripletMarginLoss └── 序列对齐 → CTCLoss
九、核心要点总结
| 要点 | 说明 |
|---|---|
| CrossEntropyLoss 内置 Softmax | 输入是 logits,不要手动加 Softmax |
| BCEWithLogitsLoss 内置 Sigmoid | 输入是 logits,不要手动加 Sigmoid |
| CrossEntropyLoss 的 target 是索引 | 形状 [N],不是 one-hot |
| BCEWithLogitsLoss 的 target 是浮点数 | 需要 .float() 转换 |
| 多标签用 BCE,不是 CrossEntropy | 每个标签独立,不互斥 |
| 回归任务不需要激活函数 | 线性层直接输出数值 |
十、完整代码模板
多分类
python
class MultiClassModel(nn.Module):
def __init__(self, input_size, num_classes):
super().__init__()
self.fc = nn.Linear(input_size, num_classes)
self.loss = nn.CrossEntropyLoss()
def forward(self, x, target=None):
logits = self.fc(x)
if target is not None:
return self.loss(logits, target)
return logits
# 推理
probs = F.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1)
二分类
python
class BinaryClassModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
logits = self.fc(x)
if target is not None:
return self.loss(logits, target.float())
return torch.sigmoid(logits)
# 推理
probs = torch.sigmoid(logits)
pred_class = (probs > 0.5).long()
多标签分类
python
class MultiLabelModel(nn.Module):
def __init__(self, input_size, num_labels):
super().__init__()
self.fc = nn.Linear(input_size, num_labels)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, x, target=None):
logits = self.fc(x)
if target is not None:
return self.loss(logits, target.float())
return torch.sigmoid(logits)
# 推理
probs = torch.sigmoid(logits)
pred_labels = (probs > 0.5).long()
回归
python
class RegressionModel(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, 1)
self.loss = nn.MSELoss()
def forward(self, x, target=None):
predict = self.fc(x)
if target is not None:
return self.loss(predict, target)
return predict