D2L(2) — softmax回归

0. 简介

线性回归是预测多少 的问题,而回归亦可被用于预测哪一个的问题:

  • 某个电子邮件是否属于垃圾邮件文件夹?
  • 某个用户可能注册不注册订阅服务?
  • 某个图像描绘的是驴、狗、猫、还是鸡?
  • 某人接下来最有可能看哪部电影?

0.1 softmax和hardmax

softmax本质上就是将一个序列,变成相应的概率,并满足以下条件:

  • 所有概率值都在 0 , 1 0,1 0,1 之间;
  • 所有值加起来是 1 1 1。

y j ^ = s o f t m a x ( o j ) = e o j ∑ i = 1 n e o i \hat{y_j} = softmax(o_j) = \frac{e^{o_j}}{\sum_{i=1}^{n}e^{o_i}} yj^=softmax(oj)=∑i=1neoieoj

假设有个输入是 0.0 , 1.0 , 2.0 0.0, 1.0, 2.0 0.0,1.0,2.0,那么softmax计算后就是 0.0900 , 0.2447 , 0.6652 0.0900, 0.2447, 0.6652 0.0900,0.2447,0.6652,假设输入变为 0.0 , 1.0 , 3.0 0.0, 1.0, 3.0 0.0,1.0,3.0,而softmax计算后就变成了 0.0420 , 0.1142 , 0.8438 0.0420, 0.1142, 0.8438 0.0420,0.1142,0.8438。可以看出,softmax本质上有很强的马太效应强(大)的更强(大),弱(小)的更弱(小)

hardmax就更暴力了,针对以上两种输入,输出都是 0.0 , 0.0 , 1.0 0.0, 0.0, 1.0 0.0,0.0,1.0,这使得其函数本身的梯度是非常稀疏的,只有被选中的变量上面才有梯度,这对于一些任务来说几乎是不可接受的。

1. 基本推导

前面的具体的推导过程我就不详述了,可以参考3.4. softmax回归。这里我说一下我个人的疑惑:

  1. 为啥使用交叉熵损失而不是平方误差?
  2. 不管是前面的线性回归,还是现在的softmax回归,其实真实的求导过程都是求得损失函数针对于参数 w \mathbf{w} w 和 b b b 的导数,从而朝着梯度下降的方向挪动,为什么书上只推导了 ∂ L ∂ o j \frac{\partial L}{\partial o_j} ∂oj∂L,而省略了 ∂ L ∂ w j \frac{\partial L}{\partial w_j} ∂wj∂L?

1.1 交叉熵损失

假设我们选择平方误差作为softmax的损失函数
L = 1 2 ∑ i = 1 m ( y i ^ − y i ) 2 L = \frac{1}{2}\sum_{i=1}^{m}(\hat{y_i} - y_i)^2 L=21i=1∑m(yi^−yi)2

那么其对模型logits输出的导数如下:
∂ L ∂ o j = ∑ i = 1 m ∂ L ∂ y i ^ ⋅ ∂ y i ^ ∂ o j \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}\frac{\partial L}{\partial \hat{y_i}} \cdot \frac{\partial \hat{y_i}}{\partial o_j} ∂oj∂L=i=1∑m∂yi^∂L⋅∂oj∂yi^

根据之前的求导:
∂ L ∂ y i ^ = y i ^ − y i \frac{\partial L}{\partial \hat{y_i}} = \hat{y_i} - y_i ∂yi^∂L=yi^−yi

而计算 ∂ y i ^ ∂ o j \frac{\partial \hat{y_i}}{\partial o_j} ∂oj∂yi^ 就需要分两种情况了:

i = j i = j i=j 时:
∂ y j ^ ∂ o j = ∂ ∂ o j ( e o j ∑ k = 1 m e o k ) = e o j ∑ k = 1 m e o k − e o j e o j ( ∑ k = 1 m e o k ) 2 = y i ^ ( 1 − y i ^ ) \frac{\partial \hat{y_j}}{\partial o_j} = \frac{\partial}{\partial o_j}(\frac{e^{o_j}}{\sum_{k=1}^{m}e^{o_k}}) = \frac{e^{o_j}\sum_{k=1}^{m}e^{o_k} - e^{o_j}e^{o_j}}{(\sum_{k=1}^{m}e^{o_k})^2} = \hat{y_i}(1-\hat{y_i}) ∂oj∂yj^=∂oj∂(∑k=1meokeoj)=(∑k=1meok)2eoj∑k=1meok−eojeoj=yi^(1−yi^)

i ≠ j i \neq j i=j 时:
∂ y i ^ ∂ o j = ∂ ∂ o j ( e o i ∑ k = 1 m e o k ) = 0 ⋅ ∑ k = 1 m e o k − e o i e o j ( ∑ k = 1 m e o k ) 2 = − y i ^ y j ^ \frac{\partial \hat{y_i}}{\partial o_j} = \frac{\partial}{\partial o_j}(\frac{e^{o_i}}{\sum_{k=1}^{m}e^{o_k}}) = \frac{0 \cdot \sum_{k=1}^{m}e^{o_k} - e^{o_i}e^{o_j}}{(\sum_{k=1}^{m}e^{o_k})^2} = -\hat{y_i}\hat{y_j} ∂oj∂yi^=∂oj∂(∑k=1meokeoi)=(∑k=1meok)20⋅∑k=1meok−eoieoj=−yi^yj^

统一表达式:
∂ y i ^ ∂ o j = y i ^ ( δ i j − y j ^ ) , δ i j = { 1 , i = j 0 , i ≠ j \frac{\partial \hat{y_i}}{\partial o_j} = \hat{y_i}(\delta_{ij}-\hat{y_j}), \delta_{ij} = \left\{ \begin{aligned} 1, i =j \\ 0, i \neq j \end{aligned} \right. ∂oj∂yi^=yi^(δij−yj^),δij={1,i=j0,i=j

所以最终的表达式为:
∂ L ∂ o j = ∑ i = 1 m ( y i ^ − y i ) y i ^ ( δ i j − y i ^ ) \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}(\hat{y_i} - y_i)\hat{y_i}(\delta_{ij}-\hat{y_i}) ∂oj∂L=i=1∑m(yi^−yi)yi^(δij−yi^)

然后这个公式面临着梯度消失 问题,假设 y i = 1 y_i = 1 yi=1,而预测值如下:

  • 当预测正确时,即 y i ^ ≈ 1 \hat{y_i}\approx1 yi^≈1,梯度趋近于0,此时正常;
  • 当预测错误 时,即 y i ^ ≈ 0 \hat{y_i}\approx0 yi^≈0,此时梯度也趋近于0,这就不正常了,因为预测值和实际值相差很远,即错误很大,但是更新的步长也很小,这就导致训练过程会非常缓慢甚至停滞,难以逃离这个"局部陷阱"!

而当我们使用交叉熵损失时
L = − ∑ i = 1 m y i log ⁡ y j ^ L = -\sum_{i=1}^{m}y_i\log\hat{y_j} L=−i=1∑myilogyj^

和上面一样:
∂ L ∂ o j = ∑ i = 1 m ∂ L ∂ y i ^ ⋅ ∂ y i ^ ∂ o j \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}\frac{\partial L}{\partial \hat{y_i}} \cdot \frac{\partial \hat{y_i}}{\partial o_j} ∂oj∂L=i=1∑m∂yi^∂L⋅∂oj∂yi^

因为都是softmax表达式,所以:
∂ y i ^ ∂ o j = y i ^ ( δ i j − y j ^ ) , δ i j = { 1 , i = j 0 , i ≠ j \frac{\partial \hat{y_i}}{\partial o_j} = \hat{y_i}(\delta_{ij}-\hat{y_j}), \delta_{ij} = \left\{ \begin{aligned} 1, i =j \\ 0, i \neq j \end{aligned} \right. ∂oj∂yi^=yi^(δij−yj^),δij={1,i=j0,i=j

重点是求导 ∂ L ∂ y i ^ \frac{\partial L}{\partial \hat{y_i}} ∂yi^∂L:
∂ L ∂ y i ^ = ∂ ∂ y i ^ ( − ∑ i = 1 m y i log ⁡ y j ^ ) = − y i y i ^ \frac{\partial L}{\partial \hat{y_i}} = \frac{\partial}{\partial \hat{y_i}}(-\sum_{i=1}^{m}y_i\log\hat{y_j}) = -\frac{y_i}{\hat{y_i}} ∂yi^∂L=∂yi^∂(−i=1∑myilogyj^)=−yi^yi

其导数始终为:
∂ L ∂ o j = ∑ i = 1 m ( − y i y i ^ ) y i ^ ( δ i j − y j ^ ) = − ( ∑ i = 1 m y i δ i j − ∑ i = 1 m y i y j ^ ) \frac{\partial L}{\partial o_j} = \sum_{i=1}^{m}(-\frac{y_i}{\hat{y_i}})\hat{y_i}(\delta_{ij}-\hat{y_j}) = -(\sum_{i=1}^{m}y_i\delta_{ij} - \sum_{i=1}^{m}y_i\hat{y_j}) ∂oj∂L=i=1∑m(−yi^yi)yi^(δij−yj^)=−(i=1∑myiδij−i=1∑myiyj^)

因为 ∑ i = 1 m y i = 1 \sum_{i=1}^{m}y_i = 1 ∑i=1myi=1, ∑ i = 1 m y i δ i j = y j ( i = j 时起作用 ) \sum_{i=1}^{m}y_i\delta_{ij} = y_j(i =j时起作用) ∑i=1myiδij=yj(i=j时起作用),所以:
∂ L ∂ o j = y j ^ − y j \frac{\partial L}{\partial o_j} = \hat{y_j} - y_j ∂oj∂L=yj^−yj

就不会存在以上问题,其梯度与误差成正比,梯度越大,更新力度越大。

1.2 为什么书上只推导了 ∂ L ∂ o j \frac{\partial L}{\partial o_j} ∂oj∂L,而省略了 ∂ L ∂ w j \frac{\partial L}{\partial w_j} ∂wj∂L?

我们的最终必然是要求损失函数针对于参数 w \mathbf{w} w 和 b b b 的导数,而以上推导是是对于softmax+交叉熵损失,softmax这一层的梯度公式就是 y j ^ − y j \hat{y_j} - y_j yj^−yj,与前面网络的结构没有关系。

而在softmax之前:
O = X W + b \mathbf{O} = \mathbf{X}\mathbf{W} + \mathbf{b} O=XW+b

这是一个线性层,所以 ∂ o ∂ w j = x j \frac{\partial o}{\partial w_j} = x_j ∂wj∂o=xj。

它之所以被单独提出来详细推导,是因为:

  1. 它是一个核心且通用的结果:只要输出层是Softmax+交叉熵损失,这个梯度公式就永远是 Y^−Y,与网络前面的结构无关。
  2. 它极其简洁,体现了数学的美感。
  3. 理解这一步是理解整个反向传播过程的基础。现代的深度学习框架(如PyTorch、TensorFlow)会自动完成链式法则的其他部分,但理解这个起点的计算原理至关重要。

2. softmax回归为什么是线性模型

尽管softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定。 因此,softmax回归是一个线性模型(linear model)。

总的来说,softmax回归分为两个步骤:

  1. 线性部分(仿射变换) o = w x + b \mathbf{o} = \mathbf{w}\mathbf{x} + b o=wx+b;
  2. 非线性部分(softmax函数) y j ^ = s o f t m a x ( o j ) = e o j ∑ i = 1 n e o i \hat{y_j} = softmax(o_j) = \frac{e^{o_j}}{\sum_{i=1}^{n}e^{o_i}} yj^=softmax(oj)=∑i=1neoieoj

而线性部分定义了决策的边界,最终的分类决策规则始终取决于线性部分的结果,第二部分只是将输出校准为有意义的概率值,便于解释和计算损失。

一个比喻: 想象你用三个不同的秤(对应三个类别)去称一个物体(对应一个样本)。每个秤给出一个读数(对应 o j o_j oj)。

  • softmax函数的作用就像是把这三个读数都转换成"这个物体最可能被哪个秤称出来"的概率百分比。这是一个非线性转换。
  • 但最终你判断物体属于哪个类别,依据仍然是哪个秤的原始读数最大。这个"找最大值"的决策规则是线性的。
相关推荐
集成显卡3 小时前
Rust实战七 |基于带 colored 颜色文字控制台的批量文件删除工具
开发语言·后端·rust
jeffer_liu4 小时前
Spring AI 生产级实战:工具调用
java·人工智能·后端·spring·ai编程
Cosolar5 小时前
AutoGen 精通教程:从零到企业级多 Agent 系统架构师
人工智能·后端·面试
狂炫冰美式6 小时前
你还在古法PPT吗,试试HTML呢?免费编辑导出工具给 xdm 放这了
前端·后端·github
万少7 小时前
未来组织的分水岭不是员工数量,而是人才密度
前端·后端·面试
Honmaple8 小时前
终端 AI 编程的两条路:Pi 极简哲学 vs Oh-My-Pi 全能主义深度对决
后端
我是一颗柠檬8 小时前
【Redis】发布订阅与消息队列Day8(2026年)
数据库·redis·后端·缓存
道友可好8 小时前
OpenSpec:轻到起飞的 AI 编程规范层
前端·人工智能·后端
IT_陈寒8 小时前
React状态管理这个坑,我爬了整整三天才出来
前端·人工智能·后端