pytorch nn.CrossEntropyLoss

PyTorch 新手必读:彻底搞懂 nn.CrossEntropyLoss

在深度学习的分类任务中(比如识别猫狗),损失函数(交叉熵损失)------nn.CrossEntropyLoss几乎是你的必选项。但是,它的用法和原理往往和新手的直觉不太一样,因为 PyTorch 的设计包含了一些"隐藏细节"。这篇博客将手把手带你理解并掌握它,保证没有任何复杂的数学公式堆砌,让你从入门到精通。

1. 常见误区

这是新手最常犯的错误:在传递给 Loss 函数之前,自己手动加了 Softmax。

在理论课上,我们学到的流程通常是:

神经网络输出 -> Softmax 归一化(变成概率) -> 计算交叉熵

但是在 PyTorch 中,nn.CrossEntropyLoss 已经在内部帮你把 Softmax 这一步做好了!

  • PyTorch 的设计nn.CrossEntropyLoss = LogSoftmax + NLLLoss (负对数似然损失)。
  • 你的做法 :直接把模型输出的原始分数(Logits) 扔进去即可。

注意 :如果你在模型最后一层加了 nn.Softmax(),然后再用这个损失函数,那你实际上是做了两次 Softmax,这会导致梯度数值不稳定,模型训练变慢甚至无法收敛。

2. 输入数据Shape 详解

理解数据的**维度(Shape) 类型(Type)**是使用 PyTorch 的关键。我们假设你在做一个 3 分类任务(比如:猫、狗、鸟),batch_size 为 2。

模型的预测值(Input)
  • 内容:未经过 Softmax 的原始分数(Logits)。可以是正数也可以是负数。
  • 维度[Batch_Size, Num_Classes] (N, C)
  • 数据类型float32 (Float)
真实标签(Target)
  • 内容 :正确类别的索引号(不是 One-hot 编码!)。比如 0 代表猫,1 代表狗,2 代表鸟。
  • 维度[Batch_Size] (N)
  • 数据类型int64 (Long) ------ 注意:必须是整数!

新手易错点 :千万不要把标签转成 One-hot 编码(比如 [0, 1, 0])传进去,PyTorch 只需要你告诉它正确答案是第几个(比如 1)就行了。

3. 手把手代码实战

好了,理论讲完了,我们直接看代码。你可以复制这段代码到你的编辑器里运行一下。

python 复制代码
import torch
import torch.nn as nn

# 1. 定义损失函数
criterion = nn.CrossEntropyLoss()

# 2. 模拟模型的输出 (Logits)
# 假设 batch_size=2 (两张图),共有 3 个类别 (猫、狗、鸟)
# 这些数字是模型最后一层全连接层的输出,没有经过 Softmax
predictions = torch.tensor([
    [2.0, 0.5, -1.0],  # 第一张图:模型认为最有可能是第0类(2.0最大)
    [0.1, 1.5, 0.8]    # 第二张图:模型认为最有可能是第1类(1.5最大)
], dtype=torch.float)

# 3. 定义真实的标签 (Labels)
# 注意:这里直接写类别的索引,不需要 One-hot
# 第一张图是第0类(猫),第二张图是第2类(鸟)
labels = torch.tensor([0, 2], dtype=torch.long) 

# 4. 计算损失
loss = criterion(predictions, labels)

print(f"模型预测形状: {predictions.shape}") # torch.Size([2, 3])
print(f"真实标签形状: {labels.shape}")      # torch.Size([2])
print(f"计算出的损失值: {loss.item():.4f}")

# 解析:
# 第一张图预测很准(2.0对应标签0),损失会比较小
# 第二张图预测错了(1.5对应1,但标签是2),损失会比较大
# CrossEntropyLoss 会把这两个样本的损失平均一下
4. 进阶技巧:处理数据不平衡

如果你的数据集中,"狗"的照片特别多,"猫"的照片特别少,模型很容易学会"偷懒"------不管看什么都猜是狗,因为这样准确率也不低。

这时候,nn.CrossEntropyLossweight 参数就派上用场了。我们可以给稀少的类别更高的权重,惩罚模型对稀少类别的误判。

python 复制代码
# 假设类别 0 (猫) 样本很少,类别 1 (狗) 样本很多
# 我们给类别 0 赋予更高的权重 (比如 2.0),给类别 1 正常的权重 (1.0)
class_weights = torch.tensor([2.0, 1.0, 1.0])

# 将权重传入损失函数
criterion_weighted = nn.CrossEntropyLoss(weight=class_weights)

# 之后的计算步骤一模一样
loss = criterion_weighted(predictions, labels)
总结:避坑指南

为了让你在实际开发中少走弯路,请记住以下三点:

  1. 不要加 Softmax:模型的最后一层直接输出 Linear 层的结果即可,不要画蛇添足。
  2. 标签要是索引 :Target 应该是 [0, 2, 1] 这样的整数索引,而不是 One-hot 向量。
  3. 类型要对齐:预测值是 Float,标签值是 Long。

希望这篇教程能帮你彻底搞定 PyTorch 中的交叉熵损失函数!如果你在实际操作中遇到了问题,欢迎在评论区留言,我们一起探讨。Happy Coding!

相关推荐
数据皮皮侠AI9 小时前
中国城市可再生能源数据集(2005-2021)|顶刊 Sci Data 11 种能源面板
大数据·人工智能·笔记·能源·1024程序员节
G31135422739 小时前
如何用 QClaw 龙虾做一个规律作息健康助理 Agent
大数据·人工智能·ai·云计算
幂律智能9 小时前
零售行业合同管理数智化转型解决方案
大数据·人工智能·零售
旺财矿工9 小时前
零基础搭建 OpenClaw 2.6.6 Win11 本地化运行环境
人工智能·openclaw·小龙虾·龙虾·openclaw安装包
九成宫9 小时前
动手学深度学习PyTorch版初步安装过程
人工智能·pytorch·深度学习
Traving Yu9 小时前
Prompt提示词工程
人工智能·prompt
NOCSAH9 小时前
统好AI CRM功能解析:智能录入与跟进
人工智能
He少年9 小时前
【AI 辅助编程做设备数据采集:一个真实项目的迭代复盘(OpenSpec 驱动)】
人工智能
华万通信king9 小时前
WorkBuddy知识库企业级搭建实战:从零到生产级别的完整路径
大数据·人工智能
测试员周周9 小时前
【AI测试系统】第3篇:AI生成的测试用例太“水”?14年老兵:规则引擎+AI才是王炸组合
人工智能·python·测试