Softmax 与 交叉熵损失

本文档详细解析了多分类任务中 Softmax交叉熵损失 (Cross-Entropy Loss) 的数学原理、配合机制以及工程实现中的数值稳定性问题。


1. 核心理论:从 Logits 到 Loss

假设神经网络最后一层的原始输出(未归一化)为向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> o \mathbf{o} </math>o,其中第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个节点的输出记为 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi(工程上常称为 Logits)。

1.1 Softmax 函数:归一化

Softmax 的核心作用是将任意实数域的输出 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi 映射为概率分布 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi。

公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> y i = Softmax ( o i ) = e o i ∑ j = 1 C e o j y_i = \text{Softmax}(o_i) = \frac{e^{o_i}}{\sum_{j=1}^{C} e^{o_j}} </math>yi=Softmax(oi)=∑j=1Ceojeoi

  • 非负性 : <math xmlns="http://www.w3.org/1998/Math/MathML"> e o i > 0 e^{o_i} > 0 </math>eoi>0,保证概率为正。
  • 归一化 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ y i = 1 \sum y_i = 1 </math>∑yi=1,符合概率定义。
  • 差异放大:指数函数会拉大数值之间的差距,使"强者恒强"。

1.2 交叉熵损失 (Cross-Entropy Loss)

交叉熵用于衡量"预测概率分布"与"真实标签分布"之间的差异。

公式: <math xmlns="http://www.w3.org/1998/Math/MathML"> L = − ∑ i = 1 C t i ⋅ log ⁡ ( y i ) L = - \sum_{i=1}^{C} t_i \cdot \log(y_i) </math>L=−∑i=1Cti⋅log(yi)

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> t t </math>t 是真实标签的 One-hot 编码向量(正确类别为 1,其余为 0)。因此,对于单个样本,公式简化为:

<math xmlns="http://www.w3.org/1998/Math/MathML"> L = − log ⁡ ( y correct ) L = - \log(y_{\text{correct}}) </math>L=−log(ycorrect)

即:我们只关心模型对"正确类别"预测了多大的概率。

1.3 为什么它们是"黄金搭档"?

当我们将 Softmax 和 Cross-Entropy 结合在一起对 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi 求导时,会得到非常优雅的梯度形式:

<math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ o i = y i − t i \frac{\partial L}{\partial o_i} = y_i - t_i </math>∂oi∂L=yi−ti

物理含义: 梯度等于 (预测概率) - (真实标签)

  • 这种线性的梯度特性避免了均方误差(MSE)在分类任务中可能遇到的梯度消失问题。
  • 误差越大,梯度越大,模型参数更新越快。

2. 工程挑战:数值稳定性 (Numerical Stability)

在实际工程落地时,直接按照数学公式计算 Softmax 会遇到严重问题。

2.1 上溢问题

指数函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> e x e^x </math>ex 增长极快。在标准的 float32 浮点数系统中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> e 100 ≈ 2.6 × 1 0 43 e^{100} \approx 2.6 \times 10^{43} </math>e100≈2.6×1043
  • 若 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i > 88 o_i > 88 </math>oi>88,则 <math xmlns="http://www.w3.org/1998/Math/MathML"> e o i → inf e^{o_i} \to \text{inf} </math>eoi→inf (无穷大)。

一旦分子或分母出现 inf,计算结果就会变成 NaN (Not a Number),导致训练崩溃。

2.2 解决方案:减去最大值 (The Max Trick)

利用 Softmax 的 平移不变性 ,我们在分子分母的指数中同时减去输入向量的最大值 <math xmlns="http://www.w3.org/1998/Math/MathML"> M = max ⁡ ( o ) M = \max(\mathbf{o}) </math>M=max(o)。

推导: <math xmlns="http://www.w3.org/1998/Math/MathML"> Softmax ( o i ) = e o i ∑ e o j = e o i ⋅ e − M ( ∑ e o j ) ⋅ e − M = e o i − M ∑ e o j − M \text{Softmax}(o_i) = \frac{e^{o_i}}{\sum e^{o_j}} = \frac{e^{o_i} \cdot e^{-M}}{(\sum e^{o_j}) \cdot e^{-M}} = \frac{e^{o_i - M}}{\sum e^{o_j - M}} </math>Softmax(oi)=∑eojeoi=(∑eoj)⋅e−Meoi⋅e−M=∑eoj−Meoi−M

优势:

  1. 最大的指数项变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> e M − M = e 0 = 1 e^{M-M} = e^0 = 1 </math>eM−M=e0=1。
  2. 其余所有项的指数部分均为负数或零,结果在 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 0 , 1 ] (0, 1] </math>(0,1] 之间。
  3. 彻底解决了上溢问题。

2.3 下溢问题

即便解决了上溢,如果在计算 Loss 时先算 Softmax 再算 Log,还可能遇到 下溢

如果某个类别的 <math xmlns="http://www.w3.org/1998/Math/MathML"> o i o_i </math>oi 非常小(负绝对值很大),经过 Softmax 后 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 可能会极其接近 0。 在浮点数精度受限的情况下,计算机可能直接将 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 截断为 0 。 随后计算 Loss 时: <math xmlns="http://www.w3.org/1998/Math/MathML"> L = − log ⁡ ( 0 ) → inf L = -\log(0) \to \text{inf} </math>L=−log(0)→inf 这会导致训练梯度爆炸或 Loss 变为无穷大。

2.4 解决方案:Log-Sum-Exp

核心思想: 不要分步计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( Softmax ) \log(\text{Softmax}) </math>log(Softmax),而是将其合并推导,转化为一个原子操作。

数学推导:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> log ⁡ ( Softmax ( o i ) ) = log ⁡ ( e o i ∑ j e o j ) = log ⁡ ( e o i ) − log ⁡ ( ∑ j e o j ) = o i − log ⁡ ( ∑ j e o j ) \begin{aligned} \log(\text{Softmax}(o_i)) &= \log\left( \frac{e^{o_i}}{\sum_{j} e^{o_j}} \right) \\ &= \log(e^{o_i}) - \log\left( \sum_{j} e^{o_j} \right) \\ &= o_i - \log\left( \sum_{j} e^{o_j} \right) \end{aligned} </math>log(Softmax(oi))=log(∑jeojeoi)=log(eoi)−log(j∑eoj)=oi−log(j∑eoj)

这里的 Log-Sum-Exp 部分: <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( ∑ e o j ) \log(\sum e^{o_j}) </math>log(∑eoj) 同样再次使用 Max Trick 来保证这一步的稳定性: <math xmlns="http://www.w3.org/1998/Math/MathML"> log ⁡ ( ∑ j e o j ) = log ⁡ ( ∑ j e o j − M ⋅ e M ) = M + log ⁡ ( ∑ j e o j − M ) \log\left( \sum_{j} e^{o_j} \right) = \log\left( \sum_{j} e^{o_j - M} \cdot e^M \right) = M + \log\left( \sum_{j} e^{o_j - M} \right) </math>log(∑jeoj)=log(∑jeoj−M⋅eM)=M+log(∑jeoj−M)

2.5 最终结论

通过合并计算,我们完全避免了直接计算概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i y_i </math>yi 这一步,而是直接通过 Logits 计算 Log-Probability。 公式变为: <math xmlns="http://www.w3.org/1998/Math/MathML"> LogSoftmax ( o i ) = o i − M − log ⁡ ( ∑ e o j − M ) \text{LogSoftmax}(o_i) = o_i - M - \log\left( \sum e^{o_j - M} \right) </math>LogSoftmax(oi)=oi−M−log(∑eoj−M) 在这个公式中,所有中间数值都被限制在安全范围内,既不会上溢也不会下溢。


相关推荐
似水এ᭄往昔2 小时前
【C++】--AVL树的认识和实现
开发语言·数据结构·c++·算法·stl
栀秋6662 小时前
“无重复字符的最长子串”:从O(n²)哈希优化到滑动窗口封神,再到DP降维打击!
前端·javascript·算法
xhxxx2 小时前
不用 Set,只用两个布尔值:如何用标志位将矩阵置零的空间复杂度压到 O(1)
javascript·算法·面试
有意义2 小时前
斐波那契数列:从递归到优化的完整指南
javascript·算法·面试
charlie1145141913 小时前
编写INI Parser 测试完整指南 - 从零开始
开发语言·c++·笔记·学习·算法·单元测试·测试
mmz12073 小时前
前缀和问题2(c++)
c++·算法
TL滕3 小时前
从0开始学算法——第十六天(双指针算法)
数据结构·笔记·学习·算法
蒲小英3 小时前
算法-贪心算法
算法·贪心算法
mit6.8243 小时前
链式投票|流向贪心
算法