0. 简介
线性回归是预测多少 的问题,而回归亦可被用于预测哪一个的问题:
- 某个电子邮件是否属于垃圾邮件文件夹?
- 某个用户可能注册 或不注册订阅服务?
- 某个图像描绘的是驴、狗、猫、还是鸡?
- 某人接下来最有可能看哪部电影?
0.1 softmax和hardmax
softmax本质上就是将一个序列,变成相应的概率,并满足以下条件:
- 所有概率值都在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 1 ] [0,1] </math>[0,1] 之间;
- 所有值加起来是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y j ^ = s o f t m a x ( o j ) = e o j ∑ i = 1 n e o i \hat{y_j} = softmax(o_j) = \frac{e^{o_j}}{\sum_{i=1}^{n}e^{o_i}} </math>yj^=softmax(oj)=∑i=1neoieoj
假设有个输入是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0 , 1.0 , 2.0 ] [0.0, 1.0, 2.0] </math>[0.0,1.0,2.0],那么softmax计算后就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0900 , 0.2447 , 0.6652 ] [0.0900, 0.2447, 0.6652] </math>[0.0900,0.2447,0.6652],假设输入变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0 , 1.0 , 3.0 ] [0.0, 1.0, 3.0] </math>[0.0,1.0,3.0],而softmax计算后就变成了 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0420 , 0.1142 , 0.8438 ] [0.0420, 0.1142, 0.8438] </math>[0.0420,0.1142,0.8438]。可以看出,softmax本质上有很强的马太效应 :强(大)的更强(大),弱(小)的更弱(小) 。
而hardmax就更暴力了,针对以上两种输入,输出都是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0 , 0.0 , 1.0 ] [0.0, 0.0, 1.0] </math>[0.0,0.0,1.0],这使得其函数本身的梯度是非常稀疏的,只有被选中的变量上面才有梯度,这对于一些任务来说几乎是不可接受的。
1. 基本推导
前面的具体的推导过程我就不详述了,可以参考3.4. softmax回归。这里我说一下我个人的疑惑:
- 为啥使用交叉熵损失而不是平方误差?
- 不管是前面的线性回归,还是现在的
softmax回归,其实真实的求导过程都是求得损失函数针对于参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> w \mathbf{w} </math>w 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 的导数,从而朝着梯度下降的方向挪动,为什么书上只推导了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ o j \frac{\partial L}{\partial o_j} </math>∂oj∂L,而省略了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ w j \frac{\partial L}{\partial w_j} </math>∂wj∂L?
1.1 交叉熵损失
假设我们选择平方误差作为softmax的损失函数:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = 1 2 ∑ i = 1 m ( y i ^ − y i ) 2 L = \frac{1}{2}\sum_{i=1}^{m}(\hat{y_i} - y_i)^2 </math>L=21i=1∑m(yi^−yi)2
那么其对模型logits输出的导数如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ∂ L ∂ y i ^ ⋅ ∂ y i ^ ∂ o j \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}\frac{\partial L}{\partial \hat{y_i}} \cdot \frac{\partial \hat{y_i}}{\partial o_j} </math>∂oj∂L=i=1∑m∂yi^∂L⋅∂oj∂yi^
根据之前的求导:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ y i ^ = y i ^ − y i \frac{\partial L}{\partial \hat{y_i}} = \hat{y_i} - y_i </math>∂yi^∂L=yi^−yi
而计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ y i ^ ∂ o j \frac{\partial \hat{y_i}}{\partial o_j} </math>∂oj∂yi^ 就需要分两种情况了:
当 <math xmlns="http://www.w3.org/1998/Math/MathML"> i = j i = j </math>i=j 时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y j ^ ∂ o j = ∂ ∂ o j ( e o j ∑ k = 1 m e o k ) = e o j ∑ k = 1 m e o k − e o j e o j ( ∑ k = 1 m e o k ) 2 = y i ^ ( 1 − y i ^ ) \frac{\partial \hat{y_j}}{\partial o_j} = \frac{\partial}{\partial o_j}(\frac{e^{o_j}}{\sum_{k=1}^{m}e^{o_k}}) = \frac{e^{o_j}\sum_{k=1}^{m}e^{o_k} - e^{o_j}e^{o_j}}{(\sum_{k=1}^{m}e^{o_k})^2} = \hat{y_i}(1-\hat{y_i}) </math>∂oj∂yj^=∂oj∂(∑k=1meokeoj)=(∑k=1meok)2eoj∑k=1meok−eojeoj=yi^(1−yi^)
当 <math xmlns="http://www.w3.org/1998/Math/MathML"> i ≠ j i \neq j </math>i=j 时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y i ^ ∂ o j = ∂ ∂ o j ( e o i ∑ k = 1 m e o k ) = 0 ⋅ ∑ k = 1 m e o k − e o i e o j ( ∑ k = 1 m e o k ) 2 = − y i ^ y j ^ \frac{\partial \hat{y_i}}{\partial o_j} = \frac{\partial}{\partial o_j}(\frac{e^{o_i}}{\sum_{k=1}^{m}e^{o_k}}) = \frac{0 \cdot \sum_{k=1}^{m}e^{o_k} - e^{o_i}e^{o_j}}{(\sum_{k=1}^{m}e^{o_k})^2} = -\hat{y_i}\hat{y_j} </math>∂oj∂yi^=∂oj∂(∑k=1meokeoi)=(∑k=1meok)20⋅∑k=1meok−eoieoj=−yi^yj^
统一表达式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y i ^ ∂ o j = y i ^ ( δ i j − y j ^ ) , δ i j = { 1 , i = j 0 , i ≠ j \frac{\partial \hat{y_i}}{\partial o_j} = \hat{y_i}(\delta_{ij}-\hat{y_j}), \delta_{ij} = \left\{ \begin{aligned} 1, i =j \\ 0, i \neq j \end{aligned} \right. </math>∂oj∂yi^=yi^(δij−yj^),δij={1,i=j0,i=j
所以最终的表达式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ( y i ^ − y i ) y i ^ ( δ i j − y i ^ ) \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}(\hat{y_i} - y_i)\hat{y_i}(\delta_{ij}-\hat{y_i}) </math>∂oj∂L=i=1∑m(yi^−yi)yi^(δij−yi^)
然后这个公式面临着梯度消失 问题,假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i = 1 y_i = 1 </math>yi=1,而预测值如下:
- 当预测正确时,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i ^ ≈ 1 \hat{y_i}\approx1 </math>yi^≈1,梯度趋近于0,此时正常;
- 当预测错误 时,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i ^ ≈ 0 \hat{y_i}\approx0 </math>yi^≈0,此时梯度也趋近于0,这就不正常了,因为预测值和实际值相差很远,即错误很大,但是更新的步长也很小,这就导致训练过程会非常缓慢甚至停滞,难以逃离这个"局部陷阱"!
而当我们使用交叉熵损失时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − ∑ i = 1 m y i log y j ^ L = -\sum_{i=1}^{m}y_i\log\hat{y_j} </math>L=−i=1∑myilogyj^
和上面一样:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ∂ L ∂ y i ^ ⋅ ∂ y i ^ ∂ o j \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}\frac{\partial L}{\partial \hat{y_i}} \cdot \frac{\partial \hat{y_i}}{\partial o_j} </math>∂oj∂L=i=1∑m∂yi^∂L⋅∂oj∂yi^
因为都是softmax表达式,所以:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y i ^ ∂ o j = y i ^ ( δ i j − y j ^ ) , δ i j = { 1 , i = j 0 , i ≠ j \frac{\partial \hat{y_i}}{\partial o_j} = \hat{y_i}(\delta_{ij}-\hat{y_j}), \delta_{ij} = \left\{ \begin{aligned} 1, i =j \\ 0, i \neq j \end{aligned} \right. </math>∂oj∂yi^=yi^(δij−yj^),δij={1,i=j0,i=j
重点是求导 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ y i ^ \frac{\partial L}{\partial \hat{y_i}} </math>∂yi^∂L:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ y i ^ = ∂ ∂ y i ^ ( − ∑ i = 1 m y i log y j ^ ) = − y i y i ^ \frac{\partial L}{\partial \hat{y_i}} = \frac{\partial}{\partial \hat{y_i}}(-\sum_{i=1}^{m}y_i\log\hat{y_j}) = -\frac{y_i}{\hat{y_i}} </math>∂yi^∂L=∂yi^∂(−i=1∑myilogyj^)=−yi^yi
其导数始终为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ( − y i y i ^ ) y i ^ ( δ i j − y j ^ ) = − ( ∑ i = 1 m y i δ i j − ∑ i = 1 m y i y j ^ ) \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}(-\frac{y_i}{\hat{y_i}})\hat{y_i}(\delta_{ij}-\hat{y_j}) = -(\sum_{i=1}^{m}y_i\delta_{ij} - \sum_{i=1}^{m}y_i\hat{y_j}) </math>∂oj∂L=i=1∑m(−yi^yi)yi^(δij−yj^)=−(i=1∑myiδij−i=1∑myiyj^)
因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ i = 1 m y i = 1 \sum_{i=1}^{m}y_i = 1 </math>∑i=1myi=1, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ i = 1 m y i δ i j = y j ( i = j 时起作用 ) \sum_{i=1}^{m}y_i\delta_{ij} = y_j(i =j时起作用) </math>∑i=1myiδij=yj(i=j时起作用),所以:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = y j ^ − y j \frac{\partial L}{\partial o_j} = \hat{y_j} - y_j </math>∂oj∂L=yj^−yj
就不会存在以上问题,其梯度与误差成正比,梯度越大,更新力度越大。
1.2 为什么书上只推导了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ o j \frac{\partial L}{\partial o_j} </math>∂oj∂L,而省略了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ w j \frac{\partial L}{\partial w_j} </math>∂wj∂L?
我们的最终必然是要求损失函数针对于参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> w \mathbf{w} </math>w 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 的导数,而以上推导是是对于softmax+交叉熵损失,softmax这一层的梯度公式就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> y j ^ − y j \hat{y_j} - y_j </math>yj^−yj,与前面网络的结构没有关系。
而在softmax之前:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O = X W + b \mathbf{O} = \mathbf{X}\mathbf{W} + \mathbf{b} </math>O=XW+b
这是一个线性层,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ o ∂ w j = x j \frac{\partial o}{\partial w_j} = x_j </math>∂wj∂o=xj。
它之所以被单独提出来详细推导,是因为:
- 它是一个核心且通用的结果:只要输出层是Softmax+交叉熵损失,这个梯度公式就永远是 Y^−Y,与网络前面的结构无关。
- 它极其简洁,体现了数学的美感。
- 理解这一步是理解整个反向传播过程的基础。现代的深度学习框架(如PyTorch、TensorFlow)会自动完成链式法则的其他部分,但理解这个起点的计算原理至关重要。
2. softmax回归为什么是线性模型
尽管softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定。 因此,softmax回归是一个线性模型(linear model)。
总的来说,softmax回归分为两个步骤:
线性部分(仿射变换): <math xmlns="http://www.w3.org/1998/Math/MathML"> o = w x + b \mathbf{o} = \mathbf{w}\mathbf{x} + b </math>o=wx+b;非线性部分(softmax函数): <math xmlns="http://www.w3.org/1998/Math/MathML"> y j ^ = s o f t m a x ( o j ) = e o j ∑ i = 1 n e o i \hat{y_j} = softmax(o_j) = \frac{e^{o_j}}{\sum_{i=1}^{n}e^{o_i}} </math>yj^=softmax(oj)=∑i=1neoieoj
而线性部分定义了决策的边界,最终的分类决策规则始终取决于线性部分的结果,第二部分只是将输出校准为有意义的概率值,便于解释和计算损失。
一个比喻: 想象你用三个不同的秤(对应三个类别)去称一个物体(对应一个样本)。每个秤给出一个读数(对应 <math xmlns="http://www.w3.org/1998/Math/MathML"> o j o_j </math>oj)。
- softmax函数的作用就像是把这三个读数都转换成"这个物体最可能被哪个秤称出来"的概率百分比。这是一个非线性转换。
- 但最终你判断物体属于哪个类别,依据仍然是哪个秤的原始读数最大。这个"找最大值"的决策规则是线性的。