PyTorch 新手必读:彻底搞懂 nn.CrossEntropyLoss
在深度学习的分类任务中(比如识别猫狗),损失函数(交叉熵损失)------nn.CrossEntropyLoss几乎是你的必选项。但是,它的用法和原理往往和新手的直觉不太一样,因为 PyTorch 的设计包含了一些"隐藏细节"。这篇博客将手把手带你理解并掌握它,保证没有任何复杂的数学公式堆砌,让你从入门到精通。
1. 常见误区
这是新手最常犯的错误:在传递给 Loss 函数之前,自己手动加了 Softmax。
在理论课上,我们学到的流程通常是:
神经网络输出 -> Softmax 归一化(变成概率) -> 计算交叉熵
但是在 PyTorch 中,nn.CrossEntropyLoss 已经在内部帮你把 Softmax 这一步做好了!
- PyTorch 的设计 :
nn.CrossEntropyLoss=LogSoftmax+NLLLoss(负对数似然损失)。 - 你的做法 :直接把模型输出的原始分数(Logits) 扔进去即可。
注意 :如果你在模型最后一层加了 nn.Softmax(),然后再用这个损失函数,那你实际上是做了两次 Softmax,这会导致梯度数值不稳定,模型训练变慢甚至无法收敛。
2. 输入数据Shape 详解
理解数据的**维度(Shape)和 类型(Type)**是使用 PyTorch 的关键。我们假设你在做一个 3 分类任务(比如:猫、狗、鸟),batch_size 为 2。
模型的预测值(Input)
- 内容:未经过 Softmax 的原始分数(Logits)。可以是正数也可以是负数。
- 维度 :
[Batch_Size, Num_Classes](N, C) - 数据类型 :
float32(Float)
真实标签(Target)
- 内容 :正确类别的索引号(不是 One-hot 编码!)。比如 0 代表猫,1 代表狗,2 代表鸟。
- 维度 :
[Batch_Size](N) - 数据类型 :
int64(Long) ------ 注意:必须是整数!
新手易错点 :千万不要把标签转成 One-hot 编码(比如
[0, 1, 0])传进去,PyTorch 只需要你告诉它正确答案是第几个(比如1)就行了。
3. 手把手代码实战
好了,理论讲完了,我们直接看代码。你可以复制这段代码到你的编辑器里运行一下。
python
import torch
import torch.nn as nn
# 1. 定义损失函数
criterion = nn.CrossEntropyLoss()
# 2. 模拟模型的输出 (Logits)
# 假设 batch_size=2 (两张图),共有 3 个类别 (猫、狗、鸟)
# 这些数字是模型最后一层全连接层的输出,没有经过 Softmax
predictions = torch.tensor([
[2.0, 0.5, -1.0], # 第一张图:模型认为最有可能是第0类(2.0最大)
[0.1, 1.5, 0.8] # 第二张图:模型认为最有可能是第1类(1.5最大)
], dtype=torch.float)
# 3. 定义真实的标签 (Labels)
# 注意:这里直接写类别的索引,不需要 One-hot
# 第一张图是第0类(猫),第二张图是第2类(鸟)
labels = torch.tensor([0, 2], dtype=torch.long)
# 4. 计算损失
loss = criterion(predictions, labels)
print(f"模型预测形状: {predictions.shape}") # torch.Size([2, 3])
print(f"真实标签形状: {labels.shape}") # torch.Size([2])
print(f"计算出的损失值: {loss.item():.4f}")
# 解析:
# 第一张图预测很准(2.0对应标签0),损失会比较小
# 第二张图预测错了(1.5对应1,但标签是2),损失会比较大
# CrossEntropyLoss 会把这两个样本的损失平均一下
4. 进阶技巧:处理数据不平衡
如果你的数据集中,"狗"的照片特别多,"猫"的照片特别少,模型很容易学会"偷懒"------不管看什么都猜是狗,因为这样准确率也不低。
这时候,nn.CrossEntropyLoss 的 weight 参数就派上用场了。我们可以给稀少的类别更高的权重,惩罚模型对稀少类别的误判。
python
# 假设类别 0 (猫) 样本很少,类别 1 (狗) 样本很多
# 我们给类别 0 赋予更高的权重 (比如 2.0),给类别 1 正常的权重 (1.0)
class_weights = torch.tensor([2.0, 1.0, 1.0])
# 将权重传入损失函数
criterion_weighted = nn.CrossEntropyLoss(weight=class_weights)
# 之后的计算步骤一模一样
loss = criterion_weighted(predictions, labels)
总结:避坑指南
为了让你在实际开发中少走弯路,请记住以下三点:
- 不要加 Softmax:模型的最后一层直接输出 Linear 层的结果即可,不要画蛇添足。
- 标签要是索引 :Target 应该是
[0, 2, 1]这样的整数索引,而不是 One-hot 向量。 - 类型要对齐:预测值是 Float,标签值是 Long。
希望这篇教程能帮你彻底搞定 PyTorch 中的交叉熵损失函数!如果你在实际操作中遇到了问题,欢迎在评论区留言,我们一起探讨。Happy Coding!