深度学习中6种loss函数Pytorch API调用示例

自定义数据

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

batchsize=2
num_class=4

logits=torch.randn(batchsize,num_class)
target=torch.randint(num_class,size=(batchsize,))#delta目标分布
target_logits=torch.randn(batchsize,num_class)#非delta目标分布

交叉熵 CrossEntropyLoss

python 复制代码
## 1. CE Loss  交叉熵

ce_loss_fn=torch.nn.CrossEntropyLoss()
ce_loss=ce_loss_fn(logits,target)
print("ce_loss1:",ce_loss)

ce_loss=ce_loss_fn(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss2:",ce_loss)

负对数似然 NLLLoss

python 复制代码
## 2. NLL Loss 负对数似然
nll_fn=torch.nn.NLLLoss()
nll_loss=nll_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),target)
print("nll_loss:",nll_loss)

####CE LOSS value = NLL LOSS value

KL散度 KLDivLoss

python 复制代码
## 3. KL loss  KL散度
kl_loss_fn=torch.nn.KLDivLoss()
kl_loss=kl_loss_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1))
print("kl_loss:", kl_loss)

交叉熵=信息熵+KL散度 CE=IE+KLD

python 复制代码
## 4. 验证 CE=IE+KLD
print("===========================")
ce_loss_fn_sample=torch.nn.CrossEntropyLoss(reduction="none")#单独对每个样本求交叉熵
ce_loss_sample=ce_loss_fn_sample(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss_sample:",ce_loss_sample)

kl_loss_fn_sample=torch.nn.KLDivLoss(reduction="none")
kl_loss_sample=kl_loss_fn_sample(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1)).sum(-1)
print("kl_loss_sample:",kl_loss_sample)

target_information_entropy=torch.distributions.Categorical(probs=torch.softmax(target_logits,dim=-1)).entropy()
print("target_information_entropy:", target_information_entropy)# IE为常数,如果目标分布是delta分布IE=0

print(torch.allclose(ce_loss_sample,kl_loss_sample+target_information_entropy))#对比两个浮点张量是否相等

二分类交叉熵 BCELoss

python 复制代码
## 5. BCE Loss  二分类交叉熵
print("===========================")
bce_loss_fn=torch.nn.BCELoss()
logits=torch.rand(batchsize)
prob_1=torch.sigmoid(logits)
target=torch.randint(2,size=(batchsize,))
bce_loss=bce_loss_fn(prob_1,target.float())
print("bce_loss:",bce_loss)

### NLL Loss是BCE Loss的一般形式,用NLL Loss代替BCE loss做二分类
prob_0=1-prob_1.unsqueeze(-1)
prob=torch.cat([prob_0,prob_1.unsqueeze(-1)],dim=-1)
nll_loss_binary=nll_fn(torch.log(prob),target)
print("nll_loss_binary:",nll_loss_binary)

余弦相似度 CosineEmbeddingLoss

python 复制代码
## 6. cosine similarity loss 余弦相似度
cosine_loss_fn=torch.nn.CosineEmbeddingLoss()
v1=torch.randn(batchsize,512)
v2=torch.randn(batchsize,512)
target=torch.randint(2,size=(batchsize,))*2-1 #生成【-1,1】之间的随机值
cosine_loss=cosine_loss_fn(v1,v2,target)
print("consine_loss:",cosine_loss)
相关推荐
红衣小蛇妖35 分钟前
神经网络-Day46
人工智能·深度学习·神经网络
带电的小王1 小时前
【动手学深度学习】3.1. 线性回归
人工智能·深度学习·线性回归
谢尔登1 小时前
结合 AI 生成 mermaid、plantuml 等图表
人工智能
VR最前沿1 小时前
【应用】Ghost Dance:利用惯性动捕构建虚拟舞伴
人工智能·科技
说私域2 小时前
内容力重塑品牌增长:开源AI大模型驱动下的智能名片与S2B2C商城赋能抖音生态种草范式
人工智能·小程序·开源·零售
l1t2 小时前
三种读写传统xls格式文件开源库libxls、xlslib、BasicExcel的比较
c++·人工智能·开源·mfc
AI浩2 小时前
【Block总结】EBlock,快速傅里叶变换(FFT)增强输入图像的幅度|即插即用|CVPR2025
人工智能·目标检测·计算机视觉
Vertira2 小时前
Pytorch安装后 如何快速查看经典的网络模型.py文件(例如Alexnet,VGG)(已解决)
人工智能·pytorch·python
Listennnn2 小时前
信号处理基础到进阶再到前沿
人工智能·深度学习·信号处理
奔跑吧邓邓子2 小时前
DeepSeek 赋能智能养老:情感陪伴机器人的温暖革新
人工智能·机器人·deepseek·智能养老·情感陪伴