1. 计算图构建过程
python
x = torch.ones(5, requires_grad=True) # 定义叶子节点,启用梯度跟踪
y = x + 2 # 加法操作,生成中间节点 y
z = y * y * 3 # 平方与乘法操作,生成中间节点 z
out = z.mean() # 标量输出(损失函数)
-
动态计算图构建 :
每行代码触发一个操作,PyTorch 动态记录操作依赖关系,生成有向无环图(DAG):
x → (Add) → y → (Pow + Mul) → z → (Mean) → out
节点类型:
- 叶子节点 :用户直接创建的
x
(x.is_leaf = True
)。 - 非叶子节点 :
y
,z
,out
由运算生成(grad_fn
属性记录操作类型)
- 叶子节点 :用户直接创建的
-
梯度跟踪机制 :
设置
requires_grad=True
后,所有依赖x
的中间节点自动继承此属性(如y.requires_grad=True
)
2. 反向传播与梯度计算
python
out.backward() # 触发反向传播
-
•反向传播流程 :
-
1.从
out
开始反向遍历 :因out
是标量(shape=()
),无需额外指定梯度权重
。 -
- 链式法则应用:
out = z.mean()
→ ∂zi∂out=51(z
有 5 个元素)。z = 3y^2
→ ∂yi∂zi=6yi。y = x + 2
→ ∂xi∂yi=1
-
3.梯度计算 :
∂xi∂out=∂zi∂out⋅∂yi∂zi⋅∂xi∂yi=51⋅6yi⋅1=56(xi+2)。
-
-
•梯度存储 :
结果存入叶子节点
x.grad
,非叶子节点(如y
,z
)的梯度默认不保留以节省内存
3. 梯度结果验证
python
print(f"x 的梯度: {x.grad}") # 输出:tensor([3.6000, 3.6000, 3.6000, 3.6000, 3.6000])
-
•数学推导 :
代入 xi=1:
∂xi∂out=56(1+2)=518=3.6。
与代码输出一致,验证了链式法则的正确性
4. 梯度累积问题
-
•默认行为 :
backward()
计算的梯度会累加 到x.grad
。若多次执行out.backward()
,梯度将叠加(如运行两次后x.grad
变为[7.2, 7.2, ...]
) -
解决方案 :
训练循环中需在每次反向传播前调用
x.grad.zero_()
或optimizer.zero_grad()
清零梯度
关键概念总结
概念 | 说明 | 代码示例 |
---|---|---|
叶子节点 | 用户直接创建的张量,梯度计算终点 | x = torch.ones(5, requires_grad=True) |
动态计算图 | 运行时动态构建的操作依赖图,反向传播后自动释放 | y = x + 2 生成 AddBackward 节点 |
非标量反向传播 | 若 out 非标量(如向量),需传入 gradient 参数作为权重矩阵 |
z.backward(torch.ones_like(z)) |
梯度保留 | 设置 retain_graph=True 可保留计算图,支持多次反向传播 |
out.backward(retain_graph=True) |
提示 :理解计算图结构是调试自动求导的关键。可通过 print(y.grad_fn)
查看操作类型(如输出 <AddBackward0>
),或使用 torchviz
库可视化计算图