Pytorch损失函数-torch.nn.NLLLoss()

一、简介

1.1 nn.CrossEntropyLoss

交叉熵损失函数的定义如下:

就是我们预测的概率的对数与标签的乘积,当qk->1的时候,它的损失接近零。

1.2 nn.NLLLoss

官方文档中介绍称: nn.NLLLoss输入是一个对数概率向量和一个目标标签.

它与nn.CrossEntropyLoss的关系可以描述为:

python 复制代码
softmax(x)+log(x)+nn.NLLLoss  ==>  nn.CrossEntropyLoss

二、计算示例

2.1 softmax 函数

python 复制代码
import math
z = [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]
z_exp = [math.exp(i) for i in z]  
print(z_exp)  # Result: [2.72, 7.39, 20.09, 54.6, 2.72, 7.39, 20.09] 
sum_z_exp = sum(z_exp)  
print(sum_z_exp)  # Result: 114.98 
softmax = [round(i / sum_z_exp, 3) for i in z_exp]
print(softmax)  # Result: [0.024, 0.064, 0.175, 0.475, 0.024, 0.064, 0.175]

2.2 nn.NLLLoss

此时,nn.NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。

2.2.1 计算log概率

python 复制代码
import torch
input=torch.randn(3,3)
soft_input = torch.nn.Softmax(dim=0)
soft_input(input)
Out[20]: 
tensor([[0.7284, 0.7364, 0.3343],
        [0.1565, 0.0365, 0.0408],
        [0.1150, 0.2270, 0.6250]])

#对softmax结果取log
torch.log(soft_input(input))
Out[21]: 
tensor([[-0.3168, -0.3059, -1.0958],
        [-1.8546, -3.3093, -3.1995],
        [-2.1625, -1.4827, -0.4701]])

2.2.2 计算nill loss

假设标签是[0,1,2],第一行取第0个元素,第二行取第1个,第三行取第2个,去掉负号,即[0.3168,3.3093,0.4701],求平均值,就可以得到损失值。

python 复制代码
(0.3168+3.3093+0.4701)/3
Out[22]: 1.3654000000000002

#验证一下

loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(0.1365)

2.3 nn.CrossEntropyLoss 结果对比

python 复制代码
loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(-0.1399)
loss =torch.nn.CrossEntropyLoss()
input = torch.tensor([[ 1.1879,  1.0780,  0.5312],
        [-0.3499, -1.9253, -1.5725],
        [-0.6578, -0.0987,  1.1570]])
target = torch.tensor([0,1,2])
loss(input,target)
Out[30]: tensor(0.1365)

以上为全部实验验证两个loss函数之间的关系!!!

参考原文链接:https://blog.csdn.net/Jeremy_lf/article/details/102725285

相关推荐
jndingxin9 分钟前
OpenCV 图形API(63)图像结构分析和形状描述符------计算图像中非零像素的边界框函数boundingRect()
人工智能·opencv·计算机视觉
旧故新长15 分钟前
支持Function Call的本地ollama模型对比评测-》开发代理agent
人工智能·深度学习·机器学习
明月与玄武22 分钟前
Python编程的真谛:超越语法,理解编程本质
python·编程语言
CodeCraft Studio24 分钟前
Excel处理控件Aspose.Cells教程:使用 Python 在 Excel 中进行数据验
开发语言·python·excel
微学AI27 分钟前
融合注意力机制和BiGRU的电力领域发电量预测项目研究,并给出相关代码
人工智能·深度学习·自然语言处理·注意力机制·bigru
知来者逆38 分钟前
计算机视觉——速度与精度的完美结合的实时目标检测算法RF-DETR详解
图像处理·人工智能·深度学习·算法·目标检测·计算机视觉·rf-detr
一勺汤41 分钟前
YOLOv11改进-双Backbone架构:利用双backbone提高yolo11目标检测的精度
人工智能·yolo·双backbone·double backbone·yolo11 backbone·yolo 双backbone
武汉唯众智创44 分钟前
高职人工智能技术应用专业(计算机视觉方向)实训室解决方案
人工智能·计算机视觉·人工智能实训室·计算机视觉实训室·人工智能计算机视觉实训室
Johny_Zhao1 小时前
MySQL 高可用集群搭建部署
linux·人工智能·mysql·信息安全·云计算·shell·yum源·系统运维·itsm
拾忆-eleven1 小时前
C语言实战:用Pygame打造高难度水果消消乐游戏
c语言·python·pygame