【深度学习】分类损失函数解析

【深度学习】分类相关的损失解析

文章目录

1. 介绍

在分类任务中,我们通常使用各种损失函数来衡量模型输出与真实标签之间的差异。有时候搞不清楚用什么,下面是几种常见的分类相关损失函数及其解析,与代码示例

2. 解析

  • 二元交叉熵损失(Binary Cross Entropy Loss,BCELoss):

    torch.nn.BCELoss() 是用于二元分类的损失函数。它将模型输出的概率与真实标签的二进制值进行比较,并计算二元交叉熵损失。BCELoss 可以处理每个样本属于多个类别的情况。当使用 BCELoss 时,需要注意模型输出经过 sigmoid 激活函数转换为 [0, 1] 的概率形式。

  • 带 logits 的二元交叉熵损失(Binary Cross Entropy With Logits Loss,BCEWithLogitsLoss):

    torch.nn.BCEWithLogitsLoss() 是和 BCELoss 相似的损失函数,它同时应用了 sigmoid 函数和二元交叉熵损失。在使用 BCEWithLogitsLoss 时,不需要对模型输出手动应用 sigmoid 函数,因为该函数内部已经自动执行了这个操作。

  • 多类别交叉熵损失(Multiclass Cross Entropy Loss,CrossEntropyLoss):

    torch.nn.CrossEntropyLoss() 是用于多类别分类任务的损失函数。它将模型输出的每个类别的分数与真实标签进行比较,并计算交叉熵损失。CrossEntropyLoss 适用于每个样本只能属于一个类别的情况。注意,在使用 CrossEntropyLoss 前,通常需要确保模型输出经过 softmax 或 log softmax 函数。

  • 多标签二元交叉熵损失(Multilabel Binary Cross Entropy Loss):

    当每个样本可以属于多个类别时,我们可以使用二元交叉熵损失来处理多标签分类任务。对于每个样本,将模型输出的概率与真实标签进行比较,并计算每个标签的二元交叉熵损失。可以逐标签地对每个标签应用 BCELoss,或者使用 torch.nn.BCEWithLogitsLoss() 并将模型输出中的最后一个维度设置为标签数量。

3. 代码示例

1)二元交叉熵损失(BCELoss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出经过 sigmoid 函数处理
model_output = torch.sigmoid(model(input))
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCELoss()
# 计算损失
loss = loss_fn(model_output, target)

2)带 logits 的二元交叉熵损失(BCEWithLogitsLoss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_fn(model_output, target)

3)多类别交叉熵损失(CrossEntropyLoss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出经过 softmax 函数处理
model_output = nn.functional.softmax(model(input), dim=1)
# 真实标签(每个样本只能属于一个类别)
target = torch.LongTensor([2, 1, 0])
# 创建损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 计算损失
loss = loss_fn(model_output, target)

4)多标签二元交叉熵损失(Multilabel Binary Cross Entropy Loss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([[0, 1], [1, 1], [1, 0], [0, 1]])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失,将模型输出的最后一个维度设置为标签数量
loss = loss_fn(model_output, target)
相关推荐
AI英德西牛仔11 小时前
AI复制的文字带星号
人工智能·ai·chatgpt·豆包·deepseek·ds随心转
卖报的大地主12 小时前
扩散薛定谔桥(Diffusion Schrödinger Bridge)
人工智能
向成科技12 小时前
当“超轻量AI”遇上“最强国产芯”
人工智能·物联网·ai·芯片·国产化·硬件·主板
远见阁12 小时前
智能体是如何“思考”的:ReAct模式
人工智能·ai·ai智能体
L-影12 小时前
为什么你的数据里藏着“隐形圈子”?聊聊AI中的聚类
人工智能·ai·数据挖掘·聚类
江瀚视野12 小时前
小马智行Robotaxi营收增超1.2倍,小马的成绩单该咋看?
人工智能
Tony Bai12 小时前
Rust 看了流泪,AI 看了沉默:扒开 Go 泛型最让你抓狂的“残疾”类型推断
开发语言·人工智能·后端·golang·rust
2301_7644413312 小时前
AI动态编排革命:Skill与Dify工作流终极对决
人工智能·机器学习
ai大模型中转api测评12 小时前
从并发噩梦到弹性自由:2026年开发者如何构建高可用的API分发层?
人工智能·gpt·gemini
程序员Shawn12 小时前
【机器学习 | 第五篇】- 决策树
人工智能·决策树·机器学习