【pytorch学习】交叉熵损失函数 nn.CrossEntropyLoss() = nn.LogSoftmax(dim=1) + nn.NLLLoss()

结论: Pytorch中CrossEntropyLoss()函数的主要是将softmax-log-NLLLoss合并到一块得到的结果。

from: https://zhuanlan.zhihu.com/p/98785902

python 复制代码
import torch
import torch.nn as nn
x_input=torch.randn(3,3)#随机生成输入 
print('x_input:\n',x_input) 
y_target=torch.tensor([1,2,0])#设置输出具体值 print('y_target\n',y_target)

#计算输入softmax,此时可以看到每一行加到一起结果都是1
softmax_func=nn.Softmax(dim=1)
soft_output=softmax_func(x_input)
print('soft_output:\n',soft_output)

#在softmax的基础上取log
log_output=torch.log(soft_output)
print('log_output:\n',log_output)

#对比softmax与log的结合与nn.LogSoftmaxloss(负对数似然损失)的输出结果,发现两者是一致的。
logsoftmax_func=nn.LogSoftmax(dim=1)
logsoftmax_output=logsoftmax_func(x_input)
print('logsoftmax_output:\n',logsoftmax_output)

#pytorch中关于NLLLoss的默认参数配置为:reducetion=True、size_average=True
nllloss_func=nn.NLLLoss()
nlloss_output=nllloss_func(logsoftmax_output,y_target)
print('nlloss_output:\n',nlloss_output)

#直接使用pytorch中的loss_func=nn.CrossEntropyLoss()看与经过NLLLoss的计算是不是一样
crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(x_input,y_target)
print('crossentropyloss_output:\n',crossentropyloss_output)

最后计算得到的结果为:

python 复制代码
x_input:
 tensor([[ 2.8883,  0.1760,  1.0774],
        [ 1.1216, -0.0562,  0.0660],
        [-1.3939, -0.0967,  0.5853]])
y_target
 tensor([1, 2, 0])
soft_output:
 tensor([[0.8131, 0.0540, 0.1329],
        [0.6039, 0.1860, 0.2102],
        [0.0841, 0.3076, 0.6083]])
log_output:
 tensor([[-0.2069, -2.9192, -2.0178],
        [-0.5044, -1.6822, -1.5599],
        [-2.4762, -1.1790, -0.4970]])
logsoftmax_output:
 tensor([[-0.2069, -2.9192, -2.0178],
        [-0.5044, -1.6822, -1.5599],
        [-2.4762, -1.1790, -0.4970]])
nlloss_output:
 tensor(2.3185)
crossentropyloss_output:
 tensor(2.3185)

通过上面的结果可以看出,直接使用pytorch中的loss_func=nn.CrossEntropyLoss()计算得到的结果与softmax-log-NLLLoss计算得到的结果是一致的。

相关推荐
广州灵眸科技有限公司21 小时前
瑞芯微(EASY EAI)RV1126B 千兆以太网电路
服务器·前端·人工智能·python·深度学习
AndrewHZ21 小时前
【大模型通关指南】3. 全球主流大模型全栈对比(含Google I/O最新Gemini,2026.05.20)
人工智能·深度学习·大模型·openai·claude·gemini·deepseek
speop21 小时前
【thorough-pytorch】评价指标
人工智能·pytorch·python
市值水晶21 小时前
海澜之家一季报:主品牌稳了,变量来了
大数据·人工智能
喵喵苗21 小时前
嵌入式和 FPGA 工程师与AI 结合技术提升规划
人工智能·fpga开发
C137的本贾尼21 小时前
图像生成初探:OpenAI 与千帆平台一键出图
人工智能
Python大数据分析@21 小时前
现在怎么去学习AI,在哪里去学习?
人工智能·学习
阿里云大数据AI技术21 小时前
开发者博客|在阿里云 PAI 平台实现规模化的机器人感知强化学习
人工智能·阿里云·机器人·强化学习·nvidia
Luhui Dev21 小时前
业务级 Agent 的 Runtime 设计:从 LangChain 看可靠性工程
人工智能·agent·luhuidev
魔乐社区21 小时前
基于昇腾 MindSpeed LLM 玩转 DeepSeek-V4-Flash
人工智能·开源·大模型