PyTorch 常见的损失函数:从基础到大模型的应用
在用 PyTorch 训练神经网络时,损失函数(Loss Function)是不可或缺的"裁判"。它告诉模型预测结果与真实答案的差距有多大,优化器则根据这个差距调整参数。PyTorch 提供了丰富而强大的损失函数接口,位于 torch.nn
模块中。今天我们就来聊聊几个常见的损失函数(比如 nn.MSELoss
和 nn.CrossEntropyLoss
),看看它们的原理和适用场景,最后再揭秘一下 GPT、BERT、LLaMA 等大模型用的是哪些损失函数。
1. 常见损失函数详解
PyTorch 的损失函数种类很多,但有些是"常客",几乎每个深度学习任务都会碰到。以下是几个典型代表:
(1) nn.MSELoss
:均方误差损失
-
全称:Mean Squared Error Loss
-
公式 :
MSE = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=n1i=1∑n(yi−y^i)2
其中 ( y i y_i yi ) 是真实值,( y ^ i \hat{y}_i y^i ) 是预测值,( n n n ) 是样本数。 -
作用:衡量预测值与真实值的平方差平均值,适合连续值的回归任务。
-
使用场景 :
- 预测房价、温度等连续变量。
- 图像重建任务(如自动编码器)。
-
代码示例 :
pythonimport torch import torch.nn as nn loss_fn = nn.MSELoss() pred = torch.tensor([2.5, 0.0, 2.1]) target = torch.tensor([3.0, -0.5, 2.0]) loss = loss_fn(pred, target) print(loss) # 输出张量的均方误差
-
特点:对异常值(outliers)敏感,因为误差被平方放大。
(2) nn.CrossEntropyLoss
:交叉熵损失
-
全称:Cross-Entropy Loss
-
公式 :
对于多分类任务,假设有 ( C C C ) 个类别:
Loss = − 1 N ∑ i = 1 N [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] Loss=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]
实际实现中,PyTorch 结合了LogSoftmax
和NLLLoss
(负对数似然损失)。 -
作用:衡量预测概率分布与真实分布的差异,适合分类任务。
-
使用场景 :
- 图像分类(如手写数字识别)。
- 多标签分类问题。
-
代码示例 :
pythonloss_fn = nn.CrossEntropyLoss() pred = torch.tensor([[1.0, 2.0, 0.5], [0.1, 2.5, 0.3]]) # 未归一化的 logits target = torch.tensor([1, 2]) # 真实类别索引 loss = loss_fn(pred, target) print(loss)
-
特点 :
- 输入是未归一化的 logits(不需要手动加 Softmax)。
- 目标是类别索引(而不是 one-hot 编码)。
(3) nn.BCELoss
和 nn.BCEWithLogitsLoss
:二元交叉熵损失
-
全称:Binary Cross-Entropy Loss
-
公式 :
BCE = − 1 N ∑ i = 1 N [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] \text{BCE} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] BCE=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]
BCEWithLogitsLoss
额外包含 Sigmoid 激活。 -
作用 :适合二分类任务,预测值为 0 或 1 的概率。
为什么上面两个公式一样?请参考笔者的另一篇博客:PyTorch 损失函数解惑:为什么 nn.CrossEntropyLoss 和 nn.BCELoss 的公式看起来一样? -
使用场景 :
- 判断图片中是否有猫。
- 多标签分类(每个标签独立预测)。
-
代码示例 :
pythonloss_fn = nn.BCEWithLogitsLoss() pred = torch.tensor([1.0, -2.0, 0.5]) # logits target = torch.tensor([1.0, 0.0, 1.0]) # 0 或 1 loss = loss_fn(pred, target) print(loss)
-
特点 :
nn.BCELoss
需要手动加 Sigmoid,nn.BCEWithLogitsLoss
更方便且数值更稳定。具体可以参考笔者的另一篇博客:PyTorch 的 nn.BCELoss:为什么需要"手动加 Sigmoid"?
(4) 其他值得一提的损失函数
nn.L1Loss
:均绝对误差(MAE),对异常值不敏感,适合回归任务。nn.NLLLoss
:负对数似然损失,通常与LogSoftmax
搭配使用。具体可以参考笔者的另一篇博客:PyTorch 的 nn.NLLLoss:负对数似然损失全解析nn.KLDivLoss
:KL 散度,用于分布对比,常用于变分自编码器(VAE)。
2. 大模型用的是哪些损失函数?
聊完了基础损失函数,我们来看看 GPT、BERT、LLaMA 等大模型是怎么定义"差距"的。这些模型通常是语言模型(LM),任务和损失函数与它们的训练目标密切相关。
(1) GPT:语言建模的交叉熵
-
任务:自回归语言建模(Autoregressive LM),根据前文预测下一个词。
-
损失函数 :
nn.CrossEntropyLoss
-
细节 :
- GPT 的输入是词序列,输出是对下一个词的概率分布(词汇表大小可能是 50k+)。
- 模型输出 logits,经过
nn.CrossEntropyLoss
计算与真实下一个词的交叉熵。
-
为什么用交叉熵 ?
- 预测下一个词本质上是分类任务,词汇表是类别集合。
- 交叉熵鼓励模型输出接近真实词的概率分布。
-
代码示意 :
pythonloss_fn = nn.CrossEntropyLoss() logits = model(input_ids) # [batch_size, seq_len, vocab_size] loss = loss_fn(logits.view(-1, vocab_size), target_ids.view(-1))
(2) BERT:掩码语言建模与 NSP
- 任务 :
- 掩码语言建模(Masked LM, MLM):预测句子中被掩盖的词。
- 下句预测(Next Sentence Prediction, NSP):判断两句话是否连续。
- 损失函数 :
- MLM :
nn.CrossEntropyLoss
- 输入是被掩盖的词的 logits,目标是正确词的索引。
- NSP :
nn.CrossEntropyLoss
(二分类版本)- 输出是两个类别的 logits(是/否),目标是 0 或 1。
- MLM :
- 细节 :
- BERT 的总损失是 MLM 损失和 NSP 损失之和。
- MLM 类似分类任务,但只对掩码位置计算损失。
- 为什么用交叉熵 ?
- MLM 是词汇表级别的分类,NSP 是二分类,交叉熵都很合适。
(3) LLaMA:高效语言建模
- 任务:与 GPT 类似,也是自回归语言建模。
- 损失函数 :
nn.CrossEntropyLoss
- 细节 :
- LLaMA 是高效优化的 Transformer,训练目标与 GPT 一致,预测下一个 token。
- 损失计算方式与 GPT 几乎相同,针对大词汇表的多分类。
- 特别之处 :
- LLaMA 可能在预训练中加入了一些正则化(如 label smoothing),但核心仍是交叉熵。
3. 大模型为何偏爱交叉熵?
从 GPT 到 BERT、LLaMA,这些大模型几乎都离不开 nn.CrossEntropyLoss
,原因有以下几点:
- 分类本质:语言建模的目标是预测词或类别,属于分类任务,交叉熵是天然选择。
- 概率解释:交叉熵衡量预测分布与真实分布的差异,与语言模型的概率生成目标一致。
- 数值稳定性:PyTorch 的实现结合了 Softmax 和对数运算,避免了数值溢出的问题。
4. 小结:选择合适的损失函数
- 回归任务 :用
nn.MSELoss
或nn.L1Loss
,适合连续值预测。 - 分类任务 :用
nn.CrossEntropyLoss
(多类)或nn.BCEWithLogitsLoss
(二类),适合离散类别预测。 - 大模型 :语言模型大多用
nn.CrossEntropyLoss
,因为它们本质上是"词级分类器"。
PyTorch 的损失函数设计得很贴心,既有基础的数学实现,也有针对深度学习的优化(比如内置 Softmax)。了解这些损失函数的原理和适用场景,能帮你在任务中选对"裁判",让模型训练更高效。
5. 彩蛋:如何调试损失函数?
- 检查输入 :确保
pred
和target
的形状匹配(比如CrossEntropyLoss
需要[batch, classes]
和[batch]
)。 - 打印中间值 :用
print(loss.item())
查看损失大小,判断是否合理。 - 验证梯度 :用
loss.backward()
后检查参数的.grad
,确认优化方向。
希望这篇博客能帮你理清 PyTorch 损失函数的脉络!
后记
2025年2月28日19点04分于上海,在grok3大模型辅助下完成。