D2L(2) — softmax回归

0. 简介

线性回归是预测多少 的问题,而回归亦可被用于预测哪一个的问题:

  • 某个电子邮件是否属于垃圾邮件文件夹?
  • 某个用户可能注册不注册订阅服务?
  • 某个图像描绘的是驴、狗、猫、还是鸡?
  • 某人接下来最有可能看哪部电影?

0.1 softmax和hardmax

softmax本质上就是将一个序列,变成相应的概率,并满足以下条件:

  • 所有概率值都在 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0 , 1 ] [0,1] </math>[0,1] 之间;
  • 所有值加起来是 <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 1 </math>1。

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y j ^ = s o f t m a x ( o j ) = e o j ∑ i = 1 n e o i \hat{y_j} = softmax(o_j) = \frac{e^{o_j}}{\sum_{i=1}^{n}e^{o_i}} </math>yj^=softmax(oj)=∑i=1neoieoj

假设有个输入是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0 , 1.0 , 2.0 ] [0.0, 1.0, 2.0] </math>[0.0,1.0,2.0],那么softmax计算后就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0900 , 0.2447 , 0.6652 ] [0.0900, 0.2447, 0.6652] </math>[0.0900,0.2447,0.6652],假设输入变为 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0 , 1.0 , 3.0 ] [0.0, 1.0, 3.0] </math>[0.0,1.0,3.0],而softmax计算后就变成了 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0420 , 0.1142 , 0.8438 ] [0.0420, 0.1142, 0.8438] </math>[0.0420,0.1142,0.8438]。可以看出,softmax本质上有很强的马太效应强(大)的更强(大),弱(小)的更弱(小)

hardmax就更暴力了,针对以上两种输入,输出都是 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ 0.0 , 0.0 , 1.0 ] [0.0, 0.0, 1.0] </math>[0.0,0.0,1.0],这使得其函数本身的梯度是非常稀疏的,只有被选中的变量上面才有梯度,这对于一些任务来说几乎是不可接受的。

1. 基本推导

前面的具体的推导过程我就不详述了,可以参考3.4. softmax回归。这里我说一下我个人的疑惑:

  1. 为啥使用交叉熵损失而不是平方误差?
  2. 不管是前面的线性回归,还是现在的softmax回归,其实真实的求导过程都是求得损失函数针对于参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> w \mathbf{w} </math>w 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 的导数,从而朝着梯度下降的方向挪动,为什么书上只推导了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ o j \frac{\partial L}{\partial o_j} </math>∂oj∂L,而省略了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ w j \frac{\partial L}{\partial w_j} </math>∂wj∂L?

1.1 交叉熵损失

假设我们选择平方误差作为softmax的损失函数
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = 1 2 ∑ i = 1 m ( y i ^ − y i ) 2 L = \frac{1}{2}\sum_{i=1}^{m}(\hat{y_i} - y_i)^2 </math>L=21i=1∑m(yi^−yi)2

那么其对模型logits输出的导数如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ∂ L ∂ y i ^ ⋅ ∂ y i ^ ∂ o j \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}\frac{\partial L}{\partial \hat{y_i}} \cdot \frac{\partial \hat{y_i}}{\partial o_j} </math>∂oj∂L=i=1∑m∂yi^∂L⋅∂oj∂yi^

根据之前的求导:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ y i ^ = y i ^ − y i \frac{\partial L}{\partial \hat{y_i}} = \hat{y_i} - y_i </math>∂yi^∂L=yi^−yi

而计算 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ y i ^ ∂ o j \frac{\partial \hat{y_i}}{\partial o_j} </math>∂oj∂yi^ 就需要分两种情况了:

当 <math xmlns="http://www.w3.org/1998/Math/MathML"> i = j i = j </math>i=j 时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y j ^ ∂ o j = ∂ ∂ o j ( e o j ∑ k = 1 m e o k ) = e o j ∑ k = 1 m e o k − e o j e o j ( ∑ k = 1 m e o k ) 2 = y i ^ ( 1 − y i ^ ) \frac{\partial \hat{y_j}}{\partial o_j} = \frac{\partial}{\partial o_j}(\frac{e^{o_j}}{\sum_{k=1}^{m}e^{o_k}}) = \frac{e^{o_j}\sum_{k=1}^{m}e^{o_k} - e^{o_j}e^{o_j}}{(\sum_{k=1}^{m}e^{o_k})^2} = \hat{y_i}(1-\hat{y_i}) </math>∂oj∂yj^=∂oj∂(∑k=1meokeoj)=(∑k=1meok)2eoj∑k=1meok−eojeoj=yi^(1−yi^)

当 <math xmlns="http://www.w3.org/1998/Math/MathML"> i ≠ j i \neq j </math>i=j 时:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y i ^ ∂ o j = ∂ ∂ o j ( e o i ∑ k = 1 m e o k ) = 0 ⋅ ∑ k = 1 m e o k − e o i e o j ( ∑ k = 1 m e o k ) 2 = − y i ^ y j ^ \frac{\partial \hat{y_i}}{\partial o_j} = \frac{\partial}{\partial o_j}(\frac{e^{o_i}}{\sum_{k=1}^{m}e^{o_k}}) = \frac{0 \cdot \sum_{k=1}^{m}e^{o_k} - e^{o_i}e^{o_j}}{(\sum_{k=1}^{m}e^{o_k})^2} = -\hat{y_i}\hat{y_j} </math>∂oj∂yi^=∂oj∂(∑k=1meokeoi)=(∑k=1meok)20⋅∑k=1meok−eoieoj=−yi^yj^

统一表达式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y i ^ ∂ o j = y i ^ ( δ i j − y j ^ ) , δ i j = { 1 , i = j 0 , i ≠ j \frac{\partial \hat{y_i}}{\partial o_j} = \hat{y_i}(\delta_{ij}-\hat{y_j}), \delta_{ij} = \left\{ \begin{aligned} 1, i =j \\ 0, i \neq j \end{aligned} \right. </math>∂oj∂yi^=yi^(δij−yj^),δij={1,i=j0,i=j

所以最终的表达式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ( y i ^ − y i ) y i ^ ( δ i j − y i ^ ) \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}(\hat{y_i} - y_i)\hat{y_i}(\delta_{ij}-\hat{y_i}) </math>∂oj∂L=i=1∑m(yi^−yi)yi^(δij−yi^)

然后这个公式面临着梯度消失 问题,假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i = 1 y_i = 1 </math>yi=1,而预测值如下:

  • 当预测正确时,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i ^ ≈ 1 \hat{y_i}\approx1 </math>yi^≈1,梯度趋近于0,此时正常;
  • 当预测错误 时,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> y i ^ ≈ 0 \hat{y_i}\approx0 </math>yi^≈0,此时梯度也趋近于0,这就不正常了,因为预测值和实际值相差很远,即错误很大,但是更新的步长也很小,这就导致训练过程会非常缓慢甚至停滞,难以逃离这个"局部陷阱"!

而当我们使用交叉熵损失时
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L = − ∑ i = 1 m y i log ⁡ y j ^ L = -\sum_{i=1}^{m}y_i\log\hat{y_j} </math>L=−i=1∑myilogyj^

和上面一样:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ∂ L ∂ y i ^ ⋅ ∂ y i ^ ∂ o j \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}\frac{\partial L}{\partial \hat{y_i}} \cdot \frac{\partial \hat{y_i}}{\partial o_j} </math>∂oj∂L=i=1∑m∂yi^∂L⋅∂oj∂yi^

因为都是softmax表达式,所以:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ y i ^ ∂ o j = y i ^ ( δ i j − y j ^ ) , δ i j = { 1 , i = j 0 , i ≠ j \frac{\partial \hat{y_i}}{\partial o_j} = \hat{y_i}(\delta_{ij}-\hat{y_j}), \delta_{ij} = \left\{ \begin{aligned} 1, i =j \\ 0, i \neq j \end{aligned} \right. </math>∂oj∂yi^=yi^(δij−yj^),δij={1,i=j0,i=j

重点是求导 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ y i ^ \frac{\partial L}{\partial \hat{y_i}} </math>∂yi^∂L:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ y i ^ = ∂ ∂ y i ^ ( − ∑ i = 1 m y i log ⁡ y j ^ ) = − y i y i ^ \frac{\partial L}{\partial \hat{y_i}} = \frac{\partial}{\partial \hat{y_i}}(-\sum_{i=1}^{m}y_i\log\hat{y_j}) = -\frac{y_i}{\hat{y_i}} </math>∂yi^∂L=∂yi^∂(−i=1∑myilogyj^)=−yi^yi

其导数始终为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = ∑ i = 1 m ( − y i y i ^ ) y i ^ ( δ i j − y j ^ ) = − ( ∑ i = 1 m y i δ i j − ∑ i = 1 m y i y j ^ ) \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}(-\frac{y_i}{\hat{y_i}})\hat{y_i}(\delta_{ij}-\hat{y_j}) = -(\sum_{i=1}^{m}y_i\delta_{ij} - \sum_{i=1}^{m}y_i\hat{y_j}) </math>∂oj∂L=i=1∑m(−yi^yi)yi^(δij−yj^)=−(i=1∑myiδij−i=1∑myiyj^)

因为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ i = 1 m y i = 1 \sum_{i=1}^{m}y_i = 1 </math>∑i=1myi=1, <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ i = 1 m y i δ i j = y j ( i = j 时起作用 ) \sum_{i=1}^{m}y_i\delta_{ij} = y_j(i =j时起作用) </math>∑i=1myiδij=yj(i=j时起作用),所以:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> ∂ L ∂ o j = y j ^ − y j \frac{\partial L}{\partial o_j} = \hat{y_j} - y_j </math>∂oj∂L=yj^−yj

就不会存在以上问题,其梯度与误差成正比,梯度越大,更新力度越大。

1.2 为什么书上只推导了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ o j \frac{\partial L}{\partial o_j} </math>∂oj∂L,而省略了 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ L ∂ w j \frac{\partial L}{\partial w_j} </math>∂wj∂L?

我们的最终必然是要求损失函数针对于参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> w \mathbf{w} </math>w 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> b b </math>b 的导数,而以上推导是是对于softmax+交叉熵损失,softmax这一层的梯度公式就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> y j ^ − y j \hat{y_j} - y_j </math>yj^−yj,与前面网络的结构没有关系。

而在softmax之前:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> O = X W + b \mathbf{O} = \mathbf{X}\mathbf{W} + \mathbf{b} </math>O=XW+b

这是一个线性层,所以 <math xmlns="http://www.w3.org/1998/Math/MathML"> ∂ o ∂ w j = x j \frac{\partial o}{\partial w_j} = x_j </math>∂wj∂o=xj。

它之所以被单独提出来详细推导,是因为:

  1. 它是一个核心且通用的结果:只要输出层是Softmax+交叉熵损失,这个梯度公式就永远是 Y^−Y,与网络前面的结构无关。
  2. 它极其简洁,体现了数学的美感。
  3. 理解这一步是理解整个反向传播过程的基础。现代的深度学习框架(如PyTorch、TensorFlow)会自动完成链式法则的其他部分,但理解这个起点的计算原理至关重要。

2. softmax回归为什么是线性模型

尽管softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定。 因此,softmax回归是一个线性模型(linear model)。

总的来说,softmax回归分为两个步骤:

  1. 线性部分(仿射变换): <math xmlns="http://www.w3.org/1998/Math/MathML"> o = w x + b \mathbf{o} = \mathbf{w}\mathbf{x} + b </math>o=wx+b;
  2. 非线性部分(softmax函数): <math xmlns="http://www.w3.org/1998/Math/MathML"> y j ^ = s o f t m a x ( o j ) = e o j ∑ i = 1 n e o i \hat{y_j} = softmax(o_j) = \frac{e^{o_j}}{\sum_{i=1}^{n}e^{o_i}} </math>yj^=softmax(oj)=∑i=1neoieoj

而线性部分定义了决策的边界,最终的分类决策规则始终取决于线性部分的结果,第二部分只是将输出校准为有意义的概率值,便于解释和计算损失。

一个比喻: 想象你用三个不同的秤(对应三个类别)去称一个物体(对应一个样本)。每个秤给出一个读数(对应 <math xmlns="http://www.w3.org/1998/Math/MathML"> o j o_j </math>oj)。

  • softmax函数的作用就像是把这三个读数都转换成"这个物体最可能被哪个秤称出来"的概率百分比。这是一个非线性转换。
  • 但最终你判断物体属于哪个类别,依据仍然是哪个秤的原始读数最大。这个"找最大值"的决策规则是线性的。
相关推荐
码事漫谈2 小时前
一文读懂“本体论”这个时髦词
后端
码事漫谈2 小时前
C++线程编程模型演进:从Pthread到jthread的技术革命
后端
半夏知半秋2 小时前
kcp学习-skynet中的kcp绑定
开发语言·笔记·后端·学习
szm02253 小时前
Spring
java·后端·spring
AlexDeng3 小时前
EF Core 开发实践:Left Join 查询的多种实现方式
后端
马卡巴卡3 小时前
用Spring的ApplicationEventPublisher进行事件发布和监听
后端
y***n6144 小时前
springboot项目架构
spring boot·后端·架构
无名之辈J4 小时前
生产环境慢 SQL 排查与优化
后端
悟能不能悟4 小时前
jasper里面$F和$P的区别
开发语言·后端