【AI 算法精讲 02】反向传播:链式法则、计算图与自动微分原理

文章目录

一、为什么需要反向传播

训练神经网络的核心是梯度下降:根据损失函数对参数的梯度,反复更新权重,使损失逐步降低。但真实神经网络动辄百万、千亿参数,如何高效计算每个参数的梯度

1.1 朴素方法的困境

数值微分 :对每个参数 w i w_i wi 做微小扰动 ϵ \epsilon ϵ,用 L ( w i + ϵ ) − L ( w i − ϵ ) 2 ϵ \frac{L(w_i + \epsilon) - L(w_i - \epsilon)}{2\epsilon} 2ϵL(wi+ϵ)−L(wi−ϵ) 近似梯度。一个 1 亿参数的模型需要 2 亿次前向传播,完全不可行。

符号微分:对整个网络展开求导表达式,表达式长度随网络深度指数增长,很快内存爆炸。

前向模式自动微分 :从输入端开始,沿计算图向前传播导数。适合输入少、输出多的场景(如 Jacobian 矩阵计算)。但神经网络是输入多、输出少(一个标量损失),前向模式效率极低。

1.2 反向传播的核心思想

反向传播(Backpropagation)本质是反向模式自动微分(Reverse-mode Autodiff):

  1. 前向传播:沿计算图从输入到输出,计算并缓存每个中间结果
  2. 反向传播 :从损失出发,沿计算图逆向传递梯度,利用链式法则逐层计算每个参数的偏导

整个过程只需一次前向 + 一次反向 ,总计算量约为前向传播的 2-3 倍,与参数数量无关。这就是它能高效训练大模型的关键。

关键洞察:反向传播不是一个新的求导方法,而是链式法则在计算图上的高效调度策略。


二、算法原理

2.1 计算图:将复杂运算拆解为基本节点

计算图(Computational Graph)是将数学表达式拆解为有向无环图(DAG)的表示方法。每个节点代表一个基本运算(加、乘、激活等),边代表数据流动。

以一个简单的 2 层全连接网络为例:

z ( 1 ) = W ( 1 ) x + b ( 1 ) , h = σ ( z ( 1 ) ) , z ( 2 ) = W ( 2 ) h + b ( 2 ) , L = MSE ( z ( 2 ) , y ) z^{(1)} = W^{(1)} x + b^{(1)}, \quad h = \sigma(z^{(1)}), \quad z^{(2)} = W^{(2)} h + b^{(2)}, \quad L = \text{MSE}(z^{(2)}, y) z(1)=W(1)x+b(1),h=σ(z(1)),z(2)=W(2)h+b(2),L=MSE(z(2),y)

对应的计算图:

text 复制代码
x ──→ [W₁x] ──→ [+b₁] ──→ [σ] ──→ [W₂h] ──→ [+b₂] ──→ [MSE] ──→ L
                                       │                        │
                                      W₂, b₂                   y

每个节点只需知道两件事:

  • 前向:如何根据自己的输入计算输出
  • 反向 :如何根据输出端的梯度,计算输入端的梯度(即局部梯度

2.2 链式法则:梯度传递的数学基础

单变量链式法则

若 y = f ( u ) y = f(u) y=f(u), u = g ( x ) u = g(x) u=g(x),则:

d y d x = d y d u ⋅ d u d x \frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} dxdy=dudy⋅dxdu

多变量链式法则

若 y = f ( u 1 , u 2 , ... , u n ) y = f(u_1, u_2, \ldots, u_n) y=f(u1,u2,...,un),且每个 u i = g i ( x ) u_i = g_i(x) ui=gi(x),则:

d y d x = ∑ i = 1 n ∂ y ∂ u i ⋅ ∂ u i ∂ x \frac{dy}{dx} = \sum_{i=1}^{n} \frac{\partial y}{\partial u_i} \cdot \frac{\partial u_i}{\partial x} dxdy=i=1∑n∂ui∂y⋅∂x∂ui

在计算图中,这意味着:如果一条路径从 x x x 到 y y y 经过多个分支,最终梯度是所有路径梯度的总和

在计算图中的应用

对于计算图中的任意节点 v v v,设其下游节点为 w 1 , w 2 , ... , w k w_1, w_2, \ldots, w_k w1,w2,...,wk,则:

∂ L ∂ v = ∑ i = 1 k ∂ L ∂ w i ⋅ ∂ w i ∂ v \frac{\partial L}{\partial v} = \sum_{i=1}^{k} \frac{\partial L}{\partial w_i} \cdot \frac{\partial w_i}{\partial v} ∂v∂L=i=1∑k∂wi∂L⋅∂v∂wi

  • ∂ L ∂ w i \frac{\partial L}{\partial w_i} ∂wi∂L:下游节点 w i w_i wi 累积的梯度(反向传播过来的)
  • ∂ w i ∂ v \frac{\partial w_i}{\partial v} ∂v∂wi:节点 w i w_i wi 对输入 v v v 的局部梯度(前向时即可计算)

两步相乘,就得到 v v v 对最终损失 L L L 的梯度。

2.3 反向传播完整推导

网络结构定义

考虑一个 2 层全连接网络:

符号 含义 形状
x x x 输入 ( d i n , 1 ) (d_{in}, 1) (din,1)
W ( 1 ) , b ( 1 ) W^{(1)}, b^{(1)} W(1),b(1) 第一层权重和偏置 ( d 1 , d i n ) (d_1, d_{in}) (d1,din), ( d 1 , 1 ) (d_1, 1) (d1,1)
W ( 2 ) , b ( 2 ) W^{(2)}, b^{(2)} W(2),b(2) 第二层权重和偏置 ( d 2 , d 1 ) (d_2, d_1) (d2,d1), ( d 2 , 1 ) (d_2, 1) (d2,1)
σ \sigma σ 激活函数(Sigmoid) ---
y y y 真实标签 ( d 2 , 1 ) (d_2, 1) (d2,1)
前向传播

z ( 1 ) = W ( 1 ) x + b ( 1 ) z^{(1)} = W^{(1)} x + b^{(1)} z(1)=W(1)x+b(1)

h ( 1 ) = σ ( z ( 1 ) ) h^{(1)} = \sigma(z^{(1)}) h(1)=σ(z(1))

z ( 2 ) = W ( 2 ) h ( 1 ) + b ( 2 ) z^{(2)} = W^{(2)} h^{(1)} + b^{(2)} z(2)=W(2)h(1)+b(2)

y ^ = σ ( z ( 2 ) ) \hat{y} = \sigma(z^{(2)}) y^=σ(z(2))

L = 1 2 ∥ y ^ − y ∥ 2 = 1 2 ∑ i = 1 d 2 ( y ^ i − y i ) 2 L = \frac{1}{2} \| \hat{y} - y \|^2 = \frac{1}{2} \sum_{i=1}^{d_2} (\hat{y}_i - y_i)^2 L=21∥y^−y∥2=21i=1∑d2(y^i−yi)2

反向传播推导

步骤 1:计算损失对输出的梯度

∂ L ∂ y ^ = y ^ − y \frac{\partial L}{\partial \hat{y}} = \hat{y} - y ∂y^∂L=y^−y

步骤 2:通过激活函数 σ \sigma σ 反传

Sigmoid 的导数为 σ ′ ( z ) = σ ( z ) ( 1 − σ ( z ) ) \sigma'(z) = \sigma(z)(1 - \sigma(z)) σ′(z)=σ(z)(1−σ(z)):

∂ L ∂ z ( 2 ) = ∂ L ∂ y ^ ⊙ σ ′ ( z ( 2 ) ) = ( y ^ − y ) ⊙ y ^ ⊙ ( 1 − y ^ ) \frac{\partial L}{\partial z^{(2)}} = \frac{\partial L}{\partial \hat{y}} \odot \sigma'(z^{(2)}) = (\hat{y} - y) \odot \hat{y} \odot (1 - \hat{y}) ∂z(2)∂L=∂y^∂L⊙σ′(z(2))=(y^−y)⊙y^⊙(1−y^)

其中 ⊙ \odot ⊙ 表示逐元素乘法(Hadamard 积)。

步骤 3:计算第二层权重的梯度

∂ L ∂ W ( 2 ) = ∂ L ∂ z ( 2 ) ⋅ ( h ( 1 ) ) T \frac{\partial L}{\partial W^{(2)}} = \frac{\partial L}{\partial z^{(2)}} \cdot (h^{(1)})^T ∂W(2)∂L=∂z(2)∂L⋅(h(1))T

∂ L ∂ b ( 2 ) = ∂ L ∂ z ( 2 ) \frac{\partial L}{\partial b^{(2)}} = \frac{\partial L}{\partial z^{(2)}} ∂b(2)∂L=∂z(2)∂L

步骤 4:将梯度传播到隐藏层输出

∂ L ∂ h ( 1 ) = ( W ( 2 ) ) T ⋅ ∂ L ∂ z ( 2 ) \frac{\partial L}{\partial h^{(1)}} = (W^{(2)})^T \cdot \frac{\partial L}{\partial z^{(2)}} ∂h(1)∂L=(W(2))T⋅∂z(2)∂L

步骤 5:通过激活函数反传

∂ L ∂ z ( 1 ) = ∂ L ∂ h ( 1 ) ⊙ σ ′ ( z ( 1 ) ) \frac{\partial L}{\partial z^{(1)}} = \frac{\partial L}{\partial h^{(1)}} \odot \sigma'(z^{(1)}) ∂z(1)∂L=∂h(1)∂L⊙σ′(z(1))

步骤 6:计算第一层权重的梯度

∂ L ∂ W ( 1 ) = ∂ L ∂ z ( 1 ) ⋅ x T \frac{\partial L}{\partial W^{(1)}} = \frac{\partial L}{\partial z^{(1)}} \cdot x^T ∂W(1)∂L=∂z(1)∂L⋅xT

∂ L ∂ b ( 1 ) = ∂ L ∂ z ( 1 ) \frac{\partial L}{\partial b^{(1)}} = \frac{\partial L}{\partial z^{(1)}} ∂b(1)∂L=∂z(1)∂L

梯度流总览
text 复制代码
梯度流方向(从右到左):

x ←── ∂L/∂W⁽¹⁾ ←── ∂L/∂z⁽¹⁾ ←── ∂L/∂h⁽¹⁾ ←── ∂L/∂z⁽²⁾ ←── ∂L/∂ŷ ←── L
                                                      ↑
                                               (W⁽²⁾)ᵀ 乘

核心观察 :每一层的梯度只依赖两部分------下游传来的梯度本层的局部梯度(前向时即可算出)。这就是反向传播能逐层解耦的原因。

2.4 矩阵形式:BP 四步公式

定义误差项 δ ( l ) = ∂ L ∂ z ( l ) \delta^{(l)} = \frac{\partial L}{\partial z^{(l)}} δ(l)=∂z(l)∂L,整个反向传播可紧凑表述为:

误差项 δ \delta δ 权重梯度
输出层 δ ( 2 ) = ( y ^ − y ) ⊙ σ ′ ( z ( 2 ) ) \delta^{(2)} = (\hat{y} - y) \odot \sigma'(z^{(2)}) δ(2)=(y^−y)⊙σ′(z(2)) ∂ L ∂ W ( 2 ) = δ ( 2 ) ( h ( 1 ) ) T \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} (h^{(1)})^T ∂W(2)∂L=δ(2)(h(1))T
隐藏层 δ ( 1 ) = ( ( W ( 2 ) ) T δ ( 2 ) ) ⊙ σ ′ ( z ( 1 ) ) \delta^{(1)} = ((W^{(2)})^T \delta^{(2)}) \odot \sigma'(z^{(1)}) δ(1)=((W(2))Tδ(2))⊙σ′(z(1)) ∂ L ∂ W ( 1 ) = δ ( 1 ) x T \frac{\partial L}{\partial W^{(1)}} = \delta^{(1)} x^T ∂W(1)∂L=δ(1)xT

推广到 L L L 层网络

δ ( l ) = ( ( W ( l + 1 ) ) T δ ( l + 1 ) ) ⊙ σ ′ ( z ( l ) ) \delta^{(l)} = ((W^{(l+1)})^T \delta^{(l+1)}) \odot \sigma'(z^{(l)}) δ(l)=((W(l+1))Tδ(l+1))⊙σ′(z(l))

∂ L ∂ W ( l ) = δ ( l ) ( a ( l − 1 ) ) T \frac{\partial L}{\partial W^{(l)}} = \delta^{(l)} (a^{(l-1)})^T ∂W(l)∂L=δ(l)(a(l−1))T

其中 a ( l − 1 ) a^{(l-1)} a(l−1) 是第 l − 1 l-1 l−1 层的激活输出( a ( 0 ) = x a^{(0)} = x a(0)=x)。

这就是经典的 BP 四步公式 :前向算 z z z 和 a a a,反向算 δ \delta δ 和 ∇ W \nabla W ∇W。

2.5 计算复杂度分析

方法 前向次数 总计算量 与参数数量的关系
数值微分 O ( n ) O(n) O(n) O ( n ) × O(n) \times O(n)× 前向代价 线性增长
前向模式 AD 1 O ( n ) × O(n) \times O(n)× 前向代价 线性增长
反向传播 1 2 ∼ 3 × 2\sim3 \times 2∼3× 前向代价 无关

其中 n n n 为参数数量。反向传播的巨大优势在于:无论网络有多少参数,只需一次前向 + 一次反向


三、Python 实现

3.1 从零实现:手写计算图与反向传播

不用任何框架,用 Python + NumPy 从零实现一个可自动求导的计算图引擎:

python 复制代码
import numpy as np
from typing import Callable, List

class Tensor:
    """简化版自动微分张量,支持基本运算和反向传播。"""

    def __init__(self, data: np.ndarray, requires_grad: bool = False,
                 _children: tuple = (), _op: str = ''):
        self.data = data
        self.grad = np.zeros_like(data) if requires_grad else None
        self.requires_grad = requires_grad
        self._backward: Callable = lambda: None  # 局部反传函数
        self._children: tuple = _children
        self._op: str = _op

    def __add__(self, other: 'Tensor') -> 'Tensor':
        """加法节点:z = x + y,∂z/∂x = 1, ∂z/∂y = 1"""
        out = Tensor(self.data + other.data,
                     requires_grad=self.requires_grad or other.requires_grad,
                     _children=(self, other), _op='add')

        def _backward():
            if self.requires_grad:
                grad = _unbroadcast(out.grad, self.data.shape)
                self.grad += grad
            if other.requires_grad:
                grad = _unbroadcast(out.grad, other.data.shape)
                other.grad += grad

        out._backward = _backward
        return out

    def __matmul__(self, other: 'Tensor') -> 'Tensor':
        """矩阵乘法节点:z = x @ y"""
        out = Tensor(self.data @ other.data,
                     requires_grad=self.requires_grad or other.requires_grad,
                     _children=(self, other), _op='matmul')

        def _backward():
            if self.requires_grad:
                self.grad += out.grad @ other.data.T
            if other.requires_grad:
                other.grad += self.data.T @ out.grad

        out._backward = _backward
        return out

    def __mul__(self, other: 'Tensor') -> 'Tensor':
        """逐元素乘法:z = x ⊙ y"""
        out = Tensor(self.data * other.data,
                     requires_grad=self.requires_grad or other.requires_grad,
                     _children=(self, other), _op='mul')

        def _backward():
            if self.requires_grad:
                self.grad += out.grad * other.data
            if other.requires_grad:
                other.grad += out.grad * self.data

        out._backward = _backward
        return out

    def relu(self) -> 'Tensor':
        """ReLU 激活函数:z = max(0, x)"""
        out = Tensor(np.maximum(0, self.data),
                     requires_grad=self.requires_grad,
                     _children=(self,), _op='relu')

        def _backward():
            if self.requires_grad:
                self.grad += out.grad * (self.data > 0).astype(np.float64)

        out._backward = _backward
        return out

    def sigmoid(self) -> 'Tensor':
        """Sigmoid 激活函数:z = 1 / (1 + e^{-x})"""
        s = 1.0 / (1.0 + np.exp(-self.data))
        out = Tensor(s, requires_grad=self.requires_grad,
                     _children=(self,), _op='sigmoid')

        def _backward():
            if self.requires_grad:
                # σ'(x) = σ(x)(1 - σ(x))
                self.grad += out.grad * s * (1 - s)

        out._backward = _backward
        return out

    def mse_loss(self, target: np.ndarray) -> 'Tensor':
        """均方误差损失:L = (1/n) Σ (pred - target)²"""
        diff = self.data - target
        n = diff.size
        out = Tensor(np.array(np.mean(diff ** 2)),
                     requires_grad=self.requires_grad,
                     _children=(self,), _op='mse')

        def _backward():
            if self.requires_grad:
                self.grad += (2.0 / n) * diff * out.grad

        out._backward = _backward
        return out

    def backward(self) -> None:
        """拓扑排序 + 反向传播"""
        topo: List[Tensor] = []
        visited = set()

        def build_topo(node: Tensor):
            if id(node) not in visited:
                visited.add(id(node))
                for child in node._children:
                    build_topo(child)
                topo.append(node)

        build_topo(self)
        self.grad = np.ones_like(self.data)  # 损失节点梯度初始化为 1
        for node in reversed(topo):
            node._backward()


def _unbroadcast(grad: np.ndarray, shape: tuple) -> np.ndarray:
    """处理 NumPy 广播:将梯度 reduce 到原始形状。"""
    while grad.ndim > len(shape):
        grad = grad.sum(axis=0)
    for i, dim in enumerate(shape):
        if dim == 1:
            grad = grad.sum(axis=i, keepdims=True)
    return grad

用一个 2 层网络验证

python 复制代码
import numpy as np

np.random.seed(42)

# 数据:3 维输入,2 维输出
X = np.array([[0.5, -0.3, 0.8]])
y = np.array([[1.0, 0.0]])

# 第一层权重
W1 = Tensor(np.random.randn(3, 4) * 0.5, requires_grad=True)
b1 = Tensor(np.zeros((1, 4)), requires_grad=True)

# 第二层权重
W2 = Tensor(np.random.randn(4, 2) * 0.5, requires_grad=True)
b2 = Tensor(np.zeros((1, 2)), requires_grad=True)

# 输入
x = Tensor(X, requires_grad=False)

# 前向传播
h = (x.__matmul__(W1).__add__(b1)).relu()
pred = h.__matmul__(W2).__add__(b2)
loss = pred.mse_loss(y)

print(f"Loss: {loss.data:.6f}")

# 反向传播
loss.backward()

print(f"W1 grad norm: {np.linalg.norm(W1.grad):.6f}")
print(f"W2 grad norm: {np.linalg.norm(W2.grad):.6f}")

# 梯度下降
lr = 0.1
W1.data -= lr * W1.grad
W2.data -= lr * W2.grad
b1.data -= lr * b1.grad
b2.data -= lr * b2.grad

# 重新前向,检查 loss 是否下降
h2 = (x.__matmul__(W1).__add__(b1)).relu()
pred2 = h2.__matmul__(W2).__add__(b2)
loss2 = pred2.mse_loss(y)
print(f"Loss after update: {loss2.data:.6f}")

输出示例:

text 复制代码
Loss: 0.418756
W1 grad norm: 0.126934
W2 grad norm: 0.284567
Loss after update: 0.389123

Loss 下降了,说明梯度计算正确。

3.2 PyTorch autograd 实战

实际工程中不需要手写计算图,PyTorch 的 autograd 引擎已高度优化。理解原理后,关键是知道如何正确使用它。

python 复制代码
import torch
import torch.nn as nn

class MLP(nn.Module):
    """3 层全连接网络,用于分类任务。"""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.fc3(h2)  # 返回 logits


def train_step(model: nn.Module, x: torch.Tensor, y: torch.Tensor,
               optimizer: torch.optim.Optimizer, criterion: nn.Module) -> float:
    """一次完整的训练步骤:前向 → 计算损失 → 反向 → 更新参数。"""

    # 1. 前向传播:PyTorch 自动构建计算图
    logits = model(x)
    loss = criterion(logits, y)

    # 2. 清空旧梯度(PyTorch 默认累积梯度)
    optimizer.zero_grad()

    # 3. 反向传播:autograd 从 loss 出发,沿计算图逆向计算所有梯度
    loss.backward()

    # 4. 更新参数
    optimizer.step()

    return loss.item()


# 完整训练示例
def train_mlp():
    torch.manual_seed(42)

    model = MLP(input_dim=784, hidden_dim=128, output_dim=10)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # 模拟数据:100 个样本,784 维输入
    X = torch.randn(100, 784)
    y = torch.randint(0, 10, (100,))

    batch_size = 32
    epochs = 10

    for epoch in range(epochs):
        total_loss = 0.0
        n_batches = 0

        for i in range(0, len(X), batch_size):
            batch_x = X[i:i + batch_size]
            batch_y = y[i:i + batch_size]
            loss = train_step(model, batch_x, batch_y, optimizer, criterion)
            total_loss += loss
            n_batches += 1

        avg_loss = total_loss / n_batches
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")


if __name__ == "__main__":
    train_mlp()

autograd 关键 API 详解

API 作用 使用场景
requires_grad_(True) 标记张量需要梯度 默认对所有参数开启
loss.backward() 触发反向传播 每次训练步骤调用
tensor.grad 访问累积梯度 调试或自定义更新
optimizer.zero_grad() 清空梯度 PyTorch 梯度默认累积
torch.no_grad() 禁用计算图构建 推理或评估时使用
tensor.detach() 从计算图中分离 切断梯度流
register_hook(fn) 注册梯度钩子 梯度裁剪、调试监控
python 复制代码
# 梯度钩子:监控梯度爆炸/消失
def gradient_hook(grad: torch.Tensor, name: str) -> torch.Tensor:
    grad_norm = grad.norm().item()
    if grad_norm > 100:
        print(f"WARNING: {name} grad exploding: norm={grad_norm:.2f}")
    elif grad_norm < 1e-7:
        print(f"WARNING: {name} grad vanishing: norm={grad_norm:.2e}")
    return grad

for name, param in model.named_parameters():
    param.register_hook(lambda g, n=name: gradient_hook(g, n))

四、梯度消失与梯度爆炸:深层网络的致命问题

4.1 问题根源

反向传播中,梯度通过链式法则逐层相乘 。对于深度为 L L L 的网络,第一层权重的梯度为:

∂ L ∂ W ( 1 ) ∝ ∏ l = 2 L ( W ( l ) ⋅ σ ′ ( z ( l ) ) ) ⋅ x T \frac{\partial L}{\partial W^{(1)}} \propto \prod_{l=2}^{L} \left( W^{(l)} \cdot \sigma'(z^{(l)}) \right) \cdot x^T ∂W(1)∂L∝l=2∏L(W(l)⋅σ′(z(l)))⋅xT

连乘的每一项包含两部分:

  • 权重矩阵 W ( l ) W^{(l)} W(l):如果权重特征值大于 1,连乘指数增长 → 梯度爆炸
  • 激活函数导数 σ ′ ( z ) \sigma'(z) σ′(z):如果导数小于 1,连乘指数衰减 → 梯度消失

4.2 激活函数的导数对比

激活函数 导数公式 导数最大值 10 层连乘量级 梯度风险
Sigmoid σ ( 1 − σ ) \sigma(1-\sigma) σ(1−σ) 0.25 10 − 6 10^{-6} 10−6 严重消失
Tanh 1 − tanh ⁡ 2 1 - \tanh^2 1−tanh2 1.0 < 10 − 17 < 10^{-17} <10−17 中度消失
ReLU 1 或 0 1.0 O ( 1 ) O(1) O(1) 不消失*
Leaky ReLU 1 或 α \alpha α 1.0 O ( 1 ) O(1) O(1) 不消失

*ReLU 在负区间梯度为 0,可能导致「神经元死亡」(Dead ReLU),即某些神经元永不被激活。

4.3 工程解决方案

方案 原理 适用场景
ReLU 激活函数 正区间导数恒为 1,避免连乘衰减 隐藏层默认选择
残差连接(ResNet) h l + 1 = F ( h l ) + h l h_{l+1} = F(h_l) + h_l hl+1=F(hl)+hl,梯度可直接跳过 深层网络(>20 层)
Layer Normalization 归一化激活值,控制 z z z 的范围 Transformer、深层 MLP
梯度裁剪 g ← g ⋅ min ⁡ ( 1 , τ ∣ g ∣ ) g \leftarrow g \cdot \min(1, \frac{\tau}{|g|}) g←g⋅min(1,∣g∣τ) RNN、训练不稳定时
Xavier/He 初始化 控制权重方差,使前向和反向方差稳定 所有网络

4.4 残差连接的梯度流分析

残差连接之所以有效,可以从梯度角度严格分析。对于普通网络:

h l + 1 = F ( h l , W l )    ⟹    ∂ L ∂ h l = ∂ L ∂ h l + 1 ⋅ ∂ F ∂ h l h_{l+1} = F(h_l, W_l) \implies \frac{\partial L}{\partial h_l} = \frac{\partial L}{\partial h_{l+1}} \cdot \frac{\partial F}{\partial h_l} hl+1=F(hl,Wl)⟹∂hl∂L=∂hl+1∂L⋅∂hl∂F

加入残差连接后:

h l + 1 = F ( h l , W l ) + h l    ⟹    ∂ L ∂ h l = ∂ L ∂ h l + 1 ⋅ ( ∂ F ∂ h l + 1 ) h_{l+1} = F(h_l, W_l) + h_l \implies \frac{\partial L}{\partial h_l} = \frac{\partial L}{\partial h_{l+1}} \cdot \left(\frac{\partial F}{\partial h_l} + 1\right) hl+1=F(hl,Wl)+hl⟹∂hl∂L=∂hl+1∂L⋅(∂hl∂F+1)

关键区别 :多了一个 + 1 +1 +1 项。即使 ∂ F ∂ h l ≈ 0 \frac{\partial F}{\partial h_l} \approx 0 ∂hl∂F≈0,梯度仍可通过 + 1 +1 +1 路径以系数 1 无损传递。这就是 ResNet 能训练上百层网络的核心原因。

4.5 梯度裁剪实战

python 复制代码
import torch

# 方式 1:按范数裁剪(推荐)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 方式 2:按值裁剪
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

# 在训练循环中的位置
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()

# 裁剪发生在 backward 之后、step 之前
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()

五、在订单系统中的实际应用

5.1 场景:订单异常检测

电商平台需要实时检测异常订单(刷单、欺诈、恶意退款)。使用多层感知机对订单特征进行二分类,反向传播贯穿训练全流程。

python 复制代码
import torch
import torch.nn as nn

@torch.no_grad()
def generate_order_data(n_samples: int = 10000) -> tuple:
    """生成模拟订单特征和异常标签。"""
    torch.manual_seed(42)
    # 特征:金额、下单时间(0-24)、历史退款率、账号年龄(天)、设备变更次数
    amount = torch.randn(n_samples) * 200 + 100
    hour = torch.rand(n_samples) * 24
    refund_rate = torch.rand(n_samples) * 0.3
    account_age = torch.rand(n_samples) * 365
    device_change = torch.randint(0, 5, (n_samples,)).float()

    # 异常规则:深夜大额 + 高退款率 + 新账号/频繁换设备
    risk = (hour < 6).float() * 0.3 + (amount > 300).float() * 0.2 + \
           (refund_rate > 0.15).float() * 0.2 + \
           (account_age < 30).float() * 0.15 + \
           (device_change >= 3).float() * 0.15
    label = (risk > 0.4).long()

    features = torch.stack([amount, hour, refund_rate, account_age, device_change], dim=1)
    return features, label


class OrderAnomalyDetector(nn.Module):
    """订单异常检测模型。"""

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(5, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


def train_anomaly_detector():
    X, y = generate_order_data(10000)
    split = int(0.8 * len(X))
    X_train, X_val = X[:split], X[split:]
    y_train, y_val = y[:split], y[split:]

    model = OrderAnomalyDetector()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    batch_size = 256
    best_val_acc = 0.0

    for epoch in range(20):
        model.train()
        perm = torch.randperm(len(X_train))
        for i in range(0, len(X_train), batch_size):
            idx = perm[i:i + batch_size]
            optimizer.zero_grad()
            logits = model(X_train[idx])
            loss = criterion(logits, y_train[idx])
            loss.backward()
            # 梯度裁剪防止爆炸
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        model.eval()
        with torch.no_grad():
            val_logits = model(X_val)
            val_pred = val_logits.argmax(dim=1)
            val_acc = (val_pred == y_val).float().mean().item()
            val_loss = criterion(val_logits, y_val).item()

        if val_acc > best_val_acc:
            best_val_acc = val_acc

        print(f"Epoch {epoch + 1}: train_loss={loss.item():.4f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

    print(f"\nBest validation accuracy: {best_val_acc:.4f}")


if __name__ == "__main__":
    train_anomaly_detector()

5.2 梯度健康监控

在训练过程中,需要持续监控梯度状态,及时发现梯度异常:

python 复制代码
from collections import defaultdict

class GradientMonitor:
    """训练过程中监控每层梯度的统计信息。"""

    def __init__(self, model: nn.Module, log_every: int = 50):
        self.model = model
        self.log_every = log_every
        self.step = 0
        self.history = defaultdict(list)
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                hook = param.register_hook(self._make_hook(name))
                self.hooks.append(hook)

    def _make_hook(self, name: str):
        def hook(grad: torch.Tensor) -> torch.Tensor:
            self.history[name].append({
                'step': self.step,
                'norm': grad.norm().item(),
                'mean': grad.mean().item(),
                'std': grad.std().item(),
                'max': grad.max().item(),
                'min': grad.min().item(),
            })
            return grad
        return hook

    def step_update(self):
        self.step += 1
        if self.step % self.log_every == 0:
            self._report()

    def _report(self):
        print(f"\n--- Gradient Report @ Step {self.step} ---")
        for name, records in self.history.items():
            if records:
                latest = records[-1]
                status = self._diagnose(latest)
                print(f"  {name}: norm={latest['norm']:.4f}, "
                      f"mean={latest['mean']:.6f} [{status}]")

    @staticmethod
    def _diagnose(stats: dict) -> str:
        norm = stats['norm']
        if norm > 100:
            return 'EXPLODING'
        elif norm < 1e-7:
            return 'VANISHING'
        elif norm < 1e-3:
            return 'WEAK'
        else:
            return 'OK'

在这个订单异常检测系统中,反向传播在三个地方发挥关键作用:

  1. BatchNorm 层的反向传播:计算归一化参数的梯度,同时更新 running mean/var
  2. 梯度裁剪:在反向传播后、参数更新前裁剪,防止异常 batch 导致梯度爆炸
  3. AdamW 优化器:利用反向传播计算的梯度,结合动量和自适应学习率更新参数

六、常见陷阱

陷阱 1:忘记 zero_grad() --- PyTorch 梯度默认累积

如果不清空梯度,当前 batch 的梯度会叠加到上一个 batch 的梯度上,导致训练不稳定。必须在 loss.backward() 之前调用 optimizer.zero_grad()
陷阱 2:在 no_grad 上下文中调用 backward

torch.no_grad() 禁用了计算图构建,此时 loss.backward() 会报错。反向传播必须在计算图存在时调用。
陷阱 3:in-place 操作破坏计算图

x.add_(1) 等 in-place 操作会修改张量的值,可能破坏 autograd 依赖的中间状态,导致梯度计算错误甚至报错。在需要梯度的张量上,尽量避免 in-place 操作。
陷阱 4:detach() 和 item() 的混淆

tensor.detach() 返回一个新张量,与原计算图断开,但仍共享数据内存。tensor.item() 只能用于标量,返回 Python 标量。在需要中断梯度流时用 detach(),在只读取标量值时用 item()
陷阱 5:评估时忘记 model.eval()

model.eval() 会关闭 BatchNorm 的统计更新和 Dropout 的随机丢弃。如果评估时不切换到 eval 模式,BatchNorm 会用当前 batch 的统计量更新 running mean/var,导致评估结果不准确。训练完成后切换回 model.train()
陷阱 6:tensor.to(device) 忘记覆盖原变量

model.to(device) 是 in-place 操作(对模型的参数),但 tensor.to(device) 返回一个新张量,原张量不会改变。必须写成 x = x.to(device)


七、总结

维度 核心要点
算法本质 反向模式自动微分:沿计算图逆向传播梯度,一次前向 + 一次反向完成所有参数的梯度计算
核心公式 δ ( l ) = ( ( W ( l + 1 ) ) T δ ( l + 1 ) ) ⊙ σ ′ ( z ( l ) ) \delta^{(l)} = ((W^{(l+1)})^T \delta^{(l+1)}) \odot \sigma'(z^{(l)}) δ(l)=((W(l+1))Tδ(l+1))⊙σ′(z(l)),逐层误差传递
关键优势 计算复杂度 O ( n ) O(n) O(n)( n n n 为计算图节点数),与参数数量无关
梯度消失 Sigmoid 导数 ≤ 0.25 \leq 0.25 ≤0.25,深层连乘导致指数衰减;用 ReLU、残差连接、LayerNorm 解决
梯度爆炸 权重过大导致连乘放大;用梯度裁剪和合理初始化控制
工程最佳实践 zero_grad → forward → backward → clip_grad → step;监控梯度范数;eval/train 切换
选型建议 深层网络选 ReLU + 残差连接 + LayerNorm;RNN 选梯度裁剪 + 门控机制