ptorch中的nn.KLDivLoss:KL散度损失

KL散度被广泛应用于度量分布之间的差异,其形式为: D K L ( P ∣ ∣ Q ) = ∑ i = 1 N p i l o g p i q i = ∑ i = 1 N p i ∗ ( l o g p i − l o g q i ) D_{KL}(P||Q)=\sum_{i=1}^{N}p_ilog\frac{p_i}{q_i}=\sum_{i=1}^{N}p_i*(logp_i-logq_i) DKL(P∣∣Q)=i=1∑Npilogqipi=i=1∑Npi∗(logpi−logqi)  pytorch中给出了两种不同的方法用于计算KL散度,分别是torch.nn.functional.kl_div()和torch.nn.KLDivLoss(),两者计算效果类似,区别无非是直接计算和作为损失函数类,我们重点看torch.nn.KLDivLoss(),在深度学习中是一个很常见的损失。官方文档地址为:

nn.KLDivLoss:https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss\>

函数定义:

c 复制代码
torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)

其中,size_average与reduce参数已被弃用,具体功能由参数reduction代替;reduction:指定损失输出的形式,有四种选择:none|mean|batchmean|sum。none:损失不做任何处理,直接输出一个数组;mean:将得到的损失求平均值再输出;batchmean:将输出的总和除以batchsize;sum:将得到的损失求和再输出;log_target:指定是否对输入的target使用log操作。

在使用上,nn.KLDivLoss和交叉熵损失是不同的,对于pytorch中的交叉熵损失torch.nn.CrossEntropyLoss,我们给进的网络预测结果不需要进行softmax处理,给进的labels可以仅仅是一个label的list,函数中内置了对标签进行的ont-hot操作,而在nn.KLDivLoss中并没有这种操作,因此,对于nn.KLDivLoss输入的两个分布input和target,我们首先要对其进行softmax操作。此外,当log_target参数设定为False时,计算方式为: P ∗ ( l o g P − Q ) P*(logP-Q) P∗(logP−Q),这与定义式的结果不同,因此,还需要对input取对数操作(在官方文档中也有提及,建议将input映射到对数空间,防止数值下溢),一个示例代码为:

c 复制代码
import torch
import torch.nn.Functional as F
torch.nn.KLDivLoss(F.softmax(Q).log(), F.softmax(P), reduction='mean')
相关推荐
蹦蹦跳跳真可爱5893 小时前
Python----深度学习(基于深度学习Pytroch簇分类,圆环分类,月牙分类)
人工智能·pytorch·python·深度学习·分类
Sherlock Ma10 小时前
PDFMathTranslate:基于LLM的PDF文档翻译及双语对照的工具【使用教程】
人工智能·pytorch·语言模型·pdf·大模型·机器翻译·deepseek
谦行13 小时前
工欲善其事,必先利其器—— PyTorch 深度学习基础操作
pytorch·深度学习·ai编程
开心快乐幸福一家人16 小时前
Spark-SQL与Hive集成及数据分析实践
人工智能·pytorch·深度学习
什么芮.16 小时前
spark-streaming
pytorch·sql·spark·kafka·scala
小宋加油啊17 小时前
深度学习小记(包括pytorch 还有一些神经网络架构)
pytorch·深度学习·神经网络
IT_Octopus20 小时前
AI工程pytorch小白TorchServe部署模型服务
人工智能·pytorch·python
北上ing1 天前
从FP32到BF16,再到混合精度的全景解析
人工智能·pytorch·深度学习·计算机视觉·stable diffusion
蔗理苦1 天前
2025-04-24 Python&深度学习4—— 计算图与动态图机制
开发语言·pytorch·python·深度学习·计算图
Y1nhl1 天前
搜广推校招面经八十一
开发语言·人工智能·pytorch·深度学习·机器学习·推荐算法·搜索算法