pytorch nn.CrossEntropyLoss

PyTorch 新手必读:彻底搞懂 nn.CrossEntropyLoss

在深度学习的分类任务中(比如识别猫狗),损失函数(交叉熵损失)------nn.CrossEntropyLoss几乎是你的必选项。但是,它的用法和原理往往和新手的直觉不太一样,因为 PyTorch 的设计包含了一些"隐藏细节"。这篇博客将手把手带你理解并掌握它,保证没有任何复杂的数学公式堆砌,让你从入门到精通。

1. 常见误区

这是新手最常犯的错误:在传递给 Loss 函数之前,自己手动加了 Softmax。

在理论课上,我们学到的流程通常是:

神经网络输出 -> Softmax 归一化(变成概率) -> 计算交叉熵

但是在 PyTorch 中,nn.CrossEntropyLoss 已经在内部帮你把 Softmax 这一步做好了!

  • PyTorch 的设计nn.CrossEntropyLoss = LogSoftmax + NLLLoss (负对数似然损失)。
  • 你的做法 :直接把模型输出的原始分数(Logits) 扔进去即可。

注意 :如果你在模型最后一层加了 nn.Softmax(),然后再用这个损失函数,那你实际上是做了两次 Softmax,这会导致梯度数值不稳定,模型训练变慢甚至无法收敛。

2. 输入数据Shape 详解

理解数据的**维度(Shape) 类型(Type)**是使用 PyTorch 的关键。我们假设你在做一个 3 分类任务(比如:猫、狗、鸟),batch_size 为 2。

模型的预测值(Input)
  • 内容:未经过 Softmax 的原始分数(Logits)。可以是正数也可以是负数。
  • 维度[Batch_Size, Num_Classes] (N, C)
  • 数据类型float32 (Float)
真实标签(Target)
  • 内容 :正确类别的索引号(不是 One-hot 编码!)。比如 0 代表猫,1 代表狗,2 代表鸟。
  • 维度[Batch_Size] (N)
  • 数据类型int64 (Long) ------ 注意:必须是整数!

新手易错点 :千万不要把标签转成 One-hot 编码(比如 [0, 1, 0])传进去,PyTorch 只需要你告诉它正确答案是第几个(比如 1)就行了。

3. 手把手代码实战

好了,理论讲完了,我们直接看代码。你可以复制这段代码到你的编辑器里运行一下。

python 复制代码
import torch
import torch.nn as nn

# 1. 定义损失函数
criterion = nn.CrossEntropyLoss()

# 2. 模拟模型的输出 (Logits)
# 假设 batch_size=2 (两张图),共有 3 个类别 (猫、狗、鸟)
# 这些数字是模型最后一层全连接层的输出,没有经过 Softmax
predictions = torch.tensor([
    [2.0, 0.5, -1.0],  # 第一张图:模型认为最有可能是第0类(2.0最大)
    [0.1, 1.5, 0.8]    # 第二张图:模型认为最有可能是第1类(1.5最大)
], dtype=torch.float)

# 3. 定义真实的标签 (Labels)
# 注意:这里直接写类别的索引,不需要 One-hot
# 第一张图是第0类(猫),第二张图是第2类(鸟)
labels = torch.tensor([0, 2], dtype=torch.long) 

# 4. 计算损失
loss = criterion(predictions, labels)

print(f"模型预测形状: {predictions.shape}") # torch.Size([2, 3])
print(f"真实标签形状: {labels.shape}")      # torch.Size([2])
print(f"计算出的损失值: {loss.item():.4f}")

# 解析:
# 第一张图预测很准(2.0对应标签0),损失会比较小
# 第二张图预测错了(1.5对应1,但标签是2),损失会比较大
# CrossEntropyLoss 会把这两个样本的损失平均一下
4. 进阶技巧:处理数据不平衡

如果你的数据集中,"狗"的照片特别多,"猫"的照片特别少,模型很容易学会"偷懒"------不管看什么都猜是狗,因为这样准确率也不低。

这时候,nn.CrossEntropyLossweight 参数就派上用场了。我们可以给稀少的类别更高的权重,惩罚模型对稀少类别的误判。

python 复制代码
# 假设类别 0 (猫) 样本很少,类别 1 (狗) 样本很多
# 我们给类别 0 赋予更高的权重 (比如 2.0),给类别 1 正常的权重 (1.0)
class_weights = torch.tensor([2.0, 1.0, 1.0])

# 将权重传入损失函数
criterion_weighted = nn.CrossEntropyLoss(weight=class_weights)

# 之后的计算步骤一模一样
loss = criterion_weighted(predictions, labels)
总结:避坑指南

为了让你在实际开发中少走弯路,请记住以下三点:

  1. 不要加 Softmax:模型的最后一层直接输出 Linear 层的结果即可,不要画蛇添足。
  2. 标签要是索引 :Target 应该是 [0, 2, 1] 这样的整数索引,而不是 One-hot 向量。
  3. 类型要对齐:预测值是 Float,标签值是 Long。

希望这篇教程能帮你彻底搞定 PyTorch 中的交叉熵损失函数!如果你在实际操作中遇到了问题,欢迎在评论区留言,我们一起探讨。Happy Coding!

相关推荐
雍凉明月夜1 小时前
Ⅳ人工智能机器学习之监督学习的概述
人工智能·深度学习·学习
三块可乐两块冰1 小时前
【第二十二周】机器学习笔记二十一
人工智能·笔记·机器学习
持续学习的程序员+11 小时前
强化学习阶段性总结
人工智能·算法
永远都不秃头的程序员(互关)1 小时前
昇腾CANN算子开发实践:从入门到性能优化
人工智能·python·机器学习
ConardLi1 小时前
分析了 100 万亿 Token 后,得出的几个关于 AI 的真相
前端·人工智能·后端
明月照山海-1 小时前
机器学习周报二十五
人工智能·机器学习
AI Echoes1 小时前
LangGraph 需求转换图架构的技巧-CRAG实现
人工智能·python·langchain·prompt·agent
AI Echoes1 小时前
LangChain LLM函数调用使用技巧与应用场景
人工智能·python·langchain·prompt·agent
Allen正心正念20251 小时前
生成式多模态图像模型返回格式的处理方法
人工智能