用「通俗例子+公式拆解+逐步代入计算」的方式,把数学公式讲透------每一步都不跳,每个符号都解释,全程紧扣你已经懂的通俗逻辑,让公式成为"看得见、算得清"的工具,而不是抽象符号~
先明确核心目标:数学公式本质是把「"正确类别置信度越高,损失越低"」这个通俗逻辑,用精准的符号表达出来。咱们还是用之前的3分类例子(猫=0、狗=1、鸟=2),正确答案是狗(标签=1),全程围绕这个例子展开!
第一步:先搞懂交叉熵的核心公式(基础中的基础)
交叉熵的公式是:
H(p,q)=−∑i=1Cp(i)log(q(i)) H(p,q) = -\sum_{i=1}^C p(i) \log(q(i)) H(p,q)=−i=1∑Cp(i)log(q(i))
逐个拆解符号(像查字典一样)
| 符号 | 通俗含义 | 对应咱们的例子 |
|---|---|---|
| H(p,q)H(p,q)H(p,q) | 交叉熵(损失值),衡量"真实标签"和"模型预测"的差异 | 最终要算的损失分 |
| p(i)p(i)p(i) | 真实标签的概率分布(iii 代表第 iii 个类别) | 正确答案是狗(i=1i=1i=1),所以 p(0)=0p(0)=0p(0)=0(猫不是)、p(1)=1p(1)=1p(1)=1(狗是)、p(2)=0p(2)=0p(2)=0(鸟不是) |
| q(i)q(i)q(i) | 模型预测的概率分布(iii 代表第 iii 个类别) | 模型猜猫0.1、狗0.8、鸟0.1,所以 q(0)=0.1q(0)=0.1q(0)=0.1、q(1)=0.8q(1)=0.8q(1)=0.8、q(2)=0.1q(2)=0.1q(2)=0.1 |
| ∑i=1C\sum_{i=1}^C∑i=1C | 求和符号(把所有类别 i=1i=1i=1 到 i=Ci=Ci=C 的结果加起来) | C=3C=3C=3(3分类),所以要算 i=0i=0i=0、i=1i=1i=1、i=2i=2i=2 三项的和(注:公式里 i=1i=1i=1 是习惯,实际咱们从0开始也一样) |
| log\loglog | 自然对数(ln),核心作用:把"置信度"转换成"损失尺度"------置信度越高,log(q(i))\log(q(i))log(q(i)) 越大,加负号后损失越小(刚好对应通俗逻辑) | 比如 log(0.8)≈−0.223\log(0.8)≈-0.223log(0.8)≈−0.223,log(0.05)≈−2.996\log(0.05)≈-2.996log(0.05)≈−2.996 |
| 负号 −-− | 反转对数结果的正负(因为 log(0−1)\log(0-1)log(0−1) 的结果是负数,加负号后损失值为正,方便理解和计算) | 比如 log(0.8)≈−0.223\log(0.8)≈-0.223log(0.8)≈−0.223,加负号后是 0.2230.2230.223(损失值) |
代入例子计算(手把手算一遍)
咱们的例子:p(0)=0p(0)=0p(0)=0、p(1)=1p(1)=1p(1)=1、p(2)=0p(2)=0p(2)=0;q(0)=0.1q(0)=0.1q(0)=0.1、q(1)=0.8q(1)=0.8q(1)=0.8、q(2)=0.1q(2)=0.1q(2)=0.1
把这些值代入公式,先算每一项,再求和:
- 计算 i=0i=0i=0(猫):p(0)×log(q(0))=0×log(0.1)=0×(−2.302)=0p(0) \times \log(q(0)) = 0 \times \log(0.1) = 0 \times (-2.302) = 0p(0)×log(q(0))=0×log(0.1)=0×(−2.302)=0
- 计算 i=1i=1i=1(狗):p(1)×log(q(1))=1×log(0.8)=1×(−0.223)=−0.223p(1) \times \log(q(1)) = 1 \times \log(0.8) = 1 \times (-0.223) = -0.223p(1)×log(q(1))=1×log(0.8)=1×(−0.223)=−0.223
- 计算 i=2i=2i=2(鸟):p(2)×log(q(2))=0×log(0.1)=0×(−2.302)=0p(2) \times \log(q(2)) = 0 \times \log(0.1) = 0 \times (-2.302) = 0p(2)×log(q(2))=0×log(0.1)=0×(−2.302)=0
然后求和:∑=0+(−0.223)+0=−0.223\sum = 0 + (-0.223) + 0 = -0.223∑=0+(−0.223)+0=−0.223
最后加负号:H(p,q)=−(−0.223)=0.223H(p,q) = -(-0.223) = 0.223H(p,q)=−(−0.223)=0.223
结果和之前通俗例子里的"单样本损失0.223"完全一致!
公式简化(关键结论,不用记复杂求和)
因为真实标签是「one-hot分布」(只有正确类别是1,其他都是0),所以求和时,除了"正确类别"那一项,其他项都是 0×log(q(i))=00 \times \log(q(i))=00×log(q(i))=0,相当于"白算"。
所以交叉熵公式可以直接简化成:
H(p,q)=−log(q(正确类别)) H(p,q) = -\log(q(\text{正确类别})) H(p,q)=−log(q(正确类别))
翻译成人话:损失值 = -ln(模型对正确类别的置信度)
这就是核心!之前的例子里,正确类别是狗,置信度0.8,所以损失就是 −ln(0.8)≈0.223-\ln(0.8)≈0.223−ln(0.8)≈0.223,和计算结果一致~
第二步:理解 from_logits=True 时的公式(优化后的计算)
之前通俗讲解里提到:from_logits=True 时,模型输出的是「原始得分(logits)」(比如 z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]),不是0-1的置信度。这时候公式变了,但本质还是一样的------只是为了避免计算出错,换了一种更聪明的计算方式。
先搞懂两个关键转换(铺垫)
-
Softmax转换 :把logits(原始得分)变成置信度(0-1,总和1)
公式:Softmax(z)i=ezi∑j=1Cezj\text{Softmax}(z)i = \frac{e^{z_i}}{\sum{j=1}^C e^{z_j}}Softmax(z)i=∑j=1Cezjezi
翻译:第 iii 类的置信度 = 第 iii 类的指数得分 / 所有类别的指数得分总和
例子:z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]
- 计算指数得分:e1.0≈2.718e^{1.0}≈2.718e1.0≈2.718、e3.0≈20.085e^{3.0}≈20.085e3.0≈20.085、e0.5≈1.648e^{0.5}≈1.648e0.5≈1.648
- 总和:2.718+20.085+1.648≈24.4512.718+20.085+1.648≈24.4512.718+20.085+1.648≈24.451
- 置信度:q(0)=2.718/24.451≈0.111q(0)=2.718/24.451≈0.111q(0)=2.718/24.451≈0.111、q(1)=20.085/24.451≈0.821q(1)=20.085/24.451≈0.821q(1)=20.085/24.451≈0.821、q(2)=1.648/24.451≈0.067q(2)=1.648/24.451≈0.067q(2)=1.648/24.451≈0.067
-
log_softmax转换 :直接把logits变成"置信度的对数"(避免单独算Softmax导致溢出)
公式:log(Softmax(z))i=zi−log(∑j=1Cezj)\log(\text{Softmax}(z))i = z_i - \log(\sum{j=1}^C e^{z_j})log(Softmax(z))i=zi−log(∑j=1Cezj)
翻译:第 iii 类置信度的对数 = 第 iii 类原始得分 - 所有类别的指数得分总和的对数
例子:z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]
- 总和的对数:log(24.451)≈3.200\log(24.451)≈3.200log(24.451)≈3.200
- log(Softmax(z))1\log(\text{Softmax}(z))_1log(Softmax(z))1(正确类别1):3.0−3.200=−0.2003.0 - 3.200 = -0.2003.0−3.200=−0.200
from_logits=True 时的损失公式
因为损失值是 −log(q(正确类别))-\log(q(\text{正确类别}))−log(q(正确类别)),而 log(q(正确类别))\log(q(\text{正确类别}))log(q(正确类别)) 就是 log(Softmax(z))正确类别\log(\text{Softmax}(z)){\text{正确类别}}log(Softmax(z))正确类别,所以代入后:
loss=−log(Softmax(z))正确类别=−(z正确类别−log(∑j=1Cezj)) \text{loss} = - \log(\text{Softmax}(z)){\text{正确类别}} = - \left( z_{\text{正确类别}} - \log(\sum_{j=1}^C e^{z_j}) \right) loss=−log(Softmax(z))正确类别=−(z正确类别−log(j=1∑Cezj))
代入例子计算(再算一遍,验证结果)
例子:logits z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5],正确类别1
- 计算 ∑ezj≈24.451\sum e^{z_j}≈24.451∑ezj≈24.451(和之前一样)
- 计算 log(∑ezj)≈3.200\log(\sum e^{z_j})≈3.200log(∑ezj)≈3.200
- 计算 z正确类别=z1=3.0z_{\text{正确类别}} = z_1 = 3.0z正确类别=z1=3.0
- 代入公式:loss=−(3.0−3.200)=−(−0.200)=0.200\text{loss} = - (3.0 - 3.200) = -(-0.200) = 0.200loss=−(3.0−3.200)=−(−0.200)=0.200
这个结果和之前通俗讲解里的"0.200"一致!
为什么要这么麻烦?(核心优势)
如果直接算 Softmax(z)\text{Softmax}(z)Softmax(z),当logits很大时(比如 z=100z=100z=100),e100e^{100}e100 会是一个超级大的数,电脑算的时候会"溢出"(超出计算范围),导致结果出错。而 log_softmax 直接通过公式变换,避开了单独计算 Softmax(z)\text{Softmax}(z)Softmax(z),既保证了结果正确,又提高了计算效率------这就是为什么推荐 from_logits=True!
第三步:批量数据的损失公式(实际训练用)
实际训练时,一次会喂给模型一批数据(比如2个样本),损失公式要算"批量内所有样本的平均损失",公式很简单:
KaTeX parse error: Expected 'EOF', got '_' at position 13: \text{batch_̲loss} = \frac{1...
拆解符号
| 符号 | 通俗含义 | 对应例子 |
|---|---|---|
| NNN | 批量大小(样本数量) | N=2N=2N=2(2个样本) |
| H(pk,qk)H(p_k, q_k)H(pk,qk) | 第 kkk 个样本的损失值 | 第1个样本损失0.223,第2个样本损失0.105 |
| 1N∑\frac{1}{N} \sumN1∑ | 所有样本损失求和后,除以样本数(求平均) | (0.223+0.105)/2≈0.164(0.223 + 0.105) / 2 ≈ 0.164(0.223+0.105)/2≈0.164 |
代入例子计算
之前的批量例子:
- 样本1:标签1,置信度[0.1,0.8,0.1],损失0.223
- 样本2:标签0,置信度[0.9,0.05,0.05],损失0.105
批量损失 = (0.223+0.105)/2≈0.164(0.223 + 0.105) / 2 ≈ 0.164(0.223+0.105)/2≈0.164
和通俗讲解的结果完全一致!
最终总结:公式的核心逻辑(记3句话,再也不怕)
- 所有公式的本质:损失值 = -ln(模型对正确类别的置信度)(简化后核心);
from_logits=True的公式:只是"避免溢出的聪明计算方式",最终还是在算上面这句话;- 批量损失公式:只是"多个样本损失的平均值",方便模型整体调整参数。