PyTorch交叉熵理解

PyTorch 中的交叉熵损失

CrossEntropyLoss

PyTorch 中使用CrossEntropyLoss 计算交叉熵损失,常用于分类任务。交叉熵损失衡量了模型输出的概率分布与实际标签分布之间的差异,目标是最小化该损失以优化模型。

我们通过一个具体的案例来详细说明 CrossEntropyLoss 的计算过程。

假设我们有一个简单的分类任务,共有 3 个类别。我们有 2 个样本的预测和实际标签。

输入

  • 模型的预测(logits,未经过 softmax 激活)

  • 实际标签

python 复制代码
import torch
import torch.nn as nn

# 模型的预测(logits)
logits = torch.tensor([[2.0, 1.0, 0.1],
                       [0.5, 2.0, 0.3]])

# 实际标签
labels = torch.tensor([0, 2])

计算步骤

  • 步骤 1: Softmax 激活

首先,将 logits 通过 softmax 激活函数转换为概率分布。

python 复制代码
softmax = nn.Softmax(dim=1)
probabilities = softmax(logits)
print(probabilities)

输出

python 复制代码
tensor([[0.6590, 0.2424, 0.0986],
        [0.1587, 0.7113, 0.1299]])
  • 步骤 2: 计算交叉熵

交叉熵损失的计算公式为:

C r o s s E n t r o p y L o s s = − ∑ i = 1 N log ⁡ ( p i , y i ) CrossEntropyLoss=-\sum_{i=1}^{N}{\log{(}}{{p}{i,{{y}{i}}}}) CrossEntropyLoss=−∑i=1Nlog(pi,yi)

其中 N 是样本数量, p i , y i p_{i,y_i} pi,yi是第 i个样本在实际标签 y i y_i yi 位置上的预测概率。

我们手动计算每个样本的交叉熵损失:

  • 对于第一个样本,实际标签为 0,预测概率为 0.6590

l o s s 1 = − log ⁡ ( 0.6590 ) ≈ 0.4171 {{loss}_{1}}=-\log{(}0.6590)\approx 0.4171 loss1=−log(0.6590)≈0.4171

  • 对于第二个样本,实际标签为 2,预测概率为 0.1299

l o s s 2 = − log ⁡ ( 0.1299 ) ≈ 2.0406 {{loss}_{2}}=-\log{(}0.1299)\approx 2.0406 loss2=−log(0.1299)≈2.0406

平均损失为:

m e a n = 0.4171 + 2.0406 2 ≈ 1.2288 mean=\frac{0.4171+2.0406}{2}\approx 1.2288 mean=20.4171+2.0406≈1.2288

  • 步骤 3: 使用 PyTorch 的 CrossEntropyLoss 计算

我们使用 PyTorch 的 CrossEntropyLoss 函数来验证计算结果:

python 复制代码
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(loss.item())

输出

shell 复制代码
1.2288230657577515
  • 步骤4:依据公式使用 PyTorch 计算

依据前面的公式使用 PyTorch 计算来验算结果

python 复制代码
neg_log_p = -torch.log(probabilities)
loss_cal = neg_log_p[torch.arange(neg_log_p.shape[0]), labels].mean()
print(loss_cal.item())

输出

shell 复制代码
1.228823184967041

结果基本一致。

总结

  1. CrossEntropyLoss 接受未经过 softmax 的 logits 作为输入。

  2. 内部首先对 logits 应用 softmax,将其转换为概率分布。

  3. 然后根据实际标签计算交叉熵损失。

相关推荐
称昵写填未6 分钟前
在Pycharm配置conda虚拟环境的Python解释器
开发语言·ide·python·pycharm·conda·anaconda·虚拟环境
m0_743106468 分钟前
nerfstudio以及相关使用记录(长期更新)
python·深度学习·ubuntu·计算机视觉·3d
关山月25 分钟前
Python 列表方法可视化解释
python
3DVisionary26 分钟前
蓝光三维扫描技术:手机闪光灯模块全尺寸3D检测的精准解决方案
python·数码相机·3d·智能手机·蓝光3d扫描技术·非接触、高效率、全尺寸检测·完美适配手机微型零部件
cainiao08060527 分钟前
WPS 接入 DeepSeek-R1 深度实践:打造全能AI办公助手
人工智能·wps·ai办公
艾思科蓝 AiScholar33 分钟前
【ACM 独立出版 | EI 快检索】2025年数据挖掘与项目管理国际研讨会 (DMPM 2025)
人工智能·网络安全·数据挖掘·数据分析·创业创新·数据可视化·数据库管理员
a小胡哦44 分钟前
解锁 AI 核心:神经网络与机器学习知名算法全解析
人工智能·神经网络·机器学习
odoo中国44 分钟前
深度学习 Deep Learning 第1章 深度学习简介
人工智能·深度学习·deep learning
极客天成ScaleFlash1 小时前
极客天成 NVFile 并行文件存储:端到端无缓存新范式,为 AI 训练按下“快进键”
人工智能·缓存
刘刚好科技1 小时前
聚力·突破·共赢|修饰组学服务联盟正式成立,共启协同发展新篇章
大数据·人工智能