在进行大语言模型(LLM)的微调或预训练时,显存(VRAM)不足通常是首要面临的问题。为了在有限的硬件资源下完成训练,了解显存的具体去向以及相应的优化技术是比较基础的工作。
从模型训练的流程来看,显存的占用主要可以分为两大部分:模型状态(Model States)和剩余耗时产生的中间变量(主要是激活值)。以下对相关的优化方法做简单的梳理。
一、 模型状态的显存占用与 ZeRO 技术
模型状态包含了训练过程中最核心的三类数据:模型参数(Weights)、梯度(Gradients)以及优化器状态(Optimizer States)。
- 参数与梯度:对于一个 1.7B(17亿参数)的模型,如果使用 BF16 或 FP16 精度,参数本身约占 3.4GB。在训练过程中,系统还需要存储一份同样大小的梯度。
- 优化器状态:这是显存占用的"大户"。以常用的 Adam 优化器为例,它需要为每个参数记录动量(Momentum)和方差(Variance)。如果使用全精度(FP32)来存储这些状态以保证精度,其占用量通常是参数本身的数倍。
为了解决这些静态数据的冗余问题,微软提出的 ZeRO(Zero Redundancy Optimizer) 技术被广泛应用。它通过将数据切片并分散到多个显卡上来降低单卡负载:
- ZeRO-1:仅对优化器状态进行切片,每张显卡只负责维护一部分参数的优化器状态。
- ZeRO-2:在 1 的基础上,进一步对梯度进行切片。这是目前平衡显存节省与通信效率较好的选择。
- ZeRO-3:对模型参数也进行切片。当某一层需要计算时,临时从其他显卡"借"来参数,算完即释放。这种方式能最大程度节省显存,但显卡间的通信开销会显著增加。
二、 激活值的显存占用与重算机制
激活值(Activations)是指模型在"前向传播"过程中,每一层神经元计算出的中间结果。
与模型参数不同,激活值的占用量是动态的,它与训练时的批大小(Batch Size)和序列长度(Sequence Length)成正比。在处理长文本时,激活值的显存占用往往会超过模型参数本身。
由于反向传播计算梯度时必须用到这些中间结果,因此默认情况下它们必须保留在显存中。目前的优化主流方案是 梯度检查点(Gradient Checkpointing),其逻辑较为简单:
- 逻辑:在前向传播时,不再保存所有层的激活值,而是只保留一小部分关键节点的"检查点"。
- 重算:当反向传播需要用到被删除的中间值时,系统会根据最近的一个检查点重新进行一次前向计算。
- 代价:这是一种典型的"以时间换空间"的方法。它能节省大量的显存(有时可达 70% 以上),但会增加约 33% 的计算时间。
三、 激活值的卸载与并行策略
除了重算,还有一些进阶的手段来处理激活值,虽然它们对硬件环境的要求更高:
- 激活值卸载(Offloading):将暂时不用的激活值通过 PCIe 总线搬运到 CPU 内存中,需要时再搬回。受限于 PCIe 的带宽,这种方法在某些配置下可能会产生较明显的延迟。
- 序列并行(Sequence Parallelism):将长文本切分成几段,分配给不同的显卡分别计算。这属于分布式训练的高级范畴,通常需要较快的跨卡互联带宽支持。
四、 参数高效微调(LoRA)的辅助作用
在讨论上述底层优化时,不得不提 LoRA(Low-Rank Adaptation)。
严格来说,LoRA 改变的是需要更新的参数量。因为它冻结了原始模型的大部分参数,只训练极小规模的旁路矩阵,这直接导致:
- 梯度大幅减少:只需要存储少量可训练参数的梯度。
- 优化器状态减少:对应的优化器记录也随之减少。
虽然 LoRA 不直接改变激活值的计算方式,但由于它极大降低了"模型状态"的显存门槛,使得我们有更多的空间去增加 Batch Size 或序列长度。