PyTorch 损失函数与激活函数的正确组合

前言

你是否遇到过这些困惑?

  • 为什么用了 CrossEntropyLoss 手动加 Softmax,结果 loss 不降反升?
  • BCELossBCEWithLogitsLoss 到底有什么区别?
  • 多分类和多标签分类,损失函数该怎么选?
  • target 到底该传索引还是 one-hot?

如果你也被这些问题困扰过,那这篇文章就是为你准备的。

本文将从原理到实践 ,用数据流图 + 数学公式 + 代码示例三位一体的方式,彻底讲清楚 PyTorch 中损失函数与激活函数的搭配规则。读完这篇文章,你将能够:

  1. 根据任务类型,快速选择正确的损失函数
  2. 理解每种损失函数的内部计算过程
  3. 避开常见的踩坑点
  4. 拿到可直接复用的代码模板

组合推荐

任务类型 推荐组合 predict 形状 target 形状/类型
多分类(N选1) LinearCrossEntropyLoss [N, C] [N] LongTensor
二分类 LinearBCEWithLogitsLoss [N, 1] [N, 1] FloatTensor
多标签(N选M) LinearBCEWithLogitsLoss [N, L] [N, L] FloatTensor
回归 LinearMSELoss [N, 1] [N, 1] FloatTensor

注意:

  1. CrossEntropyLoss 内置 Softmax,不要手动加
  2. BCEWithLogitsLoss 内置 Sigmoid,不要手动加
  3. CrossEntropyLoss 的 target 是类别索引,不是 one-hot

一、快速查表

在深入原理之前,先给大家一个速查表。遇到具体场景时,可以直接来这里查。

按场景查组合

我要做什么? 用什么损失函数?
图像分类(ImageNet 1000类) CrossEntropyLoss
文本情感分析(正面/负面) BCEWithLogitsLoss 或 CrossEntropyLoss
垃圾邮件检测(是/否) BCEWithLogitsLoss
文章打标签(可多选:科技/体育/娱乐) BCEWithLogitsLoss
预测房价(连续值) MSELoss
目标检测框回归 SmoothL1Loss
模型蒸馏 KLDivLoss
正负样本极度不平衡(1:1000) FocalLoss
人脸识别/图像检索 TripletMarginLoss
语音识别/OCR CTCLoss

按损失函数查用法

损失函数 适用场景 是否内置激活 predict target
CrossEntropyLoss 多分类(互斥) Softmax [N,C] logits [N] 索引
BCEWithLogitsLoss 二分类/多标签 Sigmoid [N,1][N,L] 同形状,float
BCELoss 二分类(不推荐) 无,需手动 [N,1] 概率 同形状,float
MSELoss 回归 [N,1] [N,1]
L1Loss 回归(抗异常值) [N,1] [N,1]
SmoothL1Loss 目标检测bbox [N,1] [N,1]

二、多分类任务:CrossEntropyLoss

这是最常用的分类损失函数,适用于互斥的多分类场景(每个样本只属于一个类别)。

2.1 正确用法

python 复制代码
import torch  
import torch.nn as nn  
  
class MultiClassModel(nn.Module):  
    def __init__(self, input_size, num_classes):        
	    super().__init__()        # 线性层直接输出,不加激活函数!  
        self.fc = nn.Linear(input_size, num_classes)        
        self.loss = nn.CrossEntropyLoss()  
    def forward(self, x, target=None):        
	    predict = self.fc(x)  # [N, C] 原始分数(logits)  
	        if target is not None:            
		        return self.loss(predict, target)  # target: [N]        
		return predict  
# 推理时获取预测类别  
		probs = torch.softmax(predict, dim=1)  
		pred_class = torch.argmax(probs, dim=1)  

2.2 数据格式

predict: [N, C] → 每个样本对应 C 个类别的原始分数(logits)

target: [N] → 每个样本的真实类别索引 (0 ~ C-1)
注意:target 是索引,不是 one-hot 编码!

2.3 内部计算过程详解

很多初学者不理解 CrossEntropyLoss 内部到底做了什么。我们用一个具体例子来拆解:

输入数据:
predict = [[2.0, 1.0, 0.1]] # 1个样本,3分类的 logits target = [0] # 真实类别是第0类

复制代码
┌─────────────────────────────────────────────────────────────┐  
│  Step 1: Softmax 将 logits 转为概率分布                       │├─────────────────────────────────────────────────────────────┤  
│                                                             │  
│   e^2.0 = 7.39                                              │  
│   e^1.0 = 2.72                                              │  
│   e^0.1 = 1.11                                              │  
│   ──────────                                                │  
│   sum   = 11.22                                             │  
│                                                             │  
│   p[0] = 7.39 / 11.22 = 0.66                                │  
│   p[1] = 2.72 / 11.22 = 0.24                                │  
│   p[2] = 1.11 / 11.22 = 0.10                                │  
│                                                             │  
│   概率分布 p = [0.66, 0.24, 0.10]                            │└─────────────────────────────────────────────────────────────┘  
                              ↓┌─────────────────────────────────────────────────────────────┐  
│  Step 2: 取真实类别的概率,计算负对数                          │├─────────────────────────────────────────────────────────────┤  
│                                                             │  
│   target = 0,所以取 p[0] = 0.66                             ││                                                             │  
│   Loss = -log(0.66) = 0.42                                  │  
│                                                             │  
└─────────────────────────────────────────────────────────────┘ 

公式总结:

CrossEntropyLoss=−log⁡(eztarget∑jezj)\text{CrossEntropyLoss} = -\log\left(\frac{e^{z_{target}}}{\sum_{j} e^{z_j}}\right)CrossEntropyLoss=−log(∑jezjeztarget)

2.4 为什么这样设计有效?

让我们看看不同预测情况下的 loss 值:

预测情况 真实类别的概率 Loss = -log§ 模型更新幅度
预测正确,置信度高 0.95 0.05 几乎不更新
预测正确,置信度低 0.60 0.51 适度更新
预测错误 0.10 2.30 大幅更新

核心思想:损失函数的梯度会推动模型,让正确类别的分数不断变大!

2.6 为什么 target 是索引而不是 one-hot?

原始交叉熵公式是这样的:

L=−∑i=0C−1yi⋅log⁡(pi)L = -\sum_{i=0}^{C-1} y_i \cdot \log(p_i)L=−i=0∑C−1yi⋅log(pi)

其中 yyy 是 one-hot 编码,如 [0, 1, 0]

展开后:
L=−(0×log⁡(p0)+1×log⁡(p1)+0×log⁡(p2))L = -(0 \times \log(p_0) + 1 \times \log(p_1) + 0 \times \log(p_2))L=−(0×log(p0)+1×log(p1)+0×log(p2))

发现问题没?只有一项是有效的! 因为 one-hot 中只有一个 1,其他都是 0。

所以可以简化为:
L=−log⁡(ptarget)L = -\log(p_{target})L=−log(ptarget)

结论:只需要知道真实类别的索引即可,不需要完整的 one-hot,更省内存也更高效!


三、二分类任务:BCEWithLogitsLoss

二分类是多分类的特例,但有更高效的实现方式。

3.1 方式一:当作2类多分类

python 复制代码
self.fc = nn.Linear(hidden_size, 2)  # 输出2个分数  
self.loss = nn.CrossEntropyLoss()  
# predict: [N, 2], target: [N] (值为0或1)  

这种方式可行,但输出维度多了一个,不够高效。

3.2 方式二:BCEWithLogitsLoss(推荐)

python 复制代码
class BinaryClassModel(nn.Module):  
    def __init__(self, input_size):        
	    super().__init__()        
	    self.fc = nn.Linear(input_size, 1)  # 只输出1个分数  
        self.loss = nn.BCEWithLogitsLoss()  
    def forward(self, x, target=None):        
	    predict = self.fc(x)  # [N, 1] 原始分数(logits)  
        if target is not None:            
	        return self.loss(predict, target.float())  # target 必须是 float        
		return torch.sigmoid(predict)  # 推理时转概率  
		
		
		
# 推理时  
probs = torch.sigmoid(predict)  
pred_class = (probs > 0.5).long()  

3.3 数据流图示

复制代码
predict: [N, 1]           target: [N, 1]  
     ↓                         ↓
┌─────────┐               ┌─────────┐  
│   2.0   │  样本1        │   1.0   │  正样本  
├─────────┤               ├─────────┤  
│  -1.5   │  样本2         │   0.0   │  负样本  
├─────────┤               ├─────────┤  
│   0.5   │  样本3         │   1.0   │  正样本  
└─────────┘                └─────────┘  
     ↓  Sigmoid(内置,不需要手动加!)  
     ↓
┌─────────┐  
│  0.88   │  样本1 → L = -log(0.88) = 0.13  
├─────────┤  
│  0.18   │  样本2 → L = -log(1-0.18) = -log(0.82) = 0.20  
├─────────┤  
│  0.62   │  样本3 → L = -log(0.62) = 0.48  
└─────────┘  
     ↓  Loss = (0.13 + 0.20 + 0.48) / 3 = 0.27```

3.4 数学原理

BCEWithLogitsLoss = Sigmoid + BCELoss

步骤1:Sigmoid 将分数转为概率
p=σ(z)=11+e−zp = \sigma(z) = \frac{1}{1 + e^{-z}}p=σ(z)=1+e−z1

步骤2:二元交叉熵
L=−[y⋅log⁡(p)+(1−y)⋅log⁡(1−p)]L = -[y \cdot \log(p) + (1-y) \cdot \log(1-p)]L=−[y⋅log(p)+(1−y)⋅log(1−p)]

公式展开理解:

  • y = 1 (正样本)时:L=−log⁡(p)L = -\log(p)L=−log(p) → 希望 p 接近 1
  • y = 0 (负样本)时:L=−log⁡(1−p)L = -\log(1-p)L=−log(1−p) → 希望 p 接近 0

3.5 为什么用 BCEWithLogitsLoss 而不是 Sigmoid + BCELoss?

很多人会问:我手动加 Sigmoid 再用 BCELoss 不行吗?

可以,但不推荐。 原因有三:

  1. 数值稳定性:当 sigmoid(x) 接近 0 或 1 时,log(sigmoid(x)) 会产生数值问题。BCEWithLogitsLoss 内部用 log-sum-exp 技巧避免了这个问题。

  2. 计算效率:一次前向传播完成,不需要中间存储 sigmoid 结果。

  3. 梯度更稳定:避免 sigmoid 饱和区的梯度消失问题。


四、多标签分类:BCEWithLogitsLoss

4.1 什么是多标签分类?

多标签 ≠ 多分类!

多分类 多标签
定义 N选1 N选M
例子 这张图是猫还是狗? 这张图里有猫、有狗、有人?
激活函数 Softmax(概率和=1) Sigmoid(每个标签独立)
损失函数 CrossEntropyLoss BCEWithLogitsLoss

4.2 正确用法

python 复制代码
class MultiLabelModel(nn.Module):  
    def __init__(self, input_size, num_labels):        
	    super().__init__()        
	    self.fc = nn.Linear(input_size, num_labels)        
	    self.loss = nn.BCEWithLogitsLoss()  
    def forward(self, x, target=None):        
	    predict = self.fc(x)  # [N, L] 每个标签的分数  
        if target is not None:            
	        return self.loss(predict, target.float())  # target: [N, L] 多热编码  
        return torch.sigmoid(predict)  
        
        
# 推理时  
probs = torch.sigmoid(predict)  
pred_labels = (probs > 0.5).long()  # 每个位置独立判断  

4.3 数据格式

复制代码
predict: [N, L]  →  L 个标签的原始分数  
target:  [N, L]  →  多热编码(multi-hot),如 [1, 0, 1, 0, 1]

4.4 数据流图示

复制代码
predict: [N, L]              target: [N, L]  
     ↓                            ↓
┌────────────────┐          ┌────────────────┐  
│ 2.0  -1.0  0.5 │ 样本1     │  1    0    1   │ 标签0和2为正  
└────────────────┘          └────────────────┘  
     ↓  Sigmoid(对每个位置独立计算)  
     ↓
┌────────────────┐  
│ 0.88 0.27 0.62 │  
└────────────────┘  
     ↓  
 每个位置独立计算BCE:  
  位置0: target=1, p=0.88 → -log(0.88) = 0.13  
  位置1: target=0, p=0.27 → -log(1-0.27) = 0.31  
  位置2: target=1, p=0.62 → -log(0.62) = 0.48  
     ↓  
 Loss = (0.13 + 0.31 + 0.48) / 3 = 0.31 

五、回归任务

回归任务预测的是连续值,不需要激活函数,线性层直接输出即可。

5.1 MSELoss(均方误差)

最常用的回归损失函数。

python 复制代码
class RegressionModel(nn.Module):  
    def __init__(self, input_size):        
	    super().__init__()        
	    self.fc = nn.Linear(input_size, 1)        
	    self.loss = nn.MSELoss()  
    def forward(self, x, target=None):        
	    predict = self.fc(x)  # [N, 1] 直接输出数值  
        if target is not None:            
	        return self.loss(predict, target)        
		return predict  

公式:
L=1N∑i=1N(yi−y^i)2L = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2L=N1i=1∑N(yi−y^i)2

特点:

  • ✅ 处处可导,优化稳定
  • ✅ 大误差时梯度大,收敛快
  • ❌ 对异常值敏感(平方会放大误差)

5.2 L1Loss(平均绝对误差)

对异常值更鲁棒。

公式:
L=1N∑i=1N∣yi−y^i∣L = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|L=N1i=1∑N∣yi−y^i∣

特点:

  • ✅ 对异常值鲁棒(不会放大大误差)
  • ❌ 在 0 点不可导
  • ❌ 梯度恒定,收敛可能慢

5.3 SmoothL1Loss(Huber Loss)

MSE 和 L1 的结合体,目标检测中的标配。

L={0.5(y−y^)2if ∣y−y^∣<1∣y−y^∣−0.5otherwiseL = \begin{cases} 0.5(y - \hat{y})^2 & \text{if } |y - \hat{y}| < 1 \\ |y - \hat{y}| - 0.5 & \text{otherwise} \end{cases}L={0.5(y−y^)2∣y−y^∣−0.5if ∣y−y^∣<1otherwise

5.4 回归损失对比

损失函数 对异常值 收敛速度 适用场景
MSELoss 敏感 误差正态分布
L1Loss 鲁棒 有异常值
SmoothL1Loss 鲁棒 中等 目标检测 bbox

六、特殊任务

6.1 知识蒸馏:KLDivLoss

用大模型(Teacher)的软标签指导小模型(Student)训练。

python 复制代码
T = 4.0  # 温度参数  
loss = nn.KLDivLoss(reduction='batchmean')(  
    F.log_softmax(student_output / T, dim=1),  # Student: log概率  
    F.softmax(teacher_output / T, dim=1)        # Teacher: 概率  

为什么要用温度 T?

复制代码
原始 logits = [3.0, 1.0, 0.1]
T=1: softmax = [0.84, 0.11, 0.05]  # 很尖锐,信息少  
T=4: softmax = [0.45, 0.30, 0.25]  # 更平滑,保留类别间关系  

6.2 不平衡分类:Focal Loss

正负样本极度不平衡时(如目标检测中背景:前景 = 1000:1),普通交叉熵会被大量简单负样本主导。

FL(pt)=−αt(1−pt)γlog⁡(pt)FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)FL(pt)=−αt(1−pt)γlog(pt)

python 复制代码
class FocalLoss(nn.Module):  
    def __init__(self, alpha=0.25, gamma=2.0):        
	    super().__init__()        self.alpha = alpha        
	    self.gamma = gamma  
    def forward(self, predict, target):        
	    ce_loss = F.cross_entropy(predict, target, reduction='none')        
	    pt = torch.exp(-ce_loss)        
	    focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss        
		return focal_loss.mean()  

核心思想:降低易分样本的权重,让模型专注于难分样本。

样本类型 pt (1−pt)2(1-p_t)^2(1−pt)2 效果
易分样本 0.95 0.0025 权重极小
难分样本 0.10 0.81 权重大

6.3 对比学习:TripletMarginLoss

学习样本间的相对距离,常用于人脸识别、图像检索。

python 复制代码
loss = nn.TripletMarginLoss(margin=1.0)(anchor, positive, negative)  

L=max⁡(0,d(a,p)−d(a,n)+margin)L = \max(0, d(a, p) - d(a, n) + margin)L=max(0,d(a,p)−d(a,n)+margin)

目标:让同类样本靠近,不同类样本远离。

6.4 序列对齐:CTCLoss

输入输出长度不对齐的场景,如语音识别、OCR。

python 复制代码
loss = nn.CTCLoss(blank=0)(log_probs, targets, input_lengths, target_lengths)  

核心思想:允许多个输入对应同一输出,通过 blank 符号处理对齐。


七、常见错误及修正

7.1 ❌ CrossEntropyLoss + Softmax(重复激活)

python 复制代码
# ❌ 错误写法  
predict = F.softmax(self.fc(x), dim=1)  
loss = F.cross_entropy(predict, target)  # 内部又做了一次 softmax!  
  
# ✅ 正确写法  
predict = self.fc(x)  # 直接用 logitsloss = F.cross_entropy(predict, target)  

7.2 ❌ BCELoss 忘记 Sigmoid

python 复制代码
# ❌ 错误写法  
predict = self.fc(x)  # logits,可能是负数!  
loss = F.binary_cross_entropy(predict, target)  # BCELoss 期望输入是概率  
  
# ✅ 正确写法1:手动加 Sigmoidpredict = torch.sigmoid(self.fc(x))  
loss = F.binary_cross_entropy(predict, target)  
  
# ✅ 正确写法2:用 BCEWithLogitsLoss(推荐)  
predict = self.fc(x)  
loss = F.binary_cross_entropy_with_logits(predict, target)  

7.3 ❌ target 类型错误

python 复制代码
# CrossEntropyLoss 需要 LongTensortarget = torch.LongTensor([0, 1, 2])      # ✅  
target = torch.FloatTensor([0.0, 1.0])    # ❌  
  
# BCEWithLogitsLoss 需要 FloatTensortarget = torch.FloatTensor([0.0, 1.0])    # ✅  
target = torch.LongTensor([0, 1])         # ❌  

7.4 ❌ 维度不匹配

python 复制代码
# CrossEntropyLoss  
predict: [N, C]  # ✅ 二维  
target:  [N]     # ✅ 一维(不是 [N, 1]!)  
  
# BCEWithLogitsLoss - 形状必须一致  
predict: [N, 1]  # ✅  
target:  [N, 1]  # ✅  

7.5 ❌ 多标签用了 CrossEntropyLoss

python 复制代码
# ❌ 错误:CrossEntropyLoss 假设类别互斥  
target = torch.FloatTensor([[1, 1, 0]])  # 同时属于类别0和1  
loss = F.cross_entropy(predict, target)   # 错!  
  
# ✅ 正确:多标签用 BCEWithLogitsLossloss = F.binary_cross_entropy_with_logits(predict, target)  

八、快速选择决策树

复制代码
你的任务是什么?  
│  
├── 分类任务  
│   │  
│   ├── 类别是否互斥?  
│   │   │  
│   │   ├── 是(N选1)  
│   │   │   └── 几个类别?  
│   │   │       ├── 2类 → BCEWithLogitsLoss(更高效)  
│   │   │       │         或 CrossEntropyLoss│   │   │       └── >2类 → CrossEntropyLoss│   │   │  
│   │   └── 否(N选M,多标签)  
│   │       └── BCEWithLogitsLoss  
│   │  
│   └── 样本是否极度不平衡?  
│       └── 是 → FocalLoss│  
├── 回归任务  
│   │  
│   ├── 是否有异常值?  
│   │   ├── 有 → L1Loss 或 SmoothL1Loss│   │   └── 无 → MSELoss│   │  
│   └── 是目标检测 bbox?  
│       └── 是 → SmoothL1Loss│  
└── 特殊任务  
    ├── 知识蒸馏 → KLDivLoss    ├── 相似度学习 → TripletMarginLoss    └── 序列对齐 → CTCLoss

九、核心要点总结

要点 说明
CrossEntropyLoss 内置 Softmax 输入是 logits,不要手动加 Softmax
BCEWithLogitsLoss 内置 Sigmoid 输入是 logits,不要手动加 Sigmoid
CrossEntropyLoss 的 target 是索引 形状 [N]不是 one-hot
BCEWithLogitsLoss 的 target 是浮点数 需要 .float() 转换
多标签用 BCE,不是 CrossEntropy 每个标签独立,不互斥
回归任务不需要激活函数 线性层直接输出数值

十、完整代码模板

多分类

python 复制代码
class MultiClassModel(nn.Module):  
    def __init__(self, input_size, num_classes):        
	    super().__init__()        
	    self.fc = nn.Linear(input_size, num_classes)        
	    self.loss = nn.CrossEntropyLoss()  
    def forward(self, x, target=None):        
	    logits = self.fc(x)        
	    if target is not None:            
		    return self.loss(logits, target)        
		return logits  
# 推理  
probs = F.softmax(logits, dim=1)  
pred_class = torch.argmax(probs, dim=1)  

二分类

python 复制代码
class BinaryClassModel(nn.Module):  
    def __init__(self, input_size):        
	    super().__init__()        
	    self.fc = nn.Linear(input_size, 1)        
	    self.loss = nn.BCEWithLogitsLoss()  
    def forward(self, x, target=None):        
	    logits = self.fc(x)        
	    if target is not None:            
		    return self.loss(logits, target.float())        
	    return torch.sigmoid(logits)  
    
    
# 推理  
probs = torch.sigmoid(logits)  
pred_class = (probs > 0.5).long()  

多标签分类

python 复制代码
class MultiLabelModel(nn.Module):  
    def __init__(self, input_size, num_labels):        
	    super().__init__()        
		self.fc = nn.Linear(input_size, num_labels)        
		self.loss = nn.BCEWithLogitsLoss()  
    def forward(self, x, target=None):        
	    logits = self.fc(x)        
	    if target is not None:            
		    return self.loss(logits, target.float())        
		return torch.sigmoid(logits)  
	
	
# 推理  
probs = torch.sigmoid(logits)  
pred_labels = (probs > 0.5).long()  

回归

python 复制代码
class RegressionModel(nn.Module):  
    def __init__(self, input_size):        
	    super().__init__()        
	    self.fc = nn.Linear(input_size, 1)        
	    self.loss = nn.MSELoss()  
    def forward(self, x, target=None):        
	    predict = self.fc(x)        
	    if target is not None:            
		    return self.loss(predict, target)        
		return predict  

相关推荐
AAA简单玩转程序设计1 小时前
Python避坑指南:基础玩家的3个"开挂"技巧
python
Mrliu__1 小时前
Opencv(十八) : 图像凸包检测
人工智能·opencv·计算机视觉
Brduino脑机接口技术答疑1 小时前
脑机接口数据处理连载(六) 脑机接口频域特征提取实战:傅里叶变换与功率谱分析
人工智能·python·算法·机器学习·数据分析·脑机接口
计算所陈老师1 小时前
Palantir的核心是Ontology
大数据·人工智能·知识图谱
大转转FE1 小时前
[特殊字符] 浏览器自动化革命:从 Selenium 到 AI Browser 的 20 年进化史
运维·人工智能·selenium·测试工具·自动化
轻竹办公PPT1 小时前
写开题报告花完精力了,PPT 没法做了。
python·powerpoint
世岩清上1 小时前
世岩清上:科技向善,让乡村“被看见”更“被理解”
人工智能·ar·乡村振兴·和美乡村
dagouaofei1 小时前
AI 生成开题报告 PPT 会自动提炼重点吗?
人工智能·python·powerpoint