链式法则如何传递参数误差 —— 深入理解神经网络中的梯度传播

目录

前言

一、神经网络训练到底在做什么

二、损失函数如何衡量误差

三、为什么需要求导

四、链式法则登场

五、误差如何传递

六、梯度为什么能指导参数更新

七、多层神经网络中的误差传播

八、为什么叫误差传播

九、自动微分是如何实现链式法则的

十、梯度消失与梯度爆炸

十一、链式法则与现代大模型

十二、代码验证链式法则

总结


前言

学习深度学习时,我们经常会听到这样一句话:

复制代码
反向传播的本质就是链式法则

很多初学者看到这里会产生疑问:

复制代码
链式法则不是高中求导知识吗?

为什么它能训练神经网络?

为什么它能更新数百万参数?

误差到底是如何传递到每一层的?

事实上,神经网络训练的核心并不复杂。

无论是:

复制代码
CNN

RNN

Transformer

GPT

DeepSeek

其参数更新本质上都依赖于:

复制代码
链式法则(Chain Rule)

本文将通过一个简单神经网络,深入理解:

  • 什么是误差传播

  • 链式法则如何工作

  • 为什么误差能够传递到每一个参数

  • 自动微分如何利用链式法则计算梯度


一、神经网络训练到底在做什么

假设我们有一个简单样本:

复制代码
输入:
x = 2

真实值:
y = 10

神经网络:

复制代码
ŷ = wx + b

初始参数:

复制代码
w = 1

b = 1

前向传播:

复制代码
ŷ = 1×2 + 1

ŷ = 3

显然:

复制代码
预测值 = 3

真实值 = 10

出现误差:

复制代码
误差 = 7

训练的目标:

复制代码
调整参数

让预测越来越接近真实值

问题来了:

复制代码
参数应该增加还是减少?

增加多少?

这就是梯度需要解决的问题。


二、损失函数如何衡量误差

通常使用均方误差:

L=(\hat y-y)^2

其中:

复制代码
L
=
Loss
损失函数

代入:

复制代码
ŷ = 3

y = 10

得到:

复制代码
L = (3-10)²

L = 49

此时:

复制代码
损失越大

说明预测越差

训练目标:

复制代码
让Loss不断减小

三、为什么需要求导

假设我们只关注参数:

复制代码
w

问题变成:

复制代码
w增加

Loss会变大还是变小?

这就是导数的意义。


如果:

复制代码
dL/dw > 0

说明:

复制代码
w增大

Loss增大

应该减小参数。


如果:

复制代码
dL/dw < 0

说明:

复制代码
w增大

Loss减小

应该增大参数。


因此:

复制代码
梯度决定参数更新方向

四、链式法则登场

神经网络并不是:

复制代码
Loss直接依赖w

而是:

复制代码
w
↓
预测值
↓
Loss

关系如下:

复制代码
w
→ ŷ
→ L

即:

复制代码
L = f(ŷ)

ŷ = g(w)

因此:

复制代码
L = f(g(w))

这是一个复合函数。


根据链式法则:

\frac{dL}{dw}=\frac{dL}{d\hat y}\cdot\frac{d\hat y}{dw}

这就是神经网络梯度传播的基础。


五、误差如何传递

继续上面的例子。

预测函数:

复制代码
ŷ = wx+b

损失函数:

复制代码
L=(ŷ-y)²

首先求:

复制代码
dL/dŷ

得到:

复制代码
2(ŷ-y)

代入:

复制代码
ŷ=3

y=10

得到:

复制代码
dL/dŷ=-14

这表示:

复制代码
Loss对预测值的敏感程度

接下来求:

复制代码
dŷ/dw

因为:

复制代码
ŷ=wx+b

所以:

复制代码
dŷ/dw=x

代入:

复制代码
x=2

得到:

复制代码
dŷ/dw=2

根据链式法则:

复制代码
dL/dw

=
-14 × 2

=
-28

最终得到:

复制代码
参数w的梯度

六、梯度为什么能指导参数更新

计算出:

复制代码
dL/dw=-28

说明:

复制代码
如果增大w

Loss会减小

因此:

复制代码
应该增大w

梯度下降公式:

w=w-\eta\frac{dL}{dw}

其中:

复制代码
η
=
学习率

假设:

复制代码
η = 0.1

更新:

复制代码
w

=

1 - 0.1×(-28)

=

3.8

参数自动向正确方向移动。


七、多层神经网络中的误差传播

真正的神经网络更复杂。

例如:

复制代码
输入层

↓

隐藏层

↓

输出层

↓

Loss

设:

复制代码
x
→ z1
→ z2
→ Loss

关系:

复制代码
z1=f(x)

z2=g(z1)

L=h(z2)

求:

复制代码
dL/dx

根据链式法则:

\frac{dL}{dx}=\frac{dL}{dz_2}\cdot\frac{dz_2}{dz_1}\cdot\frac{dz_1}{dx}

可以看到:

复制代码
误差从Loss开始

逐层向前传递

这就是:

复制代码
Back Propagation
反向传播

八、为什么叫误差传播

损失函数代表:

复制代码
预测错误程度

例如:

复制代码
Loss = 100

说明:

复制代码
模型很差

反向传播会计算:

复制代码
每一个参数

对Loss贡献多少

如果某个参数:

复制代码
对Loss影响很大

则:

复制代码
梯度大

如果某个参数:

复制代码
几乎没影响

则:

复制代码
梯度小

因此:

复制代码
误差会沿计算图不断传递

最终到达每个参数。


九、自动微分是如何实现链式法则的

假设:

python 复制代码
x = torch.tensor(
    2.0,
    requires_grad=True
)

y = x * 3

z = y ** 2

此时:

复制代码
x
↓
乘3
↓
y
↓
平方
↓
z

执行:

复制代码
z.backward()

PyTorch 自动完成:

复制代码
dz/dy

dy/dx

链式法则相乘

得到:

复制代码
dz/dx

查看:

复制代码
print(x.grad)

输出:

复制代码
36

因为:

复制代码
z=(3x)²

=9x²

求导:

复制代码
dz/dx

=18x

=36

完全正确。


十、梯度消失与梯度爆炸

链式法则虽然强大。

但也带来了问题。


假设:

复制代码
dL/dx

=

0.1×0.1×0.1×0.1×...

不断连乘。

结果:

复制代码
趋近于0

这就是:

复制代码
梯度消失

反之:

复制代码
10×10×10×10...

会越来越大。

形成:

复制代码
梯度爆炸

因此:

复制代码
深层神经网络

训练难度远高于浅层网络。


这也是后来:

复制代码
ReLU

BatchNorm

ResNet

Transformer

出现的重要原因。


十一、链式法则与现代大模型

无论是:

复制代码
BERT

GPT

Llama

Qwen

DeepSeek

训练过程本质都是:

复制代码
前向传播
↓
计算Loss
↓
链式法则反向传播
↓
更新参数

区别仅仅在于:

复制代码
参数数量更多

网络更深

计算更复杂

但核心数学原理从未改变。

仍然是:

复制代码
链式法则

十二、代码验证链式法则

下面用 PyTorch 验证。

python 复制代码
import torch

x = torch.tensor(
    2.0,
    requires_grad=True
)

y = x * 3

z = y ** 2

z.backward()

print(x.grad)

输出:

复制代码
tensor(36.)

手工计算:

python 复制代码
z=(3x)^2

=9x²

dz/dx

=18x

=36

与自动微分结果完全一致。


这说明:

复制代码
Autograd

实际上就是链式法则的自动化实现

总结

神经网络训练的本质,其实可以归结为一个简单问题:

复制代码
如何知道每个参数应该往哪个方向调整?

答案就是:

复制代码
链式法则

它通过:

复制代码
Loss
↓
输出层
↓
隐藏层
↓
输入层

将误差逐层传播。

并计算出:

复制代码
每一个参数

对Loss的贡献程度

最终得到梯度。

整个训练过程可以概括为:

复制代码
前向传播
↓
计算Loss
↓
链式法则反向传播
↓
得到梯度
↓
梯度下降更新参数
↓
Loss不断减小

而 PyTorch 的 Autograd、本质上的反向传播算法、乃至 GPT 等大型神经网络的训练,都建立在这一数学思想之上。

理解了链式法则如何传递参数误差,也就真正理解了深度学习训练的核心机制。

相关推荐
Anastasiozzzz1 小时前
从有限状态机到智能体图:传统 FSM 与 Agent Graph的演进
java·人工智能·python·ai
程序员cxuan7 小时前
为每个任务配一套 harness:Claude Code 里的动态工作流
人工智能
程序员cxuan7 小时前
Claude Fable 5 来了
人工智能·后端·程序员
云边云科技_云网融合7 小时前
云边云科技亮相 2026 WOD 制造业数智化博览会 云网融合赋能制造焕新
人工智能·科技·安全·制造
biter down7 小时前
从 0 到 1 搭建 Python 接口自动化测试框架(博客系统实战)
开发语言·python
Σίσυφος19007 小时前
激光三角 光平面标定-多高度误差分析
人工智能·计算机视觉·平面
JS菌7 小时前
手写一个 AI Agent 全栈项目:从沙箱执行到子智能体的完整实现
前端·人工智能·后端
lqqjuly7 小时前
前沿算法深度解析(二)
人工智能·算法·机器学习
Bode_20027 小时前
基于大数据分析的全生命周期质量追溯质量评估体系落地方案
大数据·人工智能