1.1 模型显存总体分析
在训练/推理大模型时,GPU 显存并不是"全部都给模型用"。一部分显存会被 AI 框架 (如 PyTorch/CUDA runtime/通信库)占用,另一部分则由 系统/驱动 保留,用于上下文、内核态缓冲、显存页管理、ECC、显示/持久化守护进程等。因此,同一张 GPU 上"理论可用显存"和"实际可用显存"常常不一致。
从工程视角看,最直接的观测方式是通过 GPU 工具查询进程显存占用,例如 NVIDIA 平台常用 nvidia-smi 会打印各进程显存使用量。示例(进程与显存占用):
±--------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 1 N/A N/A 67321 C .../anaconda3/envs/py/bin/python 23646MiB |
| 1 N/A N/A 71612 C .../anaconda3/envs/py/bin/python 848MiB |
| 2 N/A N/A 67321 C .../anaconda3/envs/py/bin/python 25776MiB |
±--------------------------------------------------------------------------------------+
需要注意两点:
nvidia-smi看到的是 进程向驱动申请并保留的显存,其中既包括模型/张量,也包括框架内部缓存(例如 CUDA caching allocator 的预留与碎片)。- 同一个 PID 可能在多张卡上占用显存(如多卡训练/推理)。
1. 显存由哪些部分构成
从概念上,训练/推理显存开销通常包括:
- 模型参数(parameters)
- 优化器状态(optimizer states)
- 梯度(gradients)
- 激活(activations)
- 输入/输出张量(inputs/outputs)
- 临时张量/工作区(temporary / workspace)
- 自动求导图与元数据(autograd details / graph metadata)
- 其他不可归类部分(unknown)
为了更好地"估算 vs 实测"对齐,可以把显存拆成三类:
1.1 可估算值(可用公式近似)
- parameters
- optimizer states
- gradients
- activations
- inputs/outputs(通常较小,但长上下文/多模态时也可能显著)
1.2 未命名数据(很难提前精确估计)
- temporary/workspace(注意力算子、归一化、GEMM 的 workspace)
- unknown(缓存、碎片、通信缓冲、框架内部保留等)
1.3 框架相关
- autograd details(反向图、保存的中间信息、hook、AMP scaler 等)
2. 为什么"估算值"和"实际测量值"会差很多
在显存估算时,常见现象是:
估算得到总消耗 50GB,但实际
nvidia-smi看到占用达到 75GB,误差超过 30%。
这类差距通常来自以下因素叠加(往往不是单一原因):
2.1 CUDA/PyTorch 缓存与碎片
深度学习框架为了加速,会采用缓存分配器:释放的显存不会立刻还给驱动,而是保留在进程内用于复用。这样会导致:
- 实际"已保留显存"远大于"当前活跃张量显存"
- 频繁不同大小分配产生碎片,进一步抬高峰值保留
2.2 算子 workspace(临时工作区)
高性能算子(如 Attention、Conv、某些归一化、FlashAttention/SDPA 的不同实现)会按需要申请 workspace。workspace 的大小与:
- 序列长度 sss
- batch size(或 micro-batch)bbb
- head 数 aaa
- hidden size hhh
- 数据类型(fp16/bf16/fp8)
- 是否启用某些 fused kernel
密切相关,并且在不同 kernel 路径下差异很大。
2.3 通信与并行策略的额外缓冲
在 DP/TP/PP/ZeRO 等并行策略下,通信库(如 NCCL)可能分配:
- ring buffer
- fusion buffer
- reduce-scatter / all-gather 缓冲
当梯度 bucket 较大、融合策略激进时,缓冲占用会显著提升。
2.4 Autograd 与保存的中间量
即使模型参数/梯度可估算,框架为反向传播保存的中间张量(尤其在未做重计算时)会极大增加显存;同时还包含图结构与元信息。
2. 训练阶段的显存分析
训练阶段的"可估算部分"通常可写为:
TrainMem≈Model+Optimizer+Gradients+Activations \text{TrainMem} \approx \text{Model} + \text{Optimizer} + \text{Gradients} + \text{Activations} TrainMem≈Model+Optimizer+Gradients+Activations
同时按"是否随时间变化"可分为:
- 静态值:模型参数、优化器状态(训练过程中基本不变)
- 动态值:激活、梯度(随前向/反向阶段变化,峰值一般出现在反向或梯度聚合附近)
2.1 静态值分析
2.1.1 模型参数显存(Model Memory)
模型参数占用与 参数量 与 数据类型 直接相关。基本计算式:
ModelMem=TypeSize×Params \text{ModelMem} = \text{TypeSize} \times \text{Params} ModelMem=TypeSize×Params
将字节换算到 GB(以 102431024^310243 为换算基准)时,常用写法:
ModelMemfp32=4×Params10243ModelMemfp16/bf16=2×Params10243ModelMemfp8/int8=1×Params10243 \begin{aligned} \text{ModelMem}{\text{fp32}} &= \frac{4 \times \text{Params}}{1024^3} \\ \text{ModelMem}{\text{fp16/bf16}} &= \frac{2 \times \text{Params}}{1024^3} \\ \text{ModelMem}_{\text{fp8/int8}} &= \frac{1 \times \text{Params}}{1024^3} \end{aligned} ModelMemfp32ModelMemfp16/bf16ModelMemfp8/int8=102434×Params=102432×Params=102431×Params
checkpoint 存储大小估算
存储 checkpoint 时通常主要考虑"模型权重本身"。例如 1B 参数、fp32 存储:
DiskSize=4×10910243≈3.725 GB \text{DiskSize}=\frac{4 \times 10^9}{1024^3}\approx 3.725\ \text{GB} DiskSize=102434×109≈3.725 GB
工程上常用近似:
- 1B params≈4GB (fp32)1\text{B params} \approx 4\text{GB (fp32)}1B params≈4GB (fp32)
- 1B params≈2GB (fp16/bf16)1\text{B params} \approx 2\text{GB (fp16/bf16)}1B params≈2GB (fp16/bf16)
示例:
- 13B 模型若以 fp32 权重存储,粗略估算为 13×4≈52 GB13 \times 4 \approx 52\ \text{GB}13×4≈52 GB。
注意:许多训练采用混合精度,但最终保存 checkpoint 的权重格式取决于保存策略;不少默认保存仍可能是 fp32(或包含 fp32 master weights),因此需要区分"训练时参数精度"和"保存时权重精度"。
2.1.2 优化器状态(Optimizer States)
LLM 训练中最常见的是 Adam/AdamW。对每个参数,Adam 通常需要维护:
- 一阶矩(Momentum)
- 二阶矩(Variance)
同时在混合精度训练中,往往还存在一份 fp32 的 master 参数副本。
因此 Adam 的状态占用可近似为(单位 GB):
OptMemAdam=(4+4+4)×Params10243 \text{OptMem}_{\text{Adam}}= \frac{(4 + 4 + 4)\times \text{Params}}{1024^3} OptMemAdam=10243(4+4+4)×Params
其中 (4+4+4)(4+4+4)(4+4+4) 代表:
- fp32 master 参数副本:4 bytes
- Momentum:4 bytes
- Variance:4 bytes
8-bit 优化器(低精度状态)
若使用 8-bit 优化器(例如将一阶/二阶矩压到 8-bit 存储),常见近似为:
OptMem8bit=(4+1+1)×Params10243 \text{OptMem}_{\text{8bit}}= \frac{(4 + 1 + 1)\times \text{Params}}{1024^3} OptMem8bit=10243(4+1+1)×Params
含义:
- master 参数副本:4 bytes
- Momentum:1 byte
- Variance:1 byte
现实实现中还可能有额外的 scale/quantization metadata,因此这是工程近似,不是严格上界。
一个直观对比案例(帮助把量级"刻在脑子里")
假设参数量为 1B1\text{B}1B:
- fp16/bf16 参数本体大约 2 GB2\ \text{GB}2 GB
- Adam 状态(含 fp32 master + m + v)大约 12 GB12\ \text{GB}12 GB
也就是说在"常规 Adam + 混合精度"下:
- 优化器状态往往比模型权重本体大得多
- 这也是为什么 ZeRO/offload/8-bit optimizer 能带来巨大收益
2.2 动态值分析
动态值最关键的是 激活(activation) 与 梯度(gradient)。
2.2.1 激活值(Activation Memory)
激活是训练显存峰值的主要来源之一,因为反向传播需要用到前向中的中间结果(或其可重建形式)。
激活显存大小受多因素影响:
- 模型结构(层数 LLL、hidden size hhh、head 数 aaa)
- 序列长度 sss
- micro-batch bbb
- 是否启用 重计算(activation checkpointing)
- 并行策略(TP/PP 导致的保存位置变化)
- 注意力实现(不同 kernel/workspace)
一种常用的近似公式(参考 Megatron 系列分析思路)可以写为:
ActMem≈s⋅b⋅h⋅(34+5⋅a⋅sh)⋅L⋅γ \text{ActMem} \approx s \cdot b \cdot h \cdot \left(34 + 5\cdot \frac{a \cdot s}{h}\right)\cdot L \cdot \gamma ActMem≈s⋅b⋅h⋅(34+5⋅ha⋅s)⋅L⋅γ
其中(符号解释尽量直白):
- sss:序列长度(tokens 数)
- bbb:micro-batch size
- hhh:隐藏层维度(hidden size)
- aaa:注意力头数(number of heads)
- LLL:Transformer 层数
- γ\gammaγ:比例系数(把"元素个数"换算成 GB)
- 当采用 fp16/bf16 时,单元素 2 bytes,因此可以把换算写为:
γ=210243 \gamma = \frac{2}{1024^3} γ=102432 - 若是 fp32,则单元素 4 bytes,对应:
γ=410243 \gamma = \frac{4}{1024^3} γ=102434
- 当采用 fp16/bf16 时,单元素 2 bytes,因此可以把换算写为:
为什么这个公式里会有 (34+5⋅a⋅sh)\left(34 + 5\cdot \frac{a\cdot s}{h}\right)(34+5⋅ha⋅s)
它本质上是在粗略地把每层需要保存的中间量分成两类:
- 与 s⋅b⋅hs\cdot b\cdot hs⋅b⋅h 成正比的部分(MLP/残差/归一化等中间张量)
- 与注意力的 a⋅sa\cdot sa⋅s 相关的部分(例如注意力相关中间量在某些实现里会引入额外依赖)
这不是"精确到字节"的公式,但它能很好解释两个最重要的现象:
- sss(序列长度)对激活几乎是线性甚至更高阶的放大器
- 层数 LLL 与 micro-batch bbb 直接线性放大激活开销
激活的典型工程现象(案例)
- 把上下文从 s=2048s=2048s=2048 提到 s=8192s=8192s=8192,即使 batch 不变,激活很可能从"勉强能跑"变成"直接 OOM"。
- 在多模态(图像/视频)里,视觉 token 本身就会显著增加等效 sss,导致激活更快爆炸。
2.2.2 梯度值(Gradient Memory)
梯度与参数张量形状一致,通常与模型参数采用相同的数据类型(具体取决于实现与是否有 fp32 grad accumulator)。
常用估算:
-
若梯度为 fp32:
GradMemfp32=4×Params10243 \text{GradMem}_{\text{fp32}}=\frac{4 \times \text{Params}}{1024^3} GradMemfp32=102434×Params -
若梯度为 fp16/bf16:
GradMemfp16/bf16=2×Params10243 \text{GradMem}_{\text{fp16/bf16}}=\frac{2 \times \text{Params}}{1024^3} GradMemfp16/bf16=102432×Params
在一些配置下可能存在 fp32 梯度累积或额外 buffer,使得梯度相关显存大于上述估算。
3. 推理阶段显存分析
推理阶段不需要保存反向传播所需的激活,也不需要优化器状态与梯度,因此显存结构更简单。
推理显存主要来自:
- 模型参数(权重)
- KV cache(自回归生成时)
- 少量临时张量(算子 workspace)
- 框架缓存/碎片
因此很多场景下会用一个经验近似:
InferMem≈1.2×ModelMem \text{InferMem} \approx 1.2 \times \text{ModelMem} InferMem≈1.2×ModelMem
这个 1.21.21.2 可以理解为"模型权重之外的框架开销 + 少量临时 buffer"的粗略倍数。它在以下条件下更接近现实:
- 生成长度不长
- batch 不大
- KV cache 可控
但在自回归长生成里,KV cache 往往成为主导项,这时仅用 1.2×1.2\times1.2× 会低估。
4. 显存优化:思路与路径
当模型规模增长远超单卡物理显存时,显存优化通常围绕两条基本原则:
4.1 时间换空间
通过增加计算量换取更低显存占用,例如:
- 重计算(activation checkpointing):不保存全部激活,反向时再算一遍前向
代价:
- 额外算力消耗
- 训练 wall-time 变长
- 可能增加通信与同步开销(视并行策略而定)
4.2 空间转移
把本应在 GPU 上的内容"挪出去"或"摊开到多卡",例如:
- 多卡并行(TP/PP/DP/ZeRO)
- offload 到 CPU 内存或 NVMe(参数/优化器/梯度 offload)
代价:
- I/O 带宽压力上升(PCIe/NVLink/网络)
- 延迟增加,吞吐下降
- 工程复杂度显著提高
4.3 一个常用的优化路径(从高层到低层)
实践中,优化往往按影响范围与性价比从上到下推进:
4.3.1 多卡并行(最常用、通常不损精度)
通过并行策略减少单卡持有的数据量,可参考"模型参数/优化器/梯度/激活"的组成,针对性拆分:
- TP(Tensor Parallel):按张量维度切分参数与计算
- PP(Pipeline Parallel):按层切分
- DP(Data Parallel):按 batch 切分
- ZeRO:按参数/梯度/优化器状态切分(不同 stage 切分粒度不同)
- 重计算:进一步降低激活峰值(常与 PP/TP 配合)
典型权衡:
- 显存降下来,但通信量会上去(尤其 TP/ZeRO 的 all-gather / reduce-scatter)
4.3.2 算子优化(同精度更省显存)
选择数值等价或精度一致但更节省显存的实现路径,例如:
- 更低 workspace 的 attention kernel
- fused kernel(减少中间张量)
- 更好的内存复用策略
典型权衡:
- 调参/验证成本高
- 受限于硬件、驱动、CUDA 版本与框架版本
4.3.3 数据类型修改(用低精度替换高精度)
常见策略:
- fp32 →\rightarrow→ fp16/bf16:参数/激活/梯度直接减半
- int8/fp8:进一步降低权重或激活存储(需硬件与软件栈支持)
典型权衡:
- 可能影响收敛稳定性与最终精度
- 需要配合 loss scaling、校准、量化策略
4.3.4 消除框架副本(减少"隐形开销")
一些框架会引入中间副本,例如:
- master weights(混合精度)
- buffer/flatten 参数副本(某些并行实现)
- optimizer state 的冗余保存
典型权衡:
- 改动侵入性强
- 容易引入数值或稳定性问题
4.3.5 显存管理(减少碎片与峰值保留)
显存碎片会让"理论还够"变成"实际 OOM"。优化方向包括:
- 更稳定的分配形状(避免频繁变化的张量大小)
- 让关键路径更少创建短生命周期大张量
- 合理设置 batch/sequence 的组合减少抖动
典型权衡:
- 可控手段相对有限
- 与框架版本和 kernel 实现强绑定
4.3.6 底层 API/库替换(最后一公里)
不同 CUDA API/库版本、不同 kernel 实现,显存开销可能差异明显,例如:
- 使用更省显存的注意力实现(如 FlashAttention 路径)
- 更新到更高版本的 CUDA/cuDNN/cuBLAS 以获得更优的 workspace 策略
- 使用更优化的 SDPA 内核选择逻辑
典型权衡:
- 需要完整验证与回归测试
- 可能影响稳定性与可复现性
5. 贯穿全文的"量级直觉":用一个综合案例串起来
为了把上述概念落到直觉上,可以用"1B 参数模型 + Adam + 混合精度"的量级来记忆(只用作数量级参考):
-
模型参数(bf16)
ModelMembf16≈2×10910243≈1.86 GB \text{ModelMem}_{\text{bf16}} \approx \frac{2\times 10^9}{1024^3}\approx 1.86\ \text{GB} ModelMembf16≈102432×109≈1.86 GB近似记为 ≈2 GB\approx 2\ \text{GB}≈2 GB。
-
Adam 优化器状态(含 fp32 master + m + v)
OptMemAdam≈12×10910243≈11.18 GB \text{OptMem}_{\text{Adam}} \approx \frac{12\times 10^9}{1024^3}\approx 11.18\ \text{GB} OptMemAdam≈1024312×109≈11.18 GB近似记为 ≈12 GB\approx 12\ \text{GB}≈12 GB。
-
梯度(bf16)
GradMembf16≈2×10910243≈1.86 GB \text{GradMem}_{\text{bf16}} \approx \frac{2\times 10^9}{1024^3}\approx 1.86\ \text{GB} GradMembf16≈102432×109≈1.86 GB近似记为 ≈2 GB\approx 2\ \text{GB}≈2 GB。
仅静态项 + 梯度合计就已经大约:
2+12+2≈16 GB 2 + 12 + 2 \approx 16\ \text{GB} 2+12+2≈16 GB
这还没算:
- 激活(通常是训练峰值的核心)
- workspace/通信 buffer/碎片(经常就是"估算 vs 实测"的差距来源)
因此一旦序列长度 sss、层数 LLL、micro-batch bbb 增长,激活项就能把总显存推到远超上述 16GB 的水平。