TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(四)

避免Softmax溢出和log-sum-exp技巧

先明确核心前提(避免混淆)

我们要解决的问题:单独算 Softmax 时,logits 里有大数值(比如 1000),会导致 e^z 溢出(变成 inf),后续计算全错

优化目标:不单独算 Softmax,而是把「Softmax + 交叉熵」的计算合并,通过数学变换绕开溢出

先复习两个关键工具(对数运算法则,必须会,超简单):

  1. log⁡(a/b)=log⁡(a)−log⁡(b)\log(a / b) = \log(a) - \log(b)log(a/b)=log(a)−log(b)(除法变减法)
  2. log⁡(ex)=x\log(e^x) = xlog(ex)=x(对数和指数是逆运算,抵消)

第一步:代数变换------合并 Softmax 和交叉熵

我们最终需要的是「交叉熵损失」,而交叉熵的核心是 log⁡(Softmax(z正确类别))\log(\text{Softmax}(z_{\text{正确类别}}))log(Softmax(z正确类别))(因为损失 = -这个值)。

咱们先从 Softmax 的定义出发,一步步推导出合并后的公式:

步骤1:写出 Softmax 的定义(目标类别的置信度)

假设正确类别是 yyy,logits 是 z=[z0,z1,...,zC]z = [z_0, z_1, ..., z_C]z=[z0,z1,...,zC](C 是类别数),则目标类别的 Softmax 置信度为:
Softmax(zy)=ezy∑j=0Cezj \text{Softmax}(z_y) = \frac{e^{z_y}}{\sum_{j=0}^C e^{z_j}} Softmax(zy)=∑j=0Cezjezy

(分子:目标类别的指数得分;分母:所有类别的指数得分总和)

步骤2:对 Softmax 结果取 log(交叉熵需要这一步)

我们需要计算 log⁡(Softmax(zy))\log(\text{Softmax}(z_y))log(Softmax(zy)),把上面的 Softmax 代入:
log⁡(ezy∑j=0Cezj) \log\left( \frac{e^{z_y}}{\sum_{j=0}^C e^{z_j}} \right) log(∑j=0Cezjezy)

步骤3:用对数运算法则拆分(关键一步)

根据 log⁡(a/b)=log⁡(a)−log⁡(b)\log(a/b) = \log(a) - \log(b)log(a/b)=log(a)−log(b),把分子分母的 log 拆开:
log⁡(ezy)−log⁡(∑j=0Cezj) \log(e^{z_y}) - \log\left( \sum_{j=0}^C e^{z_j} \right) log(ezy)−log(j=0∑Cezj)

步骤4:简化第一项(对数和指数抵消)

根据 log⁡(ex)=x\log(e^x) = xlog(ex)=x,log⁡(ezy)=zy\log(e^{z_y}) = z_ylog(ezy)=zy,所以最终简化为:
log⁡(Softmax(zy))=zy−log⁡(∑j=0Cezj) \log(\text{Softmax}(z_y)) = z_y - \log\left( \sum_{j=0}^C e^{z_j} \right) log(Softmax(zy))=zy−log(j=0∑Cezj)

步骤5:得到合并后的损失公式

因为交叉熵损失 = −log⁡(Softmax(zy))-\log(\text{Softmax}(z_y))−log(Softmax(zy)),代入上面的结果:
loss=−(zy−log⁡(∑j=0Cezj)) \text{loss} = - \left( z_y - \log\left( \sum_{j=0}^C e^{z_j} \right) \right) loss=−(zy−log(j=0∑Cezj))

关键结论:

通过这 5 步代数变换,我们完全绕开了单独计算 Softmax 置信度 (不用先算 ezy/∑ezje^{z_y}/\sum e^{z_j}ezy/∑ezj),而是直接用 logits 计算损失------这就从根源上避免了"先算 Softmax 导致的溢出"!

用之前的例子验证(确保变换等价)

例子:logits z=[1.0,3.0,0.5]z = [1.0, 3.0, 0.5]z=[1.0,3.0,0.5](正确类别 y=1y=1y=1)

按合并公式计算:

  1. 先算分母的和:∑ezj=e1.0+e3.0+e0.5≈2.718+20.085+1.648≈24.451\sum e^{z_j} = e^{1.0} + e^{3.0} + e^{0.5} ≈ 2.718 + 20.085 + 1.648 ≈ 24.451∑ezj=e1.0+e3.0+e0.5≈2.718+20.085+1.648≈24.451
  2. 算 log⁡(∑ezj)≈log⁡(24.451)≈3.200\log(\sum e^{z_j}) ≈ \log(24.451) ≈ 3.200log(∑ezj)≈log(24.451)≈3.200
  3. 代入公式:log⁡(Softmax(z1))=3.0−3.200=−0.200\log(\text{Softmax}(z_1)) = 3.0 - 3.200 = -0.200log(Softmax(z1))=3.0−3.200=−0.200
  4. 损失 = −(−0.200)=0.200-(-0.200) = 0.200−(−0.200)=0.200(和之前单独算 Softmax 再算交叉熵的结果完全一致!)

这说明:代数变换没有改变结果,只是换了计算顺序,却避开了溢出风险

第二步:log-sum-exp 技巧------解决"分母求和仍可能溢出"的问题

上面的合并公式里,还有一个潜在风险:∑ezj\sum e^{z_j}∑ezj 中的 zjz_jzj 如果是极大值(比如 zj=1000z_j=1000zj=1000),e1000e^{1000}e1000 依然会溢出(变成 inf)。

log-sum-exp 技巧就是为了处理这个"分母求和"的溢出问题,核心思路是:给所有 logits 减去一个最大值,让指数运算的结果变小,再通过数学调整保证最终结果不变

步骤1:log-sum-exp 技巧的公式推导

目标:计算 log⁡(∑j=0Cezj)\log\left( \sum_{j=0}^C e^{z_j} \right)log(∑j=0Cezj),避免 ezje^{z_j}ezj 溢出。

设 M=max⁡(z0,z1,...,zC)M = \max(z_0, z_1, ..., z_C)M=max(z0,z1,...,zC)(所有 logits 中的最大值),对求和项做变换:
∑j=0Cezj=∑j=0Cezj−M+M=eM×∑j=0Cezj−M \sum_{j=0}^C e^{z_j} = \sum_{j=0}^C e^{z_j - M + M} = e^M \times \sum_{j=0}^C e^{z_j - M} j=0∑Cezj=j=0∑Cezj−M+M=eM×j=0∑Cezj−M

(因为 ea+b=ea×ebe^{a+b} = e^a \times e^bea+b=ea×eb,所以 ezj=eM+(zj−M)=eM×ezj−Me^{z_j} = e^{M + (z_j - M)} = e^M \times e^{z_j - M}ezj=eM+(zj−M)=eM×ezj−M)

然后对两边取 log:
log⁡(∑j=0Cezj)=log⁡(eM×∑j=0Cezj−M) \log\left( \sum_{j=0}^C e^{z_j} \right) = \log\left( e^M \times \sum_{j=0}^C e^{z_j - M} \right) log(j=0∑Cezj)=log(eM×j=0∑Cezj−M)

再用对数运算法则 log⁡(a×b)=log⁡(a)+log⁡(b)\log(a \times b) = \log(a) + \log(b)log(a×b)=log(a)+log(b) 拆分:
log⁡(eM)+log⁡(∑j=0Cezj−M) \log(e^M) + \log\left( \sum_{j=0}^C e^{z_j - M} \right) log(eM)+log(j=0∑Cezj−M)

最后简化 log⁡(eM)=M\log(e^M) = Mlog(eM)=M,得到 log-sum-exp 公式:
log⁡(∑j=0Cezj)=M+log⁡(∑j=0Cezj−M) \log\left( \sum_{j=0}^C e^{z_j} \right) = M + \log\left( \sum_{j=0}^C e^{z_j - M} \right) log(j=0∑Cezj)=M+log(j=0∑Cezj−M)

步骤2:为什么这个技巧能避免溢出?

因为 MMM 是所有 logits 中的最大值,所以 zj−M≤0z_j - M ≤ 0zj−M≤0(比如 zj=1000z_j=1000zj=1000,M=1000M=1000M=1000,则 zj−M=0z_j - M=0zj−M=0;其他 logits 减去 M 后都是负数)。

而 e负数e^{负数}e负数 的结果是(0, 1] 之间的小数(比如 e−1000≈0e^{-1000}≈0e−1000≈0,e0=1e^0=1e0=1),不会溢出!

用极端例子验证(核心!看溢出怎么被解决)

例子:logits z=[1000,998,999]z = [1000, 998, 999]z=[1000,998,999](最大值 M=1000M=1000M=1000)

如果直接算 ∑ezj\sum e^{z_j}∑ezj:

  • e1000e^{1000}e1000 溢出为 inf,∑ezj=inf+e998+e999=inf\sum e^{z_j} = inf + e^{998} + e^{999} = inf∑ezj=inf+e998+e999=inf,后续计算全错。

用 log-sum-exp 技巧计算:

  1. 取 M=max⁡(z)=1000M = \max(z) = 1000M=max(z)=1000
  2. 计算每个 zj−Mz_j - Mzj−M:1000−1000=01000-1000=01000−1000=0,998−1000=−2998-1000=-2998−1000=−2,999−1000=−1999-1000=-1999−1000=−1
  3. 计算 ∑ezj−M=e0+e−2+e−1≈1+0.135+0.368≈1.503\sum e^{z_j - M} = e^0 + e^{-2} + e^{-1} ≈ 1 + 0.135 + 0.368 ≈ 1.503∑ezj−M=e0+e−2+e−1≈1+0.135+0.368≈1.503
  4. 计算 log⁡(∑ezj−M)≈log⁡(1.503)≈0.407\log(\sum e^{z_j - M}) ≈ \log(1.503) ≈ 0.407log(∑ezj−M)≈log(1.503)≈0.407
  5. 最终结果:M+0.407=1000+0.407=1000.407M + 0.407 = 1000 + 0.407 = 1000.407M+0.407=1000+0.407=1000.407(没有溢出!)

完美解决了 e1000e^{1000}e1000 溢出的问题,而且结果和"不溢出时直接计算"的结果一致(如果 e1000e^{1000}e1000 能算,log⁡(e1000+e998+e999)=log⁡(e1000(1+e−2+e−1))=1000+log⁡(1.503)≈1000.407\log(e^{1000} + e^{998} + e^{999}) = \log(e^{1000}(1 + e^{-2} + e^{-1})) = 1000 + \log(1.503) ≈ 1000.407log(e1000+e998+e999)=log(e1000(1+e−2+e−1))=1000+log(1.503)≈1000.407)。

步骤3:把 log-sum-exp 代入损失公式(最终优化版)

将 log-sum-exp 结果代入之前的合并损失公式:
loss=−(zy−(M+log⁡(∑j=0Cezj−M))) \text{loss} = - \left( z_y - \left( M + \log\left( \sum_{j=0}^C e^{z_j - M} \right) \right) \right) loss=−(zy−(M+log(j=0∑Cezj−M)))

简化后:
loss=−zy+M+log⁡(∑j=0Cezj−M) \text{loss} = -z_y + M + \log\left( \sum_{j=0}^C e^{z_j - M} \right) loss=−zy+M+log(j=0∑Cezj−M)

这个公式就是 TensorFlow 内部实际使用的计算逻辑------不管 logits 有多大,都不会溢出!

第三步:总结------为什么「from_logits=True」能避免溢出?

本质是「两步优化」的结合:

  1. 代数变换:合并 Softmax 和交叉熵的计算,绕开单独算 Softmax 时的指数溢出;
  2. log-sum-exp 技巧:处理合并后"分母求和"的溢出风险,通过减去 logits 最大值,让指数运算结果始终在(0,1] 之间,再通过数学调整保证结果正确。

整个过程的核心:不是消除指数运算,而是通过数学变换,让指数运算的输入始终是"小数值"(≤0),从而避免溢出,同时保证损失值的计算结果和原逻辑完全一致

最后用一个完整的极端例子验证(全程无溢出)

例子:logits z=[1000,998,999]z = [1000, 998, 999]z=[1000,998,999],正确类别 y=0y=0y=0(zy=1000z_y=1000zy=1000)

用最终优化公式计算损失:

  1. 取 M=max⁡(z)=1000M = \max(z) = 1000M=max(z)=1000
  2. 计算 ∑ezj−M=e0+e−2+e−1≈1.503\sum e^{z_j - M} = e^0 + e^{-2} + e^{-1} ≈ 1.503∑ezj−M=e0+e−2+e−1≈1.503
  3. 计算 log⁡(1.503)≈0.407\log(1.503) ≈ 0.407log(1.503)≈0.407
  4. 代入损失公式:loss=−1000+1000+0.407=0.407\text{loss} = -1000 + 1000 + 0.407 = 0.407loss=−1000+1000+0.407=0.407

如果单独算 Softmax:

  • 直接算 e1000e^{1000}e1000 溢出为 inf,Softmax(z0)=inf/(inf+e998+e999)=inf/inf=nan\text{Softmax}(z_0) = inf/(inf + e^{998} + e^{999}) = inf/inf = nanSoftmax(z0)=inf/(inf+e998+e999)=inf/inf=nan,损失无法计算。

而用优化后的公式,即使 logits 有 1000 这样的大数值,也能正常计算出损失值 0.407------这就是 log-sum-exp 技巧的魔力!

核心记忆点(不用记公式,记逻辑)

  1. 代数变换:把「Softmax + log」变成「logits - log(求和)」,绕开单独算 Softmax;
  2. log-sum-exp:给所有 logits 减最大值,让指数运算不溢出,再补回最大值保证结果不变;
  3. 最终目的:不管 logits 多大,都能安全计算损失,不出现 inf 或 nan。
相关推荐
博一波2 小时前
在技术转型中重温基础:机器学习核心领域梳理
人工智能·机器学习
子洋2 小时前
AI Agent 设计模式 - ReAct 模式
前端·人工智能·后端
likerhood2 小时前
6. pytorch 卷积神经网络
人工智能·pytorch·神经网络
受伤的僵尸2 小时前
算法类复习(1)-非自注意力机制(图像处理中的注意力)
人工智能·算法
AI technophile2 小时前
OpenCV计算机视觉实战(33)——文字识别详解
人工智能·opencv·计算机视觉
囊中之锥.2 小时前
机器学习:认识随机森林
人工智能·随机森林·机器学习
百胜软件@百胜软件2 小时前
CTO Wow Club 上海研讨会成功举办,百胜软件深度分享零售AI智能体实战之道
大数据·人工智能·零售
晨非辰2 小时前
基于Win32 API控制台的贪吃蛇游戏:从设计到C语言实现详解
c语言·c++·人工智能·后端·python·深度学习·游戏
Dingdangcat862 小时前
基于RetinaNet的仙人掌品种识别与分类:Gymnocalycium与Mammillaria属10品种自动识别
人工智能·数据挖掘