09c-斯坦福CS336作业二:系统与分布式训练

09c-斯坦福CS336作业二:系统与分布式训练 ⚙️

本文档详细解析斯坦福 CS336 课程 Assignment 2 的核心内容,涵盖单 GPU 性能分析与优化(基准测试、混合精度、内存分析)、FlashAttention-2 的 Triton 内核实现、分布式数据并行训练(DDP)的渐进式优化,以及优化器状态分片(ZeRO-1)技术。通过理论与实践相结合,帮助读者深入理解大模型训练中的系统工程问题 🔧

This document provides an in-depth analysis of Stanford CS336 Assignment 2, covering single-GPU profiling and optimization (benchmarking, mixed precision, memory analysis), FlashAttention-2 Triton kernel implementation, progressive optimization of Distributed Data Parallel training (DDP), and optimizer state sharding (ZeRO-1) 🛠️


术语表 / Terminology

术语 / Term 中文 说明 / Description
Profiling 性能分析 测量程序在时间和内存维度的资源消耗,定位性能瓶颈
Benchmarking 基准测试 在标准化条件下测量模型前向/反向传播的速度和显存
Mixed Precision 混合精度 同时使用 FP16/BF16 和 FP32 进行训练,加速计算并保持数值稳定
FlashAttention-2 闪存注意力 2 通过分块计算、重计算和算子融合优化注意力机制,降低内存 IO 开销
Tiling 分块计算 将输入数据分割为小块在 SRAM 中计算,避免在 HBM 中存储完整注意力矩阵
Recomputation 重计算 前向传播时不存储中间激活值,反向传播时重新计算,以时间换空间
Operator Fusion 算子融合 将多个操作合并到单个 GPU 内核中执行,减少内存读写次数
DDP (Distributed Data Parallel) 分布式数据并行 将数据批次拆分到多个 GPU,每个 GPU 持有完整模型副本,梯度通过 All-Reduce 同步
All-Reduce 全归约 集合通信操作,将所有进程的数据求和并广播给所有进程
FSDP (Fully Sharded Data Parallel) 全分片数据并行 将参数、梯度和优化器状态分片到多个 GPU,降低单卡显存占用
ZeRO (Zero Redundancy Optimizer) 零冗余优化器 通过在数据并行进程间划分优化器状态、梯度和参数来减少显存冗余
Triton Triton 编程语言 OpenAI 开发的 GPU 编程语言,提供比 CUDA 更简洁的编程接口
HBM (High Bandwidth Memory) 高带宽存储器 GPU 上的全局显存,容量大但访问速度相对较慢
SRAM 静态随机存取存储器 GPU 芯片上的高速缓存,速度极快但容量有限
NCCL NVIDIA 集合通信库 NVIDIA 提供的高性能多 GPU 集合通信库

章节阅读路线图 🗺️ / Chapter Reading Roadmap

  1. 作业概述 📋 / Assignment Overview → 了解 Assignment 2 的整体目标、结构和核心任务
  2. 性能分析与基准测试 🔬 / Performance Profiling & Benchmarking → 掌握端到端计时、Nsight 分析器、混合精度和内存分析
  3. FlashAttention-2 优化 ⚡ / FlashAttention-2 Optimization → 从分块计算到 Triton 内核,实现高效注意力机制
  4. 分布式数据并行训练 🖥️ / Distributed Data Parallel Training → 从朴素 DDP 到桶化梯度通信的渐进式优化
  5. 优化器状态分片 💾 / Optimizer State Sharding → 实现 ZeRO-1 风格的优化器状态分片
  6. 总结 📝 / Summary → 回顾核心要点

1. 作业概述 📋 / Assignment Overview

📦 Note: 本章介绍 Assignment 2 的整体目标、结构和核心任务 / This chapter introduces the overall goals, structure, and core tasks of Assignment 2.

1.1 作业定位

Assignment 2 是 CS336 课程的第二个作业,聚焦于 系统优化和分布式训练 ------即如何让大模型训练"跑得更快、用得更少、扩得更大"。

直观类比 🏭:如果说 Assignment 1 是"造出一辆能开的车",那么 Assignment 2 就是"把这辆车改装成赛车"------不仅要能跑,还要跑得又快又稳又省油耗。

学生需要在 Assignment 1 构建的 Transformer 语言模型基础上,完成以下四大核心任务:

任务 核心内容 分值占比
性能分析与基准测试 🔬 端到端计时、Nsight 分析、混合精度、内存分析 ~20 分
FlashAttention-2 ⚡ 分块计算、Triton 内核、在线 Softmax ~30 分
分布式数据并行训练 🖥️ All-Reduce、DDP、梯度桶化、计算通信重叠 ~30 分
优化器状态分片 💾 ZeRO-1 实现、显存分析 ~20 分

1.2 模型规模定义

作业中使用了多种模型规模进行基准测试,所有模型的词汇表大小均为 10,000,批量大小为 4:

规模 dmodel d_{\text{model}} dmodel(模型维度) dff d_{ff} dff(前馈网络维度) num_layers\text{num\_layers} num_layers(层数) num_heads\text{num\_heads} num_heads(注意力头数)
small(小型) 768 3072 12 12
medium(中型) 1024 4096 24 16
large(大型) 1280 5120 36 20
xl(超大) 1600 6400 48 25
2.7B(27 亿参数) 2560 10240 32 32

1.3 为什么需要系统优化? 🤔

在 Assignment 1 中,我们实现了能训练的 Transformer 模型。但在实际生产中,训练一个 2.7B 参数的模型面临着严峻的系统挑战:

  1. 显存墙(Memory Wall) 🧱:注意力机制的中间矩阵大小为 O(n2)O(n^2) O(n2),序列长度翻倍,显存占用翻四倍
  2. 计算墙(Compute Wall) ⚡:矩阵乘法虽占主导,但非矩阵运算(如 Softmax、LayerNorm)的内存 IO 开销不可忽视
  3. 通信墙(Communication Wall) 📡:多 GPU 训练时,梯度同步的通信开销可能抵消并行化带来的加速

直观类比 🚗:系统优化就像改装赛车------

  • 减重(内存优化)→ 让车更轻,跑得更快
  • 升级引擎(FlashAttention)→ 让单位时间做更多功
  • 多引擎协同(分布式训练)→ 让多个引擎高效配合

参考资料:


2. 性能分析与基准测试 🔬 / Performance Profiling & Benchmarking

🔬 Note: 本章讲解如何对模型进行性能分析,定位优化机会 / This chapter explains how to profile models and identify optimization opportunities.

在进行任何优化之前,先对程序进行性能分析至关重要 ------否则可能耗费精力优化那些对整体性能影响甚微的模块,最终无法实现可量化的端到端提升。

直观类比 🏥:性能分析就像给程序做"体检"------先做全面检查(基准测试),再用精密仪器(Profiler)定位具体问题,最后对症下药(针对性优化)。

2.1 端到端基准测试 ⏱️ / End-to-End Benchmarking

最基础的性能分析方式是计时正向传播和反向传播过程

CUDA 异步陷阱 ⚠️:在 PyTorch 中,CUDA 调用是异步的------调用 torch.matmul 后,CPU 会立即返回,不等 GPU 完成计算。因此,直接测量 Python 函数调用的返回时间不能反映 GPU 实际执行耗时。

正确做法 :每次测量后调用 torch.cuda.synchronize() 等待所有 GPU 内核执行完成。

python 复制代码
import timeit                                             # 导入高精度计时模块 ⏱️
import torch                                              # 导入 PyTorch 核心库 🔥

def benchmark(model, data, n_warmup=5, n_steps=10):       # 定义基准测试函数
    """对模型的正向传播和反向传播进行端到端基准测试

    参数:
        model: PyTorch 模型实例
        data: 输入数据
        n_warmup: 热身步数(不计入计时)
        n_steps: 正式计时步数

    返回:
        avg_time: 平均每步耗时(秒)
    """
    # 热身阶段:让 GPU 完成初始化、JIT 编译等一次性开销 🏃
    for _ in range(n_warmup):                               # 执行热身步骤
        output = model(data)                                # 前向传播
        output.sum().backward()                             # 反向传播

    torch.cuda.synchronize()                                # 等待 GPU 完成所有热身操作 ⏳

    # 正式计时阶段 📊
    start = timeit.default_timer()                          # 记录开始时间(最高分辨率时钟)
    for _ in range(n_steps):                                # 执行 n_steps 次迭代
        output = model(data)                                # 前向传播
        output.sum().backward()                             # 反向传播
        torch.cuda.synchronize()                            # 每次迭代后同步 GPU ⏳
    end = timeit.default_timer()                            # 记录结束时间

    avg_time = (end - start) / n_steps                      # 计算平均每步耗时
    return avg_time                                         # 返回结果

关键细节 📝:

  • 热身步骤不可少:GPU 首次执行操作时有 JIT 编译、显存分配等一次性开销,不热身会导致测量结果偏大
  • 使用 timeit.default_timer() :提供系统最高分辨率的时钟,比 time.time() 更适合基准测试
  • 多次测量取平均:减少随机波动的影响,同时报告标准差以评估变异性

参考资料:

2.2 Nsight Systems 性能分析器 🔍 / Nsight Systems Profiler

端到端基准测试无法告诉我们时间在各个组件上的具体分布 ,因此无法精准定位优化机会。NVIDIA 提供的 nsys 工具可以分析 CUDA 内核级别的执行细节。

使用方法 :在 Python 脚本前加上 nsys profile 前缀即可:

bash 复制代码
uv run nsys profile -o result python benchmark.py           # 使用 nsys 分析脚本性能

nsys 能告诉我们什么? 📊:

  • 每个 CUDA 内核的累计 GPU 耗时排名
  • 矩阵乘法、Softmax、LayerNorm 等操作各占多少时间
  • 正向传播和反向传播中,耗时最长的内核是否相同
  • CPU-GPU 之间的数据传输开销

NVTX 范围标注 🏷️:可以使用 NVTX 标注代码中的特定区域,在性能分析结果中以块的形式显示:

python 复制代码
import torch.cuda.nvtx as nvtx                              # 导入 NVTX 标注模块 🏷️

@nvtx.range("缩放点积注意力")                                # 标注注意力计算范围
def annotated_scaled_dot_product_attention(Q, K, V):        # 定义带标注的注意力函数
    with nvtx.range("计算注意力分数"):                        # 标注子操作:Q × K^T
        scores = torch.matmul(Q, K.transpose(-2, -1))       # 计算点积
    with nvtx.range("计算 softmax"):                          # 标注子操作:归一化
        weights = torch.softmax(scores, dim=-1)              # Softmax 归一化
    with nvtx.range("最终矩阵乘法"):                          # 标注子操作:加权求和
        output = torch.matmul(weights, V)                    # 加权求和
    return output                                           # 返回输出

关键发现 🔍:通过 nsys 分析可以发现------

  • 矩阵乘法(GEMM)占正向传播耗时的大部分,但并非全部
  • Softmax 的实际运行开销远高于其 FLOPs 所暗示的水平------它是"低 FLOPs 但高延迟"的瓶颈操作
  • 完整训练时,非 GEMM 操作(梯度缩放、激活函数导数、优化器更新)的时间占比显著上升

参考资料:

2.3 混合精度训练 🎚️ / Mixed Precision Training

现代 NVIDIA GPU 配备了专门的 Tensor Cores ,可加速低精度下的矩阵乘法运算。例如,NVIDIA A100 在 FP32 精度下最大吞吐量为 19.5 TFLOP/秒,而在 BF16 精度下可提升至 312 TFLOP/秒------提速约 16 倍

直观类比 💡:混合精度就像用不同精度的尺子量东西------大部分测量用"粗尺子"(BF16)快速完成,关键位置用"细尺子"(FP32)保证准确性。

为什么不能全部用低精度? 🤔

  • FP16 的动态范围小,某些梯度值过小会变为零
  • FP16 容易溢出,导致损失值变为 NaN
  • BF16 训练更稳定(动态范围与 FP32 相同),但仍可能影响最终模型性能

混合精度训练 (Mixed Precision Training)通过 torch.autocast 上下文管理器实现------部分操作(如矩阵乘法)以低精度执行,其他操作(如累加、归约)保持 FP32 精度:

python 复制代码
model = TransformerModel()                                  # 创建模型实例 🏗️
dtype = torch.bfloat16                                      # 选择 BF16 精度 🎚️
x = input_data                                              # 输入数据

with torch.autocast(device="cuda", dtype=dtype):            # 开启自动混合精度 🎯
    y = model(x)                                            # 前向传播:矩阵乘法自动用 BF16
    # LayerNorm、Softmax 等操作自动保持 FP32 精度 ✅

混合精度中各组件的数据类型 📋:

组件 数据类型 原因
模型参数(autocast 内) FP32 参数本身不自动转换
前馈层(Linear)输出 BF16 矩阵乘法自动降精度
LayerNorm 输出 FP32 方差计算对精度敏感
损失(Loss) FP32 保持数值稳定
梯度(Gradients) FP32 累加需要高精度

参考资料:

2.4 内存分析 💾 / Memory Profiling

PyTorch 内置了强大的内存分析器,可跟踪随时间变化的显存分配情况。

python 复制代码
# 开始记录内存历史 📊
torch.cuda.memory._record_memory_history(max_entries=1000000)

# 运行要分析的代码(前向传播、反向传播等) 🏃
output = model(data)
output.sum().backward()

# 保存内存快照(Pickle 格式) 💾
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# 停止记录 🛑
torch.cuda.memory._record_memory_history(enabled=None)

生成的 memory_snapshot.pickle 文件可以上传到 PyTorch 内存可视化工具,在浏览器中查看:

  • 活跃内存时间线(Active Memory Timeline):显示显存随时间的变化曲线
  • 峰值内存使用量:识别显存使用的高峰时刻
  • 每个内存分配的调用栈:定位哪些操作消耗了最多显存

Transformer 显存消耗分析 📐:

以 2.7B 模型为例( dmodel=2560 d_{\text{model}} = 2560 dmodel=2560, num_layers=32\text{num\_layers} = 32 num_layers=32):

组件 显存占用 说明
模型参数(FP32) ~10.8 GB 2.7×109×42.7 \times 10^9 \times 4 2.7×109×4 字节
优化器状态(AdamW) ~21.6 GB 每个参数 2 个浮点数(动量 + 方差)
梯度(FP32) ~10.8 GB 与参数相同大小
激活值(依赖序列长度) 变化 注意力矩阵为 O(n2)O(n^2) O(n2)

激活张量的显存计算 🧮:对于残差流中的激活张量(形状 batch_size,seq_len,dmodel\\text{batch\\_size}, \\text{seq\\_len}, d_{\\text{model}} batch_size,seq_len,dmodel):
显存=batch_size×seq_len×dmodel×bytes_per_element\text{显存} = \text{batch\_size} \times \text{seq\len} \times d{\text{model}} \times \text{bytes\_per\_element} 显存=batch_size×seq_len×dmodel×bytes_per_element

例如, batch_size=1\text{batch\_size} = 1 batch_size=1, seq_len=256\text{seq\len} = 256 seq_len=256, dmodel=2560 d{\text{model}} = 2560 dmodel=2560,FP32 精度:
1×256×2560×4=2,621,440 字节≈2.5 MB1 \times 256 \times 2560 \times 4 = 2{,}621{,}440 \text{ 字节} \approx 2.5 \text{ MB} 1×256×2560×4=2,621,440 字节≈2.5 MB

虽然单个激活张量不大,但注意力层的中间矩阵( seq_len×seq_len\text{seq\_len} \times \text{seq\_len} seq_len×seq_len)在长序列时会急剧膨胀,成为显存瓶颈。


参考资料:

💡 Key Takeaways / 核心要点

  • Profile before optimizing --- avoid wasting time on non-bottleneck components / 先分析再优化,避免优化非瓶颈组件
  • CUDA operations are asynchronous --- always synchronize before timing / CUDA 操作是异步的,计时前必须同步
  • Mixed precision can yield 16x speedup on Tensor Cores / 混合精度在 Tensor Cores 上可提速 16 倍
  • Memory profiling reveals hidden bottlenecks --- attention matrices dominate at long sequences / 内存分析揭示隐藏瓶颈,注意力矩阵在长序列时主导显存

3. FlashAttention-2 优化 ⚡ / FlashAttention-2 Optimization

Note: 本章讲解 FlashAttention-2 的核心技术和 Triton 内核实现 / This chapter covers FlashAttention-2 core techniques and Triton kernel implementation.

3.1 标准注意力机制的低效性 🐢 / Inefficiency of Standard Attention

标准注意力机制的前向传播包含三步:

  1. 计算注意力分数: S=QKT/ dk S = Q K^T / \sqrt{d_k} S=QKT/dk
  2. Softmax 归一化: P=softmax(S)P = \text{softmax}(S) P=softmax(S)
  3. 加权求和: O=PVO = P V O=PV

问题在哪? 🤔

朴素实现需要在 GPU 的 HBM(高带宽存储器)中存储完整的 seq_len×seq_len\text{seq\_len} \times \text{seq\_len} seq_len×seq_len 注意力分数矩阵 PP P。当序列长度较长时,这个矩阵会极大:

序列长度 注意力矩阵大小(FP32) 说明
512 1 MB 可接受
4096 64 MB 开始显著
16384 1 GB 显存瓶颈
65536 16 GB 单卡无法容纳

直观类比 📦:标准注意力就像在一张巨大的纸上写完所有答案,再抄到最终卷子上------纸(HBM)很大但搬运很慢,而且大部分内容抄完就不需要了。

根本原因 :标准实现在 HBM 和 SRAM 之间频繁传输 PP P 及其他大型激活值,产生极高的内存 IO 开销 。例如,标准反向传播会在计算 dV=PTdOdV = P^T dO dV=PTdO 和 dS=dsoftmax(dP)dS = d\text{softmax}(dP) dS=dsoftmax(dP) 时两次从 HBM 读取 PP P。

3.2 FlashAttention-2 的三大核心技术 🔧 / Three Core Techniques

FlashAttention-2 通过三种技术协同工作,避免在 HBM 中读写完整的注意力矩阵:

3.2.1 分块计算(Tiling) 🧩

将输入数据分割为多个小块(tile),在 GPU 的 SRAM(高速缓存)中完成计算,避免将完整注意力矩阵写入 HBM。

直观类比 🍕:分块计算就像吃披萨------不会一口吞下整个披萨,而是一块一块地吃。每块大小受 SRAM 容量限制,但吃完所有块后,效果等同于吃完整个披萨。

具体来说,将 QQ Q 分割为 Tq T_q Tq 个大小为 Bq×d B_q \times d Bq×d 的分块,将 KK K、 VV V 分割为 Tk T_k Tk 个大小为 Bk×d B_k \times d Bk×d 的分块:
Q1,..., QTq (每个 Bq×d) Q_1, \ldots, Q_{T_q} \quad (\text{每个 } B_q \times d) Q1,...,QTq(每个 Bq×d)
K(1),...,K (Tk) ,V(1),...,V (Tk) (每个 Bk×d)K^{(1)}, \ldots, K^{(T_k)}, \quad V^{(1)}, \ldots, V^{(T_k)} \quad (\text{每个 } B_k \times d) K(1),...,K(Tk),V(1),...,V(Tk)(每个 Bk×d)

3.2.2 重计算(Recomputation) 🔄

不再将 seq_len×seq_len\text{seq\_len} \times \text{seq\_len} seq_len×seq_len 的注意力矩阵 PP P 存储在 HBM 中,而是在反向传播时重新计算。

以时间换空间 :前向传播时只保存少量"检查点"( QQ Q、 KK K、 VV V 和 logsumexp LL L),反向传播时重新计算 PP P。这样,内存复杂度从 O(n2)O(n^2) O(n2) 降低到 O(n)O(n) O(n)。

FlashAttention-2 还会存储注意力分数的对数求和指数(logsumexp) LL L,用于简化反向传播计算:
Li=log⁡ (∑jexp⁡( Sij )) L_i = \log\left(\sum_j \exp(S_{ij})\right) Li=log(j∑exp(Sij))

借助 LL L 和预计算的 D=rowsum(O∘dO)D = \text{rowsum}(O \circ dO) D=rowsum(O∘dO),反向传播可在无需执行 softmax 运算的情况下完成:
Pij =exp⁡( Sij −Li) P_{ij} = \exp(S_{ij} - L_i) Pij=exp(Sij−Li)
d Sij = Pij ∘(d Pij −Di) dS_{ij} = P_{ij} \circ (dP_{ij} - D_i) dSij=Pij∘(dPij−Di)

3.2.3 算子融合(Operator Fusion) 🔀

将所有注意力操作( QKTQK^T QKT、缩放、Softmax、加权求和)融合到单个 Triton 内核中执行,避免对中间结果的重复内存 IO。

直观类比 🍳:算子融合就像做菜时把所有步骤在一个锅里完成------切好的菜直接下锅炒,不用每次都盛出来再倒回去。

3.3 在线 Softmax 算法 📊 / Online Softmax Algorithm

分块计算面临一个关键挑战:Softmax 需要整行数据才能计算分母,但我们对行进行了分块,无法一次性看到所有数据。

在线 Softmax 解决了这个问题------通过维护两个累计值,在遍历每个分块时增量更新 Softmax 结果:

  • mi(j) m_i^{(j)} mi(j):累计最大值,保证数值稳定性
  • li(j) l_i^{(j)} li(j):Softmax 分母的累计代理值

每处理一个新的键分块 jj j,更新规则为:
mi(j) =max⁡ ( mi(j−1) ,rowmax( Si(j) )) m_i^{(j)} = \max\left(m_i^{(j-1)}, \text{rowmax}(S_i^{(j)})\right) mi(j)=max(mi(j−1),rowmax(Si(j)))
P~i(j) =exp⁡ ( Si(j) − mi(j) ) \tilde{P}_i^{(j)} = \exp\left(S_i^{(j)} - m_i^{(j)}\right) P~i(j)=exp(Si(j)−mi(j))
li(j) =exp⁡ ( mi(j−1) − mi(j) ) ⋅ li(j−1) +rowsum ( P~i(j) ) l_i^{(j)} = \exp\left(m_i^{(j-1)} - m_i^{(j)}\right) \cdot l_i^{(j-1)} + \text{rowsum}\left(\tilde{P}_i^{(j)}\right) li(j)=exp(mi(j−1)−mi(j))⋅li(j−1)+rowsum(P~i(j))
Oi(j) =diag (exp⁡ ( mi(j−1) − mi(j) ) ) ⋅ Oi(j−1) + P~i(j) V(j) O_i^{(j)} = \text{diag}\left(\exp\left(m_i^{(j-1)} - m_i^{(j)}\right)\right) \cdot O_i^{(j-1)} + \tilde{P}_i^{(j)} V^{(j)} Oi(j)=diag(exp(mi(j−1)−mi(j)))⋅Oi(j−1)+P~i(j)V(j)

最终归一化:
Oi=diag( ( li (Tk) ) −1 )⋅ Oi (Tk) O_i = \text{diag}\left(\left(l_i^{(T_k)}\right)^{-1}\right) \cdot O_i^{(T_k)} Oi=diag((li(Tk))−1)⋅Oi(Tk)

直观类比 🗳️:在线 Softmax 就像实时计票------每收到一批新票,不需要重新数之前的票,只需要更新"当前最高票"和"总票数"两个计数器,就能得到正确的百分比。

3.4 FlashAttention-2 前向传播算法 📝 / FA2 Forward Pass Algorithm

算法 1 FlashAttention-2 前向传播

输入 Q∈R Nq×d Q \in \mathbb{R}^{N_q \times d} Q∈RNq×d, K,V∈R Nk×d K, V \in \mathbb{R}^{N_k \times d} K,V∈RNk×d,分块大小 Bq,Bk B_q, B_k Bq,Bk

  1. QQ Q 分割为 Tq=⌈Nq/Bq⌉ T_q = \lceil N_q / B_q \rceil Tq=⌈Nq/Bq⌉ 个分块 Q1,..., QTq Q_1, \ldots, Q_{T_q} Q1,...,QTq
  2. K,VK, V K,V 分割为 Tk=⌈Nk/Bk⌉ T_k = \lceil N_k / B_k \rceil Tk=⌈Nk/Bk⌉ 个分块
  3. 对于 i=1,...,Tqi = 1, \ldots, T_q i=1,...,Tq:
    • 从全局内存加载 Qi Q_i Qi
    • 初始化 Oi(0) =0 O_i^{(0)} = 0 Oi(0)=0, li(0) =0 l_i^{(0)} = 0 li(0)=0, mi(0) =−∞ m_i^{(0)} = -\infty mi(0)=−∞
    • 对于 j=1,...,Tkj = 1, \ldots, T_k j=1,...,Tk:
      • 从全局内存加载 K(j),V(j)K^{(j)}, V^{(j)} K(j),V(j)
      • 计算 Si(j) =Qi(K(j))T/d S_i^{(j)} = Q_i (K^{(j)})^T / \sqrt{d} Si(j)=Qi(K(j))T/d
      • 更新 mi(j) m_i^{(j)} mi(j)、 P~i(j) \tilde{P}_i^{(j)} P~i(j)、 li(j) l_i^{(j)} li(j)、 Oi(j) O_i^{(j)} Oi(j)(见 3.3 节公式)
    • 归一化: Oi=diag(( li (Tk) )−1)⋅ Oi (Tk) O_i = \text{diag}((l_i^{(T_k)})^{-1}) \cdot O_i^{(T_k)} Oi=diag((li(Tk))−1)⋅Oi(Tk)
    • 计算 logsumexp: Li= mi (Tk) +log⁡( li (Tk) ) L_i = m_i^{(T_k)} + \log(l_i^{(T_k)}) Li=mi(Tk)+log(li(Tk))
  4. Oi O_i Oi 和 Li L_i Li 写入全局内存

输出 :输出矩阵 OO O 和 logsumexp 矩阵 LL L

关键优势

特性 标准注意力 FlashAttention-2
内存复杂度 O(n2)O(n^2) O(n2) O(n)O(n) O(n)
HBM 读写 完整注意力矩阵 仅分块数据
数值稳定性 需要单独处理 在线算法内置保证
长序列支持 受限(OOM) 支持超长序列

3.5 Triton 内核实现要点 🔨 / Triton Kernel Implementation Notes

作业要求学生使用 Triton 编程语言实现 FlashAttention-2 内核。以下是关键实现要点:

Triton 基础:Triton 是 OpenAI 开发的 GPU 编程语言,提供比 CUDA 更简洁的接口。每个 Triton 程序实例是一个线程块,所有线程执行同一程序,可在 GPU 上并行执行。

python 复制代码
import triton                                              # 导入 Triton 核心库 🔥
import triton.language as tl                                # 导入 Triton 语言模块

@triton.jit                                                # JIT 编译装饰器 ⚡
def flash_fwd_kernel(
    Q_ptr, K_ptr, V_ptr,                                    # 输入张量指针 🔍🔑💎
    O_ptr, L_ptr,                                           # 输出张量指针 📤
    stride_qb, stride_qq, stride_qd,                        # Q 的步长参数
    N_QUERIES, N_KEYS,                                      # 序列长度
    scale,                                                   # 缩放因子 $1/\sqrt{d}$
    D: tl.constexpr,                                         # 隐藏维度(编译期常量)
    Q_TILE_SIZE: tl.constexpr,                               # 查询分块大小
    K_TILE_SIZE: tl.constexpr,                               # 键分块大小
):
    query_tile_index = tl.program_id(0)                      # 查询分块索引 📍
    batch_index = tl.program_id(1)                           # 批次索引 📍
    # ... 分块计算逻辑 ...

关键实现技巧 🛠️:

  • 使用 tl.make_block_ptr 简化指针运算
  • 片上缓冲区( Oi O_i Oi, ll l, mm m)应使用 tl.float32 保证精度
  • 使用 tl.dot 执行矩阵乘法
  • 启动网格设置为 (Tq,batch_size) (T_q, \text{batch\_size}) (Tq,batch_size)

参考资料:

💡 Key Takeaways / 核心要点

  • Standard attention stores full n×nn \times n n×n matrix --- FlashAttention avoids this / 标准注意力存储完整 n×nn \times n n×n 矩阵,FlashAttention 避免了这一点
  • Tiling + recomputation + fusion work together to reduce HBM IO / 分块 + 重计算 + 算子融合协同降低 HBM IO
  • Online softmax enables block-wise computation --- no need to see the full row / 在线 Softmax 支持分块计算,无需看到完整行
  • Memory complexity drops from O(n2)O(n^2) O(n2) to O(n)O(n) O(n) --- enabling much longer sequences / 内存复杂度从 O(n2)O(n^2) O(n2) 降到 O(n)O(n) O(n),支持更长序列

4. 分布式数据并行训练 🖥️ / Distributed Data Parallel Training

🖥️ Note: 本章讲解如何用多块 GPU 加速训练,从朴素 DDP 到桶化梯度通信 / This chapter covers multi-GPU training, from naive DDP to bucketed gradient communication.

4.1 分布式通信基础 📡 / Distributed Communication Basics

在 PyTorch 中,分布式训练的基础是进程组(Process Group)集合通信操作(Collective Communication)

核心概念 📝:

  • 节点(Node):网络中的一台机器
  • 工作进程(Worker):参与分布式训练的程序实例
  • 全局进程数(World Size):进程组中工作进程的总数
  • 全局序号(Rank):唯一标识每个工作进程的 ID(0 到 world_size-1)

最简单的分布式示例

python 复制代码
import torch.distributed as dist                            # 导入分布式模块 📡
import torch.multiprocessing as mp                          # 导入多进程模块 🧵

def distributed_demo(rank, world_size):                     # 定义分布式示例函数
    os.environ["MASTER_ADDR"] = "localhost"                 # 设置主节点地址
    os.environ["MASTER_PORT"] = "29500"                     # 设置主节点端口
    dist.init_process_group("gloo", rank=rank, world_size=world_size)  # 初始化进程组

    data = torch.randint(0, 10, (3,))                       # 每个进程生成随机数据 🎲
    print(f"rank {rank} 数据(全归约前): {data}")              # 打印全归约前数据
    dist.all_reduce(data, async_op=False)                   # 全归约操作:所有进程数据求和 📊
    print(f"rank {rank} 数据(全归约后): {data}")              # 打印全归约后数据(所有进程相同)

if __name__ == "__main__":
    world_size = 4                                          # 4 个工作进程
    mp.spawn(fn=distributed_demo, args=(world_size,), nprocs=world_size, join=True)

运行结果:每个进程最初持有不同的数据,经过 All-Reduce 后,所有进程持有相同的求和结果。

后端选择 🎯:

后端 适用场景 说明
NCCL GPU 训练(生产环境) 基于 NVIDIA NCCL 库,性能最优
Gloo CPU 训练 / 本地开发 支持仅 CPU 机器

4.2 数据并行训练流程 🔄 / Data Parallel Training Workflow

数据并行性将数据批次拆分到多个 GPU 上,每个 GPU 持有完整模型副本。训练流程如下:

直观类比 📚:数据并行就像多个学生同时做同一份试卷的不同部分,最后把答案汇总------每个学生(GPU)看不同的题目(数据),但用同一本参考书(模型)。

markdown 复制代码
1. 初始化:所有 GPU 持有相同的模型参数和优化器状态 🏗️
   ↓
2. 数据分片:将批次 n 个样本分给 d 个 GPU,每个 GPU 得到 n/d 个样本 📦
   ↓
3. 本地计算:每个 GPU 用自己的数据执行前向传播和反向传播,计算本地梯度 ⚡
   ↓
4. 梯度同步:通过 All-Reduce 对所有 GPU 的梯度求平均 🔄
   ↓
5. 参数更新:每个 GPU 用相同的平均梯度更新模型,保持同步 ✅
   ↓
6. 重复步骤 2-5,直到训练完成 🔁

4.3 朴素 DDP 实现 🛠️ / Naive DDP Implementation

最朴素的实现方式是:反向传播后,对每个参数张量单独执行 All-Reduce。

python 复制代码
class NaiveDDP(torch.nn.Module):                            # 朴素 DDP 容器类 📦
    def __init__(self, module):
        super().__init__()
        self.module = module
        # 广播参数:确保所有进程持有相同初始参数 📡
        for param in self.module.parameters():
            dist.broadcast(param.data, src=0)

    def forward(self, *inputs, **kwargs):
        return self.module.forward(*inputs, **kwargs)        # 直接转发前向传播

    def sync_gradients(self):                                # 同步梯度 🔄
        for param in self.module.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad.data)              # 对每个参数单独 All-Reduce
                param.grad.data /= dist.get_world_size()       # 除以进程数,得到平均值

问题:每次 All-Reduce 调用都有通信开销,对于深度 Transformer 模型(数百个参数张量),开销累积显著。

4.4 渐进式优化 🔧 / Progressive Optimization

作业要求学生逐步优化 DDP 实现,每一步都测量性能提升:

4.4.1 批量通信(Flattening) 📦

将所有梯度拼接成单一张量,一次性执行 All-Reduce,减少通信调用次数。

python 复制代码
# 拼接所有梯度为单一扁平化张量 📦
def sync_gradients_flat(self):
    flat_grads = torch._utils._flatten_dense_tensors(         # 拼接所有梯度
        [p.grad.data for p in self.module.parameters() if p.grad is not None]
    )
    dist.all_reduce(flat_grads)                               # 一次 All-Reduce 代替 N 次 🚀
    flat_grads /= dist.get_world_size()                       # 求平均
    # 拆分回各个参数 📤
    for p, g in zip(
        [p for p in self.module.parameters() if p.grad is not None],
        torch._utils._unflatten_dense_tensors(flat_grads, ...)
    ):
        p.grad.data = g
4.4.2 计算与通信重叠(Overlap) 🔄

利用反向传播增量计算梯度的特点:某个参数的梯度就绪后,立即异步启动 All-Reduce,无需等待所有梯度计算完成。

python 复制代码
class DDPOverlap(torch.nn.Module):                            # 带重叠的 DDP 容器 📦
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.handles = []                                     # 异步操作句柄列表 📋
        for param in self.module.parameters():
            dist.broadcast(param.data, src=0)
            if param.requires_grad:
                # 注册梯度累积完成后的钩子 📎
                param.register_post_accumulate_grad_hook(self._async_all_reduce)

    def _async_all_reduce(self, param):                      # 异步 All-Reduce 回调 🔄
        with torch.no_grad():
            param.grad.data /= dist.get_world_size()          # 先除以进程数
        self.handles.append(                                  # 记录异步句柄
            dist.all_reduce(param.grad.data, async_op=True)   # 异步 All-Reduce 🚀
        )

    def finish_gradient_synchronization(self):               # 等待所有通信完成 ⏳
        for handle in self.handles:
            handle.wait()                                     # 等待每个 All-Reduce 完成
        self.handles.clear()                                  # 清空句柄列表

直观类比 🏭:这就像工厂流水线------第一个车间完成零件后立即送到质检,不用等所有零件都做完。质检(通信)和生产(反向传播)同时进行。

4.4.3 桶化梯度通信(Bucketing) 🗑️

将参数分组为"桶"(bucket),每个桶内的梯度就绪后一起执行 All-Reduce。这结合了两种优势:

  • 减少通信调用次数(桶内批量通信)
  • 保持计算与通信的重叠(桶级别的重叠)
python 复制代码
class DDPBucketed(torch.nn.Module):                           # 桶化 DDP 容器 📦
    def __init__(self, module, bucket_size_mb=25):
        super().__init__()
        self.module = module
        self.bucket_size_mb = bucket_size_mb                  # 桶大小(MB) 📏
        # 按参数逆序分配桶(反向传播梯度就绪顺序大致与此相反) 🔄
        self.buckets = self._create_buckets()

    def _create_buckets(self):                               # 创建参数桶 🗑️
        buckets = []                                          # 桶列表
        current_bucket = []                                   # 当前桶
        current_size = 0                                      # 当前桶大小
        for param in reversed(list(self.module.parameters())):  # 逆序遍历参数
            param_size = param.numel() * param.element_size()  # 参数大小(字节)
            current_bucket.append(param)                       # 加入当前桶
            current_size += param_size
            if current_size >= self.bucket_size_mb * 1024 * 1024:  # 超过桶大小限制
                buckets.append(current_bucket)                  # 保存当前桶
                current_bucket = []                             # 新建桶
                current_size = 0
        if current_bucket:
            buckets.append(current_bucket)                      # 保存最后一个桶
        return buckets                                        # 返回桶列表

桶大小的权衡 ⚖️:

  • 桶太小:通信调用次数多,开销大
  • 桶太大:重叠效果差,需等待更多梯度就绪才能通信
  • 最优桶大小:取决于模型大小、网络带宽和通信开销

4.5 四维并行简介 🌐 / Four-Dimensional Parallelism

实际生产中,通常结合多种并行策略形成"四维并行":

并行维度 核心思想 说明
数据并行(DP) 拆分数据批次 每个 GPU 处理不同的数据子集
FSDP 分片参数/梯度/优化器状态 每个 GPU 只持有参数的一个分片
张量并行(TP) 拆分激活值 每个 GPU 处理操作的一部分
流水线并行(PP) 按层拆分模型 不同 GPU 处理不同的层

参考资料:

💡 Key Takeaways / 核心要点

  • DDP splits data, not model --- each GPU holds a full model copy / DDP 拆分数据而非模型,每个 GPU 持有完整模型副本
  • All-Reduce synchronizes gradients --- all GPUs end up with the same averaged gradients / All-Reduce 同步梯度,所有 GPU 得到相同的平均梯度
  • Overlap computation with communication --- async All-Reduce as soon as gradients are ready / 计算与通信重叠,梯度就绪后立即异步 All-Reduce
  • Bucketing balances overhead and overlap --- group parameters to reduce calls while maintaining overlap / 桶化平衡开销与重叠,减少调用次数同时保持重叠

5. 优化器状态分片 💾 / Optimizer State Sharding

💾 Note: 本章讲解 ZeRO-1 风格的优化器状态分片 / This chapter covers ZeRO-1 style optimizer state sharding.

5.1 问题背景:显存冗余 🧱 / The Problem: Memory Redundancy

DDP 要求每个 GPU 持有模型参数和优化器状态的完整副本,这带来显著的显存冗余。

以 AdamW 优化器为例,它为每个参数维护两个浮点数(一阶动量 mm m 和二阶动量 vv v):
优化器状态=2×参数量×4 字节(FP32)\text{优化器状态} = 2 \times \text{参数量} \times 4 \text{ 字节(FP32)} 优化器状态=2×参数量×4 字节(FP32)

对于 2.7B 模型:

组件 显存占用 计算方式
模型参数(FP32) ~10.8 GB 2.7×109×42.7 \times 10^9 \times 4 2.7×109×4
优化器状态(AdamW) ~21.6 GB 2.7×109×2×42.7 \times 10^9 \times 2 \times 4 2.7×109×2×4
梯度(FP32) ~10.8 GB 2.7×109×42.7 \times 10^9 \times 4 2.7×109×4
总计 ~43.2 GB 仅参数+优化器+梯度

直观类比 📚:DDP 就像给每个学生发一整本参考答案------如果 4 个学生,就有 4 本完全一样的答案,浪费 3 本。

5.2 ZeRO 优化思想 🎯 / ZeRO Optimization Idea

ZeRO(Zero Redundancy Optimizer)的核心思想是:将优化器状态分片到多个 GPU 上,消除冗余

ZeRO 分为三个阶段:

阶段 分片内容 显存节省 说明
ZeRO-1 优化器状态 4x 每个 GPU 只维护 1/N 的优化器状态
ZeRO-2 + 梯度 8x 梯度也分片存储
ZeRO-3 + 参数 与 GPU 数成正比 参数也分片存储

作业要求实现的是 ZeRO-1 的简化版本。

5.3 实现思路 🔨 / Implementation Approach

python 复制代码
class ShardedOptimizer(torch.optim.Optimizer):              # 分片优化器类 📦
    def __init__(self, params, optimizer_cls, **kwargs):
        # 将参数分配到不同 GPU(每个 GPU 只负责一部分) 📊
        # 初始化包装后的优化器,只包含本 GPU 负责的参数 🎯
        ...

    def step(self, closure=None, **kwargs):                  # 优化步骤 🚀
        # 包装后的优化器只更新其分片内的参数 🎯
        self.wrapped_optimizer.step(closure, **kwargs)
        # 将更新后的参数广播到其他 GPU 📡
        for param in self._my_shard_params:
            dist.broadcast(param.data, src=self._param_owner(param))

工作流程

markdown 复制代码
1. 每个 GPU 的优化器只负责 1/N 的参数 🎯
   ↓
2. 执行 optimizer.step(),更新本地分片的参数 ⚡
   ↓
3. 将更新后的参数广播到其他 GPU 📡
   ↓
4. 所有 GPU 的参数保持同步 ✅

5.4 显存与通信的权衡 ⚖️ / Memory-Communication Trade-off

指标 无分片(DDP) ZeRO-1 分片
优化器显存 完整(~21.6 GB) 1/N(如 2 GPU 时 ~10.8 GB)
通信开销 仅梯度 All-Reduce 梯度 All-Reduce + 参数广播
实现复杂度

ZeRO-1 的核心权衡:用额外的通信换取显存空间。对于显存受限的场景(如单卡无法容纳优化器状态),这是非常值得的。


参考资料:

💡 Key Takeaways / 核心要点

  • DDP has massive memory redundancy --- every GPU holds full optimizer states / DDP 存在大量显存冗余,每个 GPU 持有完整优化器状态
  • ZeRO-1 shards optimizer states across GPUs --- each GPU only maintains 1/N / ZeRO-1 将优化器状态分片到多个 GPU,每个 GPU 只维护 1/N
  • Trade communication for memory --- broadcast updated params after local optimization / 用通信换显存,本地优化后广播更新参数
  • AdamW doubles the parameter memory --- 2 floating point states per parameter / AdamW 使参数显存翻倍,每个参数 2 个浮点数状态

6. 总结 📝 / Summary

本节我们完成了 CS336 Assignment 2 的全面解析,核心要点回顾:🎯

主题 核心技术 关键收获
性能分析 🔬 端到端基准测试、Nsight 分析器、混合精度、内存分析 先分析再优化,避免优化非瓶颈组件
FlashAttention-2 ⚡ 分块计算、在线 Softmax、重计算、算子融合 内存复杂度从 O(n2)O(n^2) O(n2) 降到 O(n)O(n) O(n)
分布式训练 🖥️ All-Reduce、DDP、梯度桶化、计算通信重叠 从朴素到优化的渐进式工程思维
优化器分片 💾 ZeRO-1、参数广播、显存去冗余 用通信换显存的经典权衡

核心思维模型 🧠:

复制代码
单 GPU 优化:
  性能分析 → 定位瓶颈 → FlashAttention(注意力优化)→ 混合精度(计算加速)

多 GPU 扩展:
  数据并行 → 梯度同步 → 计算通信重叠 → 桶化通信 → 优化器分片

三大"墙"与对应解法 🧱:

挑战 解法 核心技术
显存墙 🧱 FlashAttention-2 + 混合精度 分块计算、重计算、算子融合
计算墙 ⚡ Tensor Cores + 算子融合 BF16 混合精度、内核融合
通信墙 📡 梯度桶化 + 计算重叠 + ZeRO 异步 All-Reduce、优化器分片

参考资料:


最后更新时间:2026-06-21

相关推荐
阿里云大数据AI技术2 小时前
用 SQL 解锁多模态数据分析:Hologres 让图片、语音、视频变成结构化洞察
人工智能
阿里云大数据AI技术3 小时前
EMR Serverless StarRocks 湖仓多模态检索:One SQL on One Data,实现全文 + 标量 + 向量三路混合检索
人工智能
冬奇Lab4 小时前
Skill 系列(02):Skill 安全风险——三类攻击面的实战测试
人工智能·安全·开源
冬奇Lab4 小时前
每日一个开源项目(第138篇):OpenMontage - 把 AI 编程助手变成完整的视频制作团队
人工智能·开源·claude
米小虾5 小时前
智谱港股盘中市值突破万亿港元!GLM-5.2 开源引爆国产 AI 价值重估
人工智能·chatglm (智谱)
阿里云大数据AI技术5 小时前
义乌小商品城基于MaxFrame AI Function的亿级AI 数据产线提速之路
人工智能
甲维斯6 小时前
用AI还原《坦克大战》并3D化升级!
前端·人工智能·游戏开发
IT_陈寒7 小时前
SpringBoot自动配置坑了我一晚上,原来问题出在这
前端·人工智能·后端