链式法则如何传递参数误差 —— 深入理解神经网络中的梯度传播

目录

前言

一、神经网络训练到底在做什么

二、损失函数如何衡量误差

三、为什么需要求导

四、链式法则登场

五、误差如何传递

六、梯度为什么能指导参数更新

七、多层神经网络中的误差传播

八、为什么叫误差传播

九、自动微分是如何实现链式法则的

十、梯度消失与梯度爆炸

十一、链式法则与现代大模型

十二、代码验证链式法则

总结


前言

学习深度学习时,我们经常会听到这样一句话:

复制代码
反向传播的本质就是链式法则

很多初学者看到这里会产生疑问:

复制代码
链式法则不是高中求导知识吗?

为什么它能训练神经网络?

为什么它能更新数百万参数?

误差到底是如何传递到每一层的?

事实上,神经网络训练的核心并不复杂。

无论是:

复制代码
CNN

RNN

Transformer

GPT

DeepSeek

其参数更新本质上都依赖于:

复制代码
链式法则(Chain Rule)

本文将通过一个简单神经网络,深入理解:

  • 什么是误差传播

  • 链式法则如何工作

  • 为什么误差能够传递到每一个参数

  • 自动微分如何利用链式法则计算梯度


一、神经网络训练到底在做什么

假设我们有一个简单样本:

复制代码
输入:
x = 2

真实值:
y = 10

神经网络:

复制代码
ŷ = wx + b

初始参数:

复制代码
w = 1

b = 1

前向传播:

复制代码
ŷ = 1×2 + 1

ŷ = 3

显然:

复制代码
预测值 = 3

真实值 = 10

出现误差:

复制代码
误差 = 7

训练的目标:

复制代码
调整参数

让预测越来越接近真实值

问题来了:

复制代码
参数应该增加还是减少?

增加多少?

这就是梯度需要解决的问题。


二、损失函数如何衡量误差

通常使用均方误差:

L=(\hat y-y)^2

其中:

复制代码
L
=
Loss
损失函数

代入:

复制代码
ŷ = 3

y = 10

得到:

复制代码
L = (3-10)²

L = 49

此时:

复制代码
损失越大

说明预测越差

训练目标:

复制代码
让Loss不断减小

三、为什么需要求导

假设我们只关注参数:

复制代码
w

问题变成:

复制代码
w增加

Loss会变大还是变小?

这就是导数的意义。


如果:

复制代码
dL/dw > 0

说明:

复制代码
w增大

Loss增大

应该减小参数。


如果:

复制代码
dL/dw < 0

说明:

复制代码
w增大

Loss减小

应该增大参数。


因此:

复制代码
梯度决定参数更新方向

四、链式法则登场

神经网络并不是:

复制代码
Loss直接依赖w

而是:

复制代码
w
↓
预测值
↓
Loss

关系如下:

复制代码
w
→ ŷ
→ L

即:

复制代码
L = f(ŷ)

ŷ = g(w)

因此:

复制代码
L = f(g(w))

这是一个复合函数。


根据链式法则:

\frac{dL}{dw}=\frac{dL}{d\hat y}\cdot\frac{d\hat y}{dw}

这就是神经网络梯度传播的基础。


五、误差如何传递

继续上面的例子。

预测函数:

复制代码
ŷ = wx+b

损失函数:

复制代码
L=(ŷ-y)²

首先求:

复制代码
dL/dŷ

得到:

复制代码
2(ŷ-y)

代入:

复制代码
ŷ=3

y=10

得到:

复制代码
dL/dŷ=-14

这表示:

复制代码
Loss对预测值的敏感程度

接下来求:

复制代码
dŷ/dw

因为:

复制代码
ŷ=wx+b

所以:

复制代码
dŷ/dw=x

代入:

复制代码
x=2

得到:

复制代码
dŷ/dw=2

根据链式法则:

复制代码
dL/dw

=
-14 × 2

=
-28

最终得到:

复制代码
参数w的梯度

六、梯度为什么能指导参数更新

计算出:

复制代码
dL/dw=-28

说明:

复制代码
如果增大w

Loss会减小

因此:

复制代码
应该增大w

梯度下降公式:

w=w-\eta\frac{dL}{dw}

其中:

复制代码
η
=
学习率

假设:

复制代码
η = 0.1

更新:

复制代码
w

=

1 - 0.1×(-28)

=

3.8

参数自动向正确方向移动。


七、多层神经网络中的误差传播

真正的神经网络更复杂。

例如:

复制代码
输入层

↓

隐藏层

↓

输出层

↓

Loss

设:

复制代码
x
→ z1
→ z2
→ Loss

关系:

复制代码
z1=f(x)

z2=g(z1)

L=h(z2)

求:

复制代码
dL/dx

根据链式法则:

\frac{dL}{dx}=\frac{dL}{dz_2}\cdot\frac{dz_2}{dz_1}\cdot\frac{dz_1}{dx}

可以看到:

复制代码
误差从Loss开始

逐层向前传递

这就是:

复制代码
Back Propagation
反向传播

八、为什么叫误差传播

损失函数代表:

复制代码
预测错误程度

例如:

复制代码
Loss = 100

说明:

复制代码
模型很差

反向传播会计算:

复制代码
每一个参数

对Loss贡献多少

如果某个参数:

复制代码
对Loss影响很大

则:

复制代码
梯度大

如果某个参数:

复制代码
几乎没影响

则:

复制代码
梯度小

因此:

复制代码
误差会沿计算图不断传递

最终到达每个参数。


九、自动微分是如何实现链式法则的

假设:

python 复制代码
x = torch.tensor(
    2.0,
    requires_grad=True
)

y = x * 3

z = y ** 2

此时:

复制代码
x
↓
乘3
↓
y
↓
平方
↓
z

执行:

复制代码
z.backward()

PyTorch 自动完成:

复制代码
dz/dy

dy/dx

链式法则相乘

得到:

复制代码
dz/dx

查看:

复制代码
print(x.grad)

输出:

复制代码
36

因为:

复制代码
z=(3x)²

=9x²

求导:

复制代码
dz/dx

=18x

=36

完全正确。


十、梯度消失与梯度爆炸

链式法则虽然强大。

但也带来了问题。


假设:

复制代码
dL/dx

=

0.1×0.1×0.1×0.1×...

不断连乘。

结果:

复制代码
趋近于0

这就是:

复制代码
梯度消失

反之:

复制代码
10×10×10×10...

会越来越大。

形成:

复制代码
梯度爆炸

因此:

复制代码
深层神经网络

训练难度远高于浅层网络。


这也是后来:

复制代码
ReLU

BatchNorm

ResNet

Transformer

出现的重要原因。


十一、链式法则与现代大模型

无论是:

复制代码
BERT

GPT

Llama

Qwen

DeepSeek

训练过程本质都是:

复制代码
前向传播
↓
计算Loss
↓
链式法则反向传播
↓
更新参数

区别仅仅在于:

复制代码
参数数量更多

网络更深

计算更复杂

但核心数学原理从未改变。

仍然是:

复制代码
链式法则

十二、代码验证链式法则

下面用 PyTorch 验证。

python 复制代码
import torch

x = torch.tensor(
    2.0,
    requires_grad=True
)

y = x * 3

z = y ** 2

z.backward()

print(x.grad)

输出:

复制代码
tensor(36.)

手工计算:

python 复制代码
z=(3x)^2

=9x²

dz/dx

=18x

=36

与自动微分结果完全一致。


这说明:

复制代码
Autograd

实际上就是链式法则的自动化实现

总结

神经网络训练的本质,其实可以归结为一个简单问题:

复制代码
如何知道每个参数应该往哪个方向调整?

答案就是:

复制代码
链式法则

它通过:

复制代码
Loss
↓
输出层
↓
隐藏层
↓
输入层

将误差逐层传播。

并计算出:

复制代码
每一个参数

对Loss的贡献程度

最终得到梯度。

整个训练过程可以概括为:

复制代码
前向传播
↓
计算Loss
↓
链式法则反向传播
↓
得到梯度
↓
梯度下降更新参数
↓
Loss不断减小

而 PyTorch 的 Autograd、本质上的反向传播算法、乃至 GPT 等大型神经网络的训练,都建立在这一数学思想之上。

理解了链式法则如何传递参数误差,也就真正理解了深度学习训练的核心机制。

相关推荐
妙妙屋(zy)22 分钟前
Claude Code+CC-Switch+CC-Connect+飞书使用教程
ai
阿里云大数据AI技术35 分钟前
Agentic Memory Extension 支持对接主流Agent - 适用于 Claude Code、CodeX等
人工智能·agent
我唔知啊1 小时前
不是让 AI 写代码,我是在指挥 AI 干活:一套打磨出来的 AI 编程工作流
人工智能
ZzT1 小时前
在 GitHub 上 @一下 claude,它自己把 issue 改成 PR
人工智能·开源
不加辣椒1 小时前
第15章 上下文窗口管理与长文本策略
人工智能
牛奶2 小时前
AI 能赚钱了——但赚的不是你
人工智能·ai编程·nvidia
凌杰2 小时前
AI 学习笔记:研究方法的演变
人工智能
半盏药香3 小时前
由于jinja2的starlette版本过高引发的问题:500 Server Error TypeError: unhashable type: 'dict'
人工智能
阿里云大数据AI技术3 小时前
MiniMax M3、Kimi K2.7 Code来啦!PAI已支持一键部署,开源前沿触手可及
人工智能·agent
百度Geek说3 小时前
AI Coding 的底层框架:一切优化都是在对抗熵增
人工智能