Pytorch 反向传播

1. 理论基础:动态计算图与自动微分

动态计算图(DAG)

PyTorch采用动态计算图(Dynamic Computation Graph),在运行时即时构建计算流程。每个操作(如加法、乘法)会在执行时生成一个节点,并记录输入输出依赖关系。这种"定义即运行"机制允许灵活控制流(如条件分支、循环),但牺牲了静态图的编译期优化机会。

自动微分(AutoGrad)

基于反向模式自动微分(Reverse-mode AD),利用链式法则计算梯度:

  • 前向传播:记录操作序列,构建DAG。
  • 反向传播:从输出(Loss)出发,按拓扑逆序逐层计算梯度。
    数学形式:若 y = f ( x ) y = f(x) y=f(x) ,则梯度 d L d x = ∑ i d L d y i ⋅ ∂ y i ∂ x \frac{dL}{dx} = \sum_{i} \frac{dL}{dy_i} \cdot \frac{\partial y_i}{\partial x} dxdL=∑idyidL⋅∂x∂yi。

典型应用场景

  • 调试时动态调整网络结构(如注意力机制)。
  • 需要复杂控制流的模型(如强化学习策略网络)。

2. 实现流程:反向传播步骤分解

分阶段流程:
  1. 前向传播

    python 复制代码
    import torch
    # 定义参数
    w = torch.randn(1, requires_grad=True)
    b = torch.randn(1, requires_grad=True)
    
    # 输入数据
    x = torch.tensor([2.0])
    y_true = torch.tensor([5.0])
    
    # 前向计算
    y_pred = w * x + b  # 构建计算图
  2. 计算Loss

    python 复制代码
    loss = (y_pred - y_true) ** 2  # Loss = (wx + b - y_true)^2
  3. 反向传播

    python 复制代码
    loss.backward()  # 启动反向传播,计算梯度
  4. 梯度更新

    python 复制代码
    with torch.no_grad():
        w -= 0.1 * w.grad  # 手动更新参数
        b -= 0.1 * b.grad
  5. 梯度清零

    python 复制代码
    w.grad.zero_()  # 避免梯度累积
    b.grad.zero_()

关键点

  • requires_grad标记需跟踪梯度的张量。
  • backward()触发Autograd引擎递归计算梯度。
  • torch.no_grad()临时禁用梯度计算。

3. 核心组件:grad_fn与Autograd引擎

grad_fn属性
python 复制代码
a = torch.tensor([2.0], requires_grad=True)
b = a ** 2
c = b + 3
print(c.grad_fn)  # <AddBackward0 object>,记录生成c的操作
  • grad_fn指向创建该Tensor的操作函数(如AddBackward0MulBackward0)。
  • 叶子节点 (Leaf Nodes):由用户直接创建的Tensor(如a),其grad_fn=None
Autograd引擎工作原理
  1. 依赖追踪:在前向传播时,Autograd记录每个操作的前向函数和反向传播函数。
  2. 拓扑排序:反向传播时,按DAG逆序处理节点(确保父节点梯度已计算)。
  3. 梯度累加 :中间节点梯度会累积到tensor.grad中。

源码级原理

  • C++实现的Engine类管理任务队列,Python端通过torch.autograd.backward()交互。
  • 每个grad_fn实现forward()backward()方法。

4. 代码演示:反向传播与高阶导数

基本反向传播示例
python 复制代码
import torch

# 定义模型参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

# 数据
x = torch.tensor([2.0])
y_true = torch.tensor([5.0])

# 前向传播
y_pred = w * x + b
loss = (y_pred - y_true) ** 2

# 反向传播
loss.backward()

# 查看梯度
print(f"w.grad: {w.grad}, b.grad: {b.grad}")
高阶导数计算
python 复制代码
x = torch.tensor([3.0], requires_grad=True)
y = x ** 3  # y = x^3

# 一阶导数 dy/dx = 3x²
first_derivative = torch.autograd.grad(y, x, create_graph=True)[0]  # create_graph=True保留计算图

# 二阶导数 d²y/dx² = 6x
second_derivative = torch.autograd.grad(first_derivative, x)[0]

print(f"First derivative: {first_derivative}")  # 输出 3*3²=27
print(f"Second derivative: {second_derivative}")  # 输出 6*3=18

关键点

  • create_graph=True允许计算高阶导数。
  • 返回值是元组,需索引取第一个元素。

5. 内存管理:计算图保留与优化

retain_graph参数
python 复制代码
loss1 = (w * x + b - y_true) ** 2
loss2 = (w * x**2 + b - y_true) ** 2

loss1.backward(retain_graph=True)  # 第一次反向传播保留图
loss2.backward()  # 继续第二次反向传播
  • 默认反向传播后释放计算图,retain_graph=True防止释放。
  • 应用场景:多任务学习中多个Loss顺序反向传播。
内存优化技巧
  • 及时清除无用中间变量del tensor或使用上下文管理器。
  • 合并操作:减少冗余计算(如将多个激活函数合并)。
  • 降低精度 :使用torch.float16或混合精度训练。
python 复制代码
with torch.no_grad():  # 推理阶段禁用梯度
    predictions = model(inputs)

6. 注意事项:常见陷阱与解决方案

梯度累积
python 复制代码
for batch in data_loader:
    optimizer.zero_grad()
    outputs = model(batch)
    loss = loss_function(outputs, labels)
    loss.backward()  # 梯度累积(未调用zero_grad)
optimizer.step()  # 累积多个batch后更新
  • 应用场景:小显存设备模拟大batch size训练。
In-place操作限制
python 复制代码
x = torch.randn(3, requires_grad=True)
# 错误!In-place操作破坏计算图
x.add_(1)  # 报错:a leaf Variable that requires grad is being used in an in-place operation
  • PyTorch禁止修改requires_grad=True张量的值(除非标记为.volatile)。
非标量输出处理
python 复制代码
x = torch.tensor([2.0], requires_grad=True)
y = torch.stack([x**2, x**3])  # 非标量输出
v = torch.tensor([1.0, 0.1])  # 外部梯度权重
y.backward(v)  # 相当于计算 ∂L/∂x = v[0]*dy[0]/dx + v[1]*dy[1]/dx
print(x.grad)  # 输出 1.0*4 + 0.1*12 = 5.2
  • 对非标量输出调用backward()时必须传入gradient参数,用于指定外部梯度。

总结

维度 关键技术点 典型应用
理论基础 DAG、反向模式AD 动态模型设计
实现流程 forward → backward → update 训练自定义模型
核心组件 grad_fn、Autograd引擎 调试梯度计算流程
内存管理 retain_graph、no_grad 多任务学习、低显存训练
注意事项 梯度累积、in-place限制、非标量处理 复杂Loss设计、高阶优化问题
相关推荐
IT技术员几秒前
【Java学习】动态代理有哪些形式?
java·python·学习
上海云盾商务经理杨杨1 分钟前
AI如何重塑DDoS防护行业?六大变革与未来展望
人工智能·安全·web安全·ddos
q_q王5 分钟前
本地知识库工具FASTGPT的安装与搭建
python·大模型·llm·知识库·fastgpt
lanboAI6 分钟前
基于卷积神经网络的蔬菜水果识别系统,resnet50,mobilenet模型【pytorch框架+python源码】
pytorch·python·cnn
一刀到底21112 分钟前
ai agent(智能体)开发 python3基础8 网页抓取中 selenium 和 Playwright 区别和联系
人工智能·python
每天都要写算法(努力版)17 分钟前
【神经网络与深度学习】改变随机种子可以提升模型性能?
人工智能·深度学习·神经网络
剑哥在胡说20 分钟前
Python三大Web框架对比:Django、Flask、Tornado的异步实现方式详解
数据库·python·django·flask·tornado
烟锁池塘柳035 分钟前
【计算机视觉】三种图像质量评价指标详解:PSNR、SSIM与SAM
人工智能·深度学习·计算机视觉
da-peng-song38 分钟前
ArcGIS arcpy代码工具——根据属性结构表创建shape图层
javascript·python·arcgis
滚雪球~1 小时前
小市值策略复现(A股选股框架回测系统)
python·量化·策略