深度学习中6种loss函数Pytorch API调用示例

自定义数据

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

batchsize=2
num_class=4

logits=torch.randn(batchsize,num_class)
target=torch.randint(num_class,size=(batchsize,))#delta目标分布
target_logits=torch.randn(batchsize,num_class)#非delta目标分布

交叉熵 CrossEntropyLoss

python 复制代码
## 1. CE Loss  交叉熵

ce_loss_fn=torch.nn.CrossEntropyLoss()
ce_loss=ce_loss_fn(logits,target)
print("ce_loss1:",ce_loss)

ce_loss=ce_loss_fn(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss2:",ce_loss)

负对数似然 NLLLoss

python 复制代码
## 2. NLL Loss 负对数似然
nll_fn=torch.nn.NLLLoss()
nll_loss=nll_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),target)
print("nll_loss:",nll_loss)

####CE LOSS value = NLL LOSS value

KL散度 KLDivLoss

python 复制代码
## 3. KL loss  KL散度
kl_loss_fn=torch.nn.KLDivLoss()
kl_loss=kl_loss_fn(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1))
print("kl_loss:", kl_loss)

交叉熵=信息熵+KL散度 CE=IE+KLD

python 复制代码
## 4. 验证 CE=IE+KLD
print("===========================")
ce_loss_fn_sample=torch.nn.CrossEntropyLoss(reduction="none")#单独对每个样本求交叉熵
ce_loss_sample=ce_loss_fn_sample(logits,torch.softmax(target_logits,dim=-1))
print("ce_loss_sample:",ce_loss_sample)

kl_loss_fn_sample=torch.nn.KLDivLoss(reduction="none")
kl_loss_sample=kl_loss_fn_sample(torch.log(torch.softmax(logits,dim=-1)+1e-7),torch.softmax(target_logits,dim=-1)).sum(-1)
print("kl_loss_sample:",kl_loss_sample)

target_information_entropy=torch.distributions.Categorical(probs=torch.softmax(target_logits,dim=-1)).entropy()
print("target_information_entropy:", target_information_entropy)# IE为常数,如果目标分布是delta分布IE=0

print(torch.allclose(ce_loss_sample,kl_loss_sample+target_information_entropy))#对比两个浮点张量是否相等

二分类交叉熵 BCELoss

python 复制代码
## 5. BCE Loss  二分类交叉熵
print("===========================")
bce_loss_fn=torch.nn.BCELoss()
logits=torch.rand(batchsize)
prob_1=torch.sigmoid(logits)
target=torch.randint(2,size=(batchsize,))
bce_loss=bce_loss_fn(prob_1,target.float())
print("bce_loss:",bce_loss)

### NLL Loss是BCE Loss的一般形式,用NLL Loss代替BCE loss做二分类
prob_0=1-prob_1.unsqueeze(-1)
prob=torch.cat([prob_0,prob_1.unsqueeze(-1)],dim=-1)
nll_loss_binary=nll_fn(torch.log(prob),target)
print("nll_loss_binary:",nll_loss_binary)

余弦相似度 CosineEmbeddingLoss

python 复制代码
## 6. cosine similarity loss 余弦相似度
cosine_loss_fn=torch.nn.CosineEmbeddingLoss()
v1=torch.randn(batchsize,512)
v2=torch.randn(batchsize,512)
target=torch.randint(2,size=(batchsize,))*2-1 #生成【-1,1】之间的随机值
cosine_loss=cosine_loss_fn(v1,v2,target)
print("consine_loss:",cosine_loss)
相关推荐
视觉语言导航13 小时前
ICRA-2025 | 阿德莱德机器人拓扑导航探索!TANGO:具有局部度量控制的拓扑目标可穿越性感知具身导航
人工智能·机器人·具身智能
西猫雷婶18 小时前
CNN卷积计算
人工智能·神经网络·cnn
格林威19 小时前
常规线扫描镜头有哪些类型?能做什么?
人工智能·深度学习·数码相机·算法·计算机视觉·视觉检测·工业镜头
lyx331369675919 小时前
#深度学习基础:神经网络基础与PyTorch
pytorch·深度学习·神经网络·参数初始化
倔强青铜三20 小时前
苦练Python第63天:零基础玩转TOML配置读写,tomllib模块实战
人工智能·python·面试
B站计算机毕业设计之家20 小时前
智慧交通项目:Python+YOLOv8 实时交通标志系统 深度学习实战(TT100K+PySide6 源码+文档)✅
人工智能·python·深度学习·yolo·计算机视觉·智慧交通·交通标志
高工智能汽车20 小时前
棱镜观察|极氪销量遇阻?千里智驾左手服务吉利、右手对标华为
人工智能·华为
txwtech21 小时前
第6篇 OpenCV RotatedRect如何判断矩形的角度
人工智能·opencv·计算机视觉
正牌强哥21 小时前
Futures_ML——机器学习在期货量化交易中的应用与实践
人工智能·python·机器学习·ai·交易·akshare
倔强青铜三21 小时前
苦练Python第62天:零基础玩转CSV文件读写,csv模块实战
人工智能·python·面试