重计算是大模型训练的核心内存优化技术 ,本质是以时间换空间的策略。我们从「为什么需要它」到「核心特性」再到「进阶玩法」,一步步讲清楚。
一、小白入门:先搞懂「大模型训练的痛点」
要理解重计算,先得知道大模型训练时,内存都被谁占了 。
普通深度学习模型(比如小的CNN)训练流程分两步:
- 前向传播 :输入数据经过各层网络,生成中间结果(叫激活值),最后输出预测结果;同时会把「模型参数」和「激活值」都存在内存里。
- 反向传播 :用预测结果和真实标签算损失,再顺着网络往回算每一层参数的梯度,这个过程必须用到前向的激活值;算完梯度后,更新模型参数。
但到了大模型(比如千亿参数的Transformer),情况变了:
- 模型参数本身占内存(比如GPT-3的参数占几百GB);
- 激活值的内存占比,往往比参数还高! 比如Transformer的注意力层,每一层的激活值会随batch size、序列长度指数级增长。
普通GPU的内存根本扛不住------直接报「OOM(内存溢出)」错误,训练直接中断。
这时候,重计算就登场了。
二、核心概念:重计算到底是什么?
重计算的本质很简单:
前向传播时,不把所有激活值都存下来,只存少数关键的「检查点」;反向传播时,扔掉的激活值不再从内存里取,而是重新跑一遍前向计算得到,再用它算梯度。
举个生活中的例子:
你算一道数学题 (1+2)×(3+4)-5,步骤是:
- 先算
1+2=3(中间结果A),3+4=7(中间结果B); - 再算
3×7=21(中间结果C); - 最后算
21-5=16(最终结果)。
如果你的「草稿纸内存不够」,记不下A、B、C三个中间结果,你可以只记最终结果16 和原始算式 (这就是「检查点」)。
等你需要验证计算是否正确时,不用翻草稿纸找A、B、C,而是重新算一遍原始算式,再得到中间结果------这就是「重计算」。
对应到模型训练:
- 「原始算式」= 网络结构;
- 「中间结果A/B/C」= 激活值;
- 「只记最终结果+算式」= 只存检查点,不存全部激活值;
- 「重新算算式」= 反向时重跑前向,生成需要的激活值。
三、递进1:重计算的「官方实现」------ 梯度检查点
重计算在工程上的标准实现叫 Gradient Checkpointing(梯度检查点) ,这是你在论文和框架里最常看到的名字。
它的核心流程可以分成两步:
-
前向阶段 - 存检查点
- 把整个网络分成若干段(比如每10层为一段);
- 只在每段的开头/结尾存一个「检查点」(即这一层的激活值),中间层的激活值算完就扔,不占内存;
- 最终只保留「少数检查点 + 模型参数」。
-
反向阶段 - 重算激活值
- 反向传播到某一段时,先从该段的检查点出发,重新跑一遍这段的前向计算,生成中间层的激活值;
- 用重算出来的激活值,计算这段网络的参数梯度;
- 算完后,重算的激活值再次被扔掉,不占用额外内存。
四、递进2:重计算的核心特性(小白必懂)
从「时间、空间、精度、灵活性」四个维度,总结重计算的关键特性:
1. 核心优势:极致的空间效率
这是重计算存在的唯一目的。
- 可以减少 30%~70% 的激活值内存占用,让原本在单卡上跑不起来的大模型,能在普通GPU集群上训练;
- 内存节省的比例,和「检查点的密度」直接相关:检查点越少,内存省得越多(比如每20层存一个检查点,比每10层存一个省更多内存)。
2. 必然代价:时间开销增加
天下没有免费的午餐,「省内存」的代价是「多花时间」。
- 反向传播时要重新跑前向计算,相当于训练时间会增加 20%~50%(具体比例看检查点密度);
- 这是典型的时间换空间 trade-off(权衡),也是大模型训练的核心取舍逻辑。
3. 关键保障:无精度损失
这是重计算和其他内存优化方法(比如量化、剪枝)的最大区别。
- 重计算是精确的重新计算,不是近似计算------重算出来的激活值,和前向时第一次算的结果完全一样;
- 最终训练出的模型精度,和「不使用重计算、内存足够」的情况完全相同 。
这对大模型训练至关重要------毕竟大模型训练成本极高,没人能接受精度损失。
4. 灵活可调:平衡时间与空间
重计算不是「非开即关」的开关,而是可以精细调节的旋钮:
- 检查点密度:你可以自己决定「每隔几层存一个检查点」------比如内存紧张时,就减少检查点数量(多花时间,多省内存);算力紧张时,就增加检查点数量(少花时间,少省内存);
- 选择性重计算:可以只对「内存占比高的层」做重计算(比如Transformer的注意力层),对「内存占比低的层」(比如简单的线性层)直接存激活值------这样能在时间和空间之间找到最优平衡点。
五、递进3:进阶特性(小白了解即可)
当你熟悉基础用法后,会发现重计算还有更高级的玩法,这些是工业界优化大模型训练速度的关键:
1. 分层重计算:针对不同层做差异化策略
不同网络层的「计算复杂度」和「内存占比」差异很大:
- 比如Transformer的注意力层:计算复杂(要算Q/K/V矩阵乘法),但激活值内存占比极高;
- 比如全连接层(FFN):计算简单,激活值内存占比低。
工业界会对注意力层「少存检查点,多重计算」,对FFN层「多存检查点,少重计算」------这样既省内存,又不会让时间开销过高。
2. 异步重计算:利用GPU并行性减少等待
GPU是并行计算的高手,反向传播和重计算可以异步进行:
- 在反向算梯度的同时,让GPU的另一个计算核心异步重算激活值;
- 等反向需要激活值时,重计算已经完成,不用等待------能把时间开销从50%降到20%左右。
3. 混合精度重计算:进一步压缩内存
结合「混合精度训练」(用FP16半精度代替FP32单精度存储参数和激活值),重计算可以更省内存:
- 检查点用FP16存储,重计算时也用FP16算------内存占用再砍一半,且精度损失几乎可以忽略。
六、递进4:实际应用的注意事项(小白踩坑指南)
-
和其他内存优化方法结合使用
重计算很少单独用,通常和这些技术搭配:
- 模型并行:把模型的不同层分到不同GPU上;
- 数据并行:把数据分到不同GPU上;
- ZeRO优化 :把参数和梯度分片存储,不重复占用内存。
组合使用后,才能支撑千亿参数模型的训练。
-
动态图框架更友好
PyTorch、TensorFlow 2.x 等动态图框架 ,实现重计算很简单(几行代码调用API就行);
而TensorFlow 1.x 等静态图框架,需要提前规划计算图,实现起来更复杂。
-
小模型没必要用
重计算是「大模型专属优化」------如果你的模型很小(比如几万参数的CNN),内存足够用,用重计算反而会增加训练时间,得不偿失。
总结
重计算的核心逻辑可以一句话概括:
为了让大模型能在普通硬件上训练,牺牲部分训练时间,换取内存占用的大幅降低,且不损失任何模型精度。
它是大模型训练的「必备技能」,也是小白从「训练小模型」进阶到「训练大模型」的关键知识点。