PyTorch梯度计算

PyTorch梯度计算

  • 介绍
  • [PyTorch 梯度计算和更新详解](#PyTorch 梯度计算和更新详解)
  • 补充:torch.autograd.Function使用
  • [PyTorch Autograd 高级特性详解](#PyTorch Autograd 高级特性详解)
    • 基础使用
      • [1. **torch.autograd.Function 的结构**](#1. torch.autograd.Function 的结构)
      • [2. **ctx (Context) 对象的重要方法**](#2. ctx (Context) 对象的重要方法)
    • 使用示例
      • [1. **梯度检查 (Gradient Check)**](#1. 梯度检查 (Gradient Check))
      • [2. **高阶导数和 Hessian 矩阵**](#2. 高阶导数和 Hessian 矩阵)
      • [3. **Hook 机制详解**](#3. Hook 机制详解)
      • [4. **实用进阶技巧**](#4. 实用进阶技巧)
    • 总结与最佳实践
      • [**核心 API 速查**](#核心 API 速查)
  • [补充2:深入理解 backward() 的 grad_outputs 参数](#补充2:深入理解 backward() 的 grad_outputs 参数)
      • [**1. grad_outputs 的本质**](#1. grad_outputs 的本质)
      • [**2. 三种典型情况**](#2. 三种典型情况)
      • [**3. 为什么向量输出必须指定?**](#3. 为什么向量输出必须指定?)
      • [**4. 实用场景**](#4. 实用场景)
      • [**5. 常见错误**](#5. 常见错误)

介绍

PyTorch 梯度计算和更新详解

我来详细介绍 PyTorch 的梯度计算和更新机制,并提供易于理解的示例。

一、核心概念

1. 自动微分(Autograd)

PyTorch 使用动态计算图来自动计算梯度。当你对张量进行操作时,PyTorch 会记录这些操作,然后反向传播时自动计算梯度。

2. requires_grad 属性

  • 设置为 True 的张量会追踪所有对它的操作
  • 用于告诉 PyTorch 需要计算这个张量的梯度

3. .grad 属性

  • 存储计算得到的梯度值
  • 默认会累积梯度,需要手动清零

二、基础示例

示例 1:简单的梯度计算

python 复制代码
import torch

# 创建一个需要梯度的张量
x = torch.tensor([2.0], requires_grad=True)

# 定义一个简单函数:y = x^2
y = x ** 2

print(f"x = {x}")
print(f"y = {y}")

# 反向传播计算梯度
y.backward()

# 查看梯度:dy/dx = 2x = 2*2 = 4
print(f"x的梯度: {x.grad}")

输出解释:

  • y = x²
  • dy/dx = 2x = 2×2 = 4

示例 2:多步运算的梯度

python 复制代码
import torch

x = torch.tensor([3.0], requires_grad=True)

# 多步运算:z = (x + 2) * (x^2)
y = x + 2          # y = x + 2 = 5
z = y * (x ** 2)   # z = 5 * 9 = 45

print(f"z = {z}")

z.backward()

# 梯度计算:dz/dx = d[(x+2)*x²]/dx = x² + 2x(x+2) = x² + 2x² + 4x = 3x² + 4x
# 当 x=3: 3*9 + 4*3 = 27 + 12 = 39
print(f"x的梯度: {x.grad}")

三、神经网络训练示例

python 复制代码
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim


print("\n" + "=" * 60)
print("示例 2: 梯度累积问题")
print("=" * 60)

x = torch.tensor([3.0], requires_grad=True)

# 第一次计算
y1 = x ** 2
y1.backward()
print(f"第一次backward后,x.grad = {x.grad.item()}")

# 如果不清零,梯度会累积
y2 = x ** 3
y2.backward()
print(f"第二次backward后(未清零),x.grad = {x.grad.item()}")

# 清零梯度
x.grad.zero_()
y3 = x ** 3
y3.backward()
print(f"清零后再次backward,x.grad = {x.grad.item()}")

print("\n" + "=" * 60)
print("示例 3: 简单线性回归")
print("=" * 60)

# 生成训练数据:y = 3x + 2 + 噪声
torch.manual_seed(42)
x_train = torch.randn(100, 1)
y_label = 3 * x_train + 2 + torch.randn(100, 1) * 0.3

# 定义模型参数
w = torch.tensor([[0.0]], requires_grad=True)
b = torch.tensor([[0.0]], requires_grad=True)

# 训练参数
learning_rate = 0.01
epochs = 100

# 记录损失
losses = []

print("开始训练...")
for epoch in range(epochs):
    # 前向传播
    y_pred = x_train @ w + b
    
    # 计算损失(均方误差)
    loss = ((y_pred - y_label) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 手动更新参数(梯度下降)
    with torch.no_grad():  # 更新参数时不需要计算梯度
        w -= learning_rate * w.grad
        b -= learning_rate * b.grad
    
    # 清零梯度
    w.grad.zero_()
    b.grad.zero_()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, w = {w.item():.4f}, b = {b.item():.4f}")

print(f"\n最终参数: w = {w.item():.4f}, b = {b.item():.4f}")
print("真实参数: w = 3.0000, b = 2.0000")

print("\n" + "=" * 60)
print("示例 4: 使用优化器的神经网络")
print("=" * 60)

# 定义一个简单的神经网络
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(1, 10)
        self.fc2 = nn.Linear(10, 1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建模型
model = SimpleNN()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练数据:y = sin(x)
x_data = torch.linspace(-3, 3, 100).reshape(-1, 1)
y_data = torch.sin(x_data)

print("使用优化器训练神经网络...")
for epoch in range(200):
    # 前向传播
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    
    # 反向传播
    optimizer.zero_grad()  # 清零梯度
    loss.backward()        # 计算梯度
    optimizer.step()       # 更新参数
    
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}")

print("\n" + "=" * 60)
print("示例 5: 查看网络中的梯度")
print("=" * 60)

# 创建一个小型网络
net = nn.Sequential(
    nn.Linear(2, 3),
    nn.ReLU(),
    nn.Linear(3, 1)
)

# 前向传播
x = torch.tensor([[1.0, 2.0]])
y = torch.tensor([[1.0]])
output = net(x)
loss = (output - y) ** 2

# 反向传播
loss.backward()

# 查看每层的梯度
print("网络各层的梯度:")
for name, param in net.named_parameters():
    if param.grad is not None:
        print(f"{name}:")
        print(f"  形状: {param.grad.shape}")
        print(f"  梯度: {param.grad}")
        print()

print("=" * 60)
print("关键要点总结:")
print("=" * 60)
print("1. requires_grad=True: 启用梯度追踪")
print("2. .backward(): 计算梯度")
print("3. .grad: 访问梯度值")
print("4. .zero_grad(): 清零梯度(防止累积)")
print("5. with torch.no_grad(): 临时禁用梯度计算")
print("6. optimizer.step(): 使用计算的梯度更新参数")
print("=" * 60)

四、重要概念详解

1. 为什么需要 zero_grad()

梯度默认会累积,如果不清零,新的梯度会加到旧梯度上:

python 复制代码
x = torch.tensor([2.0], requires_grad=True)

y1 = x ** 2
y1.backward()
print(f"第一次:{x.grad}")  # 4.0

y2 = x ** 2
y2.backward()
print(f"第二次(累积):{x.grad}")  # 8.0(4+4)

x.grad.zero_()
y3 = x ** 2
y3.backward()
print(f"清零后:{x.grad}")  # 4.0

2. with torch.no_grad()的作用

在不需要计算梯度的地方(如评估、推理、手动更新参数)使用,可以节省内存和计算:

python 复制代码
# 推理时不需要梯度
with torch.no_grad():
    predictions = model(test_data)

# 手动更新参数时不需要梯度
with torch.no_grad():
    w -= learning_rate * w.grad

五、调试技巧

检查梯度是否计算:

python 复制代码
# 检查张量是否需要梯度
print(x.requires_grad)

# 检查梯度是否已计算
print(x.grad)

# 查看计算图
print(y.grad_fn)

# 分离计算图(不再追踪梯度)
x_detached = x.detach()

梯度裁剪(防止梯度爆炸):

python 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

总结

PyTorch 的梯度计算遵循以下流程:

  1. 设置 requires_grad=True 追踪计算
  2. 前向传播执行计算
  3. 调用 .backward() 自动计算梯度
  4. 使用梯度更新参数
  5. 清零梯度准备下一次迭代

这个自动微分机制让我们无需手动推导复杂的梯度公式,极大简化了深度学习模型的开发!

补充:torch.autograd.Function使用

PyTorch Autograd 高级特性详解

PyTorch 的高级自动微分特性,包括自定义反向传播、梯度检查、高阶导数等。## 深入解析关键概念

基础使用

1. torch.autograd.Function 的结构

python 复制代码
class CustomFunction(Function):
    @staticmethod
    def forward(ctx, *args):
        # ctx: 上下文对象
        # *args: 可变数量的输入
        # 返回: 前向传播的输出
        pass
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        # *grad_outputs: 从后续层传来的梯度
        # 返回: 对每个输入的梯度 (顺序与forward的输入对应)
        # 如果某个输入不需要梯度,返回 None
        pass

2. ctx (Context) 对象的重要方法

python 复制代码
# 保存张量 (会自动处理梯度追踪)
ctx.save_for_backward(tensor1, tensor2, ...)

# 保存非张量数据
ctx.constant = value
ctx.mark_non_differentiable(tensor)  # 标记输出不需要梯度

# 标记输入是否需要梯度 (优化性能)
ctx.needs_input_grad[0]  # 检查第一个输入是否需要梯度

使用示例

1. 梯度检查 (Gradient Check)

数值方法验证解析梯度的正确性:

python 复制代码
import numpy as np
import torch
from torch.autograd import Function, gradcheck

print("=" * 70)
print("梯度检查原理详解")
print("=" * 70)

def numerical_gradient(f, x, eps=1e-5):
    """
    使用有限差分法计算数值梯度
    f(x + h) - f(x - h)
    ---------------------
           2h
    """
    grad = torch.zeros_like(x)
    
    for i in range(x.numel()):
        # 创建扰动
        x_plus = x.clone()
        x_minus = x.clone()
        x_plus.view(-1)[i] += eps
        x_minus.view(-1)[i] -= eps
        
        # 计算数值梯度
        f_plus = f(x_plus)
        f_minus = f(x_minus)
        grad.view(-1)[i] = (f_plus - f_minus) / (2 * eps)
    
    return grad

# 定义一个简单函数
def func(x):
    return (x ** 2).sum()

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# 解析梯度
y = func(x)
y.backward()
analytical_grad = x.grad.clone()

# 数值梯度
x_no_grad = x.detach().clone()
numerical_grad = numerical_gradient(func, x_no_grad)

print("函数: f(x) = sum(x^2)")
print(f"输入: {x.data}")
print(f"\n解析梯度 (自动微分): {analytical_grad}")
print(f"数值梯度 (有限差分): {numerical_grad}")
print(f"差异: {(analytical_grad - numerical_grad).abs().max().item():.2e}")

print("\n" + "=" * 70)
print("使用 gradcheck 自动验证")
print("=" * 70)

class MyExp(Function):
    @staticmethod
    def forward(ctx, x):
        result = torch.exp(x)
        ctx.save_for_backward(result)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

# 使用 double 精度以提高数值稳定性
x = torch.randn(3, 4, dtype=torch.double, requires_grad=True)

print("测试自定义 Exp 函数...")
test_result = gradcheck(MyExp.apply, x, eps=1e-6, atol=1e-4)
print(f"结果: {'✓ 通过' if test_result else '✗ 失败'}")

print("\n" + "=" * 70)
print("常见梯度错误示例")
print("=" * 70)

class WrongGradient(Function):
    """故意写错的梯度"""
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x ** 2
    
    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        # 错误: 应该是 2*x,这里写成 3*x
        return grad_output * 3 * x

print("测试错误的梯度实现...")
x = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
try:
    test_result = gradcheck(WrongGradient.apply, x, eps=1e-6, atol=1e-4)
    print(f"结果: {'✓ 通过' if test_result else '✗ 失败'}")
except RuntimeError as e:
    print("✗ 检测到错误: gradcheck 失败")
    print("  说明梯度实现有误!")

print("\n" + "=" * 70)
print("复杂函数的梯度验证")
print("=" * 70)

class ComplexFunction(Function):
    """
    复杂函数: f(x, y) = sin(x) * exp(y) + x^2 * y
    """
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return torch.sin(x) * torch.exp(y) + x**2 * y
    
    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        
        # df/dx = cos(x)*exp(y) + 2*x*y
        grad_x = grad_output * (torch.cos(x) * torch.exp(y) + 2 * x * y)
        
        # df/dy = sin(x)*exp(y) + x^2
        grad_y = grad_output * (torch.sin(x) * torch.exp(y) + x**2)
        
        return grad_x, grad_y

x = torch.randn(2, 3, dtype=torch.double, requires_grad=True)
y = torch.randn(2, 3, dtype=torch.double, requires_grad=True)

print("测试复杂函数...")
test_result = gradcheck(ComplexFunction.apply, (x, y), eps=1e-6, atol=1e-4)
print(f"结果: {'✓ 通过' if test_result else '✗ 失败'}")

print("\n" + "=" * 70)
print("梯度检查最佳实践")
print("=" * 70)
print("1. 使用 double 精度 (torch.double)")
print("2. 设置合适的 eps (通常 1e-6)")
print("3. 设置合适的 atol (绝对容差,通常 1e-4)")
print("4. 测试多个随机输入")
print("5. 在开发阶段测试,生产环境可以关闭")
print("=" * 70)

2. 高阶导数和 Hessian 矩阵

python 复制代码
import torch
import torch.autograd as autograd

# ======================================================================
# 示例 1: 计算高阶导数
# ======================================================================
# f(x) = x^4, x=2
# f(2) = 16.0
# f'(2) = 4*2^3 = 32.0
# f''(2) = 12*2^2 = 48.0
# f'''(2) = 24*2 = 48.0
# f''''(2) = 24 = 24.0
print("=" * 70)
print("示例 1: 计算高阶导数")
print("=" * 70)

# 函数: f(x) = x^4
x = torch.tensor([2.0], requires_grad=True)

# 零阶: f(x)
y = x ** 4
print("f(x) = x^4, x=2")
print(f"f(2) = {y.item()}")

# 一阶导数: f'(x) = 4x^3
dy_dx = autograd.grad(y, x, create_graph=True)[0]
print(f"f'(2) = 4*2^3 = {dy_dx.item()}")

# 二阶导数: f''(x) = 12x^2
d2y_dx2 = autograd.grad(dy_dx, x, create_graph=True)[0]
print(f"f''(2) = 12*2^2 = {d2y_dx2.item()}")

# 三阶导数: f'''(x) = 24x
d3y_dx3 = autograd.grad(d2y_dx2, x, create_graph=True)[0]
print(f"f'''(2) = 24*2 = {d3y_dx3.item()}")

# 四阶导数: f''''(x) = 24
d4y_dx4 = autograd.grad(d3y_dx3, x)[0]
print(f"f''''(2) = 24 = {d4y_dx4.item()}")

# ======================================================================
# 示例 2: 计算 Hessian 矩阵
# ======================================================================
# 函数: f(x,y) = x² + 2xy + 3y²
# 在点 x=1.0, y=2.0
# f = 17.0
# Hessian 矩阵:
# tensor([[2., 2.],
#         [2., 6.]])
# 理论 Hessian:
#   ∂²f/∂x² = 2
#   ∂²f/∂x∂y = 2
#   ∂²f/∂y∂x = 2
#   ∂²f/∂y² = 6
print("\n" + "=" * 70)
print("示例 2: 计算 Hessian 矩阵")
print("=" * 70)

def compute_hessian(f, x):
    """
    计算标量函数 f 关于向量 x 的 Hessian 矩阵
    Hessian[i,j] = ∂²f/∂xi∂xj
    """
    n = x.shape[0]
    hessian = torch.zeros(n, n)
    
    # 计算一阶导数
    grad = autograd.grad(f, x, create_graph=True)[0]
    
    # 对每个一阶导数分量再求导
    for i in range(n):
        grad2 = autograd.grad(grad[i], x, retain_graph=True)[0]
        hessian[i] = grad2
    
    return hessian

# 函数: f(x, y) = x^2 + 2xy + 3y^2
x = torch.tensor([1.0, 2.0], requires_grad=True)
f = x[0]**2 + 2*x[0]*x[1] + 3*x[1]**2

print("函数: f(x,y) = x² + 2xy + 3y²")
print(f"在点 x={x[0].item()}, y={x[1].item()}")
print(f"f = {f.item()}")

hessian = compute_hessian(f, x)
print("\nHessian 矩阵:")
print(hessian)

print("\n理论 Hessian:")
print("  ∂²f/∂x² = 2")
print("  ∂²f/∂x∂y = 2")
print("  ∂²f/∂y∂x = 2")
print("  ∂²f/∂y² = 6")


# ======================================================================
# 示例 3: 使用 functorch 高效计算 Hessian
# ======================================================================
# Rosenbrock 函数在 x=tensor([0.5000, 0.5000])
# f = 6.5000
# 梯度: tensor([-51.,  50.])
# Hessian 矩阵:
# tensor([[ 102., -200.],
#         [-200.,  200.]])
print("\n" + "=" * 70)
print("示例 3: 使用 functorch 高效计算 Hessian")
print("=" * 70)

def rosenbrock(x):
    """Rosenbrock 函数: f(x,y) = (1-x)² + 100(y-x²)²"""
    return (1 - x[0])**2 + 100 * (x[1] - x[0]**2)**2

x = torch.tensor([0.5, 0.5], requires_grad=True)
f_val = rosenbrock(x)

print(f"Rosenbrock 函数在 x={x.data}")
print(f"f = {f_val.item():.4f}")

# 计算梯度
grad = autograd.grad(f_val, x, create_graph=True)[0]
print(f"\n梯度: {grad.data}")

# 计算 Hessian
hessian = compute_hessian(f_val, x)
print("\nHessian 矩阵:")
print(hessian)

# ======================================================================
# 示例 4: Jacobian 向量积 (JVP)
# ======================================================================
# 输入 x = tensor([2., 3.])
# 方向 v = tensor([1., 1.])
# 输出 f(x) = tensor([7., 6.])
# JVP (J@v) = tensor([7., 3.])
# 手动计算验证:
# J = [[2*x[0], 1    ],
#      [x[1],   x[0] ]]
# J = [[4.0, 1.0],
#      [3.0, 2.0]]
# v^T @ J = [7.0, 3.0]
print("\n" + "=" * 70)
print("示例 4: Jacobian 向量积 (JVP)")
print("=" * 70)

def f(x):
    """向量函数 R² -> R²"""
    return torch.stack([
        x[0]**2 + x[1],
        x[0] * x[1]
    ])

x = torch.tensor([2.0, 3.0], requires_grad=True)
v = torch.tensor([1.0, 1.0])  # 方向向量

# 计算 JVP: J(x) @ v
y = f(x)
jvp = autograd.grad(y, x, grad_outputs=v, create_graph=True)[0]

print(f"输入 x = {x.data}")
print(f"方向 v = {v}")
print(f"输出 f(x) = {y.data}")
print(f"JVP (J@v) = {jvp.data}")

print("\n手动计算验证:")
print("J = [[2*x[0], 1    ],")
print("     [x[1],   x[0] ]]")
print(f"J = [[{2*x[0].item():.1f}, {1.0:.1f}],")
print(f"     [{x[1].item():.1f}, {x[0].item():.1f}]]")
print(f"v^T @ J = [{v[0]*2*x[0].item() + v[1]*x[1].item():.1f}, {v[0]*1.0 + v[1]*x[0].item():.1f}]")

# ======================================================================
# 示例 5: 向量 Jacobian 积 (VJP)
# ======================================================================
# 输入 x = tensor([2., 3.])
# 方向 v = tensor([1., 1.])
# VJP (v^T@J) = tensor([7., 3.])
# 手动计算验证:
# v^T@J = [1, 1] @ J = [7.0, 3.0]
print("\n" + "=" * 70)
print("示例 5: 向量 Jacobian 积 (VJP)")
print("=" * 70)

x = torch.tensor([2.0, 3.0], requires_grad=True)
v = torch.tensor([1.0, 1.0])

y = f(x)

# 计算 VJP: v^T @ J(x)
vjp = autograd.grad(y, x, grad_outputs=v)[0]

print(f"输入 x = {x.data}")
print(f"方向 v = {v}")
print(f"VJP (v^T@J) = {vjp.data}")

print("\n手动计算验证:")
print(f"v^T@J = [1, 1] @ J = [{2*x[0].item() + x[1].item():.1f}, {1 + x[0].item():.1f}]")

# ======================================================================
# 示例 6: 拉普拉斯算子 (Laplacian)
# ======================================================================
# 函数: f(x,y) = x² + y²
# 拉普拉斯: ∇²f = ∂²f/∂x² + ∂²f/∂y² = 4.0
# 理论值: 2 + 2 = 4
print("\n" + "=" * 70)
print("示例 6: 拉普拉斯算子 (Laplacian)")
print("=" * 70)

def compute_laplacian(f, x):
    """
    计算标量函数的拉普拉斯: ∇²f = Σ ∂²f/∂xi²
    """
    grad = autograd.grad(f, x, create_graph=True)[0]
    
    laplacian = 0
    for i in range(x.shape[0]):
        grad2 = autograd.grad(grad[i], x, retain_graph=True)[0]
        laplacian += grad2[i]
    
    return laplacian

# 函数: f(x,y) = x² + y²
x = torch.tensor([1.0, 2.0], requires_grad=True)
f = (x ** 2).sum()

laplacian = compute_laplacian(f, x)
print("函数: f(x,y) = x² + y²")
print(f"拉普拉斯: ∇²f = ∂²f/∂x² + ∂²f/∂y² = {laplacian.item()}")
print("理论值: 2 + 2 = 4")

# ======================================================================
# 示例 7: 牛顿法优化 (使用 Hessian)
# ======================================================================
# 使用牛顿法优化 Rosenbrock 函数
# 目标: 找到最小值点 (1, 1)
# 迭代 0: x=tensor([1., 0.]), f(x)=100.000000
# 迭代 2: x=tensor([1., 1.]), f(x)=0.000000
# 迭代 4: x=tensor([1., 1.]), f(x)=0.000000
# 最优解: tensor([1., 1.])
# 函数值: 0.000000
print("\n" + "=" * 70)
print("示例 7: 牛顿法优化 (使用 Hessian)")
print("=" * 70)

def newton_method(f, x0, n_iterations=10):
    """使用牛顿法最小化函数"""
    x = x0.clone().requires_grad_(True)
    
    history = [x0.clone().detach()]
    
    for i in range(n_iterations):
        # 计算函数值
        f_val = f(x)
        
        # 计算梯度
        grad = autograd.grad(f_val, x, create_graph=True)[0]
        
        # 计算 Hessian
        hessian = compute_hessian(f_val, x)
        
        # 牛顿更新: x = x - H^(-1) @ g
        with torch.no_grad():
            delta = torch.linalg.solve(hessian, grad)
            x = x - delta
            x.requires_grad_(True)
        
        history.append(x.clone().detach())
        
        if i % 2 == 0:
            print(f"迭代 {i}: x={x.data}, f(x)={f(x).item():.6f}")
    
    return x.detach(), history

# 优化 Rosenbrock 函数
x0 = torch.tensor([0.0, 0.0], requires_grad=True)
print("使用牛顿法优化 Rosenbrock 函数")
print("目标: 找到最小值点 (1, 1)")

x_opt, history = newton_method(rosenbrock, x0, n_iterations=6)
print(f"\n最优解: {x_opt}")
print(f"函数值: {rosenbrock(x_opt).item():.6f}")
print("理论最优: [1, 1], f=0")

print("\n" + "=" * 70)
print("关键概念总结")
print("=" * 70)
print("1. create_graph=True: 创建高阶导数的计算图")
print("2. retain_graph=True: 保留图以便多次 backward")
print("3. Hessian: 二阶导数矩阵,用于优化和分析")
print("4. JVP: Jacobian-Vector Product, 高效的前向模式")
print("5. VJP: Vector-Jacobian Product, 高效的反向模式")
print("6. 拉普拉斯: 二阶导数的迹,物理中常用")
print("7. 牛顿法: 使用 Hessian 的二阶优化方法")
print("=" * 70)

3. Hook 机制详解

python 复制代码
import torch
import torch.nn as nn

print("=" * 70)
print("示例 1: Tensor Hook - 监控梯度")
print("=" * 70)

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([4.0, 5.0, 6.0], requires_grad=True)


# 定义 hook 函数
def grad_hook(grad):
    print(f"  梯度被计算: {grad}")
    # 可以修改梯度
    return grad * 2  # 梯度翻倍


# 注册 hook
hook_handle = x.register_hook(grad_hook)

z = (x * y).sum()
print(f"z = (x * y).sum() = {z.item()}")

print("\n反向传播:")
z.backward()

print(f"\nx 的最终梯度: {x.grad}")
print(f"y 的最终梯度: {y.grad}")
print("注意: x 的梯度被翻倍了!")

# 移除 hook
hook_handle.remove()

print("\n" + "=" * 70)
print("示例 2: Module Hook - 前向传播")
print("=" * 70)


class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(4, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = SimpleNet()

# 存储中间激活值
activations = {}


def forward_hook(module, input, output):
    """前向传播 hook"""
    # module: 当前层
    # input: 输入 (tuple)
    # output: 输出
    name = module.__class__.__name__
    activations[name] = output.detach()
    print(f"  {name}: input shape={input[0].shape}, output shape={output.shape}")


# 为每一层注册 hook
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        layer.register_forward_hook(forward_hook)

x = torch.randn(2, 3)
print("前向传播:")
output = model(x)

print("\n保存的激活值:")
for name, act in activations.items():
    print(f"  {name}: shape={act.shape}, mean={act.mean().item():.4f}")

print("\n" + "=" * 70)
print("示例 3: Module Hook - 反向传播")
print("=" * 70)

gradients = {}


def backward_hook(module, grad_input, grad_output):
    """反向传播 hook"""
    # grad_input: 相对于输入的梯度 (tuple)
    # grad_output: 相对于输出的梯度 (tuple)
    name = module.__class__.__name__
    gradients[name] = {
        'grad_output': grad_output[0].detach() if grad_output[0] is not None else None,
        'grad_input': grad_input[0].detach() if grad_input[0] is not None else None,
    }
    print(f"  {name}: grad_output shape={grad_output[0].shape if grad_output[0] is not None else None}")


model = SimpleNet()

# 注册反向 hook
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        layer.register_full_backward_hook(backward_hook)

x = torch.randn(2, 3)
output = model(x)
loss = output.sum()

print("反向传播:")
loss.backward()

print("\n保存的梯度信息:")
for name, grads in gradients.items():
    print(f"  {name}:")
    if grads['grad_output'] is not None:
        print(f"    grad_output shape: {grads['grad_output'].shape}")
    if grads['grad_input'] is not None:
        print(f"    grad_input shape: {grads['grad_input'].shape}")

print("\n" + "=" * 70)
print("示例 4: 梯度裁剪 Hook")
print("=" * 70)


def gradient_clipping_hook(grad, max_norm=1.0):
    """梯度裁剪 hook"""
    norm = grad.norm()
    if norm > max_norm:
        print(f"  梯度范数 {norm:.4f} 超过 {max_norm}, 进行裁剪")
        return grad * (max_norm / norm)
    return grad


x = torch.tensor([10.0], requires_grad=True)
x.register_hook(lambda grad: gradient_clipping_hook(grad, max_norm=1.0))

y = x**3
print("y = x^3, x=10")
print(f"y = {y.item()}")

print("\n反向传播 (带梯度裁剪):")
y.backward()
print(f"裁剪后的梯度: {x.grad.item():.4f}")
print(f"原始梯度应该是: 3*x^2 = {3 * 10**2}")

print("\n" + "=" * 70)
print("示例 5: 梯度监控和可视化")
print("=" * 70)


class GradientMonitor:
    """梯度监控器"""

    def __init__(self):
        self.gradients = {}
        self.hooks = []

    def register(self, model):
        """为模型所有参数注册 hook"""
        for name, param in model.named_parameters():
            if param.requires_grad:
                hook = param.register_hook(lambda grad, name=name: self.save_grad(name, grad))
                self.hooks.append(hook)

    def save_grad(self, name, grad):
        """保存梯度统计"""
        self.gradients[name] = {
            'mean': grad.mean().item(),
            'std': grad.std().item(),
            'min': grad.min().item(),
            'max': grad.max().item(),
            'norm': grad.norm().item(),
        }

    def report(self):
        """报告梯度统计"""
        print("\n梯度统计报告:")
        print("-" * 70)
        print(f"{'层名':<20} {'均值':<10} {'标准差':<10} {'范数':<10}")
        print("-" * 70)
        for name, stats in self.gradients.items():
            print(f"{name:<20} {stats['mean']:<10.4f} {stats['std']:<10.4f} {stats['norm']:<10.4f}")

    def clear(self):
        """清除保存的梯度"""
        self.gradients.clear()

    def remove_hooks(self):
        """移除所有 hooks"""
        for hook in self.hooks:
            hook.remove()


# 创建模型和监控器
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.ReLU(), nn.Linear(10, 1))

monitor = GradientMonitor()
monitor.register(model)

# 训练一步
x = torch.randn(5, 10)
y = torch.randn(5, 1)
output = model(x)
loss = ((output - y) ** 2).mean()

print("前向传播完成,开始反向传播...")
loss.backward()

monitor.report()

print("\n" + "=" * 70)
print("示例 6: Pre-Hook (输入修改)")
print("=" * 70)


def pre_hook(module, input):
    """前向传播前的 hook,可以修改输入"""
    print(f"  原始输入范围: [{input[0].min():.2f}, {input[0].max():.2f}]")
    # 归一化输入
    normalized = (input[0] - input[0].mean()) / (input[0].std() + 1e-8)
    print(f"  归一化后范围: [{normalized.min():.2f}, {normalized.max():.2f}]")
    return (normalized,)


model = nn.Linear(5, 3)
model.register_forward_pre_hook(pre_hook)

x = torch.randn(2, 5) * 100  # 大范围的输入
print("前向传播 (带输入归一化):")
output = model(x)

print("\n" + "=" * 70)
print("示例 7: 梯度累积检测")
print("=" * 70)


class GradientAccumulationDetector:
    """检测意外的梯度累积"""

    def __init__(self):
        self.grad_counts = {}

    def register(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.register_hook(lambda grad, name=name: self.check_accumulation(name, grad))

    def check_accumulation(self, name, grad):
        self.grad_counts[name] = self.grad_counts.get(name, 0) + 1
        if self.grad_counts[name] > 1:
            print(f"  ⚠️  警告: {name} 的梯度被计算了 {self.grad_counts[name]} 次!")

    def reset(self):
        self.grad_counts.clear()


model = nn.Linear(3, 2)
detector = GradientAccumulationDetector()
detector.register(model)

x = torch.randn(1, 3)

print("第一次反向传播:")
loss1 = model(x).sum()
loss1.backward()

print("\n第二次反向传播 (未清零梯度):")
loss2 = model(x).sum()
loss2.backward()

print("\n检测到梯度累积问题!")
print("解决方案: 使用 optimizer.zero_grad() 或 param.grad.zero_()")

print("\n" + "=" * 70)
print("示例 8: 条件梯度修改")
print("=" * 70)


def conditional_grad_modification(grad, threshold=1.0):
    """根据条件修改梯度"""
    # 对异常大的梯度进行标记和修改
    large_grad_mask = grad.abs() > threshold
    if large_grad_mask.any():
        print(f"  检测到 {large_grad_mask.sum().item()} 个异常梯度")
        # 将大梯度裁剪到阈值
        grad = torch.where(large_grad_mask, grad.sign() * threshold, grad)
    return grad


x = torch.tensor([1.0, 2.0, 10.0], requires_grad=True)
x.register_hook(lambda grad: conditional_grad_modification(grad, threshold=5.0))

y = (x**3).sum()
print("反向传播 (条件梯度修改):")
y.backward()
print(f"修改后的梯度: {x.grad}")
print("原始梯度应该是: [3, 12, 300]")

print("\n" + "=" * 70)
print("Hook 使用最佳实践")
print("=" * 70)
print("1. 记得移除不再需要的 hook (调用 handle.remove())")
print("2. Hook 函数应该简洁高效,避免复杂计算")
print("3. 注意 hook 的执行顺序 (按注册顺序)")
print("4. 修改梯度时要小心,可能影响训练稳定性")
print("5. 使用 hook 进行调试,但生产环境中移除")
print("6. Forward hook 不能修改输出,只能观察")
print("7. Backward hook 可以修改梯度,需谨慎使用")
print("=" * 70)

4. 实用进阶技巧

python 复制代码
import time

import torch
import torch.nn as nn
from torch.autograd import Function

# ======================================================================
# 技巧 1: 混合精度训练中的自定义函数
# ======================================================================
# 输入 (FP16): tensor([ 0.1000,  1.0000, 10.0000], dtype=torch.float16, requires_grad=True)
# 输出 (FP16): tensor([-2.3027,  0.0000,  2.3027], dtype=torch.float16,
#        grad_fn=<MixedPrecisionOperationBackward>)
# 梯度 (FP16): tensor([10.0000,  1.0000,  0.1000], dtype=torch.float16)
print("=" * 70)
print("技巧 1: 混合精度训练中的自定义函数")
print("=" * 70)


class MixedPrecisionOperation(Function):
    """在 FP16 训练中保持某些操作的 FP32 精度"""

    @staticmethod
    def forward(ctx, x):
        # 转换到 FP32 进行高精度计算
        x_fp32 = x.float()
        result = torch.log(x_fp32 + 1e-8)  # 数值稳定的 log
        ctx.save_for_backward(x)
        return result.to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        # 在 FP32 中计算梯度
        x_fp32 = x.float()
        grad_input = grad_output.float() / (x_fp32 + 1e-8)
        return grad_input.to(x.dtype)


x = torch.tensor([0.1, 1.0, 10.0], dtype=torch.float16, requires_grad=True)
y = MixedPrecisionOperation.apply(x)

print(f"输入 (FP16): {x}")
print(f"输出 (FP16): {y}")

y.sum().backward()
print(f"梯度 (FP16): {x.grad}")

# ======================================================================
# 技巧 2: 梯度检查点 (Gradient Checkpointing)
# ======================================================================
# 不使用检查点:
#   时间: 0.1083s

# 使用检查点:
#   时间: 0.3041s
#   (可能更慢,但显著节省内存)
print("\n" + "=" * 70)
print("技巧 2: 梯度检查点 (Gradient Checkpointing)")
print("=" * 70)


def checkpoint_function(func, *args):
    """
    简化版的梯度检查点
    前向传播时不保存中间激活,反向传播时重新计算
    """

    class CheckpointFunction(Function):
        @staticmethod
        def forward(ctx, *inputs):
            # 不保存中间结果
            ctx.func = func
            with torch.no_grad():
                outputs = func(*inputs)
            # 只保存输入用于重新计算
            ctx.save_for_backward(*inputs)
            return outputs

        @staticmethod
        def backward(ctx, *grad_outputs):
            inputs = ctx.saved_tensors
            # 重新计算前向传播(这次保存计算图)
            with torch.enable_grad():
                detached_inputs = [x.detach().requires_grad_(True) for x in inputs]
                outputs = ctx.func(*detached_inputs)

            # 计算梯度
            torch.autograd.backward(outputs, grad_outputs)
            return tuple(x.grad for x in detached_inputs)

    return CheckpointFunction.apply(*args)


# 示例: 深层网络
def heavy_computation(x):
    """模拟计算密集型操作"""
    for _ in range(100):
        x = torch.relu(x)
        x = x + 0.01
    return x


x = torch.randn(1000, 1000, requires_grad=True)

print("不使用检查点:")
start = time.time()
y = heavy_computation(x)
loss = y.sum()
loss.backward()
time_normal = time.time() - start
print(f"  时间: {time_normal:.4f}s")

x.grad = None
print("\n使用检查点:")
start = time.time()
y = checkpoint_function(heavy_computation, x)
loss = y.sum()
loss.backward()
time_checkpoint = time.time() - start
print(f"  时间: {time_checkpoint:.4f}s")
print("  (可能更慢,但显著节省内存)")

# ======================================================================
# 技巧 3: 分离部分计算图
# ======================================================================
# 训练判别器 - 生成器梯度应该被阻断:
#   D 损失: 2.4889
#   生成器有梯度? False (应该是 False)
# 训练生成器 - 判别器梯度应该被阻断:
#   G 损失: -2.4889
#   生成器有梯度? True (应该是 True)
#   判别器有梯度? True (应该是 False)
print("\n" + "=" * 70)
print("技巧 3: 分离部分计算图")
print("=" * 70)


# 场景: 生成对抗网络 (GAN) 训练
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 20)

    def forward(self, x):
        return torch.tanh(self.fc(x))


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(20, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))


G = Generator()
D = Discriminator()

z = torch.randn(5, 10)

print("训练判别器 - 生成器梯度应该被阻断:")
fake_data = G(z)
# 分离生成器输出,避免梯度流回生成器
fake_data_detached = fake_data.detach()
d_loss = D(fake_data_detached).sum()
print(f"  D 损失: {d_loss.item():.4f}")
d_loss.backward()

# 检查生成器参数是否有梯度
g_has_grad = any(p.grad is not None for p in G.parameters())
print(f"  生成器有梯度? {g_has_grad} (应该是 False)")

print("\n训练生成器 - 判别器梯度应该被阻断:")
for p in D.parameters():
    p.requires_grad = False  # 冻结判别器

fake_data = G(z)
g_loss = -D(fake_data).sum()  # 生成器想要欺骗判别器
print(f"  G 损失: {g_loss.item():.4f}")

for p in G.parameters():
    p.grad = None
g_loss.backward()

g_has_grad = any(p.grad is not None for p in G.parameters())
d_has_grad = any(p.grad is not None for p in D.parameters())
print(f"  生成器有梯度? {g_has_grad} (应该是 True)")
print(f"  判别器有梯度? {d_has_grad} (应该是 False)")

# ======================================================================
# 技巧 4: 梯度累积实现大批次训练
# ======================================================================
# 模拟批次大小: 32
# 实际小批次: 8
# 累积步数: 4
#   步骤 1: loss=1.5185
#   步骤 2: loss=1.1364
#   步骤 3: loss=1.6141
#   步骤 4: loss=1.4942
# 参数更新完成
print("\n" + "=" * 70)
print("技巧 4: 梯度累积实现大批次训练")
print("=" * 70)

model = nn.Linear(5, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 模拟大批次 (batch_size=32) 但内存只够处理 batch_size=8
real_batch_size = 32
mini_batch_size = 8
accumulation_steps = real_batch_size // mini_batch_size

print(f"模拟批次大小: {real_batch_size}")
print(f"实际小批次: {mini_batch_size}")
print(f"累积步数: {accumulation_steps}")

# 生成数据
data = torch.randn(real_batch_size, 5)
targets = torch.randn(real_batch_size, 3)

optimizer.zero_grad()

for i in range(accumulation_steps):
    # 取小批次
    start_idx = i * mini_batch_size
    end_idx = start_idx + mini_batch_size

    mini_batch = data[start_idx:end_idx]
    mini_targets = targets[start_idx:end_idx]

    # 前向传播
    output = model(mini_batch)
    loss = ((output - mini_targets) ** 2).mean()

    # 累积梯度 (除以累积步数以得到平均梯度)
    (loss / accumulation_steps).backward()

    print(f"  步骤 {i + 1}: loss={loss.item():.4f}")

# 一次性更新参数
optimizer.step()
print("参数更新完成")

# ======================================================================
# 技巧 5: 多任务学习的梯度平衡
# ======================================================================
# 任务1损失: 2.0551
# 任务2损失: 0.7745
# 任务1梯度范数: 2.5237
# 任务2梯度范数: 1.3568
# 动态权重: w1=0.3496, w2=0.6782
print("\n" + "=" * 70)
print("技巧 5: 多任务学习的梯度平衡")
print("=" * 70)


class MultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared = nn.Linear(10, 20)
        self.task1_head = nn.Linear(20, 5)
        self.task2_head = nn.Linear(20, 3)

    def forward(self, x):
        shared_features = torch.relu(self.shared(x))
        out1 = self.task1_head(shared_features)
        out2 = self.task2_head(shared_features)
        return out1, out2


model = MultiTaskModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

x = torch.randn(4, 10)
y1 = torch.randn(4, 5)
y2 = torch.randn(4, 3)

# 计算两个任务的损失
out1, out2 = model(x)
loss1 = ((out1 - y1) ** 2).mean()
loss2 = ((out2 - y2) ** 2).mean()

print(f"任务1损失: {loss1.item():.4f}")
print(f"任务2损失: {loss2.item():.4f}")

# 方法1: 简单加权
# loss_total = 0.5 * loss1 + 0.5 * loss2

# 方法2: 梯度归一化
optimizer.zero_grad()

# 分别计算梯度
loss1.backward(retain_graph=True)
grad1_norm = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).norm()

optimizer.zero_grad()
loss2.backward(retain_graph=True)
grad2_norm = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]).norm()

print(f"\n任务1梯度范数: {grad1_norm.item():.4f}")
print(f"任务2梯度范数: {grad2_norm.item():.4f}")

# 根据梯度范数动态调整权重
w1 = 1.0 / (grad1_norm.item() + 1e-8)
w2 = 1.0 / (grad2_norm.item() + 1e-8)
w1 = w1 / (w1 + w2)
w2 = w2 / (w1 + w2)

print(f"动态权重: w1={w1:.4f}, w2={w2:.4f}")

optimizer.zero_grad()
loss_total = w1 * loss1 + w2 * loss2
loss_total.backward()
optimizer.step()

# ======================================================================
# 技巧 6: 自定义梯度缩放
# ======================================================================
# 正常反向传播:
#   梯度: tensor([2., 4.])
# 梯度反转:
#   梯度: tensor([-2., -4.]) (注意负号)
print("\n" + "=" * 70)
print("技巧 6: 自定义梯度缩放")
print("=" * 70)


class GradientRescale(Function):
    """自定义梯度缩放,常用于对抗训练"""

    @staticmethod
    def forward(ctx, x, scale):
        ctx.scale = scale
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output * ctx.scale, None


# 应用: 梯度反转层 (用于域适应)
class GradientReversalLayer(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None


x = torch.tensor([1.0, 2.0], requires_grad=True)

# 正常梯度
y1 = (x**2).sum()

# 梯度反转
grad_reverse = GradientReversalLayer.apply
y2 = grad_reverse((x**2).sum(), 1.0)

print("正常反向传播:")
y1.backward()
print(f"  梯度: {x.grad}")

x.grad = None
print("\n梯度反转:")
y2.backward()
print(f"  梯度: {x.grad} (注意负号)")

# ======================================================================
# 技巧 7: 条件计算图
# ======================================================================
# 条件为 True:
#   输出: 1.9872
#   梯度: 2.5217
# 条件为 False:
#   输出: 1.0000
#   梯度: 2.0000
print("\n" + "=" * 70)
print("技巧 7: 条件计算图")
print("=" * 70)


def conditional_forward(x, condition):
    """根据条件选择不同的计算路径"""
    if condition:
        # 路径 A: 复杂计算
        return torch.sigmoid(x) * torch.exp(x)
    else:
        # 路径 B: 简单计算
        return x**2


x = torch.tensor([1.0], requires_grad=True)

print("条件为 True:")
y1 = conditional_forward(x, True)
y1.backward()
print(f"  输出: {y1.item():.4f}")
print(f"  梯度: {x.grad.item():.4f}")

x.grad = None
print("\n条件为 False:")
y2 = conditional_forward(x, False)
y2.backward()
print(f"  输出: {y2.item():.4f}")
print(f"  梯度: {x.grad.item():.4f}")

# ======================================================================
# 技巧 8: 参数组的分别优化
# ======================================================================
# 不同层的学习率:
#   组 0: lr=0.001
#   组 1: lr=0.0001
print("\n" + "=" * 70)
print("技巧 8: 参数组的分别优化")
print("=" * 70)

model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10))

# 为不同层设置不同学习率
optimizer = torch.optim.Adam(
    [{'params': model[0].parameters(), 'lr': 1e-3}, {'params': model[2].parameters(), 'lr': 1e-4}]
)

x = torch.randn(5, 10)
y = model(x)
loss = y.sum()

print("不同层的学习率:")
for i, param_group in enumerate(optimizer.param_groups):
    print(f"  组 {i}: lr={param_group['lr']}")

loss.backward()
optimizer.step()

print("\n" + "=" * 70)
print("高级技巧总结")
print("=" * 70)
print("1. 混合精度: 关键操作保持 FP32 精度")
print("2. 梯度检查点: 用计算换内存")
print("3. 分离计算图: 控制梯度流向")
print("4. 梯度累积: 模拟大批次训练")
print("5. 多任务平衡: 动态调整任务权重")
print("6. 梯度缩放: 自定义梯度大小")
print("7. 条件计算: 动态选择计算路径")
print("8. 参数组优化: 不同参数不同策略")
print("=" * 70)


# ======================================================================
# 高级技巧总结
# ======================================================================
# 1. 混合精度: 关键操作保持 FP32 精度
# 2. 梯度检查点: 用计算换内存
# 3. 分离计算图: 控制梯度流向
# 4. 梯度累积: 模拟大批次训练
# 5. 多任务平衡: 动态调整任务权重
# 6. 梯度缩放: 自定义梯度大小
# 7. 条件计算: 动态选择计算路径
# 8. 参数组优化: 不同参数不同策略
# ======================================================================

总结与最佳实践

核心 API 速查

python 复制代码
# 基础操作
tensor.requires_grad_(True)          # 启用梯度追踪
tensor.detach()                      # 分离计算图
with torch.no_grad():                # 临时禁用梯度

# 梯度计算
loss.backward()                      # 反向传播
torch.autograd.grad(outputs, inputs) # 计算梯度

# 自定义函数
class MyFunc(Function):
    @staticmethod
    def forward(ctx, ...):           # 前向传播
        ctx.save_for_backward(...)   # 保存张量
    
    @staticmethod
    def backward(ctx, grad_output):  # 反向传播
        return grad_input            # 返回梯度

# Hook 机制
tensor.register_hook(hook_fn)        # 张量 hook
module.register_forward_hook(fn)     # 前向 hook
module.register_full_backward_hook(fn) # 反向 hook

# 高阶导数
grad = autograd.grad(..., create_graph=True)  # 创建高阶图

# 梯度检查
gradcheck(func, inputs)              # 验证梯度正确性

补充2:深入理解 backward() 的 grad_outputs 参数

1. grad_outputs 的本质

复制代码
完整的链式法则:
dL/dx = dL/dy × dy/dx
        ↑       ↑
   grad_outputs  当前层计算
   (外部提供)   (自动求导)

2. 三种典型情况

情况 grad_outputs 值 说明
标量损失 1.0 (隐式) loss.backward()
向量中间层 来自后续层 hidden.backward(grad)
手动指定权重 用户定义 y.backward(weights)

3. 为什么向量输出必须指定?

python 复制代码
# 问题: y 有多个元素,对 x 的影响如何组合?
x = [x1, x2]
y = [y1, y2] = [x1², x2²]

# PyTorch 需要知道如何组合:
dL/dx1 = dL/dy1 × dy1/dx1 + dL/dy2 × dy2/dx1
         ↑                 ↑
         这两个权重需要你指定!

4. 实用场景

场景 1: 多任务学习

python 复制代码
loss1, loss2 = model(x)
# 使用 grad_outputs 指定任务权重
loss1.backward(torch.tensor(0.7), retain_graph=True)
loss2.backward(torch.tensor(0.3))

场景 2: 计算雅可比矩阵

python 复制代码
for i in range(n):
    grad_output = torch.zeros(n)
    grad_output[i] = 1.0  # 单位向量
    y.backward(grad_output, retain_graph=True)

场景 3: 自定义梯度流

python 复制代码
# 只关心某些输出的梯度
mask = torch.tensor([1.0, 0.0, 1.0])  # 忽略第2个输出
y.backward(mask)

5. 常见错误

python 复制代码
# ❌ 错误: 向量输出不提供 grad_outputs
y = x ** 2  # 向量
y.backward()  # RuntimeError!

# ✓ 正确
y.backward(torch.ones_like(y))

# ❌ 错误: 形状不匹配
y = x ** 2  # shape: (2,)
y.backward(torch.tensor([1.0]))  # 形状错误

# ✓ 正确
y.backward(torch.tensor([1.0, 1.0]))
相关推荐
deephub1 小时前
Prompt Engineering 的本质:角色、任务、上下文、格式、约束
人工智能·prompt·大语言模型·多智能体
段一凡-华北理工大学1 小时前
工业领域的Hadoop架构学习~系列文章08:Flink流处理引擎
人工智能·hadoop·学习·架构·flink·高炉炼铁·高炉炼铁智能化
2601_958352901 小时前
双麦双波束独立拾音:A-59F 让智能工牌与翻译设备“听清每一个方向”
人工智能·语音识别·硬件开发·音频处理模块·消除回音
词元Max1 小时前
3.1 Agent开发需要懂多少数学?
人工智能·python
FelixBitSoul1 小时前
面试必考!RAG 知识库全链路深度解析:父子分块 × Rerank × 查询重写 × 标准化改写
人工智能·langchain·aigc
许彰午1 小时前
06_Java面向对象入门
java·开发语言·python
ZHW_AI课题组1 小时前
使用 Rectified Flow 和 Diffusion Transformer实现 MNIST 手写数字图像生成
人工智能·python·机器学习
z202305081 小时前
RDMA之DCQCN (14)
linux·服务器·网络·人工智能·ai
SimpleLearingAI1 小时前
PyTorch & Numpy 实现线性回归详解
人工智能·算法·多模态大模型