深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss

深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss

flyfish

本系列的主要内容是在2017年所写,GPT使用了交叉熵损失函数,所以就温故而知新,文中代码又用新版的PyTorch写了一遍,在看交叉熵损失函数遇到问题时,可先看链接提供的基础知识,可以有更深的理解。

深入理解交叉熵损失 CrossEntropyLoss - one-hot 编码
深入理解交叉熵损失 CrossEntropyLoss - 对数
深入理解交叉熵损失 CrossEntropyLoss - 概率基础
深入理解交叉熵损失 CrossEntropyLoss - 概率分布
深入理解交叉熵损失 CrossEntropyLoss - 损失函数
深入理解交叉熵损失 CrossEntropyLoss - 归一化
深入理解交叉熵损失 CrossEntropyLoss - 信息论(交叉熵)
深入理解交叉熵损失 CrossEntropyLoss - Softmax
深入理解交叉熵损失 CrossEntropyLoss - nn.LogSoftmax

深入理解交叉熵损失 CrossEntropyLoss - 似然
深入理解交叉熵损失CrossEntropyLoss - 乘积符号在似然函数中的应用

深入理解交叉熵损失 CrossEntropyLoss - nn.NLLLoss

深入理解交叉熵损失 CrossEntropyLoss - nn.CrossEntropyLoss

深入理解交叉熵损失CrossEntropyLoss

在 PyTorch 中, torch.nn.CrossEntropyLoss 是一个常用的 损失函数 ,主要用于多分类任务。它结合了 nn.LogSoftmaxnn.NLLLoss ,并且内部进行了优化以避免 数值稳定性 问题。

具体来说,torch.nn.CrossEntropyLoss 计算的是预测值与目标值之间的交叉熵损失 。对于多分类问题,交叉熵损失是最常用的损失函数,因为它直接衡量了两个概率分布(预测概率分布和实际分布)之间的差异。

LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss

nn.CrossEntropyLoss 在内部已经包含了 LogSoftmax 和 NLLLoss 的操作。

编写代码验证,分别是 LogSoftmax和 NLLLoss两者的结合,对比立使用CrossEntropyLoss。

py 复制代码
import torch
import torch.nn as nn

# 输入张量 (batch_size=2, num_classes=3)
input_tensor = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
# 目标张量 (batch_size=2)
target_tensor = torch.tensor([2, 0])

# 使用 nn.LogSoftmax 和 nn.NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
log_probs = log_softmax(input_tensor)
nll_loss = nn.NLLLoss()
loss = nll_loss(log_probs, target_tensor)
print(f'Loss using LogSoftmax and NLLLoss: {loss.item()}')

# 使用 nn.CrossEntropyLoss
cross_entropy_loss = nn.CrossEntropyLoss()
loss_ce = cross_entropy_loss(input_tensor, target_tensor)
print(f'Loss using CrossEntropyLoss: {loss_ce.item()}')

输出结果

Loss using LogSoftmax and NLLLoss: 1.4076058864593506

Loss using CrossEntropyLoss: 1.4076058864593506

解释

对于单个样本,交叉熵损失的定义如下:

CrossEntropyLoss = − ∑ i = 1 C y i log ⁡ ( y ^ i ) \text{CrossEntropyLoss} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) CrossEntropyLoss=−i=1∑Cyilog(y^i)

其中:

  • C C C 是类别的数量。
  • y i y_i yi 是真实标签的一个one-hot编码(若样本属于类别 i i i,则 y i = 1 y_i = 1 yi=1,否则 y i = 0 y_i = 0 yi=0)。
  • y ^ i \hat{y}_i y^i 是模型预测的第 i i i 类的概率。

直观解释 Softmax和负对数似然

交叉熵损失结合了两个概念:

  1. Softmax
    首先将模型输出的原始分数(logits)通过 softmax 函数转换成概率分布,Softmax 函数将 logits 转换为概率分布。对于一个有 C C C 个类别的分类问题,Softmax 公式如下:

y ^ i = exp ⁡ ( z i ) ∑ j = 1 C exp ⁡ ( z j ) \hat{y}i = \frac{\exp(z_i)}{\sum{j=1}^{C} \exp(z_j)} y^i=∑j=1Cexp(zj)exp(zi)

其中 z i z_i zi 是第 i i i 类的 logit。

  1. 负对数似然
    计算这些概率分布与真实标签之间的负对数似然。在获得概率分布后,交叉熵损失计算真实标签的负对数概率。如果真实标签对应的类别概率很高,损失就小;如果概率很低,损失就大。这驱动模型在训练过程中提高真实标签类别的预测概率。

以下是一个简单的示例,展示如何计算交叉熵损失:

python 复制代码
import torch
import torch.nn as nn

# 假设我们有两个样本,每个样本属于3个类别中的一个
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.3]])
# 真实标签
labels = torch.tensor([0, 1])

# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(f'Cross Entropy Loss: {loss.item()}')

Cross Entropy Loss: 0.37882310152053833

在这个示例中:

  • logits 是模型输出的原始分数。
  • labels 是真实的类别标签。
  • nn.CrossEntropyLoss 会先将 logits 转换为概率分布,然后计算真实标签的负对数似然损失。

二分类问题

二分类交叉熵损失的公式为:

CrossEntropyLoss = − ( y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ) \text{CrossEntropyLoss} = - (y \log(\hat{y}) + (1 - y) \log(1 - \hat{y})) CrossEntropyLoss=−(ylog(y^)+(1−y)log(1−y^))

手动计算步骤

  1. 计算 Sigmoid 激活值

假设:

  • 真实标签 y = 1 y = 1 y=1
  • 模型输出的logits为 z = 1.5 z = 1.5 z=1.5
    计算过程:
    σ ( z ) = 1 1 + exp ⁡ ( − 1.5 ) \sigma(z) = \frac{1}{1 + \exp(-1.5)} σ(z)=1+exp(−1.5)1

我们使用更高精度来计算:
exp ⁡ ( − 1.5 ) ≈ 0.22313016014842982 \exp(-1.5) \approx 0.22313016014842982 exp(−1.5)≈0.22313016014842982
σ ( z ) = 1 1 + 0.22313016014842982 ≈ 1 1.22313016014842982 ≈ 0.8175744761936437 \sigma(z) = \frac{1}{1 + 0.22313016014842982} \approx \frac{1}{1.22313016014842982} \approx 0.8175744761936437 σ(z)=1+0.223130160148429821≈1.223130160148429821≈0.8175744761936437

  1. 计算交叉熵损失

CrossEntropyLoss = − ( y log ⁡ ( σ ( z ) ) + ( 1 − y ) log ⁡ ( 1 − σ ( z ) ) ) \text{CrossEntropyLoss} = - (y \log(\sigma(z)) + (1 - y) \log(1 - \sigma(z))) CrossEntropyLoss=−(ylog(σ(z))+(1−y)log(1−σ(z)))
CrossEntropyLoss = − log ⁡ ( 0.8175744761936437 ) \text{CrossEntropyLoss} = - \log(0.8175744761936437) CrossEntropyLoss=−log(0.8175744761936437)
log ⁡ ( 0.8175744761936437 ) ≈ − 0.2014132779827524 \log(0.8175744761936437) \approx -0.2014132779827524 log(0.8175744761936437)≈−0.2014132779827524
CrossEntropyLoss ≈ 0.2014132779827524 \text{CrossEntropyLoss} \approx 0.2014132779827524 CrossEntropyLoss≈0.2014132779827524

代码实现

python 复制代码
import torch
import torch.nn as nn
import math

# 真实标签和 logits
labels = torch.tensor([1.0])
logits = torch.tensor([1.5])

# 使用 BCEWithLogitsLoss
criterion = nn.BCEWithLogitsLoss()
loss = criterion(logits, labels)
print(f'Binary Classification Cross Entropy Loss: {loss.item()}')

# 手动计算 sigmoid 和交叉熵损失
sigmoid = 1 / (1 + math.exp(-1.5))
manual_loss = - (1 * math.log(sigmoid) + (1 - 1) * math.log(1 - sigmoid))
print(f'Manually Computed Cross Entropy Loss: {manual_loss}')

输出结果

python 复制代码
Binary Classification Cross Entropy Loss: 0.20141397416591644
Manually Computed Cross Entropy Loss: 0.2014132779827524

多分类问题

假设有3个类别:

  • 真实标签为第3类,所以one-hot编码 y = [ 0 , 0 , 1 ] y = [0, 0, 1] y=[0,0,1]。
  • 模型预测的logits为 logits = [ 0.1 , 0.2 , 0.7 ] \text{logits} = [0.1, 0.2, 0.7] logits=[0.1,0.2,0.7]。

手动计算步骤

  1. 计算Softmax
    y ^ i = exp ⁡ ( z i ) ∑ k = 1 C exp ⁡ ( z k ) \hat{y}i = \frac{\exp(z_i)}{\sum{k=1}^{C} \exp(z_k)} y^i=∑k=1Cexp(zk)exp(zi)

具体计算:

y ^ 1 = exp ⁡ ( 0.1 ) exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) \hat{y}_1 = \frac{\exp(0.1)}{\exp(0.1) + \exp(0.2) + \exp(0.7)} y^1=exp(0.1)+exp(0.2)+exp(0.7)exp(0.1)
y ^ 2 = exp ⁡ ( 0.2 ) exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) \hat{y}_2 = \frac{\exp(0.2)}{\exp(0.1) + \exp(0.2) + \exp(0.7)} y^2=exp(0.1)+exp(0.2)+exp(0.7)exp(0.2)
y ^ 3 = exp ⁡ ( 0.7 ) exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) \hat{y}_3 = \frac{\exp(0.7)}{\exp(0.1) + \exp(0.2) + \exp(0.7)} y^3=exp(0.1)+exp(0.2)+exp(0.7)exp(0.7)

计算得到:

exp ⁡ ( 0.1 ) ≈ 1.1052 \exp(0.1) \approx 1.1052 exp(0.1)≈1.1052
exp ⁡ ( 0.2 ) ≈ 1.2214 \exp(0.2) \approx 1.2214 exp(0.2)≈1.2214
exp ⁡ ( 0.7 ) ≈ 2.0138 \exp(0.7) \approx 2.0138 exp(0.7)≈2.0138

总和:

exp ⁡ ( 0.1 ) + exp ⁡ ( 0.2 ) + exp ⁡ ( 0.7 ) ≈ 1.1052 + 1.2214 + 2.0138 = 4.3404 \exp(0.1) + \exp(0.2) + \exp(0.7) \approx 1.1052 + 1.2214 + 2.0138 = 4.3404 exp(0.1)+exp(0.2)+exp(0.7)≈1.1052+1.2214+2.0138=4.3404

各个概率:

y ^ 1 = 1.1052 4.3404 ≈ 0.2546 \hat{y}_1 = \frac{1.1052}{4.3404} \approx 0.2546 y^1=4.34041.1052≈0.2546
y ^ 2 = 1.2214 4.3404 ≈ 0.2814 \hat{y}_2 = \frac{1.2214}{4.3404} \approx 0.2814 y^2=4.34041.2214≈0.2814
y ^ 3 = 2.0138 4.3404 ≈ 0.4639 \hat{y}_3 = \frac{2.0138}{4.3404} \approx 0.4639 y^3=4.34042.0138≈0.4639

  1. 计算交叉熵损失
    CrossEntropyLoss = − ( 0 ⋅ log ⁡ ( 0.2546 ) + 0 ⋅ log ⁡ ( 0.2814 ) + 1 ⋅ log ⁡ ( 0.4639 ) ) \text{CrossEntropyLoss} = - (0 \cdot \log(0.2546) + 0 \cdot \log(0.2814) + 1 \cdot \log(0.4639)) CrossEntropyLoss=−(0⋅log(0.2546)+0⋅log(0.2814)+1⋅log(0.4639))
    CrossEntropyLoss = − log ⁡ ( 0.4639 ) ≈ 0.769 \text{CrossEntropyLoss} = - \log(0.4639) \approx 0.769 CrossEntropyLoss=−log(0.4639)≈0.769

代码验证

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# 模拟输入的 logits 和真实标签
logits = torch.tensor([[0.1, 0.2, 0.7]], requires_grad=True)
labels = torch.tensor([2])

# 使用 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(f'Computed Cross Entropy Loss (using nn.CrossEntropyLoss): {loss.item()}')

# 手动计算 softmax 和交叉熵损失
softmax_probs = F.softmax(logits, dim=1)
manual_loss = -torch.log(softmax_probs[0, labels])
print(f'Manually Computed Cross Entropy Loss: {manual_loss.item()}')

输出结果

python 复制代码
Computed Cross Entropy Loss (using nn.CrossEntropyLoss): 0.7679495811462402
Manually Computed Cross Entropy Loss: 0.7679495811462402

注意在多分类问题的代码中,我们提供了logits而不是softmax后的概率,因为nn.CrossEntropyLoss会在内部应用softmax。

在二分类问题中,我们可以使用 nn.BCEWithLogitsLoss,它会在内部应用 Sigmoid 激活函数,并计算二分类的交叉熵损失。

在多分类问题中,我们可以使用 nn.CrossEntropyLoss,它会在内部应用 Softmax 激活函数,并计算多分类的交叉熵损失

相关推荐
绘梨衣54713 小时前
Agentic RAG、传统RAG、ReAct、Function Calling 核心关系
人工智能·chatgpt·tensorflow
qq56801807613 小时前
国内如何使用Gemini 3.1 Pro?
chatgpt·ai作画·ai编程·ai写作·agi
whyfail14 小时前
AI 平台订阅套餐 Coding Plan 、Token Plan对比指南(2026年4月)
人工智能·ai·chatgpt·订阅套餐·平台对比
小龙报14 小时前
【Coze-AI智能体平台】低代码省时高效:Coze 应用开发全流程指南
java·人工智能·python·深度学习·低代码·chatgpt·交互
大写的老王14 小时前
2026年AI工具终极对比:豆包、DeepSeek、元宝、ChatGPT、Cursor,谁才是你的最佳搭档?
人工智能·chatgpt
Agent产品评测局14 小时前
流程型制造业生产节拍智能调整,落地方法与案例 | 2026工业AI Agent架构全景解析
人工智能·ai·chatgpt·架构
蔡俊锋16 小时前
AI进阶运营:从信息爆炸到智能掌控
人工智能·chatgpt·ai进阶运营
yaodong51817 小时前
ChatGPT-Image-2 绘图实战:国内镜像站 Prompt 工程指南及多模型对比
gpt·chatgpt·prompt
迪娜学姐1 天前
ChatGPT image 2 科研绘图实测分享
人工智能·chatgpt