PyTorch自动求导

1. 计算图构建过程

python 复制代码
x = torch.ones(5, requires_grad=True)  # 定义叶子节点,启用梯度跟踪
y = x + 2                             # 加法操作,生成中间节点 y
z = y * y * 3                         # 平方与乘法操作,生成中间节点 z
out = z.mean()                        # 标量输出(损失函数)
  • 动态计算图构建

    每行代码触发一个操作,PyTorch 动态记录操作依赖关系,生成有向无环图(DAG):

    x → (Add) → y → (Pow + Mul) → z → (Mean) → out

    节点类型:

    • 叶子节点 :用户直接创建的 xx.is_leaf = True)。
    • 非叶子节点y, z, out由运算生成(grad_fn属性记录操作类型)
  • 梯度跟踪机制

    设置 requires_grad=True后,所有依赖 x的中间节点自动继承此属性(如 y.requires_grad=True


2. 反向传播与梯度计算

python 复制代码
out.backward()  # 触发反向传播
  • 反向传播流程

    1. 1.out开始反向遍历 :因 out是标量(shape=()),无需额外指定梯度权重

      1. 链式法则应用​:
      • out = z.mean()→ ∂zi∂out=51(z有 5 个元素)。
      • z = 3y^2→ ∂yi∂zi=6yi。
      • y = x + 2→ ∂xi∂yi=1
    2. 3.梯度计算

      ∂xi​∂out​=∂zi​∂out​⋅∂yi​∂zi​​⋅∂xi​∂yi​​=51​⋅6yi​⋅1=56​(xi​+2)。

  • 梯度存储

    结果存入叶子节点 x.grad,非叶子节点(如 y, z)的梯度默认不保留以节省内存


3. 梯度结果验证

python 复制代码
print(f"x 的梯度: {x.grad}")  # 输出:tensor([3.6000, 3.6000, 3.6000, 3.6000, 3.6000])
  • 数学推导

    代入 xi​=1:

    ∂xi​∂out​=56​(1+2)=518​=3.6。

    与代码输出一致,验证了链式法则的正确性


4. 梯度累积问题

  • 默认行为

    backward()计算的梯度会累加x.grad。若多次执行 out.backward(),梯度将叠加(如运行两次后 x.grad变为 [7.2, 7.2, ...]

  • 解决方案

    训练循环中需在每次反向传播前调用 x.grad.zero_()optimizer.zero_grad()清零梯度


关键概念总结

概念 说明 代码示例
叶子节点 用户直接创建的张量,梯度计算终点 x = torch.ones(5, requires_grad=True)
动态计算图 运行时动态构建的操作依赖图,反向传播后自动释放 y = x + 2生成 AddBackward节点
非标量反向传播 out非标量(如向量),需传入 gradient参数作为权重矩阵 z.backward(torch.ones_like(z))
梯度保留 设置 retain_graph=True可保留计算图,支持多次反向传播 out.backward(retain_graph=True)

提示 ​:理解计算图结构是调试自动求导的关键。可通过 print(y.grad_fn)查看操作类型(如输出 <AddBackward0>),或使用 torchviz库可视化计算图