必知必会:大模型训练显存计算与优化详解
AI-Compass 致力于构建最全面、最实用、最前沿的AI技术学习和实践生态,通过六大核心模块的系统化组织,为不同层次的学习者和开发者提供完整学习路径。
- github地址:AI-Compass👈:https://github.com/tingaicompass/AI-Compass
- gitee地址:AI-Compass👈:https://gitee.com/tingaicompass/ai-compass
🌟 如果本项目对您有所帮助,请为我们点亮一颗星!🌟
1. 显存消耗的组成与模型权重计算
1.1 核心问题
大模型训练时显存被什么占满了?不同量化精度下模型权重需要多少显存?
1.2 原文核心要点
深度神经网络训练的显存消耗主要包括两大部分:模型状态(模型权重、梯度、优化器状态)和激活值(各个非线性模块的中间激活值)。不同量化精度下的显存占用差异巨大。
1.3 显存消耗的两大组成部分
换句话说,显存就像你的工作台空间:一部分放置"工具箱和材料"(模型状态),一部分作为"临时加工区"(激活值)。前者大小固定,后者随工作量波动。
| 组成部分 | 具体内容 | 说明 |
|---|---|---|
| 模型状态 | 模型权重(参数)、梯度、优化器状态 | 与模型参数量Φ成正比,是固定开销 |
| 激活值 | 各个非线性模块的中间激活值 | 与batch_size和序列长度相关,是动态开销 |
1.4 模型权重与量化精度的关系
假设模型参数量为 Φ(单位:参数个数),不同量化精度下的显存占用如下:
| 量化程度 | 每参数字节数 | 显存占用 | 1B参数模型 | 7B参数模型 |
|---|---|---|---|---|
| FP32 | 4字节 | 4Φ | 4GB | 28GB |
| FP16/BF16 | 2字节 | 2Φ | 2GB | 14GB |
| INT8 | 1字节 | 1Φ | 1GB | 7GB |
| INT4 | 0.5字节 | ≤1Φ | 0.5GB | 3.5GB |
1.5 模型参数量的计算公式
以 Llama-3 模型为例,其参数量由以下符号定义:
| 符号 | 含义 |
|---|---|
| n_vocab | 词表中词的个数 |
| d_hidden | 隐藏层维度(嵌入向量的维度) |
| n_head | 注意力头的数量 |
| n_kv-head | 分组查询注意力中的键值头数量 |
| n_layer | Transformer的层数 |
| d_FFN | 前馈神经网络的隐藏层维度 |
| b | 输入数据的批次大小(batch size) |
| s | 输入序列长度 |
模型总参数量公式:
\\Phi = n_{\\text{vocab}} \\times d_{\\text{hidden}} + n_{\\text{layer}} \\times \\left\[ d_{\\text{hidden}} + \\left(2 + 2 \\cdot \\frac{n_{\\text{kv}}}{n_{\\text{head}}}\\right) d_{\\text{hidden}}\^2 + d_{\\text{hidden}} + 3 \\cdot d_{\\text{hidden}} \\cdot d_{\\text{FFN}} \\right\] + d_{\\text{hidden}} + d_{\\text{hidden}} \\times n_{\\text{vocab}}
| 组成部分 | 公式项 | 说明 |
|---|---|---|
| 词嵌入层 | n_{\\text{vocab}} \\times d_{\\text{hidden}} | 词表大小 × 隐藏维度 |
| Transformer层(×n_{\\text{layer}}) | 含 QKV 投影 + FFN | GQA 时 KV 头数 < 注意力头数 |
| 输出层 | d_{\\text{hidden}} + d_{\\text{hidden}} \\times n_{\\text{vocab}} | LayerNorm + 输出投影 |
注意:
- 当 n_kv-head = 1 时为多查询注意力(MQA)
- 当 n_kv-head = n_head 时为多头注意力(MHA)
- 当 1 < n_kv-head < n_head 时为分组查询注意力(GQA)
1.6 通俗理解
直观类比
想象你在搬家,需要把所有家当装上卡车(GPU显存)。
- 模型权重 = 你的家具(沙发、床、桌子)------这些是固定的,搬多少次都一样重。
- 梯度 = 每件家具的"搬运说明书"------和家具数量一一对应,同样多。
- 优化器状态 = 每件家具的"维修记录"和"使用日志"------Adam优化器需要记录每个参数的"动量"和"方差",所以额外占用2倍的家具重量。
- 激活值 = 搬运过程中的临时存放点------搬的批次(batch_size)越多,需要的临时空间越大。
量化精度就像选择不同精度的"包装方式":
- FP32 = 用厚实的防震泡沫包裹每件家具(4字节/参数,最安全但最占空间)
- FP16 = 用薄一些的包装(2字节/参数,空间减半)
- INT8 = 只用塑料薄膜简单裹一下(1字节/参数)
核心要点
- 显存 = 模型状态(固定)+ 激活值(动态),两者都需要关注
- 量化精度每降一档,模型权重显存减半
- 7B模型仅权重(FP32)就需要28GB,整体训练显存远超单卡容量
1.7 小结
| 维度 | 说明 |
|---|---|
| 两大组成 | 模型状态(权重+梯度+优化器)+ 激活值 |
| 量化关系 | FP32=4Φ, FP16=2Φ, INT8=1Φ |
| 参数计算 | 含词嵌入层 + n_layer个Transformer层 + 输出层 |
| 关键认知 | 7B模型FP32权重=28GB,训练总显存约112GB |
2. FP32训练与混合精度训练
2.1 核心问题
FP32训练需要多少显存?混合精度训练能节省显存吗?
2.2 原文核心要点
使用AdamW优化器进行FP32训练,模型状态总显存为16Φ。混合精度训练并没有节省模型状态的显存!其真正优势是加速计算和降低激活值显存。
2.3 FP32训练的显存占用
通俗来讲,训练模型不仅要存"模型本身",还要存"每个参数的更新历史"(优化器状态),这才是显存的大头。
使用 AdamW 优化器进行 FP32 训练时:
| 组成部分 | 显存占用 | 说明 |
|---|---|---|
| 模型权重 M_model | 4Φ | FP32参数 |
| 梯度 M_grad | 4Φ | 与模型权重相同精度 |
| 优化器状态 M_optim | 8Φ | 动量(4Φ) + 方差(4Φ) |
| 总计 M_total | 16Φ | 4Φ + 4Φ + 8Φ |
M_{\\text{total}} = M_{\\text{model}} + M_{\\text{grad}} + M_{\\text{optim}} = 4\\Phi + 4\\Phi + 8\\Phi = 16\\Phi
实际案例:
| 模型规模 | 参数量Φ | 模型状态显存(16Φ) | 单张A100(80GB)能否容纳 |
|---|---|---|---|
| 1B | 10亿 | 16GB | 可以 |
| 7B | 70亿 | 112GB | 不可以 |
| 13B | 130亿 | 208GB | 不可以 |
| 70B | 700亿 | 1120GB | 不可以 |
2.4 混合精度训练
混合精度训练使用 FP16/BF16 存储模型参数和梯度,但需要保留一份 FP32 的模型权重副本:
| 组成部分 | 显存占用 | 说明 |
|---|---|---|
| 模型权重 M_model | 2Φ | FP16/BF16 |
| 梯度 M_grad | 2Φ | FP16/BF16 |
| 优化器状态 M_optim | 12Φ | FP32副本(4Φ) + 动量(4Φ) + 方差(4Φ) |
| 总计 M_total | 16Φ | 2Φ + 2Φ + 12Φ |
M_{\\text{total}}\^{\\text{mixed}} = \\underbrace{2\\Phi}*{\\text{FP16权重}} + \\underbrace{2\\Phi}* {\\text{FP16梯度}} + \\underbrace{4\\Phi + 4\\Phi + 4\\Phi}_{\\text{FP32副本 + 动量 + 方差}} = 16\\Phi
关键结论 :混合精度训练并没有节省模型权重、梯度以及优化器状态的显存占用!总计仍为16Φ。
混合精度训练的真正优势:
- 加速前向传播:半精度计算速度更快(尤其Tensor Core加速)
- 降低激活值显存:中间激活值使用半精度存储,显存占用减半
2.5 通俗理解
直观类比
想象你在做账。
- FP32训练 = 所有账目都用"精确到分"的方式记录------账本很厚,但数字精确。
- 混合精度训练 = 日常流水账用"精确到元"的简化方式记(FP16,省纸),但总账还是保留一份"精确到分"的备份(FP32副本),防止长期累积误差。
所以混合精度的"诡异之处"在于:虽然日常计算用了更短的数字,但因为必须保留一份精确备份,总的"账本厚度"(模型状态显存)没有变!省下来的只是"草稿纸"(激活值)的纸张。
换句话说,混合精度训练不是为了"省空间",而是为了"算得快"------就像用计算器比手算快,虽然计算器和笔记本加起来并不比原来的大账本轻。
核心要点
- FP32训练模型状态总显存 = 16Φ(权重4Φ + 梯度4Φ + 优化器8Φ)
- 混合精度训练模型状态仍然是16Φ,不节省模型状态显存
- 混合精度的真正价值:计算加速 + 激活值显存减半
2.6 小结
| 维度 | 说明 |
|---|---|
| FP32训练 | 模型状态=16Φ(7B模型=112GB) |
| 混合精度 | 模型状态仍=16Φ,不省显存 |
| 混合精度真正优势 | 计算加速 + 激活值减半 |
| 关键认识 | Adam优化器状态占大头(8-12Φ) |
3. 激活值显存与梯度检查点
3.1 核心问题
训练过程中的激活值占多少显存?如何用"计算换显存"来优化?
3.2 原文核心要点
激活值是训练过程中必须缓存的中间结果,用于反向传播时计算梯度。激活值显存与batch_size × sequence_length成正比。梯度检查点通过重新计算来换取显存。
3.3 激活值显存计算
换个角度理解:激活值就像炒菜时的"中间半成品"(切好的菜、调好的酱汁)。反向传播时需要这些"半成品"来计算梯度,所以前向传播必须全部保存下来。
自注意力机制的激活值:
| 计算阶段 | 需保存的激活值 | 显存占用 |
|---|---|---|
| 归一化前的输入 | 前置归一化输入 | 2×b×s×d_hidden |
| QKV投影后 | Q、K、V矩阵 | 2×b×s×(d_hidden + d_hidden×n_kv-head/n_head×2) |
| Softmax前 | 注意力logits | 2×b×n_head×s×s |
| Dropout掩码 | 0/1矩阵 | 1×b×n_head×s×s |
| Dropout后 | 注意力得分 | 2×b×n_head×s×s |
| 输出投影前 | 注意力输出 | 2×b×s×d_hidden |
自注意力激活值总量:
M_{\\text{attn}} = 8 \\cdot b \\cdot s \\cdot d_{\\text{hidden}} + 4 \\cdot \\frac{n_{\\text{kv}}}{n_{\\text{head}}} \\cdot b \\cdot s \\cdot d_{\\text{hidden}} + 5 \\cdot b \\cdot s\^2 \\cdot n_{\\text{head}}
FFN激活值总量:
M_{\\text{FFN}} = 8 \\cdot b \\cdot s \\cdot d_{\\text{FFN}} + 2 \\cdot b \\cdot s \\cdot d_{\\text{hidden}}
每层总激活值显存:
M_{\\text{layer,act}} = \\left(10 + 4 \\cdot \\frac{n_{\\text{kv}}}{n_{\\text{head}}}\\right) \\cdot b \\cdot s \\cdot d_{\\text{hidden}} + 8 \\cdot b \\cdot s \\cdot d_{\\text{FFN}} + 5 \\cdot b \\cdot s\^2 \\cdot n_{\\text{head}}
模型总激活值显存:
M_{\\text{total,act}} = n_{\\text{layer}} \\times \\left\[\\left(10 + 4 \\cdot \\frac{n_{\\text{kv}}}{n_{\\text{head}}}\\right) b \\cdot s \\cdot d_{\\text{hidden}} + 8 \\cdot b \\cdot s \\cdot d_{\\text{FFN}} + 5 \\cdot b \\cdot s\^2 \\cdot n_{\\text{head}}\\right\] + 4 \\cdot b \\cdot s \\cdot d_{\\text{hidden}}
| 符号 | 含义 |
|---|---|
| b | 批次大小 (batch size) |
| s | 序列长度 |
| d_{\\text{hidden}} | 隐藏层维度 |
| d_{\\text{FFN}} | FFN 中间维度 |
| n_{\\text{head}} | 注意力头数 |
| n_{\\text{kv}} | KV 头数(GQA) |
| n_{\\text{layer}} | Transformer 层数 |
数值示例:以 Llama-3 8B 为例(n_{\\text{layer}}=32, d_{\\text{hidden}}=4096, n_{\\text{head}}=32, n_{\\text{kv}}=8, d_{\\text{FFN}}=14336, b=1, s=4096):
- 自注意力激活:8 \\times 1 \\times 4096 \\times 4096 + 4 \\times \\frac{8}{32} \\times 1 \\times 4096 \\times 4096 + 5 \\times 1 \\times 4096\^2 \\times 32 \\approx 2.7\\text{GB}(单层)
- 注意力矩阵项 5 \\cdot b \\cdot s\^2 \\cdot n_{\\text{head}} 在长序列时迅速增长,这就是 FlashAttention 等方法的优化目标
关键洞察 :激活值显存与 batch_size × sequence_length 成正比,其中注意力矩阵部分与 s² 成正比,这就是长序列训练的显存瓶颈。
3.4 梯度检查点(Gradient Checkpointing)
核心思想:用计算换显存------在前向传播时不保存所有激活值,而是在反向传播时重新计算。
普通训练:保存所有激活值 → 显存大,速度快
梯度检查点:只保存部分激活值 → 显存小,需要重新计算部分激活值
实际案例:
假设模型有32层(L=32),梯度检查点的效果:
| 方案 | 保存的激活值层数 | 显存占用 | 额外计算开销 |
|---|---|---|---|
| 普通训练 | 32层全部保存 | O(L) = O(32) | 0% |
| 梯度检查点(√L个) | √32 ≈ 6层 | O(√L) ≈ O(6) | ~25-30% |
| 极端检查点(仅首层) | 1层 | O(1) | ~100%(等于两次前向) |
结论:使用√L个检查点是最优平衡,显存从O(L)降到O(√L),仅增加约25-30%计算时间。
3.5 通俗理解
直观类比
想象你在考数学试卷,有32道大题需要先做"草稿"再写"答案"。
- 普通训练 = 每道题的草稿都保留在草稿纸上。写答案(反向传播)时随时能查看。缺点:需要一大叠草稿纸(显存占用大)。
- 梯度检查点 = 只保留每5道题的草稿(关键节点)。写答案时如果需要第3题的草稿,就从第1题的草稿重新推算到第3题。缺点:需要多花时间重新算,但省了大量草稿纸。
激活值中注意力矩阵与序列长度的平方成正比------就像写作文时,文章越长,你需要记住的"前后文关联"就呈爆炸式增长。这就是为什么长文本训练特别吃显存。
核心要点
- 激活值显存与 batch_size × seq_len 成正比,注意力矩阵与 seq_len² 成正比
- 梯度检查点用约25-30%的额外计算时间换取显存从O(L)降到O(√L)
- 长序列训练的显存瓶颈在于注意力矩阵的二次增长
3.6 小结
| 维度 | 说明 |
|---|---|
| 激活值关键因素 | batch_size, seq_len, d_hidden, n_layer |
| 注意力瓶颈 | 注意力矩阵与seq_len²成正比 |
| 梯度检查点 | 用25-30%计算换取O(L)→O(√L)显存 |
| 适用场景 | 显存受限但算力充足,训练超大模型时必用 |
4. 数据并行与ZeRO优化
4.1 核心问题
如何让多张GPU协同训练?ZeRO优化如何将单卡显存从16Φ降到16Φ/N?
4.2 原文核心要点
DDP虽然实现多卡并行但不节省单卡显存。ZeRO通过三个层次逐步切分优化器状态、梯度和模型参数,最终实现16Φ/N的单卡显存。
4.3 集合通信原语
建立直觉之后,让我们深入理解分布式训练的通信基础------集合通信原语,它们是所有多卡协同的基石。
多卡训练的核心挑战是"如何让所有GPU保持同步"。这需要三种基础通信操作,可以类比为:大家一起做作业时"抄答案""汇总结果""分发任务"的不同方式。
分布式训练依赖三种核心通信操作:
| 通信原语 | 功能 | 说明 |
|---|---|---|
| All-Gather | 全聚集操作 | 从多个设备收集结果,并同步完整状态到所有设备 |
| Reduce-Scatter | 规约-分发操作 | 执行聚合操作(求和等),每个进程只获取结果的一部分 |
| All-Reduce | 全规约操作 | 执行规约操作后,将结果同步到所有设备 |
重要结论:All-Reduce = Reduce-Scatter + All-Gather
下图展示了三种集合通信原语的数据流向对比(以 4 卡为例):
flowchart TD subgraph AG["All-Gather:收集完整数据"] AG1[GPU0: A] --> AG_R[所有GPU] AG2[GPU1: B] --> AG_R AG3[GPU2: C] --> AG_R AG4[GPU3: D] --> AG_R AG_R --> AG_O1[GPU0: ABCD] AG_R --> AG_O2[GPU1: ABCD] AG_R --> AG_O3[GPU2: ABCD] AG_R --> AG_O4[GPU3: ABCD] end subgraph RS["Reduce-Scatter:规约后分片"] RS1[GPU0: A] --> RS_SUM[Sum: A+B+C+D] RS2[GPU1: B] --> RS_SUM RS3[GPU2: C] --> RS_SUM RS4[GPU3: D] --> RS_SUM RS_SUM --> RS_O1[GPU0: Sum_part0] RS_SUM --> RS_O2[GPU1: Sum_part1] RS_SUM --> RS_O3[GPU2: Sum_part2] RS_SUM --> RS_O4[GPU3: Sum_part3] end subgraph AR["All-Reduce:规约后全员获取"] AR1[GPU0: A] --> AR_SUM[Sum: A+B+C+D] AR2[GPU1: B] --> AR_SUM AR3[GPU2: C] --> AR_SUM AR4[GPU3: D] --> AR_SUM AR_SUM --> AR_O1[GPU0: Sum] AR_SUM --> AR_O2[GPU1: Sum] AR_SUM --> AR_O3[GPU2: Sum] AR_SUM --> AR_O4[GPU3: Sum] end
数值示例:假设 4 张 GPU 各自计算得到一个梯度向量的分片
| GPU | 初始数据 | All-Gather | Reduce-Scatter(求和) | All-Reduce(求和) |
|---|---|---|---|---|
| GPU0 | [1, 2] | [1,2,3,4,5,6,7,8] | [10, 12](所有第1-2个元素之和) | [10,12,14,16,18,20,22,24] |
| GPU1 | [3, 4] | [1,2,3,4,5,6,7,8] | [14, 16](所有第3-4个元素之和) | [10,12,14,16,18,20,22,24] |
| GPU2 | [5, 6] | [1,2,3,4,5,6,7,8] | [18, 20](所有第5-6个元素之和) | [10,12,14,16,18,20,22,24] |
| GPU3 | [7, 8] | [1,2,3,4,5,6,7,8] | [22, 24](所有第7-8个元素之和) | [10,12,14,16,18,20,22,24] |
4.4 DP vs DDP
| 特性 | DP (Data Parallelism) | DDP (Distributed Data Parallel) |
|---|---|---|
| 进程模型 | 单进程多线程 | 多进程 |
| 主设备 | 设备0负载重(通信、计算、存储不均衡) | 各设备独立 |
| 梯度同步 | 设备0聚合所有梯度 | All-Reduce同步 |
| 单卡显存 | 16Φ | 16Φ |
| 显存优化 | 无 | 无 |
DDP显存占用 :每个设备都需要完整的16Φ,DDP没有实现任何显存节省!
通俗理解 DP vs DDP:
想象一个团队做同一个项目的4份报告(数据并行)。
- DP(数据并行) = 有一个主管(设备0)统一协调。4个人各自写报告,最后都交给主管汇总修改意见,再由主管统一分发更新。问题:主管工作量特别大(负载不均衡)。
- DDP(分布式数据并行) = 4个人各自独立写报告,写完后大家一起开会讨论(All-Reduce),每个人都得到完整的修改意见,然后各自更新。优势:负载均衡。问题:每个人还是需要准备全套材料(显存不省)。
4.5 ZeRO的三个层次
ZeRO(Zero Redundancy Optimizer)是微软提出的显存优化技术,核心思想是将优化器状态、梯度、模型权重分块处理,分配到多个设备上。
| 层次 | 切分内容 | 常驻单卡显存 | N=8时单卡显存 |
|---|---|---|---|
| ZeRO-1 (P_os) | 优化器状态 | 4Φ + 12Φ/N | 4Φ + 1.5Φ = 5.5Φ |
| ZeRO-2 (P_os+g) | 优化器状态 + 梯度 | 2Φ + 14Φ/N | 2Φ + 1.75Φ = 3.75Φ |
| ZeRO-3 (P_os+g+p) | 全部 | 16Φ/N | 2Φ |
其中 N = num_devices(设备数量)
M_{\\text{ZeRO-1}} = 4\\Phi + \\frac{12\\Phi}{N}, \\quad M_{\\text{ZeRO-2}} = 2\\Phi + \\frac{14\\Phi}{N}, \\quad M_{\\text{ZeRO-3}} = \\frac{16\\Phi}{N}
实际案例:
训练7B模型(Φ=7B),使用8张A100(80GB):
| 配置 | 单卡模型状态显存 | 单卡是否可行 |
|---|---|---|
| DDP(无优化) | 16×7 = 112GB | 不可行(>80GB) |
| ZeRO-1(8卡) | 5.5×7 ≈ 38.5GB | 可行 |
| ZeRO-2(8卡) | 3.75×7 ≈ 26.3GB | 可行(更宽裕) |
| ZeRO-3(8卡) | 2×7 = 14GB | 非常宽裕 |
ZeRO各层次工作原理:
ZeRO-1:每个设备保存完整模型参数和梯度,优化器状态被切分。参数更新后通过All-Gather同步。
ZeRO-2:每个设备保存完整模型参数。梯度计算完一层后立即通过Reduce-Scatter分发,每个设备只保留自己负责的梯度分片。
ZeRO-3:模型参数、梯度、优化器状态全部切分。前向传播时通过All-Gather临时收集所需参数,计算完成后丢弃。
下图展示了ZeRO三个层次的切分策略对比:
flowchart LR subgraph DDP["DDP(无优化)"] D1[每卡: 参数2Φ + 梯度2Φ + 优化器12Φ = 16Φ] end subgraph Z1["ZeRO-1"] Z1A[每卡: 参数2Φ + 梯度2Φ] Z1B[优化器12Φ/N 切分] end subgraph Z2["ZeRO-2"] Z2A[每卡: 参数2Φ] Z2B[梯度+优化器 14Φ/N 切分] end subgraph Z3["ZeRO-3"] Z3A[全部 16Φ/N 切分] end DDP -->|切分优化器| Z1 Z1 -->|切分梯度| Z2 Z2 -->|切分参数| Z3
上图展示了ZeRO从DDP逐步切分到全分片的演进路径,每一步都进一步降低单卡显存。
4.6 ZeRO-Offload与梯度累积
ZeRO-Offload:当显存仍然不足时,可以将部分数据卸载到CPU内存甚至磁盘。
梯度累积:进行n次前向传播后再进行一次反向传播,等效增大batch_size为原来的n倍,只需一次反向传播的激活值显存。
4.7 通俗理解
直观类比
想象4个同学一起背一本很厚的字典。
- DDP = 每人各买一本完整字典(参数+梯度+优化器全部冗余),然后各自背不同的单词,最后交流学习心得。问题:每人都要扛一整本字典(显存不省)。
- ZeRO-1 = 每人买一本完整字典,但"笔记本"(优化器状态)拆成4份,每人只带1/4的笔记。查笔记时问其他同学借看一下。
- ZeRO-2 = 字典还是每人一本,但"笔记本"和"错题集"(梯度)都拆成4份。
- ZeRO-3 = 连字典也拆成4份!每人只带1/4字典+1/4笔记+1/4错题集。需要查某个字时,临时向拥有该部分的同学借阅。
ZeRO-3最省"书包空间"(显存),但"借阅"次数最多(通信开销最大)。
核心要点
- DDP不省显存,每卡仍需16Φ
- ZeRO逐级切分:优化器→梯度→参数,单卡显存从16Φ降到16Φ/N
- 级别越高显存越省,但通信开销越大------需要根据网络带宽权衡
4.8 小结
| 维度 | 说明 |
|---|---|
| DDP | 多卡并行但单卡仍需16Φ,不省显存 |
| ZeRO-1 | 切分优化器,单卡=4Φ+12Φ/N |
| ZeRO-2 | 切分优化器+梯度,单卡=2Φ+14Φ/N |
| ZeRO-3 | 全切分,单卡=16Φ/N(最省但通信最大) |
| 通信基础 | All-Reduce = Reduce-Scatter + All-Gather |
5. 模型并行与3D并行训练
5.1 核心问题
当模型太大连ZeRO也不够时怎么办?如何配置数据并行+张量并行+流水线并行的3D并行?
5.2 原文核心要点
ZeRO切分的是模型状态,模型并行切分的是模型计算图。现代大模型通常采用3D并行:数据并行+张量并行+流水线并行,总卡数 = D_dp × D_tp × D_pp。
5.3 ZeRO vs 模型并行
简单理解:ZeRO 是"把材料分散存放"(参数、梯度、优化器切分到多卡),模型并行是"把工作流程切分"(不同卡负责不同的计算步骤)。前者省存储空间,后者省计算时的显存占用。
| 特性 | ZeRO (尤其ZeRO-3) | 模型并行 |
|---|---|---|
| 切分对象 | 模型状态(参数、梯度、优化器) | 模型计算图 |
| 设备间传递 | 模型参数、梯度、优化器状态 | 中间激活值 |
| 每个设备功能 | 获取完整参数后独立计算完整梯度 | 只负责模型的一部分计算 |
5.4 模型并行的两种类型
| 类型 | 切分方式 | 通信特点 | 适用场景 |
|---|---|---|---|
| 张量并行 (TP) | 按矩阵分块,切分单层 | 通信频繁,需高带宽 | 同机多卡(NVLink互联) |
| 流水线并行 (PP) | 按层切分,不同层放不同设备 | 通信量小,仅层间传递 | 跨机多卡 |
模型并行的显存占用:
M_{\\text{per_gpu}} = \\frac{16\\Phi}{D_{\\text{tp}} \\times D_{\\text{pp}}}
5.5 3D并行训练配置
现代大模型通常采用 3D并行:数据并行 + 张量并行 + 流水线并行
核心公式:
D_{\\text{dp}} \\times D_{\\text{tp}} \\times D_{\\text{pp}} = N_{\\text{devices}}
结合ZeRO-1的每卡显存:
M_{\\text{per_gpu}} = \\frac{4\\Phi}{D_{\\text{tp}} \\times D_{\\text{pp}}} + \\frac{12\\Phi}{D_{\\text{dp}} \\times D_{\\text{tp}} \\times D_{\\text{pp}}}
重要结论 :D_dp × D_tp × D_pp = num_devices(总卡数),卡数越多,分摊到每个设备的优化器状态就越少。在万卡集群中,优化器状态甚至可以忽略不计!
实际案例:
配置128张GPU(4机,每机32卡)训练70B模型:
| 并行维度 | 值 | 配置依据 |
|---|---|---|
| 张量并行 D_tp | 8 | 机内8卡做张量并行(NVLink高速互联) |
| 流水线并行 D_pp | 4 | 4机做流水线并行(跨机通信量小) |
| 数据并行 D_dp | 128/(8×4) = 4 | 4路数据并行 |
每卡显存(模型状态):
M_{\\text{权重+梯度}} = \\frac{4\\Phi}{D_{\\text{tp}} \\times D_{\\text{pp}}} = \\frac{4 \\times 70\\text{B}}{8 \\times 4} = \\frac{280}{32} \\approx 8.75\\text{GB}
M_{\\text{优化器}} = \\frac{12\\Phi}{D_{\\text{dp}} \\times D_{\\text{tp}} \\times D_{\\text{pp}}} = \\frac{12 \\times 70\\text{B}}{4 \\times 8 \\times 4} = \\frac{840}{128} \\approx 6.56\\text{GB}
M_{\\text{总计}} \\approx 8.75 + 6.56 = 15.3\\text{GB} \\quad (\\text{远小于A100的80GB,宽裕!})
下图展示了3D并行训练的维度划分:
flowchart TD A[128张GPU总集群] --> B[数据并行 D_dp=4] B --> C[数据并行组0: 32卡] B --> D[数据并行组1: 32卡] B --> E[数据并行组2: 32卡] B --> F[数据并行组3: 32卡] C --> G[流水线阶段0: 8卡 TP] C --> H[流水线阶段1: 8卡 TP] C --> I[流水线阶段2: 8卡 TP] C --> J[流水线阶段3: 8卡 TP]
上图展示了128卡3D并行配置:4路数据并行,每组内4个流水线阶段,每阶段8卡做张量并行。
5.6 通俗理解
直观类比
想象你要建一座摩天大楼(训练大模型),需要组织128个工人(GPU)。
- 数据并行 = 把工人分成4组,每组建完全相同的一栋楼,最后取平均效果。问题:每组都要准备全套材料(显存不省)。
- 流水线并行 = 把楼分成4段(地基→主体→装修→封顶),每组工人负责一段。上一段完工后交给下一段。
- 张量并行 = 每段内,8个工人一起砌同一面墙的不同部分。需要频繁沟通对齐接缝(高带宽通信)。
3D并行就是三种方式的组合:4组建筑队(DP=4),每队分4段工程(PP=4),每段工程8人协作(TP=8),128人各司其职。
核心要点
- 3D并行 = 数据并行 × 张量并行 × 流水线并行
- 张量并行适合机内(高带宽),流水线并行适合跨机(通信量小)
- 万卡集群中,优化器状态占比趋近于零
5.7 小结
| 维度 | 说明 |
|---|---|
| ZeRO vs 模型并行 | ZeRO切分状态,模型并行切分计算图 |
| 张量并行(TP) | 层内切分,需高带宽,适合机内 |
| 流水线并行(PP) | 层间切分,通信量小,适合跨机 |
| 3D并行公式 | D_dp × D_tp × D_pp = 总卡数 |
| 万卡集群 | 优化器状态趋近于零 |
6. 显存优化方法总结
| 优化方法 | 原理 | 显存节省 | 代价 |
|---|---|---|---|
| 混合精度训练 | FP16/BF16计算 | 激活值减半(模型状态不省) | 需要FP32副本 |
| 梯度检查点 | 重计算激活值 | 激活值从O(L)到O(√L) | 增加约25-30%计算时间 |
| 梯度累积 | 多次前向,一次反向 | 减少激活值 | 等效增大batch |
| ZeRO-1 | 切分优化器状态 | 优化器状态/N | 额外通信 |
| ZeRO-2 | 切分优化器+梯度 | (优化器+梯度)/N | 更多通信 |
| ZeRO-3 | 全切分 | 全部状态/N | 显著增加通信 |
| 模型并行 | 切分模型计算图 | 模型相关/并行度 | 实现复杂 |
| ZeRO-Offload | CPU/磁盘卸载 | 大幅降低GPU显存 | 增加IO开销 |
显存公式速记
| 场景 | 公式 |
|---|---|
| FP32训练 | 16Φ |
| 混合精度训练 | 16Φ(模型状态相同) |
| ZeRO-1 | 4Φ + 12Φ/N |
| ZeRO-2 | 2Φ + 14Φ/N |
| ZeRO-3 | 16Φ/N |
| 模型并行 | 16Φ/(D_tp×D_pp) |
| 3D并行+ZeRO-1 | 4Φ/(D_tp×D_pp) + 12Φ/(D_dp×D_tp×D_pp) |
关键数字速记
- 1B参数(FP32):4GB
- 1B参数(FP16):2GB
- 7B模型训练(混合精度):约112GB模型状态
- Adam优化器:每参数额外8字节(动量4+方差4)
7. 高频面试题及答案
Q1: 请解释大模型训练中显存的主要组成部分。【基础】
答案 :
大模型训练显存由两大部分组成:模型状态(模型权重+梯度+优化器状态)和激活值。使用AdamW+FP32训练时,模型状态=16Φ(权重4Φ+梯度4Φ+优化器8Φ),7B模型仅模型状态就需112GB。
详细说明:
| 要点 | 说明 |
|---|---|
| 模型权重 | 可训练参数,FP32=4Φ, FP16=2Φ |
| 梯度 | 与参数一一对应,用于更新 |
| 优化器状态 | Adam需维护动量(4Φ)+方差(4Φ)=8Φ |
| 激活值 | 中间结果,与batch_size×seq_len成正比 |
Q2: 混合精度训练能节省多少显存?【基础】
答案 :
混合精度训练不能节省模型状态显存(仍为16Φ),因为需要保留FP32副本确保数值稳定性。其真正优势是:FP16计算加速(Tensor Core)和激活值显存减半。
详细说明:
| 要点 | 说明 |
|---|---|
| 模型状态 | FP32=16Φ, 混合精度=16Φ(不变) |
| 激活值 | 使用FP16存储,显存减半 |
| 计算速度 | FP16计算更快,尤其Tensor Core |
| FP32副本 | 必须保留,防止累积更新精度损失 |
Q3: 请详细解释ZeRO的三个阶段及其显存优化原理。【进阶】
答案 :
ZeRO通过逐级切分实现显存优化:ZeRO-1切分优化器状态(单卡=4Φ+12Φ/N),ZeRO-2额外切分梯度(2Φ+14Φ/N),ZeRO-3全部切分(16Φ/N)。级别越高显存越省但通信开销越大。
详细说明:
| 要点 | 说明 |
|---|---|
| ZeRO-1 | 只切分优化器,通信开销最小 |
| ZeRO-2 | 梯度Reduce-Scatter后只保留自己的分片 |
| ZeRO-3 | 前向时All-Gather临时收集参数,计算后丢弃 |
| 权衡 | 级别越高,显存越省,通信越多 |
Q4: 数据并行(DP/DDP)和模型并行有什么区别?【基础】
答案 :
数据并行切分数据(每卡完整模型副本,通过All-Reduce同步梯度),模型并行切分模型(张量并行切矩阵、流水线并行切层,设备间传递激活值)。DDP不省显存(每卡16Φ),模型并行显存=16Φ/(D_tp×D_pp)。
详细说明:
| 要点 | 说明 |
|---|---|
| DDP切分 | 数据切分,每卡完整模型,All-Reduce同步梯度 |
| 模型并行切分 | 模型计算图切分,设备间传递激活值 |
| 张量并行(TP) | 层内矩阵切分,需高带宽(NVLink) |
| 流水线并行(PP) | 层间切分,通信量小,适合跨机 |
Q5: 什么是梯度检查点?它如何节省显存?【进阶】
答案 :
梯度检查点用计算换显存:前向传播时只保存选定检查点位置的激活值,反向传播时从最近检查点重新计算中间激活值。使用√L个检查点,显存从O(L)降到O(√L),代价是约25-30%的额外计算时间。
详细说明:
| 要点 | 说明 |
|---|---|
| 核心思想 | 不保存所有激活值,反向传播时重新计算 |
| 最优配置 | √L个检查点(L为层数) |
| 显存节省 | 从O(L)降到O(√L) |
| 计算代价 | 增加约25-30%训练时间 |
Q6: 请解释All-Reduce、All-Gather和Reduce-Scatter的区别。【基础】
答案 :
All-Gather:每个设备收集所有设备的数据,最终所有设备有完整数据。Reduce-Scatter:执行规约后每个设备只获取结果的一部分。All-Reduce:规约后所有设备得到完整结果。核心关系:All-Reduce = Reduce-Scatter + All-Gather。
详细说明:
| 要点 | 说明 |
|---|---|
| All-Gather | [A],[B],[C]→ 每个设备都得到[A,B,C] |
| Reduce-Scatter | 规约后分片,每设备只得一部分结果 |
| All-Reduce | 规约后广播,每设备得完整结果 |
| 关系 | All-Reduce = Reduce-Scatter + All-Gather |
Q7: 如何估算训练一个7B参数模型需要多少显存?【进阶】
答案 :
混合精度+AdamW:模型状态=16×7=112GB(FP16权重14GB+FP16梯度14GB+FP32副本28GB+动量28GB+方差28GB),加上激活值约10-30GB,总计约120-150GB。单张A100(80GB)不够,用ZeRO-3(2卡)约56GB/卡可行。
详细说明:
| 要点 | 说明 |
|---|---|
| 模型权重(FP16) | 7B×2=14GB |
| 梯度(FP16) | 7B×2=14GB |
| 优化器(FP32) | 副本28GB+动量28GB+方差28GB=84GB |
| 解决方案 | ZeRO-3(2卡)=56GB/卡,或ZeRO-1(8卡)≈38.5GB/卡 |
Q8: DeepSeek为什么选择流水线并行+ZeRO-1而不是ZeRO-3?【进阶】
答案 :
工程权衡:ZeRO-3每次前向/反向都需All-Gather参数(通信量巨大),而流水线并行仅传递层间激活值。在多机场景下,PP+ZeRO-1通信更可控。且MoE架构下TP收益有限,万卡集群中优化器状态本就可忽略。
详细说明:
| 要点 | 说明 |
|---|---|
| ZeRO-3问题 | 每次前向/反向都要All-Gather,通信量大 |
| PP优势 | 仅层间传递激活值,通信量可控 |
| ZeRO-1足够 | 万卡集群中优化器状态=12Φ/N→趋近于零 |
| MoE考虑 | 张量并行对MoE架构收益有限 |
Q9: 激活值显存与哪些因素相关?如何优化?【进阶】
答案 :
激活值与batch_size(线性)、seq_len(线性+注意力矩阵的平方)、d_hidden(线性)、n_layer(线性)相关。优化方法:梯度检查点(O(L)→O(√L))、减小batch_size、梯度累积、FlashAttention(注意力从O(s²)→O(s))、序列并行。
详细说明:
| 要点 | 说明 |
|---|---|
| 主要因素 | b, s, d_hidden, n_layer |
| 注意力瓶颈 | 注意力矩阵与s²成正比(长序列瓶颈) |
| 梯度检查点 | O(L)→O(√L),增25-30%计算 |
| FlashAttention | 融合kernel,注意力从O(s²)→O(s) |
Q10: 3D并行训练如何配置?各维度的考虑因素是什么?【进阶】
答案 :
D_dp×D_tp×D_pp=总卡数。TP适合机内(≤8,需NVLink高带宽),PP适合跨机(层数需被PP整除),DP为剩余卡数。128卡典型配置:TP=8, PP=4, DP=4。先定TP(不超机内卡数),再定PP(根据层数和机数),最后算DP。
详细说明:
| 要点 | 说明 |
|---|---|
| 张量并行(TP) | 机内高带宽互联,通常≤8 |
| 流水线并行(PP) | 跨机,通信量小,可能有bubble |
| 数据并行(DP) | 总卡数/(TP×PP),增大有效batch |
| 128卡示例 | TP=8, PP=4, DP=4,单卡≈15.3GB |
Q11: 训练一个70B模型,你有256张A100(80GB),请设计完整的显存优化方案并估算每卡显存占用。【综合】
答案:
这是一道综合设计题,需要结合多种显存优化技术。
第一步:确定3D并行配置
- 张量并行 D_{\\text{tp}} = 8(机内8卡NVLink互联)
- 流水线并行 D_{\\text{pp}} = 4(模型80层,每阶段20层)
- 数据并行 D_{\\text{dp}} = 256/(8 \\times 4) = 8
第二步:计算模型状态显存(采用混合精度 + ZeRO-1)
M_{\\text{权重+梯度}} = \\frac{4\\Phi}{D_{\\text{tp}} \\times D_{\\text{pp}}} = \\frac{4 \\times 70\\text{B}}{32} = 8.75\\text{GB}
M_{\\text{优化器}} = \\frac{12\\Phi}{D_{\\text{dp}} \\times D_{\\text{tp}} \\times D_{\\text{pp}}} = \\frac{12 \\times 70\\text{B}}{256} \\approx 3.28\\text{GB}
模型状态总计 ≈ 12.03GB
第三步:估算激活值显存
- 使用梯度检查点,每卡只需保存 \\sqrt{20} \\approx 5 个检查点层的激活值
- 使用 FlashAttention 消除注意力矩阵 O(s\^2) 项
- 估算激活值约 15-25GB(取决于 batch_size 和 seq_len)
第四步:总计
- 模型状态 ≈ 12GB + 激活值 ≈ 20GB + 临时缓冲 ≈ 5GB ≈ 37GB/卡
- 80GB A100 绰绰有余,可以适当增大 batch_size 提升吞吐
Q12: 为什么混合精度训练不省模型状态显存但仍是标配?请结合激活值优化和梯度检查点综合分析。【综合】
答案:
混合精度训练虽然模型状态仍为 16\\Phi(因为必须保留 FP32 副本),但它在三个层面带来收益:
- 计算加速:FP16/BF16 在 Tensor Core 上的吞吐是 FP32 的 2-8 倍
- 激活值减半:中间激活值使用 FP16 存储,对于长序列训练这是巨大的节省
- 与梯度检查点协同:梯度检查点重计算时用 FP16 计算速度更快,减轻了"计算换显存"的代价
综合来看,现代大模型训练的标准配置是:混合精度(加速 + 激活值减半)+ 梯度检查点(激活值从 O(L) 到 O(\\sqrt{L}))+ FlashAttention(消除注意力 O(s\^2) 显存)+ ZeRO/3D并行(切分模型状态)。这四者缺一不可,共同使得千亿参数级模型训练成为可能。
8. 大厂常见面试题
Q13: 请计算训练一个13B模型在不同并行策略下的单卡显存占用,并给出推荐配置。【进阶】
来源:字节跳动/阿里巴巴 大模型训练岗常见计算题
答案:
模型参数 \\Phi = 13\\text{B},以8张A100(80GB)为例:
| 策略 | 单卡显存公式 | 数值 | 是否可行 |
|---|---|---|---|
| DDP(无优化) | 16\\Phi | 16 \\times 13 = 208\\text{GB} | 不可行 |
| ZeRO-1(8卡) | 4\\Phi + 12\\Phi/8 | 52 + 19.5 = 71.5\\text{GB} | 勉强可行(不含激活值) |
| ZeRO-2(8卡) | 2\\Phi + 14\\Phi/8 | 26 + 22.75 = 48.75\\text{GB} | 可行 |
| ZeRO-3(8卡) | 16\\Phi/8 | 26\\text{GB} | 宽裕 |
| TP=8 | 16\\Phi/8 | 26\\text{GB} | 宽裕 |
推荐配置:8卡单机优先用 ZeRO-2 + 梯度检查点 + 混合精度。ZeRO-3 虽然最省显存,但通信开销显著增大(每次前向/反向都需 All-Gather),在机内 NVLink 带宽下 ZeRO-2 通常是更好的平衡点。
Q14: ZeRO-Offload 和 ZeRO-Infinity 的区别是什么?在什么场景下使用?【进阶】
来源:微软/百度 基础架构岗高频题
答案:
| 特性 | ZeRO-Offload | ZeRO-Infinity |
|---|---|---|
| 卸载目标 | 优化器状态 + 梯度 → CPU | 全部(参数+梯度+优化器)→ CPU + NVMe |
| 基于 | ZeRO-2 | ZeRO-3 |
| 适用场景 | 单卡/少卡训练超出显存的模型 | 极端情况,需要在有限GPU上训练超大模型 |
| 性能影响 | CPU-GPU 带宽成为瓶颈,训练速度下降约 30-50% | NVMe 带宽更低,速度进一步下降 |
| 典型用途 | 学术实验室用消费级GPU微调大模型 | 万亿参数模型的可行性验证 |
核心权衡:ZeRO-Offload 用 PCIe 带宽换 GPU 显存,ZeRO-Infinity 进一步用 NVMe 带宽换更多显存。在有充足 GPU 资源时应优先使用纯 GPU 方案(ZeRO-1/2/3 + 模型并行)。
Q15: 序列并行(Sequence Parallelism)是什么?它解决了什么问题?【进阶】
来源:腾讯/华为 大模型团队面试常见问题
答案:
序列并行解决的是张量并行中非并行区域(如 LayerNorm、Dropout)仍需完整激活值的问题。
在标准张量并行中,虽然注意力和 FFN 的计算被切分到多卡,但 LayerNorm 和 Dropout 等操作仍在每张卡上保留完整的激活值。序列并行将序列维度也进行切分:
| 特性 | 张量并行 (TP) | 张量并行 + 序列并行 (TP+SP) |
|---|---|---|
| 注意力/FFN | 按隐藏维度切分 | 按隐藏维度切分 |
| LayerNorm/Dropout | 每卡完整激活值 | 按序列维度切分 |
| 激活值显存 | 约 M_{\\text{act}}/D_{\\text{tp}}(仅并行部分) | 接近 M_{\\text{act}}/D_{\\text{tp}}(全部) |
| 通信变化 | All-Reduce | All-Gather + Reduce-Scatter |
Megatron-LM v3 引入此技术,配合选择性激活重算(selective recomputation),可将激活值显存降低约 5 倍,是训练超长序列的关键技术。
总结
核心知识点回顾
| 知识点 | 核心内容 | 关键公式/数值 |
|---|---|---|
| 显存组成 | 模型状态 + 激活值 | 模型状态=16Φ(AdamW+FP32) |
| 量化精度 | FP32/FP16/INT8/INT4 | 1B参数FP32=4GB |
| 混合精度 | 模型状态不省,激活值减半 | 仍为16Φ |
| 激活值 | 与b×s成正比,注意力与s²成正比 | 长序列是瓶颈 |
| 梯度检查点 | 用计算换显存 | O(L)→O(√L),+25-30%计算 |
| DDP | 不省单卡显存 | 每卡16Φ |
| ZeRO-1/2/3 | 逐级切分优化器/梯度/参数 | 最终16Φ/N |
| 模型并行 | TP(层内)+PP(层间) | 16Φ/(D_tp×D_pp) |
| 3D并行 | DP+TP+PP | D_dp×D_tp×D_pp=总卡数 |
| 万卡集群 | 优化器状态趋近零 | 12Φ/(D_dp×D_tp×D_pp)→0 |
思维导图结构
大模型训练显存计算与优化
├── 1. 显存组成
│ ├── 模型状态(权重+梯度+优化器)= 16Φ
│ ├── 激活值(与b×s成正比,注意力与s²成正比)
│ └── 量化精度: FP32=4Φ, FP16=2Φ, INT8=1Φ
├── 2. 混合精度训练
│ ├── 模型状态仍=16Φ(不省显存!)
│ └── 真正优势: 计算加速 + 激活值减半
├── 3. 激活值优化
│ ├── 梯度检查点: O(L)→O(√L), +25-30%计算
│ ├── 梯度累积: 多次前向+一次反向
│ └── FlashAttention: O(s²)→O(s)
├── 4. 数据并行
│ ├── DDP: 不省显存(每卡16Φ)
│ ├── ZeRO-1: 切分优化器, 4Φ+12Φ/N
│ ├── ZeRO-2: 切分优化器+梯度, 2Φ+14Φ/N
│ ├── ZeRO-3: 全切分, 16Φ/N
│ └── ZeRO-Offload: CPU/磁盘卸载
├── 5. 模型并行
│ ├── 张量并行(TP): 层内切分, 需高带宽
│ ├── 流水线并行(PP): 层间切分, 通信量小
│ └── 显存: 16Φ/(D_tp×D_pp)
└── 6. 3D并行训练
├── 公式: D_dp × D_tp × D_pp = 总卡数
├── 配置: TP≤8(机内) → PP(跨机) → DP(剩余)
└── 万卡集群: 优化器状态→0
参考文献
AI-Compass 致力于构建最全面、最实用、最前沿的AI技术学习和实践生态,通过六大核心模块的系统化组织,为不同层次的学习者和开发者提供完整学习路径。
- github地址:AI-Compass👈:https://github.com/tingaicompass/AI-Compass
- gitee地址:AI-Compass👈:https://gitee.com/tingaicompass/ai-compass
🌟 如果本项目对您有所帮助,请为我们点亮一颗星!🌟