交叉熵损失函数详解

文章目录


一、交叉熵损失是什么?

在机器学习的分类任务中,模型会预测一个样本属于某个特定类别的概率。因为每个样本只能属于一个特定的类别,所以对于这个类别,真实的概率值是1,而对于其他类别,真实概率值是0。交叉熵(Cross Entropy)用来衡量预测概率和真实概率之间的差异。

二元交叉熵损失(Binary Cross Entropy Loss)和多类交叉熵损失(Multiclass Cross Entropy Loss)是交叉熵损失的两种变体,分别适用于不同类型的分类任务。


二、Binary Cross Entropy Loss

二元交叉熵损失是二分类问题中广泛使用的损失函数。对于具有 N N N 个实例的数据集,二元交叉熵损失的计算方式如下:
loss = − 1 N Σ i = 1 N ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) \text{loss}=-\frac{1}{N}\Sigma_{i=1}^N(y_{i}log(p_{i}) + (1-y_{i})log(1-p_{i})) loss=−N1Σi=1N(yilog(pi)+(1−yi)log(1−pi))

其中:
y i y_{i} yi:是第 i i i 个样本的真实标签。它的值要么是0,要么是1。
p i p_{i} pi:是模型对第 i i i 个样本预测为1的概率。

例子:

假设我们有一个样本,其真实标签为1,即 ( y 0 = 1 y_{0} = 1 y0=1 )。模型对这个样本预测为1的概率是 ( p 0 = 0.012 p_{0} = 0.012 p0=0.012 )。

那么,这个样本的损失 ( loss 0 \text{loss}_0 loss0 ) 计算如下:
loss 0 = − ( y 0 ⋅ log ⁡ ( p 0 ) + ( 1 − y 0 ) ⋅ log ⁡ ( 1 − p 0 ) ) \text{loss}_0 = -(y_0 \cdot \log(p_0) + (1-y_0) \cdot \log(1-p_0)) loss0=−(y0⋅log(p0)+(1−y0)⋅log(1−p0))

代入已知值:
loss 0 = − ( 1 ⋅ log ⁡ ( 0.012 ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − 0.012 ) ) \text{loss}_0 = -(1 \cdot \log(0.012) + (1-1) \cdot \log(1-0.012)) loss0=−(1⋅log(0.012)+(1−1)⋅log(1−0.012))

简化后:
loss 0 = − log ⁡ ( 0.012 ) + 0 ≈ 4.423 \text{loss}_0 = -\log(0.012) + 0 \approx 4.423 loss0=−log(0.012)+0≈4.423

这个结果表示,当模型预测为1的概率很低(0.012)但真实标签是1时,损失值较大,说明模型的预测与实际情况有较大偏差。模型需要通过训练来降低这种损失,提高预测准确性。


三、Multiclass Cross Entropy Loss

多类交叉熵损失,也称为分类交叉熵或 softmax 损失,是多类分类问题中训练模型的广泛使用的损失函数。对于具有 N N N 个实例的数据集,多类交叉熵损失的计算方式如下:
l o s s = − 1 N Σ i = 1 N Σ j = 1 C ( y i , j . l o g ( p i , j ) ) loss=-\frac{1}{N}\Sigma_{i=1}^N\Sigma_{j=1}^C(y_{i,j}.log(p_{i,j})) loss=−N1Σi=1NΣj=1C(yi,j.log(pi,j))

其中:
C C C:类别总数
y i , j y_{i,j} yi,j:是样本 i i i 在类别 j j j 上的真实标签,通常是0或1。1表示样本 i i i 属于类别 j j j,0表示不属于。
p i , j p_{i,j} pi,j:是模型预测的样本 i i i 属于类别 j j j 的概率。

例子:

假设我们有一个样本,它属于一个三分类问题(即 C = 3 C = 3 C=3)。这个样本的真实标签是类别2(通常用 one-hot 编码表示为 [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0]),也就是说:

  • y 0 , 1 = 0 y_{0,1} = 0 y0,1=0
  • y 0 , 2 = 1 y_{0,2} = 1 y0,2=1
  • y 0 , 3 = 0 y_{0,3} = 0 y0,3=0

模型对这个样本的预测概率为:

  • 类别1的概率 p 0 , 1 = 0.1 p_{0,1} = 0.1 p0,1=0.1
  • 类别2的概率 p 0 , 2 = 0.7 p_{0,2} = 0.7 p0,2=0.7
  • 类别3的概率 p 0 , 3 = 0.2 p_{0,3} = 0.2 p0,3=0.2

那么,这个样本的损失 loss 0 \text{loss}0 loss0 计算如下:
loss 0 = − ∑ j = 1 C ( y 0 , j ⋅ log ⁡ ( p 0 , j ) ) \text{loss}0 = -\sum{j=1}^C (y
{0,j} \cdot \log(p_{0,j})) loss0=−j=1∑C(y0,j⋅log(p0,j))

代入已知值:
loss 0 = − ( 0 ⋅ log ⁡ ( 0.1 ) + 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) ) \text{loss}_0 = -(0 \cdot \log(0.1) + 1 \cdot \log(0.7) + 0 \cdot \log(0.2)) loss0=−(0⋅log(0.1)+1⋅log(0.7)+0⋅log(0.2))

简化后:
loss 0 = − log ⁡ ( 0.7 ) ≈ 0.357 \text{loss}_0 = -\log(0.7) \approx 0.357 loss0=−log(0.7)≈0.357

这个结果表示模型对类别2的预测概率为0.7,与真实标签一致,因此损失值较小,说明模型在这个样本上的预测较为准确。模型的目标是通过训练来降低整体损失,提高对所有样本的预测准确性。


四、如何理解交叉熵损失?

图片来源:https://www.geeksforgeeks.org/what-is-cross-entropy-loss-function/

交叉熵损失是一个用于衡量模型预测结果与真实标签之间差距的标量值。对于数据集中的每个样本,交叉熵损失反映了模型预测的准确性。损失值越低,表示预测越准确;损失值越高,表示预测与真实情况的差距越大。

  • 二分类问题中的解释:

    • 在二分类问题中,因为只有两个类别(0和1),所以损失值的解释相对简单。
    • 如果真实标签是1,损失值主要取决于模型对类别1的预测概率接近1.0的程度。
    • 如果真实标签是0,损失值则取决于模型对类别1的预测概率接近0.0的程度。
  • 多分类问题中的解释:

    • 在多分类问题中,只有正确的真实标签会对损失产生影响,因为其他标签为零不会对损失函数增加任何值。
    • 较低的损失表示模型为正确类别分配了较高的概率,而为错误类别分配了较低的概率。

如果觉得这篇文章有用,就给个 👍和收藏⭐️吧!也欢迎在评论区分享你的看法!


参考

相关推荐
LeMay0838 分钟前
基础算法——排序算法(冒泡排序,选择排序,堆排序,插入排序,希尔排序,归并排序,快速排序,计数排序,桶排序,基数排序,Java排序)
java·算法·排序算法
花千树-0102 小时前
Milvus - GPU 索引类型及其应用场景
运维·人工智能·aigc·embedding·ai编程·milvus
微雨盈萍cbb2 小时前
OCR与PaddleOCR介绍
人工智能
战国2 小时前
卫星授时服务器,单北斗授时服务器,北斗卫星时钟服务器
服务器·网络·测试工具·分类
DogDaoDao2 小时前
深度学习常用开源数据集介绍【持续更新】
图像处理·人工智能·深度学习·ai·数据集
lqqjuly2 小时前
深度学习基础知识-编解码结构理论超详细讲解
人工智能·深度学习·编解码结构
迅为电子2 小时前
迅为RK3588开发板Android多屏显示之多屏同显和多屏异显
人工智能·rk3588·多屏显示
chaplinthink2 小时前
DB GPT本地安装部署
ai·dbgpt
2 小时前
在函数 \( f(x+1) = x^2 + 1 \) 中,\( x \) 和 \( x+1 \) 代表不同的概念
学习·机器学习