交叉熵损失函数详解

文章目录


一、交叉熵损失是什么?

在机器学习的分类任务中,模型会预测一个样本属于某个特定类别的概率。因为每个样本只能属于一个特定的类别,所以对于这个类别,真实的概率值是1,而对于其他类别,真实概率值是0。交叉熵(Cross Entropy)用来衡量预测概率和真实概率之间的差异。

二元交叉熵损失(Binary Cross Entropy Loss)和多类交叉熵损失(Multiclass Cross Entropy Loss)是交叉熵损失的两种变体,分别适用于不同类型的分类任务。


二、Binary Cross Entropy Loss

二元交叉熵损失是二分类问题中广泛使用的损失函数。对于具有 N N N 个实例的数据集,二元交叉熵损失的计算方式如下:
loss = − 1 N Σ i = 1 N ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) \text{loss}=-\frac{1}{N}\Sigma_{i=1}^N(y_{i}log(p_{i}) + (1-y_{i})log(1-p_{i})) loss=−N1Σi=1N(yilog(pi)+(1−yi)log(1−pi))

其中:
y i y_{i} yi:是第 i i i 个样本的真实标签。它的值要么是0,要么是1。
p i p_{i} pi:是模型对第 i i i 个样本预测为1的概率。

例子:

假设我们有一个样本,其真实标签为1,即 ( y 0 = 1 y_{0} = 1 y0=1 )。模型对这个样本预测为1的概率是 ( p 0 = 0.012 p_{0} = 0.012 p0=0.012 )。

那么,这个样本的损失 ( loss 0 \text{loss}_0 loss0 ) 计算如下:
loss 0 = − ( y 0 ⋅ log ⁡ ( p 0 ) + ( 1 − y 0 ) ⋅ log ⁡ ( 1 − p 0 ) ) \text{loss}_0 = -(y_0 \cdot \log(p_0) + (1-y_0) \cdot \log(1-p_0)) loss0=−(y0⋅log(p0)+(1−y0)⋅log(1−p0))

代入已知值:
loss 0 = − ( 1 ⋅ log ⁡ ( 0.012 ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − 0.012 ) ) \text{loss}_0 = -(1 \cdot \log(0.012) + (1-1) \cdot \log(1-0.012)) loss0=−(1⋅log(0.012)+(1−1)⋅log(1−0.012))

简化后:
loss 0 = − log ⁡ ( 0.012 ) + 0 ≈ 4.423 \text{loss}_0 = -\log(0.012) + 0 \approx 4.423 loss0=−log(0.012)+0≈4.423

这个结果表示,当模型预测为1的概率很低(0.012)但真实标签是1时,损失值较大,说明模型的预测与实际情况有较大偏差。模型需要通过训练来降低这种损失,提高预测准确性。


三、Multiclass Cross Entropy Loss

多类交叉熵损失,也称为分类交叉熵或 softmax 损失,是多类分类问题中训练模型的广泛使用的损失函数。对于具有 N N N 个实例的数据集,多类交叉熵损失的计算方式如下:
l o s s = − 1 N Σ i = 1 N Σ j = 1 C ( y i , j . l o g ( p i , j ) ) loss=-\frac{1}{N}\Sigma_{i=1}^N\Sigma_{j=1}^C(y_{i,j}.log(p_{i,j})) loss=−N1Σi=1NΣj=1C(yi,j.log(pi,j))

其中:
C C C:类别总数
y i , j y_{i,j} yi,j:是样本 i i i 在类别 j j j 上的真实标签,通常是0或1。1表示样本 i i i 属于类别 j j j,0表示不属于。
p i , j p_{i,j} pi,j:是模型预测的样本 i i i 属于类别 j j j 的概率。

例子:

假设我们有一个样本,它属于一个三分类问题(即 C = 3 C = 3 C=3)。这个样本的真实标签是类别2(通常用 one-hot 编码表示为 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0]),也就是说:

  • y 0 , 1 = 0 y_{0,1} = 0 y0,1=0
  • y 0 , 2 = 1 y_{0,2} = 1 y0,2=1
  • y 0 , 3 = 0 y_{0,3} = 0 y0,3=0

模型对这个样本的预测概率为:

  • 类别1的概率 p 0 , 1 = 0.1 p_{0,1} = 0.1 p0,1=0.1
  • 类别2的概率 p 0 , 2 = 0.7 p_{0,2} = 0.7 p0,2=0.7
  • 类别3的概率 p 0 , 3 = 0.2 p_{0,3} = 0.2 p0,3=0.2

那么,这个样本的损失 loss 0 \text{loss}0 loss0 计算如下:
loss 0 = − ∑ j = 1 C ( y 0 , j ⋅ log ⁡ ( p 0 , j ) ) \text{loss}0 = -\sum{j=1}^C (y
{0,j} \cdot \log(p_{0,j})) loss0=−j=1∑C(y0,j⋅log(p0,j))

代入已知值:
loss 0 = − ( 0 ⋅ log ⁡ ( 0.1 ) + 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) ) \text{loss}_0 = -(0 \cdot \log(0.1) + 1 \cdot \log(0.7) + 0 \cdot \log(0.2)) loss0=−(0⋅log(0.1)+1⋅log(0.7)+0⋅log(0.2))

简化后:
loss 0 = − log ⁡ ( 0.7 ) ≈ 0.357 \text{loss}_0 = -\log(0.7) \approx 0.357 loss0=−log(0.7)≈0.357

这个结果表示模型对类别2的预测概率为0.7,与真实标签一致,因此损失值较小,说明模型在这个样本上的预测较为准确。模型的目标是通过训练来降低整体损失,提高对所有样本的预测准确性。


四、如何理解交叉熵损失?

图片来源:https://www.geeksforgeeks.org/what-is-cross-entropy-loss-function/

交叉熵损失是一个用于衡量模型预测结果与真实标签之间差距的标量值。对于数据集中的每个样本,交叉熵损失反映了模型预测的准确性。损失值越低,表示预测越准确;损失值越高,表示预测与真实情况的差距越大。

  • 二分类问题中的解释:

    • 在二分类问题中,因为只有两个类别(0和1),所以损失值的解释相对简单。
    • 如果真实标签是1,损失值主要取决于模型对类别1的预测概率接近1.0的程度。
    • 如果真实标签是0,损失值则取决于模型对类别1的预测概率接近0.0的程度。
  • 多分类问题中的解释:

    • 在多分类问题中,只有正确的真实标签会对损失产生影响,因为其他标签为零不会对损失函数增加任何值。
    • 较低的损失表示模型为正确类别分配了较高的概率,而为错误类别分配了较低的概率。

如果觉得这篇文章有用,就给个 👍和收藏⭐️吧!也欢迎在评论区分享你的看法!


参考

相关推荐
Swift社区2 小时前
LeetCode - #139 单词拆分
算法·leetcode·职场和发展
Kent_J_Truman3 小时前
greater<>() 、less<>()及运算符 < 重载在排序和堆中的使用
算法
IT 青年3 小时前
数据结构 (1)基本概念和术语
数据结构·算法
ZHOU_WUYI4 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
Dong雨4 小时前
力扣hot100-->栈/单调栈
算法·leetcode·职场和发展
如若1234 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
SoraLuna4 小时前
「Mac玩转仓颉内测版24」基础篇4 - 浮点类型详解
开发语言·算法·macos·cangjie
老艾的AI世界4 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221514 小时前
机器学习系列----关联分析
人工智能·机器学习
liujjjiyun4 小时前
小R的随机播放顺序
数据结构·c++·算法