FlashAttention 1→4 进化史
- 背景:为什么需要FlashAttention?
-
- [1.1 标准Attention的计算瓶颈](#1.1 标准Attention的计算瓶颈)
- [1.2 GPU内存层次结构](#1.2 GPU内存层次结构)
- [数学基础:从Safe Softmax到Online Softmax](#数学基础:从Safe Softmax到Online Softmax)
-
- [2.1 Safe Softmax(3-pass)](#2.1 Safe Softmax(3-pass))
- [2.2 Online Softmax(2-pass)](#2.2 Online Softmax(2-pass))
- [FlashAttention V1:IO感知的分块注意力](#FlashAttention V1:IO感知的分块注意力)
-
- [3.1 核心思想](#3.1 核心思想)
- [3.2 分块算法(Tiling)](#3.2 分块算法(Tiling))
- [3.3 IO复杂度分析](#3.3 IO复杂度分析)
- [FlashAttention V2:并行性与工作划分优化](#FlashAttention V2:并行性与工作划分优化)
-
- [4.1 核心改进](#4.1 核心改进)
-
- [1. 减少非矩阵运算](#1. 减少非矩阵运算)
- [2. 改进并行性](#2. 改进并行性)
- [3. 优化Warp间工作划分](#3. 优化Warp间工作划分)
- [4.2 算法对比](#4.2 算法对比)
- [4.3 性能表现](#4.3 性能表现)
- [FlashAttention V3:Hopper架构的极致利用](#FlashAttention V3:Hopper架构的极致利用)
-
- [5.1 Hopper新特性](#5.1 Hopper新特性)
- [5.2 三大技术创新](#5.2 三大技术创新)
-
- [1. 异步执行与Warp专业化](#1. 异步执行与Warp专业化)
- [2. 非相干处理(Incoherent Processing)](#2. 非相干处理(Incoherent Processing))
- [3. 硬件感知的分块](#3. 硬件感知的分块)
- [5.3 性能表现](#5.3 性能表现)
- [FlashAttention V4:Blackwell架构的进化](#FlashAttention V4:Blackwell架构的进化)
-
- [6.1 Blackwell新特性](#6.1 Blackwell新特性)
- [6.2 核心创新](#6.2 核心创新)
-
- [1. 5-stage流水线](#1. 5-stage流水线)
- [2. 软件指数计算](#2. 软件指数计算)
- [3. 自适应在线缩放](#3. 自适应在线缩放)
- [6.3 当前状态](#6.3 当前状态)
- 各版本对比总结
-
- [7.1 技术演进路线](#7.1 技术演进路线)
- [7.2 关键指标对比](#7.2 关键指标对比)
- [7.3 适用场景建议](#7.3 适用场景建议)
- 实践指南
-
- [8.1 代码示例](#8.1 代码示例)
- [8.2 硬件兼容性](#8.2 硬件兼容性)
- 总结
-
- [9.1 核心要点](#9.1 核心要点)
- [9.2 一句话总结](#9.2 一句话总结)
- 参考文献
FlashAttention 系列算法是近年来Transformer加速领域最重要的突破之一。它通过IO感知(IO-Awareness)的设计理念,将注意力计算的速度提升了数十倍,同时将内存占用从 O ( N 2 ) O(N^2) O(N2) 降低到接近 O ( N ) O(N) O(N)。本文将详细解析 FlashAttention V1 到 V4 的核心原理、数学推导、技术演进以及性能对比。
背景:为什么需要FlashAttention?
1.1 标准Attention的计算瓶颈
标准的Scaled Dot-Product Attention计算公式为:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
其中 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N \times d} Q,K,V∈RN×d, N N N 是序列长度, d d d 是头维度。
标准实现的3-pass算法 :
┌─────────────────────────────────────────────────────────────┐
│ 标准Self-Attention计算流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Step 1: S = Q @ K.T # [N, N] ← 写入HBM │
│ Step 2: P = softmax(S) # [N, N] ← 写入HBM │
│ Step 3: O = P @ V # [N, d] ← 写入HBM │
│ │
│ 内存需求:O(N²) 中间矩阵 S 和 P │
│ HBM访问次数:O(Nd + N²) │
│ │
└─────────────────────────────────────────────────────────────┘
当 N = 128 K N=128K N=128K 时, N 2 = 16 亿 N^2 = 16\text{亿} N2=16亿,仅存储中间矩阵就需要 12.8GB 显存(FP16)!这就是为什么原始Transformer无法处理超长序列的根本原因。
1.2 GPU内存层次结构
┌─────────────────────────────────────────────────────────────┐
│ GPU内存层次结构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Register File (寄存器) : 最快,每个SM私有 │
│ ↓ │
│ Shared Memory (共享内存) : 极快,~20MB/SM │
│ ↓ │
│ L1/L2 Cache (缓存) : 快速,~几十MB │
│ ↓ │
│ HBM (高带宽显存) : 慢速,~80GB/s,大容量 │
│ │
│ 带宽差异:HBM(1x) vs SRAM(10-20x) │
│ │
└─────────────────────────────────────────────────────────────┘
FlashAttention的核心洞察:计算速度远快于内存访问速度 ,因此优化的关键在于减少HBM访问次数,尽可能在SRAM中完成计算 。
数学基础:从Safe Softmax到Online Softmax
2.1 Safe Softmax(3-pass)
为了避免数值溢出,标准的safe softmax需要先减去最大值:
softmax ( x i ) = e x i − m ∑ j = 1 N e x j − m , m = max ( x 1 , . . . , x N ) \text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_{j=1}^N e^{x_j - m}}, \quad m = \max(x_1, ..., x_N) softmax(xi)=∑j=1Nexj−mexi−m,m=max(x1,...,xN)
3-pass算法 :
Algorithm 1: 3-pass safe softmax
--------------------------------
m = -inf
for i = 1 to N: # Pass 1: 找最大值
m = max(m, x_i)
d = 0
for i = 1 to N: # Pass 2: 计算分母
d = d + exp(x_i - m)
for i = 1 to N: # Pass 3: 计算softmax
y_i = exp(x_i - m) / d
2.2 Online Softmax(2-pass)
Online Softmax的核心思想是:在遍历数据的同时,维护一个运行的最大值和归一化因子 。
定义:
- m i = max ( m i − 1 , x i ) m_i = \max(m_{i-1}, x_i) mi=max(mi−1,xi)
- l ~ i = ∑ j = 1 i e x j − m i \tilde{l}i = \sum{j=1}^i e^{x_j - m_i} l~i=∑j=1iexj−mi
递推关系推导 :
l ~ i = ∑ j = 1 i − 1 e x j − m i + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = l ~ i − 1 e m i − 1 − m i + e x i − m i \begin{align*} \tilde{l}i &= \sum{j=1}^{i-1} e^{x_j - m_i} + e^{x_i - m_i} \\ &= \left(\sum_{j=1}^{i-1} e^{x_j - m_{i-1}}\right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= \tilde{l}{i-1} e^{m{i-1} - m_i} + e^{x_i - m_i} \end{align*} l~i=j=1∑i−1exj−mi+exi−mi=(j=1∑i−1exj−mi−1)emi−1−mi+exi−mi=l~i−1emi−1−mi+exi−mi
2-pass算法 :
Algorithm 2: 2-pass online softmax
---------------------------------
m = -inf
l = 0
for i = 1 to N: # Pass 1: 同时更新 m 和 l
m_prev = m
m = max(m, x_i)
l = l * exp(m_prev - m) + exp(x_i - m)
for i = 1 to N: # Pass 2: 计算softmax
y_i = exp(x_i - m) / l
关键洞察:虽然online softmax仍然是2-pass,但第一个pass中同时计算了最大值和归一化因子,减少了HBM访问 。
FlashAttention V1:IO感知的分块注意力
3.1 核心思想
FlashAttention V1 将online softmax的思想扩展到完整的注意力计算中,实现了1-pass attention 。
定义输出的替身 o ~ i \tilde{o}_i o~i:
o ~ i = ∑ j = 1 i e x j − m i l ~ i v j \tilde{o}i = \sum{j=1}^i \frac{e^{x_j - m_i}}{\tilde{l}_i} v_j o~i=j=1∑il~iexj−mivj
递推关系推导 :
o ~ i = ∑ j = 1 i − 1 e x j − m i l ~ i v j + e x i − m i l ~ i v i = ( ∑ j = 1 i − 1 e x j − m i − 1 l ~ i − 1 v j ) l ~ i − 1 e m i − 1 − m i l ~ i + e x i − m i l ~ i v i = o ~ i − 1 l ~ i − 1 e m i − 1 − m i l ~ i + e x i − m i l ~ i v i \begin{align*} \tilde{o}i &= \sum{j=1}^{i-1} \frac{e^{x_j - m_i}}{\tilde{l}i} v_j + \frac{e^{x_i - m_i}}{\tilde{l}i} v_i \\ &= \left(\sum{j=1}^{i-1} \frac{e^{x_j - m{i-1}}}{\tilde{l}{i-1}} v_j\right) \frac{\tilde{l}{i-1} e^{m_{i-1} - m_i}}{\tilde{l}i} + \frac{e^{x_i - m_i}}{\tilde{l}i} v_i \\ &= \tilde{o}{i-1} \frac{\tilde{l}{i-1} e^{m_{i-1} - m_i}}{\tilde{l}_i} + \frac{e^{x_i - m_i}}{\tilde{l}_i} v_i \end{align*} o~i=j=1∑i−1l~iexj−mivj+l~iexi−mivi=(j=1∑i−1l~i−1exj−mi−1vj)l~il~i−1emi−1−mi+l~iexi−mivi=o~i−1l~il~i−1emi−1−mi+l~iexi−mivi
3.2 分块算法(Tiling)
FlashAttention V1 将 Q , K , V Q, K, V Q,K,V 分块加载到SRAM中计算 :
┌─────────────────────────────────────────────────────────────┐
│ FlashAttention V1 分块计算 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 外层循环: 遍历 Q 的块 (Q_i) │
│ 内层循环: 遍历 K, V 的块 (K_j, V_j) │
│ 1. 加载 Q_i, K_j 到 SRAM │
│ 2. 计算 S_ij = Q_i @ K_j.T │
│ 3. 更新 running max 和 sum exp │
│ 4. 计算局部输出并 rescale │
│ 5. 累加到 O_i │
│ │
│ 整个过程无需将中间矩阵写回HBM! │
│ │
└─────────────────────────────────────────────────────────────┘
算法伪代码 :
Algorithm 3: FlashAttention-1 Forward
-------------------------------------
将 Q, K, V 分块为 Q_i, K_j, V_j, 块大小 B_q, B_kv
for i = 1 to ceil(N/B_q): # 遍历 Q 块
加载 Q_i 到 SRAM
初始化 O_i = 0, m_i = -inf, l_i = 0
for j = 1 to ceil(N/B_kv): # 遍历 K,V 块
加载 K_j, V_j 到 SRAM
# 计算当前块的点积
S_ij = Q_i @ K_j.T # [B_q, B_kv]
# 更新 running max
m_ij = row_max(S_ij) # [B_q]
m_new = max(m_i, m_ij) # [B_q]
# 更新 running sum (online softmax)
l_new = exp(m_i - m_new) * l_i + row_sum(exp(S_ij - m_new)) # [B_q]
# 计算局部输出并 rescale
P_ij_hat = exp(S_ij - m_new) # [B_q, B_kv]
O_i = exp(m_i - m_new) * O_i + P_ij_hat @ V_j # [B_q, d]
# 更新状态
m_i = m_new
l_i = l_new
将 O_i / l_i 写回HBM (最终输出)
3.3 IO复杂度分析
| 实现方式 | HBM访问次数 | 内存占用 |
|---|---|---|
| 标准Attention | O ( N d + N 2 ) O(Nd + N^2) O(Nd+N2) | O ( N 2 ) O(N^2) O(N2) |
| FlashAttention V1 | O ( N 2 d 2 / M ) O(N^2d^2 / M) O(N2d2/M) | O ( N ) O(N) O(N) |
其中 M M M 是SRAM大小 。
性能提升 :比PyTorch标准实现快 2-4倍 。
FlashAttention V2:并行性与工作划分优化
4.1 核心改进
FlashAttention V2 在V1基础上做了三项关键优化 :
1. 减少非矩阵运算
非矩阵乘法(如Softmax、点乘)在GPU张量核心上执行速度慢16倍。V2将尽可能多的操作转换为矩阵乘法 。
2. 改进并行性
V1中,每个注意力头分配一个thread block,当序列长但batch小时,GPU利用率不足。V2将Q的循环也并行化 :
┌─────────────────────────────────────────────────────────────┐
│ FlashAttention V1 vs V2 并行策略 │
├─────────────────────────────────────────────────────────────┤
│ │
│ V1: thread block 分配 = batch_size × num_heads │
│ 每个block处理一个头的所有Q │
│ │
│ V2: thread block 分配 = batch_size × num_heads × num_q_blocks
│ 每个block处理一个头的一个Q块 │
│ │
│ 效果:更细粒度的并行,GPU利用率从25-40% → 50-73% │
│ │
└─────────────────────────────────────────────────────────────┘
3. 优化Warp间工作划分
V2重新设计了Warp间的任务分配,减少通过共享内存的数据交换 。
4.2 算法对比
| 版本 | 外层循环 | 内层循环 | 并行粒度 |
|---|---|---|---|
| V1 | 遍历Q块 | 遍历K,V块 | 每个head一个block |
| V2 | 遍历K,V块 | 遍历Q块 | 每个Q块一个block |
为什么交换循环顺序更好? :
- 减少了中间结果的rescale次数
- 更好的寄存器利用
- 更容易实现MQA/GQA支持
4.3 性能表现
| 指标 | V1 | V2 | 提升 |
|---|---|---|---|
| A100利用率 | 25-40% | 50-73% | 2倍 |
| 训练速度 | 基准 | 2倍 | 2倍 |
FlashAttention V3:Hopper架构的极致利用
5.1 Hopper新特性
NVIDIA H100 GPU引入了三项关键硬件特性 :
| 特性 | 说明 | 优势 |
|---|---|---|
| WGMMA | Warpgroup级矩阵乘加指令 | 比Ampere快2倍 |
| TMA | 张量内存加速器 | 异步数据移动,自动地址计算 |
| FP8 | 8位浮点格式 | 理论性能翻倍 |
5.2 三大技术创新
1. 异步执行与Warp专业化
FlashAttention-3采用生产者-消费者模型 :
┌─────────────────────────────────────────────────────────────┐
│ Warp专业化流水线 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Warp组1 (生产者): 使用TMA从HBM加载数据 → 存入SRAM │
│ Warp组2 (消费者): 使用WGMMA执行矩阵乘法 → 写回HBM │
│ │
│ 时间线: │
│ T0: 生产者加载块0 │
│ T1: 消费者计算块0 + 生产者加载块1 (重叠) │
│ T2: 消费者计算块1 + 生产者加载块2 (重叠) │
│ │
│ 效果:隐藏TMA延迟(~200-300 cycles) │
│ │
└─────────────────────────────────────────────────────────────┘
2. 非相干处理(Incoherent Processing)
解决FP8量化误差的关键技术 :
Q ′ = Q ⋅ H , K ′ = K ⋅ H Q' = Q \cdot H, \quad K' = K \cdot H Q′=Q⋅H,K′=K⋅H
其中 H H H 是随机正交矩阵(如Hadamard矩阵)。
原理:
- 离群值被"涂抹"到所有坐标上,分布更均匀
- 量化误差降低 2.6倍
- 由于正交性, ( Q H ) ( K H ) T = Q K T (QH)(KH)^T = QK^T (QH)(KH)T=QKT,结果不变
3. 硬件感知的分块
V3根据Hopper的架构特点优化分块大小 :
| 参数 | V2 | V3 |
|---|---|---|
| 块大小(BM, BN) | 64×64 | 128×128 |
| 流水线深度 | 2-stage | 4-6 stage |
| 支持头维度 | 64,128 | 72,80,96,128,256 |
5.3 性能表现
| 精度 | V2 on H100 | V3 on H100 | 提升 |
|---|---|---|---|
| FP16/BF16 | ~370 TFLOPS | 740-840 TFLOPS | 2倍 |
| FP8 | 不原生支持 | 1.2-1.3 PFLOPS | - |
| 利用率 | ~35% | ~75-85% | 2倍+ |
FlashAttention V4:Blackwell架构的进化
6.1 Blackwell新特性
Blackwell架构(B200, SM100)带来了新的优化机会 :
- TMEM(张量内存):专门的片上内存
- 增强的异步执行:更深的流水线
- 更高的计算密度:更多张量核心
6.2 核心创新
1. 5-stage流水线
FA4实现了更细粒度的流水线划分 :
┌─────────────────────────────────────────────────────────────┐
│ FlashAttention-4 5-stage流水线 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Stage 1: 数据移动 (TMA加载) │
│ Stage 2: QK矩阵乘法 (WGMMA) │
│ Stage 3: Softmax计算 (CUDA Cores) │
│ Stage 4: PV矩阵乘法 (WGMMA) │
│ Stage 5: 结果写回 (TMA存储) │
│ │
│ 所有阶段异步执行,最大化硬件利用率 │
│ │
└─────────────────────────────────────────────────────────────┘
2. 软件指数计算
FA4在CUDA核心上用软件实现指数函数,避免SFU瓶颈 :
python
# 软件实现的exp2多项式近似
def fast_exp2(x):
# 3次多项式近似,精度与硬件SFU相当
return poly3(x) # 可在任意CUDA核心并行执行
优势:SFU数量有限(每SM几个),CUDA核心众多(每SM上百),并行度大幅提升。
3. 自适应在线缩放
传统方法每当发现新的最大值就rescale,FA4引入自适应策略 :
Algorithm 4: Adaptive Online Rescaling
--------------------------------------
if new_max > old_max * threshold: # 只有显著增大才rescale
rescale_all()
else:
accumulate_without_rescale() # 继续累积
效果 :rescale操作频率降低 10倍 。
6.3 当前状态
| 特性 | 支持状态 | 说明 |
|---|---|---|
| 前向传播 | ✅ 可用 | Blackwell优化版本已提交 |
| 反向传播 | ⚠️ 开发中 | varlen、GQA支持缺失 |
| 变长序列 | ⚠️ 部分支持 | 前向支持,反向开发中 |
| GQA/MQA | ⚠️ 部分支持 | 前向支持,反向开发中 |
| 框架集成 | ⏳ 进行中 | PyTorch SDPA尚未集成 |
性能 :比cuDNN注意力快 20-22% 在Blackwell上 。
各版本对比总结
7.1 技术演进路线
┌─────────────────────────────────────────────────────────────┐
│ FlashAttention 演进路线 │
├─────────────────────────────────────────────────────────────┤
│ │
│ V1 (2022) → Ampere (A100) │
│ ├─ 核心创新:Online Softmax + Tiling │
│ ├─ 性能:2-4× vs PyTorch,内存O(N²)→O(N) │
│ └─ 利用率:25-40% │
│ │
│ V2 (2023) → Ampere/Ada │
│ ├─ 核心创新:循环交换 + 细粒度并行 │
│ ├─ 性能:2× vs V1 │
│ └─ 利用率:50-73% │
│ │
│ V3 (2024) → Hopper (H100) │
│ ├─ 核心创新:Warp专业化 + TMA + FP8 + 非相干处理 │
│ ├─ 性能:1.5-2× vs V2,FP8达1.2 PFLOPS │
│ └─ 利用率:75-85% │
│ │
│ V4 (2025) → Blackwell (B200) │
│ ├─ 核心创新:5-stage流水线 + 软件exp + 自适应缩放 │
│ ├─ 性能:+20-22% vs cuDNN │
│ └─ 状态:前向可用,反向开发中 │
│ │
└─────────────────────────────────────────────────────────────┘
7.2 关键指标对比
| 版本 | 目标架构 | 核心技巧 | 速度提升 | 利用率 | 内存节省 |
|---|---|---|---|---|---|
| V1 | A100 | Online Softmax + Tiling | 2-4× | 25-40% | 10-20× |
| V2 | A100/Ada | 循环交换 + 并行优化 | 2× (vs V1) | 50-73% | 10-20× |
| V3 | H100 | Warp专业化 + TMA + FP8 | 1.5-2× (vs V2) | 75-85% | 10-20× |
| V4 | B200 | 5-stage流水线 + 软件exp | +20-22% (vs cuDNN) | ~90% | 10-20× |
7.3 适用场景建议
| 场景 | 推荐版本 | 理由 |
|---|---|---|
| 训练长序列 | V3 (H100) / V2 (A100) | 反向传播支持完善 |
| 推理服务 | V3 + PagedAttention | 结合vLLM使用 |
| Blackwell部署 | V4 (前向) / V3 (完整) | 前向可用,反向待完善 |
| 端侧/低资源 | V2 | 兼容性好,成熟稳定 |
| 研究开发 | 最新版 | 持续跟进新特性 |
实践指南
8.1 代码示例
python
# PyTorch 2.0+ 使用SDPA自动选择最优后端
import torch.nn.functional as F
# 自动使用FlashAttention(如果可用)
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=True
)
# HuggingFace Transformers 启用FlashAttention-2
from transformers import AutoModel
model = AutoModel.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2" # 或 "sdpa"
)
8.2 硬件兼容性
| 版本 | 最低GPU架构 | 推荐GPU | CUDA版本 |
|---|---|---|---|
| V1 | Volta (V100) | Ampere (A100) | 11.0+ |
| V2 | Volta (V100) | Ampere/Ada | 11.4+ |
| V3 | Ampere (A100) | Hopper (H100) | 12.0+ |
| V4 | Hopper (H100) | Blackwell (B200) | 12.8+ |
总结
9.1 核心要点
┌─────────────────────────────────────────────────────────────┐
│ FlashAttention 核心要点 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 设计哲学 │
│ • IO-Awareness:减少HBM访问,利用SRAM │
│ • 算法-硬件协同设计:针对特定架构优化 │
│ │
│ 2️⃣ 数学基础 │
│ • Online Softmax:2-pass → 1-pass attention │
│ • 分块累加 + 动态rescale │
│ │
│ 3️⃣ 演进路线 │
│ • V1:证明可行性,奠定基础 │
│ • V2:优化并行,提升利用率 │
│ • V3:利用Hopper新特性,突破性能瓶颈 │
│ • V4:Blackwell优化,探索新方向 │
│ │
│ 4️⃣ 实际收益 │
│ • 速度:累计提升8-10倍 vs 标准实现 │
│ • 内存:从O(N²)到O(N),支持百万级上下文 │
│ • 精度:完全无损,与标准attention等价 │
│ │
└─────────────────────────────────────────────────────────────┘
9.2 一句话总结
FlashAttention系列通过算法-硬件协同设计,将注意力计算从内存瓶颈转变为计算瓶颈,让百万级上下文成为现实,是Transformer加速领域最重要的里程碑之一! 🚀
参考文献
- Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (2022)
- Dao. "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (2023)
- Shah et al. "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision" (2024)
- "FlashAttention 4: Faster, Memory-Efficient Attention for LLMs" DigitalOcean (2026)
- "From Online Softmax to FlashAttention" University of Washington
本文为原创内容,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。