深度学习中6种loss函数Pytorch API调用示例

自定义数据

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

batchsize=2
num_class=4

logits=torch.randn(batchsize,num_class)
target=torch.randint(num_class,size=(batchsize,))#delta目标分布
target_logits=torch.randn(batchsize,num_class)#非delta目标分布

交叉熵 CrossEntropyLoss

python 复制代码
## 1. CE Loss  交叉熵

ce_loss_fn=torch.nn.CrossEntropyLoss()
ce_loss=ce_loss_fn(logits,target)
print("ce_loss1:",ce_loss)

ce_loss=ce_loss_fn(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss2:",ce_loss)

负对数似然 NLLLoss

python 复制代码
## 2. NLL Loss 负对数似然
nll_fn=torch.nn.NLLLoss()
nll_loss=nll_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),target)
print("nll_loss:",nll_loss)

####CE LOSS value = NLL LOSS value

KL散度 KLDivLoss

python 复制代码
## 3. KL loss  KL散度
kl_loss_fn=torch.nn.KLDivLoss()
kl_loss=kl_loss_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1))
print("kl_loss:", kl_loss)

交叉熵=信息熵+KL散度 CE=IE+KLD

python 复制代码
## 4. 验证 CE=IE+KLD
print("===========================")
ce_loss_fn_sample=torch.nn.CrossEntropyLoss(reduction="none")#单独对每个样本求交叉熵
ce_loss_sample=ce_loss_fn_sample(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss_sample:",ce_loss_sample)

kl_loss_fn_sample=torch.nn.KLDivLoss(reduction="none")
kl_loss_sample=kl_loss_fn_sample(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1)).sum(-1)
print("kl_loss_sample:",kl_loss_sample)

target_information_entropy=torch.distributions.Categorical(probs=torch.softmax(target_logits,dim=-1)).entropy()
print("target_information_entropy:", target_information_entropy)# IE为常数,如果目标分布是delta分布IE=0

print(torch.allclose(ce_loss_sample,kl_loss_sample+target_information_entropy))#对比两个浮点张量是否相等

二分类交叉熵 BCELoss

python 复制代码
## 5. BCE Loss  二分类交叉熵
print("===========================")
bce_loss_fn=torch.nn.BCELoss()
logits=torch.rand(batchsize)
prob_1=torch.sigmoid(logits)
target=torch.randint(2,size=(batchsize,))
bce_loss=bce_loss_fn(prob_1,target.float())
print("bce_loss:",bce_loss)

### NLL Loss是BCE Loss的一般形式,用NLL Loss代替BCE loss做二分类
prob_0=1-prob_1.unsqueeze(-1)
prob=torch.cat([prob_0,prob_1.unsqueeze(-1)],dim=-1)
nll_loss_binary=nll_fn(torch.log(prob),target)
print("nll_loss_binary:",nll_loss_binary)

余弦相似度 CosineEmbeddingLoss

python 复制代码
## 6. cosine similarity loss 余弦相似度
cosine_loss_fn=torch.nn.CosineEmbeddingLoss()
v1=torch.randn(batchsize,512)
v2=torch.randn(batchsize,512)
target=torch.randint(2,size=(batchsize,))*2-1 #生成【-1,1】之间的随机值
cosine_loss=cosine_loss_fn(v1,v2,target)
print("consine_loss:",cosine_loss)
相关推荐
手机不死我是天子2 分钟前
拆解大模型二:Transformer 最核心的设计,其实你高中就学过
人工智能·llm
gustt3 分钟前
MCP协议进阶:构建多工具Agent实现智能查询与浏览器交互
人工智能·agent·mcp
Halo咯咯7 分钟前
Claude Code 的工程哲学:缓存与工具设计的真实教训 | 经验分享
人工智能
风象南1 小时前
最适合新手先装的 20 个 OpenClaw Skills 来了!
人工智能
小兵张健12 小时前
35岁程序员的春天来了
人工智能
大怪v12 小时前
AI抢饭?前端佬:我要验牌!
前端·人工智能·程序员
冬奇Lab12 小时前
OpenClaw 深度解析(六):节点、Canvas 与子 Agent
人工智能·开源
刀法如飞14 小时前
AI提示词框架深度对比分析
人工智能·ai编程
IT_陈寒15 小时前
Python开发者必知的5大性能陷阱:90%的人都踩过的坑!
前端·人工智能·后端
1G16 小时前
openclaw控制浏览器/自动化的playwright MCP + Mcporter方案实现
人工智能