【深度学习】分类损失函数解析

【深度学习】分类相关的损失解析

文章目录

1. 介绍

在分类任务中,我们通常使用各种损失函数来衡量模型输出与真实标签之间的差异。有时候搞不清楚用什么,下面是几种常见的分类相关损失函数及其解析,与代码示例

2. 解析

  • 二元交叉熵损失(Binary Cross Entropy Loss,BCELoss):

    torch.nn.BCELoss() 是用于二元分类的损失函数。它将模型输出的概率与真实标签的二进制值进行比较,并计算二元交叉熵损失。BCELoss 可以处理每个样本属于多个类别的情况。当使用 BCELoss 时,需要注意模型输出经过 sigmoid 激活函数转换为 [0, 1] 的概率形式。

  • 带 logits 的二元交叉熵损失(Binary Cross Entropy With Logits Loss,BCEWithLogitsLoss):

    torch.nn.BCEWithLogitsLoss() 是和 BCELoss 相似的损失函数,它同时应用了 sigmoid 函数和二元交叉熵损失。在使用 BCEWithLogitsLoss 时,不需要对模型输出手动应用 sigmoid 函数,因为该函数内部已经自动执行了这个操作。

  • 多类别交叉熵损失(Multiclass Cross Entropy Loss,CrossEntropyLoss):

    torch.nn.CrossEntropyLoss() 是用于多类别分类任务的损失函数。它将模型输出的每个类别的分数与真实标签进行比较,并计算交叉熵损失。CrossEntropyLoss 适用于每个样本只能属于一个类别的情况。注意,在使用 CrossEntropyLoss 前,通常需要确保模型输出经过 softmax 或 log softmax 函数。

  • 多标签二元交叉熵损失(Multilabel Binary Cross Entropy Loss):

    当每个样本可以属于多个类别时,我们可以使用二元交叉熵损失来处理多标签分类任务。对于每个样本,将模型输出的概率与真实标签进行比较,并计算每个标签的二元交叉熵损失。可以逐标签地对每个标签应用 BCELoss,或者使用 torch.nn.BCEWithLogitsLoss() 并将模型输出中的最后一个维度设置为标签数量。

3. 代码示例

1)二元交叉熵损失(BCELoss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出经过 sigmoid 函数处理
model_output = torch.sigmoid(model(input))
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCELoss()
# 计算损失
loss = loss_fn(model_output, target)

2)带 logits 的二元交叉熵损失(BCEWithLogitsLoss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([0, 1, 1, 0])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_fn(model_output, target)

3)多类别交叉熵损失(CrossEntropyLoss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出经过 softmax 函数处理
model_output = nn.functional.softmax(model(input), dim=1)
# 真实标签(每个样本只能属于一个类别)
target = torch.LongTensor([2, 1, 0])
# 创建损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 计算损失
loss = loss_fn(model_output, target)

4)多标签二元交叉熵损失(Multilabel Binary Cross Entropy Loss):

python 复制代码
import torch
import torch.nn as nn

# 模型输出未经过 sigmoid 函数处理
model_output = model(input)
# 真实标签
target = torch.Tensor([[0, 1], [1, 1], [1, 0], [0, 1]])
# 创建损失函数对象
loss_fn = nn.BCEWithLogitsLoss()
# 计算损失,将模型输出的最后一个维度设置为标签数量
loss = loss_fn(model_output, target)
相关推荐
梦云澜1 小时前
论文阅读(十二):全基因组关联研究中生物通路的图形建模
论文阅读·人工智能·深度学习
远洋录1 小时前
构建一个数据分析Agent:提升分析效率的实践
人工智能·ai·ai agent
IT古董2 小时前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师3 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
摸鱼仙人~4 小时前
Attention Free Transformer (AFT)-2020论文笔记
论文阅读·深度学习·transformer
python算法(魔法师版)4 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui4 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20255 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥5 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
云空6 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析