Pytorch CrossEntropyLoss() 原理和用法详解

文章目录

  • [1. 前言](#1. 前言)
  • [2. 交叉熵的含义](#2. 交叉熵的含义)
  • [3. 举例计算交叉熵](#3. 举例计算交叉熵)
  • [4. CrossEntropyLoss 计算过程](#4. CrossEntropyLoss 计算过程)
    • [4.1 具体过程](#4.1 具体过程)
    • [4.2 nn.NLLLoss() 的用法](#4.2 nn.NLLLoss() 的用法)
  • [5. 参考](#5. 参考)

1. 前言

在 PyTorch 中,CrossEntropyLoss() 是一个用于计算交叉熵损失(Cross-Entropy Loss)的损失函数。它通常用于多类别分类任务中,特别是当类别之间不平衡或者样本数目不均衡时。

官方文档用法参考这里

2. 交叉熵的含义

交叉熵(Cross-Entropy)是一种用于比较两个概率分布之间差异的度量。在机器学习中,交叉熵通常用作损失函数,用于衡量模型预测与真实标签之间的差异,尤其在分类任务中广泛使用。

假设有两个概率分布 P P P 和 Q Q Q,其中 P P P 表示真实分布, Q Q Q 表示模型的预测分布。这两个分布都是离散的,通常用于表示类别的概率分布。交叉熵损失函数的计算方式如下:

H ( P , Q ) = − ∑ i P ( i ) log ⁡ ( Q ( i ) ) H(P, Q) = - \sum_{i} P(i) \log(Q(i)) H(P,Q)=−i∑P(i)log(Q(i))

其中, i i i 表示类别的索引, P ( i ) P(i) P(i) 和 Q ( i ) Q(i) Q(i) 分别表示真实分布和预测分布中第 i i i 个类别的概率。交叉熵衡量了在真实分布下观察到的事件的平均信息量,与预测分布 Q Q Q 相对应。

在分类任务中,通常使用交叉熵损失函数来衡量模型预测的概率分布与真实标签的差异。在训练过程中,模型的目标是最小化交叉熵损失,使得模型的预测分布尽可能接近真实分布。

交叉熵越小,越接近真实模型。当模型的预测与真实标签完全一致时,交叉熵达到最小值为 0。

3. 举例计算交叉熵

假设我们有一个分类任务,共有 3 个类别,并且模型的预测结果和真实标签如下:

  • 5个样本所属的真实标签(Ground Truth):[1, 0, 2, 1, 2]
  • 模型的预测概率分布:
    • 类别 0 的预测概率分布:[0.2, 0.6, 0.2]
    • 类别 1 的预测概率分布:[0.7, 0.2, 0.1]
    • 类别 2 的预测概率分布:[0.1, 0.1, 0.8]

首先,我们需要计算每个样本的交叉熵损失,然后将它们求和并除以样本数量来计算平均损失。计算过程如下:

(1)对于第一个样本(真实标签为 1):

  • 真实标签概率分布:[0, 1, 0]
  • 模型预测概率分布:[0.7, 0.2, 0.1]
  • 交叉熵损失:-1 * (1 * log(0.7) + 0 * log(0.2) + 0 * log(0.1)) ≈ 0.36

(2)对于第二个样本(真实标签为 0):

  • 真实标签概率分布:[1, 0, 0]
  • 模型预测概率分布:[0.2, 0.6, 0.2]
  • 交叉熵损失:-1 * (0 * log(0.2) + 1 * log(0.6) + 0 * log(0.2)) ≈ 0.51

(3)对于第三个样本(真实标签为 2):

  • 真实标签概率分布:[0, 0, 1]
  • 模型预测概率分布:[0.1, 0.1, 0.8]
  • 交叉熵损失:-1 * (0 * log(0.1) + 0 * log(0.1) + 1 * log(0.8)) ≈ 0.22

(4)对于第四个样本(真实标签为 1):

  • 真实标签概率分布:[0, 1, 0]
  • 模型预测概率分布:[0.7, 0.2, 0.1]
  • 交叉熵损失:-1 * (1 * log(0.7) + 0 * log(0.2) + 0 * log(0.1)) ≈ 0.36

(5)对于第五个样本(真实标签为 2):

  • 真实标签概率分布:[0, 0, 1]
  • 模型预测概率分布:[0.1, 0.1, 0.8]
  • 交叉熵损失:-1 * (0 * log(0.1) + 0 * log(0.1) + 1 * log(0.8)) ≈ 0.22

最后,将每个样本的交叉熵损失相加,并除以样本数量得到平均损失:
平均损失 = 0.36 + 0.51 + 0.22 + 0.36 + 0.22 5 ≈ 0.334 \text{平均损失} = \frac{0.36 + 0.51 + 0.22 + 0.36 + 0.22}{5} \approx 0.334 平均损失=50.36+0.51+0.22+0.36+0.22≈0.334

所以,该多分类任务的平均交叉熵损失约为 0.334。

4. CrossEntropyLoss 计算过程

4.1 具体过程

Pytorch 中 CrossEntropyLoss() 函数包含以下步骤:

  1. softmax
  2. log
  3. NLLLoss

以下是验证流程:

python 复制代码
import torch
import torch.nn as nn

_input = torch.randn(4, 3)
print('input:\n', _input)

target = torch.tensor([1, 2, 0, 1])  # 设置输出具体值 

################# 输出:#################
input:
 tensor([[-0.0251, -1.0660, -1.2555],
        [ 0.4511,  1.4464,  0.9722],
        [ 0.3108,  0.4180, -0.4181],
        [ 1.0811, -1.6097, -0.6413]])
python 复制代码
# 计算输入softmax
softmax_f = nn.Softmax(dim=1)
soft_output = softmax_f(_input)
print('softmax_output:\n', soft_output)

# 在softmax的基础上取log
log_output = torch.log(soft_output)
print('log_output:\n', log_output)

################# 输出:#################
softmax_output:
 tensor([[0.6078, 0.2146, 0.1776],
        [0.1855, 0.5020, 0.3124],
        [0.3853, 0.4289, 0.1859],
        [0.8023, 0.0544, 0.1433]])
log_output:
 tensor([[-0.4979, -1.5388, -1.7284],
        [-1.6845, -0.6891, -1.1633],
        [-0.9538, -0.8466, -1.6828],
        [-0.2203, -2.9111, -1.9427]])
python 复制代码
# softmax+log与nn.LogSoftmaxloss的结果是一致的。
logsoftmax_func = nn.LogSoftmax(dim=1)
logsoftmax_output = logsoftmax_func(_input)
print('logsoftmax_output:\n', logsoftmax_output)

################# 输出:#################
logsoftmax_output:
 tensor([[-0.4979, -1.5388, -1.7284],
        [-1.6845, -0.6891, -1.1633],
        [-0.9538, -0.8466, -1.6828],
        [-0.2203, -2.9111, -1.9427]])
python 复制代码
# 先用nn.NLLLoss()计算
nllloss_func = nn.NLLLoss()
nlloss_output = nllloss_func(logsoftmax_output, target)
print('nlloss_output:\n', nlloss_output)

# 和nn.CrossEntropyLoss()的结果是一样的
crossentropyloss = nn.CrossEntropyLoss()
crossentropyloss_output = crossentropyloss(_input, target)
print('crossentropyloss_output:\n', crossentropyloss_output)

################# 输出:#################
nlloss_output:
 tensor(1.6417)
crossentropyloss_output:
 tensor(1.6417)

上述过程验证了CrossEntropyLoss() 函数包含以下步骤:

  1. softmax
  2. log
  3. NLLLoss

4.2 nn.NLLLoss() 的用法

下面有必要介绍 nn.NLLLoss() 的用法。在 PyTorch 中,NLLLoss() 是一个用于计算负对数似然损失(Negative Log Likelihood Loss)的损失函数。官方文档参考这里

例子:

python 复制代码
from torch import nn
import torch

# 初始化
nllloss = nn.NLLLoss() # 可选参数中有 reduction='mean', 'sum', 默认mean

# 两个张量,一个是预测向量,一个是真实标签label
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
v = nllloss(predict, label)
print(v)

################# 输出:#################
tensor(-6.)

解释:

上面的label 表示依次在 predict 选取值。例如,上面的 label = [1,2],那么在predict[0]中选取3,在predict[1]中选取9。然后求平均值并取负: − ( 3 + 9 ) / 2 = − 6 -(3+9)/2=-6 −(3+9)/2=−6

5. 参考

nn.CrossEntropyLoss
nn.NLLLoss

欢迎关注本人,我是喜欢搞事的程序猿;一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

相关推荐
biter00881 分钟前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖8 分钟前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
qq_5290252925 分钟前
Torch.gather
python·深度学习·机器学习
数据小爬虫@26 分钟前
如何高效利用Python爬虫按关键字搜索苏宁商品
开发语言·爬虫·python
Cachel wood1 小时前
python round四舍五入和decimal库精确四舍五入
java·linux·前端·数据库·vue.js·python·前端框架
IT古董1 小时前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比1 小时前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
終不似少年遊*1 小时前
pyecharts
python·信息可视化·数据分析·学习笔记·pyecharts·使用技巧
Python之栈1 小时前
【无标题】
数据库·python·mysql