pytorch nn.CrossEntropyLoss

PyTorch 新手必读:彻底搞懂 nn.CrossEntropyLoss

在深度学习的分类任务中(比如识别猫狗),损失函数(交叉熵损失)------nn.CrossEntropyLoss几乎是你的必选项。但是,它的用法和原理往往和新手的直觉不太一样,因为 PyTorch 的设计包含了一些"隐藏细节"。这篇博客将手把手带你理解并掌握它,保证没有任何复杂的数学公式堆砌,让你从入门到精通。

1. 常见误区

这是新手最常犯的错误:在传递给 Loss 函数之前,自己手动加了 Softmax。

在理论课上,我们学到的流程通常是:

神经网络输出 -> Softmax 归一化(变成概率) -> 计算交叉熵

但是在 PyTorch 中,nn.CrossEntropyLoss 已经在内部帮你把 Softmax 这一步做好了!

  • PyTorch 的设计nn.CrossEntropyLoss = LogSoftmax + NLLLoss (负对数似然损失)。
  • 你的做法 :直接把模型输出的原始分数(Logits) 扔进去即可。

注意 :如果你在模型最后一层加了 nn.Softmax(),然后再用这个损失函数,那你实际上是做了两次 Softmax,这会导致梯度数值不稳定,模型训练变慢甚至无法收敛。

2. 输入数据Shape 详解

理解数据的**维度(Shape) 类型(Type)**是使用 PyTorch 的关键。我们假设你在做一个 3 分类任务(比如:猫、狗、鸟),batch_size 为 2。

模型的预测值(Input)
  • 内容:未经过 Softmax 的原始分数(Logits)。可以是正数也可以是负数。
  • 维度[Batch_Size, Num_Classes] (N, C)
  • 数据类型float32 (Float)
真实标签(Target)
  • 内容 :正确类别的索引号(不是 One-hot 编码!)。比如 0 代表猫,1 代表狗,2 代表鸟。
  • 维度[Batch_Size] (N)
  • 数据类型int64 (Long) ------ 注意:必须是整数!

新手易错点 :千万不要把标签转成 One-hot 编码(比如 [0, 1, 0])传进去,PyTorch 只需要你告诉它正确答案是第几个(比如 1)就行了。

3. 手把手代码实战

好了,理论讲完了,我们直接看代码。你可以复制这段代码到你的编辑器里运行一下。

python 复制代码
import torch
import torch.nn as nn

# 1. 定义损失函数
criterion = nn.CrossEntropyLoss()

# 2. 模拟模型的输出 (Logits)
# 假设 batch_size=2 (两张图),共有 3 个类别 (猫、狗、鸟)
# 这些数字是模型最后一层全连接层的输出,没有经过 Softmax
predictions = torch.tensor([
    [2.0, 0.5, -1.0],  # 第一张图:模型认为最有可能是第0类(2.0最大)
    [0.1, 1.5, 0.8]    # 第二张图:模型认为最有可能是第1类(1.5最大)
], dtype=torch.float)

# 3. 定义真实的标签 (Labels)
# 注意:这里直接写类别的索引,不需要 One-hot
# 第一张图是第0类(猫),第二张图是第2类(鸟)
labels = torch.tensor([0, 2], dtype=torch.long) 

# 4. 计算损失
loss = criterion(predictions, labels)

print(f"模型预测形状: {predictions.shape}") # torch.Size([2, 3])
print(f"真实标签形状: {labels.shape}")      # torch.Size([2])
print(f"计算出的损失值: {loss.item():.4f}")

# 解析:
# 第一张图预测很准(2.0对应标签0),损失会比较小
# 第二张图预测错了(1.5对应1,但标签是2),损失会比较大
# CrossEntropyLoss 会把这两个样本的损失平均一下
4. 进阶技巧:处理数据不平衡

如果你的数据集中,"狗"的照片特别多,"猫"的照片特别少,模型很容易学会"偷懒"------不管看什么都猜是狗,因为这样准确率也不低。

这时候,nn.CrossEntropyLossweight 参数就派上用场了。我们可以给稀少的类别更高的权重,惩罚模型对稀少类别的误判。

python 复制代码
# 假设类别 0 (猫) 样本很少,类别 1 (狗) 样本很多
# 我们给类别 0 赋予更高的权重 (比如 2.0),给类别 1 正常的权重 (1.0)
class_weights = torch.tensor([2.0, 1.0, 1.0])

# 将权重传入损失函数
criterion_weighted = nn.CrossEntropyLoss(weight=class_weights)

# 之后的计算步骤一模一样
loss = criterion_weighted(predictions, labels)
总结:避坑指南

为了让你在实际开发中少走弯路,请记住以下三点:

  1. 不要加 Softmax:模型的最后一层直接输出 Linear 层的结果即可,不要画蛇添足。
  2. 标签要是索引 :Target 应该是 [0, 2, 1] 这样的整数索引,而不是 One-hot 向量。
  3. 类型要对齐:预测值是 Float,标签值是 Long。

希望这篇教程能帮你彻底搞定 PyTorch 中的交叉熵损失函数!如果你在实际操作中遇到了问题,欢迎在评论区留言,我们一起探讨。Happy Coding!

相关推荐
卤代烃7 分钟前
🦾 可为与不可为:CDP 视角下的 Browser 控制边界
前端·人工智能·浏览器
ggabb16 分钟前
海南封关:锚定中国制造2025,破解产业转移生死局
大数据·人工智能
_XU22 分钟前
AI工具如何重塑我的开发日常
前端·人工智能·深度学习
Blossom.1181 小时前
Prompt工程与思维链优化实战:从零构建动态Few-Shot与CoT推理引擎
人工智能·分布式·python·智能手机·django·prompt·边缘计算
zxsz_com_cn2 小时前
设备预测性维护典型案例:中讯烛龙赋能高端制造降本增效
人工智能
人工智能培训2 小时前
图神经网络初探(1)
人工智能·深度学习·知识图谱·群体智能·智能体
love530love3 小时前
Windows 11 下 Z-Image-Turbo 完整部署与 Flash Attention 2.8.3 本地编译复盘
人工智能·windows·python·aigc·flash-attn·z-image·cuda加速
雪下的新火3 小时前
AI工具-Hyper3D
人工智能·aigc·blender·ai工具·笔记分享
Das13 小时前
【机器学习】01_模型选择与评估
人工智能·算法·机器学习