Pytorch CrossEntropyLoss() 原理和用法详解

文章目录

  • [1. 前言](#1. 前言)
  • [2. 交叉熵的含义](#2. 交叉熵的含义)
  • [3. 举例计算交叉熵](#3. 举例计算交叉熵)
  • [4. CrossEntropyLoss 计算过程](#4. CrossEntropyLoss 计算过程)
    • [4.1 具体过程](#4.1 具体过程)
    • [4.2 nn.NLLLoss() 的用法](#4.2 nn.NLLLoss() 的用法)
  • [5. 参考](#5. 参考)

1. 前言

在 PyTorch 中,CrossEntropyLoss() 是一个用于计算交叉熵损失(Cross-Entropy Loss)的损失函数。它通常用于多类别分类任务中,特别是当类别之间不平衡或者样本数目不均衡时。

官方文档用法参考这里

2. 交叉熵的含义

交叉熵(Cross-Entropy)是一种用于比较两个概率分布之间差异的度量。在机器学习中,交叉熵通常用作损失函数,用于衡量模型预测与真实标签之间的差异,尤其在分类任务中广泛使用。

假设有两个概率分布 P P P 和 Q Q Q,其中 P P P 表示真实分布, Q Q Q 表示模型的预测分布。这两个分布都是离散的,通常用于表示类别的概率分布。交叉熵损失函数的计算方式如下:

H ( P , Q ) = − ∑ i P ( i ) log ⁡ ( Q ( i ) ) H(P, Q) = - \sum_{i} P(i) \log(Q(i)) H(P,Q)=−i∑P(i)log(Q(i))

其中, i i i 表示类别的索引, P ( i ) P(i) P(i) 和 Q ( i ) Q(i) Q(i) 分别表示真实分布和预测分布中第 i i i 个类别的概率。交叉熵衡量了在真实分布下观察到的事件的平均信息量,与预测分布 Q Q Q 相对应。

在分类任务中,通常使用交叉熵损失函数来衡量模型预测的概率分布与真实标签的差异。在训练过程中,模型的目标是最小化交叉熵损失,使得模型的预测分布尽可能接近真实分布。

交叉熵越小,越接近真实模型。当模型的预测与真实标签完全一致时,交叉熵达到最小值为 0。

3. 举例计算交叉熵

假设我们有一个分类任务,共有 3 个类别,并且模型的预测结果和真实标签如下:

  • 5个样本所属的真实标签(Ground Truth):[1, 0, 2, 1, 2]
  • 模型的预测概率分布:
    • 类别 0 的预测概率分布:[0.2, 0.6, 0.2]
    • 类别 1 的预测概率分布:[0.7, 0.2, 0.1]
    • 类别 2 的预测概率分布:[0.1, 0.1, 0.8]

首先,我们需要计算每个样本的交叉熵损失,然后将它们求和并除以样本数量来计算平均损失。计算过程如下:

(1)对于第一个样本(真实标签为 1):

  • 真实标签概率分布:[0, 1, 0]
  • 模型预测概率分布:[0.7, 0.2, 0.1]
  • 交叉熵损失:-1 * (1 * log(0.7) + 0 * log(0.2) + 0 * log(0.1)) ≈ 0.36

(2)对于第二个样本(真实标签为 0):

  • 真实标签概率分布:[1, 0, 0]
  • 模型预测概率分布:[0.2, 0.6, 0.2]
  • 交叉熵损失:-1 * (0 * log(0.2) + 1 * log(0.6) + 0 * log(0.2)) ≈ 0.51

(3)对于第三个样本(真实标签为 2):

  • 真实标签概率分布:[0, 0, 1]
  • 模型预测概率分布:[0.1, 0.1, 0.8]
  • 交叉熵损失:-1 * (0 * log(0.1) + 0 * log(0.1) + 1 * log(0.8)) ≈ 0.22

(4)对于第四个样本(真实标签为 1):

  • 真实标签概率分布:[0, 1, 0]
  • 模型预测概率分布:[0.7, 0.2, 0.1]
  • 交叉熵损失:-1 * (1 * log(0.7) + 0 * log(0.2) + 0 * log(0.1)) ≈ 0.36

(5)对于第五个样本(真实标签为 2):

  • 真实标签概率分布:[0, 0, 1]
  • 模型预测概率分布:[0.1, 0.1, 0.8]
  • 交叉熵损失:-1 * (0 * log(0.1) + 0 * log(0.1) + 1 * log(0.8)) ≈ 0.22

最后,将每个样本的交叉熵损失相加,并除以样本数量得到平均损失:
平均损失 = 0.36 + 0.51 + 0.22 + 0.36 + 0.22 5 ≈ 0.334 \text{平均损失} = \frac{0.36 + 0.51 + 0.22 + 0.36 + 0.22}{5} \approx 0.334 平均损失=50.36+0.51+0.22+0.36+0.22≈0.334

所以,该多分类任务的平均交叉熵损失约为 0.334。

4. CrossEntropyLoss 计算过程

4.1 具体过程

Pytorch 中 CrossEntropyLoss() 函数包含以下步骤:

  1. softmax
  2. log
  3. NLLLoss

以下是验证流程:

python 复制代码
import torch
import torch.nn as nn

_input = torch.randn(4, 3)
print('input:\n', _input)

target = torch.tensor([1, 2, 0, 1])  # 设置输出具体值 

################# 输出:#################
input:
 tensor([[-0.0251, -1.0660, -1.2555],
        [ 0.4511,  1.4464,  0.9722],
        [ 0.3108,  0.4180, -0.4181],
        [ 1.0811, -1.6097, -0.6413]])
python 复制代码
# 计算输入softmax
softmax_f = nn.Softmax(dim=1)
soft_output = softmax_f(_input)
print('softmax_output:\n', soft_output)

# 在softmax的基础上取log
log_output = torch.log(soft_output)
print('log_output:\n', log_output)

################# 输出:#################
softmax_output:
 tensor([[0.6078, 0.2146, 0.1776],
        [0.1855, 0.5020, 0.3124],
        [0.3853, 0.4289, 0.1859],
        [0.8023, 0.0544, 0.1433]])
log_output:
 tensor([[-0.4979, -1.5388, -1.7284],
        [-1.6845, -0.6891, -1.1633],
        [-0.9538, -0.8466, -1.6828],
        [-0.2203, -2.9111, -1.9427]])
python 复制代码
# softmax+log与nn.LogSoftmaxloss的结果是一致的。
logsoftmax_func = nn.LogSoftmax(dim=1)
logsoftmax_output = logsoftmax_func(_input)
print('logsoftmax_output:\n', logsoftmax_output)

################# 输出:#################
logsoftmax_output:
 tensor([[-0.4979, -1.5388, -1.7284],
        [-1.6845, -0.6891, -1.1633],
        [-0.9538, -0.8466, -1.6828],
        [-0.2203, -2.9111, -1.9427]])
python 复制代码
# 先用nn.NLLLoss()计算
nllloss_func = nn.NLLLoss()
nlloss_output = nllloss_func(logsoftmax_output, target)
print('nlloss_output:\n', nlloss_output)

# 和nn.CrossEntropyLoss()的结果是一样的
crossentropyloss = nn.CrossEntropyLoss()
crossentropyloss_output = crossentropyloss(_input, target)
print('crossentropyloss_output:\n', crossentropyloss_output)

################# 输出:#################
nlloss_output:
 tensor(1.6417)
crossentropyloss_output:
 tensor(1.6417)

上述过程验证了CrossEntropyLoss() 函数包含以下步骤:

  1. softmax
  2. log
  3. NLLLoss

4.2 nn.NLLLoss() 的用法

下面有必要介绍 nn.NLLLoss() 的用法。在 PyTorch 中,NLLLoss() 是一个用于计算负对数似然损失(Negative Log Likelihood Loss)的损失函数。官方文档参考这里

例子:

python 复制代码
from torch import nn
import torch

# 初始化
nllloss = nn.NLLLoss() # 可选参数中有 reduction='mean', 'sum', 默认mean

# 两个张量,一个是预测向量,一个是真实标签label
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
v = nllloss(predict, label)
print(v)

################# 输出:#################
tensor(-6.)

解释:

上面的label 表示依次在 predict 选取值。例如,上面的 label = [1,2],那么在predict[0]中选取3,在predict[1]中选取9。然后求平均值并取负: − ( 3 + 9 ) / 2 = − 6 -(3+9)/2=-6 −(3+9)/2=−6

5. 参考

nn.CrossEntropyLoss
nn.NLLLoss

欢迎关注本人,我是喜欢搞事的程序猿;一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

相关推荐
网易独家音乐人Mike Zhou2 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot
安静读书2 小时前
Python解析视频FPS(帧率)、分辨率信息
python·opencv·音视频
小陈phd2 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao3 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
小二·4 小时前
java基础面试题笔记(基础篇)
java·笔记·python
小喵要摸鱼5 小时前
Python 神经网络项目常用语法
python
一念之坤6 小时前
零基础学Python之数据结构 -- 01篇
数据结构·python
wxl7812277 小时前
如何使用本地大模型做数据分析
python·数据挖掘·数据分析·代码解释器
NoneCoder7 小时前
Python入门(12)--数据处理
开发语言·python
ZHOU_WUYI7 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt