【深度学习】pytorch计算KL散度、kl_div

使用pytorch进行KL散度计算,可以使用pytorch的kl_div函数

假设y为真实分布,x为预测分布。

java 复制代码
import torch
import torch.nn.functional as F

# 定义两组数据
tensor1 = torch.tensor([[0.1, 0.2, 0.3, 0.2, 0.2],
                        [0.2, 0.1, 0.2, 0.3, 0.2],
                        [0.2, 0.3, 0.1, 0.2, 0.2],
                        [0.2, 0.2, 0.3, 0.1, 0.2],
                        [0.2, 0.2, 0.2, 0.3, 0.1],
                        [0.1, 0.2, 0.2, 0.2, 0.3],
                        [0.3, 0.2, 0.1, 0.2, 0.2],
                        [0.2, 0.3, 0.2, 0.1, 0.2],
                        [0.1, 0.2, 0.2, 0.3, 0.2],
                        [0.2, 0.1, 0.3, 0.2, 0.2],
                        [0.2, 0.3, 0.2, 0.2, 0.1],
                        [0.1, 0.1, 0.2, 0.3, 0.3],
                        [0.3, 0.2, 0.2, 0.1, 0.2],
                        [0.2, 0.3, 0.1, 0.2, 0.2],
                        [0.1, 0.3, 0.2, 0.2, 0.2],
                        [0.2, 0.2, 0.1, 0.3, 0.2]])

tensor2 = torch.tensor([[0.2, 0.1, 0.3, 0.2, 0.2],
                        [0.3, 0.2, 0.2, 0.1, 0.2],
                        [0.2, 0.3, 0.2, 0.2, 0.1],
                        [0.1, 0.2, 0.3, 0.2, 0.2],
                        [0.2, 0.2, 0.1, 0.2, 0.3],
                        [0.3, 0.2, 0.2, 0.3, 0.0],
                        [0.2, 0.3, 0.1, 0.2, 0.2],
                        [0.1, 0.2, 0.2, 0.3, 0.2],
                        [0.2, 0.1, 0.3, 0.2, 0.2],
                        [0.2, 0.3, 0.2, 0.1, 0.2],
                        [0.1, 0.2, 0.3, 0.2, 0.2],
                        [0.2, 0.3, 0.2, 0.2, 0.1],
                        [0.2, 0.1, 0.2, 0.3, 0.2],
                        [0.3, 0.2, 0.2, 0.1, 0.2],
                        [0.2, 0.2, 0.3, 0.2, 0.1],
                        [0.1, 0.3, 0.2, 0.2, 0.2]])

# 计算两组张量之间的 KL 散度
logp_x = F.log_softmax(tensor1, dim=-1)
p_y = F.softmax(tensor2, dim=-1)

kl_divergence = F.kl_div(logp_x, p_y, reduction='batchmean')
kl_sum = F.kl_div(logp_x, p_y, reduction='sum')
print("KL散度(batchmean)值为:", kl_divergence.item())
print("KL散度(sum)值为:", kl_sum.item())

打印结果:

复制代码
KL散度(batchmean)值为: 0.00508523266762495
KL散度(sum)值为: 0.0813637226819992  

其中kl_div接收三个参数,第一个为预测分布,第二个为真实分布,第三个为reduction。(其实还有其他参数,只是基本用不到)

这里有一些细节需要注意,第一个参数与第二个参数都要进行softmax(dim=-1),目的是使两个概率分布的所有值之和都为1,若不进行此操作,如果x或y概率分布所有值的和大于1,则可能会使计算的KL为负数。

softmax接收一个参数dim,dim=-1表示在最后一维进行softmax操作。

除此之外,第一个参数还要进行log()操作(至于为什么,大概是为了方便pytorch的代码组织,pytorch定义的损失函数都调用handle_torch_function函数,方便权重控制等),才能得到正确结果。还有说是因为要用y指导x,所以求x的对数概率,y的概率

相关推荐
努力还债的学术吗喽1 分钟前
【速通】深度学习模型调试系统化方法论:从问题定位到性能优化
人工智能·深度学习·学习·调试·模型·方法论
云边云科技32 分钟前
零售行业新店网络零接触部署场景下,如何选择SDWAN
运维·服务器·网络·人工智能·安全·边缘计算·零售
audyxiao00142 分钟前
为了更强大的空间智能,如何将2D图像转换成完整、具有真实尺度和外观的3D场景?
人工智能·计算机视觉·3d·iccv·空间智能
伊织code1 小时前
PyTorch API 6
pytorch·api·ddp
Monkey的自我迭代1 小时前
机器学习总复习
人工智能·机器学习
大千AI助手1 小时前
GitHub Copilot:AI编程助手的架构演进与真实世界影响
人工智能·深度学习·大模型·github·copilot·ai编程·codex
用户5191495848451 小时前
耶稣蓝队集体防护Bash脚本:多模块协同防御实战
人工智能·aigc
☺����1 小时前
实现自己的AI视频监控系统-第一章-视频拉流与解码1
人工智能·python·音视频
Black_Rock_br2 小时前
本地部署的终极多面手:Qwen2.5-Omni-3B,视频剪、音频混、图像生、文本写全搞定
人工智能·音视频
用什么都重名2 小时前
《GPT-OSS 模型全解析:OpenAI 回归开源的 Mixture-of-Experts 之路》
人工智能·大模型·openai·gpt-oss