深度学习的核心求梯度就是多维函数求导数

文章目录

    • 一、先回顾:梯度消失的根源
    • 二、残差连接的核心:给梯度加一个"1"
      • [1. 残差块的正向传播](#1. 残差块的正向传播)
      • [2. 反向传播的梯度推导(关键!)](#2. 反向传播的梯度推导(关键!))
      • [3. 这个"+1"的魔力在哪里?](#3. 这个“+1”的魔力在哪里?)
    • [三、极简例子:2层残差网络 vs 普通网络](#三、极简例子:2层残差网络 vs 普通网络)
      • 设定
      • [1. 普通网络的梯度计算(无残差)](#1. 普通网络的梯度计算(无残差))
        • 正向传播
        • [反向传播(求 \(\frac{\partial L}{\partial h_0}\))](#反向传播(求 \frac{\partial L}{\partial h_0}))
      • [2. 残差网络的梯度计算(有残差)](#2. 残差网络的梯度计算(有残差))
        • 正向传播
        • [反向传播(求 \(\frac{\partial L}{\partial h_0}\))](#反向传播(求 \frac{\partial L}{\partial h_0}))
    • 四、核心结论

1. 深度学习的核心求梯度就是多维函数求导数

**完全正确!**✅

  • 深度学习模型的损失函数 (L(\theta)) 是关于参数向量 (\theta = (\theta_1,\dots,\theta_n)) 的多维函数
  • 梯度 (\nabla_\theta L(\theta)) 就是这个多维函数对每个参数 (\theta_i) 求偏导数后得到的向量:

    \\nabla_\\theta L(\\theta) = \\left( \\frac{\\partial L}{\\partial \\theta_1},; \\frac{\\partial L}{\\partial \\theta_2},; \\dots,; \\frac{\\partial L}{\\partial \\theta_n} \\right)

  • 所以"求梯度"本质上就是"对多维函数求偏导数",只是在高维场景下用梯度这个统一的向量形式来表示。

2. 强化学习中的梯度上升更新

是的,强化学习的目标是最大化期望回报 (R_\theta) ,所以我们会沿着梯度的方向更新参数:

\\theta_{\\text{new}} = \\theta_{\\text{old}} + \\alpha \\cdot \\nabla_\\theta R_\\theta

其中 (\alpha > 0) 是学习率,这就是梯度上升(如果是最小化损失函数则是梯度下降)。


🎯 简单函数举例:单参数策略的期望回报梯度更新

我们用一个极简的例子来完整走一遍流程。

设定
  • 假设策略只有一个参数 (\theta),在状态 (s) 下选择动作 (a=1) 的概率为 (\pi_\theta(a=1|s) = \theta),选择 (a=0) 的概率为 (1-\theta)((0 < \theta < 1))。
  • 环境规则:选择 (a=1) 时获得回报 (r=2),选择 (a=0) 时获得回报 (r=1)。
  • 期望回报 (R_\theta = \mathbb{E}{a \sim \pi\theta}[r] = 2\theta + 1 \cdot (1-\theta) = \theta + 1)。
步骤1:求期望回报的梯度

\\nabla_\\theta R_\\theta = \\frac{d}{d\\theta} (\\theta + 1) = 1

步骤2:梯度上升更新参数

取学习率 (\alpha = 0.1),初始参数 (\theta_0 = 0.3):

\\theta_1 = \\theta_0 + \\alpha \\cdot \\nabla_\\theta R_\\theta = 0.3 + 0.1 \\times 1 = 0.4

\\theta_2 = \\theta_1 + 0.1 \\times 1 = 0.5

\\theta_3 = 0.5 + 0.1 \\times 1 = 0.6

结果
  • 参数 (\theta) 不断增大,最终趋近于 (1)。
  • 此时策略会一直选择 (a=1),期望回报 (R_\theta = 1+1=2),达到最大。

🧩 再看一个带基线的例子(更贴近强化学习公式)

我们在期望回报中引入基线 (b=1),梯度公式变为:

\\nabla_\\theta R_\\theta = \\mathbb{E}*{a \\sim \\pi* \\theta}\\left\[(r - b) \\cdot \\nabla_\\theta \\log \\pi_\\theta(a\|s)\\right

]

  1. 计算 (\nabla_\theta \log \pi_\theta(a|s)):

    • 若 (a=1):(\log \pi_\theta = \log \theta),梯度为 (\frac{1}{\theta})
    • 若 (a=0):(\log \pi_\theta = \log (1-\theta)),梯度为 (-\frac{1}{1-\theta})
  2. 代入期望计算:

    \\nabla_\\theta R_\\theta = \\theta \\cdot (2-1) \\cdot \\frac{1}{\\theta} + (1-\\theta) \\cdot (1-1) \\cdot \\left(-\\frac{1}{1-\\theta}\\right) = 1 + 0 = 1

    结果和之前一致,说明基线不改变梯度的期望,只降低方差。


我可以帮你整理这个例子的完整梯度更新步骤清单,把每一步的计算和更新过程都列出来,让你能更清晰地看到参数变化的过程。需要吗?

没问题!我用一个高中生也能轻松理解的例子,把梯度下降讲明白。


🎯 生活化理解:梯度下降 = 下山最快的路

想象你站在一座山坡上,目标是尽快走到山谷最低点。

  • 函数值 = 你当前的海拔高度
  • 自变量 (x) = 你在山坡上的水平位置
  • 梯度(导数) = 脚下山坡的陡峭程度和方向(正表示上坡,负表示下坡)
  • 梯度下降 = 每次都朝着"最陡下坡"的方向走一小步,最终到达山谷。

🧮 具体数学例子

我们用一个简单的二次函数:

f(x) = (x - 3)\^2 + 2

  • 这个函数的图像是开口向上的抛物线,最低点在 (x=3) 处,此时 (f(x)=2)。
  • 我们的目标是通过梯度下降,从任意初始点(比如 (x_0=7))出发,找到这个最低点。
步骤1:求梯度(导数)

对函数求导得到:

f'(x) = 2(x - 3)

这个导数就是梯度,它告诉我们在任意 (x) 点,函数值的变化率和方向。

步骤2:梯度下降更新规则

每次更新都沿着梯度的反方向走一小步:

x_{\\text{new}} = x_{\\text{old}} - \\alpha \\cdot f'(x_{\\text{old}})

  • (\alpha) 是"步长"(学习率),这里取 (\alpha=0.2),确保走得稳。
步骤3:一步步计算

我们从 (x_0=7) 开始:

  1. 第一次更新

    \\begin{align\*} f'(7) \&= 2(7-3) = 8 \\ x_1 \&= 7 - 0.2 \\times 8 = 7 - 1.6 = 5.4 \\end{align\*}

    此时 (f(5.4) = (5.4-3)^2 + 2 = 7.76),比初始的 (f(7)=18) 小了很多。

  2. 第二次更新

    \\begin{align\*} f'(5.4) \&= 2(5.4-3) = 4.8 \\ x_2 \&= 5.4 - 0.2 \\times 4.8 = 5.4 - 0.96 = 4.44 \\end{align\*}

    (f(4.44) = (4.44-3)^2 + 2 = 4.07),继续下降。

  3. 第三次更新

    \\begin{align\*} f'(4.44) \&= 2(4.44-3) = 2.88 \\ x_3 \&= 4.44 - 0.2 \\times 2.88 = 4.44 - 0.576 = 3.864 \\end{align\*}

    (f(3.864) \approx 2.75),越来越接近最小值 2。

  4. 继续迭代

    再更新几次,(x) 会越来越接近 3,(f(x)) 也会越来越接近 2,最终收敛到最低点。


✅ 核心结论

  • 梯度(导数)告诉我们函数在当前点的变化率和方向
  • 梯度下降通过反向走小步,让函数值一步步减小,最终找到最小值。
  • 这就像下山时,每一步都朝着最陡的下坡方向走,最终到达山谷。

我可以帮你整理这个例子的梯度下降可视化表格,把每一步的 (x) 值、梯度和函数值都列出来,这样给高中生讲的时候会更直观。需要吗?

哈哈,这个问题问得特别好!其实在收敛到最低点后,如果不停止算法,会出现几种有趣的情况,我们还是用刚才的例子 (f(x)=(x-3)^2+2) 来解释。


1. 理想情况:在最低点"原地踏步"

当我们到达最低点 (x=3) 时,函数的梯度(导数)为:

f'(3) = 2(3-3) = 0

代入梯度下降的更新公式:

x_{\\text{new}} = 3 - \\alpha \\times 0 = 3

这说明:

  • 只要走到了最低点,梯度就为0,更新后 (x) 的值不会再变化。
  • 即使继续迭代,算法也会在 (x=3) 这个点"原地踏步",不会再移动。

2. 实际情况:在最低点附近"小幅震荡"

在真实的深度学习场景中,因为学习率 (\alpha) 不可能无限小,或者计算中存在数值误差,算法不会正好停在最低点,而是会出现:

  • 学习率偏大:当接近最低点时,梯度已经很小,但步长 (\alpha \cdot f'(x)) 仍然会让 (x) 跨过最低点,然后在另一侧再走回来,形成小幅来回震荡。
  • 学习率偏小:会非常缓慢地逼近最低点,看起来像是"粘"在离最低点很近的地方不再动了。

比如,假设我们在 (x=3.1) 处,梯度 (f'(3.1)=0.2),如果学习率 (\alpha=1):

x_{\\text{new}} = 3.1 - 1 \\times 0.2 = 2.9

下一步在 (x=2.9) 处,梯度 (f'(2.9)=-0.2):

x_{\\text{new}} = 2.9 - 1 \\times (-0.2) = 3.1

这样就会在 (3.1) 和 (2.9) 之间来回震荡,永远不会精确停在 (3)。


3. 为什么我们需要"停止条件"

正是因为存在震荡或缓慢收敛的情况,实际算法都会设置停止条件,比如:

  • 连续几次迭代后,函数值的变化小于一个很小的阈值(比如 (10^{-6}))。
  • 梯度的绝对值小于一个很小的阈值(比如 (|f'(x)| < 10^{-5}))。
  • 达到预设的最大迭代次数。

我可以帮你整理一份梯度下降停止条件的教学版说明,用简单的语言和例子讲给高中生听,让他们明白为什么算法不能一直跑下去。需要吗?

你的观察特别准!模型层数越深,反向传播的梯度推导式子确实会更长、理论上更复杂,但PyTorch这类框架已经把所有复杂的数学计算都封装好了,所以你完全感受不到这种难度。我们拆成"理论复杂度"和"工程封装"两部分说清楚:

一、先回答核心问题:层数越深,梯度计算的"理论式子"确实越复杂

反向传播的核心是链式法则的多层嵌套,层数越多,嵌套的环节就越多,梯度的数学表达式自然越长、越繁琐。

举个极简例子:2层 vs 3层网络的梯度式子对比

我们用最简单的全连接层+ReLU激活,对比2层和3层网络的梯度推导:

1. 2层网络(输入→隐藏层→输出)

假设:

  • 隐藏层输出:(h = \text{ReLU}(x \cdot W_1 + b_1))
  • 输出层预测:(\hat{y} = h \cdot W_2 + b_2)
  • 损失:(L = \frac{1}{2}(\hat{y} - y)^2)

损失对 (W_1) 的梯度(反向传播的核心):

\\frac{\\partial L}{\\partial W_1} = \\frac{\\partial L}{\\partial \\hat{y}} \\cdot \\frac{\\partial \\hat{y}}{\\partial h} \\cdot \\frac{\\partial h}{\\partial (xW_1+b_1)} \\cdot \\frac{\\partial (xW_1+b_1)}{\\partial W_1}

式子只有4个因子相乘,逻辑很清晰。

2. 3层网络(输入→隐藏层1→隐藏层2→输出)

新增隐藏层2:(h2 = \text{ReLU}(h1 \cdot W_2 + b_2)),输出层:(\hat{y} = h2 \cdot W_3 + b_3)。

损失对 (W_1) 的梯度:

\\frac{\\partial L}{\\partial W_1} = \\frac{\\partial L}{\\partial \\hat{y}} \\cdot \\frac{\\partial \\hat{y}}{\\partial h2} \\cdot \\frac{\\partial h2}{\\partial h1} \\cdot \\frac{\\partial h1}{\\partial (xW_1+b_1)} \\cdot \\frac{\\partial (xW_1+b_1)}{\\partial W_1}

式子多了1个因子,嵌套层数增加------如果是10层、100层网络,这个链式法则的式子会变成一长串相乘,手动推导不仅繁琐,还极易出错。

更深层的额外挑战:梯度消失/爆炸

层数越深,链式法则里的因子越多:

  • 如果每个因子都小于1,相乘后梯度会趋近于0(梯度消失),底层参数几乎不更新;
  • 如果每个因子都大于1,梯度会急剧增大(梯度爆炸),参数更新失控。

这是深层网络梯度计算的"本质难度",但框架也帮你解决了(比如用ReLU缓解梯度消失、用梯度裁剪缓解爆炸)。

二、为什么你用PyTorch完全感受不到这种困难?

核心原因是:PyTorch把"手动推导梯度式子"和"处理梯度消失/爆炸"的所有工作都自动化、封装化了,具体靠3点:

1. 计算图:自动记录所有计算的"依赖关系"

当你写 y = x @ W1 + b1h = F.relu(y) 时,PyTorch不会只计算数值,还会悄悄构建一个计算图 ------记录"哪个变量由哪些参数计算而来"(比如h依赖yy依赖W1b1)。

反向传播时,PyTorch不需要你写任何梯度式子,只需要调用 loss.backward(),它就会:

  • 从损失L出发,沿着计算图反向遍历;
  • 对每个节点(比如ReLU、矩阵乘法),自动调用预设的导数规则 (比如ReLU的导数是(x>0).float(),矩阵乘法的导数是输入的转置);
  • 把链式法则的多层嵌套计算,拆解成一个个简单的导数计算步骤,自动完成相乘和传递。
2. 自动微分:把"数学推导"变成"代码自动执行"

PyTorch的torch.autograd模块是"自动微分引擎",它的核心是:

  • 正向传播时,为每个操作(如+@relu)注册对应的"反向导数函数";
  • 反向传播时,按计算图的顺序,依次执行这些导数函数,自动算出所有参数的梯度。

哪怕是100层的网络,你只需要写正向传播的代码,反向传播的梯度计算完全由框架完成------你看不到一长串的梯度式子,只需要关心loss.backward()这一行代码。

3. 内置优化:自动处理深层网络的梯度问题

对于梯度消失/爆炸这类深层网络的核心难题,PyTorch也提供了现成的解决方案:

  • 激活函数:F.relu()F.gelu() 等内置激活函数,天生缓解梯度消失;
  • 优化器:torch.optim.Adam() 等优化器内置了梯度缩放、动量机制,减少梯度波动;
  • 梯度裁剪:torch.nn.utils.clip_grad_norm_() 一键解决梯度爆炸,不用手动处理。

三、一个直观的比喻:你是司机,PyTorch是自动驾驶系统

  • 手动推导梯度 = 你自己开手动挡车,需要手动控制离合、油门、换挡,层数越深(路况越复杂),操作越难;
  • 用PyTorch = 你开自动驾驶车,只需要告诉车"目标是最小化loss"(踩油门),车会自动处理离合、换挡、避让障碍(梯度计算、梯度消失/爆炸),你只需要关注最终的行驶方向(模型效果)。

总结

  1. 理论层面:模型层数越深,反向传播的梯度式子确实越长、嵌套越复杂,还会出现梯度消失/爆炸等难题;
  2. 工程层面 :PyTorch通过计算图+自动微分+内置优化 ,把所有复杂的数学推导和问题处理都封装了,你只需要写正向传播代码,调用loss.backward()就能自动算出所有参数的梯度;
  3. 你的体验:因为框架替你做了所有"脏活累活",所以哪怕训练100层的网络,也和训练2层网络的代码复杂度差不多,完全感受不到梯度计算的难度。

如果需要,我可以给你一个2层vs3层网络的PyTorch极简代码示例,让你直观看到:层数增加后,代码只需要多写一行层的定义,反向传播的代码完全不变,帮你更深刻理解这种"封装带来的便捷性"。需要吗?

你的理解完全抓对了核心!残差连接(Residual Connection)的精髓,就是通过 (f(x)+x) 这个简单的加法,在反向传播时给梯度加了一个**"保底的1"**,从而从根本上缓解深层网络的梯度消失问题。

我们用数学推导+极简例子,把这个逻辑拆得明明白白。


一、先回顾:梯度消失的根源

深层网络梯度消失,核心是链式法则的连乘效应

假设一个深层网络的每一层变换是 (h_{l} = F_{l}(h_{l-1}))((F_l) 是第 (l) 层的卷积/全连接+激活),那么损失 (L) 对第 1 层输入 (h_1) 的梯度是:

\frac{\partial L}{\partial h_1} = \frac{\partial L}{\partial h_L} \times \frac{\partial h_L}{\partial h_{L-1}} \times \frac{\partial h_{L-1}}{\partial h_{L-2}} \times \dots \times \frac{\partial h_2}{\partial h_1}

如果每一层的 (\frac{\partial h_l}{\partial h_{l-1}}) 都是小于1的数 (比如ReLU的导数在正区间是1,负区间是0;Sigmoid的导数最大只有0.25),多层连乘后,梯度会指数级衰减------传到浅层时,梯度已经趋近于0,参数根本更新不动。


二、残差连接的核心:给梯度加一个"1"

残差块的定义很简单,我们先明确公式:

1. 残差块的正向传播

普通层:(h_l = F_l(h_{l-1}))

残差层:(h_l = F_l(h_{l-1}) + h_{l-1})

其中 (F_l(h_{l-1})) 叫残差函数 (就是层里的卷积/全连接+激活操作),(h_{l-1}) 是直接短路过来的输入

2. 反向传播的梯度推导(关键!)

我们要计算 损失对残差块输入 (h_{l-1}) 的梯度 (\frac{\partial L}{\partial h_{l-1}}),根据链式法则:

\\frac{\\partial L}{\\partial h_{l-1}} = \\frac{\\partial L}{\\partial h_l} \\times \\frac{\\partial h_l}{\\partial h_{l-1}}

把残差层的 (h_l = F_l + h_{l-1}) 代入,求偏导:

\\frac{\\partial h_l}{\\partial h_{l-1}} = \\frac{\\partial F_l(h_{l-1})}{\\partial h_{l-1}} + \\frac{\\partial h_{l-1}}{\\partial h_{l-1}} = \\frac{\\partial F_l}{\\partial h_{l-1}} + 1

因此,梯度公式变成:

\\boldsymbol{\\frac{\\partial L}{\\partial h_{l-1}} = \\frac{\\partial L}{\\partial h_l} \\times \\left( \\frac{\\partial F_l}{\\partial h_{l-1}} + 1 \\right)}

3. 这个"+1"的魔力在哪里?

  • 没有残差连接时,(\frac{\partial h_l}{\partial h_{l-1}} = \frac{\partial F_l}{\partial h_{l-1}}),如果这个值小于1,多层连乘就会梯度消失;
  • 有残差连接时,梯度里多了一个保底的"1"------哪怕 (\frac{\partial F_l}{\partial h_{l-1}}) 很小(比如0.1),(\frac{\partial h_l}{\partial h_{l-1}}) 也至少是 (1.1),不会趋近于0;
  • 更关键的是:梯度可以通过 "短路路径"直接传递------当 (\frac{\partial F_l}{\partial h_{l-1}}) 趋近于0时,梯度近似等于 (\frac{\partial L}{\partial h_l} \times 1),相当于梯度直接从深层传到浅层,完全不会消失!

三、极简例子:2层残差网络 vs 普通网络

我们用全连接层+ReLU做例子,手动算梯度,看差距有多明显。

设定

  • 输入:(x = h_0 = 2)
  • 普通层:(h_1 = \text{ReLU}(W_1 h_0 + b_1)),(h_2 = \text{ReLU}(W_2 h_1 + b_2))
  • 残差层:(h_1 = \text{ReLU}(W_1 h_0 + b_1) + h_0),(h_2 = \text{ReLU}(W_2 h_1 + b_1) + h_1)
  • 简化参数:(W_1=W_2=0.1),(b_1=b_2=0)(让 (\frac{\partial F}{\partial h}) 很小,模拟梯度消失场景)
  • 损失:(L = (h_2 - y)^2),假设 (y=5),我们只看梯度传递的大小。

1. 普通网络的梯度计算(无残差)

正向传播
  • (h_1 = \text{ReLU}(0.1 \times 2) = \text{ReLU}(0.2) = 0.2)
  • (h_2 = \text{ReLU}(0.1 \times 0.2) = \text{ReLU}(0.02) = 0.02)
反向传播(求 (\frac{\partial L}{\partial h_0}))
  • 第一步:(\frac{\partial L}{\partial h_2} = 2(h_2 - y) = 2(0.02-5) = -9.96)
  • 第二步:(\frac{\partial h_2}{\partial h_1} = W_2 \times \text{ReLU导数}(0.02) = 0.1 \times 1 = 0.1)(ReLU正区间导数为1)
  • 第三步:(\frac{\partial h_1}{\partial h_0} = W_1 \times \text{ReLU导数}(0.2) = 0.1 \times 1 = 0.1)
  • 最终梯度:

    \\frac{\\partial L}{\\partial h_0} = -9.96 \\times 0.1 \\times 0.1 = -0.0996

    梯度从-9.96衰减到了0.0996,只剩原来的1%! 再叠几层,梯度直接趋近于0。

2. 残差网络的梯度计算(有残差)

正向传播
  • (h_1 = \text{ReLU}(0.1 \times 2) + 2 = 0.2 + 2 = 2.2)
  • (h_2 = \text{ReLU}(0.1 \times 2.2) + 2.2 = 0.22 + 2.2 = 2.42)
反向传播(求 (\frac{\partial L}{\partial h_0}))
  • 第一步:(\frac{\partial L}{\partial h_2} = 2(2.42-5) = -5.16)
  • 第二步:残差层的梯度因子 (\frac{\partial h_2}{\partial h_1} = \frac{\partial F_2}{\partial h_1} + 1 = 0.1 + 1 = 1.1)
  • 第三步:残差层的梯度因子 (\frac{\partial h_1}{\partial h_0} = \frac{\partial F_1}{\partial h_0} + 1 = 0.1 + 1 = 1.1)
  • 最终梯度:

    \\frac{\\partial L}{\\partial h_0} = -5.16 \\times 1.1 \\times 1.1 = -5.16 \\times 1.21 = -6.2436

    梯度不仅没衰减,反而因为1.1的连乘,比深层梯度更大! 这就是残差连接的核心作用。

四、核心结论

  1. 残差连接的理论依据 :不是随便加的 (f(x)+x),而是通过导数的加法法则,给梯度传递加了一个"保底的1";
  2. 缓解梯度消失的本质:梯度可以通过两条路径传递------一条是经过残差函数 (F_l) 的"正常路径",一条是直接短路的"1路径";哪怕正常路径的梯度衰减,短路路径也能保证梯度不消失;
  3. 何凯明的厉害之处:用一个极其简单的加法操作,解决了深层网络训练的世界级难题------这就是"大道至简"的体现。

我可以帮你写一个PyTorch的极简残差块代码,对比普通层和残差层在10层网络中的梯度值,让你直观看到残差连接如何保住梯度。需要吗?

相关推荐
冬奇Lab38 分钟前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab39 分钟前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP4 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年4 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼5 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS5 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区6 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈6 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang7 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk18 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能