TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(三)

用「通俗例子+公式拆解+逐步代入计算」的方式,把数学公式讲透------每一步都不跳,每个符号都解释,全程紧扣你已经懂的通俗逻辑,让公式成为"看得见、算得清"的工具,而不是抽象符号~

先明确核心目标:数学公式本质是把「"正确类别置信度越高,损失越低"」这个通俗逻辑,用精准的符号表达出来。咱们还是用之前的3分类例子(猫=0、狗=1、鸟=2),正确答案是狗(标签=1),全程围绕这个例子展开!

第一步:先搞懂交叉熵的核心公式(基础中的基础)

交叉熵的公式是:
H(p,q)=−∑i=1Cp(i)log⁡(q(i)) H(p,q) = -\sum_{i=1}^C p(i) \log(q(i)) H(p,q)=−i=1∑Cp(i)log(q(i))

逐个拆解符号(像查字典一样)
符号 通俗含义 对应咱们的例子
H(p,q)H(p,q)H(p,q) 交叉熵(损失值),衡量"真实标签"和"模型预测"的差异 最终要算的损失分
p(i)p(i)p(i) 真实标签的概率分布(iii 代表第 iii 个类别) 正确答案是狗(i=1i=1i=1),所以 p(0)=0p(0)=0p(0)=0(猫不是)、p(1)=1p(1)=1p(1)=1(狗是)、p(2)=0p(2)=0p(2)=0(鸟不是)
q(i)q(i)q(i) 模型预测的概率分布(iii 代表第 iii 个类别) 模型猜猫0.1、狗0.8、鸟0.1,所以 q(0)=0.1q(0)=0.1q(0)=0.1、q(1)=0.8q(1)=0.8q(1)=0.8、q(2)=0.1q(2)=0.1q(2)=0.1
∑i=1C\sum_{i=1}^C∑i=1C 求和符号(把所有类别 i=1i=1i=1 到 i=Ci=Ci=C 的结果加起来) C=3C=3C=3(3分类),所以要算 i=0i=0i=0、i=1i=1i=1、i=2i=2i=2 三项的和(注:公式里 i=1i=1i=1 是习惯,实际咱们从0开始也一样)
log⁡\loglog 自然对数(ln),核心作用:把"置信度"转换成"损失尺度"------置信度越高,log⁡(q(i))\log(q(i))log(q(i)) 越大,加负号后损失越小(刚好对应通俗逻辑) 比如 log⁡(0.8)≈−0.223\log(0.8)≈-0.223log(0.8)≈−0.223,log⁡(0.05)≈−2.996\log(0.05)≈-2.996log(0.05)≈−2.996
负号 −-− 反转对数结果的正负(因为 log⁡(0−1)\log(0-1)log(0−1) 的结果是负数,加负号后损失值为正,方便理解和计算) 比如 log⁡(0.8)≈−0.223\log(0.8)≈-0.223log(0.8)≈−0.223,加负号后是 0.2230.2230.223(损失值)
代入例子计算(手把手算一遍)

咱们的例子:p(0)=0p(0)=0p(0)=0、p(1)=1p(1)=1p(1)=1、p(2)=0p(2)=0p(2)=0;q(0)=0.1q(0)=0.1q(0)=0.1、q(1)=0.8q(1)=0.8q(1)=0.8、q(2)=0.1q(2)=0.1q(2)=0.1

把这些值代入公式,先算每一项,再求和:

  1. 计算 i=0i=0i=0(猫):p(0)×log⁡(q(0))=0×log⁡(0.1)=0×(−2.302)=0p(0) \times \log(q(0)) = 0 \times \log(0.1) = 0 \times (-2.302) = 0p(0)×log(q(0))=0×log(0.1)=0×(−2.302)=0
  2. 计算 i=1i=1i=1(狗):p(1)×log⁡(q(1))=1×log⁡(0.8)=1×(−0.223)=−0.223p(1) \times \log(q(1)) = 1 \times \log(0.8) = 1 \times (-0.223) = -0.223p(1)×log(q(1))=1×log(0.8)=1×(−0.223)=−0.223
  3. 计算 i=2i=2i=2(鸟):p(2)×log⁡(q(2))=0×log⁡(0.1)=0×(−2.302)=0p(2) \times \log(q(2)) = 0 \times \log(0.1) = 0 \times (-2.302) = 0p(2)×log(q(2))=0×log(0.1)=0×(−2.302)=0

然后求和:∑=0+(−0.223)+0=−0.223\sum = 0 + (-0.223) + 0 = -0.223∑=0+(−0.223)+0=−0.223

最后加负号:H(p,q)=−(−0.223)=0.223H(p,q) = -(-0.223) = 0.223H(p,q)=−(−0.223)=0.223

结果和之前通俗例子里的"单样本损失0.223"完全一致!

公式简化(关键结论,不用记复杂求和)

因为真实标签是「one-hot分布」(只有正确类别是1,其他都是0),所以求和时,除了"正确类别"那一项,其他项都是 0×log⁡(q(i))=00 \times \log(q(i))=00×log(q(i))=0,相当于"白算"。

所以交叉熵公式可以直接简化成:
H(p,q)=−log⁡(q(正确类别)) H(p,q) = -\log(q(\text{正确类别})) H(p,q)=−log(q(正确类别))

翻译成人话:损失值 = -ln(模型对正确类别的置信度)

这就是核心!之前的例子里,正确类别是狗,置信度0.8,所以损失就是 −ln⁡(0.8)≈0.223-\ln(0.8)≈0.223−ln(0.8)≈0.223,和计算结果一致~

第二步:理解 from_logits=True 时的公式(优化后的计算)

之前通俗讲解里提到:from_logits=True 时,模型输出的是「原始得分(logits)」(比如 z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]),不是0-1的置信度。这时候公式变了,但本质还是一样的------只是为了避免计算出错,换了一种更聪明的计算方式。

先搞懂两个关键转换(铺垫)
  1. Softmax转换 :把logits(原始得分)变成置信度(0-1,总和1)

    公式:Softmax(z)i=ezi∑j=1Cezj\text{Softmax}(z)i = \frac{e^{z_i}}{\sum{j=1}^C e^{z_j}}Softmax(z)i=∑j=1Cezjezi

    翻译:第 iii 类的置信度 = 第 iii 类的指数得分 / 所有类别的指数得分总和

    例子:z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]

    • 计算指数得分:e1.0≈2.718e^{1.0}≈2.718e1.0≈2.718、e3.0≈20.085e^{3.0}≈20.085e3.0≈20.085、e0.5≈1.648e^{0.5}≈1.648e0.5≈1.648
    • 总和:2.718+20.085+1.648≈24.4512.718+20.085+1.648≈24.4512.718+20.085+1.648≈24.451
    • 置信度:q(0)=2.718/24.451≈0.111q(0)=2.718/24.451≈0.111q(0)=2.718/24.451≈0.111、q(1)=20.085/24.451≈0.821q(1)=20.085/24.451≈0.821q(1)=20.085/24.451≈0.821、q(2)=1.648/24.451≈0.067q(2)=1.648/24.451≈0.067q(2)=1.648/24.451≈0.067
  2. log_softmax转换 :直接把logits变成"置信度的对数"(避免单独算Softmax导致溢出)

    公式:log⁡(Softmax(z))i=zi−log⁡(∑j=1Cezj)\log(\text{Softmax}(z))i = z_i - \log(\sum{j=1}^C e^{z_j})log(Softmax(z))i=zi−log(∑j=1Cezj)

    翻译:第 iii 类置信度的对数 = 第 iii 类原始得分 - 所有类别的指数得分总和的对数

    例子:z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]

    • 总和的对数:log⁡(24.451)≈3.200\log(24.451)≈3.200log(24.451)≈3.200
    • log⁡(Softmax(z))1\log(\text{Softmax}(z))_1log(Softmax(z))1(正确类别1):3.0−3.200=−0.2003.0 - 3.200 = -0.2003.0−3.200=−0.200
from_logits=True 时的损失公式

因为损失值是 −log⁡(q(正确类别))-\log(q(\text{正确类别}))−log(q(正确类别)),而 log⁡(q(正确类别))\log(q(\text{正确类别}))log(q(正确类别)) 就是 log⁡(Softmax(z))正确类别\log(\text{Softmax}(z)){\text{正确类别}}log(Softmax(z))正确类别,所以代入后:
loss=−log⁡(Softmax(z))正确类别=−(z正确类别−log⁡(∑j=1Cezj)) \text{loss} = - \log(\text{Softmax}(z))
{\text{正确类别}} = - \left( z_{\text{正确类别}} - \log(\sum_{j=1}^C e^{z_j}) \right) loss=−log(Softmax(z))正确类别=−(z正确类别−log(j=1∑Cezj))

代入例子计算(再算一遍,验证结果)

例子:logits z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5],正确类别1

  1. 计算 ∑ezj≈24.451\sum e^{z_j}≈24.451∑ezj≈24.451(和之前一样)
  2. 计算 log⁡(∑ezj)≈3.200\log(\sum e^{z_j})≈3.200log(∑ezj)≈3.200
  3. 计算 z正确类别=z1=3.0z_{\text{正确类别}} = z_1 = 3.0z正确类别=z1=3.0
  4. 代入公式:loss=−(3.0−3.200)=−(−0.200)=0.200\text{loss} = - (3.0 - 3.200) = -(-0.200) = 0.200loss=−(3.0−3.200)=−(−0.200)=0.200

这个结果和之前通俗讲解里的"0.200"一致!

为什么要这么麻烦?(核心优势)

如果直接算 Softmax(z)\text{Softmax}(z)Softmax(z),当logits很大时(比如 z=100z=100z=100),e100e^{100}e100 会是一个超级大的数,电脑算的时候会"溢出"(超出计算范围),导致结果出错。而 log_softmax 直接通过公式变换,避开了单独计算 Softmax(z)\text{Softmax}(z)Softmax(z),既保证了结果正确,又提高了计算效率------这就是为什么推荐 from_logits=True

第三步:批量数据的损失公式(实际训练用)

实际训练时,一次会喂给模型一批数据(比如2个样本),损失公式要算"批量内所有样本的平均损失",公式很简单:
KaTeX parse error: Expected 'EOF', got '_' at position 13: \text{batch_̲loss} = \frac{1...

拆解符号
符号 通俗含义 对应例子
NNN 批量大小(样本数量) N=2N=2N=2(2个样本)
H(pk,qk)H(p_k, q_k)H(pk,qk) 第 kkk 个样本的损失值 第1个样本损失0.223,第2个样本损失0.105
1N∑\frac{1}{N} \sumN1∑ 所有样本损失求和后,除以样本数(求平均) (0.223+0.105)/2≈0.164(0.223 + 0.105) / 2 ≈ 0.164(0.223+0.105)/2≈0.164
代入例子计算

之前的批量例子:

  • 样本1:标签1,置信度[0.1,0.8,0.1],损失0.223
  • 样本2:标签0,置信度[0.9,0.05,0.05],损失0.105
    批量损失 = (0.223+0.105)/2≈0.164(0.223 + 0.105) / 2 ≈ 0.164(0.223+0.105)/2≈0.164

和通俗讲解的结果完全一致!

最终总结:公式的核心逻辑(记3句话,再也不怕)

  1. 所有公式的本质:损失值 = -ln(模型对正确类别的置信度)(简化后核心);
  2. from_logits=True 的公式:只是"避免溢出的聪明计算方式",最终还是在算上面这句话;
  3. 批量损失公式:只是"多个样本损失的平均值",方便模型整体调整参数。
相关推荐
Boll09660002 小时前
开关柜设备状态识别与分类_YOLO11_C3k2_RetBlock实现
人工智能·分类·数据挖掘
一车小面包2 小时前
知识点12.22
人工智能
serve the people2 小时前
TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(四)
人工智能·分类·tensorflow
博一波2 小时前
在技术转型中重温基础:机器学习核心领域梳理
人工智能·机器学习
子洋2 小时前
AI Agent 设计模式 - ReAct 模式
前端·人工智能·后端
likerhood2 小时前
6. pytorch 卷积神经网络
人工智能·pytorch·神经网络
受伤的僵尸2 小时前
算法类复习(1)-非自注意力机制(图像处理中的注意力)
人工智能·算法
AI technophile2 小时前
OpenCV计算机视觉实战(33)——文字识别详解
人工智能·opencv·计算机视觉
囊中之锥.2 小时前
机器学习:认识随机森林
人工智能·随机森林·机器学习