1.1 模型显存总体分析

1.1 模型显存总体分析

在训练/推理大模型时,GPU 显存并不是"全部都给模型用"。一部分显存会被 AI 框架 (如 PyTorch/CUDA runtime/通信库)占用,另一部分则由 系统/驱动 保留,用于上下文、内核态缓冲、显存页管理、ECC、显示/持久化守护进程等。因此,同一张 GPU 上"理论可用显存"和"实际可用显存"常常不一致。

从工程视角看,最直接的观测方式是通过 GPU 工具查询进程显存占用,例如 NVIDIA 平台常用 nvidia-smi 会打印各进程显存使用量。示例(进程与显存占用):

±--------------------------------------------------------------------------------------+

| Processes: |

| GPU GI CI PID Type Process name GPU Memory |

| ID ID Usage |

|=======================================================================================|

| 1 N/A N/A 67321 C .../anaconda3/envs/py/bin/python 23646MiB |

| 1 N/A N/A 71612 C .../anaconda3/envs/py/bin/python 848MiB |

| 2 N/A N/A 67321 C .../anaconda3/envs/py/bin/python 25776MiB |

±--------------------------------------------------------------------------------------+

需要注意两点:

  • nvidia-smi 看到的是 进程向驱动申请并保留的显存,其中既包括模型/张量,也包括框架内部缓存(例如 CUDA caching allocator 的预留与碎片)。
  • 同一个 PID 可能在多张卡上占用显存(如多卡训练/推理)。

1. 显存由哪些部分构成

从概念上,训练/推理显存开销通常包括:

  • 模型参数(parameters)
  • 优化器状态(optimizer states)
  • 梯度(gradients)
  • 激活(activations)
  • 输入/输出张量(inputs/outputs)
  • 临时张量/工作区(temporary / workspace)
  • 自动求导图与元数据(autograd details / graph metadata)
  • 其他不可归类部分(unknown)

为了更好地"估算 vs 实测"对齐,可以把显存拆成三类:

1.1 可估算值(可用公式近似)

  • parameters
  • optimizer states
  • gradients
  • activations
  • inputs/outputs(通常较小,但长上下文/多模态时也可能显著)

1.2 未命名数据(很难提前精确估计)

  • temporary/workspace(注意力算子、归一化、GEMM 的 workspace)
  • unknown(缓存、碎片、通信缓冲、框架内部保留等)

1.3 框架相关

  • autograd details(反向图、保存的中间信息、hook、AMP scaler 等)

2. 为什么"估算值"和"实际测量值"会差很多

在显存估算时,常见现象是:

估算得到总消耗 50GB,但实际 nvidia-smi 看到占用达到 75GB,误差超过 30%。

这类差距通常来自以下因素叠加(往往不是单一原因):

2.1 CUDA/PyTorch 缓存与碎片

深度学习框架为了加速,会采用缓存分配器:释放的显存不会立刻还给驱动,而是保留在进程内用于复用。这样会导致:

  • 实际"已保留显存"远大于"当前活跃张量显存"
  • 频繁不同大小分配产生碎片,进一步抬高峰值保留

2.2 算子 workspace(临时工作区)

高性能算子(如 Attention、Conv、某些归一化、FlashAttention/SDPA 的不同实现)会按需要申请 workspace。workspace 的大小与:

  • 序列长度 sss
  • batch size(或 micro-batch)bbb
  • head 数 aaa
  • hidden size hhh
  • 数据类型(fp16/bf16/fp8)
  • 是否启用某些 fused kernel
    密切相关,并且在不同 kernel 路径下差异很大。

2.3 通信与并行策略的额外缓冲

在 DP/TP/PP/ZeRO 等并行策略下,通信库(如 NCCL)可能分配:

  • ring buffer
  • fusion buffer
  • reduce-scatter / all-gather 缓冲
    当梯度 bucket 较大、融合策略激进时,缓冲占用会显著提升。

2.4 Autograd 与保存的中间量

即使模型参数/梯度可估算,框架为反向传播保存的中间张量(尤其在未做重计算时)会极大增加显存;同时还包含图结构与元信息。


2. 训练阶段的显存分析

训练阶段的"可估算部分"通常可写为:

TrainMem≈Model+Optimizer+Gradients+Activations \text{TrainMem} \approx \text{Model} + \text{Optimizer} + \text{Gradients} + \text{Activations} TrainMem≈Model+Optimizer+Gradients+Activations

同时按"是否随时间变化"可分为:

  • 静态值:模型参数、优化器状态(训练过程中基本不变)
  • 动态值:激活、梯度(随前向/反向阶段变化,峰值一般出现在反向或梯度聚合附近)

2.1 静态值分析

2.1.1 模型参数显存(Model Memory)

模型参数占用与 参数量数据类型 直接相关。基本计算式:

ModelMem=TypeSize×Params \text{ModelMem} = \text{TypeSize} \times \text{Params} ModelMem=TypeSize×Params

将字节换算到 GB(以 102431024^310243 为换算基准)时,常用写法:

ModelMemfp32=4×Params10243ModelMemfp16/bf16=2×Params10243ModelMemfp8/int8=1×Params10243 \begin{aligned} \text{ModelMem}{\text{fp32}} &= \frac{4 \times \text{Params}}{1024^3} \\ \text{ModelMem}{\text{fp16/bf16}} &= \frac{2 \times \text{Params}}{1024^3} \\ \text{ModelMem}_{\text{fp8/int8}} &= \frac{1 \times \text{Params}}{1024^3} \end{aligned} ModelMemfp32ModelMemfp16/bf16ModelMemfp8/int8=102434×Params=102432×Params=102431×Params

checkpoint 存储大小估算

存储 checkpoint 时通常主要考虑"模型权重本身"。例如 1B 参数、fp32 存储:

DiskSize=4×10910243≈3.725 GB \text{DiskSize}=\frac{4 \times 10^9}{1024^3}\approx 3.725\ \text{GB} DiskSize=102434×109≈3.725 GB

工程上常用近似:

  • 1B params≈4GB (fp32)1\text{B params} \approx 4\text{GB (fp32)}1B params≈4GB (fp32)
  • 1B params≈2GB (fp16/bf16)1\text{B params} \approx 2\text{GB (fp16/bf16)}1B params≈2GB (fp16/bf16)

示例:

  • 13B 模型若以 fp32 权重存储,粗略估算为 13×4≈52 GB13 \times 4 \approx 52\ \text{GB}13×4≈52 GB。

注意:许多训练采用混合精度,但最终保存 checkpoint 的权重格式取决于保存策略;不少默认保存仍可能是 fp32(或包含 fp32 master weights),因此需要区分"训练时参数精度"和"保存时权重精度"。


2.1.2 优化器状态(Optimizer States)

LLM 训练中最常见的是 Adam/AdamW。对每个参数,Adam 通常需要维护:

  • 一阶矩(Momentum)
  • 二阶矩(Variance)
    同时在混合精度训练中,往往还存在一份 fp32 的 master 参数副本。

因此 Adam 的状态占用可近似为(单位 GB):

OptMemAdam=(4+4+4)×Params10243 \text{OptMem}_{\text{Adam}}= \frac{(4 + 4 + 4)\times \text{Params}}{1024^3} OptMemAdam=10243(4+4+4)×Params

其中 (4+4+4)(4+4+4)(4+4+4) 代表:

  • fp32 master 参数副本:4 bytes
  • Momentum:4 bytes
  • Variance:4 bytes
8-bit 优化器(低精度状态)

若使用 8-bit 优化器(例如将一阶/二阶矩压到 8-bit 存储),常见近似为:

OptMem8bit=(4+1+1)×Params10243 \text{OptMem}_{\text{8bit}}= \frac{(4 + 1 + 1)\times \text{Params}}{1024^3} OptMem8bit=10243(4+1+1)×Params

含义:

  • master 参数副本:4 bytes
  • Momentum:1 byte
  • Variance:1 byte

现实实现中还可能有额外的 scale/quantization metadata,因此这是工程近似,不是严格上界。

一个直观对比案例(帮助把量级"刻在脑子里")

假设参数量为 1B1\text{B}1B:

  • fp16/bf16 参数本体大约 2 GB2\ \text{GB}2 GB
  • Adam 状态(含 fp32 master + m + v)大约 12 GB12\ \text{GB}12 GB

也就是说在"常规 Adam + 混合精度"下:

  • 优化器状态往往比模型权重本体大得多
  • 这也是为什么 ZeRO/offload/8-bit optimizer 能带来巨大收益

2.2 动态值分析

动态值最关键的是 激活(activation)梯度(gradient)

2.2.1 激活值(Activation Memory)

激活是训练显存峰值的主要来源之一,因为反向传播需要用到前向中的中间结果(或其可重建形式)。

激活显存大小受多因素影响:

  • 模型结构(层数 LLL、hidden size hhh、head 数 aaa)
  • 序列长度 sss
  • micro-batch bbb
  • 是否启用 重计算(activation checkpointing)
  • 并行策略(TP/PP 导致的保存位置变化)
  • 注意力实现(不同 kernel/workspace)

一种常用的近似公式(参考 Megatron 系列分析思路)可以写为:

ActMem≈s⋅b⋅h⋅(34+5⋅a⋅sh)⋅L⋅γ \text{ActMem} \approx s \cdot b \cdot h \cdot \left(34 + 5\cdot \frac{a \cdot s}{h}\right)\cdot L \cdot \gamma ActMem≈s⋅b⋅h⋅(34+5⋅ha⋅s)⋅L⋅γ

其中(符号解释尽量直白):

  • sss:序列长度(tokens 数)
  • bbb:micro-batch size
  • hhh:隐藏层维度(hidden size)
  • aaa:注意力头数(number of heads)
  • LLL:Transformer 层数
  • γ\gammaγ:比例系数(把"元素个数"换算成 GB)
    • 当采用 fp16/bf16 时,单元素 2 bytes,因此可以把换算写为:
      γ=210243 \gamma = \frac{2}{1024^3} γ=102432
    • 若是 fp32,则单元素 4 bytes,对应:
      γ=410243 \gamma = \frac{4}{1024^3} γ=102434
为什么这个公式里会有 (34+5⋅a⋅sh)\left(34 + 5\cdot \frac{a\cdot s}{h}\right)(34+5⋅ha⋅s)

它本质上是在粗略地把每层需要保存的中间量分成两类:

  • 与 s⋅b⋅hs\cdot b\cdot hs⋅b⋅h 成正比的部分(MLP/残差/归一化等中间张量)
  • 与注意力的 a⋅sa\cdot sa⋅s 相关的部分(例如注意力相关中间量在某些实现里会引入额外依赖)

这不是"精确到字节"的公式,但它能很好解释两个最重要的现象:

  1. sss(序列长度)对激活几乎是线性甚至更高阶的放大器
  2. 层数 LLL 与 micro-batch bbb 直接线性放大激活开销
激活的典型工程现象(案例)
  • 把上下文从 s=2048s=2048s=2048 提到 s=8192s=8192s=8192,即使 batch 不变,激活很可能从"勉强能跑"变成"直接 OOM"。
  • 在多模态(图像/视频)里,视觉 token 本身就会显著增加等效 sss,导致激活更快爆炸。

2.2.2 梯度值(Gradient Memory)

梯度与参数张量形状一致,通常与模型参数采用相同的数据类型(具体取决于实现与是否有 fp32 grad accumulator)。

常用估算:

  • 若梯度为 fp32:
    GradMemfp32=4×Params10243 \text{GradMem}_{\text{fp32}}=\frac{4 \times \text{Params}}{1024^3} GradMemfp32=102434×Params

  • 若梯度为 fp16/bf16:
    GradMemfp16/bf16=2×Params10243 \text{GradMem}_{\text{fp16/bf16}}=\frac{2 \times \text{Params}}{1024^3} GradMemfp16/bf16=102432×Params

在一些配置下可能存在 fp32 梯度累积或额外 buffer,使得梯度相关显存大于上述估算。


3. 推理阶段显存分析

推理阶段不需要保存反向传播所需的激活,也不需要优化器状态与梯度,因此显存结构更简单。

推理显存主要来自:

  • 模型参数(权重)
  • KV cache(自回归生成时)
  • 少量临时张量(算子 workspace)
  • 框架缓存/碎片

因此很多场景下会用一个经验近似:

InferMem≈1.2×ModelMem \text{InferMem} \approx 1.2 \times \text{ModelMem} InferMem≈1.2×ModelMem

这个 1.21.21.2 可以理解为"模型权重之外的框架开销 + 少量临时 buffer"的粗略倍数。它在以下条件下更接近现实:

  • 生成长度不长
  • batch 不大
  • KV cache 可控

但在自回归长生成里,KV cache 往往成为主导项,这时仅用 1.2×1.2\times1.2× 会低估。


4. 显存优化:思路与路径

当模型规模增长远超单卡物理显存时,显存优化通常围绕两条基本原则:

4.1 时间换空间

通过增加计算量换取更低显存占用,例如:

  • 重计算(activation checkpointing):不保存全部激活,反向时再算一遍前向

代价:

  • 额外算力消耗
  • 训练 wall-time 变长
  • 可能增加通信与同步开销(视并行策略而定)

4.2 空间转移

把本应在 GPU 上的内容"挪出去"或"摊开到多卡",例如:

  • 多卡并行(TP/PP/DP/ZeRO)
  • offload 到 CPU 内存或 NVMe(参数/优化器/梯度 offload)

代价:

  • I/O 带宽压力上升(PCIe/NVLink/网络)
  • 延迟增加,吞吐下降
  • 工程复杂度显著提高

4.3 一个常用的优化路径(从高层到低层)

实践中,优化往往按影响范围与性价比从上到下推进:

4.3.1 多卡并行(最常用、通常不损精度)

通过并行策略减少单卡持有的数据量,可参考"模型参数/优化器/梯度/激活"的组成,针对性拆分:

  • TP(Tensor Parallel):按张量维度切分参数与计算
  • PP(Pipeline Parallel):按层切分
  • DP(Data Parallel):按 batch 切分
  • ZeRO:按参数/梯度/优化器状态切分(不同 stage 切分粒度不同)
  • 重计算:进一步降低激活峰值(常与 PP/TP 配合)

典型权衡:

  • 显存降下来,但通信量会上去(尤其 TP/ZeRO 的 all-gather / reduce-scatter)

4.3.2 算子优化(同精度更省显存)

选择数值等价或精度一致但更节省显存的实现路径,例如:

  • 更低 workspace 的 attention kernel
  • fused kernel(减少中间张量)
  • 更好的内存复用策略

典型权衡:

  • 调参/验证成本高
  • 受限于硬件、驱动、CUDA 版本与框架版本

4.3.3 数据类型修改(用低精度替换高精度)

常见策略:

  • fp32 →\rightarrow→ fp16/bf16:参数/激活/梯度直接减半
  • int8/fp8:进一步降低权重或激活存储(需硬件与软件栈支持)

典型权衡:

  • 可能影响收敛稳定性与最终精度
  • 需要配合 loss scaling、校准、量化策略

4.3.4 消除框架副本(减少"隐形开销")

一些框架会引入中间副本,例如:

  • master weights(混合精度)
  • buffer/flatten 参数副本(某些并行实现)
  • optimizer state 的冗余保存

典型权衡:

  • 改动侵入性强
  • 容易引入数值或稳定性问题

4.3.5 显存管理(减少碎片与峰值保留)

显存碎片会让"理论还够"变成"实际 OOM"。优化方向包括:

  • 更稳定的分配形状(避免频繁变化的张量大小)
  • 让关键路径更少创建短生命周期大张量
  • 合理设置 batch/sequence 的组合减少抖动

典型权衡:

  • 可控手段相对有限
  • 与框架版本和 kernel 实现强绑定

4.3.6 底层 API/库替换(最后一公里)

不同 CUDA API/库版本、不同 kernel 实现,显存开销可能差异明显,例如:

  • 使用更省显存的注意力实现(如 FlashAttention 路径)
  • 更新到更高版本的 CUDA/cuDNN/cuBLAS 以获得更优的 workspace 策略
  • 使用更优化的 SDPA 内核选择逻辑

典型权衡:

  • 需要完整验证与回归测试
  • 可能影响稳定性与可复现性

5. 贯穿全文的"量级直觉":用一个综合案例串起来

为了把上述概念落到直觉上,可以用"1B 参数模型 + Adam + 混合精度"的量级来记忆(只用作数量级参考):

  1. 模型参数(bf16)
    ModelMembf16≈2×10910243≈1.86 GB \text{ModelMem}_{\text{bf16}} \approx \frac{2\times 10^9}{1024^3}\approx 1.86\ \text{GB} ModelMembf16≈102432×109≈1.86 GB

    近似记为 ≈2 GB\approx 2\ \text{GB}≈2 GB。

  2. Adam 优化器状态(含 fp32 master + m + v)
    OptMemAdam≈12×10910243≈11.18 GB \text{OptMem}_{\text{Adam}} \approx \frac{12\times 10^9}{1024^3}\approx 11.18\ \text{GB} OptMemAdam≈1024312×109≈11.18 GB

    近似记为 ≈12 GB\approx 12\ \text{GB}≈12 GB。

  3. 梯度(bf16)
    GradMembf16≈2×10910243≈1.86 GB \text{GradMem}_{\text{bf16}} \approx \frac{2\times 10^9}{1024^3}\approx 1.86\ \text{GB} GradMembf16≈102432×109≈1.86 GB

    近似记为 ≈2 GB\approx 2\ \text{GB}≈2 GB。

仅静态项 + 梯度合计就已经大约:
2+12+2≈16 GB 2 + 12 + 2 \approx 16\ \text{GB} 2+12+2≈16 GB

这还没算:

  • 激活(通常是训练峰值的核心)
  • workspace/通信 buffer/碎片(经常就是"估算 vs 实测"的差距来源)

因此一旦序列长度 sss、层数 LLL、micro-batch bbb 增长,激活项就能把总显存推到远超上述 16GB 的水平。


相关推荐
程序员的那些事_1 小时前
宕机瘫痪 13 小时!官方甩锅人类员工,内部员工:自家 AI 干的
人工智能
vm321 小时前
02:Agent Loop 深度剖析:ReAct 循环的工程实现
人工智能·python
Matrix_111 小时前
论文阅读--Agent AI 探索多模态交互的前沿领域(二)
论文阅读·人工智能
无忧智库2 小时前
大型国际机场全域态势感知与航班运行协同决策系统 (A-CDM) 深度解析:打造智慧民航的“最强大脑”(WORD)
人工智能
lisw052 小时前
如何在科学出版中负责任地使用人工智能?
人工智能·机器学习
mtouch3332 小时前
三维数字沙盘智能交互式可视化动态主界面系统
人工智能·ai·信息可视化·无人机·虚拟现实·电子沙盘·数字沙盘
AC赳赳老秦2 小时前
多模态 AI 驱动办公智能化变革:DeepSeek 赋能图文转写与视频摘要的高效实践
java·ide·人工智能·python·prometheus·ai-native·deepseek
未来之窗软件服务2 小时前
AI人工智能(十二)C# 运行sensevoice onnx—东方仙盟练气期
开发语言·人工智能·c#·仙盟创梦ide·东方仙盟
2501_926978332 小时前
嵌套分形意识融合理论3.0:概率分形通用理论与存在意义论的统一整合框架
人工智能·经验分享·机器学习·ai写作·agi