Pytorch 反向传播

1. 理论基础:动态计算图与自动微分

动态计算图(DAG)

PyTorch采用动态计算图(Dynamic Computation Graph),在运行时即时构建计算流程。每个操作(如加法、乘法)会在执行时生成一个节点,并记录输入输出依赖关系。这种"定义即运行"机制允许灵活控制流(如条件分支、循环),但牺牲了静态图的编译期优化机会。

自动微分(AutoGrad)

基于反向模式自动微分(Reverse-mode AD),利用链式法则计算梯度:

  • 前向传播:记录操作序列,构建DAG。
  • 反向传播:从输出(Loss)出发,按拓扑逆序逐层计算梯度。
    数学形式:若 y = f ( x ) y = f(x) y=f(x) ,则梯度 d L d x = ∑ i d L d y i ⋅ ∂ y i ∂ x \frac{dL}{dx} = \sum_{i} \frac{dL}{dy_i} \cdot \frac{\partial y_i}{\partial x} dxdL=∑idyidL⋅∂x∂yi。

典型应用场景

  • 调试时动态调整网络结构(如注意力机制)。
  • 需要复杂控制流的模型(如强化学习策略网络)。

2. 实现流程:反向传播步骤分解

分阶段流程:
  1. 前向传播

    python 复制代码
    import torch
    # 定义参数
    w = torch.randn(1, requires_grad=True)
    b = torch.randn(1, requires_grad=True)
    
    # 输入数据
    x = torch.tensor([2.0])
    y_true = torch.tensor([5.0])
    
    # 前向计算
    y_pred = w * x + b  # 构建计算图
  2. 计算Loss

    python 复制代码
    loss = (y_pred - y_true) ** 2  # Loss = (wx + b - y_true)^2
  3. 反向传播

    python 复制代码
    loss.backward()  # 启动反向传播,计算梯度
  4. 梯度更新

    python 复制代码
    with torch.no_grad():
        w -= 0.1 * w.grad  # 手动更新参数
        b -= 0.1 * b.grad
  5. 梯度清零

    python 复制代码
    w.grad.zero_()  # 避免梯度累积
    b.grad.zero_()

关键点

  • requires_grad标记需跟踪梯度的张量。
  • backward()触发Autograd引擎递归计算梯度。
  • torch.no_grad()临时禁用梯度计算。

3. 核心组件:grad_fn与Autograd引擎

grad_fn属性
python 复制代码
a = torch.tensor([2.0], requires_grad=True)
b = a ** 2
c = b + 3
print(c.grad_fn)  # <AddBackward0 object>,记录生成c的操作
  • grad_fn指向创建该Tensor的操作函数(如AddBackward0MulBackward0)。
  • 叶子节点 (Leaf Nodes):由用户直接创建的Tensor(如a),其grad_fn=None
Autograd引擎工作原理
  1. 依赖追踪:在前向传播时,Autograd记录每个操作的前向函数和反向传播函数。
  2. 拓扑排序:反向传播时,按DAG逆序处理节点(确保父节点梯度已计算)。
  3. 梯度累加 :中间节点梯度会累积到tensor.grad中。

源码级原理

  • C++实现的Engine类管理任务队列,Python端通过torch.autograd.backward()交互。
  • 每个grad_fn实现forward()backward()方法。

4. 代码演示:反向传播与高阶导数

基本反向传播示例
python 复制代码
import torch

# 定义模型参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)

# 数据
x = torch.tensor([2.0])
y_true = torch.tensor([5.0])

# 前向传播
y_pred = w * x + b
loss = (y_pred - y_true) ** 2

# 反向传播
loss.backward()

# 查看梯度
print(f"w.grad: {w.grad}, b.grad: {b.grad}")
高阶导数计算
python 复制代码
x = torch.tensor([3.0], requires_grad=True)
y = x ** 3  # y = x^3

# 一阶导数 dy/dx = 3x²
first_derivative = torch.autograd.grad(y, x, create_graph=True)[0]  # create_graph=True保留计算图

# 二阶导数 d²y/dx² = 6x
second_derivative = torch.autograd.grad(first_derivative, x)[0]

print(f"First derivative: {first_derivative}")  # 输出 3*3²=27
print(f"Second derivative: {second_derivative}")  # 输出 6*3=18

关键点

  • create_graph=True允许计算高阶导数。
  • 返回值是元组,需索引取第一个元素。

5. 内存管理:计算图保留与优化

retain_graph参数
python 复制代码
loss1 = (w * x + b - y_true) ** 2
loss2 = (w * x**2 + b - y_true) ** 2

loss1.backward(retain_graph=True)  # 第一次反向传播保留图
loss2.backward()  # 继续第二次反向传播
  • 默认反向传播后释放计算图,retain_graph=True防止释放。
  • 应用场景:多任务学习中多个Loss顺序反向传播。
内存优化技巧
  • 及时清除无用中间变量del tensor或使用上下文管理器。
  • 合并操作:减少冗余计算(如将多个激活函数合并)。
  • 降低精度 :使用torch.float16或混合精度训练。
python 复制代码
with torch.no_grad():  # 推理阶段禁用梯度
    predictions = model(inputs)

6. 注意事项:常见陷阱与解决方案

梯度累积
python 复制代码
for batch in data_loader:
    optimizer.zero_grad()
    outputs = model(batch)
    loss = loss_function(outputs, labels)
    loss.backward()  # 梯度累积(未调用zero_grad)
optimizer.step()  # 累积多个batch后更新
  • 应用场景:小显存设备模拟大batch size训练。
In-place操作限制
python 复制代码
x = torch.randn(3, requires_grad=True)
# 错误!In-place操作破坏计算图
x.add_(1)  # 报错:a leaf Variable that requires grad is being used in an in-place operation
  • PyTorch禁止修改requires_grad=True张量的值(除非标记为.volatile)。
非标量输出处理
python 复制代码
x = torch.tensor([2.0], requires_grad=True)
y = torch.stack([x**2, x**3])  # 非标量输出
v = torch.tensor([1.0, 0.1])  # 外部梯度权重
y.backward(v)  # 相当于计算 ∂L/∂x = v[0]*dy[0]/dx + v[1]*dy[1]/dx
print(x.grad)  # 输出 1.0*4 + 0.1*12 = 5.2
  • 对非标量输出调用backward()时必须传入gradient参数,用于指定外部梯度。

总结

维度 关键技术点 典型应用
理论基础 DAG、反向模式AD 动态模型设计
实现流程 forward → backward → update 训练自定义模型
核心组件 grad_fn、Autograd引擎 调试梯度计算流程
内存管理 retain_graph、no_grad 多任务学习、低显存训练
注意事项 梯度累积、in-place限制、非标量处理 复杂Loss设计、高阶优化问题
相关推荐
数据与后端架构提升之路36 分钟前
小鹏VLA 2.0的“神秘涌现”:从痛苦到突破,自动驾驶与机器人如何突然“开窍”?
人工智能·机器人·自动驾驶
fruge38 分钟前
CANN核心特性深度解析:简化AI开发的技术优势
人工智能
沛沛老爹1 小时前
AI入门知识之RAFT方法:基于微调的RAG优化技术详解
人工智能·llm·sft·raft·rag
zskj_zhyl1 小时前
科技助老与智慧养老的国家级政策与地方实践探索
大数据·人工智能·科技
避避风港1 小时前
Java 抽象类
java·开发语言·python
YangYang9YangYan1 小时前
职业本科发展路径与规划指南
大数据·人工智能·学习·数据分析
牛客企业服务1 小时前
2025年AI面试防作弊指南:技术笔试如何识别异常行为
人工智能·面试·职场和发展
shayudiandian2 小时前
CNN详解:卷积神经网络是如何识别图像的?
人工智能·深度学习·cnn
V_156560272192 小时前
2025年蚌埠市“三首产品”、市级服务型制造示范、市级企业技术中心等5个项目认定申报指南大全
大数据·人工智能·制造
盘古信息IMS2 小时前
AI算力时代,PCB制造如何借助盘古信息MOM构建数字化新范式?
人工智能·制造