PyTorch梯度计算
- 介绍
- [PyTorch 梯度计算和更新详解](#PyTorch 梯度计算和更新详解)
-
- 一、核心概念
-
- [1. 自动微分(Autograd)](#1. 自动微分(Autograd))
- [2. `requires_grad` 属性](#2.
requires_grad属性) - [3. `.grad` 属性](#3.
.grad属性)
- 二、基础示例
-
- [示例 1:简单的梯度计算](#示例 1:简单的梯度计算)
- [示例 2:多步运算的梯度](#示例 2:多步运算的梯度)
- 三、神经网络训练示例
- 四、重要概念详解
-
- [1. **为什么需要 `zero_grad()`?**](#1. 为什么需要
zero_grad()?) - [2. **`with torch.no_grad()`的作用**](#2.
with torch.no_grad()的作用)
- [1. **为什么需要 `zero_grad()`?**](#1. 为什么需要
- 五、调试技巧
- 总结
- 补充:torch.autograd.Function使用
- [PyTorch Autograd 高级特性详解](#PyTorch Autograd 高级特性详解)
-
- 基础使用
-
- [1. **torch.autograd.Function 的结构**](#1. torch.autograd.Function 的结构)
- [2. **ctx (Context) 对象的重要方法**](#2. ctx (Context) 对象的重要方法)
- 使用示例
-
- [1. **梯度检查 (Gradient Check)**](#1. 梯度检查 (Gradient Check))
- [2. **高阶导数和 Hessian 矩阵**](#2. 高阶导数和 Hessian 矩阵)
- [3. **Hook 机制详解**](#3. Hook 机制详解)
- [4. **实用进阶技巧**](#4. 实用进阶技巧)
- 总结与最佳实践
-
- [**核心 API 速查**](#核心 API 速查)
- [补充2:深入理解 backward() 的 grad_outputs 参数](#补充2:深入理解 backward() 的 grad_outputs 参数)
-
-
- [**1. grad_outputs 的本质**](#1. grad_outputs 的本质)
- [**2. 三种典型情况**](#2. 三种典型情况)
- [**3. 为什么向量输出必须指定?**](#3. 为什么向量输出必须指定?)
- [**4. 实用场景**](#4. 实用场景)
- [**5. 常见错误**](#5. 常见错误)
-
介绍
PyTorch 梯度计算和更新详解
我来详细介绍 PyTorch 的梯度计算和更新机制,并提供易于理解的示例。
一、核心概念
1. 自动微分(Autograd)
PyTorch 使用动态计算图来自动计算梯度。当你对张量进行操作时,PyTorch 会记录这些操作,然后反向传播时自动计算梯度。
2. requires_grad 属性
- 设置为
True的张量会追踪所有对它的操作 - 用于告诉 PyTorch 需要计算这个张量的梯度
3. .grad 属性
- 存储计算得到的梯度值
- 默认会累积梯度,需要手动清零
二、基础示例
示例 1:简单的梯度计算
python
import torch
# 创建一个需要梯度的张量
x = torch.tensor([2.0], requires_grad=True)
# 定义一个简单函数:y = x^2
y = x ** 2
print(f"x = {x}")
print(f"y = {y}")
# 反向传播计算梯度
y.backward()
# 查看梯度:dy/dx = 2x = 2*2 = 4
print(f"x的梯度: {x.grad}")
输出解释:
- y = x²
- dy/dx = 2x = 2×2 = 4
示例 2:多步运算的梯度
python
import torch
x = torch.tensor([3.0], requires_grad=True)
# 多步运算:z = (x + 2) * (x^2)
y = x + 2 # y = x + 2 = 5
z = y * (x ** 2) # z = 5 * 9 = 45
print(f"z = {z}")
z.backward()
# 梯度计算:dz/dx = d[(x+2)*x²]/dx = x² + 2x(x+2) = x² + 2x² + 4x = 3x² + 4x
# 当 x=3: 3*9 + 4*3 = 27 + 12 = 39
print(f"x的梯度: {x.grad}")
三、神经网络训练示例
python
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
print("\n" + "=" * 60)
print("示例 2: 梯度累积问题")
print("=" * 60)
x = torch.tensor([3.0], requires_grad=True)
# 第一次计算
y1 = x ** 2
y1.backward()
print(f"第一次backward后,x.grad = {x.grad.item()}")
# 如果不清零,梯度会累积
y2 = x ** 3
y2.backward()
print(f"第二次backward后(未清零),x.grad = {x.grad.item()}")
# 清零梯度
x.grad.zero_()
y3 = x ** 3
y3.backward()
print(f"清零后再次backward,x.grad = {x.grad.item()}")
print("\n" + "=" * 60)
print("示例 3: 简单线性回归")
print("=" * 60)
# 生成训练数据:y = 3x + 2 + 噪声
torch.manual_seed(42)
x_train = torch.randn(100, 1)
y_label = 3 * x_train + 2 + torch.randn(100, 1) * 0.3
# 定义模型参数
w = torch.tensor([[0.0]], requires_grad=True)
b = torch.tensor([[0.0]], requires_grad=True)
# 训练参数
learning_rate = 0.01
epochs = 100
# 记录损失
losses = []
print("开始训练...")
for epoch in range(epochs):
# 前向传播
y_pred = x_train @ w + b
# 计算损失(均方误差)
loss = ((y_pred - y_label) ** 2).mean()
# 反向传播
loss.backward()
# 手动更新参数(梯度下降)
with torch.no_grad(): # 更新参数时不需要计算梯度
w -= learning_rate * w.grad
b -= learning_rate * b.grad
# 清零梯度
w.grad.zero_()
b.grad.zero_()
losses.append(loss.item())
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, w = {w.item():.4f}, b = {b.item():.4f}")
print(f"\n最终参数: w = {w.item():.4f}, b = {b.item():.4f}")
print("真实参数: w = 3.0000, b = 2.0000")
print("\n" + "=" * 60)
print("示例 4: 使用优化器的神经网络")
print("=" * 60)
# 定义一个简单的神经网络
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(1, 10)
self.fc2 = nn.Linear(10, 1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建模型
model = SimpleNN()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练数据:y = sin(x)
x_data = torch.linspace(-3, 3, 100).reshape(-1, 1)
y_data = torch.sin(x_data)
print("使用优化器训练神经网络...")
for epoch in range(200):
# 前向传播
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
# 反向传播
optimizer.zero_grad() # 清零梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}")
print("\n" + "=" * 60)
print("示例 5: 查看网络中的梯度")
print("=" * 60)
# 创建一个小型网络
net = nn.Sequential(
nn.Linear(2, 3),
nn.ReLU(),
nn.Linear(3, 1)
)
# 前向传播
x = torch.tensor([[1.0, 2.0]])
y = torch.tensor([[1.0]])
output = net(x)
loss = (output - y) ** 2
# 反向传播
loss.backward()
# 查看每层的梯度
print("网络各层的梯度:")
for name, param in net.named_parameters():
if param.grad is not None:
print(f"{name}:")
print(f" 形状: {param.grad.shape}")
print(f" 梯度: {param.grad}")
print()
print("=" * 60)
print("关键要点总结:")
print("=" * 60)
print("1. requires_grad=True: 启用梯度追踪")
print("2. .backward(): 计算梯度")
print("3. .grad: 访问梯度值")
print("4. .zero_grad(): 清零梯度(防止累积)")
print("5. with torch.no_grad(): 临时禁用梯度计算")
print("6. optimizer.step(): 使用计算的梯度更新参数")
print("=" * 60)
四、重要概念详解
1. 为什么需要 zero_grad()?
梯度默认会累积,如果不清零,新的梯度会加到旧梯度上:
python
x = torch.tensor([2.0], requires_grad=True)
y1 = x ** 2
y1.backward()
print(f"第一次:{x.grad}") # 4.0
y2 = x ** 2
y2.backward()
print(f"第二次(累积):{x.grad}") # 8.0(4+4)
x.grad.zero_()
y3 = x ** 2
y3.backward()
print(f"清零后:{x.grad}") # 4.0
2. with torch.no_grad()的作用
在不需要计算梯度的地方(如评估、推理、手动更新参数)使用,可以节省内存和计算:
python
# 推理时不需要梯度
with torch.no_grad():
predictions = model(test_data)
# 手动更新参数时不需要梯度
with torch.no_grad():
w -= learning_rate * w.grad
五、调试技巧
检查梯度是否计算:
python
# 检查张量是否需要梯度
print(x.requires_grad)
# 检查梯度是否已计算
print(x.grad)
# 查看计算图
print(y.grad_fn)
# 分离计算图(不再追踪梯度)
x_detached = x.detach()
梯度裁剪(防止梯度爆炸):
python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
总结
PyTorch 的梯度计算遵循以下流程:
- 设置
requires_grad=True追踪计算 - 前向传播执行计算
- 调用
.backward()自动计算梯度 - 使用梯度更新参数
- 清零梯度准备下一次迭代
这个自动微分机制让我们无需手动推导复杂的梯度公式,极大简化了深度学习模型的开发!
补充:torch.autograd.Function使用
PyTorch Autograd 高级特性详解
PyTorch 的高级自动微分特性,包括自定义反向传播、梯度检查、高阶导数等。## 深入解析关键概念
基础使用
1. torch.autograd.Function 的结构
python
class CustomFunction(Function):
@staticmethod
def forward(ctx, *args):
# ctx: 上下文对象
# *args: 可变数量的输入
# 返回: 前向传播的输出
pass
@staticmethod
def backward(ctx, *grad_outputs):
# *grad_outputs: 从后续层传来的梯度
# 返回: 对每个输入的梯度 (顺序与forward的输入对应)
# 如果某个输入不需要梯度,返回 None
pass
2. ctx (Context) 对象的重要方法
python
# 保存张量 (会自动处理梯度追踪)
ctx.save_for_backward(tensor1, tensor2, ...)
# 保存非张量数据
ctx.constant = value
ctx.mark_non_differentiable(tensor) # 标记输出不需要梯度
# 标记输入是否需要梯度 (优化性能)
ctx.needs_input_grad[0] # 检查第一个输入是否需要梯度
使用示例
1. 梯度检查 (Gradient Check)
数值方法验证解析梯度的正确性:
python
import numpy as np
import torch
from torch.autograd import Function, gradcheck
print("=" * 70)
print("梯度检查原理详解")
print("=" * 70)
def numerical_gradient(f, x, eps=1e-5):
"""
使用有限差分法计算数值梯度
f(x + h) - f(x - h)
---------------------
2h
"""
grad = torch.zeros_like(x)
for i in range(x.numel()):
# 创建扰动
x_plus = x.clone()
x_minus = x.clone()
x_plus.view(-1)[i] += eps
x_minus.view(-1)[i] -= eps
# 计算数值梯度
f_plus = f(x_plus)
f_minus = f(x_minus)
grad.view(-1)[i] = (f_plus - f_minus) / (2 * eps)
return grad
# 定义一个简单函数
def func(x):
return (x ** 2).sum()
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
# 解析梯度
y = func(x)
y.backward()
analytical_grad = x.grad.clone()
# 数值梯度
x_no_grad = x.detach().clone()
numerical_grad = numerical_gradient(func, x_no_grad)
print("函数: f(x) = sum(x^2)")
print(f"输入: {x.data}")
print(f"\n解析梯度 (自动微分): {analytical_grad}")
print(f"数值梯度 (有限差分): {numerical_grad}")
print(f"差异: {(analytical_grad - numerical_grad).abs().max().item():.2e}")
print("\n" + "=" * 70)
print("使用 gradcheck 自动验证")
print("=" * 70)
class MyExp(Function):
@staticmethod
def forward(ctx, x):
result = torch.exp(x)
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
# 使用 double 精度以提高数值稳定性
x = torch.randn(3, 4, dtype=torch.double, requires_grad=True)
print("测试自定义 Exp 函数...")
test_result = gradcheck(MyExp.apply, x, eps=1e-6, atol=1e-4)
print(f"结果: {'✓ 通过' if test_result else '✗ 失败'}")
print("\n" + "=" * 70)
print("常见梯度错误示例")
print("=" * 70)
class WrongGradient(Function):
"""故意写错的梯度"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
# 错误: 应该是 2*x,这里写成 3*x
return grad_output * 3 * x
print("测试错误的梯度实现...")
x = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
try:
test_result = gradcheck(WrongGradient.apply, x, eps=1e-6, atol=1e-4)
print(f"结果: {'✓ 通过' if test_result else '✗ 失败'}")
except RuntimeError as e:
print("✗ 检测到错误: gradcheck 失败")
print(" 说明梯度实现有误!")
print("\n" + "=" * 70)
print("复杂函数的梯度验证")
print("=" * 70)
class ComplexFunction(Function):
"""
复杂函数: f(x, y) = sin(x) * exp(y) + x^2 * y
"""
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x) * torch.exp(y) + x**2 * y
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
# df/dx = cos(x)*exp(y) + 2*x*y
grad_x = grad_output * (torch.cos(x) * torch.exp(y) + 2 * x * y)
# df/dy = sin(x)*exp(y) + x^2
grad_y = grad_output * (torch.sin(x) * torch.exp(y) + x**2)
return grad_x, grad_y
x = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
y = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
print("测试复杂函数...")
test_result = gradcheck(ComplexFunction.apply, (x, y), eps=1e-6, atol=1e-4)
print(f"结果: {'✓ 通过' if test_result else '✗ 失败'}")
print("\n" + "=" * 70)
print("梯度检查最佳实践")
print("=" * 70)
print("1. 使用 double 精度 (torch.double)")
print("2. 设置合适的 eps (通常 1e-6)")
print("3. 设置合适的 atol (绝对容差,通常 1e-4)")
print("4. 测试多个随机输入")
print("5. 在开发阶段测试,生产环境可以关闭")
print("=" * 70)
2. 高阶导数和 Hessian 矩阵
python
import torch
import torch.autograd as autograd
# ======================================================================
# 示例 1: 计算高阶导数
# ======================================================================
# f(x) = x^4, x=2
# f(2) = 16.0
# f'(2) = 4*2^3 = 32.0
# f''(2) = 12*2^2 = 48.0
# f'''(2) = 24*2 = 48.0
# f''''(2) = 24 = 24.0
print("=" * 70)
print("示例 1: 计算高阶导数")
print("=" * 70)
# 函数: f(x) = x^4
x = torch.tensor([2.0], requires_grad=True)
# 零阶: f(x)
y = x ** 4
print("f(x) = x^4, x=2")
print(f"f(2) = {y.item()}")
# 一阶导数: f'(x) = 4x^3
dy_dx = autograd.grad(y, x, create_graph=True)[0]
print(f"f'(2) = 4*2^3 = {dy_dx.item()}")
# 二阶导数: f''(x) = 12x^2
d2y_dx2 = autograd.grad(dy_dx, x, create_graph=True)[0]
print(f"f''(2) = 12*2^2 = {d2y_dx2.item()}")
# 三阶导数: f'''(x) = 24x
d3y_dx3 = autograd.grad(d2y_dx2, x, create_graph=True)[0]
print(f"f'''(2) = 24*2 = {d3y_dx3.item()}")
# 四阶导数: f''''(x) = 24
d4y_dx4 = autograd.grad(d3y_dx3, x)[0]
print(f"f''''(2) = 24 = {d4y_dx4.item()}")
# ======================================================================
# 示例 2: 计算 Hessian 矩阵
# ======================================================================
# 函数: f(x,y) = x² + 2xy + 3y²
# 在点 x=1.0, y=2.0
# f = 17.0
# Hessian 矩阵:
# tensor([[2., 2.],
# [2., 6.]])
# 理论 Hessian:
# ∂²f/∂x² = 2
# ∂²f/∂x∂y = 2
# ∂²f/∂y∂x = 2
# ∂²f/∂y² = 6
print("\n" + "=" * 70)
print("示例 2: 计算 Hessian 矩阵")
print("=" * 70)
def compute_hessian(f, x):
"""
计算标量函数 f 关于向量 x 的 Hessian 矩阵
Hessian[i,j] = ∂²f/∂xi∂xj
"""
n = x.shape[0]
hessian = torch.zeros(n, n)
# 计算一阶导数
grad = autograd.grad(f, x, create_graph=True)[0]
# 对每个一阶导数分量再求导
for i in range(n):
grad2 = autograd.grad(grad[i], x, retain_graph=True)[0]
hessian[i] = grad2
return hessian
# 函数: f(x, y) = x^2 + 2xy + 3y^2
x = torch.tensor([1.0, 2.0], requires_grad=True)
f = x[0]**2 + 2*x[0]*x[1] + 3*x[1]**2
print("函数: f(x,y) = x² + 2xy + 3y²")
print(f"在点 x={x[0].item()}, y={x[1].item()}")
print(f"f = {f.item()}")
hessian = compute_hessian(f, x)
print("\nHessian 矩阵:")
print(hessian)
print("\n理论 Hessian:")
print(" ∂²f/∂x² = 2")
print(" ∂²f/∂x∂y = 2")
print(" ∂²f/∂y∂x = 2")
print(" ∂²f/∂y² = 6")
# ======================================================================
# 示例 3: 使用 functorch 高效计算 Hessian
# ======================================================================
# Rosenbrock 函数在 x=tensor([0.5000, 0.5000])
# f = 6.5000
# 梯度: tensor([-51., 50.])
# Hessian 矩阵:
# tensor([[ 102., -200.],
# [-200., 200.]])
print("\n" + "=" * 70)
print("示例 3: 使用 functorch 高效计算 Hessian")
print("=" * 70)
def rosenbrock(x):
"""Rosenbrock 函数: f(x,y) = (1-x)² + 100(y-x²)²"""
return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2
x = torch.tensor([0.5, 0.5], requires_grad=True)
f_val = rosenbrock(x)
print(f"Rosenbrock 函数在 x={x.data}")
print(f"f = {f_val.item():.4f}")
# 计算梯度
grad = autograd.grad(f_val, x, create_graph=True)[0]
print(f"\n梯度: {grad.data}")
# 计算 Hessian
hessian = compute_hessian(f_val, x)
print("\nHessian 矩阵:")
print(hessian)
# ======================================================================
# 示例 4: Jacobian 向量积 (JVP)
# ======================================================================
# 输入 x = tensor([2., 3.])
# 方向 v = tensor([1., 1.])
# 输出 f(x) = tensor([7., 6.])
# JVP (J@v) = tensor([7., 3.])
# 手动计算验证:
# J = [[2*x[0], 1 ],
# [x[1], x[0] ]]
# J = [[4.0, 1.0],
# [3.0, 2.0]]
# v^T @ J = [7.0, 3.0]
print("\n" + "=" * 70)
print("示例 4: Jacobian 向量积 (JVP)")
print("=" * 70)
def f(x):
"""向量函数 R² -> R²"""
return torch.stack([
x[0]**2 + x[1],
x[0] * x[1]
])
x = torch.tensor([2.0, 3.0], requires_grad=True)
v = torch.tensor([1.0, 1.0]) # 方向向量
# 计算 JVP: J(x) @ v
y = f(x)
jvp = autograd.grad(y, x, grad_outputs=v, create_graph=True)[0]
print(f"输入 x = {x.data}")
print(f"方向 v = {v}")
print(f"输出 f(x) = {y.data}")
print(f"JVP (J@v) = {jvp.data}")
print("\n手动计算验证:")
print("J = [[2*x[0], 1 ],")
print(" [x[1], x[0] ]]")
print(f"J = [[{2*x[0].item():.1f}, {1.0:.1f}],")
print(f" [{x[1].item():.1f}, {x[0].item():.1f}]]")
print(f"v^T @ J = [{v[0]*2*x[0].item() + v[1]*x[1].item():.1f}, {v[0]*1.0 + v[1]*x[0].item():.1f}]")
# ======================================================================
# 示例 5: 向量 Jacobian 积 (VJP)
# ======================================================================
# 输入 x = tensor([2., 3.])
# 方向 v = tensor([1., 1.])
# VJP (v^T@J) = tensor([7., 3.])
# 手动计算验证:
# v^T@J = [1, 1] @ J = [7.0, 3.0]
print("\n" + "=" * 70)
print("示例 5: 向量 Jacobian 积 (VJP)")
print("=" * 70)
x = torch.tensor([2.0, 3.0], requires_grad=True)
v = torch.tensor([1.0, 1.0])
y = f(x)
# 计算 VJP: v^T @ J(x)
vjp = autograd.grad(y, x, grad_outputs=v)[0]
print(f"输入 x = {x.data}")
print(f"方向 v = {v}")
print(f"VJP (v^T@J) = {vjp.data}")
print("\n手动计算验证:")
print(f"v^T@J = [1, 1] @ J = [{2*x[0].item() + x[1].item():.1f}, {1 + x[0].item():.1f}]")
# ======================================================================
# 示例 6: 拉普拉斯算子 (Laplacian)
# ======================================================================
# 函数: f(x,y) = x² + y²
# 拉普拉斯: ∇²f = ∂²f/∂x² + ∂²f/∂y² = 4.0
# 理论值: 2 + 2 = 4
print("\n" + "=" * 70)
print("示例 6: 拉普拉斯算子 (Laplacian)")
print("=" * 70)
def compute_laplacian(f, x):
"""
计算标量函数的拉普拉斯: ∇²f = Σ ∂²f/∂xi²
"""
grad = autograd.grad(f, x, create_graph=True)[0]
laplacian = 0
for i in range(x.shape[0]):
grad2 = autograd.grad(grad[i], x, retain_graph=True)[0]
laplacian += grad2[i]
return laplacian
# 函数: f(x,y) = x² + y²
x = torch.tensor([1.0, 2.0], requires_grad=True)
f = (x ** 2).sum()
laplacian = compute_laplacian(f, x)
print("函数: f(x,y) = x² + y²")
print(f"拉普拉斯: ∇²f = ∂²f/∂x² + ∂²f/∂y² = {laplacian.item()}")
print("理论值: 2 + 2 = 4")
# ======================================================================
# 示例 7: 牛顿法优化 (使用 Hessian)
# ======================================================================
# 使用牛顿法优化 Rosenbrock 函数
# 目标: 找到最小值点 (1, 1)
# 迭代 0: x=tensor([1., 0.]), f(x)=100.000000
# 迭代 2: x=tensor([1., 1.]), f(x)=0.000000
# 迭代 4: x=tensor([1., 1.]), f(x)=0.000000
# 最优解: tensor([1., 1.])
# 函数值: 0.000000
print("\n" + "=" * 70)
print("示例 7: 牛顿法优化 (使用 Hessian)")
print("=" * 70)
def newton_method(f, x0, n_iterations=10):
"""使用牛顿法最小化函数"""
x = x0.clone().requires_grad_(True)
history = [x0.clone().detach()]
for i in range(n_iterations):
# 计算函数值
f_val = f(x)
# 计算梯度
grad = autograd.grad(f_val, x, create_graph=True)[0]
# 计算 Hessian
hessian = compute_hessian(f_val, x)
# 牛顿更新: x = x - H^(-1) @ g
with torch.no_grad():
delta = torch.linalg.solve(hessian, grad)
x = x - delta
x.requires_grad_(True)
history.append(x.clone().detach())
if i % 2 == 0:
print(f"迭代 {i}: x={x.data}, f(x)={f(x).item():.6f}")
return x.detach(), history
# 优化 Rosenbrock 函数
x0 = torch.tensor([0.0, 0.0], requires_grad=True)
print("使用牛顿法优化 Rosenbrock 函数")
print("目标: 找到最小值点 (1, 1)")
x_opt, history = newton_method(rosenbrock, x0, n_iterations=6)
print(f"\n最优解: {x_opt}")
print(f"函数值: {rosenbrock(x_opt).item():.6f}")
print("理论最优: [1, 1], f=0")
print("\n" + "=" * 70)
print("关键概念总结")
print("=" * 70)
print("1. create_graph=True: 创建高阶导数的计算图")
print("2. retain_graph=True: 保留图以便多次 backward")
print("3. Hessian: 二阶导数矩阵,用于优化和分析")
print("4. JVP: Jacobian-Vector Product, 高效的前向模式")
print("5. VJP: Vector-Jacobian Product, 高效的反向模式")
print("6. 拉普拉斯: 二阶导数的迹,物理中常用")
print("7. 牛顿法: 使用 Hessian 的二阶优化方法")
print("=" * 70)
3. Hook 机制详解
python
import torch
import torch.nn as nn
print("=" * 70)
print("示例 1: Tensor Hook - 监控梯度")
print("=" * 70)
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)
# 定义 hook 函数
def grad_hook(grad):
print(f" 梯度被计算: {grad}")
# 可以修改梯度
return grad * 2 # 梯度翻倍
# 注册 hook
hook_handle = x.register_hook(grad_hook)
z = (x * y).sum()
print(f"z = (x * y).sum() = {z.item()}")
print("\n反向传播:")
z.backward()
print(f"\nx 的最终梯度: {x.grad}")
print(f"y 的最终梯度: {y.grad}")
print("注意: x 的梯度被翻倍了!")
# 移除 hook
hook_handle.remove()
print("\n" + "=" * 70)
print("示例 2: Module Hook - 前向传播")
print("=" * 70)
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(4, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# 存储中间激活值
activations = {}
def forward_hook(module, input, output):
"""前向传播 hook"""
# module: 当前层
# input: 输入 (tuple)
# output: 输出
name = module.__class__.__name__
activations[name] = output.detach()
print(f" {name}: input shape={input[0].shape}, output shape={output.shape}")
# 为每一层注册 hook
for name, layer in model.named_modules():
if isinstance(layer, nn.Linear):
layer.register_forward_hook(forward_hook)
x = torch.randn(2, 3)
print("前向传播:")
output = model(x)
print("\n保存的激活值:")
for name, act in activations.items():
print(f" {name}: shape={act.shape}, mean={act.mean().item():.4f}")
print("\n" + "=" * 70)
print("示例 3: Module Hook - 反向传播")
print("=" * 70)
gradients = {}
def backward_hook(module, grad_input, grad_output):
"""反向传播 hook"""
# grad_input: 相对于输入的梯度 (tuple)
# grad_output: 相对于输出的梯度 (tuple)
name = module.__class__.__name__
gradients[name] = {
'grad_output': grad_output[0].detach() if grad_output[0] is not None else None,
'grad_input': grad_input[0].detach() if grad_input[0] is not None else None,
}
print(f" {name}: grad_output shape={grad_output[0].shape if grad_output[0] is not None else None}")
model = SimpleNet()
# 注册反向 hook
for name, layer in model.named_modules():
if isinstance(layer, nn.Linear):
layer.register_full_backward_hook(backward_hook)
x = torch.randn(2, 3)
output = model(x)
loss = output.sum()
print("反向传播:")
loss.backward()
print("\n保存的梯度信息:")
for name, grads in gradients.items():
print(f" {name}:")
if grads['grad_output'] is not None:
print(f" grad_output shape: {grads['grad_output'].shape}")
if grads['grad_input'] is not None:
print(f" grad_input shape: {grads['grad_input'].shape}")
print("\n" + "=" * 70)
print("示例 4: 梯度裁剪 Hook")
print("=" * 70)
def gradient_clipping_hook(grad, max_norm=1.0):
"""梯度裁剪 hook"""
norm = grad.norm()
if norm > max_norm:
print(f" 梯度范数 {norm:.4f} 超过 {max_norm}, 进行裁剪")
return grad * (max_norm / norm)
return grad
x = torch.tensor([10.0], requires_grad=True)
x.register_hook(lambda grad: gradient_clipping_hook(grad, max_norm=1.0))
y = x**3
print("y = x^3, x=10")
print(f"y = {y.item()}")
print("\n反向传播 (带梯度裁剪):")
y.backward()
print(f"裁剪后的梯度: {x.grad.item():.4f}")
print(f"原始梯度应该是: 3*x^2 = {3 * 10**2}")
print("\n" + "=" * 70)
print("示例 5: 梯度监控和可视化")
print("=" * 70)
class GradientMonitor:
"""梯度监控器"""
def __init__(self):
self.gradients = {}
self.hooks = []
def register(self, model):
"""为模型所有参数注册 hook"""
for name, param in model.named_parameters():
if param.requires_grad:
hook = param.register_hook(lambda grad, name=name: self.save_grad(name, grad))
self.hooks.append(hook)
def save_grad(self, name, grad):
"""保存梯度统计"""
self.gradients[name] = {
'mean': grad.mean().item(),
'std': grad.std().item(),
'min': grad.min().item(),
'max': grad.max().item(),
'norm': grad.norm().item(),
}
def report(self):
"""报告梯度统计"""
print("\n梯度统计报告:")
print("-" * 70)
print(f"{'层名':<20} {'均值':<10} {'标准差':<10} {'范数':<10}")
print("-" * 70)
for name, stats in self.gradients.items():
print(f"{name:<20} {stats['mean']:<10.4f} {stats['std']:<10.4f} {stats['norm']:<10.4f}")
def clear(self):
"""清除保存的梯度"""
self.gradients.clear()
def remove_hooks(self):
"""移除所有 hooks"""
for hook in self.hooks:
hook.remove()
# 创建模型和监控器
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.ReLU(), nn.Linear(10, 1))
monitor = GradientMonitor()
monitor.register(model)
# 训练一步
x = torch.randn(5, 10)
y = torch.randn(5, 1)
output = model(x)
loss = ((output - y) ** 2).mean()
print("前向传播完成,开始反向传播...")
loss.backward()
monitor.report()
print("\n" + "=" * 70)
print("示例 6: Pre-Hook (输入修改)")
print("=" * 70)
def pre_hook(module, input):
"""前向传播前的 hook,可以修改输入"""
print(f" 原始输入范围: [{input[0].min():.2f}, {input[0].max():.2f}]")
# 归一化输入
normalized = (input[0] - input[0].mean()) / (input[0].std() + 1e-8)
print(f" 归一化后范围: [{normalized.min():.2f}, {normalized.max():.2f}]")
return (normalized,)
model = nn.Linear(5, 3)
model.register_forward_pre_hook(pre_hook)
x = torch.randn(2, 5) * 100 # 大范围的输入
print("前向传播 (带输入归一化):")
output = model(x)
print("\n" + "=" * 70)
print("示例 7: 梯度累积检测")
print("=" * 70)
class GradientAccumulationDetector:
"""检测意外的梯度累积"""
def __init__(self):
self.grad_counts = {}
def register(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
param.register_hook(lambda grad, name=name: self.check_accumulation(name, grad))
def check_accumulation(self, name, grad):
self.grad_counts[name] = self.grad_counts.get(name, 0) + 1
if self.grad_counts[name] > 1:
print(f" ⚠️ 警告: {name} 的梯度被计算了 {self.grad_counts[name]} 次!")
def reset(self):
self.grad_counts.clear()
model = nn.Linear(3, 2)
detector = GradientAccumulationDetector()
detector.register(model)
x = torch.randn(1, 3)
print("第一次反向传播:")
loss1 = model(x).sum()
loss1.backward()
print("\n第二次反向传播 (未清零梯度):")
loss2 = model(x).sum()
loss2.backward()
print("\n检测到梯度累积问题!")
print("解决方案: 使用 optimizer.zero_grad() 或 param.grad.zero_()")
print("\n" + "=" * 70)
print("示例 8: 条件梯度修改")
print("=" * 70)
def conditional_grad_modification(grad, threshold=1.0):
"""根据条件修改梯度"""
# 对异常大的梯度进行标记和修改
large_grad_mask = grad.abs() > threshold
if large_grad_mask.any():
print(f" 检测到 {large_grad_mask.sum().item()} 个异常梯度")
# 将大梯度裁剪到阈值
grad = torch.where(large_grad_mask, grad.sign() * threshold, grad)
return grad
x = torch.tensor([1.0, 2.0, 10.0], requires_grad=True)
x.register_hook(lambda grad: conditional_grad_modification(grad, threshold=5.0))
y = (x**3).sum()
print("反向传播 (条件梯度修改):")
y.backward()
print(f"修改后的梯度: {x.grad}")
print("原始梯度应该是: [3, 12, 300]")
print("\n" + "=" * 70)
print("Hook 使用最佳实践")
print("=" * 70)
print("1. 记得移除不再需要的 hook (调用 handle.remove())")
print("2. Hook 函数应该简洁高效,避免复杂计算")
print("3. 注意 hook 的执行顺序 (按注册顺序)")
print("4. 修改梯度时要小心,可能影响训练稳定性")
print("5. 使用 hook 进行调试,但生产环境中移除")
print("6. Forward hook 不能修改输出,只能观察")
print("7. Backward hook 可以修改梯度,需谨慎使用")
print("=" * 70)
4. 实用进阶技巧
python
import time
import torch
import torch.nn as nn
from torch.autograd import Function
# ======================================================================
# 技巧 1: 混合精度训练中的自定义函数
# ======================================================================
# 输入 (FP16): tensor([ 0.1000, 1.0000, 10.0000], dtype=torch.float16, requires_grad=True)
# 输出 (FP16): tensor([-2.3027, 0.0000, 2.3027], dtype=torch.float16,
# grad_fn=<MixedPrecisionOperationBackward>)
# 梯度 (FP16): tensor([10.0000, 1.0000, 0.1000], dtype=torch.float16)
print("=" * 70)
print("技巧 1: 混合精度训练中的自定义函数")
print("=" * 70)
class MixedPrecisionOperation(Function):
"""在 FP16 训练中保持某些操作的 FP32 精度"""
@staticmethod
def forward(ctx, x):
# 转换到 FP32 进行高精度计算
x_fp32 = x.float()
result = torch.log(x_fp32 + 1e-8) # 数值稳定的 log
ctx.save_for_backward(x)
return result.to(x.dtype)
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
# 在 FP32 中计算梯度
x_fp32 = x.float()
grad_input = grad_output.float() / (x_fp32 + 1e-8)
return grad_input.to(x.dtype)
x = torch.tensor([0.1, 1.0, 10.0], dtype=torch.float16, requires_grad=True)
y = MixedPrecisionOperation.apply(x)
print(f"输入 (FP16): {x}")
print(f"输出 (FP16): {y}")
y.sum().backward()
print(f"梯度 (FP16): {x.grad}")
# ======================================================================
# 技巧 2: 梯度检查点 (Gradient Checkpointing)
# ======================================================================
# 不使用检查点:
# 时间: 0.1083s
# 使用检查点:
# 时间: 0.3041s
# (可能更慢,但显著节省内存)
print("\n" + "=" * 70)
print("技巧 2: 梯度检查点 (Gradient Checkpointing)")
print("=" * 70)
def checkpoint_function(func, *args):
"""
简化版的梯度检查点
前向传播时不保存中间激活,反向传播时重新计算
"""
class CheckpointFunction(Function):
@staticmethod
def forward(ctx, *inputs):
# 不保存中间结果
ctx.func = func
with torch.no_grad():
outputs = func(*inputs)
# 只保存输入用于重新计算
ctx.save_for_backward(*inputs)
return outputs
@staticmethod
def backward(ctx, *grad_outputs):
inputs = ctx.saved_tensors
# 重新计算前向传播(这次保存计算图)
with torch.enable_grad():
detached_inputs = [x.detach().requires_grad_(True) for x in inputs]
outputs = ctx.func(*detached_inputs)
# 计算梯度
torch.autograd.backward(outputs, grad_outputs)
return tuple(x.grad for x in detached_inputs)
return CheckpointFunction.apply(*args)
# 示例: 深层网络
def heavy_computation(x):
"""模拟计算密集型操作"""
for _ in range(100):
x = torch.relu(x)
x = x + 0.01
return x
x = torch.randn(1000, 1000, requires_grad=True)
print("不使用检查点:")
start = time.time()
y = heavy_computation(x)
loss = y.sum()
loss.backward()
time_normal = time.time() - start
print(f" 时间: {time_normal:.4f}s")
x.grad = None
print("\n使用检查点:")
start = time.time()
y = checkpoint_function(heavy_computation, x)
loss = y.sum()
loss.backward()
time_checkpoint = time.time() - start
print(f" 时间: {time_checkpoint:.4f}s")
print(" (可能更慢,但显著节省内存)")
# ======================================================================
# 技巧 3: 分离部分计算图
# ======================================================================
# 训练判别器 - 生成器梯度应该被阻断:
# D 损失: 2.4889
# 生成器有梯度? False (应该是 False)
# 训练生成器 - 判别器梯度应该被阻断:
# G 损失: -2.4889
# 生成器有梯度? True (应该是 True)
# 判别器有梯度? True (应该是 False)
print("\n" + "=" * 70)
print("技巧 3: 分离部分计算图")
print("=" * 70)
# 场景: 生成对抗网络 (GAN) 训练
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 20)
def forward(self, x):
return torch.tanh(self.fc(x))
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(20, 1)
def forward(self, x):
return torch.sigmoid(self.fc(x))
G = Generator()
D = Discriminator()
z = torch.randn(5, 10)
print("训练判别器 - 生成器梯度应该被阻断:")
fake_data = G(z)
# 分离生成器输出,避免梯度流回生成器
fake_data_detached = fake_data.detach()
d_loss = D(fake_data_detached).sum()
print(f" D 损失: {d_loss.item():.4f}")
d_loss.backward()
# 检查生成器参数是否有梯度
g_has_grad = any(p.grad is not None for p in G.parameters())
print(f" 生成器有梯度? {g_has_grad} (应该是 False)")
print("\n训练生成器 - 判别器梯度应该被阻断:")
for p in D.parameters():
p.requires_grad = False # 冻结判别器
fake_data = G(z)
g_loss = -D(fake_data).sum() # 生成器想要欺骗判别器
print(f" G 损失: {g_loss.item():.4f}")
for p in G.parameters():
p.grad = None
g_loss.backward()
g_has_grad = any(p.grad is not None for p in G.parameters())
d_has_grad = any(p.grad is not None for p in D.parameters())
print(f" 生成器有梯度? {g_has_grad} (应该是 True)")
print(f" 判别器有梯度? {d_has_grad} (应该是 False)")
# ======================================================================
# 技巧 4: 梯度累积实现大批次训练
# ======================================================================
# 模拟批次大小: 32
# 实际小批次: 8
# 累积步数: 4
# 步骤 1: loss=1.5185
# 步骤 2: loss=1.1364
# 步骤 3: loss=1.6141
# 步骤 4: loss=1.4942
# 参数更新完成
print("\n" + "=" * 70)
print("技巧 4: 梯度累积实现大批次训练")
print("=" * 70)
model = nn.Linear(5, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 模拟大批次 (batch_size=32) 但内存只够处理 batch_size=8
real_batch_size = 32
mini_batch_size = 8
accumulation_steps = real_batch_size // mini_batch_size
print(f"模拟批次大小: {real_batch_size}")
print(f"实际小批次: {mini_batch_size}")
print(f"累积步数: {accumulation_steps}")
# 生成数据
data = torch.randn(real_batch_size, 5)
targets = torch.randn(real_batch_size, 3)
optimizer.zero_grad()
for i in range(accumulation_steps):
# 取小批次
start_idx = i * mini_batch_size
end_idx = start_idx + mini_batch_size
mini_batch = data[start_idx:end_idx]
mini_targets = targets[start_idx:end_idx]
# 前向传播
output = model(mini_batch)
loss = ((output - mini_targets) ** 2).mean()
# 累积梯度 (除以累积步数以得到平均梯度)
(loss / accumulation_steps).backward()
print(f" 步骤 {i + 1}: loss={loss.item():.4f}")
# 一次性更新参数
optimizer.step()
print("参数更新完成")
# ======================================================================
# 技巧 5: 多任务学习的梯度平衡
# ======================================================================
# 任务1损失: 2.0551
# 任务2损失: 0.7745
# 任务1梯度范数: 2.5237
# 任务2梯度范数: 1.3568
# 动态权重: w1=0.3496, w2=0.6782
print("\n" + "=" * 70)
print("技巧 5: 多任务学习的梯度平衡")
print("=" * 70)
class MultiTaskModel(nn.Module):
def __init__(self):
super().__init__()
self.shared = nn.Linear(10, 20)
self.task1_head = nn.Linear(20, 5)
self.task2_head = nn.Linear(20, 3)
def forward(self, x):
shared_features = torch.relu(self.shared(x))
out1 = self.task1_head(shared_features)
out2 = self.task2_head(shared_features)
return out1, out2
model = MultiTaskModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
x = torch.randn(4, 10)
y1 = torch.randn(4, 5)
y2 = torch.randn(4, 3)
# 计算两个任务的损失
out1, out2 = model(x)
loss1 = ((out1 - y1) ** 2).mean()
loss2 = ((out2 - y2) ** 2).mean()
print(f"任务1损失: {loss1.item():.4f}")
print(f"任务2损失: {loss2.item():.4f}")
# 方法1: 简单加权
# loss_total = 0.5 * loss1 + 0.5 * loss2
# 方法2: 梯度归一化
optimizer.zero_grad()
# 分别计算梯度
loss1.backward(retain_graph=True)
grad1_norm = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).norm()
optimizer.zero_grad()
loss2.backward(retain_graph=True)
grad2_norm = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).norm()
print(f"\n任务1梯度范数: {grad1_norm.item():.4f}")
print(f"任务2梯度范数: {grad2_norm.item():.4f}")
# 根据梯度范数动态调整权重
w1 = 1.0 / (grad1_norm.item() + 1e-8)
w2 = 1.0 / (grad2_norm.item() + 1e-8)
w1 = w1 / (w1 + w2)
w2 = w2 / (w1 + w2)
print(f"动态权重: w1={w1:.4f}, w2={w2:.4f}")
optimizer.zero_grad()
loss_total = w1 * loss1 + w2 * loss2
loss_total.backward()
optimizer.step()
# ======================================================================
# 技巧 6: 自定义梯度缩放
# ======================================================================
# 正常反向传播:
# 梯度: tensor([2., 4.])
# 梯度反转:
# 梯度: tensor([-2., -4.]) (注意负号)
print("\n" + "=" * 70)
print("技巧 6: 自定义梯度缩放")
print("=" * 70)
class GradientRescale(Function):
"""自定义梯度缩放,常用于对抗训练"""
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output * ctx.scale, None
# 应用: 梯度反转层 (用于域适应)
class GradientReversalLayer(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.alpha * grad_output, None
x = torch.tensor([1.0, 2.0], requires_grad=True)
# 正常梯度
y1 = (x**2).sum()
# 梯度反转
grad_reverse = GradientReversalLayer.apply
y2 = grad_reverse((x**2).sum(), 1.0)
print("正常反向传播:")
y1.backward()
print(f" 梯度: {x.grad}")
x.grad = None
print("\n梯度反转:")
y2.backward()
print(f" 梯度: {x.grad} (注意负号)")
# ======================================================================
# 技巧 7: 条件计算图
# ======================================================================
# 条件为 True:
# 输出: 1.9872
# 梯度: 2.5217
# 条件为 False:
# 输出: 1.0000
# 梯度: 2.0000
print("\n" + "=" * 70)
print("技巧 7: 条件计算图")
print("=" * 70)
def conditional_forward(x, condition):
"""根据条件选择不同的计算路径"""
if condition:
# 路径 A: 复杂计算
return torch.sigmoid(x) * torch.exp(x)
else:
# 路径 B: 简单计算
return x**2
x = torch.tensor([1.0], requires_grad=True)
print("条件为 True:")
y1 = conditional_forward(x, True)
y1.backward()
print(f" 输出: {y1.item():.4f}")
print(f" 梯度: {x.grad.item():.4f}")
x.grad = None
print("\n条件为 False:")
y2 = conditional_forward(x, False)
y2.backward()
print(f" 输出: {y2.item():.4f}")
print(f" 梯度: {x.grad.item():.4f}")
# ======================================================================
# 技巧 8: 参数组的分别优化
# ======================================================================
# 不同层的学习率:
# 组 0: lr=0.001
# 组 1: lr=0.0001
print("\n" + "=" * 70)
print("技巧 8: 参数组的分别优化")
print("=" * 70)
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10))
# 为不同层设置不同学习率
optimizer = torch.optim.Adam(
[{'params': model[0].parameters(), 'lr': 1e-3}, {'params': model[2].parameters(), 'lr': 1e-4}]
)
x = torch.randn(5, 10)
y = model(x)
loss = y.sum()
print("不同层的学习率:")
for i, param_group in enumerate(optimizer.param_groups):
print(f" 组 {i}: lr={param_group['lr']}")
loss.backward()
optimizer.step()
print("\n" + "=" * 70)
print("高级技巧总结")
print("=" * 70)
print("1. 混合精度: 关键操作保持 FP32 精度")
print("2. 梯度检查点: 用计算换内存")
print("3. 分离计算图: 控制梯度流向")
print("4. 梯度累积: 模拟大批次训练")
print("5. 多任务平衡: 动态调整任务权重")
print("6. 梯度缩放: 自定义梯度大小")
print("7. 条件计算: 动态选择计算路径")
print("8. 参数组优化: 不同参数不同策略")
print("=" * 70)
# ======================================================================
# 高级技巧总结
# ======================================================================
# 1. 混合精度: 关键操作保持 FP32 精度
# 2. 梯度检查点: 用计算换内存
# 3. 分离计算图: 控制梯度流向
# 4. 梯度累积: 模拟大批次训练
# 5. 多任务平衡: 动态调整任务权重
# 6. 梯度缩放: 自定义梯度大小
# 7. 条件计算: 动态选择计算路径
# 8. 参数组优化: 不同参数不同策略
# ======================================================================
总结与最佳实践
核心 API 速查
python
# 基础操作
tensor.requires_grad_(True) # 启用梯度追踪
tensor.detach() # 分离计算图
with torch.no_grad(): # 临时禁用梯度
# 梯度计算
loss.backward() # 反向传播
torch.autograd.grad(outputs, inputs) # 计算梯度
# 自定义函数
class MyFunc(Function):
@staticmethod
def forward(ctx, ...): # 前向传播
ctx.save_for_backward(...) # 保存张量
@staticmethod
def backward(ctx, grad_output): # 反向传播
return grad_input # 返回梯度
# Hook 机制
tensor.register_hook(hook_fn) # 张量 hook
module.register_forward_hook(fn) # 前向 hook
module.register_full_backward_hook(fn) # 反向 hook
# 高阶导数
grad = autograd.grad(..., create_graph=True) # 创建高阶图
# 梯度检查
gradcheck(func, inputs) # 验证梯度正确性
补充2:深入理解 backward() 的 grad_outputs 参数
1. grad_outputs 的本质
完整的链式法则:
dL/dx = dL/dy × dy/dx
↑ ↑
grad_outputs 当前层计算
(外部提供) (自动求导)
2. 三种典型情况
| 情况 | grad_outputs 值 | 说明 |
|---|---|---|
| 标量损失 | 1.0 (隐式) | loss.backward() |
| 向量中间层 | 来自后续层 | hidden.backward(grad) |
| 手动指定权重 | 用户定义 | y.backward(weights) |
3. 为什么向量输出必须指定?
python
# 问题: y 有多个元素,对 x 的影响如何组合?
x = [x1, x2]
y = [y1, y2] = [x1², x2²]
# PyTorch 需要知道如何组合:
dL/dx1 = dL/dy1 × dy1/dx1 + dL/dy2 × dy2/dx1
↑ ↑
这两个权重需要你指定!
4. 实用场景
场景 1: 多任务学习
python
loss1, loss2 = model(x)
# 使用 grad_outputs 指定任务权重
loss1.backward(torch.tensor(0.7), retain_graph=True)
loss2.backward(torch.tensor(0.3))
场景 2: 计算雅可比矩阵
python
for i in range(n):
grad_output = torch.zeros(n)
grad_output[i] = 1.0 # 单位向量
y.backward(grad_output, retain_graph=True)
场景 3: 自定义梯度流
python
# 只关心某些输出的梯度
mask = torch.tensor([1.0, 0.0, 1.0]) # 忽略第2个输出
y.backward(mask)
5. 常见错误
python
# ❌ 错误: 向量输出不提供 grad_outputs
y = x ** 2 # 向量
y.backward() # RuntimeError!
# ✓ 正确
y.backward(torch.ones_like(y))
# ❌ 错误: 形状不匹配
y = x ** 2 # shape: (2,)
y.backward(torch.tensor([1.0])) # 形状错误
# ✓ 正确
y.backward(torch.tensor([1.0, 1.0]))