交叉熵在机器学习中的应用解析

文章目录

交叉熵(Cross-Entropy)是信息论和机器学习中的一个重要概念,用于衡量两个概率分布之间的差异。它在分类任务(如逻辑回归、神经网络)中常作为损失函数使用。

核心概念

香农信息量(自信息)

对于一个具有概率 P ( x ) P(x) P(x) 的事件 x x x,其信息量 I ( x ) I(x) I(x) 定义为:
I ( x ) = − log ⁡ b P ( x ) I(x) = -\log_b P(x) I(x)=−logbP(x)

其中:

  • log ⁡ b \log_b logb 是以 b b b 为底的对数,常用的底数有:
    • b = 2 b = 2 b=2:信息量单位为比特(bit)
    • b = e b = e b=e:信息量单位为奈特(nat)
    • b = 10 b = 10 b=10:信息量单位为哈特(hart)
  • 信息量 I ( x ) I(x) I(x) 表示事件 x x x 发生时所携带的信息的多少,概率越低的事件信息量越大。

熵(Entropy)

熵(平均信息量)

熵是随机变量不确定性的度量,定义为信息量的期望:
H ( X ) = − ∑ x ∈ X P ( x ) log ⁡ b P ( x ) H(X) = -\sum_{x \in X} P(x) \log_b P(x) H(X)=−x∈X∑P(x)logbP(x)

对于连续随机变量,熵可以表示为:
H ( X ) = − ∫ − ∞ ∞ p ( x ) log ⁡ b p ( x )   d x H(X) = -\int_{-\infty}^{\infty} p(x) \log_b p(x) \, dx H(X)=−∫−∞∞p(x)logbp(x)dx

其中 ( p(x) ) 是概率密度函数。

表示一个概率分布自身的不确定性。对于离散分布 P P P,熵定义为:
H ( P ) = − ∑ i P ( x i ) log ⁡ P ( x i ) H(P) = -\sum_{i} P(x_i) \log P(x_i) H(P)=−i∑P(xi)logP(xi)

  • 熵越大,不确定性越高。

KL散度(Kullback-Leibler Divergence)

衡量两个分布 P P P(真实分布)和 Q Q Q(预测分布)的差异:
D K L ( P ∥ Q ) = ∑ i P ( x i ) log ⁡ P ( x i ) Q ( x i ) D_{KL}(P \| Q) = \sum_{i} P(x_i) \log \frac{P(x_i)}{Q(x_i)} DKL(P∥Q)=i∑P(xi)logQ(xi)P(xi)

  • KL散度非负,且不对称。
  • 当 P = Q P = Q P=Q 时,交叉熵最小,等于 P P P 的熵。

交叉熵

交叉熵是熵与KL散度的组合:
H ( P , Q ) = H ( P ) + D K L ( P ∥ Q ) = − ∑ i P ( x i ) log ⁡ Q ( x i ) H(P, Q) = H(P) + D_{KL}(P \| Q) = -\sum_{i} P(x_i) \log Q(x_i) H(P,Q)=H(P)+DKL(P∥Q)=−i∑P(xi)logQ(xi)

  • 当 P P P 是真实分布(如one-hot标签), Q Q Q 是模型预测时,最小化交叉熵等价于最小化KL散度。

在机器学习中的应用

作为损失函数

对于二分类(Binary Classification):
  • 公式
    L = − 1 N ∑ i = 1 N [ y i log ⁡ ( p i ) + ( 1 − y i ) log ⁡ ( 1 − p i ) ] L = -\frac{1}{N} \sum_{i=1}^N \left[ y_i \log(p_i) + (1-y_i) \log(1-p_i) \right] L=−N1i=1∑N[yilog(pi)+(1−yi)log(1−pi)]
    其中 y i ∈ { 0 , 1 } y_i \in \{0,1\} yi∈{0,1} 是真实标签, p i p_i pi 是模型预测为正类的概率。
  • 场景
    逻辑回归、神经网络二分类输出层(如Sigmoid激活)。
对于多分类(Multiclass Classification):
  • 公式 (分类交叉熵,Categorical Cross-Entropy)
    L = − 1 N ∑ i = 1 N ∑ c = 1 C y i , c log ⁡ ( p i , c ) L = -\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C y_{i,c} \log(p_{i,c}) L=−N1i=1∑Nc=1∑Cyi,clog(pi,c)
    • y i , c y_{i,c} yi,c:样本 i i i 属于类别 c c c 的真实标签(one-hot编码)。
    • p i , c p_{i,c} pi,c:模型预测样本 i i i 属于类别 c c c 的概率。
  • 场景
    Softmax输出层配合交叉熵(如ResNet、Transformer的分类头)。
多标签分类(Multi-label Classification)
  • 特点 :每个样本可能属于多个类别,使用二元交叉熵对每个类别独立计算损失。
  • 公式
    L = − 1 N ∑ i = 1 N ∑ c = 1 C [ y i , c log ⁡ ( p i , c ) + ( 1 − y i , c ) log ⁡ ( 1 − p i , c ) ] L = -\frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C \left[ y_{i,c} \log(p_{i,c}) + (1-y_{i,c}) \log(1-p_{i,c}) \right] L=−N1i=1∑Nc=1∑C[yi,clog(pi,c)+(1−yi,c)log(1−pi,c)]

其他应用场景

  • 生成模型:GAN中判别器的损失函数常使用交叉熵衡量真实/生成分布的差异。
  • 语言模型:预测下一个词的概率分布(如BERT、GPT的预训练目标)。
  • 强化学习:策略梯度方法中优化策略分布与最优分布的交叉熵。

实例

手撸计算

假设真实分布 P = [ 1 , 0 ] P = [1, 0] P=[1,0](类别1),模型预测 Q = [ 0.8 , 0.2 ] Q = [0.8, 0.2] Q=[0.8,0.2]:
H ( P , Q ) = − 1 ⋅ log ⁡ ( 0.8 ) − 0 ⋅ log ⁡ ( 0.2 ) ≈ 0.223 H(P, Q) = -1 \cdot \log(0.8) - 0 \cdot \log(0.2) \approx 0.223 H(P,Q)=−1⋅log(0.8)−0⋅log(0.2)≈0.223

若预测更差(如 Q = \[0.3, 0.7\] ):
H ( P , Q ) = − 1 ⋅ log ⁡ ( 0.3 ) ≈ 1.203 H(P, Q) = -1 \cdot \log(0.3) \approx 1.203 H(P,Q)=−1⋅log(0.3)≈1.203

实现示例(PyTorch)
python 复制代码
import torch.nn as nn

# 二分类
loss_fn = nn.BCELoss()  # 需手动Sigmoid
loss_fn = nn.BCEWithLogitsLoss()  # 内置Sigmoid

# 多分类
loss_fn = nn.CrossEntropyLoss()  # 输入为logits(无需Softmax)
注意事项
  • 数值稳定性 :计算 log ⁡ ( p ) \log(p) log(p)时可能溢出,通常框架会自动处理(如添加微小偏移 ϵ \epsilon ϵ)。
  • 概率归一化:确保模型输出符合概率分布(如通过Softmax或Sigmoid)。

直观解释

  • 当预测概率 Q Q Q 与真实分布 P P P 一致时,交叉熵最小(等于 P P P 的熵)。
  • 预测越偏离真实,交叉熵越大。

为什么用交叉熵?

  • 梯度友好性
    • 对于Softmax输出,交叉熵的梯度为 ∂ L ∂ z i = p i − y i \frac{\partial L}{\partial z_i} = p_i - y_i ∂zi∂L=pi−yi,避免了均方误差(MSE)的梯度消失问题(当 p i p_i pi接近0或1时,MSE梯度极小)。
  • 概率解释:直接优化模型输出的概率分布与真实分布的差异,与最大似然估计(MLE)等价。天然适配分类任务的概率输出。
  • 处理不平衡数据:可通过加权交叉熵(Weighted Cross-Entropy)调整类别权重。

变体与改进

  • 标签平滑(Label Smoothing) :防止模型对标签过度自信,将真实标签从1调整为 1 − ϵ 1-\epsilon 1−ϵ,其余类别分配 ϵ / ( C − 1 ) \epsilon/(C-1) ϵ/(C−1)。
  • Focal Loss :解决类别不平衡问题,降低易分类样本的权重:
    L = − α t ( 1 − p t ) γ log ⁡ ( p t ) L = -\alpha_t (1-p_t)^\gamma \log(p_t) L=−αt(1−pt)γlog(pt)
    ( γ \gamma γ 为调节因子, α t \alpha_t αt 为类别权重)。

理解交叉熵的关键是掌握其与熵、KL散度的关系,以及如何通过最小化它来使模型逼近真实分布。

相关推荐
@Mr_LiuYang1 个月前
Rethinking BiSeNet For Real-time Semantic Segmentation细节损失函数学习
语义分割·损失函数·边界提取·边界损失
xidianjiapei0011 个月前
一文读懂深度学习中的损失函数quantifying loss —— 作用、分类和示例代码
人工智能·深度学习·分类·损失函数·交叉熵
丶21363 个月前
【分类】【损失函数】处理类别不平衡:CEFL 和 CEFL2 损失函数的实现与应用
人工智能·分类·损失函数
余胜辉3 个月前
【深度学习】交叉熵:从理论到实践
人工智能·深度学习·机器学习·损失函数·交叉熵
chencjiajy4 个月前
机器学习基础:极大似然估计与交叉熵
深度学习·机器学习·损失函数
王亭_6664 个月前
深度学习中损失函数(loss function)介绍
人工智能·pytorch·深度学习·损失函数
goomind4 个月前
深度学习常用损失函数介绍
人工智能·深度学习·损失函数
lishanlu1366 个月前
目标检测中的损失函数
目标检测·损失函数·iou损失函数·边界框回归损失
Nicolas8936 个月前
【大模型理论篇】大模型相关的周边技术分享-关于《NN and DL》的笔记
深度学习·神经网络·损失函数·参数初始化·深度学习模型训练·规范化