损失函数:Cross Entropy Loss (交叉熵损失函数)

损失函数:Cross Entropy Loss (交叉熵损失函数)

前言

相关介绍

损失函数(Loss Function)在机器学习和深度学习中扮演着至关重要的角色,它是一个评估模型预测输出与真实标签之间差异程度的函数。损失函数量化了模型预测错误的程度,并在训练过程中作为优化的目标,模型通过不断地调整内部参数以最小化损失函数的值,从而实现更好的拟合数据和泛化能力。

主要特性与作用:

  1. 量化误差:损失函数将模型预测值与实际目标值之间的差异转化为数值,这样就可以通过数值大小直观地衡量模型的预测效果。

  2. 优化导向:在训练神经网络时,优化算法(如梯度下降法)会根据损失函数的梯度来更新模型参数,使损失函数朝着最小化方向移动。

  3. 种类多样:根据不同的任务和需求,有多种不同的损失函数可供选择。例如,在二分类任务中常用的有二元交叉熵损失函数(Binary Cross-Entropy Loss/BCE Loss),在多分类任务中有softmax交叉熵损失函数,在回归任务中常见的是均方误差(Mean Squared Error/MSE)和绝对误差(Mean Absolute Error/MAE)等。

常见的损失函数包括:

  • 二元交叉熵损失(Binary Cross-Entropy Loss / BCE Loss):适用于二分类问题,衡量的是sigmoid函数输出的概率与真实标签间的距离。

  • 多分类交叉熵损失(Categorical Cross-Entropy Loss):对于多分类问题,每个样本可能属于多个类别之一,使用softmax函数和交叉熵损失。

  • 均方误差(Mean Squared Error / MSE):在回归问题中常用,计算预测值与真实值之差的平方平均。

  • 均方根误差(Root Mean Squared Error / RMSE):MSE的平方根,也是回归任务中的损失函数。

  • Huber损失:一种既能兼顾均方误差又能容忍较大误差的混合损失函数,常用于回归问题中。

  • Dice系数损失(Dice Loss):在图像分割任务中广泛使用,衡量的是预测分割区域与真实分割区域的重叠程度。

  • IoU(Intersection over Union)损失:也是在图像分割领域常用的损失函数,计算的是预测区域与真实区域交集与其并集的比例。

  • Focal Loss:在目标检测中应对类别不平衡问题的损失函数,对易分类的样本给予较小的权重,强调难分类样本的训练。

每种损失函数都有其适用的情境和优缺点,选择合适的损失函数是优化模型性能的关键因素之一。
交叉熵(Cross-Entropy)之所以能够用于分类问题,是因为它能够很好地衡量模型预测的概率分布与实际标签分布之间的相似度,而且它拥有几个非常适合分类任务的重要特性:

  1. 信息论基础:交叉熵源于信息论中的概念,表示一个概率分布 (p) 与另一个概率分布 (q) 的差异。在分类问题中,我们可以把 (p) 视为真实数据的标签分布,(q)视为模型预测的概率分布。交叉熵可以衡量模型预测概率与实际类别标签之间的信息差异。

  2. 最大似然估计的自然延伸:在机器学习中,我们通常倾向于最大化模型对数据的似然性,即模型预测给定数据标签的概率。交叉熵损失函数实际上是负对数似然函数在多项式分布(对于多分类问题)或伯努利分布(对于二分类问题)下的特殊情况,通过最小化交叉熵损失,相当于最大化数据的对数似然性。

  3. 梯度稳定性:交叉熵损失函数是连续且可微的,其梯度容易计算且对于大多数情况是有意义的。这意味着在训练过程中,模型可以根据损失函数的梯度进行有效的参数更新。

  4. 稀疏性惩罚:对于多分类问题,softmax函数与交叉熵损失组合使用时,不仅鼓励模型正确预测每个样本的类别,同时也通过归一化机制惩罚了预测概率分布的不均匀性,即模型不能过于肯定任何一个错误类别。

  5. 处理多类别和二类别问题 :交叉熵既可以用于处理二分类问题(通过二元交叉熵,Binary Cross-Entropy),也可以处理多分类问题(通过多类别交叉熵,Multiclass

    Cross-Entropy)。在二分类问题中,通常搭配Sigmoid函数输出概率;在多分类问题中,通常配合Softmax函数生成类别概率分布。

总的来说,交叉熵损失函数因其良好的理论基础、优化目标清晰以及在实践中的优秀表现,成为了分类问题中最常用的损失函数之一。

Softmax函数

Softmax函数是深度学习和机器学习中广泛使用的激活函数,特别是在多分类问题中。它的目的是将一个线性变换的输出(通常称为logits)映射为一个概率分布,使得所有类别的概率总和为1,每个类别的概率都在0到1之间。

Softmax函数的形式:

对于一个向量 ( z ) ,其中包含每个类别的原始得分(logits),Softmax函数的计算公式如下:

s o f t m a x ( z ) i = e z i ∑ j = 1 K e z j softmax(z)i = \frac{e^{z_i}}{\sum{j=1}^{K} e^{z_j}} softmax(z)i=∑j=1Kezjezi

其中:

  • ( K ) 表示类别总数。
  • ( z_i ) 表示第 ( i ) 个类别的得分。
  • ( softmax(z)_i ) 表示第 ( i ) 个类别的归一化概率。

整个Softmax函数的结果是一个概率分布向量,其中每个元素都是原得分经过指数函数变换后再除以所有得分指数函数值之和,因此所有元素的和为1。

Softmax函数的特性:

  1. 概率性质:Softmax函数确保输出的每个元素都是非负数,并且所有元素的和为1,满足概率分布的要求。
  2. 竞争性:Softmax函数会使得分最高的类别获得最大的概率值,其余类别的概率按比例递减,形成了一种"赢家通吃"的效应。
  3. 平滑连续:由于使用了指数函数和平滑的除法运算,Softmax函数输出是平滑且连续的,便于在训练过程中梯度的计算和传播。

应用场景

在深度学习的多分类问题中,例如图像分类、文本分类等任务,Softmax函数通常与交叉熵损失函数一起使用。模型最后一层通常会产生一个logits向量,接着通过Softmax函数得到每个类别的概率,最后计算与实际标签之间的交叉熵损失,以此指导模型参数的更新。

代码实例

在PyTorch中,你可以直接使用torch.softmax()函数来实现Softmax操作。下面是一个简单的实例:

python 复制代码
import torch

# 假设我们有一个代表logits的张量
logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

# 使用torch.softmax函数计算Softmax值
probs = torch.softmax(logits, dim=1)

print(probs)
'''
tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652]])
'''

上述代码中,logits是一个2x3的张量,代表两个样本的三个类别的原始得分。dim=1表示我们在每个样本的类别间计算Softmax,也就是对每一行进行操作。执行torch.softmax()后,probs张量将包含每个样本各类别的归一化概率。

注意,如果你正在训练一个多分类模型并且使用了nn.CrossEntropyLoss()损失函数,通常不需要单独调用torch.softmax(),因为该损失函数内部已经包含了对logits计算Softmax的过程。在多数情况下,你只需将模型的原始输出(logits)传递给损失函数,并配合真实类别标签即可。

Cross Entropy Loss (交叉熵损失函数)

nn.CrossEntropyLoss是PyTorch中用于多分类问题的一种损失函数,特别适用于输出层是softmax激活函数后的分类任务。它结合了softmax函数和交叉熵损失(Cross-Entropy Loss)的操作,简化了模型训练过程中的计算步骤和代码实现。

基本概念:

  • 交叉熵损失(Cross-Entropy Loss)源于信息论中的熵概念,用于衡量两个概率分布之间的差异。在机器学习和深度学习中,它用来量化模型预测的概率分布与真实标签分布之间的差距。

  • softmax函数:在多分类问题中,softmax函数将模型的线性输出(logits)转换为一个概率分布,确保所有类别的概率和为1。softmax函数的输出可以用作模型预测的概率分布。

nn.CrossEntropyLoss的工作方式:

  • PyTorch中的nn.CrossEntropyLoss接收两个输入:

    • input:模型的原始输出(logits),通常是未经过softmax激活的张量。
    • target:真实的一维标签张量,包含了每个样本所属类别的索引,通常采用LongTensor类型。
  • 内部处理流程

    • 对于每个样本,首先计算其对应的softmax概率分布。
    • 然后,根据真实标签计算交叉熵损失。损失是对每个样本的损失值进行平均得到的,如果没有特殊指定,损失默认会在批次(batch)层面求平均。
  • 损失函数计算公式

    • 对于单个样本,交叉熵损失是 -∑(yi * log(pi)),其中 yi 是实际标签的one-hot编码(在实际情况中,由于标签是索引形式,nn.CrossEntropyLoss内部会处理one-hot编码),pi 是模型预测的该类别概率。
    • 对于整个批次,损失则是各样本损失的平均。

Cross Entropy Loss与BCE loss区别

CrossEntropyLossBCELoss 都是 PyTorch 中用于监督学习分类任务的损失函数,它们分别适用于不同的分类场景:

BCELoss (Binary Cross Entropy Loss)

  • BCELoss 是二元交叉熵损失函数,专门用于二分类问题,即输出只有两类(0或1,正面或负面,真或假等)。
  • 使用 BCELoss 时,模型的输出一般是通过 Sigmoid 函数得到的概率值,介于0和1之间。
  • 计算公式为 -y * log(p) - (1-y) * log(1-p),其中 y 是真实的标签(0或1),p 是模型预测的概率。
  • 输入要求是经过Sigmoid激活函数之后的输出张量和相应的真实标签张量,二者形状必须相同。

CrossEntropyLoss (Multinomial Cross Entropy Loss 或者 Softmax Cross Entropy Loss)

  • CrossEntropyLoss 适用于多分类问题,它可以处理任何数量的类别,不仅仅是二分类。
  • 对于多分类问题,模型的输出通常是一个 logits(未归一化的预测值),然后CrossEntropyLoss内部会先通过Softmax函数将其转换为概率分布,然后再计算交叉熵。
  • 使用 CrossEntropyLoss 时,不需要手动在输出层之前添加Sigmoid或Softmax函数,因为它已经包含了Softmax运算步骤。
  • 它结合了Softmax函数和交叉熵损失的功能,简化了多分类任务的训练流程,其计算公式基于交叉熵和类别间的互斥性(即对于每个样本,所有类别的概率之和为1)。
  • 输入要求是未经Softmax激活函数处理的logits张量和one-hot编码形式的真实标签张量。

总结来说,两者的主要区别在于:

  • BCELoss用于二分类任务,而CrossEntropyLoss适用于多分类任务。
  • BCELoss前接Sigmoid,CrossEntropyLoss前接Softmax(但这一步在使用CrossEntropyLoss时由损失函数内部自动完成)。
  • BCELoss处理的是二元概率分布,而CrossEntropyLoss处理的是多类别概率分布。

代码实例

python 复制代码
import torch
import torch.nn as nn

# 假设模型输出和真实标签
output_logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])  # 假设输出是两样本的三个类别的logits
targets = torch.tensor([1, 2])  # 假设第一样本是第二类,第二样本是第三类

# 创建交叉熵损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(output_logits, targets)

print(loss.item())  # 输出损失值 # 0.9076058864593506

在上述代码中,nn.CrossEntropyLoss()函数内部处理了softmax激活和交叉熵损失计算,直接返回了模型预测与真实标签之间的交叉熵损失。

相关推荐
hunter20620614 分钟前
用opencv生成视频流,然后用rtsp进行拉流显示
人工智能·python·opencv
Daphnis_z16 分钟前
大模型应用编排工具Dify之常用编排组件
人工智能·chatgpt·prompt
yuanbenshidiaos1 小时前
【大数据】机器学习----------强化学习机器学习阶段尾声
人工智能·机器学习
盼小辉丶6 小时前
TensorFlow深度学习实战——情感分析模型
深度学习·神经网络·tensorflow
好评笔记6 小时前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
算家云6 小时前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
AI街潜水的八角7 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
叫我:松哥8 小时前
基于Python django的音乐用户偏好分析及可视化系统设计与实现
人工智能·后端·python·mysql·数据分析·django
熊文豪9 小时前
深入解析人工智能中的协同过滤算法及其在推荐系统中的应用与优化
人工智能·算法
Vol火山9 小时前
AI引领工业制造智能化革命:机器视觉与时序数据预测的双重驱动
人工智能·制造