【Pytorch 深入理解(2)】减少训练显存-Gradient Checkpointing

Gradient Checkpointing(梯度检查点)详解:用时间换空间的巧妙优化

目录

  1. [引言:为什么需要 Gradient Checkpointing?](#引言:为什么需要 Gradient Checkpointing?)
  2. 核心问题:显存不足的困境
  3. [Gradient Checkpointing 的基本概念:用生活比喻理解](#Gradient Checkpointing 的基本概念:用生活比喻理解)
  4. 技术原理:从数学角度理解
  5. 内存节省的数学分析

一句话理解

Gradient Checkpointing 是什么?

  • 一句话总结 :用计算时间换取显存空间的技术,通过在前向传播时不保存中间激活值,反向传播时重新计算,从而节省约 50-70% 的显存。

核心要点:

  • 显存节省:约 50-70%(从 O(N) 降到 O(√N))
  • ⏱️ 时间代价:训练时间增加 20-30%
  • 🎯 训练效果:几乎无影响(数值等价)
  • 📚 官方 APItorch.utils.checkpoint.checkpoint(function, *args, use_reentrant=False, ...)
  • ⚠️ 重要提示 :必须显式传递 use_reentrant=False(PyTorch 2.5+ 强制要求)

适用场景:

  • 显存不足,无法训练目标模型
  • 需要增大 batch size 提高训练稳定性
  • 训练深层模型(> 15 层)

引言:为什么需要 Gradient Checkpointing?

在深度学习训练中,我们经常面临一个经典矛盾:模型容量与显存限制的冲突

真实场景

想象你正在训练一个大型 Transformer 模型(如 HunyuanDiT):

  • 模型参数:可能达到数十亿甚至数百亿
  • 显存需求:前向传播需要保存所有中间激活值(activations)
  • 现实限制:GPU 显存有限(如 24GB、40GB 或 80GB)

当模型太大,显存装不下时,你该怎么办?

传统解决方案的局限

  1. 减小 batch size:虽然能降低显存,但训练不稳定,收敛慢
  2. 使用更小的模型:牺牲模型表达能力
  3. 购买更多/更大的 GPU:成本高昂

Gradient Checkpointing 的优雅解决方案

Gradient Checkpointing(梯度检查点)提供了一种巧妙的思路:

  • 核心思想:不保存所有中间激活值,只保存关键检查点
  • 代价:反向传播时需要重新计算部分前向过程
  • 收益:显存占用降低约 50-70%,可以用更大的 batch size 或更深的模型

简单来说:用计算时间换取显存空间。


核心问题:显存不足的困境

深度学习训练中的显存消耗

在训练神经网络时,GPU 显存主要被以下内容占用:

  1. 模型参数(Parameters)

    • 权重矩阵、偏置等
    • 通常占用相对固定的显存
  2. 优化器状态(Optimizer States)

    • Adam 优化器需要保存动量和方差
    • 通常是参数大小的 2-3 倍
  3. 激活值(Activations) ⚠️ 这是最大的显存消耗者

    • 前向传播时每层的输出
    • 反向传播时需要这些值来计算梯度
    • 对于深度模型,激活值可能占用 50-80% 的显存

激活值显存消耗示例

假设我们有一个 24 层的 Transformer 模型:

  • 每层输出:[batch_size=16, seq_len=1024, hidden_dim=1024]
  • 单个激活值大小:16 × 1024 × 1024 × 4 bytes (float32) = 64 MB
  • 24 层总激活值:64 MB × 24 = 1.5 GB

这还只是单次前向传播!如果考虑梯度计算时的中间值,显存需求会更大。

问题的本质

传统训练流程:

复制代码
前向传播:计算并保存所有层的激活值 → 显存占用 = O(层数 × 激活值大小)
反向传播:使用保存的激活值计算梯度 → 显存占用 = O(层数 × 激活值大小)

对于深度模型,这个显存占用是线性增长的,很快就会超出 GPU 容量。


Gradient Checkpointing 的基本概念:用生活比喻理解

比喻:图书馆的借书系统

想象你在图书馆写论文,需要查阅很多书籍:

传统方式(不使用 Checkpointing):

  • 你一次性把所有需要的书都借出来,堆在桌子上
  • 优点:需要哪本书时,随手就能拿到
  • 缺点:桌子空间有限,很快就放不下了

Checkpointing 方式:

  • 你只保留几本最关键的参考书在桌子上(检查点)
  • 当需要其他书时,再去书架上取(重新计算)
  • 优点:桌子空间节省很多,可以放更多重要资料
  • 缺点:需要某本书时,要多走几步去取(增加计算时间)

核心概念总结

通过以上比喻,我们可以理解 Gradient Checkpointing 的三个核心概念:

  1. 检查点(Checkpoint):只保存关键位置的激活值
  2. 丢弃中间值:非检查点的激活值被丢弃,释放显存
  3. 重新计算:反向传播时,从检查点重新计算需要的激活值

技术原理:从数学角度理解

前向传播与反向传播回顾

在深度学习中,训练过程分为两个阶段:

1. 前向传播(Forward Pass)
复制代码
输入 x → Layer 1 → a₁ → Layer 2 → a₂ → ... → Layer N → aₙ → 输出 y

其中 a₁, a₂, ..., aₙ 是各层的激活值。

2. 反向传播(Backward Pass)

为了计算梯度,我们需要链式法则:

复制代码
∂L/∂W₁ = (∂L/∂aₙ) × (∂aₙ/∂aₙ₋₁) × ... × (∂a₂/∂a₁) × (∂a₁/∂W₁)

关键点 :计算 ∂aᵢ₊₁/∂aᵢ 时,需要用到前向传播时的激活值 aᵢ

传统方式的内存占用

复制代码
前向传播:
  - 计算 a₁,保存到显存
  - 计算 a₂,保存到显存
  - ...
  - 计算 aₙ,保存到显存
  - 显存占用 = a₁ + a₂ + ... + aₙ = O(N × |a|)

反向传播:
  - 使用保存的 aₙ 计算梯度
  - 使用保存的 aₙ₋₁ 计算梯度
  - ...
  - 使用保存的 a₁ 计算梯度
  - 显存占用 = O(N × |a|) (激活值 + 梯度)

Gradient Checkpointing 的工作原理

策略:只保存检查点

假设我们有一个 24 层的模型,只在每 4 层设置一个检查点:

复制代码
检查点位置:Layer 0, Layer 4, Layer 8, Layer 12, Layer 16, Layer 20, Layer 24
前向传播阶段
复制代码
1. 计算 Layer 0-3,只保存 Layer 0 的输出(检查点)
2. 丢弃 Layer 1-3 的激活值
3. 计算 Layer 4-7,只保存 Layer 4 的输出(检查点)
4. 丢弃 Layer 5-7 的激活值
5. ... 以此类推

显存占用 = 检查点数量 × |a| = O(√N × |a|)
反向传播阶段

当需要计算 Layer 5 的梯度时:

复制代码
1. 从最近的检查点(Layer 4)开始
2. 重新计算 Layer 4 → Layer 5 的前向过程
3. 使用重新计算的激活值计算梯度
4. 计算完成后,丢弃重新计算的激活值

数学复杂度分析

假设模型有 N 层,每层激活值大小为 |a|:

传统方式:

  • 显存占用:O(N × |a|)
  • 计算时间:O(N)(前向)+ O(N)(反向)= O(N)

Gradient Checkpointing(检查点间隔为 √N):

  • 显存占用:O(√N × |a|)
  • 计算时间:O(N)(前向)+ O(N × √N)(反向,需要重新计算)= O(N × √N)

权衡:

  • 显存节省:从 O(N) 降到 O(√N),节省约 50-70%
  • 时间增加:从 O(N) 增加到 O(N × √N),增加约 20-30%

代码详解

1. 条件判断
python 复制代码
if self.training and self.gradient_checkpointing:
  • self.training:只在训练模式下启用(推理时不需要)
  • self.gradient_checkpointing:检查是否启用了该功能
2. 创建自定义前向函数
python 复制代码
def create_custom_forward(module):
    def custom_forward(*inputs):
        return module(*inputs)
    return custom_forward

这个函数的作用是:

  • 创建一个包装函数,将模块的前向传播逻辑封装起来
  • PyTorch 的 checkpoint 需要这个函数来知道如何重新计算
3. 使用 torch.utils.checkpoint.checkpoint
python 复制代码
x = torch.utils.checkpoint.checkpoint(
    create_custom_forward(block),
    x, c, cond, skip_value,
    use_reentrant=False
)

参数说明(基于官方文档):

  • create_custom_forward(block):前向传播函数(第一个参数,必需)
  • x, c, cond, skip_value:输入参数(位置参数,可以是多个)
  • use_reentrant=False必须显式传递 (PyTorch 2.5+ 将要求必须传递)
    • False:非重入模式(推荐),支持更多功能,更高效
    • True:重入模式(旧版),有更多限制
  • preserve_rng_state=True(可选):是否保存和恢复随机数生成器状态
  • context_fn=None(可选):用于选择性检查点的上下文函数

⚠️ 官方文档重要警告:

use_reentrant 参数应该显式传递。在 PyTorch 2.9 版本中,如果不传递此参数将抛出异常。

一句话总结: checkpoint() 通过在前向传播时不保存中间激活值,反向传播时重新计算,从而节省显存;必须显式设置 use_reentrant=False 以获得最佳性能和兼容性。

工作原理:

  1. 前向传播时

    • 执行 block 的前向计算
    • 不保存中间激活值到计算图
    • 只保存输入参数(作为检查点)
  2. 反向传播时

    • 从保存的输入参数开始
    • 重新执行 block 的前向计算
    • 使用重新计算的激活值进行反向传播

注意torch.utils.checkpoint.checkpoint 是 PyTorch 1.0+ 就提供的标准 API,与训练框架(PyTorch Lightning、原生 PyTorch 等)无关。

官方文档参考 :完整的 API 文档和参数说明请参考 PyTorch 官方文档,这是 torch.utils.checkpoint 模块的原始官方文档。

完整代码流程

让我们用一个具体的例子来理解:

假设有 24 层的模型,当前处理第 5 层:

不使用 Gradient Checkpointing:

python 复制代码
# 前向传播
x = block(x, c, cond, skip_value)  # 计算并保存所有中间值

# 反向传播时
# 直接使用保存的激活值计算梯度

使用 Gradient Checkpointing:

python 复制代码
# 前向传播
x = torch.utils.checkpoint.checkpoint(
    create_custom_forward(block),
    x, c, cond, skip_value,
    use_reentrant=False
)
# 只保存输入 (x, c, cond, skip_value),不保存 block 内部的中间值

# 反向传播时
# 1. 从保存的输入 (x, c, cond, skip_value) 开始
# 2. 重新执行 block 的前向计算
# 3. 使用重新计算的激活值进行反向传播
# 4. 计算完成后,丢弃重新计算的激活值

官方文档要点补充

根据 PyTorch 官方文档,以下是关键要点:

1. 核心 API:torch.utils.checkpoint.checkpoint()

函数签名:

python 复制代码
torch.utils.checkpoint.checkpoint(
    function,           # 要检查点的函数
    *args,             # 位置参数
    use_reentrant,     # 必须显式传递(推荐 False)
    preserve_rng_state=True,  # 是否保存 RNG 状态
    context_fn=None,    # 选择性检查点上下文
    **kwargs           # 关键字参数
)

关键特性:

  • ✅ 支持关键字参数(当 use_reentrant=False 时)
  • ✅ 支持 torch.autograd.grad()(当 use_reentrant=False 时)
  • ✅ 支持选择性检查点(通过 context_fn
  • ⚠️ use_reentrant 参数必须显式传递(PyTorch 2.5+ 强制要求)
2. 便捷函数:checkpoint_sequential()

对于 nn.Sequential 模型,可以使用更简单的 API:

python 复制代码
torch.utils.checkpoint.checkpoint_sequential(
    functions,      # nn.Sequential 或模块列表
    segments,       # 分块数量
    input,          # 输入张量
    preserve_rng_state=True,
    use_reentrant=False
)

一句话总结: 将 Sequential 模型自动分成多个段,每段作为一个检查点,简化使用。

3. 高级功能:选择性检查点(Selective Checkpointing)

可以控制哪些操作需要重新计算,哪些操作保存结果:

python 复制代码
from torch.utils.checkpoint import (
    CheckpointPolicy, 
    create_selective_checkpoint_contexts
)

# 定义策略:哪些操作保存,哪些重新计算
def policy_fn(ctx, op, *args, **kwargs):
    if op == torch.ops.aten.mm.default:  # 矩阵乘法保存
        return CheckpointPolicy.MUST_SAVE
    else:  # 其他操作重新计算
        return CheckpointPolicy.PREFER_RECOMPUTE

context_fn = functools.partial(
    create_selective_checkpoint_contexts, 
    policy_fn
)

out = torch.utils.checkpoint.checkpoint(
    fn, x, y,
    use_reentrant=False,
    context_fn=context_fn
)

一句话总结: 选择性检查点允许你精细控制哪些操作保存结果、哪些重新计算,可以进一步优化显存和计算时间的平衡。

4. 官方文档的重要注意事项
  1. use_reentrant=False 的优势:

    • 支持关键字参数
    • 支持 torch.autograd.grad()
    • 支持选择性检查点
    • 更好的性能
  2. use_reentrant=True 的限制:

    • 不支持关键字参数
    • 不支持 torch.autograd.grad()
    • 更多限制和注意事项
  3. 版本兼容性:

    • PyTorch 1.0+:基础 checkpoint 功能
    • PyTorch 1.11+:支持 use_reentrant=False
    • PyTorch 2.5+:要求显式传递 use_reentrant
    • PyTorch 2.9+:不传递 use_reentrant 将抛出异常

一句话总结: 始终使用 use_reentrant=False 并显式传递,这是官方推荐的最佳实践,能获得最佳性能和未来兼容性。


内存节省的数学分析

理论分析

假设我们有一个深度为 N 的模型,每层的激活值大小为 |a|。

传统方式的内存占用
复制代码
前向传播:
  - Layer 1 激活值:|a|
  - Layer 2 激活值:|a|
  - ...
  - Layer N 激活值:|a|
  - 总占用:N × |a|

反向传播:
  - 需要保存激活值用于梯度计算
  - 还需要保存梯度本身
  - 总占用:约 2 × N × |a|

总计:约 3 × N × |a|
Gradient Checkpointing 的内存占用

假设检查点间隔为 k(通常 k ≈ √N):

复制代码
检查点数量:N/k ≈ √N

前向传播:
  - 只保存检查点的激活值:√N × |a|
  - 当前正在计算的层:1 × |a|
  - 总占用:√N × |a| + |a| ≈ √N × |a|

反向传播:
  - 检查点激活值:√N × |a|
  - 重新计算的激活值(最多 k 层):k × |a| ≈ √N × |a|
  - 梯度:N × |a|
  - 总占用:约 (2√N + N) × |a|

总计:约 (2√N + N) × |a|
内存节省比例

对于 N = 24 的模型:

  • 传统方式:3 × 24 × |a| = 72 × |a|
  • Checkpointing:(2 × √24 + 24) × |a| ≈ (10 + 24) × |a| = 34 × |a|
  • 节省比例(72 - 34) / 72 ≈ 53%

对于更深的模型(N = 100):

  • 传统方式:3 × 100 × |a| = 300 × |a|
  • Checkpointing:(2 × 10 + 100) × |a| = 120 × |a|
  • 节省比例(300 - 120) / 300 = 60%

实际测量

在实际训练中,Gradient Checkpointing 通常能节省:

  • 浅层模型(10-20 层):30-50% 显存
  • 深层模型(20-50 层):50-70% 显存
  • 超深层模型(50+ 层):60-80% 显存

为什么是"约一半"?

对于典型的 Transformer 模型(20-30 层):

  • 检查点通常设置为每 4-6 层一个
  • 检查点数量约为总层数的 1/4 到 1/6
  • 加上重新计算的开销,总显存节省约为 50-60%

这就是为什么你观察到"节省了将近一半的 GPU memory"。


对训练效果的影响

数值精度:理论上完全等价

重要结论:Gradient Checkpointing 不会改变训练结果。

数学证明

Gradient Checkpointing 只是改变了计算顺序,不改变计算内容:

  1. 前向传播

    • 传统:y = f_N(f_{N-1}(...f_1(x)...))
    • Checkpointing:y = f_N(f_{N-1}(...f_1(x)...))(相同)
  2. 反向传播

    • 传统:使用保存的激活值计算梯度
    • Checkpointing:重新计算激活值后计算梯度
    • 结果相同:因为重新计算使用的是相同的输入和参数
  3. 数值稳定性

    • PyTorch 的 checkpoint 使用确定性重计算
    • 浮点运算顺序可能略有不同,但误差在机器精度范围内

训练时间:增加 20-30%

时间开销分析

传统方式:

复制代码
总时间 = 前向时间 + 反向时间
      = O(N) + O(N)
      = O(N)

Gradient Checkpointing:

复制代码
总时间 = 前向时间 + 反向时间(含重计算)
      = O(N) + O(N × k)  // k 是检查点间隔
      = O(N × k)
      ≈ O(N × √N)

对于 N = 24 的模型:

  • 检查点间隔 k ≈ 4-6
  • 时间增加:约 20-30%
实际测量

在实际训练中,Gradient Checkpointing 通常增加:

  • 浅层模型:10-20% 训练时间
  • 深层模型:20-30% 训练时间
  • 超深层模型:30-40% 训练时间

训练效果:几乎无影响

实验证据

大量研究和实践表明:

  1. 收敛速度:基本一致

    • 相同的 loss 曲线
    • 相同的收敛步数
  2. 最终性能:完全一致

    • 相同的验证集准确率
    • 相同的模型质量
  3. 训练稳定性:完全一致

    • 相同的梯度方差
    • 相同的训练曲线波动
为什么没有影响?
  1. 数学等价性:计算结果理论上完全相同
  2. 浮点误差:在机器精度范围内,可忽略
  3. 批量效应:batch size 的影响远大于 checkpointing 的数值误差

实际应用中的权衡

使用 Gradient Checkpointing 的场景

推荐使用:

  • 显存不足,无法训练目标模型
  • 需要增大 batch size 提高训练稳定性
  • 需要训练更深的模型
  • 显存是瓶颈,计算时间可以接受

不推荐使用:

  • 显存充足,不需要优化
  • 训练时间敏感,需要最快速度
  • 模型很浅(< 10 层),节省不明显
最佳实践

在你的配置中:

yaml 复制代码
batch_size: 16
gradient_checkpointing: true

这个配置是合理的:

  • 使用 checkpointing 节省显存
  • 可以用更大的 batch size(16)
  • 训练时间增加 20-30%,但可以接受
  • 训练效果不受影响

最佳实践与建议

1. 何时启用 Gradient Checkpointing

判断标准

启用条件:

  • GPU 显存使用率 > 90%
  • 无法增大 batch size
  • 模型深度 > 15 层
  • 激活值显存占用 > 总显存的 50%

不启用条件:

  • 显存充足(使用率 < 70%)
  • 模型很浅(< 10 层)
  • 训练时间敏感

2. 检查点间隔的选择

理论最优值

对于 N 层的模型,最优检查点间隔为:

复制代码
k = √N

例如:

  • N = 24 → k ≈ 5
  • N = 100 → k ≈ 10
实际建议

在代码中,通常每个 block 设置一个检查点:

python 复制代码
for layer, block in enumerate(self.blocks):
    if self.training and self.gradient_checkpointing:
        x = torch.utils.checkpoint.checkpoint(...)

这意味着检查点间隔为 1(每层都是检查点),这是最保守但也最安全的选择。

3. 与其他优化技术的结合

可以组合使用的技术
  1. 混合精度训练(AMP)

    yaml 复制代码
    use_amp: true
    amp_type: "bf16"
    gradient_checkpointing: true
    • AMP 减少激活值大小(fp32 → bf16)
    • Checkpointing 减少激活值数量
    • 两者结合可以节省 70-80% 显存
  2. 梯度累积(Gradient Accumulation)

    • 可以模拟更大的 batch size
    • 与 checkpointing 结合,进一步优化显存
  3. 数据并行(Data Parallelism)

    • 多 GPU 训练
    • 每个 GPU 使用 checkpointing
不建议组合的技术
  1. 激活值重计算(Activation Recomputation)
    • 与 checkpointing 功能重复
    • 同时使用会浪费计算

4. 性能调优建议

监控指标

训练时应该监控:

  • GPU 显存使用率
  • 训练时间(每个 step)
  • Loss 曲线(确保训练正常)
调优策略

如果显存还有余量:

  • 可以增大 batch size
  • 或者减少 checkpointing 频率(增大检查点间隔)

如果训练太慢:

  • 检查是否真的需要 checkpointing
  • 考虑使用更快的 GPU
  • 或者使用混合精度训练

5. 常见问题与解决方案

问题 1:启用后显存仍然不足

可能原因:

  • 模型参数太大
  • Batch size 太大
  • 其他显存占用(优化器状态等)

解决方案:

  • 减小 batch size
  • 使用梯度累积
  • 使用混合精度训练
  • 考虑模型并行
问题 2:训练时间增加太多

可能原因:

  • 检查点间隔太小
  • 模型计算本身很慢

解决方案:

  • 增大检查点间隔(如果可能)
  • 使用更快的 GPU
  • 优化模型结构
问题 3:训练不稳定

可能原因:

  • 不是 checkpointing 的问题
  • 可能是 batch size 或其他超参数

解决方案:

  • 检查 loss 曲线
  • 调整学习率
  • 检查数据加载

总结

核心要点

  1. Gradient Checkpointing 是什么?

    • 一种用计算时间换取显存空间的技术
    • 只保存关键检查点的激活值
    • 反向传播时重新计算中间激活值
  2. 为什么能节省显存?

    • 传统方式:保存所有层的激活值 → O(N)
    • Checkpointing:只保存检查点 → O(√N)
    • 节省比例:约 50-70%
  3. 对训练有什么影响?

    • 数值结果:理论上完全等价,实际无影响
    • 训练时间:增加 20-30%
    • 训练效果:几乎无影响
  4. 何时使用?

    • 显存不足时
    • 需要更大 batch size 时
    • 训练深层模型时

进一步学习

如果你想深入了解:

  1. PyTorch 官方文档(原始文档)

    • 官方原始文档https://docs.pytorch.org/docs/stable/checkpoint.html
    • torch.utils.checkpoint.checkpoint API(PyTorch 原生,不依赖任何框架)
    • torch.utils.checkpoint.checkpoint_sequential API(Sequential 模型便捷函数)
    • 选择性检查点(Selective Checkpointing)高级功能
    • 完整的参数说明、使用示例、注意事项和版本兼容性
    • 一句话总结:这是 PyTorch 官方提供的权威参考,包含所有 API 细节和最佳实践
  2. 相关论文

    • "Training Deep Nets with Sublinear Memory Cost" (Chen et al., 2016)
    • "Gradient Checkpointing" (Gruslys et al., 2016)
  3. 在不同框架中的使用

    • 纯 PyTorch :直接使用 torch.utils.checkpoint.checkpoint
    • PyTorch Lightning:在模型代码中使用(如本文示例)
    • Hugging Face Transformers :通过 model.gradient_checkpointing_enable() 启用
    • FairScale:提供了更高级的 checkpointing 策略
  4. 实践建议

    • 在实际项目中尝试不同的检查点策略
    • 监控显存和训练时间
    • 根据实际情况调整

希望这篇博客帮助你深入理解 Gradient Checkpointing 的原理和应用! 🚀

相关推荐
nix.gnehc1 小时前
PyTorch自动求导
人工智能·pytorch·python
mortimer1 小时前
视频自动翻译里的“时空折叠”:简单实用的音画同步实践
python·ffmpeg·aigc
Dfreedom.1 小时前
机器学习模型误差深度解读:从三类来源到偏差-方差权衡
人工智能·深度学习·机器学习·误差·偏差方差权衡
serve the people1 小时前
tensorflow tf.function 的 多态性(Polymorphism)
人工智能·python·tensorflow
爱思德学术1 小时前
【EI收录】第三届智能交通及智慧城市国际会议(ICITSC 2026)
人工智能·智慧城市
muxin-始终如一1 小时前
Semaphore 使用及原理详解
java·开发语言·python
水水不水啊1 小时前
通过一个域名,借助IPV6免费远程访问自己家里的设备
前端·python·算法
马踏岛国赏樱花1 小时前
低成本大模型构建-KTransformers
人工智能
nju_spy1 小时前
力扣每日一题(11.10-11.29)0-1 和 k 整除系列
python·算法·leetcode·前缀和·单调栈·最大公约数·0-1背包