09c-斯坦福CS336作业二:系统与分布式训练 ⚙️
本文档详细解析斯坦福 CS336 课程 Assignment 2 的核心内容,涵盖单 GPU 性能分析与优化(基准测试、混合精度、内存分析)、FlashAttention-2 的 Triton 内核实现、分布式数据并行训练(DDP)的渐进式优化,以及优化器状态分片(ZeRO-1)技术。通过理论与实践相结合,帮助读者深入理解大模型训练中的系统工程问题 🔧
This document provides an in-depth analysis of Stanford CS336 Assignment 2, covering single-GPU profiling and optimization (benchmarking, mixed precision, memory analysis), FlashAttention-2 Triton kernel implementation, progressive optimization of Distributed Data Parallel training (DDP), and optimizer state sharding (ZeRO-1) 🛠️
术语表 / Terminology
| 术语 / Term | 中文 | 说明 / Description |
|---|---|---|
| Profiling | 性能分析 | 测量程序在时间和内存维度的资源消耗,定位性能瓶颈 |
| Benchmarking | 基准测试 | 在标准化条件下测量模型前向/反向传播的速度和显存 |
| Mixed Precision | 混合精度 | 同时使用 FP16/BF16 和 FP32 进行训练,加速计算并保持数值稳定 |
| FlashAttention-2 | 闪存注意力 2 | 通过分块计算、重计算和算子融合优化注意力机制,降低内存 IO 开销 |
| Tiling | 分块计算 | 将输入数据分割为小块在 SRAM 中计算,避免在 HBM 中存储完整注意力矩阵 |
| Recomputation | 重计算 | 前向传播时不存储中间激活值,反向传播时重新计算,以时间换空间 |
| Operator Fusion | 算子融合 | 将多个操作合并到单个 GPU 内核中执行,减少内存读写次数 |
| DDP (Distributed Data Parallel) | 分布式数据并行 | 将数据批次拆分到多个 GPU,每个 GPU 持有完整模型副本,梯度通过 All-Reduce 同步 |
| All-Reduce | 全归约 | 集合通信操作,将所有进程的数据求和并广播给所有进程 |
| FSDP (Fully Sharded Data Parallel) | 全分片数据并行 | 将参数、梯度和优化器状态分片到多个 GPU,降低单卡显存占用 |
| ZeRO (Zero Redundancy Optimizer) | 零冗余优化器 | 通过在数据并行进程间划分优化器状态、梯度和参数来减少显存冗余 |
| Triton | Triton 编程语言 | OpenAI 开发的 GPU 编程语言,提供比 CUDA 更简洁的编程接口 |
| HBM (High Bandwidth Memory) | 高带宽存储器 | GPU 上的全局显存,容量大但访问速度相对较慢 |
| SRAM | 静态随机存取存储器 | GPU 芯片上的高速缓存,速度极快但容量有限 |
| NCCL | NVIDIA 集合通信库 | NVIDIA 提供的高性能多 GPU 集合通信库 |
章节阅读路线图 🗺️ / Chapter Reading Roadmap
- 作业概述 📋 / Assignment Overview → 了解 Assignment 2 的整体目标、结构和核心任务
- 性能分析与基准测试 🔬 / Performance Profiling & Benchmarking → 掌握端到端计时、Nsight 分析器、混合精度和内存分析
- FlashAttention-2 优化 ⚡ / FlashAttention-2 Optimization → 从分块计算到 Triton 内核,实现高效注意力机制
- 分布式数据并行训练 🖥️ / Distributed Data Parallel Training → 从朴素 DDP 到桶化梯度通信的渐进式优化
- 优化器状态分片 💾 / Optimizer State Sharding → 实现 ZeRO-1 风格的优化器状态分片
- 总结 📝 / Summary → 回顾核心要点
1. 作业概述 📋 / Assignment Overview
📦 Note: 本章介绍 Assignment 2 的整体目标、结构和核心任务 / This chapter introduces the overall goals, structure, and core tasks of Assignment 2.
1.1 作业定位
Assignment 2 是 CS336 课程的第二个作业,聚焦于 系统优化和分布式训练 ------即如何让大模型训练"跑得更快、用得更少、扩得更大"。
直观类比 🏭:如果说 Assignment 1 是"造出一辆能开的车",那么 Assignment 2 就是"把这辆车改装成赛车"------不仅要能跑,还要跑得又快又稳又省油耗。
学生需要在 Assignment 1 构建的 Transformer 语言模型基础上,完成以下四大核心任务:
| 任务 | 核心内容 | 分值占比 |
|---|---|---|
| 性能分析与基准测试 🔬 | 端到端计时、Nsight 分析、混合精度、内存分析 | ~20 分 |
| FlashAttention-2 ⚡ | 分块计算、Triton 内核、在线 Softmax | ~30 分 |
| 分布式数据并行训练 🖥️ | All-Reduce、DDP、梯度桶化、计算通信重叠 | ~30 分 |
| 优化器状态分片 💾 | ZeRO-1 实现、显存分析 | ~20 分 |
1.2 模型规模定义
作业中使用了多种模型规模进行基准测试,所有模型的词汇表大小均为 10,000,批量大小为 4:
| 规模 | dmodel(模型维度) | dff(前馈网络维度) | num_layers(层数) | num_heads(注意力头数) |
|---|---|---|---|---|
| small(小型) | 768 | 3072 | 12 | 12 |
| medium(中型) | 1024 | 4096 | 24 | 16 |
| large(大型) | 1280 | 5120 | 36 | 20 |
| xl(超大) | 1600 | 6400 | 48 | 25 |
| 2.7B(27 亿参数) | 2560 | 10240 | 32 | 32 |
1.3 为什么需要系统优化? 🤔
在 Assignment 1 中,我们实现了能训练的 Transformer 模型。但在实际生产中,训练一个 2.7B 参数的模型面临着严峻的系统挑战:
- 显存墙(Memory Wall) 🧱:注意力机制的中间矩阵大小为 O(n2),序列长度翻倍,显存占用翻四倍
- 计算墙(Compute Wall) ⚡:矩阵乘法虽占主导,但非矩阵运算(如 Softmax、LayerNorm)的内存 IO 开销不可忽视
- 通信墙(Communication Wall) 📡:多 GPU 训练时,梯度同步的通信开销可能抵消并行化带来的加速
直观类比 🚗:系统优化就像改装赛车------
- 减重(内存优化)→ 让车更轻,跑得更快
- 升级引擎(FlashAttention)→ 让单位时间做更多功
- 多引擎协同(分布式训练)→ 让多个引擎高效配合
参考资料:
- Stanford CS336 官方课程网站 -- Stanford ⭐值得阅读
- Stanford CS336 Assignment 2 GitHub 仓库 -- GitHub ⭐值得阅读
- CS336 Assignment 2 (systems): Systems and Parallelism 翻译和实现 -- CSDN
- Standford CS336 Language Modeling from Scratch通关总结 -- 知乎
2. 性能分析与基准测试 🔬 / Performance Profiling & Benchmarking
🔬 Note: 本章讲解如何对模型进行性能分析,定位优化机会 / This chapter explains how to profile models and identify optimization opportunities.
在进行任何优化之前,先对程序进行性能分析至关重要 ------否则可能耗费精力优化那些对整体性能影响甚微的模块,最终无法实现可量化的端到端提升。
直观类比 🏥:性能分析就像给程序做"体检"------先做全面检查(基准测试),再用精密仪器(Profiler)定位具体问题,最后对症下药(针对性优化)。
2.1 端到端基准测试 ⏱️ / End-to-End Benchmarking
最基础的性能分析方式是计时正向传播和反向传播过程。
CUDA 异步陷阱 ⚠️:在 PyTorch 中,CUDA 调用是异步的------调用 torch.matmul 后,CPU 会立即返回,不等 GPU 完成计算。因此,直接测量 Python 函数调用的返回时间不能反映 GPU 实际执行耗时。
正确做法 :每次测量后调用 torch.cuda.synchronize() 等待所有 GPU 内核执行完成。
python
import timeit # 导入高精度计时模块 ⏱️
import torch # 导入 PyTorch 核心库 🔥
def benchmark(model, data, n_warmup=5, n_steps=10): # 定义基准测试函数
"""对模型的正向传播和反向传播进行端到端基准测试
参数:
model: PyTorch 模型实例
data: 输入数据
n_warmup: 热身步数(不计入计时)
n_steps: 正式计时步数
返回:
avg_time: 平均每步耗时(秒)
"""
# 热身阶段:让 GPU 完成初始化、JIT 编译等一次性开销 🏃
for _ in range(n_warmup): # 执行热身步骤
output = model(data) # 前向传播
output.sum().backward() # 反向传播
torch.cuda.synchronize() # 等待 GPU 完成所有热身操作 ⏳
# 正式计时阶段 📊
start = timeit.default_timer() # 记录开始时间(最高分辨率时钟)
for _ in range(n_steps): # 执行 n_steps 次迭代
output = model(data) # 前向传播
output.sum().backward() # 反向传播
torch.cuda.synchronize() # 每次迭代后同步 GPU ⏳
end = timeit.default_timer() # 记录结束时间
avg_time = (end - start) / n_steps # 计算平均每步耗时
return avg_time # 返回结果
关键细节 📝:
- 热身步骤不可少:GPU 首次执行操作时有 JIT 编译、显存分配等一次性开销,不热身会导致测量结果偏大
- 使用
timeit.default_timer():提供系统最高分辨率的时钟,比time.time()更适合基准测试 - 多次测量取平均:减少随机波动的影响,同时报告标准差以评估变异性
参考资料:
2.2 Nsight Systems 性能分析器 🔍 / Nsight Systems Profiler
端到端基准测试无法告诉我们时间在各个组件上的具体分布 ,因此无法精准定位优化机会。NVIDIA 提供的 nsys 工具可以分析 CUDA 内核级别的执行细节。
使用方法 :在 Python 脚本前加上 nsys profile 前缀即可:
bash
uv run nsys profile -o result python benchmark.py # 使用 nsys 分析脚本性能
nsys 能告诉我们什么? 📊:
- 每个 CUDA 内核的累计 GPU 耗时排名
- 矩阵乘法、Softmax、LayerNorm 等操作各占多少时间
- 正向传播和反向传播中,耗时最长的内核是否相同
- CPU-GPU 之间的数据传输开销
NVTX 范围标注 🏷️:可以使用 NVTX 标注代码中的特定区域,在性能分析结果中以块的形式显示:
python
import torch.cuda.nvtx as nvtx # 导入 NVTX 标注模块 🏷️
@nvtx.range("缩放点积注意力") # 标注注意力计算范围
def annotated_scaled_dot_product_attention(Q, K, V): # 定义带标注的注意力函数
with nvtx.range("计算注意力分数"): # 标注子操作:Q × K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # 计算点积
with nvtx.range("计算 softmax"): # 标注子操作:归一化
weights = torch.softmax(scores, dim=-1) # Softmax 归一化
with nvtx.range("最终矩阵乘法"): # 标注子操作:加权求和
output = torch.matmul(weights, V) # 加权求和
return output # 返回输出
关键发现 🔍:通过 nsys 分析可以发现------
- 矩阵乘法(GEMM)占正向传播耗时的大部分,但并非全部
- Softmax 的实际运行开销远高于其 FLOPs 所暗示的水平------它是"低 FLOPs 但高延迟"的瓶颈操作
- 完整训练时,非 GEMM 操作(梯度缩放、激活函数导数、优化器更新)的时间占比显著上升
参考资料:
- 快速入门 Nsys/TorchProfiler/NCU -- 知乎 ⭐值得阅读
- Speed Up PyTorch Training by 3x with NVIDIA Nsight -- Practical ML
2.3 混合精度训练 🎚️ / Mixed Precision Training
现代 NVIDIA GPU 配备了专门的 Tensor Cores ,可加速低精度下的矩阵乘法运算。例如,NVIDIA A100 在 FP32 精度下最大吞吐量为 19.5 TFLOP/秒,而在 BF16 精度下可提升至 312 TFLOP/秒------提速约 16 倍!
直观类比 💡:混合精度就像用不同精度的尺子量东西------大部分测量用"粗尺子"(BF16)快速完成,关键位置用"细尺子"(FP32)保证准确性。
为什么不能全部用低精度? 🤔
- FP16 的动态范围小,某些梯度值过小会变为零
- FP16 容易溢出,导致损失值变为 NaN
- BF16 训练更稳定(动态范围与 FP32 相同),但仍可能影响最终模型性能
混合精度训练 (Mixed Precision Training)通过 torch.autocast 上下文管理器实现------部分操作(如矩阵乘法)以低精度执行,其他操作(如累加、归约)保持 FP32 精度:
python
model = TransformerModel() # 创建模型实例 🏗️
dtype = torch.bfloat16 # 选择 BF16 精度 🎚️
x = input_data # 输入数据
with torch.autocast(device="cuda", dtype=dtype): # 开启自动混合精度 🎯
y = model(x) # 前向传播:矩阵乘法自动用 BF16
# LayerNorm、Softmax 等操作自动保持 FP32 精度 ✅
混合精度中各组件的数据类型 📋:
| 组件 | 数据类型 | 原因 |
|---|---|---|
| 模型参数(autocast 内) | FP32 | 参数本身不自动转换 |
| 前馈层(Linear)输出 | BF16 | 矩阵乘法自动降精度 |
| LayerNorm 输出 | FP32 | 方差计算对精度敏感 |
| 损失(Loss) | FP32 | 保持数值稳定 |
| 梯度(Gradients) | FP32 | 累加需要高精度 |
参考资料:
2.4 内存分析 💾 / Memory Profiling
PyTorch 内置了强大的内存分析器,可跟踪随时间变化的显存分配情况。
python
# 开始记录内存历史 📊
torch.cuda.memory._record_memory_history(max_entries=1000000)
# 运行要分析的代码(前向传播、反向传播等) 🏃
output = model(data)
output.sum().backward()
# 保存内存快照(Pickle 格式) 💾
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
# 停止记录 🛑
torch.cuda.memory._record_memory_history(enabled=None)
生成的 memory_snapshot.pickle 文件可以上传到 PyTorch 内存可视化工具,在浏览器中查看:
- 活跃内存时间线(Active Memory Timeline):显示显存随时间的变化曲线
- 峰值内存使用量:识别显存使用的高峰时刻
- 每个内存分配的调用栈:定位哪些操作消耗了最多显存
Transformer 显存消耗分析 📐:
以 2.7B 模型为例( dmodel=2560, num_layers=32):
| 组件 | 显存占用 | 说明 |
|---|---|---|
| 模型参数(FP32) | ~10.8 GB | 2.7×109×4 字节 |
| 优化器状态(AdamW) | ~21.6 GB | 每个参数 2 个浮点数(动量 + 方差) |
| 梯度(FP32) | ~10.8 GB | 与参数相同大小 |
| 激活值(依赖序列长度) | 变化 | 注意力矩阵为 O(n2) |
激活张量的显存计算 🧮:对于残差流中的激活张量(形状 batch_size,seq_len,dmodel):
显存=batch_size×seq_len×dmodel×bytes_per_element
例如, batch_size=1, seq_len=256, dmodel=2560,FP32 精度:
1×256×2560×4=2,621,440 字节≈2.5 MB
虽然单个激活张量不大,但注意力层的中间矩阵( seq_len×seq_len)在长序列时会急剧膨胀,成为显存瓶颈。
参考资料:
- PyTorch 显存可视化与 Snapshot 数据分析 -- 知乎 ⭐值得阅读
- Visualize and Understand GPU Memory in PyTorch -- Hugging Face ⭐值得阅读
- PyTorch 官方 CUDA 内存分析文档 -- PyTorch
💡 Key Takeaways / 核心要点
- Profile before optimizing --- avoid wasting time on non-bottleneck components / 先分析再优化,避免优化非瓶颈组件
- CUDA operations are asynchronous --- always synchronize before timing / CUDA 操作是异步的,计时前必须同步
- Mixed precision can yield 16x speedup on Tensor Cores / 混合精度在 Tensor Cores 上可提速 16 倍
- Memory profiling reveals hidden bottlenecks --- attention matrices dominate at long sequences / 内存分析揭示隐藏瓶颈,注意力矩阵在长序列时主导显存
3. FlashAttention-2 优化 ⚡ / FlashAttention-2 Optimization
⚡ Note: 本章讲解 FlashAttention-2 的核心技术和 Triton 内核实现 / This chapter covers FlashAttention-2 core techniques and Triton kernel implementation.
3.1 标准注意力机制的低效性 🐢 / Inefficiency of Standard Attention
标准注意力机制的前向传播包含三步:
- 计算注意力分数: S=QKT/dk
- Softmax 归一化: P=softmax(S)
- 加权求和: O=PV
问题在哪? 🤔
朴素实现需要在 GPU 的 HBM(高带宽存储器)中存储完整的 seq_len×seq_len 注意力分数矩阵 P。当序列长度较长时,这个矩阵会极大:
| 序列长度 | 注意力矩阵大小(FP32) | 说明 |
|---|---|---|
| 512 | 1 MB | 可接受 |
| 4096 | 64 MB | 开始显著 |
| 16384 | 1 GB | 显存瓶颈 |
| 65536 | 16 GB | 单卡无法容纳 |
直观类比 📦:标准注意力就像在一张巨大的纸上写完所有答案,再抄到最终卷子上------纸(HBM)很大但搬运很慢,而且大部分内容抄完就不需要了。
根本原因 :标准实现在 HBM 和 SRAM 之间频繁传输 P 及其他大型激活值,产生极高的内存 IO 开销 。例如,标准反向传播会在计算 dV=PTdO 和 dS=dsoftmax(dP) 时两次从 HBM 读取 P。
3.2 FlashAttention-2 的三大核心技术 🔧 / Three Core Techniques
FlashAttention-2 通过三种技术协同工作,避免在 HBM 中读写完整的注意力矩阵:
3.2.1 分块计算(Tiling) 🧩
将输入数据分割为多个小块(tile),在 GPU 的 SRAM(高速缓存)中完成计算,避免将完整注意力矩阵写入 HBM。
直观类比 🍕:分块计算就像吃披萨------不会一口吞下整个披萨,而是一块一块地吃。每块大小受 SRAM 容量限制,但吃完所有块后,效果等同于吃完整个披萨。
具体来说,将 Q 分割为 Tq 个大小为 Bq×d 的分块,将 K、 V 分割为 Tk 个大小为 Bk×d 的分块:
Q1,...,QTq(每个 Bq×d)
K(1),...,K(Tk),V(1),...,V(Tk)(每个 Bk×d)
3.2.2 重计算(Recomputation) 🔄
不再将 seq_len×seq_len 的注意力矩阵 P 存储在 HBM 中,而是在反向传播时重新计算。
以时间换空间 :前向传播时只保存少量"检查点"( Q、 K、 V 和 logsumexp L),反向传播时重新计算 P。这样,内存复杂度从 O(n2) 降低到 O(n)。
FlashAttention-2 还会存储注意力分数的对数求和指数(logsumexp) L,用于简化反向传播计算:
Li=log(j∑exp(Sij))
借助 L 和预计算的 D=rowsum(O∘dO),反向传播可在无需执行 softmax 运算的情况下完成:
Pij=exp(Sij−Li)
dSij=Pij∘(dPij−Di)
3.2.3 算子融合(Operator Fusion) 🔀
将所有注意力操作( QKT、缩放、Softmax、加权求和)融合到单个 Triton 内核中执行,避免对中间结果的重复内存 IO。
直观类比 🍳:算子融合就像做菜时把所有步骤在一个锅里完成------切好的菜直接下锅炒,不用每次都盛出来再倒回去。
3.3 在线 Softmax 算法 📊 / Online Softmax Algorithm
分块计算面临一个关键挑战:Softmax 需要整行数据才能计算分母,但我们对行进行了分块,无法一次性看到所有数据。
在线 Softmax 解决了这个问题------通过维护两个累计值,在遍历每个分块时增量更新 Softmax 结果:
- mi(j):累计最大值,保证数值稳定性
- li(j):Softmax 分母的累计代理值
每处理一个新的键分块 j,更新规则为:
mi(j)=max(mi(j−1),rowmax(Si(j)))
P~i(j)=exp(Si(j)−mi(j))
li(j)=exp(mi(j−1)−mi(j))⋅li(j−1)+rowsum(P~i(j))
Oi(j)=diag(exp(mi(j−1)−mi(j)))⋅Oi(j−1)+P~i(j)V(j)
最终归一化:
Oi=diag((li(Tk))−1)⋅Oi(Tk)
直观类比 🗳️:在线 Softmax 就像实时计票------每收到一批新票,不需要重新数之前的票,只需要更新"当前最高票"和"总票数"两个计数器,就能得到正确的百分比。
3.4 FlashAttention-2 前向传播算法 📝 / FA2 Forward Pass Algorithm
算法 1 FlashAttention-2 前向传播
输入 : Q∈RNq×d, K,V∈RNk×d,分块大小 Bq,Bk
- 将 Q 分割为 Tq=⌈Nq/Bq⌉ 个分块 Q1,...,QTq
- 将 K,V 分割为 Tk=⌈Nk/Bk⌉ 个分块
- 对于 i=1,...,Tq:
- 从全局内存加载 Qi
- 初始化 Oi(0)=0, li(0)=0, mi(0)=−∞
- 对于 j=1,...,Tk:
- 从全局内存加载 K(j),V(j)
- 计算 Si(j)=Qi(K(j))T/d
- 更新 mi(j)、 P~i(j)、 li(j)、 Oi(j)(见 3.3 节公式)
- 归一化: Oi=diag((li(Tk))−1)⋅Oi(Tk)
- 计算 logsumexp: Li=mi(Tk)+log(li(Tk))
- 将 Oi 和 Li 写入全局内存
输出 :输出矩阵 O 和 logsumexp 矩阵 L
关键优势:
| 特性 | 标准注意力 | FlashAttention-2 |
|---|---|---|
| 内存复杂度 | O(n2) | O(n) |
| HBM 读写 | 完整注意力矩阵 | 仅分块数据 |
| 数值稳定性 | 需要单独处理 | 在线算法内置保证 |
| 长序列支持 | 受限(OOM) | 支持超长序列 |
3.5 Triton 内核实现要点 🔨 / Triton Kernel Implementation Notes
作业要求学生使用 Triton 编程语言实现 FlashAttention-2 内核。以下是关键实现要点:
Triton 基础:Triton 是 OpenAI 开发的 GPU 编程语言,提供比 CUDA 更简洁的接口。每个 Triton 程序实例是一个线程块,所有线程执行同一程序,可在 GPU 上并行执行。
python
import triton # 导入 Triton 核心库 🔥
import triton.language as tl # 导入 Triton 语言模块
@triton.jit # JIT 编译装饰器 ⚡
def flash_fwd_kernel(
Q_ptr, K_ptr, V_ptr, # 输入张量指针 🔍🔑💎
O_ptr, L_ptr, # 输出张量指针 📤
stride_qb, stride_qq, stride_qd, # Q 的步长参数
N_QUERIES, N_KEYS, # 序列长度
scale, # 缩放因子 $1/\sqrt{d}$
D: tl.constexpr, # 隐藏维度(编译期常量)
Q_TILE_SIZE: tl.constexpr, # 查询分块大小
K_TILE_SIZE: tl.constexpr, # 键分块大小
):
query_tile_index = tl.program_id(0) # 查询分块索引 📍
batch_index = tl.program_id(1) # 批次索引 📍
# ... 分块计算逻辑 ...
关键实现技巧 🛠️:
- 使用
tl.make_block_ptr简化指针运算 - 片上缓冲区( Oi, l, m)应使用
tl.float32保证精度 - 使用
tl.dot执行矩阵乘法 - 启动网格设置为 (Tq,batch_size)
参考资料:
- FlashAttention: Making Attention I/O-Aware -- Hugging Face ⭐值得阅读
- FlashAttention-2 in Triton: From GPU Mental Models to Kernel -- Timashov ⭐值得阅读
- 从零开始用自定义 Triton 内核编写 FlashAttention-2 -- 知乎
- 通透理解 FlashAttention 全面降低显存读写 -- CSDN
- FlashAttention 原始论文 -- arXiv ⭐值得阅读
💡 Key Takeaways / 核心要点
- Standard attention stores full n×n matrix --- FlashAttention avoids this / 标准注意力存储完整 n×n 矩阵,FlashAttention 避免了这一点
- Tiling + recomputation + fusion work together to reduce HBM IO / 分块 + 重计算 + 算子融合协同降低 HBM IO
- Online softmax enables block-wise computation --- no need to see the full row / 在线 Softmax 支持分块计算,无需看到完整行
- Memory complexity drops from O(n2) to O(n) --- enabling much longer sequences / 内存复杂度从 O(n2) 降到 O(n),支持更长序列
4. 分布式数据并行训练 🖥️ / Distributed Data Parallel Training
🖥️ Note: 本章讲解如何用多块 GPU 加速训练,从朴素 DDP 到桶化梯度通信 / This chapter covers multi-GPU training, from naive DDP to bucketed gradient communication.
4.1 分布式通信基础 📡 / Distributed Communication Basics
在 PyTorch 中,分布式训练的基础是进程组(Process Group)和集合通信操作(Collective Communication)。
核心概念 📝:
- 节点(Node):网络中的一台机器
- 工作进程(Worker):参与分布式训练的程序实例
- 全局进程数(World Size):进程组中工作进程的总数
- 全局序号(Rank):唯一标识每个工作进程的 ID(0 到 world_size-1)
最简单的分布式示例:
python
import torch.distributed as dist # 导入分布式模块 📡
import torch.multiprocessing as mp # 导入多进程模块 🧵
def distributed_demo(rank, world_size): # 定义分布式示例函数
os.environ["MASTER_ADDR"] = "localhost" # 设置主节点地址
os.environ["MASTER_PORT"] = "29500" # 设置主节点端口
dist.init_process_group("gloo", rank=rank, world_size=world_size) # 初始化进程组
data = torch.randint(0, 10, (3,)) # 每个进程生成随机数据 🎲
print(f"rank {rank} 数据(全归约前): {data}") # 打印全归约前数据
dist.all_reduce(data, async_op=False) # 全归约操作:所有进程数据求和 📊
print(f"rank {rank} 数据(全归约后): {data}") # 打印全归约后数据(所有进程相同)
if __name__ == "__main__":
world_size = 4 # 4 个工作进程
mp.spawn(fn=distributed_demo, args=(world_size,), nprocs=world_size, join=True)
运行结果:每个进程最初持有不同的数据,经过 All-Reduce 后,所有进程持有相同的求和结果。
后端选择 🎯:
| 后端 | 适用场景 | 说明 |
|---|---|---|
| NCCL | GPU 训练(生产环境) | 基于 NVIDIA NCCL 库,性能最优 |
| Gloo | CPU 训练 / 本地开发 | 支持仅 CPU 机器 |
4.2 数据并行训练流程 🔄 / Data Parallel Training Workflow
数据并行性将数据批次拆分到多个 GPU 上,每个 GPU 持有完整模型副本。训练流程如下:
直观类比 📚:数据并行就像多个学生同时做同一份试卷的不同部分,最后把答案汇总------每个学生(GPU)看不同的题目(数据),但用同一本参考书(模型)。
markdown
1. 初始化:所有 GPU 持有相同的模型参数和优化器状态 🏗️
↓
2. 数据分片:将批次 n 个样本分给 d 个 GPU,每个 GPU 得到 n/d 个样本 📦
↓
3. 本地计算:每个 GPU 用自己的数据执行前向传播和反向传播,计算本地梯度 ⚡
↓
4. 梯度同步:通过 All-Reduce 对所有 GPU 的梯度求平均 🔄
↓
5. 参数更新:每个 GPU 用相同的平均梯度更新模型,保持同步 ✅
↓
6. 重复步骤 2-5,直到训练完成 🔁
4.3 朴素 DDP 实现 🛠️ / Naive DDP Implementation
最朴素的实现方式是:反向传播后,对每个参数张量单独执行 All-Reduce。
python
class NaiveDDP(torch.nn.Module): # 朴素 DDP 容器类 📦
def __init__(self, module):
super().__init__()
self.module = module
# 广播参数:确保所有进程持有相同初始参数 📡
for param in self.module.parameters():
dist.broadcast(param.data, src=0)
def forward(self, *inputs, **kwargs):
return self.module.forward(*inputs, **kwargs) # 直接转发前向传播
def sync_gradients(self): # 同步梯度 🔄
for param in self.module.parameters():
if param.grad is not None:
dist.all_reduce(param.grad.data) # 对每个参数单独 All-Reduce
param.grad.data /= dist.get_world_size() # 除以进程数,得到平均值
问题:每次 All-Reduce 调用都有通信开销,对于深度 Transformer 模型(数百个参数张量),开销累积显著。
4.4 渐进式优化 🔧 / Progressive Optimization
作业要求学生逐步优化 DDP 实现,每一步都测量性能提升:
4.4.1 批量通信(Flattening) 📦
将所有梯度拼接成单一张量,一次性执行 All-Reduce,减少通信调用次数。
python
# 拼接所有梯度为单一扁平化张量 📦
def sync_gradients_flat(self):
flat_grads = torch._utils._flatten_dense_tensors( # 拼接所有梯度
[p.grad.data for p in self.module.parameters() if p.grad is not None]
)
dist.all_reduce(flat_grads) # 一次 All-Reduce 代替 N 次 🚀
flat_grads /= dist.get_world_size() # 求平均
# 拆分回各个参数 📤
for p, g in zip(
[p for p in self.module.parameters() if p.grad is not None],
torch._utils._unflatten_dense_tensors(flat_grads, ...)
):
p.grad.data = g
4.4.2 计算与通信重叠(Overlap) 🔄
利用反向传播增量计算梯度的特点:某个参数的梯度就绪后,立即异步启动 All-Reduce,无需等待所有梯度计算完成。
python
class DDPOverlap(torch.nn.Module): # 带重叠的 DDP 容器 📦
def __init__(self, module):
super().__init__()
self.module = module
self.handles = [] # 异步操作句柄列表 📋
for param in self.module.parameters():
dist.broadcast(param.data, src=0)
if param.requires_grad:
# 注册梯度累积完成后的钩子 📎
param.register_post_accumulate_grad_hook(self._async_all_reduce)
def _async_all_reduce(self, param): # 异步 All-Reduce 回调 🔄
with torch.no_grad():
param.grad.data /= dist.get_world_size() # 先除以进程数
self.handles.append( # 记录异步句柄
dist.all_reduce(param.grad.data, async_op=True) # 异步 All-Reduce 🚀
)
def finish_gradient_synchronization(self): # 等待所有通信完成 ⏳
for handle in self.handles:
handle.wait() # 等待每个 All-Reduce 完成
self.handles.clear() # 清空句柄列表
直观类比 🏭:这就像工厂流水线------第一个车间完成零件后立即送到质检,不用等所有零件都做完。质检(通信)和生产(反向传播)同时进行。
4.4.3 桶化梯度通信(Bucketing) 🗑️
将参数分组为"桶"(bucket),每个桶内的梯度就绪后一起执行 All-Reduce。这结合了两种优势:
- 减少通信调用次数(桶内批量通信)
- 保持计算与通信的重叠(桶级别的重叠)
python
class DDPBucketed(torch.nn.Module): # 桶化 DDP 容器 📦
def __init__(self, module, bucket_size_mb=25):
super().__init__()
self.module = module
self.bucket_size_mb = bucket_size_mb # 桶大小(MB) 📏
# 按参数逆序分配桶(反向传播梯度就绪顺序大致与此相反) 🔄
self.buckets = self._create_buckets()
def _create_buckets(self): # 创建参数桶 🗑️
buckets = [] # 桶列表
current_bucket = [] # 当前桶
current_size = 0 # 当前桶大小
for param in reversed(list(self.module.parameters())): # 逆序遍历参数
param_size = param.numel() * param.element_size() # 参数大小(字节)
current_bucket.append(param) # 加入当前桶
current_size += param_size
if current_size >= self.bucket_size_mb * 1024 * 1024: # 超过桶大小限制
buckets.append(current_bucket) # 保存当前桶
current_bucket = [] # 新建桶
current_size = 0
if current_bucket:
buckets.append(current_bucket) # 保存最后一个桶
return buckets # 返回桶列表
桶大小的权衡 ⚖️:
- 桶太小:通信调用次数多,开销大
- 桶太大:重叠效果差,需等待更多梯度就绪才能通信
- 最优桶大小:取决于模型大小、网络带宽和通信开销
4.5 四维并行简介 🌐 / Four-Dimensional Parallelism
实际生产中,通常结合多种并行策略形成"四维并行":
| 并行维度 | 核心思想 | 说明 |
|---|---|---|
| 数据并行(DP) | 拆分数据批次 | 每个 GPU 处理不同的数据子集 |
| FSDP | 分片参数/梯度/优化器状态 | 每个 GPU 只持有参数的一个分片 |
| 张量并行(TP) | 拆分激活值 | 每个 GPU 处理操作的一部分 |
| 流水线并行(PP) | 按层拆分模型 | 不同 GPU 处理不同的层 |
参考资料:
- PyTorch DDP 官方文档 -- PyTorch ⭐值得阅读
- PyTorch 分布式训练实践 -- arXiv ⭐值得阅读
- Data Parallelism: DDP, Gradient Synchronization & All-Reduce -- Brenddörfer
- 模型训练数据并行 FSDP -- 知乎
💡 Key Takeaways / 核心要点
- DDP splits data, not model --- each GPU holds a full model copy / DDP 拆分数据而非模型,每个 GPU 持有完整模型副本
- All-Reduce synchronizes gradients --- all GPUs end up with the same averaged gradients / All-Reduce 同步梯度,所有 GPU 得到相同的平均梯度
- Overlap computation with communication --- async All-Reduce as soon as gradients are ready / 计算与通信重叠,梯度就绪后立即异步 All-Reduce
- Bucketing balances overhead and overlap --- group parameters to reduce calls while maintaining overlap / 桶化平衡开销与重叠,减少调用次数同时保持重叠
5. 优化器状态分片 💾 / Optimizer State Sharding
💾 Note: 本章讲解 ZeRO-1 风格的优化器状态分片 / This chapter covers ZeRO-1 style optimizer state sharding.
5.1 问题背景:显存冗余 🧱 / The Problem: Memory Redundancy
DDP 要求每个 GPU 持有模型参数和优化器状态的完整副本,这带来显著的显存冗余。
以 AdamW 优化器为例,它为每个参数维护两个浮点数(一阶动量 m 和二阶动量 v):
优化器状态=2×参数量×4 字节(FP32)
对于 2.7B 模型:
| 组件 | 显存占用 | 计算方式 |
|---|---|---|
| 模型参数(FP32) | ~10.8 GB | 2.7×109×4 |
| 优化器状态(AdamW) | ~21.6 GB | 2.7×109×2×4 |
| 梯度(FP32) | ~10.8 GB | 2.7×109×4 |
| 总计 | ~43.2 GB | 仅参数+优化器+梯度 |
直观类比 📚:DDP 就像给每个学生发一整本参考答案------如果 4 个学生,就有 4 本完全一样的答案,浪费 3 本。
5.2 ZeRO 优化思想 🎯 / ZeRO Optimization Idea
ZeRO(Zero Redundancy Optimizer)的核心思想是:将优化器状态分片到多个 GPU 上,消除冗余。
ZeRO 分为三个阶段:
| 阶段 | 分片内容 | 显存节省 | 说明 |
|---|---|---|---|
| ZeRO-1 | 优化器状态 | 4x | 每个 GPU 只维护 1/N 的优化器状态 |
| ZeRO-2 | + 梯度 | 8x | 梯度也分片存储 |
| ZeRO-3 | + 参数 | 与 GPU 数成正比 | 参数也分片存储 |
作业要求实现的是 ZeRO-1 的简化版本。
5.3 实现思路 🔨 / Implementation Approach
python
class ShardedOptimizer(torch.optim.Optimizer): # 分片优化器类 📦
def __init__(self, params, optimizer_cls, **kwargs):
# 将参数分配到不同 GPU(每个 GPU 只负责一部分) 📊
# 初始化包装后的优化器,只包含本 GPU 负责的参数 🎯
...
def step(self, closure=None, **kwargs): # 优化步骤 🚀
# 包装后的优化器只更新其分片内的参数 🎯
self.wrapped_optimizer.step(closure, **kwargs)
# 将更新后的参数广播到其他 GPU 📡
for param in self._my_shard_params:
dist.broadcast(param.data, src=self._param_owner(param))
工作流程:
markdown
1. 每个 GPU 的优化器只负责 1/N 的参数 🎯
↓
2. 执行 optimizer.step(),更新本地分片的参数 ⚡
↓
3. 将更新后的参数广播到其他 GPU 📡
↓
4. 所有 GPU 的参数保持同步 ✅
5.4 显存与通信的权衡 ⚖️ / Memory-Communication Trade-off
| 指标 | 无分片(DDP) | ZeRO-1 分片 |
|---|---|---|
| 优化器显存 | 完整(~21.6 GB) | 1/N(如 2 GPU 时 ~10.8 GB) |
| 通信开销 | 仅梯度 All-Reduce | 梯度 All-Reduce + 参数广播 |
| 实现复杂度 | 低 | 中 |
ZeRO-1 的核心权衡:用额外的通信换取显存空间。对于显存受限的场景(如单卡无法容纳优化器状态),这是非常值得的。
参考资料:
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models -- arXiv ⭐值得阅读
- ZeRO Optimization: Stages 1, 2, and 3 Explained -- Brenddörfer
- PyTorch FSDP 官方教程 -- PyTorch ⭐值得阅读
- 如何用数据并行训练万亿参数模型 -- 知乎
💡 Key Takeaways / 核心要点
- DDP has massive memory redundancy --- every GPU holds full optimizer states / DDP 存在大量显存冗余,每个 GPU 持有完整优化器状态
- ZeRO-1 shards optimizer states across GPUs --- each GPU only maintains 1/N / ZeRO-1 将优化器状态分片到多个 GPU,每个 GPU 只维护 1/N
- Trade communication for memory --- broadcast updated params after local optimization / 用通信换显存,本地优化后广播更新参数
- AdamW doubles the parameter memory --- 2 floating point states per parameter / AdamW 使参数显存翻倍,每个参数 2 个浮点数状态
6. 总结 📝 / Summary
本节我们完成了 CS336 Assignment 2 的全面解析,核心要点回顾:🎯
| 主题 | 核心技术 | 关键收获 |
|---|---|---|
| 性能分析 🔬 | 端到端基准测试、Nsight 分析器、混合精度、内存分析 | 先分析再优化,避免优化非瓶颈组件 |
| FlashAttention-2 ⚡ | 分块计算、在线 Softmax、重计算、算子融合 | 内存复杂度从 O(n2) 降到 O(n) |
| 分布式训练 🖥️ | All-Reduce、DDP、梯度桶化、计算通信重叠 | 从朴素到优化的渐进式工程思维 |
| 优化器分片 💾 | ZeRO-1、参数广播、显存去冗余 | 用通信换显存的经典权衡 |
核心思维模型 🧠:
单 GPU 优化:
性能分析 → 定位瓶颈 → FlashAttention(注意力优化)→ 混合精度(计算加速)
多 GPU 扩展:
数据并行 → 梯度同步 → 计算通信重叠 → 桶化通信 → 优化器分片
三大"墙"与对应解法 🧱:
| 挑战 | 解法 | 核心技术 |
|---|---|---|
| 显存墙 🧱 | FlashAttention-2 + 混合精度 | 分块计算、重计算、算子融合 |
| 计算墙 ⚡ | Tensor Cores + 算子融合 | BF16 混合精度、内核融合 |
| 通信墙 📡 | 梯度桶化 + 计算重叠 + ZeRO | 异步 All-Reduce、优化器分片 |
参考资料:
- Stanford CS336 官方课程网站 -- Stanford ⭐值得阅读
- Stanford CS336 Assignment 2 GitHub 仓库 -- GitHub ⭐值得阅读
- Standford CS336 Language Modeling from Scratch通关总结 -- 知乎
- 斯坦福大模型课CS336,硬核到让人怀疑人生? -- 博客园
最后更新时间:2026-06-21