PyTorch 常见的损失函数:从基础到大模型的应用

PyTorch 常见的损失函数:从基础到大模型的应用

在用 PyTorch 训练神经网络时,损失函数(Loss Function)是不可或缺的"裁判"。它告诉模型预测结果与真实答案的差距有多大,优化器则根据这个差距调整参数。PyTorch 提供了丰富而强大的损失函数接口,位于 torch.nn 模块中。今天我们就来聊聊几个常见的损失函数(比如 nn.MSELossnn.CrossEntropyLoss),看看它们的原理和适用场景,最后再揭秘一下 GPT、BERT、LLaMA 等大模型用的是哪些损失函数。

1. 常见损失函数详解

PyTorch 的损失函数种类很多,但有些是"常客",几乎每个深度学习任务都会碰到。以下是几个典型代表:

(1) nn.MSELoss:均方误差损失
  • 全称:Mean Squared Error Loss

  • 公式
    MSE = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=n1i=1∑n(yi−y^i)2
    其中 ( y i y_i yi ) 是真实值,( y ^ i \hat{y}_i y^i ) 是预测值,( n n n ) 是样本数。

  • 作用:衡量预测值与真实值的平方差平均值,适合连续值的回归任务。

  • 使用场景

    • 预测房价、温度等连续变量。
    • 图像重建任务(如自动编码器)。
  • 代码示例

    python 复制代码
    import torch
    import torch.nn as nn
    
    loss_fn = nn.MSELoss()
    pred = torch.tensor([2.5, 0.0, 2.1])
    target = torch.tensor([3.0, -0.5, 2.0])
    loss = loss_fn(pred, target)
    print(loss)  # 输出张量的均方误差
  • 特点:对异常值(outliers)敏感,因为误差被平方放大。

(2) nn.CrossEntropyLoss:交叉熵损失
  • 全称:Cross-Entropy Loss

  • 公式
    对于多分类任务,假设有 ( C C C ) 个类别:
    Loss = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] Loss=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]
    实际实现中,PyTorch 结合了 LogSoftmaxNLLLoss(负对数似然损失)。

  • 作用:衡量预测概率分布与真实分布的差异,适合分类任务。

  • 使用场景

    • 图像分类(如手写数字识别)。
    • 多标签分类问题。
  • 代码示例

    python 复制代码
    loss_fn = nn.CrossEntropyLoss()
    pred = torch.tensor([[1.0, 2.0, 0.5], [0.1, 2.5, 0.3]])  # 未归一化的 logits
    target = torch.tensor([1, 2])  # 真实类别索引
    loss = loss_fn(pred, target)
    print(loss)
  • 特点

    • 输入是未归一化的 logits(不需要手动加 Softmax)。
    • 目标是类别索引(而不是 one-hot 编码)。
(3) nn.BCELossnn.BCEWithLogitsLoss:二元交叉熵损失
  • 全称:Binary Cross-Entropy Loss

  • 公式
    BCE = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \text{BCE} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i)] BCE=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]
    BCEWithLogitsLoss 额外包含 Sigmoid 激活。

  • 作用 :适合二分类任务,预测值为 0 或 1 的概率。
    为什么上面两个公式一样?请参考笔者的另一篇博客:PyTorch 损失函数解惑:为什么 nn.CrossEntropyLoss 和 nn.BCELoss 的公式看起来一样?

  • 使用场景

    • 判断图片中是否有猫。
    • 多标签分类(每个标签独立预测)。
  • 代码示例

    python 复制代码
    loss_fn = nn.BCEWithLogitsLoss()
    pred = torch.tensor([1.0, -2.0, 0.5])  # logits
    target = torch.tensor([1.0, 0.0, 1.0])  # 0 或 1
    loss = loss_fn(pred, target)
    print(loss)
  • 特点

(4) 其他值得一提的损失函数
  • nn.L1Loss:均绝对误差(MAE),对异常值不敏感,适合回归任务。
  • nn.NLLLoss :负对数似然损失,通常与 LogSoftmax 搭配使用。具体可以参考笔者的另一篇博客:PyTorch 的 nn.NLLLoss:负对数似然损失全解析
  • nn.KLDivLoss:KL 散度,用于分布对比,常用于变分自编码器(VAE)。
2. 大模型用的是哪些损失函数?

聊完了基础损失函数,我们来看看 GPT、BERT、LLaMA 等大模型是怎么定义"差距"的。这些模型通常是语言模型(LM),任务和损失函数与它们的训练目标密切相关。

(1) GPT:语言建模的交叉熵
  • 任务:自回归语言建模(Autoregressive LM),根据前文预测下一个词。

  • 损失函数nn.CrossEntropyLoss

  • 细节

    • GPT 的输入是词序列,输出是对下一个词的概率分布(词汇表大小可能是 50k+)。
    • 模型输出 logits,经过 nn.CrossEntropyLoss 计算与真实下一个词的交叉熵。
  • 为什么用交叉熵

    • 预测下一个词本质上是分类任务,词汇表是类别集合。
    • 交叉熵鼓励模型输出接近真实词的概率分布。
  • 代码示意

    python 复制代码
    loss_fn = nn.CrossEntropyLoss()
    logits = model(input_ids)  # [batch_size, seq_len, vocab_size]
    loss = loss_fn(logits.view(-1, vocab_size), target_ids.view(-1))
(2) BERT:掩码语言建模与 NSP
  • 任务
    • 掩码语言建模(Masked LM, MLM):预测句子中被掩盖的词。
    • 下句预测(Next Sentence Prediction, NSP):判断两句话是否连续。
  • 损失函数
    • MLMnn.CrossEntropyLoss
      • 输入是被掩盖的词的 logits,目标是正确词的索引。
    • NSPnn.CrossEntropyLoss(二分类版本)
      • 输出是两个类别的 logits(是/否),目标是 0 或 1。
  • 细节
    • BERT 的总损失是 MLM 损失和 NSP 损失之和。
    • MLM 类似分类任务,但只对掩码位置计算损失。
  • 为什么用交叉熵
    • MLM 是词汇表级别的分类,NSP 是二分类,交叉熵都很合适。
(3) LLaMA:高效语言建模
  • 任务:与 GPT 类似,也是自回归语言建模。
  • 损失函数nn.CrossEntropyLoss
  • 细节
    • LLaMA 是高效优化的 Transformer,训练目标与 GPT 一致,预测下一个 token。
    • 损失计算方式与 GPT 几乎相同,针对大词汇表的多分类。
  • 特别之处
    • LLaMA 可能在预训练中加入了一些正则化(如 label smoothing),但核心仍是交叉熵。
3. 大模型为何偏爱交叉熵?

从 GPT 到 BERT、LLaMA,这些大模型几乎都离不开 nn.CrossEntropyLoss,原因有以下几点:

  • 分类本质:语言建模的目标是预测词或类别,属于分类任务,交叉熵是天然选择。
  • 概率解释:交叉熵衡量预测分布与真实分布的差异,与语言模型的概率生成目标一致。
  • 数值稳定性:PyTorch 的实现结合了 Softmax 和对数运算,避免了数值溢出的问题。
4. 小结:选择合适的损失函数
  • 回归任务 :用 nn.MSELossnn.L1Loss,适合连续值预测。
  • 分类任务 :用 nn.CrossEntropyLoss(多类)或 nn.BCEWithLogitsLoss(二类),适合离散类别预测。
  • 大模型 :语言模型大多用 nn.CrossEntropyLoss,因为它们本质上是"词级分类器"。

PyTorch 的损失函数设计得很贴心,既有基础的数学实现,也有针对深度学习的优化(比如内置 Softmax)。了解这些损失函数的原理和适用场景,能帮你在任务中选对"裁判",让模型训练更高效。

5. 彩蛋:如何调试损失函数?
  • 检查输入 :确保 predtarget 的形状匹配(比如 CrossEntropyLoss 需要 [batch, classes][batch])。
  • 打印中间值 :用 print(loss.item()) 查看损失大小,判断是否合理。
  • 验证梯度 :用 loss.backward() 后检查参数的 .grad,确认优化方向。

希望这篇博客能帮你理清 PyTorch 损失函数的脉络!

后记

2025年2月28日19点04分于上海,在grok3大模型辅助下完成。

相关推荐
我不会编程5552 小时前
Python Cookbook-2.24 在 Mac OSX平台上统计PDF文档的页数
开发语言·python·pdf
胡歌13 小时前
final 关键字在不同上下文中的用法及其名称
开发语言·jvm·python
程序员张小厨4 小时前
【0005】Python变量详解
开发语言·python
Hacker_Oldv5 小时前
Python 爬虫与网络安全有什么关系
爬虫·python·web安全
深蓝海拓5 小时前
PySide(PyQT)重新定义contextMenuEvent()实现鼠标右键弹出菜单
开发语言·python·pyqt
车载诊断技术5 小时前
人工智能AI在汽车设计领域的应用探索
数据库·人工智能·网络协议·架构·汽车·是诊断功能配置的核心
AuGuSt_817 小时前
【深度学习】Hopfield网络:模拟联想记忆
人工智能·深度学习
jndingxin7 小时前
OpenCV计算摄影学(6)高动态范围成像(HDR imaging)
人工智能·opencv·计算机视觉
数据攻城小狮子7 小时前
深入剖析 OpenCV:全面掌握基础操作、图像处理算法与特征匹配
图像处理·python·opencv·算法·计算机视觉
Sol-itude7 小时前
【文献阅读】Collective Decision for Open Set Recognition
论文阅读·人工智能·机器学习·支持向量机