Gradient Checkpointing(梯度检查点)详解:用时间换空间的巧妙优化
目录
- [引言:为什么需要 Gradient Checkpointing?](#引言:为什么需要 Gradient Checkpointing?)
- 核心问题:显存不足的困境
- [Gradient Checkpointing 的基本概念:用生活比喻理解](#Gradient Checkpointing 的基本概念:用生活比喻理解)
- 技术原理:从数学角度理解
- 内存节省的数学分析
一句话理解
Gradient Checkpointing 是什么?
- 一句话总结 :用计算时间换取显存空间的技术,通过在前向传播时不保存中间激活值,反向传播时重新计算,从而节省约 50-70% 的显存。
核心要点:
- ✅ 显存节省:约 50-70%(从 O(N) 降到 O(√N))
- ⏱️ 时间代价:训练时间增加 20-30%
- 🎯 训练效果:几乎无影响(数值等价)
- 📚 官方 API :
torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=False, ...) - ⚠️ 重要提示 :必须显式传递
use_reentrant=False(PyTorch 2.5+ 强制要求)
适用场景:
- 显存不足,无法训练目标模型
- 需要增大 batch size 提高训练稳定性
- 训练深层模型(> 15 层)
引言:为什么需要 Gradient Checkpointing?
在深度学习训练中,我们经常面临一个经典矛盾:模型容量与显存限制的冲突。
真实场景
想象你正在训练一个大型 Transformer 模型(如 HunyuanDiT):
- 模型参数:可能达到数十亿甚至数百亿
- 显存需求:前向传播需要保存所有中间激活值(activations)
- 现实限制:GPU 显存有限(如 24GB、40GB 或 80GB)
当模型太大,显存装不下时,你该怎么办?
传统解决方案的局限
- 减小 batch size:虽然能降低显存,但训练不稳定,收敛慢
- 使用更小的模型:牺牲模型表达能力
- 购买更多/更大的 GPU:成本高昂
Gradient Checkpointing 的优雅解决方案
Gradient Checkpointing(梯度检查点)提供了一种巧妙的思路:
- 核心思想:不保存所有中间激活值,只保存关键检查点
- 代价:反向传播时需要重新计算部分前向过程
- 收益:显存占用降低约 50-70%,可以用更大的 batch size 或更深的模型
简单来说:用计算时间换取显存空间。
核心问题:显存不足的困境
深度学习训练中的显存消耗
在训练神经网络时,GPU 显存主要被以下内容占用:
-
模型参数(Parameters)
- 权重矩阵、偏置等
- 通常占用相对固定的显存
-
优化器状态(Optimizer States)
- Adam 优化器需要保存动量和方差
- 通常是参数大小的 2-3 倍
-
激活值(Activations) ⚠️ 这是最大的显存消耗者
- 前向传播时每层的输出
- 反向传播时需要这些值来计算梯度
- 对于深度模型,激活值可能占用 50-80% 的显存
激活值显存消耗示例
假设我们有一个 24 层的 Transformer 模型:
- 每层输出:
[batch_size=16, seq_len=1024, hidden_dim=1024] - 单个激活值大小:
16 × 1024 × 1024 × 4 bytes (float32) = 64 MB - 24 层总激活值:
64 MB × 24 = 1.5 GB
这还只是单次前向传播!如果考虑梯度计算时的中间值,显存需求会更大。
问题的本质
传统训练流程:
前向传播:计算并保存所有层的激活值 → 显存占用 = O(层数 × 激活值大小)
反向传播:使用保存的激活值计算梯度 → 显存占用 = O(层数 × 激活值大小)
对于深度模型,这个显存占用是线性增长的,很快就会超出 GPU 容量。
Gradient Checkpointing 的基本概念:用生活比喻理解
比喻:图书馆的借书系统
想象你在图书馆写论文,需要查阅很多书籍:
传统方式(不使用 Checkpointing):
- 你一次性把所有需要的书都借出来,堆在桌子上
- 优点:需要哪本书时,随手就能拿到
- 缺点:桌子空间有限,很快就放不下了
Checkpointing 方式:
- 你只保留几本最关键的参考书在桌子上(检查点)
- 当需要其他书时,再去书架上取(重新计算)
- 优点:桌子空间节省很多,可以放更多重要资料
- 缺点:需要某本书时,要多走几步去取(增加计算时间)
核心概念总结
通过以上比喻,我们可以理解 Gradient Checkpointing 的三个核心概念:
- 检查点(Checkpoint):只保存关键位置的激活值
- 丢弃中间值:非检查点的激活值被丢弃,释放显存
- 重新计算:反向传播时,从检查点重新计算需要的激活值
技术原理:从数学角度理解
前向传播与反向传播回顾
在深度学习中,训练过程分为两个阶段:
1. 前向传播(Forward Pass)
输入 x → Layer 1 → a₁ → Layer 2 → a₂ → ... → Layer N → aₙ → 输出 y
其中 a₁, a₂, ..., aₙ 是各层的激活值。
2. 反向传播(Backward Pass)
为了计算梯度,我们需要链式法则:
∂L/∂W₁ = (∂L/∂aₙ) × (∂aₙ/∂aₙ₋₁) × ... × (∂a₂/∂a₁) × (∂a₁/∂W₁)
关键点 :计算 ∂aᵢ₊₁/∂aᵢ 时,需要用到前向传播时的激活值 aᵢ。
传统方式的内存占用
前向传播:
- 计算 a₁,保存到显存
- 计算 a₂,保存到显存
- ...
- 计算 aₙ,保存到显存
- 显存占用 = a₁ + a₂ + ... + aₙ = O(N × |a|)
反向传播:
- 使用保存的 aₙ 计算梯度
- 使用保存的 aₙ₋₁ 计算梯度
- ...
- 使用保存的 a₁ 计算梯度
- 显存占用 = O(N × |a|) (激活值 + 梯度)
Gradient Checkpointing 的工作原理
策略:只保存检查点
假设我们有一个 24 层的模型,只在每 4 层设置一个检查点:
检查点位置:Layer 0, Layer 4, Layer 8, Layer 12, Layer 16, Layer 20, Layer 24
前向传播阶段
1. 计算 Layer 0-3,只保存 Layer 0 的输出(检查点)
2. 丢弃 Layer 1-3 的激活值
3. 计算 Layer 4-7,只保存 Layer 4 的输出(检查点)
4. 丢弃 Layer 5-7 的激活值
5. ... 以此类推
显存占用 = 检查点数量 × |a| = O(√N × |a|)
反向传播阶段
当需要计算 Layer 5 的梯度时:
1. 从最近的检查点(Layer 4)开始
2. 重新计算 Layer 4 → Layer 5 的前向过程
3. 使用重新计算的激活值计算梯度
4. 计算完成后,丢弃重新计算的激活值
数学复杂度分析
假设模型有 N 层,每层激活值大小为 |a|:
传统方式:
- 显存占用:
O(N × |a|) - 计算时间:
O(N)(前向)+O(N)(反向)=O(N)
Gradient Checkpointing(检查点间隔为 √N):
- 显存占用:
O(√N × |a|) - 计算时间:
O(N)(前向)+O(N × √N)(反向,需要重新计算)=O(N × √N)
权衡:
- 显存节省:从
O(N)降到O(√N),节省约 50-70% - 时间增加:从
O(N)增加到O(N × √N),增加约 20-30%
代码详解
1. 条件判断
python
if self.training and self.gradient_checkpointing:
self.training:只在训练模式下启用(推理时不需要)self.gradient_checkpointing:检查是否启用了该功能
2. 创建自定义前向函数
python
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
这个函数的作用是:
- 创建一个包装函数,将模块的前向传播逻辑封装起来
- PyTorch 的
checkpoint需要这个函数来知道如何重新计算
3. 使用 torch.utils.checkpoint.checkpoint
python
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, c, cond, skip_value,
use_reentrant=False
)
参数说明(基于官方文档):
create_custom_forward(block):前向传播函数(第一个参数,必需)x, c, cond, skip_value:输入参数(位置参数,可以是多个)use_reentrant=False:必须显式传递 (PyTorch 2.5+ 将要求必须传递)False:非重入模式(推荐),支持更多功能,更高效True:重入模式(旧版),有更多限制
preserve_rng_state=True(可选):是否保存和恢复随机数生成器状态context_fn=None(可选):用于选择性检查点的上下文函数
⚠️ 官方文档重要警告:
use_reentrant参数应该显式传递。在 PyTorch 2.9 版本中,如果不传递此参数将抛出异常。
一句话总结: checkpoint() 通过在前向传播时不保存中间激活值,反向传播时重新计算,从而节省显存;必须显式设置 use_reentrant=False 以获得最佳性能和兼容性。
工作原理:
-
前向传播时:
- 执行
block的前向计算 - 不保存中间激活值到计算图
- 只保存输入参数(作为检查点)
- 执行
-
反向传播时:
- 从保存的输入参数开始
- 重新执行
block的前向计算 - 使用重新计算的激活值进行反向传播
注意 :torch.utils.checkpoint.checkpoint 是 PyTorch 1.0+ 就提供的标准 API,与训练框架(PyTorch Lightning、原生 PyTorch 等)无关。
官方文档参考 :完整的 API 文档和参数说明请参考 PyTorch 官方文档,这是 torch.utils.checkpoint 模块的原始官方文档。
完整代码流程
让我们用一个具体的例子来理解:
假设有 24 层的模型,当前处理第 5 层:
不使用 Gradient Checkpointing:
python
# 前向传播
x = block(x, c, cond, skip_value) # 计算并保存所有中间值
# 反向传播时
# 直接使用保存的激活值计算梯度
使用 Gradient Checkpointing:
python
# 前向传播
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, c, cond, skip_value,
use_reentrant=False
)
# 只保存输入 (x, c, cond, skip_value),不保存 block 内部的中间值
# 反向传播时
# 1. 从保存的输入 (x, c, cond, skip_value) 开始
# 2. 重新执行 block 的前向计算
# 3. 使用重新计算的激活值进行反向传播
# 4. 计算完成后,丢弃重新计算的激活值
官方文档要点补充
根据 PyTorch 官方文档,以下是关键要点:
1. 核心 API:torch.utils.checkpoint.checkpoint()
函数签名:
python
torch.utils.checkpoint.checkpoint(
function, # 要检查点的函数
*args, # 位置参数
use_reentrant, # 必须显式传递(推荐 False)
preserve_rng_state=True, # 是否保存 RNG 状态
context_fn=None, # 选择性检查点上下文
**kwargs # 关键字参数
)
关键特性:
- ✅ 支持关键字参数(当
use_reentrant=False时) - ✅ 支持
torch.autograd.grad()(当use_reentrant=False时) - ✅ 支持选择性检查点(通过
context_fn) - ⚠️
use_reentrant参数必须显式传递(PyTorch 2.5+ 强制要求)
2. 便捷函数:checkpoint_sequential()
对于 nn.Sequential 模型,可以使用更简单的 API:
python
torch.utils.checkpoint.checkpoint_sequential(
functions, # nn.Sequential 或模块列表
segments, # 分块数量
input, # 输入张量
preserve_rng_state=True,
use_reentrant=False
)
一句话总结: 将 Sequential 模型自动分成多个段,每段作为一个检查点,简化使用。
3. 高级功能:选择性检查点(Selective Checkpointing)
可以控制哪些操作需要重新计算,哪些操作保存结果:
python
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts
)
# 定义策略:哪些操作保存,哪些重新计算
def policy_fn(ctx, op, *args, **kwargs):
if op == torch.ops.aten.mm.default: # 矩阵乘法保存
return CheckpointPolicy.MUST_SAVE
else: # 其他操作重新计算
return CheckpointPolicy.PREFER_RECOMPUTE
context_fn = functools.partial(
create_selective_checkpoint_contexts,
policy_fn
)
out = torch.utils.checkpoint.checkpoint(
fn, x, y,
use_reentrant=False,
context_fn=context_fn
)
一句话总结: 选择性检查点允许你精细控制哪些操作保存结果、哪些重新计算,可以进一步优化显存和计算时间的平衡。
4. 官方文档的重要注意事项
-
use_reentrant=False的优势:- 支持关键字参数
- 支持
torch.autograd.grad() - 支持选择性检查点
- 更好的性能
-
use_reentrant=True的限制:- 不支持关键字参数
- 不支持
torch.autograd.grad() - 更多限制和注意事项
-
版本兼容性:
- PyTorch 1.0+:基础 checkpoint 功能
- PyTorch 1.11+:支持
use_reentrant=False - PyTorch 2.5+:要求显式传递
use_reentrant - PyTorch 2.9+:不传递
use_reentrant将抛出异常
一句话总结: 始终使用 use_reentrant=False 并显式传递,这是官方推荐的最佳实践,能获得最佳性能和未来兼容性。
内存节省的数学分析
理论分析
假设我们有一个深度为 N 的模型,每层的激活值大小为 |a|。
传统方式的内存占用
前向传播:
- Layer 1 激活值:|a|
- Layer 2 激活值:|a|
- ...
- Layer N 激活值:|a|
- 总占用:N × |a|
反向传播:
- 需要保存激活值用于梯度计算
- 还需要保存梯度本身
- 总占用:约 2 × N × |a|
总计:约 3 × N × |a|
Gradient Checkpointing 的内存占用
假设检查点间隔为 k(通常 k ≈ √N):
检查点数量:N/k ≈ √N
前向传播:
- 只保存检查点的激活值:√N × |a|
- 当前正在计算的层:1 × |a|
- 总占用:√N × |a| + |a| ≈ √N × |a|
反向传播:
- 检查点激活值:√N × |a|
- 重新计算的激活值(最多 k 层):k × |a| ≈ √N × |a|
- 梯度:N × |a|
- 总占用:约 (2√N + N) × |a|
总计:约 (2√N + N) × |a|
内存节省比例
对于 N = 24 的模型:
- 传统方式:
3 × 24 × |a| = 72 × |a| - Checkpointing:
(2 × √24 + 24) × |a| ≈ (10 + 24) × |a| = 34 × |a| - 节省比例 :
(72 - 34) / 72 ≈ 53%
对于更深的模型(N = 100):
- 传统方式:
3 × 100 × |a| = 300 × |a| - Checkpointing:
(2 × 10 + 100) × |a| = 120 × |a| - 节省比例 :
(300 - 120) / 300 = 60%
实际测量
在实际训练中,Gradient Checkpointing 通常能节省:
- 浅层模型(10-20 层):30-50% 显存
- 深层模型(20-50 层):50-70% 显存
- 超深层模型(50+ 层):60-80% 显存
为什么是"约一半"?
对于典型的 Transformer 模型(20-30 层):
- 检查点通常设置为每 4-6 层一个
- 检查点数量约为总层数的 1/4 到 1/6
- 加上重新计算的开销,总显存节省约为 50-60%
这就是为什么你观察到"节省了将近一半的 GPU memory"。
对训练效果的影响
数值精度:理论上完全等价
重要结论:Gradient Checkpointing 不会改变训练结果。
数学证明
Gradient Checkpointing 只是改变了计算顺序,不改变计算内容:
-
前向传播:
- 传统:
y = f_N(f_{N-1}(...f_1(x)...)) - Checkpointing:
y = f_N(f_{N-1}(...f_1(x)...))(相同)
- 传统:
-
反向传播:
- 传统:使用保存的激活值计算梯度
- Checkpointing:重新计算激活值后计算梯度
- 结果相同:因为重新计算使用的是相同的输入和参数
-
数值稳定性:
- PyTorch 的
checkpoint使用确定性重计算 - 浮点运算顺序可能略有不同,但误差在机器精度范围内
- PyTorch 的
训练时间:增加 20-30%
时间开销分析
传统方式:
总时间 = 前向时间 + 反向时间
= O(N) + O(N)
= O(N)
Gradient Checkpointing:
总时间 = 前向时间 + 反向时间(含重计算)
= O(N) + O(N × k) // k 是检查点间隔
= O(N × k)
≈ O(N × √N)
对于 N = 24 的模型:
- 检查点间隔 k ≈ 4-6
- 时间增加:约 20-30%
实际测量
在实际训练中,Gradient Checkpointing 通常增加:
- 浅层模型:10-20% 训练时间
- 深层模型:20-30% 训练时间
- 超深层模型:30-40% 训练时间
训练效果:几乎无影响
实验证据
大量研究和实践表明:
-
收敛速度:基本一致
- 相同的 loss 曲线
- 相同的收敛步数
-
最终性能:完全一致
- 相同的验证集准确率
- 相同的模型质量
-
训练稳定性:完全一致
- 相同的梯度方差
- 相同的训练曲线波动
为什么没有影响?
- 数学等价性:计算结果理论上完全相同
- 浮点误差:在机器精度范围内,可忽略
- 批量效应:batch size 的影响远大于 checkpointing 的数值误差
实际应用中的权衡
使用 Gradient Checkpointing 的场景
✅ 推荐使用:
- 显存不足,无法训练目标模型
- 需要增大 batch size 提高训练稳定性
- 需要训练更深的模型
- 显存是瓶颈,计算时间可以接受
❌ 不推荐使用:
- 显存充足,不需要优化
- 训练时间敏感,需要最快速度
- 模型很浅(< 10 层),节省不明显
最佳实践
在你的配置中:
yaml
batch_size: 16
gradient_checkpointing: true
这个配置是合理的:
- 使用 checkpointing 节省显存
- 可以用更大的 batch size(16)
- 训练时间增加 20-30%,但可以接受
- 训练效果不受影响
最佳实践与建议
1. 何时启用 Gradient Checkpointing
判断标准
启用条件:
- GPU 显存使用率 > 90%
- 无法增大 batch size
- 模型深度 > 15 层
- 激活值显存占用 > 总显存的 50%
不启用条件:
- 显存充足(使用率 < 70%)
- 模型很浅(< 10 层)
- 训练时间敏感
2. 检查点间隔的选择
理论最优值
对于 N 层的模型,最优检查点间隔为:
k = √N
例如:
- N = 24 → k ≈ 5
- N = 100 → k ≈ 10
实际建议
在代码中,通常每个 block 设置一个检查点:
python
for layer, block in enumerate(self.blocks):
if self.training and self.gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(...)
这意味着检查点间隔为 1(每层都是检查点),这是最保守但也最安全的选择。
3. 与其他优化技术的结合
可以组合使用的技术
-
混合精度训练(AMP)
yamluse_amp: true amp_type: "bf16" gradient_checkpointing: true- AMP 减少激活值大小(fp32 → bf16)
- Checkpointing 减少激活值数量
- 两者结合可以节省 70-80% 显存
-
梯度累积(Gradient Accumulation)
- 可以模拟更大的 batch size
- 与 checkpointing 结合,进一步优化显存
-
数据并行(Data Parallelism)
- 多 GPU 训练
- 每个 GPU 使用 checkpointing
不建议组合的技术
- 激活值重计算(Activation Recomputation)
- 与 checkpointing 功能重复
- 同时使用会浪费计算
4. 性能调优建议
监控指标
训练时应该监控:
- GPU 显存使用率
- 训练时间(每个 step)
- Loss 曲线(确保训练正常)
调优策略
如果显存还有余量:
- 可以增大 batch size
- 或者减少 checkpointing 频率(增大检查点间隔)
如果训练太慢:
- 检查是否真的需要 checkpointing
- 考虑使用更快的 GPU
- 或者使用混合精度训练
5. 常见问题与解决方案
问题 1:启用后显存仍然不足
可能原因:
- 模型参数太大
- Batch size 太大
- 其他显存占用(优化器状态等)
解决方案:
- 减小 batch size
- 使用梯度累积
- 使用混合精度训练
- 考虑模型并行
问题 2:训练时间增加太多
可能原因:
- 检查点间隔太小
- 模型计算本身很慢
解决方案:
- 增大检查点间隔(如果可能)
- 使用更快的 GPU
- 优化模型结构
问题 3:训练不稳定
可能原因:
- 不是 checkpointing 的问题
- 可能是 batch size 或其他超参数
解决方案:
- 检查 loss 曲线
- 调整学习率
- 检查数据加载
总结
核心要点
-
Gradient Checkpointing 是什么?
- 一种用计算时间换取显存空间的技术
- 只保存关键检查点的激活值
- 反向传播时重新计算中间激活值
-
为什么能节省显存?
- 传统方式:保存所有层的激活值 → O(N)
- Checkpointing:只保存检查点 → O(√N)
- 节省比例:约 50-70%
-
对训练有什么影响?
- 数值结果:理论上完全等价,实际无影响
- 训练时间:增加 20-30%
- 训练效果:几乎无影响
-
何时使用?
- 显存不足时
- 需要更大 batch size 时
- 训练深层模型时
进一步学习
如果你想深入了解:
-
PyTorch 官方文档(原始文档)
- 官方原始文档:https://docs.pytorch.org/docs/stable/checkpoint.html
torch.utils.checkpoint.checkpointAPI(PyTorch 原生,不依赖任何框架)torch.utils.checkpoint.checkpoint_sequentialAPI(Sequential 模型便捷函数)- 选择性检查点(Selective Checkpointing)高级功能
- 完整的参数说明、使用示例、注意事项和版本兼容性
- 一句话总结:这是 PyTorch 官方提供的权威参考,包含所有 API 细节和最佳实践
-
相关论文
- "Training Deep Nets with Sublinear Memory Cost" (Chen et al., 2016)
- "Gradient Checkpointing" (Gruslys et al., 2016)
-
在不同框架中的使用
- 纯 PyTorch :直接使用
torch.utils.checkpoint.checkpoint - PyTorch Lightning:在模型代码中使用(如本文示例)
- Hugging Face Transformers :通过
model.gradient_checkpointing_enable()启用 - FairScale:提供了更高级的 checkpointing 策略
- 纯 PyTorch :直接使用
-
实践建议
- 在实际项目中尝试不同的检查点策略
- 监控显存和训练时间
- 根据实际情况调整
希望这篇博客帮助你深入理解 Gradient Checkpointing 的原理和应用! 🚀