PyTorch中CrossEntropyLoss、BCELoss、BCEWithLogitsLoss的理解

import torch

predict =torch.Tensor(\[0.5796,0.4403,0.9087,-1.5673,-0.3150,1.6660])

#predict =torch.Tensor(\[0.5796,0.4403,-1.5673,-0.3150])

print(predict)

target =torch.tensor(0,2)

target_bce =torch.Tensor(\[1,0,0,0,0,1])

ce_loss=torch.nn.CrossEntropyLoss()

soft_max=torch.nn.Softmax(dim=-1)

sig_max=torch.nn.Sigmoid()

soft_out=soft_max(predict)

sig_out=sig_max(predict)

bce_loss=torch.nn.BCELoss()

bce_loss1=torch.nn.BCEWithLogitsLoss()

print(ce_loss(predict,target))

print(bce_loss(soft_out,target_bce))

print(bce_loss(sig_out,target_bce))

print(bce_loss1(predict,target_bce))

输出:

#predict:

tensor(\[ 0.5796, 0.4403, 0.9087,
-1.5673, -0.3150, 1.6660])

#print(ce_loss(predict,target)):

tensor(0.6725)

#print(bce_loss(soft_out,target_bce))
tensor(0.3950)

#print(bce_loss(sig_out,target_bce))

tensor(0.5900)

print(bce_loss1(predict,target_bce))
tensor(0.5900)

结论:

1.sigmoid激活+BCELoss等于BCEWithLogitsLoss

2.BCEWithLogitsLoss和CrossEntropyLoss不一样,但都可以不加激活

3.sigmoid激活+BCELoss和softmax激活+BCELoss有很大区别

相关推荐
程序猿追6 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
xiao5kou4chang6kai46 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
renhongxia16 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC6 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
闵孚龙6 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
β添砖java6 天前
深度学习(22)网络中的网络NiN
人工智能·深度学习
Kobebryant-Manba6 天前
深度学习时候d2l报错和使用问题
人工智能·深度学习
zhangfeng11336 天前
deepspeed zero3 结合 llamafactory 微调 ,save_only_model: true 导致保存时候出错
开发语言·python·深度学习
大模型最新论文速读6 天前
06-16 · LLM 最新论文速览
论文阅读·人工智能·深度学习·机器学习·自然语言处理
宝贝儿好6 天前
【LLM】第二章:HuggingFace入门学习
人工智能·深度学习·神经网络·学习·算法·自然语言处理