TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(三)

用「通俗例子+公式拆解+逐步代入计算」的方式,把数学公式讲透------每一步都不跳,每个符号都解释,全程紧扣你已经懂的通俗逻辑,让公式成为"看得见、算得清"的工具,而不是抽象符号~

先明确核心目标:数学公式本质是把「"正确类别置信度越高,损失越低"」这个通俗逻辑,用精准的符号表达出来。咱们还是用之前的3分类例子(猫=0、狗=1、鸟=2),正确答案是狗(标签=1),全程围绕这个例子展开!

第一步:先搞懂交叉熵的核心公式(基础中的基础)

交叉熵的公式是:
H(p,q)=−∑i=1Cp(i)log⁡(q(i)) H(p,q) = -\sum_{i=1}^C p(i) \log(q(i)) H(p,q)=−i=1∑Cp(i)log(q(i))

逐个拆解符号(像查字典一样)
符号 通俗含义 对应咱们的例子
H(p,q)H(p,q)H(p,q) 交叉熵(损失值),衡量"真实标签"和"模型预测"的差异 最终要算的损失分
p(i)p(i)p(i) 真实标签的概率分布(iii 代表第 iii 个类别) 正确答案是狗(i=1i=1i=1),所以 p(0)=0p(0)=0p(0)=0(猫不是)、p(1)=1p(1)=1p(1)=1(狗是)、p(2)=0p(2)=0p(2)=0(鸟不是)
q(i)q(i)q(i) 模型预测的概率分布(iii 代表第 iii 个类别) 模型猜猫0.1、狗0.8、鸟0.1,所以 q(0)=0.1q(0)=0.1q(0)=0.1、q(1)=0.8q(1)=0.8q(1)=0.8、q(2)=0.1q(2)=0.1q(2)=0.1
∑i=1C\sum_{i=1}^C∑i=1C 求和符号(把所有类别 i=1i=1i=1 到 i=Ci=Ci=C 的结果加起来) C=3C=3C=3(3分类),所以要算 i=0i=0i=0、i=1i=1i=1、i=2i=2i=2 三项的和(注:公式里 i=1i=1i=1 是习惯,实际咱们从0开始也一样)
log⁡\loglog 自然对数(ln),核心作用:把"置信度"转换成"损失尺度"------置信度越高,log⁡(q(i))\log(q(i))log(q(i)) 越大,加负号后损失越小(刚好对应通俗逻辑) 比如 log⁡(0.8)≈−0.223\log(0.8)≈-0.223log(0.8)≈−0.223,log⁡(0.05)≈−2.996\log(0.05)≈-2.996log(0.05)≈−2.996
负号 −-− 反转对数结果的正负(因为 log⁡(0−1)\log(0-1)log(0−1) 的结果是负数,加负号后损失值为正,方便理解和计算) 比如 log⁡(0.8)≈−0.223\log(0.8)≈-0.223log(0.8)≈−0.223,加负号后是 0.2230.2230.223(损失值)
代入例子计算(手把手算一遍)

咱们的例子:p(0)=0p(0)=0p(0)=0、p(1)=1p(1)=1p(1)=1、p(2)=0p(2)=0p(2)=0;q(0)=0.1q(0)=0.1q(0)=0.1、q(1)=0.8q(1)=0.8q(1)=0.8、q(2)=0.1q(2)=0.1q(2)=0.1

把这些值代入公式,先算每一项,再求和:

  1. 计算 i=0i=0i=0(猫):p(0)×log⁡(q(0))=0×log⁡(0.1)=0×(−2.302)=0p(0) \times \log(q(0)) = 0 \times \log(0.1) = 0 \times (-2.302) = 0p(0)×log(q(0))=0×log(0.1)=0×(−2.302)=0
  2. 计算 i=1i=1i=1(狗):p(1)×log⁡(q(1))=1×log⁡(0.8)=1×(−0.223)=−0.223p(1) \times \log(q(1)) = 1 \times \log(0.8) = 1 \times (-0.223) = -0.223p(1)×log(q(1))=1×log(0.8)=1×(−0.223)=−0.223
  3. 计算 i=2i=2i=2(鸟):p(2)×log⁡(q(2))=0×log⁡(0.1)=0×(−2.302)=0p(2) \times \log(q(2)) = 0 \times \log(0.1) = 0 \times (-2.302) = 0p(2)×log(q(2))=0×log(0.1)=0×(−2.302)=0

然后求和:∑=0+(−0.223)+0=−0.223\sum = 0 + (-0.223) + 0 = -0.223∑=0+(−0.223)+0=−0.223

最后加负号:H(p,q)=−(−0.223)=0.223H(p,q) = -(-0.223) = 0.223H(p,q)=−(−0.223)=0.223

结果和之前通俗例子里的"单样本损失0.223"完全一致!

公式简化(关键结论,不用记复杂求和)

因为真实标签是「one-hot分布」(只有正确类别是1,其他都是0),所以求和时,除了"正确类别"那一项,其他项都是 0×log⁡(q(i))=00 \times \log(q(i))=00×log(q(i))=0,相当于"白算"。

所以交叉熵公式可以直接简化成:
H(p,q)=−log⁡(q(正确类别)) H(p,q) = -\log(q(\text{正确类别})) H(p,q)=−log(q(正确类别))

翻译成人话:损失值 = -ln(模型对正确类别的置信度)

这就是核心!之前的例子里,正确类别是狗,置信度0.8,所以损失就是 −ln⁡(0.8)≈0.223-\ln(0.8)≈0.223−ln(0.8)≈0.223,和计算结果一致~

第二步:理解 from_logits=True 时的公式(优化后的计算)

之前通俗讲解里提到:from_logits=True 时,模型输出的是「原始得分(logits)」(比如 z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]),不是0-1的置信度。这时候公式变了,但本质还是一样的------只是为了避免计算出错,换了一种更聪明的计算方式。

先搞懂两个关键转换(铺垫)
  1. Softmax转换 :把logits(原始得分)变成置信度(0-1,总和1)

    公式:Softmax(z)i=ezi∑j=1Cezj\text{Softmax}(z)i = \frac{e^{z_i}}{\sum{j=1}^C e^{z_j}}Softmax(z)i=∑j=1Cezjezi

    翻译:第 iii 类的置信度 = 第 iii 类的指数得分 / 所有类别的指数得分总和

    例子:z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]

    • 计算指数得分:e1.0≈2.718e^{1.0}≈2.718e1.0≈2.718、e3.0≈20.085e^{3.0}≈20.085e3.0≈20.085、e0.5≈1.648e^{0.5}≈1.648e0.5≈1.648
    • 总和:2.718+20.085+1.648≈24.4512.718+20.085+1.648≈24.4512.718+20.085+1.648≈24.451
    • 置信度:q(0)=2.718/24.451≈0.111q(0)=2.718/24.451≈0.111q(0)=2.718/24.451≈0.111、q(1)=20.085/24.451≈0.821q(1)=20.085/24.451≈0.821q(1)=20.085/24.451≈0.821、q(2)=1.648/24.451≈0.067q(2)=1.648/24.451≈0.067q(2)=1.648/24.451≈0.067
  2. log_softmax转换 :直接把logits变成"置信度的对数"(避免单独算Softmax导致溢出)

    公式:log⁡(Softmax(z))i=zi−log⁡(∑j=1Cezj)\log(\text{Softmax}(z))i = z_i - \log(\sum{j=1}^C e^{z_j})log(Softmax(z))i=zi−log(∑j=1Cezj)

    翻译:第 iii 类置信度的对数 = 第 iii 类原始得分 - 所有类别的指数得分总和的对数

    例子:z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]

    • 总和的对数:log⁡(24.451)≈3.200\log(24.451)≈3.200log(24.451)≈3.200
    • log⁡(Softmax(z))1\log(\text{Softmax}(z))_1log(Softmax(z))1(正确类别1):3.0−3.200=−0.2003.0 - 3.200 = -0.2003.0−3.200=−0.200
from_logits=True 时的损失公式

因为损失值是 −log⁡(q(正确类别))-\log(q(\text{正确类别}))−log(q(正确类别)),而 log⁡(q(正确类别))\log(q(\text{正确类别}))log(q(正确类别)) 就是 log⁡(Softmax(z))正确类别\log(\text{Softmax}(z)){\text{正确类别}}log(Softmax(z))正确类别,所以代入后:
loss=−log⁡(Softmax(z))正确类别=−(z正确类别−log⁡(∑j=1Cezj)) \text{loss} = - \log(\text{Softmax}(z))
{\text{正确类别}} = - \left( z_{\text{正确类别}} - \log(\sum_{j=1}^C e^{z_j}) \right) loss=−log(Softmax(z))正确类别=−(z正确类别−log(j=1∑Cezj))

代入例子计算(再算一遍,验证结果)

例子:logits z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5],正确类别1

  1. 计算 ∑ezj≈24.451\sum e^{z_j}≈24.451∑ezj≈24.451(和之前一样)
  2. 计算 log⁡(∑ezj)≈3.200\log(\sum e^{z_j})≈3.200log(∑ezj)≈3.200
  3. 计算 z正确类别=z1=3.0z_{\text{正确类别}} = z_1 = 3.0z正确类别=z1=3.0
  4. 代入公式:loss=−(3.0−3.200)=−(−0.200)=0.200\text{loss} = - (3.0 - 3.200) = -(-0.200) = 0.200loss=−(3.0−3.200)=−(−0.200)=0.200

这个结果和之前通俗讲解里的"0.200"一致!

为什么要这么麻烦?(核心优势)

如果直接算 Softmax(z)\text{Softmax}(z)Softmax(z),当logits很大时(比如 z=100z=100z=100),e100e^{100}e100 会是一个超级大的数,电脑算的时候会"溢出"(超出计算范围),导致结果出错。而 log_softmax 直接通过公式变换,避开了单独计算 Softmax(z)\text{Softmax}(z)Softmax(z),既保证了结果正确,又提高了计算效率------这就是为什么推荐 from_logits=True

第三步:批量数据的损失公式(实际训练用)

实际训练时,一次会喂给模型一批数据(比如2个样本),损失公式要算"批量内所有样本的平均损失",公式很简单:
KaTeX parse error: Expected 'EOF', got '_' at position 13: \text{batch_̲loss} = \frac{1...

拆解符号
符号 通俗含义 对应例子
NNN 批量大小(样本数量) N=2N=2N=2(2个样本)
H(pk,qk)H(p_k, q_k)H(pk,qk) 第 kkk 个样本的损失值 第1个样本损失0.223,第2个样本损失0.105
1N∑\frac{1}{N} \sumN1∑ 所有样本损失求和后,除以样本数(求平均) (0.223+0.105)/2≈0.164(0.223 + 0.105) / 2 ≈ 0.164(0.223+0.105)/2≈0.164
代入例子计算

之前的批量例子:

  • 样本1:标签1,置信度[0.1,0.8,0.1],损失0.223
  • 样本2:标签0,置信度[0.9,0.05,0.05],损失0.105
    批量损失 = (0.223+0.105)/2≈0.164(0.223 + 0.105) / 2 ≈ 0.164(0.223+0.105)/2≈0.164

和通俗讲解的结果完全一致!

最终总结:公式的核心逻辑(记3句话,再也不怕)

  1. 所有公式的本质:损失值 = -ln(模型对正确类别的置信度)(简化后核心);
  2. from_logits=True 的公式:只是"避免溢出的聪明计算方式",最终还是在算上面这句话;
  3. 批量损失公式:只是"多个样本损失的平均值",方便模型整体调整参数。
相关推荐
NAGNIP11 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab12 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab12 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP16 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年16 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼16 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS16 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区17 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈17 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang18 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx