衔接前序 :第 35-38 篇完成了 Transformer 的训练流程(从手算到代码),第 40 篇拆解了推理过程的数学原理,第 41 篇用 KV Cache 优化了解码效率。但所有这些都建立在"单卡 GPU 装得下模型"的前提下。当模型从数百万参数膨胀到数十亿甚至数千亿参数时,单张 GPU 的显存和算力都不再够用------这时候就需要分布式训练。
本文从数学上拆解三种并行维度(数据并行、张量并行、流水线并行),用手算示例展示它们的工作原理,然后深入 ZeRO + FSDP------这是现代 LLM 训练的事实标准。
当然,本文无意过多讨论底层原理,只是简单了解介绍分布式训练,掌握如何进行训练即可
一、引言:为什么需要分布式训练
1.1 单卡瓶颈
考虑一个 7B 参数的模型(如 Llama 2-7B)在单张 GPU 上训练需要多少显存:
| 组件 | 占用(FP32) | 占用(BF16) |
|---|---|---|
| 模型参数 | 7B × 4B = 28 GB | 7B × 2B = 14 GB |
| 梯度 | 7B × 4B = 28 GB | 7B × 2B = 14 GB |
| 优化器状态(Adam:2 个动量) | 7B × 8B = 56 GB | 7B × 8B = 56 GB |
| 激活值(以 batch=1, seq=4096 估算) | ~ 30 GB | ~ 15 GB |
| 总计 | ~142 GB | ~99 GB |
一张 RTX 3090 只有 24 GB 显存。即使使用 BF16,7B 模型也放不进单卡------甚至连模型参数本身(14 GB)都勉强,加上梯度和优化器状态就完全不可能了。
分布式训练的核心思路:把计算和存储拆分到多张 GPU 上,三种基本维度:
┌────────── 数据并行 ──────────┐
│ 每卡一份完整模型,拆分数据 │
│ 通信:All-Reduce 梯度 │
│ 适合:数据量大,模型能放进单卡 │
│ │
数据并行 × 张量并行 × 流水线并行 = 3D 并行
│ │
┌────────── 张量并行 ──────────┐ ┌──────── 流水线并行 ────────┐
│ 拆分单个算子到多卡 │ │ 按层切分,不同 layer 在不同卡 │
│ 通信:每层内 All-Reduce │ │ 通信:层间激活传输 │
│ 适合:单层太大无法计算 │ │ 适合:网络太深,减少 bubble │
└──────────────────────────────┘ └────────────────────────────┘
1.2 本文路线图
- 数据并行 (最直观,PyTorch DDP)→ 2. 张量并行 (Megatron-LM 核心)→ 3. 流水线并行 → 4. ZeRO + FSDP (现代 LLM 标配)→ 5. 3D 并行全景 → 6. 实战
二、数据并行------最直观的并行方式
2.1 同步 SGD 的数学形式
数据并行的想法最直接:把 batch 拆成 N 份,每张卡算一份,然后同步梯度。
设总 batch size 为 BBB,拆到 NNN 张卡,每卡 BN\frac{B}{N}NB 条样本。损失函数为:
L(w)=1B∑i=1Bℓi(w)\mathcal{L}(w) = \frac{1}{B} \sum_{i=1}^{B} \ell_i(w)L(w)=B1i=1∑Bℓi(w)
每张卡 kkk 计算其局部梯度:
gk=NB∑i∈batchk∇ℓi(w)g_k = \frac{N}{B} \sum_{i \in \text{batch}_k} \nabla \ell_i(w)gk=BNi∈batchk∑∇ℓi(w)
通过 All-Reduce 操作对所有卡的梯度取平均:
gˉ=1N∑k=1Ngk=1B∑i=1B∇ℓi(w)\bar{g} = \frac{1}{N} \sum_{k=1}^{N} g_k = \frac{1}{B} \sum_{i=1}^{B} \nabla \ell_i(w)gˉ=N1k=1∑Ngk=B1i=1∑B∇ℓi(w)
然后用平均梯度更新参数:
wt+1=wt−ηgˉw_{t+1} = w_t - \eta \bar{g}wt+1=wt−ηgˉ
关键洞察 :All-Reduce 后的 gˉ\bar{g}gˉ 与单卡全 batch 梯度完全一致------数据并行在数学上是精确的,不会改变收敛性质。
2.2 All-Reduce 的通信数学
All-Reduce 的朴素实现是"一台机器收所有梯度,平均,再广播回去"------但这样存在明显的单点瓶颈和带宽浪费。
Ring All-Reduce 将所有 GPU 组织成一个环,将通信量均匀分摊到所有节点:
┌─────────────┐
GPU 0 ◄──────► GPU 1 │
▲ ▲ │
│ │ │
▼ ▼ │
GPU 3 ◄──────► GPU 2 │
└─────────────┘
通信分两个阶段:
Phase 1 --- ReduceScatter :每卡将数据切成 NNN 个 chunk,沿环做 N−1N-1N−1 次传递。每步,每卡接收相邻卡的 chunk,与本地的对应 chunk 相加。N−1N-1N−1 步后,每卡持有全局求和后的一个 chunk。
Phase 2 --- AllGather :将 ReduceScatter 得到的部分和沿环广播,N−1N-1N−1 步后,所有卡持有完整的全局和。
总通信时间(不考虑延迟):
Tring=2(N−1)α+2N−1NβMT_{\text{ring}} = 2(N-1)\alpha + 2\frac{N-1}{N}\beta MTring=2(N−1)α+2NN−1βM
其中 α\alphaα 是通信延迟常数,β\betaβ 是带宽倒数的传输常数,MMM 是总数据量。
对比朴素的 Parameter Server:
Tps=2Nα+2βMT_{\text{ps}} = 2N\alpha + 2\beta MTps=2Nα+2βM
当 NNN 较大时,Ring All-Reduce 的带宽项 N−1NβM\frac{N-1}{N}\beta MNN−1βM 趋近于 βM\beta MβM,而 Parameter Server 的带宽项 2βM2\beta M2βM 在整个环中是最优的------Ring 的带宽是 Parameter Server 的 2 倍。
2.3 手算示例:2 卡梯度平均
设 N=2N=2N=2 张卡,每卡 batch_size=2,四维参数。
Step 1: 每卡算局部梯度
卡 0 两条样本的梯度:
∇ℓ1=0.2,−0.1,0.3,0.0,∇ℓ2=−0.1,0.2,−0.1,0.1\nabla \ell_1 = 0.2, -0.1, 0.3, 0.0, \quad \nabla \ell_2 = -0.1, 0.2, -0.1, 0.1∇ℓ1=0.2,−0.1,0.3,0.0,∇ℓ2=−0.1,0.2,−0.1,0.1
卡 0 局部梯度(已乘 NB=24=0.5\frac{N}{B} = \frac{2}{4} = 0.5BN=42=0.5):
g0=0.5×(0.2,−0.1,0.3,0.0+−0.1,0.2,−0.1,0.1)=0.05,0.05,0.10,0.05g_0 = 0.5 \times (0.2, -0.1, 0.3, 0.0 + -0.1, 0.2, -0.1, 0.1) = 0.05, 0.05, 0.10, 0.05g0=0.5×(0.2,−0.1,0.3,0.0+−0.1,0.2,−0.1,0.1)=0.05,0.05,0.10,0.05
卡 1 两条样本的梯度:
∇ℓ3=0.1,0.0,−0.2,0.0,∇ℓ4=0.0,−0.3,0.1,0.2\nabla \ell_3 = 0.1, 0.0, -0.2, 0.0, \quad \nabla \ell_4 = 0.0, -0.3, 0.1, 0.2∇ℓ3=0.1,0.0,−0.2,0.0,∇ℓ4=0.0,−0.3,0.1,0.2
卡 1 局部梯度:
g1=0.5×(0.1,0.0,−0.2,0.0+0.0,−0.3,0.1,0.2)=0.05,−0.15,−0.05,0.10g_1 = 0.5 \times (0.1, 0.0, -0.2, 0.0 + 0.0, -0.3, 0.1, 0.2) = 0.05, -0.15, -0.05, 0.10g1=0.5×(0.1,0.0,−0.2,0.0+0.0,−0.3,0.1,0.2)=0.05,−0.15,−0.05,0.10
Step 2: All-Reduce 取平均
gˉ=12(g0+g1)=12(0.10,−0.10,0.05,0.15)=0.05,−0.05,0.025,0.075\bar{g} = \frac{1}{2}(g_0 + g_1) = \frac{1}{2}(0.10, -0.10, 0.05, 0.15) = 0.05, -0.05, 0.025, 0.075gˉ=21(g0+g1)=21(0.10,−0.10,0.05,0.15)=0.05,−0.05,0.025,0.075
验证:直接用所有 4 条样本算全量梯度:
gˉfull=14(0.2,−0.1,0.3,0.0+−0.1,0.2,−0.1,0.1+0.1,0.0,−0.2,0.0+0.0,−0.3,0.1,0.2)=0.05,−0.05,0.025,0.075\bar{g}_{\text{full}} = \frac{1}{4}(0.2, -0.1, 0.3, 0.0 + -0.1, 0.2, -0.1, 0.1 + 0.1, 0.0, -0.2, 0.0 + 0.0, -0.3, 0.1, 0.2) = 0.05, -0.05, 0.025, 0.075gˉfull=41(0.2,−0.1,0.3,0.0+−0.1,0.2,−0.1,0.1+0.1,0.0,−0.2,0.0+0.0,−0.3,0.1,0.2)=0.05,−0.05,0.025,0.075
完全一致 ✓
Step 3: 参数更新
gˉ=0.05,−0.05,0.025,0.075\bar{g} = 0.05, -0.05, 0.025, 0.075gˉ=0.05,−0.05,0.025,0.075
设 η=0.1\eta = 0.1η=0.1,w0=0.0,0.0,0.0,0.0w_0 = 0.0, 0.0, 0.0, 0.0w0=0.0,0.0,0.0,0.0:
w1=w0−ηgˉ=−0.005,0.005,−0.0025,−0.0075w_1 = w_0 - \eta\bar{g} = -0.005, 0.005, -0.0025, -0.0075w1=w0−ηgˉ=−0.005,0.005,−0.0025,−0.0075
三、模型并行------当模型放不进一张卡
当模型本身的参数总量超过单卡显存时,数据并行就不够用了------每卡都存一份完整模型,无异于把 4 个 24 GB 的杯子并排放,但水(模型)需要 100 GB 的容器。
3.1 朴素模型并行(按层切分)
最简单的思路:把 Transformer 的层切到不同 GPU 上。
GPU 0: Embed + Layer 1-6
│
▼ (激活传到 GPU 1)
GPU 1: Layer 7-12 + Output Proj
问题 :串行执行,任何时候只有一张卡在算,GPU 利用率极低。数学上看,并行度为零------Ttotal=∑TlayerT_{\text{total}} = \sum T_{\text{layer}}Ttotal=∑Tlayer,和单卡一样慢。
3.2 张量并行(Tensor Parallelism)
张量并行是 Megatron-LM 的核心贡献。它在一个算子内部(如一个 Linear 层)将矩阵乘法拆分到多卡上并行计算。
列并行(Column Parallel)
一个线性层 Y=XAY = XAY=XA,其中 X∈RB×dinX \in \mathbb{R}^{B \times d_{\text{in}}}X∈RB×din,A∈Rdin×doutA \in \mathbb{R}^{d_{\text{in}} \times d_{\text{out}}}A∈Rdin×dout。
将 AAA 按列切分成 NNN 份:
A=A1,A2,...,AN,Ak∈Rdin×doutNA = A_1, A_2, ..., A_N, \quad A_k \in \mathbb{R}^{d_{\text{in}} \times \frac{d_{\text{out}}}{N}}A=A1,A2,...,AN,Ak∈Rdin×Ndout
每卡计算:
Yk=XAk∈RB×doutNY_k = X A_k \in \mathbb{R}^{B \times \frac{d_{\text{out}}}{N}}Yk=XAk∈RB×Ndout
然后通过 All-Gather 拼接:
Y=Y1,Y2,...,YN∈RB×doutY = Y_1, Y_2, ..., Y_N \in \mathbb{R}^{B \times d_{\text{out}}}Y=Y1,Y2,...,YN∈RB×dout
X (B × d_in)
│
├────── GPU 0: X × A₁ = Y₁ (B × d_out/N)
├────── GPU 1: X × A₂ = Y₂ (B × d_out/N)
├────── GPU 2: X × A₃ = Y₃ (B × d_out/N)
└────── GPU 3: X × A₄ = Y₄ (B × d_out/N)
│
└── All-Gather → Y (B × d_out)
行并行(Row Parallel)
对于 Y=XAY = XAY=XA,将 AAA 按行切分:
A=A1A2⋮AN,Ak∈RdinN×doutA = \begin{bmatrix} A_1 \\ A_2 \\ \vdots \\ A_N \end{bmatrix}, \quad A_k \in \mathbb{R}^{\frac{d_{\text{in}}}{N} \times d_{\text{out}}}A= A1A2⋮AN ,Ak∈RNdin×dout
每卡计算 Yk=XkAkY_k = X_k A_kYk=XkAk(其中 XkX_kXk 是 XXX 按列切分的结果),然后通过 All-Reduce 求和:
Y=∑k=1NYk=X1A1+X2A2+...+XNANY = \sum_{k=1}^N Y_k = X_1 A_1 + X_2 A_2 + ... + X_N A_NY=k=1∑NYk=X1A1+X2A2+...+XNAN
MLP 层的完整切分
Transformer 的 MLP 包含两个线性层和一个激活函数:
MLP(X)=GeLU(XA)B\text{MLP}(X) = \text{GeLU}(X A) BMLP(X)=GeLU(XA)B
Megatron-LM 的经典拆分:
-
第一层用列并行 :AAA 列切,每卡算部分 GeLU
-
第二层用行并行 :BBB 行切,All-Reduce 后得到完整输出
X │ ▼ ┌─────────┐GPU 0 ─┤ X × A₁ ├── GeLU ──┐
└─────────┘ │
┌─────────┐ │ All-Reduce
GPU 1 ─┤ X × A₂ ├── GeLU ──┼───► GeLU(A) × B
└─────────┘ │
┌─────────┐ │
GPU 2 ─┤ X × A₃ ├── GeLU ──┘
└─────────┘
为什么这样切? 因为 GeLU 是逐元素的,不需要跨卡通信。列并行先把计算拆分,每卡独立过 GeLU,然后在行并行层做 All-Reduce------只在必要的时候通信一次。
手算示例:2 卡列并行
设 X∈R2×4X \in \mathbb{R}^{2 \times 4}X∈R2×4,W∈R4×6W \in \mathbb{R}^{4 \times 6}W∈R4×6:
X=(12345678),W=(102101011020110101001010)X = \begin{pmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \end{pmatrix}, \quad W = \begin{pmatrix} 1 & 0 & 2 & 1 & 0 & 1 \\ 0 & 1 & 1 & 0 & 2 & 0 \\ 1 & 1 & 0 & 1 & 0 & 1 \\ 0 & 0 & 1 & 0 & 1 & 0 \end{pmatrix}X=(15263748),W= 101001102101101002011010
单卡计算:
Y=XW=(1⋅1+3⋅1+4⋅02⋅1+3⋅1⋯5⋅1+7⋅1+8⋅06⋅1+7⋅1⋯)=(1+32+32+0+41+0+00+6+41+0+05+710+710+0+85+0+00+14+85+0+0)Y = XW = \begin{pmatrix} 1\cdot1 + 3\cdot1 + 4\cdot0 & 2\cdot1 + 3\cdot1 & \cdots \\ 5\cdot1 + 7\cdot1 + 8\cdot0 & 6\cdot1 + 7\cdot1 & \cdots \end{pmatrix} = \begin{pmatrix} 1+3 & 2+3 & 2+0+4 & 1+0+0 & 0+6+4 & 1+0+0 \\ 5+7 & 10+7 & 10+0+8 & 5+0+0 & 0+14+8 & 5+0+0 \end{pmatrix}Y=XW=(1⋅1+3⋅1+4⋅05⋅1+7⋅1+8⋅02⋅1+3⋅16⋅1+7⋅1⋯⋯)=(1+35+72+310+72+0+410+0+81+0+05+0+00+6+40+14+81+0+05+0+0)
=(45611011217185225)= \begin{pmatrix} 4 & 5 & 6 & 1 & 10 & 1 \\ 12 & 17 & 18 & 5 & 22 & 5 \end{pmatrix}=(41251761815102215)
2 卡列并行:
将 WWW 按列切分,每卡 3 列:
W1=(102011110001),W2=(101020101010)W_1 = \begin{pmatrix} 1 & 0 & 2 \\ 0 & 1 & 1 \\ 1 & 1 & 0 \\ 0 & 0 & 1 \end{pmatrix}, \quad W_2 = \begin{pmatrix} 1 & 0 & 1 \\ 0 & 2 & 0 \\ 1 & 0 & 1 \\ 0 & 1 & 0 \end{pmatrix}W1= 101001102101 ,W2= 101002011010
GPU 0 计算:
Y1=XW1=(1⋅1+3⋅12⋅1+3⋅11⋅2+2⋅1+3⋅0+4⋅15⋅1+7⋅16⋅1+7⋅15⋅2+6⋅1+7⋅0+8⋅1)=(456121718)Y_1 = XW_1 = \begin{pmatrix} 1\cdot1 + 3\cdot1 & 2\cdot1 + 3\cdot1 & 1\cdot2 + 2\cdot1 + 3\cdot0 + 4\cdot1 \\ 5\cdot1 + 7\cdot1 & 6\cdot1 + 7\cdot1 & 5\cdot2 + 6\cdot1 + 7\cdot0 + 8\cdot1 \end{pmatrix} = \begin{pmatrix} 4 & 5 & 6 \\ 12 & 17 & 18 \end{pmatrix}Y1=XW1=(1⋅1+3⋅15⋅1+7⋅12⋅1+3⋅16⋅1+7⋅11⋅2+2⋅1+3⋅0+4⋅15⋅2+6⋅1+7⋅0+8⋅1)=(412517618)
GPU 1 计算:
Y2=XW2=(1⋅1+2⋅0+3⋅1+4⋅01⋅0+2⋅2+3⋅0+4⋅11⋅1+2⋅0+3⋅1+4⋅05⋅1+6⋅0+7⋅1+8⋅05⋅0+6⋅2+7⋅0+8⋅15⋅1+6⋅0+7⋅1+8⋅0)=(11015225)Y_2 = XW_2 = \begin{pmatrix} 1\cdot1 + 2\cdot0 + 3\cdot1 + 4\cdot0 & 1\cdot0 + 2\cdot2 + 3\cdot0 + 4\cdot1 & 1\cdot1 + 2\cdot0 + 3\cdot1 + 4\cdot0 \\ 5\cdot1 + 6\cdot0 + 7\cdot1 + 8\cdot0 & 5\cdot0 + 6\cdot2 + 7\cdot0 + 8\cdot1 & 5\cdot1 + 6\cdot0 + 7\cdot1 + 8\cdot0 \end{pmatrix} = \begin{pmatrix} 1 & 10 & 1 \\ 5 & 22 & 5 \end{pmatrix}Y2=XW2=(1⋅1+2⋅0+3⋅1+4⋅05⋅1+6⋅0+7⋅1+8⋅01⋅0+2⋅2+3⋅0+4⋅15⋅0+6⋅2+7⋅0+8⋅11⋅1+2⋅0+3⋅1+4⋅05⋅1+6⋅0+7⋅1+8⋅0)=(15102215)
拼接 :Y=Y1,Y2=(45611011217185225)Y = Y_1, Y_2 = \begin{pmatrix} 4 & 5 & 6 & 1 & 10 & 1 \\ 12 & 17 & 18 & 5 & 22 & 5 \end{pmatrix}Y=Y1,Y2=(41251761815102215)
与单卡结果完全一致 ✓
四、流水线并行------当网络太深
4.1 基本思想
把 Transformer 的层按顺序切分到不同 GPU:
GPU 0: L1 → L2 GPU 1: L3 → L4 GPU 2: L5 → L6
│ │ │
└───── micro-batch ─────►└───── micro-batch ─────►
4.2 Bubble 率的数学
流水线并行的核心问题是 bubble(气泡)------等待前一个 stage 做完才能开始计算的空间。(实际上,如果大家学习过计算机组成原理,应该对这点并不陌生)
假设 ppp 个 stage,mmm 个 micro-batch。理想情况下,如果流水线打满,每个 micro-batch 经过 ppp 个 stage 需要 ppp 个时间单位,总共 m+p−1m + p - 1m+p−1 个时间单位。其中空闲的时间占总时间的比例(bubble 率)为:
Bubble=1−mm+p−1=p−1m+p−1\text{Bubble} = 1 - \frac{m}{m + p - 1} = \frac{p - 1}{m + p - 1}Bubble=1−m+p−1m=m+p−1p−1
当 m≫pm \gg pm≫p 时,bubble 率趋近于 0。因此流水线并行需要足够的 micro-batch 数量来"填满"流水线。
手算示例 :p=4p=4p=4 个 stage,m=8m=8m=8 个 micro-batch
时间 → 1 2 3 4 5 6 7 8 9 10 11
GPU 0: F0 F1 F2 F3 F4 F5 F6 F7
GPU 1: F0 F1 F2 F3 F4 F5 F6 F7
GPU 2: F0 F1 F2 F3 F4 F5 F6 F7
GPU 3: F0 F1 F2 F3 F4 F5 F6 F7
↑ 气泡 ↑ ↑ 气泡 ↑
Bubble 率 =p−1m+p−1=38+3=311≈27.3%= \frac{p-1}{m+p-1} = \frac{3}{8+3} = \frac{3}{11} \approx 27.3\%=m+p−1p−1=8+33=113≈27.3%
4.3 1F1B 调度
现代流水线并行采用 1F1B(One-Forward-One-Backward)调度------前向和反向交叉进行,减少显存峰值,让 bubble 更小。
五、ZeRO + FSDP------现代分布式训练的事实标准
5.1 从数据并行到 ZeRO
数据并行的问题是每卡都存一份完整的模型状态(参数 + 梯度 + 优化器状态)。ZeRO 的洞察是:"既然你把数据拆分到了 N 张卡,为什么模型状态不能也拆分?"
ZeRO(Zero Redundancy Optimizer)的三个 stage:
| Stage | 切分什么 | 每卡显存(N 卡) | 相对 DDP 节省 |
|---|---|---|---|
| Stage 1 | 优化器状态 | Params+Grad+OptimN\text{Params} + \text{Grad} + \frac{\text{Optim}}{N}Params+Grad+NOptim | 约 4× |
| Stage 2 | + 梯度 | Params+GradN+OptimN\text{Params} + \frac{\text{Grad}}{N} + \frac{\text{Optim}}{N}Params+NGrad+NOptim | 约 8× |
| Stage 3 | + 参数 | ParamsN+GradN+OptimN\frac{\text{Params}}{N} + \frac{\text{Grad}}{N} + \frac{\text{Optim}}{N}NParams+NGrad+NOptim | 约 16×(理想) |
5.2 手算示例:1.5B 模型在 4 卡上的显存对比
设模型 1.5B 参数,FP32(4 bytes),Adam 优化器(2 个状态)。
| 策略 | 参数 | 梯度 | 优化器 | 每卡总计 |
|---|---|---|---|---|
| DDP(无 ZeRO) | 6.0 GB | 6.0 GB | 12.0 GB | 24.0 GB |
| ZeRO-1 | 6.0 GB | 6.0 GB | 12.0/4 = 3.0 GB | 15.0 GB |
| ZeRO-2 | 6.0 GB | 6.0/4 = 1.5 GB | 12.0/4 = 3.0 GB | 10.5 GB |
| ZeRO-3 | 6.0/4 = 1.5 GB | 6.0/4 = 1.5 GB | 12.0/4 = 3.0 GB | 6.0 GB |
注意 :ZeRO-3 下,每张卡只持有 1/4 的参数。当需要前向或反向计算时,从其他卡收集当前层参数------时间换空间,会在通信上付出额外开销。
5.3 FSDP:统一的 Sharding 策略
FSDP(Fully Sharded Data Parallel)是 PyTorch 对 ZeRO-3 的实现。它的核心思想是把 ZeRO-1/2/3 统一为一个可配置的 sharding 策略:
python
from torch.distributed.fsdp import ShardingStrategy
# ZeRO-1: 只切分优化器状态
ShardingStrategy.SHARD_GRAD_OP
# ZeRO-2: 切分梯度和优化器状态
ShardingStrategy.SHARD_GRAD_OP # (默认行为)
# ZeRO-3: 切分参数、梯度、优化器状态
ShardingStrategy.FULL_SHARD
FSDP 的关键概念:
-
Flat Parameter:将整个模型的参数"拍平"成一个连续的大张量,便于高效切分和通信。
-
Unshard:前向计算某层前,从所有卡收集该层的完整参数。
-
Reshard:前向完成后,释放非本卡持有的参数分片,腾出显存。
-
Pre-backward:反向传播前再次收集参数,计算梯度。
-
Post-backward:反向完成后,对梯度做 All-Reduce 并释放参数。
时间 →───────────────────────────────────────────
参数: │░░░░│████│░░░░│████│░░░░│████│
GPU 0 └────┴────┴────┴────┴────┴────┘
shard FWD reshard BWD shard│░░░░│████│░░░░│████│░░░░│████│GPU 1 └────┴────┴────┴────┴────┴────┘
shard FWD reshard BWD shard
───────────────────────────────────────────
░░░ = 不持有此层参数
████ = 持有完整参数(unshard)
六、3D 并行全景
实际的大模型训练通常不是只用一种并行策略,而是组合使用:
| 并行维度 | 切分对象 | 通信模式 | 典型场景 |
|---|---|---|---|
| 数据并行 | 数据 batch | All-Reduce | 节点间(RDMA) |
| 张量并行 | 矩阵乘法 | All-Reduce / All-Gather | 节点内(NVLink) |
| 流水线并行 | 网络层 | P2P 通信 | 跨节点 |
为什么这样组合?
-
张量并行通信量大(每层都需要 All-Reduce),适合放在节点内(NVLink 带宽高,延迟低)
-
数据并行通信量小(每步一次 All-Reduce),适合跨节点(RDMA)
-
流水线并行通信最小(只传层间激活),但 bubble 问题需要权衡
┌──────────────── Node 0 ────────────────┐ │ GPU 0 ←TP→ GPU 1 GPU 2 ←TP→ GPU 3 │ │ ↓ ↓ │ │ GPU 0 ←TP→ GPU 1 GPU 2 ←TP→ GPU 3 │ │ ↓ ↓ │ │ GPU 0 ←TP→ GPU 1 GPU 2 ←TP→ GPU 3 │ └─────────────────────────────────────────┘ │ DP (跨节点) ┌──────────────── Node 1 ────────────────┐ │ GPU 0 ←TP→ GPU 1 GPU 2 ←TP→ GPU 3 │ │ ↓ ↓ │ │ GPU 0 ←TP→ GPU 1 GPU 2 ←TP→ GPU 3 │ │ ↓ ↓ │ │ GPU 0 ←TP→ GPU 1 GPU 2 ←TP→ GPU 3 │ └─────────────────────────────────────────┘
这是 Megatron-DeepSpeed 的标准 3D 并行配置:
- 张量并行:节点内 4 卡(NVLink)
- 流水线并行:3 组(垂直方向)
- 数据并行:2 节点(水平方向)
七、实战:用实验脚本理解分布式训练
(本实战需要运行在多卡服务器上)
实验 1:手工实现 Ring All-Reduce
纯 Python 实现,不需要 GPU,演示 All-Reduce 的两个阶段(ReduceScatter + AllGather)。
python
# 核心逻辑:沿环做 N-1 次通信
for step in range(world_size - 1):
send_idx = (rank - step) % world_size
recv_idx = (rank - step - 1) % world_size
# 发送 send_buf[send_idx], 接收并累加到 recv_buf[recv_idx]
data = send_buf[send_idx]
recv_buf[recv_idx] = [data[i] + recv_buf[recv_idx][i]
for i in range(len(data))]
运行后输出:
Experiment 1: Manual Ring All-Reduce
All-Reduce completed for 4 ranks, dim=8
Max error across all ranks/dims: 0.00e+00
✓ Numerical check PASSED
Communication rounds: 6 (ReduceScatter: 3, AllGather: 3)
实验 2:DDP 加速比
对比 1/2/4 张 GPU 上同步 SGD 的训练速度。
| Config | Time/iter (ms) | Speedup |
|---|---|---|
| 1 GPU | 48.2 | 1.00× |
| 2 GPUs (DDP) | 26.1 | 1.85× |
| 4 GPUs (DDP) | 14.8 | 3.26× |
观察:加速比小于线性(3.26× 而非 4×),原因是 All-Reduce 的通信开销和 PCIe 带宽限制(我使用的3090之间没有NVLink,如果换成A100等专业显卡,效果会更显著)。
实验 3:FSDP ZeRO 显存对比
计算 1.5B 模型在 4 卡上不同 ZeRO 策略的每卡显存占用:
| Strategy | Per-GPU Memory |
|---|---|
| DDP (no sharding) | 24.00 GB |
| ZeRO-1 | 15.00 GB |
| ZeRO-2 | 10.50 GB |
| ZeRO-3 | 6.00 GB |
关键洞察:ZeRO-3 将 24 GB → 6 GB(节省 75%),1.5B 模型终于可以放进 4×3090(每卡 24 GB,冗余充足)。
实验 4:Scaling Efficiency(Amdahl 定律)
不同并行化比例下的理论加速比:
| N (GPUs) | p=0.90 | p=0.95 | p=0.98 | p=0.99 |
|---|---|---|---|---|
| 1 | 1.00× | 1.00× | 1.00× | 1.00× |
| 2 | 1.82× | 1.90× | 1.96× | 1.98× |
| 4 | 3.08× | 3.48× | 3.77× | 3.88× |
| 8 | 4.71× | 5.93× | 7.02× | 7.48× |
| 16 | 6.40× | 8.95× | 11.64× | 13.00× |
| 32 | 7.80× | 12.09× | 17.77× | 21.07× |
| 64 | 8.79× | 14.83× | 24.22× | 30.85× |
结论 :如果 10% 的代码无法并行(p=0.90p=0.90p=0.90),64 卡也只能获得 8.79× 加速比。分布式训练的关键不是堆更多卡,而是让可并行部分尽可能大。
八、总结
三种并行维度的数学本质
| 并行类型 | 数学操作 | 通信量 | 适用场景 |
|---|---|---|---|
| 数据并行 | 全局平均梯度 | O(参数量)O(\text{参数量})O(参数量) 每步 | 单卡装得下模型 |
| 张量并行 | 分块矩阵乘法 | O(激活量)O(\text{激活量})O(激活量) 每层 | 单层太大 |
| 流水线并行 | 切分计算图 | O(激活量)O(\text{激活量})O(激活量) 每 batch | 网络太深 |
分布式训练的 Amdahl 定律
S(N)=11−p+p/NS(N) = \frac{1}{1 - p + p/N}S(N)=1−p+p/N1
其中 ppp 是可并行化的比例,NNN 是 GPU 数量。即使 99% 的代码可并行(p=0.99p=0.99p=0.99),64 卡的理论加速也只有 30.85×------分布式训练受木桶效应限制。
选择的指导原则
- 模型能放下一张卡? → 数据并行(DDP / FSDP ZeRO-2)
- 单卡放不下但 4 卡拼得起来? → FSDP ZeRO-3 + 数据并行
- 单层放不下一张卡? → 张量并行(Megatron-LM)
- 模型极大(100B+)? → 3D 并行(数据 + 张量 + 流水线)