二分类交叉熵与多分类交叉熵详解及实例计算

二分类交叉熵与多分类交叉熵详解及实例计算



文章目录


引言

在机器学习和深度学习中,交叉熵损失函数是衡量分类模型预测概率与真实标签之间差异的一种重要方法。本文将详细介绍二分类交叉熵和多分类交叉熵的定义、应用场景以及如何结合激活函数进行实例计算。

二分类交叉熵

定义

二分类交叉熵(Binary Cross Entropy, BCE)通常用于只有两个类别的分类问题。它的目的是最小化模型预测概率与实际标签之间的差异。

应用场景

假设我们正在构建一个模型来预测电子邮件是否为垃圾邮件(spam)。这里有两个类别:"是垃圾邮件"(标记为1)和"不是垃圾邮件"(标记为0)。

输出格式

  • 对于每个样本,模型输出一个介于0到1之间的实数,表示属于正类的概率。

标签格式

  • 标签是单个值,通常为0或1,其中0表示负类,1表示正类。

激活函数

  • 输出层通常使用Sigmoid激活函数,因为它可以将输出压缩到[0, 1]区间内,表示概率。

Sigmoid激活函数

Sigmoid函数的定义如下:
σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+e−z1

其中 z z z 是模型的原始输出。

公式

BCE ( y , y ^ ) = − [ y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ] \text{BCE}(y, \hat{y}) = -[y \log(\hat{y}) + (1-y) \log(1-\hat{y})] BCE(y,y^)=−[ylog(y^)+(1−y)log(1−y^)]

  • 其中 y y y 是实际标签(0 或 1), y ^ \hat{y} y^ 是模型预测的概率。

实例计算

假设模型的原始输出为1.5,经过Sigmoid激活函数后得到的预测概率为0.8,即预测该邮件为垃圾邮件的概率为80%。如果实际标签是1(确实是垃圾邮件),则二分类交叉熵损失计算如下:

  1. 计算Sigmoid激活函数的输出:
    y ^ = σ ( 1.5 ) = 1 1 + e − 1.5 ≈ 0.81757 \hat{y} = \sigma(1.5) = \frac{1}{1 + e^{-1.5}} \approx 0.81757 y^=σ(1.5)=1+e−1.51≈0.81757

  2. 计算二分类交叉熵损失:
    BCE ( y , y ^ ) = − [ 1 ⋅ log ⁡ ( 0.81757 ) + ( 1 − 1 ) ⋅ log ⁡ ( 1 − 0.81757 ) ] ≈ − log ⁡ ( 0.81757 ) ≈ 0.20136 \text{BCE}(y, \hat{y}) = -[1 \cdot \log(0.81757) + (1-1) \cdot \log(1-0.81757)] \approx -\log(0.81757) \approx 0.20136 BCE(y,y^)=−[1⋅log(0.81757)+(1−1)⋅log(1−0.81757)]≈−log(0.81757)≈0.20136

多分类交叉熵

定义

多分类交叉熵(Categorical Cross Entropy, CE)适用于三个或更多类别的分类任务。它的目标也是最小化模型预测概率与实际标签之间的差异。

应用场景

假设我们正在构建一个图像分类器,它可以识别三种不同的动物:猫、狗和兔子。

输出格式

  • 对于每个样本,模型输出一个向量,表示每个类别的预测概率。

标签格式

  • 标签是一组one-hot编码向量,其中只有一个元素为1,表示实际类别,其余元素均为0。

激活函数

  • 输出层通常使用Softmax激活函数,它将输出转换为概率分布,所有类别概率之和为1。

Softmax激活函数

Softmax函数的定义如下:
softmax ( z ) i = e z i ∑ j = 1 C e z j \text{softmax}(z)i = \frac{e^{z_i}}{\sum{j=1}^{C} e^{z_j}} softmax(z)i=∑j=1Cezjezi

其中 z z z 是模型的原始输出向量, i i i 表示类别索引。

公式

CE ( y , y ^ ) = − ∑ c = 1 C y c log ⁡ ( y ^ c ) \text{CE}(y, \hat{y}) = -\sum_{c=1}^{C} y_c \log(\hat{y}_c) CE(y,y^)=−c=1∑Cyclog(y^c)

  • 其中 C C C 是类别的数量, y c y_c yc 是实际标签的one-hot编码向量, y ^ c \hat{y}_c y^c 是模型预测的概率向量。

实例计算

假设模型的原始输出向量为 [ 1.0 , 2.0 , 0.5 ] [1.0, 2.0, 0.5] [1.0,2.0,0.5],经过Softmax激活函数后得到的预测概率向量为 [ 0.2 , 0.5 , 0.3 ] [0.2, 0.5, 0.3] [0.2,0.5,0.3],即预测这张图片为猫的概率为20%,为狗的概率为50%,为兔子的概率为30%。如果实际标签是狗(one-hot编码为[0, 1, 0]),则多分类交叉熵损失计算如下:

  1. 计算Softmax激活函数的输出:
    y ^ = softmax ( [ 1.0 , 2.0 , 0.5 ] ) \hat{y} = \text{softmax}([1.0, 2.0, 0.5]) y^=softmax([1.0,2.0,0.5])
    y ^ = [ e 1.0 e 1.0 + e 2.0 + e 0.5 , e 2.0 e 1.0 + e 2.0 + e 0.5 , e 0.5 e 1.0 + e 2.0 + e 0.5 ] ≈ [ 0.24473 , 0.53966 , 0.21561 ] \hat{y} = \left[\frac{e^{1.0}}{e^{1.0} + e^{2.0} + e^{0.5}}, \frac{e^{2.0}}{e^{1.0} + e^{2.0} + e^{0.5}}, \frac{e^{0.5}}{e^{1.0} + e^{2.0} + e^{0.5}}\right] \approx [0.24473, 0.53966, 0.21561] y^=[e1.0+e2.0+e0.5e1.0,e1.0+e2.0+e0.5e2.0,e1.0+e2.0+e0.5e0.5]≈[0.24473,0.53966,0.21561]

  2. 计算多分类交叉熵损失:
    CE ( y , y ^ ) = − [ 0 ⋅ log ⁡ ( 0.24473 ) + 1 ⋅ log ⁡ ( 0.53966 ) + 0 ⋅ log ⁡ ( 0.21561 ) ] = − log ⁡ ( 0.53966 ) ≈ 0.61886 \text{CE}(y, \hat{y}) = -\left[ 0 \cdot \log(0.24473) + 1 \cdot \log(0.53966) + 0 \cdot \log(0.21561) \right] = -\log(0.53966) \approx 0.61886 CE(y,y^)=−[0⋅log(0.24473)+1⋅log(0.53966)+0⋅log(0.21561)]=−log(0.53966)≈0.61886

小结

  • 二分类交叉熵常用于只有两个类别的分类问题,输出层使用Sigmoid激活函数。
  • 多分类交叉熵用于三个或更多类别的分类问题,输出层使用Softmax激活函数。
  • 输出格式:二分类交叉熵通常接收一个标量作为输入,而多分类交叉熵接收一个向量。
  • 标签格式:二分类中标签通常是单个值(0或1),而在多分类中标签是one-hot编码向量。
  • 应用场景:二分类适用于只有两类的情况,多分类适用于三个或更多
相关推荐
m0_6090004213 分钟前
向日葵好用吗?4款稳定的远程控制软件推荐。
运维·服务器·网络·人工智能·远程工作
开MINI的工科男1 小时前
深蓝学院-- 量产自动驾驶中的规划控制算法 小鹏
人工智能·机器学习·自动驾驶
AI大模型知识分享2 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
张人玉4 小时前
人工智能——猴子摘香蕉问题
人工智能
草莓屁屁我不吃4 小时前
Siri因ChatGPT-4o升级:我们的个人信息还安全吗?
人工智能·安全·chatgpt·chatgpt-4o
小言从不摸鱼4 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
AI科研视界5 小时前
ChatGPT+2:修订初始AI安全性和超级智能假设
人工智能·chatgpt
霍格沃兹测试开发学社测试人社区5 小时前
人工智能 | 基于ChatGPT开发人工智能服务平台
软件测试·人工智能·测试开发·chatgpt
小R资源5 小时前
3款免费的GPT类工具
人工智能·gpt·chatgpt·ai作画·ai模型·国内免费
artificiali8 小时前
Anaconda配置pytorch的基本操作
人工智能·pytorch·python