pytorch nn.CrossEntropyLoss

PyTorch 新手必读:彻底搞懂 nn.CrossEntropyLoss

在深度学习的分类任务中(比如识别猫狗),损失函数(交叉熵损失)------nn.CrossEntropyLoss几乎是你的必选项。但是,它的用法和原理往往和新手的直觉不太一样,因为 PyTorch 的设计包含了一些"隐藏细节"。这篇博客将手把手带你理解并掌握它,保证没有任何复杂的数学公式堆砌,让你从入门到精通。

1. 常见误区

这是新手最常犯的错误:在传递给 Loss 函数之前,自己手动加了 Softmax。

在理论课上,我们学到的流程通常是:

神经网络输出 -> Softmax 归一化(变成概率) -> 计算交叉熵

但是在 PyTorch 中,nn.CrossEntropyLoss 已经在内部帮你把 Softmax 这一步做好了!

  • PyTorch 的设计nn.CrossEntropyLoss = LogSoftmax + NLLLoss (负对数似然损失)。
  • 你的做法 :直接把模型输出的原始分数(Logits) 扔进去即可。

注意 :如果你在模型最后一层加了 nn.Softmax(),然后再用这个损失函数,那你实际上是做了两次 Softmax,这会导致梯度数值不稳定,模型训练变慢甚至无法收敛。

2. 输入数据Shape 详解

理解数据的**维度(Shape) 类型(Type)**是使用 PyTorch 的关键。我们假设你在做一个 3 分类任务(比如:猫、狗、鸟),batch_size 为 2。

模型的预测值(Input)
  • 内容:未经过 Softmax 的原始分数(Logits)。可以是正数也可以是负数。
  • 维度[Batch_Size, Num_Classes] (N, C)
  • 数据类型float32 (Float)
真实标签(Target)
  • 内容 :正确类别的索引号(不是 One-hot 编码!)。比如 0 代表猫,1 代表狗,2 代表鸟。
  • 维度[Batch_Size] (N)
  • 数据类型int64 (Long) ------ 注意:必须是整数!

新手易错点 :千万不要把标签转成 One-hot 编码(比如 [0, 1, 0])传进去,PyTorch 只需要你告诉它正确答案是第几个(比如 1)就行了。

3. 手把手代码实战

好了,理论讲完了,我们直接看代码。你可以复制这段代码到你的编辑器里运行一下。

python 复制代码
import torch
import torch.nn as nn

# 1. 定义损失函数
criterion = nn.CrossEntropyLoss()

# 2. 模拟模型的输出 (Logits)
# 假设 batch_size=2 (两张图),共有 3 个类别 (猫、狗、鸟)
# 这些数字是模型最后一层全连接层的输出,没有经过 Softmax
predictions = torch.tensor([
    [2.0, 0.5, -1.0],  # 第一张图:模型认为最有可能是第0类(2.0最大)
    [0.1, 1.5, 0.8]    # 第二张图:模型认为最有可能是第1类(1.5最大)
], dtype=torch.float)

# 3. 定义真实的标签 (Labels)
# 注意:这里直接写类别的索引,不需要 One-hot
# 第一张图是第0类(猫),第二张图是第2类(鸟)
labels = torch.tensor([0, 2], dtype=torch.long) 

# 4. 计算损失
loss = criterion(predictions, labels)

print(f"模型预测形状: {predictions.shape}") # torch.Size([2, 3])
print(f"真实标签形状: {labels.shape}")      # torch.Size([2])
print(f"计算出的损失值: {loss.item():.4f}")

# 解析:
# 第一张图预测很准(2.0对应标签0),损失会比较小
# 第二张图预测错了(1.5对应1,但标签是2),损失会比较大
# CrossEntropyLoss 会把这两个样本的损失平均一下
4. 进阶技巧:处理数据不平衡

如果你的数据集中,"狗"的照片特别多,"猫"的照片特别少,模型很容易学会"偷懒"------不管看什么都猜是狗,因为这样准确率也不低。

这时候,nn.CrossEntropyLossweight 参数就派上用场了。我们可以给稀少的类别更高的权重,惩罚模型对稀少类别的误判。

python 复制代码
# 假设类别 0 (猫) 样本很少,类别 1 (狗) 样本很多
# 我们给类别 0 赋予更高的权重 (比如 2.0),给类别 1 正常的权重 (1.0)
class_weights = torch.tensor([2.0, 1.0, 1.0])

# 将权重传入损失函数
criterion_weighted = nn.CrossEntropyLoss(weight=class_weights)

# 之后的计算步骤一模一样
loss = criterion_weighted(predictions, labels)
总结:避坑指南

为了让你在实际开发中少走弯路,请记住以下三点:

  1. 不要加 Softmax:模型的最后一层直接输出 Linear 层的结果即可,不要画蛇添足。
  2. 标签要是索引 :Target 应该是 [0, 2, 1] 这样的整数索引,而不是 One-hot 向量。
  3. 类型要对齐:预测值是 Float,标签值是 Long。

希望这篇教程能帮你彻底搞定 PyTorch 中的交叉熵损失函数!如果你在实际操作中遇到了问题,欢迎在评论区留言,我们一起探讨。Happy Coding!

相关推荐
一次旅行8 分钟前
HyperTool:突破传统工具调用限制,让Agent更高效执行复杂任务
人工智能
陈天伟教授39 分钟前
图解人工智能(58)人工智能应用-围棋国手
人工智能·语音识别·机器翻译
闻道参看42 分钟前
2026年AI优质企业培训系统综合测评:合规管控/数据量化
人工智能
老虾头1 小时前
科技贴近烟火:本地化 AI,赋能各行各业日常经营
人工智能
毒爪的小新1 小时前
Linux 环境极速部署 vLLM:从零搭建生产级大模型推理服务
linux·人工智能·ai·语言模型·vllm
老大白菜1 小时前
25美元,DIY开源可穿戴智能AI眼镜:Arduino+乐鑫ESP32+DeepSeek项目
人工智能
岁月宁静2 小时前
RAG 文档摄入全链路,从原理到生产落地
vue.js·人工智能·python
小和尚同志2 小时前
AI 自动化测试探索(一):Playwright MCP
前端·人工智能·aigc
硅谷秋水2 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶