binary_cross_entropy和binary_cross_entropy_with_logits的区别

binary_cross_entropy和binary_cross_entropy_with_logits的区别

引言

二分类问题是常见的机器学习任务之一,其目标是将样本分为两个类别。为了训练一个二分类模型,通常使用交叉熵作为损失函数。

二分类交叉熵损失函数有两种不同的形式,分别是 binary_cross_entropy_with_logitsbinary_cross_entropy。在 PyTorch 中,这两种损失函数都是可用的,它们的区别在于输入的形式不同,以及它们分别是在什么情况下使用更合适


无论生活中发生什么,你都可以选择快乐。 悲伤从来都不是一种选择。 快乐的关键是要知道你可以控制你接受什么和放弃什么。

主要区别与说明

binary_cross_entropy_with_logits 通常用于二元分类问题,其中每个样本都只属于两个类别之一。此损失函数的输入应该是模型的预测值和真实标签,通常是使用sigmoid函数将最终的输出值转换为概率值。

binary_cross_entropy 也是用于二元分类问题的损失函数,但其输入应该是模型的预测值和真实标签的概率值。因此,在使用此损失函数时,需要将模型的输出值使用sigmoid函数转换为概率值,然后再将其与真实标签进行比较。

总之,binary_cross_entropy_with_logits 适用于模型输出未经过概率变换的情况,而 binary_cross_entropy 适用于模型输出已经是概率值的情况。

实例说明

以下是一个基于PyTorch的实例,展示如何使用两种损失函数:

python 复制代码
import torch
import torch.nn as nn

# 创建一个样例数据
y_true = torch.Tensor([1, 0, 1, 1])
y_pred = torch.Tensor([0.9, 0.1, 0.8, 0.7])

# 使用binary_cross_entropy_with_logits计算损失函数
loss_logits = nn.BCEWithLogitsLoss()(y_pred, y_true)
print("loss with logits:", loss_logits)

# 错误示例
loss_sigmoid_error = nn.BCELoss()(y_pred, y_true)
print("注意:错误示例 loss with sigmoid_error:", loss_sigmoid_error)  # !!!注意:可以直接计算,但是这样的计算式错误的

# 使用binary_cross_entropy计算损失函数
y_pred_sigmoid = torch.sigmoid(y_pred)
print("y_pred_sigmoid:", y_pred_sigmoid)
loss_sigmoid = nn.BCELoss()(y_pred_sigmoid, y_true)
print("loss with sigmoid:", loss_sigmoid)

运行输出如下:

bash 复制代码
loss with logits: tensor(0.4650)
注意:错误示例 loss with sigmoid_error: tensor(0.1976)
y_pred_sigmoid: tensor([0.7109, 0.5250, 0.6900, 0.6682])
loss with sigmoid: tensor(0.4650)

其中,使用nn.BCEWithLogitsLoss()函数计算binary_cross_entropy_with_logits损失函数,而使用nn.BCELoss()函数计算binary_cross_entropy损失函数。在实际使用中,建议优先使用binary_cross_entropy_with_logits损失函数。

总结

binary_cross_entropy_with_logitsbinary_cross_entropy 两者都是用于二分类问题中的损失函数。它们的主要区别在于输入的形式以及计算方式。

binary_cross_entropy_with_logits的输入是网络输出的logits(未经sigmoid函数激活的),并且该函数会自动进行sigmoid函数激活处理。而binary_cross_entropy的输入是经过sigmoid函数激活的概率值。因此使用binary_cross_entropy_with_logits会更加方便且稳定,因为它可以避免数值计算溢出的情况。

这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间。事实上,官方是推荐使用函数带有with_logits的,解释是

This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.

翻译一下就是说将sigmoid层和binaray_cross_entropy合在一起计算比分开依次计算有更好的数值稳定性,这主要是运用了log-sum-exp技巧。

reference

@misc{BibEntry2023Oct,

title = {{pytorch损失函数binary{ _ \_ }cross{ _ \ }entropy和binary{ _ \ }cross{ _ \ }entropy{ _ \ }with{ _ \ _}logits的区别-CSDN博客}},

year = {2023},

month = oct,

urldate = {2023-10-06},

language = {chinese},

note = {[Online; accessed 6. Oct. 2023]},

url = {https://blog.csdn.net/u010630669/article/details/105599067}

}

相关推荐
阡之尘埃1 小时前
Python数据分析案例61——信贷风控评分卡模型(A卡)(scorecardpy 全面解析)
人工智能·python·机器学习·数据分析·智能风控·信贷风控
丕羽4 小时前
【Pytorch】基本语法
人工智能·pytorch·python
bryant_meng4 小时前
【python】Distribution
开发语言·python·分布函数·常用分布
m0_594526306 小时前
Python批量合并多个PDF
java·python·pdf
工业互联网专业6 小时前
Python毕业设计选题:基于Hadoop的租房数据分析系统的设计与实现
vue.js·hadoop·python·flask·毕业设计·源码·课程设计
钱钱钱端6 小时前
【压力测试】如何确定系统最大并发用户数?
自动化测试·软件测试·python·职场和发展·压力测试·postman
慕卿扬6 小时前
基于python的机器学习(二)—— 使用Scikit-learn库
笔记·python·学习·机器学习·scikit-learn
Json____6 小时前
python的安装环境Miniconda(Conda 命令管理依赖配置)
开发语言·python·conda·miniconda
小袁在上班6 小时前
Python 单元测试中的 Mocking 与 Stubbing:提高测试效率的关键技术
python·单元测试·log4j
白狐欧莱雅6 小时前
使用python中的pygame简单实现飞机大战游戏
经验分享·python·游戏·pygame