Softmax 与 交叉熵损失

本文档详细解析了多分类任务中 Softmax交叉熵损失 (Cross-Entropy Loss) 的数学原理、配合机制以及工程实现中的数值稳定性问题。


1. 核心理论:从 Logits 到 Loss

假设神经网络最后一层的原始输出(未归一化)为向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> o \mathbf{o} </math>o,其中第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个节点的输出记为 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi(工程上常称为 Logits)。

1.1 Softmax 函数:归一化

Softmax 的核心作用是将任意实数域的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi 映射为概率分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi。

公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> y i = Softmax ( o i ) = e o i ∑ j = 1 C e o j y_i = \text{Softmax}(o_i) = \frac{e^{o_i}}{\sum_{j=1}^{C} e^{o_j}} </math>yi=Softmax(oi)=∑j=1Ceojeoi

  • 非负性 : <math xmlns="http://www.w3.org/1998/Math/MathML"> e o i > 0 e^{o_i} > 0 </math>eoi>0,保证概率为正。
  • 归一化 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ y i = 1 \sum y_i = 1 </math>∑yi=1,符合概率定义。
  • 差异放大:指数函数会拉大数值之间的差距,使"强者恒强"。

1.2 交叉熵损失 (Cross-Entropy Loss)

交叉熵用于衡量"预测概率分布"与"真实标签分布"之间的差异。

公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> L = − ∑ i = 1 C t i ⋅ log ⁡ ( y i ) L = - \sum_{i=1}^{C} t_i \cdot \log(y_i) </math>L=−∑i=1Cti⋅log(yi)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 是真实标签的 One-hot 编码向量(正确类别为 1,其余为 0)。因此,对于单个样本,公式简化为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L = − log ⁡ ( y correct ) L = - \log(y_{\text{correct}}) </math>L=−log(ycorrect)

即:我们只关心模型对"正确类别"预测了多大的概率。

1.3 为什么它们是"黄金搭档"?

当我们将 Softmax 和 Cross-Entropy 结合在一起对 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi 求导时,会得到非常优雅的梯度形式:

<math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ o i = y i − t i \frac{\partial L}{\partial o_i} = y_i - t_i </math>∂oi∂L=yi−ti

物理含义: 梯度等于 (预测概率) - (真实标签)

  • 这种线性的梯度特性避免了均方误差(MSE)在分类任务中可能遇到的梯度消失问题。
  • 误差越大,梯度越大,模型参数更新越快。

2. 工程挑战:数值稳定性 (Numerical Stability)

在实际工程落地时,直接按照数学公式计算 Softmax 会遇到严重问题。

2.1 上溢问题

指数函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> e x e^x </math>ex 增长极快。在标准的 float32 浮点数系统中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> e 100 ≈ 2.6 × 1 0 43 e^{100} \approx 2.6 \times 10^{43} </math>e100≈2.6×1043
  • 若 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i > 88 o_i > 88 </math>oi>88,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> e o i → inf e^{o_i} \to \text{inf} </math>eoi→inf (无穷大)。

一旦分子或分母出现 inf,计算结果就会变成 NaN (Not a Number),导致训练崩溃。

2.2 解决方案:减去最大值 (The Max Trick)

利用 Softmax 的 平移不变性 ,我们在分子分母的指数中同时减去输入向量的最大值 <math xmlns="http://www.w3.org/1998/Math/MathML"> M = max ⁡ ( o ) M = \max(\mathbf{o}) </math>M=max(o)。

推导: <math xmlns="http://www.w3.org/1998/Math/MathML"> Softmax ( o i ) = e o i ∑ e o j = e o i ⋅ e − M ( ∑ e o j ) ⋅ e − M = e o i − M ∑ e o j − M \text{Softmax}(o_i) = \frac{e^{o_i}}{\sum e^{o_j}} = \frac{e^{o_i} \cdot e^{-M}}{(\sum e^{o_j}) \cdot e^{-M}} = \frac{e^{o_i - M}}{\sum e^{o_j - M}} </math>Softmax(oi)=∑eojeoi=(∑eoj)⋅e−Meoi⋅e−M=∑eoj−Meoi−M

优势:

  1. 最大的指数项变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> e M − M = e 0 = 1 e^{M-M} = e^0 = 1 </math>eM−M=e0=1。
  2. 其余所有项的指数部分均为负数或零,结果在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 0 , 1 ] (0, 1] </math>(0,1] 之间。
  3. 彻底解决了上溢问题。

2.3 下溢问题

即便解决了上溢,如果在计算 Loss 时先算 Softmax 再算 Log,还可能遇到 下溢

如果某个类别的 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi 非常小(负绝对值很大),经过 Softmax 后 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 可能会极其接近 0。 在浮点数精度受限的情况下,计算机可能直接将 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 截断为 0 。 随后计算 Loss 时: <math xmlns="http://www.w3.org/1998/Math/MathML"> L = − log ⁡ ( 0 ) → inf L = -\log(0) \to \text{inf} </math>L=−log(0)→inf 这会导致训练梯度爆炸或 Loss 变为无穷大。

2.4 解决方案:Log-Sum-Exp

核心思想: 不要分步计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( Softmax ) \log(\text{Softmax}) </math>log(Softmax),而是将其合并推导,转化为一个原子操作。

数学推导:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> log ⁡ ( Softmax ( o i ) ) = log ⁡ ( e o i ∑ j e o j ) = log ⁡ ( e o i ) − log ⁡ ( ∑ j e o j ) = o i − log ⁡ ( ∑ j e o j ) \begin{aligned} \log(\text{Softmax}(o_i)) &= \log\left( \frac{e^{o_i}}{\sum_{j} e^{o_j}} \right) \\ &= \log(e^{o_i}) - \log\left( \sum_{j} e^{o_j} \right) \\ &= o_i - \log\left( \sum_{j} e^{o_j} \right) \end{aligned} </math>log(Softmax(oi))=log(∑jeojeoi)=log(eoi)−log(j∑eoj)=oi−log(j∑eoj)

这里的 Log-Sum-Exp 部分: <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( ∑ e o j ) \log(\sum e^{o_j}) </math>log(∑eoj) 同样再次使用 Max Trick 来保证这一步的稳定性: <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( ∑ j e o j ) = log ⁡ ( ∑ j e o j − M ⋅ e M ) = M + log ⁡ ( ∑ j e o j − M ) \log\left( \sum_{j} e^{o_j} \right) = \log\left( \sum_{j} e^{o_j - M} \cdot e^M \right) = M + \log\left( \sum_{j} e^{o_j - M} \right) </math>log(∑jeoj)=log(∑jeoj−M⋅eM)=M+log(∑jeoj−M)

2.5 最终结论

通过合并计算,我们完全避免了直接计算概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 这一步,而是直接通过 Logits 计算 Log-Probability。 公式变为: <math xmlns="http://www.w3.org/1998/Math/MathML"> LogSoftmax ( o i ) = o i − M − log ⁡ ( ∑ e o j − M ) \text{LogSoftmax}(o_i) = o_i - M - \log\left( \sum e^{o_j - M} \right) </math>LogSoftmax(oi)=oi−M−log(∑eoj−M) 在这个公式中,所有中间数值都被限制在安全范围内,既不会上溢也不会下溢。


相关推荐
MicroTech20254 分钟前
微算法科技(NASDAQ :MLGO)抗量子区块链技术:筑牢量子时代的数字安全防线
科技·算法·区块链
Ivanqhz6 分钟前
图着色寄存器分配算法(Graph Coloring)
开发语言·javascript·python·算法·蓝桥杯·rust
Elsa️7468 分钟前
洛谷p5718 复习下快速排序和堆排序
数据结构·算法·排序算法
Frostnova丶11 分钟前
LeetCode 3567.子矩阵的最小绝对差
算法·leetcode·矩阵
夏日听雨眠12 分钟前
文件学习9
数据结构·学习·算法
华农DrLai12 分钟前
什么是自动Prompt优化?为什么需要算法来寻找最佳提示词?
人工智能·算法·llm·nlp·prompt·llama
黎阳之光13 分钟前
十五五智赋新程 黎阳之光以AI硬核技术筑造产业数智底座
大数据·人工智能·算法·安全·数字孪生
2401_8914821714 分钟前
C++中的原型模式
开发语言·c++·算法
皙然15 分钟前
深度解析三色标记算法:JVM 并发 GC 的核心底层逻辑
java·jvm·算法
sali-tec17 分钟前
C# 基于OpenCv的视觉工作流-章40-特征找图
图像处理·人工智能·opencv·算法·计算机视觉