交叉熵损失函数详解

文章目录


一、交叉熵损失是什么?

在机器学习的分类任务中,模型会预测一个样本属于某个特定类别的概率。因为每个样本只能属于一个特定的类别,所以对于这个类别,真实的概率值是1,而对于其他类别,真实概率值是0。交叉熵(Cross Entropy)用来衡量预测概率和真实概率之间的差异。

二元交叉熵损失(Binary Cross Entropy Loss)和多类交叉熵损失(Multiclass Cross Entropy Loss)是交叉熵损失的两种变体,分别适用于不同类型的分类任务。


二、Binary Cross Entropy Loss

二元交叉熵损失是二分类问题中广泛使用的损失函数。对于具有 N N N 个实例的数据集,二元交叉熵损失的计算方式如下:
loss = − 1 N Σ i = 1 N ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) \text{loss}=-\frac{1}{N}\Sigma_{i=1}^N(y_{i}log(p_{i}) + (1-y_{i})log(1-p_{i})) loss=−N1Σi=1N(yilog(pi)+(1−yi)log(1−pi))

其中:
y i y_{i} yi:是第 i i i 个样本的真实标签。它的值要么是0,要么是1。
p i p_{i} pi:是模型对第 i i i 个样本预测为1的概率。

例子:

假设我们有一个样本,其真实标签为1,即 ( y 0 = 1 y_{0} = 1 y0=1 )。模型对这个样本预测为1的概率是 ( p 0 = 0.012 p_{0} = 0.012 p0=0.012 )。

那么,这个样本的损失 ( loss 0 \text{loss}_0 loss0 ) 计算如下:
loss 0 = − ( y 0 ⋅ log ⁡ ( p 0 ) + ( 1 − y 0 ) ⋅ log ⁡ ( 1 − p 0 ) ) \text{loss}_0 = -(y_0 \cdot \log(p_0) + (1-y_0) \cdot \log(1-p_0)) loss0=−(y0⋅log(p0)+(1−y0)⋅log(1−p0))

代入已知值:
loss 0 = − ( 1 ⋅ log ⁡ ( 0.012 ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − 0.012 ) ) \text{loss}_0 = -(1 \cdot \log(0.012) + (1-1) \cdot \log(1-0.012)) loss0=−(1⋅log(0.012)+(1−1)⋅log(1−0.012))

简化后:
loss 0 = − log ⁡ ( 0.012 ) + 0 ≈ 4.423 \text{loss}_0 = -\log(0.012) + 0 \approx 4.423 loss0=−log(0.012)+0≈4.423

这个结果表示,当模型预测为1的概率很低(0.012)但真实标签是1时,损失值较大,说明模型的预测与实际情况有较大偏差。模型需要通过训练来降低这种损失,提高预测准确性。


三、Multiclass Cross Entropy Loss

多类交叉熵损失,也称为分类交叉熵或 softmax 损失,是多类分类问题中训练模型的广泛使用的损失函数。对于具有 N N N 个实例的数据集,多类交叉熵损失的计算方式如下:
l o s s = − 1 N Σ i = 1 N Σ j = 1 C ( y i , j . l o g ( p i , j ) ) loss=-\frac{1}{N}\Sigma_{i=1}^N\Sigma_{j=1}^C(y_{i,j}.log(p_{i,j})) loss=−N1Σi=1NΣj=1C(yi,j.log(pi,j))

其中:
C C C:类别总数
y i , j y_{i,j} yi,j:是样本 i i i 在类别 j j j 上的真实标签,通常是0或1。1表示样本 i i i 属于类别 j j j,0表示不属于。
p i , j p_{i,j} pi,j:是模型预测的样本 i i i 属于类别 j j j 的概率。

例子:

假设我们有一个样本,它属于一个三分类问题(即 C = 3 C = 3 C=3)。这个样本的真实标签是类别2(通常用 one-hot 编码表示为 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0]),也就是说:

  • y 0 , 1 = 0 y_{0,1} = 0 y0,1=0
  • y 0 , 2 = 1 y_{0,2} = 1 y0,2=1
  • y 0 , 3 = 0 y_{0,3} = 0 y0,3=0

模型对这个样本的预测概率为:

  • 类别1的概率 p 0 , 1 = 0.1 p_{0,1} = 0.1 p0,1=0.1
  • 类别2的概率 p 0 , 2 = 0.7 p_{0,2} = 0.7 p0,2=0.7
  • 类别3的概率 p 0 , 3 = 0.2 p_{0,3} = 0.2 p0,3=0.2

那么,这个样本的损失 loss 0 \text{loss}0 loss0 计算如下:
loss 0 = − ∑ j = 1 C ( y 0 , j ⋅ log ⁡ ( p 0 , j ) ) \text{loss}0 = -\sum{j=1}^C (y
{0,j} \cdot \log(p_{0,j})) loss0=−j=1∑C(y0,j⋅log(p0,j))

代入已知值:
loss 0 = − ( 0 ⋅ log ⁡ ( 0.1 ) + 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) ) \text{loss}_0 = -(0 \cdot \log(0.1) + 1 \cdot \log(0.7) + 0 \cdot \log(0.2)) loss0=−(0⋅log(0.1)+1⋅log(0.7)+0⋅log(0.2))

简化后:
loss 0 = − log ⁡ ( 0.7 ) ≈ 0.357 \text{loss}_0 = -\log(0.7) \approx 0.357 loss0=−log(0.7)≈0.357

这个结果表示模型对类别2的预测概率为0.7,与真实标签一致,因此损失值较小,说明模型在这个样本上的预测较为准确。模型的目标是通过训练来降低整体损失,提高对所有样本的预测准确性。


四、如何理解交叉熵损失?

图片来源:https://www.geeksforgeeks.org/what-is-cross-entropy-loss-function/

交叉熵损失是一个用于衡量模型预测结果与真实标签之间差距的标量值。对于数据集中的每个样本,交叉熵损失反映了模型预测的准确性。损失值越低,表示预测越准确;损失值越高,表示预测与真实情况的差距越大。

  • 二分类问题中的解释:

    • 在二分类问题中,因为只有两个类别(0和1),所以损失值的解释相对简单。
    • 如果真实标签是1,损失值主要取决于模型对类别1的预测概率接近1.0的程度。
    • 如果真实标签是0,损失值则取决于模型对类别1的预测概率接近0.0的程度。
  • 多分类问题中的解释:

    • 在多分类问题中,只有正确的真实标签会对损失产生影响,因为其他标签为零不会对损失函数增加任何值。
    • 较低的损失表示模型为正确类别分配了较高的概率,而为错误类别分配了较低的概率。

如果觉得这篇文章有用,就给个 👍和收藏⭐️吧!也欢迎在评论区分享你的看法!


参考

相关推荐
地中海~7 分钟前
DENIAL-OF-SERVICE POISONING ATTACKS ON LARGE LANGUAGE MODELS
人工智能·语言模型·自然语言处理
HSunR20 分钟前
概率论 期末 笔记
笔记·概率论
红色的山茶花31 分钟前
YOLOv9-0.1部分代码阅读笔记-loss_tal.py
笔记·深度学习·yolo
荒古前1 小时前
龟兔赛跑 PTA
c语言·算法
Colinnian1 小时前
Codeforces Round 994 (Div. 2)-D题
算法·动态规划
边缘计算社区1 小时前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
用户0099383143011 小时前
代码随想录算法训练营第十三天 | 二叉树part01
数据结构·算法
shinelord明1 小时前
【再谈设计模式】享元模式~对象共享的优化妙手
开发语言·数据结构·算法·设计模式·软件工程
游客5201 小时前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉