ptorch中的nn.KLDivLoss:KL散度损失

KL散度被广泛应用于度量分布之间的差异,其形式为: D K L ( P ∣ ∣ Q ) = ∑ i = 1 N p i l o g p i q i = ∑ i = 1 N p i ∗ ( l o g p i − l o g q i ) D_{KL}(P||Q)=\sum_{i=1}^{N}p_ilog\frac{p_i}{q_i}=\sum_{i=1}^{N}p_i*(logp_i-logq_i) DKL(P∣∣Q)=i=1∑Npilogqipi=i=1∑Npi∗(logpi−logqi)  pytorch中给出了两种不同的方法用于计算KL散度,分别是torch.nn.functional.kl_div()和torch.nn.KLDivLoss(),两者计算效果类似,区别无非是直接计算和作为损失函数类,我们重点看torch.nn.KLDivLoss(),在深度学习中是一个很常见的损失。官方文档地址为:

nn.KLDivLoss:https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss\>

函数定义:

c 复制代码
torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)

其中,size_average与reduce参数已被弃用,具体功能由参数reduction代替;reduction:指定损失输出的形式,有四种选择:none|mean|batchmean|sum。none:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出;batchmean:将输出的总和除以batchsize;sum:将得到的损失求和再输出;log_target:指定是否对输入的target使用log操作。

在使用上,nn.KLDivLoss和交叉熵损失是不同的,对于pytorch中的交叉熵损失torch.nn.CrossEntropyLoss,我们给进的网络预测结果不需要进行softmax处理,给进的labels可以仅仅是一个label的list,函数中内置了对标签进行的ont-hot操作,而在nn.KLDivLoss中并没有这种操作,因此,对于nn.KLDivLoss输入的两个分布input和target,我们首先要对其进行softmax操作。此外,当log_target参数设定为False时,计算方式为: P ∗ ( l o g P − Q ) P*(logP-Q) P∗(logP−Q),这与定义式的结果不同,因此,还需要对input取对数操作(在官方文档中也有提及,建议将input映射到对数空间,防止数值下溢),一个示例代码为:

c 复制代码
import torch
import torch.nn.Functional as F
torch.nn.KLDivLoss(F.softmax(Q).log(), F.softmax(P), reduction='mean')
相关推荐
喝过期的拉菲3 小时前
如何使用 Pytorch Lightning 启用早停机制
pytorch·lightning·早停机制
kk爱闹3 小时前
【挑战14天学完python和pytorch】- day01
android·pytorch·python
Yo_Becky7 小时前
【PyTorch】PyTorch预训练模型缓存位置迁移,也可拓展应用于其他文件的迁移
人工智能·pytorch·经验分享·笔记·python·程序人生·其他
xinxiangwangzhi_7 小时前
pytorch底层原理学习--PyTorch 架构梳理
人工智能·pytorch·架构
FF-Studio8 小时前
【硬核数学 · LLM篇】3.1 Transformer之心:自注意力机制的线性代数解构《从零构建机器学习、深度学习到LLM的数学认知》
人工智能·pytorch·深度学习·线性代数·机器学习·数学建模·transformer
盼小辉丶14 小时前
PyTorch实战(14)——条件生成对抗网络(conditional GAN,cGAN)
人工智能·pytorch·生成对抗网络
Gyoku Mint16 小时前
深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹
人工智能·pytorch·python·深度学习·神经网络·算法·聚类
郭庆汝1 天前
pytorch、torchvision与python版本对应关系
人工智能·pytorch·python
cver1231 天前
野生动物检测数据集介绍-5,138张图片 野生动物保护监测 智能狩猎相机系统 生态研究与调查
人工智能·pytorch·深度学习·目标检测·计算机视觉·目标跟踪
点我头像干啥1 天前
用 PyTorch 构建液态神经网络(LNN):下一代动态深度学习模型
pytorch·深度学习·神经网络