llamafactory gradient_checkpointing 梯度检查点 通俗完整讲解

llamafactory gradient_checkpointing 梯度检查点 通俗完整讲解

1. 原生不开启时(你现在 false 的状态)

模型走一遍前向传播,会把所有中间激活值全部存在显存里

作用:反向传播算梯度时,直接拿这些中间值计算,不用重新跑前向,速度快。

代价:序列越长、batch越大,中间激活占的显存爆炸,很容易冲到99%。

举个你场景的例子:cutoff_len=1024、bs=32,14B大模型,中间激活张量体积非常大,这就是你显存经常顶满的核心元凶。

2. 开启 gradient_checkpointing: true 做了什么?

核心逻辑:以少量重复计算,换取大幅显存节省

  1. 前向传播时,不保存全部中间激活,只存少量关键节点;
  2. 等到反向传播需要某一段中间数据时,重新再跑一遍对应区间的前向计算,现场算出激活;
  3. 全程不会一次性堆海量中间张量在显存里,显存峰值直接下降 30%~50%。

和你梯度累积完全无关

不管 gradient_accumulation_steps 是2还是4,梯度检查点只管「前向激活要不要存」,不影响梯度累加逻辑。

3. 优缺点

优点

  • 大幅压低显存峰值,解决你 88%~99% 冲高、偶尔OOM崩溃问题;
  • 同样显卡下,可以开更大 batch / 更长序列长度;
  • 搭配 FlashAttention2 双重省显存,MI300 192GB 体验提升明显。

缺点

  • 每一步训练会多跑一小段前向计算,训练速度大概慢 10%~20%;
  • CPU算力、读写开销轻微上涨,preprocessing_num_workers=8 完全能扛住。

4. 适配你当前场景怎么选

  1. 现状:显存经常冲到99%,波动极大
    建议开启 true,稳定性优先,轻微降速完全可接受;
  2. 如果你后续调小batch、且全程显存稳定80以内不报警,再切回 false 提速。

5. 补充关键细节(SFT训练专用)

  • use_cache: false 必须搭配梯度检查点一起用(你配置里已经开了,正确),训练阶段禁用KV缓存,额外省一大块显存;
  • LoRA微调场景下开启梯度检查点不影响LoRA梯度更新,只冻结基础模型部分重算,LoRA训练效果无损失;
  • 和 bf16 / FlashAttention2 兼容,ROCm MI300无兼容性bug。

极简总结

不开:存所有中间激活 → 显存占用高、跑的快

开启:丢掉大部分中间激活,反向时临时重算 → 显存砍半、速度略慢

专门解决你长文本1024序列导致的显存爆满问题。