目录
- [第一篇:FlashAttention --- IO 感知的精确注意力](#第一篇:FlashAttention — IO 感知的精确注意力)
- [第二篇:FlashAttention-2 --- 并行优化与因果掩码](#第二篇:FlashAttention-2 — 并行优化与因果掩码)
- [第三篇:FlashAttention-3 --- Hopper 架构异步流水线](#第三篇:FlashAttention-3 — Hopper 架构异步流水线)
- 参考文献
第一篇:FlashAttention --- IO 感知的精确注意力
1. 引言
标准自注意力的计算复杂度为 O ( L 2 d ) O(L^2 d) O(L2d),但实践中真正的瓶颈往往不是浮点运算量(FLOPs),而是内存 IO。GPU 的计算能力(FLOPS)增长远快于显存带宽(GB/s),导致注意力计算被内存访问所制约。
FlashAttention (Dao et al., 2022)从一个全新的视角审视注意力:不优化 FLOPs,而是优化内存 IO 。通过分块计算(tiling)和在线 softmax 技巧,在不引入任何近似的情况下,将注意力的 HBM 访问量从 O ( L 2 ) O(L^2) O(L2) 降低到 O ( L 2 d 2 / M ) O(L^2 d^2 / M) O(L2d2/M),其中 M M M 是 SRAM 容量。在实际 GPU 上,这带来了 2-4× 的端到端加速 和 5-20× 的显存节省。
2. 理论基础 --- GPU 内存层次与 IO 复杂度
2.1 GPU 内存层次结构
现代 GPU(如 A100)的内存层次:
| 层级 | 容量 | 带宽 | 延迟 |
|---|---|---|---|
| HBM(高带宽显存) | 40-80 GB | 1.5-3.35 TB/s | ~数百周期 |
| SRAM(共享内存/L1) | 每 SM 164-228 KB | ~19 TB/s | ~数周期 |
| 寄存器 | 每 SM 256 KB | ~数百 TB/s | 1 周期 |
关键洞察 :SRAM 比 HBM 快约 10-20×,但容量小约 1000×。标准注意力实现需要将 O ( L 2 ) O(L^2) O(L2) 大小的注意力矩阵在 HBM 和 SRAM 之间反复传输,造成严重的 IO 瓶颈。
2.2 IO 复杂度理论
IO 复杂度衡量算法执行期间在不同内存层级之间的数据传输量。
对于矩阵乘法 C = A B \mathbf{C} = \mathbf{A} \mathbf{B} C=AB,其中 A ∈ R M × K \mathbf{A} \in \mathbb{R}^{M \times K} A∈RM×K, B ∈ R K × N \mathbf{B} \in \mathbb{R}^{K \times N} B∈RK×N:
朴素实现 :每次计算 C i j = ∑ k A i k B k j \mathbf{C}{ij} = \sum_k \mathbf{A}{ik} \mathbf{B}_{kj} Cij=∑kAikBkj,需要从 HBM 读取整行/整列。
分块实现 :将矩阵分成 B r × B c B_r \times B_c Br×Bc 大小的块,每次将两个块加载到 SRAM,在 SRAM 中完成子矩阵乘法。
IO 复杂度下界(Hong & Kung, 1981):
IO ≥ M N K M SRAM \text{IO} \geq \frac{MNK}{M_{\text{SRAM}}} IO≥MSRAMMNK
其中 M SRAM M_{\text{SRAM}} MSRAM 是 SRAM 容量。分块矩阵乘法可以达到这个下界。
2.3 标准注意力的 IO 分析
标准注意力的计算步骤:
S = Q K T ∈ R L × L (写入 HBM) \mathbf{S} = \mathbf{Q} \mathbf{K}^T \in \mathbb{R}^{L \times L} \quad \text{(写入 HBM)} S=QKT∈RL×L(写入 HBM)
P = softmax ( S ) ∈ R L × L (读取 + 写入 HBM) \mathbf{P} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{L \times L} \quad \text{(读取 + 写入 HBM)} P=softmax(S)∈RL×L(读取 + 写入 HBM)
O = P V ∈ R L × d (读取 HBM) \mathbf{O} = \mathbf{P} \mathbf{V} \in \mathbb{R}^{L \times d} \quad \text{(读取 HBM)} O=PV∈RL×d(读取 HBM)
IO 分析:
| 步骤 | HBM 读 | HBM 写 | 说明 |
|---|---|---|---|
| 计算 S \mathbf{S} S | O ( L d ) O(Ld) O(Ld) | O ( L 2 ) O(L^2) O(L2) | 读 Q,K;写 S |
| 计算 P \mathbf{P} P | O ( L 2 ) O(L^2) O(L2) | O ( L 2 ) O(L^2) O(L2) | 读 S;写 P |
| 计算 O \mathbf{O} O | O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) | O ( L d ) O(Ld) O(Ld) | 读 P,V;写 O |
| 总计 | O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) | O ( L 2 ) O(L^2) O(L2) | --- |
当 L ≫ d L \gg d L≫d 时(长序列),IO 复杂度为 O ( L 2 ) O(L^2) O(L2),由注意力矩阵 S , P \mathbf{S}, \mathbf{P} S,P 的读写主导。
问题 : L 2 L^2 L2 大小的注意力矩阵必须物化(materialized)到 HBM 中,这既是 IO 瓶颈,也是显存瓶颈。
3. FlashAttention 的核心算法
3.1 核心思想:分块 + 在线 Softmax
FlashAttention 的关键洞察:不需要将完整的注意力矩阵 S \mathbf{S} S 和 P \mathbf{P} P 物化到 HBM 中。
通过分块计算,在 SRAM 中逐块完成 softmax 和矩阵乘法,只将最终结果 O \mathbf{O} O 写回 HBM。
3.2 在线 Softmax 算法
标准 softmax 需要两次遍历数据:
softmax ( x ) i = e x i ∑ j e x j \text{softmax}(\mathbf{x})_i = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(x)i=∑jexjexi
第一次遍历计算 max ( x ) \max(\mathbf{x}) max(x) 和 ∑ j e x j − max \sum_j e^{x_j - \max} ∑jexj−max,第二次计算归一化值。
在线 softmax(Milakov & Gimelshein, 2018)只需一次遍历:
维护运行最大值 m k m_k mk 和运行和 l k l_k lk:
m k = max ( m k − 1 , x k ) m_k = \max(m_{k-1}, x_k) mk=max(mk−1,xk)
l k = e m k − 1 − m k l k − 1 + e x k − m k l_k = e^{m_{k-1} - m_k} l_{k-1} + e^{x_k - m_k} lk=emk−1−mklk−1+exk−mk
softmax ( x ) k = e x k − m k l k \text{softmax}(\mathbf{x})_k = \frac{e^{x_k - m_k}}{l_k} softmax(x)k=lkexk−mk
关键性质:当新数据到来时,可以增量更新,无需重新访问历史数据。
3.3 FlashAttention 算法详解
输入 : Q , K , V ∈ R L × d \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{L \times d} Q,K,V∈RL×d,SRAM 容量 M M M
分块策略 :将 Q \mathbf{Q} Q 分成 T r = ⌈ L / B r ⌉ T_r = \lceil L / B_r \rceil Tr=⌈L/Br⌉ 个块, K , V \mathbf{K}, \mathbf{V} K,V 分成 T c = ⌈ L / B c ⌉ T_c = \lceil L / B_c \rceil Tc=⌈L/Bc⌉ 个块。
算法:
初始化: O = 0_{L×d}, l = 0_L, m = -∞_L
for j = 1 to T_c: // 遍历 K,V 的块
从 HBM 加载 K_j, V_j 到 SRAM // 块大小: B_c × d
for i = 1 to T_r: // 遍历 Q 的块
从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM
// 在 SRAM 中计算
S_ij = Q_i · K_j^T // (B_r × B_c), 在 SRAM 中
m̃_ij = rowmax(S_ij) // (B_r,)
m_new = max(m_i, m̃_ij) // 更新最大值
P̃_ij = exp(S_ij - m̃_ij) // 未归一化的注意力权重
l̃_ij = rowsum(P̃_ij) // (B_r,)
// 修正之前的累积值
l_new = exp(m_i - m_new) · l_i + exp(m̃_ij - m_new) · l̃_ij
// 更新输出
O_i = diag(exp(m_i - m_new))^{-1} · (
l_i · O_i + exp(m̃_ij - m_new) · P̃_ij · V_j
) / l_new
更新 m_i = m_new, l_i = l_new
将 O_i, l_i, m_i 写回 HBM // 块大小: B_r × d + 2 B_r
3.4 算法的数学正确性
关键等式 :在处理完前 j j j 个 KV 块后, O i \mathbf{O}_i Oi 等于:
O i ( j ) = ∑ k = 1 j P ~ i k V k ∑ k = 1 j l ~ i k \mathbf{O}i^{(j)} = \frac{\sum{k=1}^{j} \tilde{\mathbf{P}}{ik} \mathbf{V}k}{\sum{k=1}^{j} \tilde{\mathbf{l}}{ik}} Oi(j)=∑k=1jl~ik∑k=1jP~ikVk
其中 P ~ i k = exp ( S i k − m ( j ) ) \tilde{\mathbf{P}}{ik} = \exp(\mathbf{S}{ik} - m^{(j)}) P~ik=exp(Sik−m(j)) 是用最终最大值 m ( j ) m^{(j)} m(j) 修正后的未归一化权重。
证明(归纳法):
基础 : j = 1 j = 1 j=1 时,算法正确计算了第一块的 softmax 加权输出。
归纳 :假设前 j − 1 j-1 j−1 块正确。处理第 j j j 块时:
新的最大值 m ( j ) = max ( m ( j − 1 ) , m ~ i j ) m^{(j)} = \max(m^{(j-1)}, \tilde{m}_{ij}) m(j)=max(m(j−1),m~ij)。
旧累积值需要乘以修正因子 exp ( m ( j − 1 ) − m ( j ) ) \exp(m^{(j-1)} - m^{(j)}) exp(m(j−1)−m(j)):
l ( j ) = e m ( j − 1 ) − m ( j ) l ( j − 1 ) + e m ~ i j − m ( j ) l ~ i j l^{(j)} = e^{m^{(j-1)} - m^{(j)}} l^{(j-1)} + e^{\tilde{m}{ij} - m^{(j)}} \tilde{l}{ij} l(j)=em(j−1)−m(j)l(j−1)+em~ij−m(j)l~ij
O ( j ) = e m ( j − 1 ) − m ( j ) l ( j − 1 ) O ( j − 1 ) + e m ~ i j − m ( j ) P ~ i j V j l ( j ) \mathbf{O}^{(j)} = \frac{e^{m^{(j-1)} - m^{(j)}} l^{(j-1)} \mathbf{O}^{(j-1)} + e^{\tilde{m}{ij} - m^{(j)}} \tilde{\mathbf{P}}{ij} \mathbf{V}_j}{l^{(j)}} O(j)=l(j)em(j−1)−m(j)l(j−1)O(j−1)+em~ij−m(j)P~ijVj
这恰好等于从头计算的 softmax 加权输出。 ■ \blacksquare ■
3.5 IO 复杂度分析
定理(Dao et al., 2022):FlashAttention 的 HBM IO 复杂度为:
IO = O ( L 2 d 2 M ) \text{IO} = O\left(\frac{L^2 d^2}{M}\right) IO=O(ML2d2)
其中 M M M 是 SRAM 容量。
证明:
- 外层循环 T c = L / B c T_c = L / B_c Tc=L/Bc 次
- 内层循环 T r = L / B r T_r = L / B_r Tr=L/Br 次
- 每次内层循环:读取 Q i \mathbf{Q}_i Qi( B r d B_r d Brd)、 O i , l i , m i \mathbf{O}_i, \mathbf{l}_i, \mathbf{m}_i Oi,li,mi( B r d + 2 B r B_r d + 2B_r Brd+2Br),写回( B r d + 2 B r B_r d + 2B_r Brd+2Br)
- 每次外层循环:读取 K j , V j \mathbf{K}_j, \mathbf{V}_j Kj,Vj( 2 B c d 2 B_c d 2Bcd)
总 IO:
IO = T c ⋅ T r ⋅ O ( B r d ) + T c ⋅ O ( B c d ) \text{IO} = T_c \cdot T_r \cdot O(B_r d) + T_c \cdot O(B_c d) IO=Tc⋅Tr⋅O(Brd)+Tc⋅O(Bcd)
= L B c ⋅ L B r ⋅ O ( B r d ) + L B c ⋅ O ( B c d ) = \frac{L}{B_c} \cdot \frac{L}{B_r} \cdot O(B_r d) + \frac{L}{B_c} \cdot O(B_c d) =BcL⋅BrL⋅O(Brd)+BcL⋅O(Bcd)
= O ( L 2 d B c ) + O ( L d ) = O\left(\frac{L^2 d}{B_c}\right) + O(Ld) =O(BcL2d)+O(Ld)
SRAM 约束: B r d + B c d + B r B c ≤ M B_r d + B_c d + B_r B_c \leq M Brd+Bcd+BrBc≤M,取 B r = B c = M / ( 2 d ) B_r = B_c = \sqrt{M / (2d)} Br=Bc=M/(2d) :
IO = O ( L 2 d M / d ) = O ( L 2 d 3 / 2 M ) \text{IO} = O\left(\frac{L^2 d}{\sqrt{M/d}}\right) = O\left(\frac{L^2 d^{3/2}}{\sqrt{M}}\right) IO=O(M/d L2d)=O(M L2d3/2)
在 d = 128 , M = 192 KB d = 128, M = 192 \text{KB} d=128,M=192KB 的典型设置下,IO 减少约 L 2 d / M ≈ L 2 / 1500 L^2 d / M \approx L^2 / 1500 L2d/M≈L2/1500 倍。 ■ \blacksquare ■
4. 反向传播与显存优化
4.1 问题:注意力矩阵不被存储
标准反向传播需要前向传播中存储的注意力矩阵 P \mathbf{P} P 来计算梯度:
∂ L ∂ V = P T ∂ L ∂ O \frac{\partial \mathcal{L}}{\partial \mathbf{V}} = \mathbf{P}^T \frac{\partial \mathcal{L}}{\partial \mathbf{O}} ∂V∂L=PT∂O∂L
∂ L ∂ P = ∂ L ∂ O V T \frac{\partial \mathcal{L}}{\partial \mathbf{P}} = \frac{\partial \mathcal{L}}{\partial \mathbf{O}} \mathbf{V}^T ∂P∂L=∂O∂LVT
∂ L ∂ S = ∂ L ∂ P ⊙ P − P ⊙ ( ∑ j ∂ L ∂ P i j P i j ) \frac{\partial \mathcal{L}}{\partial \mathbf{S}} = \frac{\partial \mathcal{L}}{\partial \mathbf{P}} \odot \mathbf{P} - \mathbf{P} \odot \left(\sum_j \frac{\partial \mathcal{L}}{\partial \mathbf{P}{ij}} \mathbf{P}{ij}\right) ∂S∂L=∂P∂L⊙P−P⊙(j∑∂Pij∂LPij)
FlashAttention 不存储 P \mathbf{P} P,因此需要重新计算。
4.2 重计算策略
FlashAttention 在反向传播时,利用前向传播中存储的以下信息重新计算 P \mathbf{P} P:
- O \mathbf{O} O:最终输出
- l \mathbf{l} l:softmax 的归一化因子
- m \mathbf{m} m:每行的最大值
重新计算 P \mathbf{P} P:
S i j = Q i K j T \mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^T Sij=QiKjT
P i j = exp ( S i j − m i ) l i \mathbf{P}{ij} = \frac{\exp(\mathbf{S}{ij} - \mathbf{m}_i)}{\mathbf{l}_i} Pij=liexp(Sij−mi)
额外计算量 :前向传播约 2 L 2 d 2L^2 d 2L2d FLOPs,反向传播约 4 L 2 d 4L^2 d 4L2d FLOPs(重新计算 S \mathbf{S} S 和 P \mathbf{P} P)。总 FLOPs 增加约 50-75%。
但:由于 IO 大幅减少,实际墙钟时间反而更快。
4.3 显存节省
| 方法 | 需存储的注意力相关数据 |
|---|---|
| 标准注意力 | S \mathbf{S} S ( L 2 L^2 L2) + P \mathbf{P} P ( L 2 L^2 L2) |
| FlashAttention | O \mathbf{O} O ( L d Ld Ld) + l \mathbf{l} l ( L L L) + m \mathbf{m} m ( L L L) |
显存节省: O ( L 2 ) → O ( L d ) O(L^2) \to O(Ld) O(L2)→O(Ld),当 L ≫ d L \gg d L≫d 时节省巨大。
5. 完整可运行实现
5.1 FlashAttention 前向传播
python
"""
FlashAttention --- 完整可运行实现 (教学版本)
依赖: torch >= 2.0, numpy, matplotlib
"""
import torch
import torch.nn.functional as F
import math
import numpy as np
from typing import Tuple
def standard_attention(
Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
) -> torch.Tensor:
"""
标准注意力实现 (用于验证正确性)
Q, K, V: (B, H, L, d)
"""
d = Q.shape[-1]
S = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d) # (B, H, L, L)
P = torch.softmax(S, dim=-1) # (B, H, L, L)
O = torch.matmul(P, V) # (B, H, L, d)
return O
def flash_attention_forward(
Q: torch.Tensor, # (B, H, L, d)
K: torch.Tensor, # (B, H, L, d)
V: torch.Tensor, # (B, H, L, d)
block_size_r: int = 32,
block_size_c: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
FlashAttention 前向传播 (在线 softmax 版本)
返回: O (输出), l (归一化因子), m (行最大值)
"""
B, H, L, d = Q.shape
scale = 1.0 / math.sqrt(d)
# 初始化输出和统计量
O = torch.zeros(B, H, L, d, device=Q.device, dtype=Q.dtype)
l = torch.zeros(B, H, L, device=Q.device, dtype=Q.dtype)
m = torch.full((B, H, L), float("-inf"), device=Q.device, dtype=Q.dtype)
# 分块
num_blocks_r = math.ceil(L / block_size_r)
num_blocks_c = math.ceil(L / block_size_c)
for j in range(num_blocks_c):
# 加载 K_j, V_j 块
j_start = j * block_size_c
j_end = min(j_start + block_size_c, L)
Kj = K[:, :, j_start:j_end, :] # (B, H, B_c, d)
Vj = V[:, :, j_start:j_end, :]
for i in range(num_blocks_r):
# 加载 Q_i, O_i, l_i, m_i
i_start = i * block_size_r
i_end = min(i_start + block_size_r, L)
Qi = Q[:, :, i_start:i_end, :] # (B, H, B_r, d)
Oi = O[:, :, i_start:i_end, :]
li = l[:, :, i_start:i_end]
mi = m[:, :, i_start:i_end]
# 计算 S_ij = Q_i @ K_j^T / sqrt(d)
Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) * scale # (B, H, B_r, B_c)
# 当前行最大值
mij_tilde = Sij.max(dim=-1).values # (B, H, B_r)
# 新的最大值
mi_new = torch.max(mi, mij_tilde) # (B, H, B_r)
# 未归一化的注意力权重
Pij_tilde = torch.exp(Sij - mij_tilde.unsqueeze(-1)) # (B, H, B_r, B_c)
lij_tilde = Pij_tilde.sum(dim=-1) # (B, H, B_r)
# 修正因子
alpha = torch.exp(mi - mi_new) # (B, H, B_r)
beta = torch.exp(mij_tilde - mi_new) # (B, H, B_r)
# 更新归一化因子
li_new = alpha * li + beta * lij_tilde # (B, H, B_r)
# 更新输出
Oi_new = (
alpha.unsqueeze(-1) * li.unsqueeze(-1) * Oi
+ beta.unsqueeze(-1) * torch.matmul(Pij_tilde, Vj)
) / li_new.unsqueeze(-1)
# 写回
O[:, :, i_start:i_end, :] = Oi_new
l[:, :, i_start:i_end] = li_new
m[:, :, i_start:i_end] = mi_new
return O, l, m
def verify_flash_attention():
"""验证 FlashAttention 与标准注意力的等价性"""
torch.manual_seed(42)
B, H, L, d = 2, 4, 128, 64
Q = torch.randn(B, H, L, d)
K = torch.randn(B, H, L, d)
V = torch.randn(B, H, L, d)
# 标准注意力
O_std = standard_attention(Q, K, V)
# FlashAttention
O_flash, l, m = flash_attention_forward(Q, K, V, block_size_r=32, block_size_c=32)
max_diff = (O_std - O_flash).abs().max().item()
print(f"FlashAttention 等价性验证:")
print(f" 最大绝对误差: {max_diff:.6e}")
print(f" 数学等价: {max_diff < 1e-5}")
return max_diff < 1e-5
5.2 IO 复杂度对比实验
python
def compare_io_complexity():
"""对比标准注意力与 FlashAttention 的 IO 复杂度"""
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seq_lengths = [512, 1024, 2048, 4096]
d = 64
B, H = 2, 8
print("注意力 IO 对比 (B=2, H=8, d=64)")
print("=" * 60)
print(f"{'序列长度':>10} | {'标准注意力 (ms)':>18} | {'FlashAttention (ms)':>20} | {'加速比':>8}")
print("-" * 60)
for L in seq_lengths:
Q = torch.randn(B, H, L, d, device=device)
K = torch.randn(B, H, L, d, device=device)
V = torch.randn(B, H, L, d, device=device)
# 预热
for _ in range(3):
_ = standard_attention(Q, K, V)
if device.type == "cuda":
torch.cuda.synchronize()
# 标准注意力
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(10):
_ = standard_attention(Q, K, V)
if device.type == "cuda":
torch.cuda.synchronize()
t_std = (time.perf_counter() - t0) / 10 * 1000
# FlashAttention (使用 PyTorch 内置实现)
if hasattr(F, "scaled_dot_product_attention"):
for _ in range(3):
_ = F.scaled_dot_product_attention(Q, K, V)
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(10):
_ = F.scaled_dot_product_attention(Q, K, V)
if device.type == "cuda":
torch.cuda.synchronize()
t_flash = (time.perf_counter() - t0) / 10 * 1000
else:
t_flash = t_std # fallback
speedup = t_std / t_flash if t_flash > 0 else 1.0
print(f"{L:>10} | {t_std:>18.2f} | {t_flash:>20.2f} | {speedup:>7.2f}×")
5.3 显存占用对比
python
def compare_memory_usage():
"""对比标准注意力与 FlashAttention 的显存占用"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
L = 4096
d = 64
B, H = 2, 8
bytes_per_elem = 4 # FP32
# 标准注意力需要存储 S 和 P
std_memory = B * H * L * L * 2 * bytes_per_elem # S + P
# FlashAttention 只需存储 O, l, m
flash_memory = B * H * (L * d + L + L) * bytes_per_elem
print(f"显存对比 (L={L}, d={d}, B={B}, H={H}):")
print(f" 标准注意力: {std_memory / 1024**2:.1f} MB (存储 S + P)")
print(f" FlashAttention: {flash_memory / 1024**2:.2f} MB (存储 O + l + m)")
print(f" 显存节省: {std_memory / flash_memory:.0f}×")
6. FlashAttention 的理论意义
6.1 IO 复杂度 vs 计算复杂度
| 指标 | 标准注意力 | FlashAttention |
|---|---|---|
| FLOPs | O ( L 2 d ) O(L^2 d) O(L2d) | O ( L 2 d ) O(L^2 d) O(L2d)(相同) |
| HBM IO | O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) | O ( L 2 d 2 / M ) O(L^2 d^2 / M) O(L2d2/M) |
| 显存 | O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) | O ( L + L d ) O(L + Ld) O(L+Ld) |
FlashAttention 的 FLOPs 不变,但 IO 大幅减少。在 GPU 上,IO 是实际瓶颈,因此减少 IO = 更快。
6.2 算法-硬件协同设计的典范
FlashAttention 的成功证明了一个重要的设计哲学:
不要只优化理论 FLOPs,要优化实际墙钟时间。
这需要深入理解硬件架构(内存层次、并行度、带宽),并将这些约束融入算法设计。
7. FlashAttention 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ FlashAttention 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 标准注意力: ║
║ S = Q·K^T / √d (L×L, 写入 HBM) ║
║ P = softmax(S) (L×L, 读写 HBM) ║
║ O = P·V (L×d, 读 HBM) ║
║ IO: O(L² + Ld) ║
║ ║
║ 2. 在线 Softmax 增量更新: ║
║ m_new = max(m_old, m̃_j) ║
║ l_new = e^{m_old-m_new}·l_old + e^{m̃_j-m_new}·l̃_j ║
║ O_new = (e^{m_old-m_new}·l_old·O_old + e^{m̃_j-m_new}·P̃_j·V_j) / l_new ║
║ 数学等价于从头计算的 softmax ║
║ ║
║ 3. FlashAttention IO 复杂度: ║
║ IO = O(L²·d² / M) (M = SRAM 容量) ║
║ 当 d=128, M=192KB 时, IO 减少 ~L²/1500 倍 ║
║ ║
║ 4. 显存节省: ║
║ 标准: O(L²) (存储 S + P) ║
║ Flash: O(Ld) (存储 O + l + m) ║
║ 节省: O(L/d) 倍 ║
║ ║
║ 5. 反向传播: ║
║ 不存储 P, 利用 (O, l, m) 重新计算 ║
║ FLOPs 增加 ~50-75%, 但 IO 减少 → 总时间更快 ║
║ ║
║ 6. 分块约束: ║
║ B_r·d + B_c·d + B_r·B_c ≤ M (SRAM 容量) ║
║ 最优: B_r = B_c = √(M/(2d)) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第二篇:FlashAttention-2 --- 并行优化与因果掩码
1. 引言
FlashAttention-2(Dao, 2023)在 FlashAttention 的基础上,通过三个关键优化进一步提升了 GPU 利用率:
- 减少非矩阵乘法运算:将 softmax 的缩放和修正操作融合到矩阵乘法中
- 序列长度维度并行:在 Q 的块之间增加并行度
- 优化因果掩码:利用因果注意力的下三角结构减少无效计算
2. FlashAttention-2 的核心优化
2.1 减少非 matmul 运算
在 FlashAttention-1 中,每个内层循环需要执行:
python
# 修正因子计算 (非 matmul)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij_tilde - mi_new)
li_new = alpha * li + beta * lij_tilde
# 输出修正 (非 matmul)
Oi_new = (alpha * li * Oi + beta * Pij_tilde @ Vj) / li_new
这些逐元素操作(element-wise)虽然 FLOPs 少,但在 GPU 上效率低------GPU 的 Tensor Core 专门优化矩阵乘法,逐元素操作无法充分利用。
FlashAttention-2 的优化:将缩放操作融合到矩阵乘法中:
O i ( j ) = diag ( l ~ i ( j ) ) − 1 ( e m i ( j − 1 ) − m i ( j ) l i ( j − 1 ) O i ( j − 1 ) + P ~ i j V j ) \mathbf{O}_i^{(j)} = \text{diag}(\tilde{l}_i^{(j)})^{-1} \left( e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \mathbf{O}i^{(j-1)} + \tilde{\mathbf{P}}{ij} \mathbf{V}_j \right) Oi(j)=diag(l~i(j))−1(emi(j−1)−mi(j)li(j−1)Oi(j−1)+P~ijVj)
将 e m i ( j − 1 ) − m i ( j ) l i ( j − 1 ) e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} emi(j−1)−mi(j)li(j−1) 作为标量乘法融合到最后的输出缩放中。
2.2 序列长度维度并行
FlashAttention-1 的并行策略:
外层: 遍历 K,V 块 (串行)
内层: 遍历 Q 块 (并行)
问题:外层循环是串行的(因为在线 softmax 需要按序处理 KV 块),限制了 GPU 的并行度。
FlashAttention-2 的优化:将 Q 的块分配给不同的 thread block,KV 块的遍历在每个 thread block 内串行执行:
Thread Block i: 处理 Q_i, 串行遍历所有 KV 块
Thread Block j: 处理 Q_j, 串行遍历所有 KV 块
...
每个 thread block 独立计算 O i \mathbf{O}_i Oi,最后合并(reduction)。这使得并行度从 T c T_c Tc(KV 块数)增加到 T r T_r Tr(Q 块数),通常 T r ≥ T c T_r \geq T_c Tr≥Tc。
2.3 因果掩码优化
因果注意力(Causal Attention)的掩码:
S i j = { Q i K j T / d if i ≥ j − ∞ if i < j \mathbf{S}_{ij} = \begin{cases} \mathbf{Q}_i \mathbf{K}_j^T / \sqrt{d} & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} Sij={QiKjT/d −∞if i≥jif i<j
在标准实现中, i < j i < j i<j 的块仍然被计算但被掩码为 − ∞ -\infty −∞,浪费计算。
FlashAttention-2 的优化 :对于每个 Q 块 i i i,只处理 j ≤ i j \leq i j≤i 的 KV 块:
for i = 1 to T_r:
for j = 1 to i: // 只处理 j ≤ i 的块
计算 S_ij, P_ij, 更新 O_i
进一步优化 :对于 j = i j = i j=i 的对角块,只有下三角部分有效。FlashAttention-2 对对角块使用特殊的内核,只计算下三角部分。
2.4 FlashAttention-2 的性能提升
在 A100 GPU 上的实测:
| 模型 | FlashAttention-1 | FlashAttention-2 | 提升 |
|---|---|---|---|
| GPT-J (2K) | 170 TFLOPS | 230 TFLOPS | 1.35× |
| GPT-J (4K) | 180 TFLOPS | 240 TFLOPS | 1.33× |
| 长序列 (8K) | 190 TFLOPS | 250 TFLOPS | 1.32× |
FlashAttention-2 达到了 A100 理论 FLOPS 的 ~73%(相比 FlashAttention-1 的 ~55%)。
3. 完整可运行实现
3.1 FlashAttention-2 因果掩码实现
python
def flash_attention_causal(
Q: torch.Tensor, # (B, H, L, d)
K: torch.Tensor, # (B, H, L, d)
V: torch.Tensor, # (B, H, L, d)
block_size: int = 64,
) -> torch.Tensor:
"""
FlashAttention-2 因果掩码版本
优化: 只计算 j ≤ i 的块, 对角块只计算下三角
"""
B, H, L, d = Q.shape
scale = 1.0 / math.sqrt(d)
O = torch.zeros(B, H, L, d, device=Q.device, dtype=Q.dtype)
l = torch.zeros(B, H, L, device=Q.device, dtype=Q.dtype)
m = torch.full((B, H, L), float("-inf"), device=Q.device, dtype=Q.dtype)
num_blocks = math.ceil(L / block_size)
for i in range(num_blocks):
i_start = i * block_size
i_end = min(i_start + block_size, L)
Qi = Q[:, :, i_start:i_end, :]
Oi = O[:, :, i_start:i_end, :]
li = l[:, :, i_start:i_end]
mi = m[:, :, i_start:i_end]
# 因果掩码: 只处理 j ≤ i 的块
for j in range(i + 1):
j_start = j * block_size
j_end = min(j_start + block_size, L)
Kj = K[:, :, j_start:j_end, :]
Vj = V[:, :, j_start:j_end, :]
# 计算 S_ij
Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) * scale
# 因果掩码: 对角块只保留下三角
if i == j:
mask = torch.triu(
torch.ones(block_size, block_size, device=Q.device), diagonal=1
).bool()
Sij = Sij.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
# 在线 softmax 更新
mij_tilde = Sij.max(dim=-1).values
mi_new = torch.max(mi, mij_tilde)
Pij_tilde = torch.exp(Sij - mij_tilde.unsqueeze(-1))
lij_tilde = Pij_tilde.sum(dim=-1)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij_tilde - mi_new)
li_new = alpha * li + beta * lij_tilde
Oi_new = (
alpha.unsqueeze(-1) * li.unsqueeze(-1) * Oi
+ beta.unsqueeze(-1) * torch.matmul(Pij_tilde, Vj)
) / li_new.unsqueeze(-1)
Oi = Oi_new
li = li_new
mi = mi_new
O[:, :, i_start:i_end, :] = Oi
l[:, :, i_start:i_end] = li
m[:, :, i_start:i_end] = mi
return O
3.2 因果注意力等价性验证
python
def verify_causal_flash_attention():
"""验证因果 FlashAttention 与标准因果注意力的等价性"""
torch.manual_seed(42)
B, H, L, d = 2, 4, 128, 64
Q = torch.randn(B, H, L, d)
K = torch.randn(B, H, L, d)
V = torch.randn(B, H, L, d)
# 标准因果注意力
S = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d)
mask = torch.triu(torch.ones(L, L), diagonal=1).bool()
S = S.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
P = torch.softmax(S, dim=-1)
O_std = torch.matmul(P, V)
# FlashAttention 因果版本
O_flash = flash_attention_causal(Q, K, V, block_size=32)
max_diff = (O_std - O_flash).abs().max().item()
print(f"因果 FlashAttention 等价性验证:")
print(f" 最大绝对误差: {max_diff:.6e}")
print(f" 数学等价: {max_diff < 1e-5}")
return max_diff < 1e-5
4. FlashAttention-2 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ FlashAttention-2 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 减少非 matmul 运算: ║
║ 将缩放操作融合到 matmul 中, Tensor Core 利用率从 ~55% 提升到 ~73% ║
║ ║
║ 2. 序列长度维度并行: ║
║ FlashAttention-1: Q 块内层并行, KV 块外层串行 ║
║ FlashAttention-2: Q 块并行 (不同 thread block), KV 块内部串行 ║
║ 并行度: T_c → T_r (通常 T_r ≥ T_c) ║
║ ║
║ 3. 因果掩码优化: ║
║ 只计算 j ≤ i 的块 (跳过上三角) ║
║ 对角块: 只计算下三角部分 ║
║ 计算量减少约 50% ║
║ ║
║ 4. 性能: ║
║ A100 上达到 ~250 TFLOPS (理论峰值的 ~73%) ║
║ 相比 FlashAttention-1 提升 ~33% ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第三篇:FlashAttention-3 --- Hopper 架构异步流水线
1. 引言
FlashAttention-3(Shah et al., 2024)针对 NVIDIA Hopper 架构(H100/H200)的硬件特性进行了深度优化:
- 异步 WGMMA 指令:利用 Hopper 的异步矩阵乘法指令,实现计算与数据搬运的重叠
- FP8 低精度支持:利用 Hopper 的 FP8 Tensor Core,实现 2× 的计算吞吐
- 块量化(Block Quantization):在保持精度的同时实现 FP8 注意力
2. Hopper 架构的新特性
2.1 异步执行模型
Hopper 引入了 TMA(Tensor Memory Accelerator) 和 WGMMA(Warp Group Matrix Multiply-Accumulate) 指令:
- TMA:异步将数据从 HBM 搬运到 SMEM(共享内存),不占用计算单元
- WGMMA:异步执行矩阵乘法,可以直接从 SMEM 读取操作数
关键 :TMA 和 WGMMA 可以同时执行,实现计算与数据搬运的完全重叠。
2.2 FlashAttention-3 的三阶段流水线
时间 → ┌─────────┬─────────┬─────────┐
│ 阶段 1 │ 阶段 2 │ 阶段 3 │
│ TMA 加载 │ WGMMA │ Softmax │
│ KV 块 │ Q·K^T │ 修正 │
├─────────┼─────────┼─────────┤
│ TMA 加载 │ WGMMA │ Softmax │
│ 下一 KV │ P·V │ 修正 │
└─────────┴─────────┴─────────┘
三个阶段完全重叠,GPU 的计算单元和内存搬运单元同时工作。
2.3 FP8 与块量化
FP8(E4M3 格式)的动态范围有限,直接应用于注意力会导致精度损失。
块量化:将 Q、K 分成小块,每块独立缩放到 FP8 范围:
Q FP8 = Quantize ( Q block s q ) , s q = max ( ∣ Q block ∣ ) FP8_MAX \mathbf{Q}{\text{FP8}} = \text{Quantize}\left(\frac{\mathbf{Q}{\text{block}}}{s_q}\right), \quad s_q = \frac{\max(|\mathbf{Q}_{\text{block}}|)}{\text{FP8\_MAX}} QFP8=Quantize(sqQblock),sq=FP8_MAXmax(∣Qblock∣)
注意力计算后反量化:
S = Q FP8 K FP8 T d ⋅ s q ⋅ s k \mathbf{S} = \frac{\mathbf{Q}{\text{FP8}} \mathbf{K}{\text{FP8}}^T}{\sqrt{d}} \cdot s_q \cdot s_k S=d QFP8KFP8T⋅sq⋅sk
2.4 FlashAttention-3 的性能
在 H100 GPU 上:
| 精度 | FlashAttention-2 | FlashAttention-3 | 提升 |
|---|---|---|---|
| FP16 | 350 TFLOPS | 580 TFLOPS | 1.66× |
| FP8 | N/A | 1100 TFLOPS | 3.14× (vs FA2 FP16) |
FlashAttention-3 达到了 H100 理论 FLOPS 的 ~75% (FP16)和 ~73%(FP8)。
3. FlashAttention-3 数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ FlashAttention-3 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 异步流水线: ║
║ TMA (数据搬运) ⊕ WGMMA (矩阵乘法) ⊕ Softmax 三阶段完全重叠 ║
║ GPU 计算单元和内存单元同时满载 ║
║ ║
║ 2. FP8 块量化: ║
║ Q_FP8 = Quantize(Q_block / s_q) ║
║ s_q = max(|Q_block|) / FP8_MAX ║
║ S = Q_FP8 · K_FP8^T · s_q · s_k / √d ║
║ ║
║ 3. 性能 (H100): ║
║ FP16: ~580 TFLOPS (理论峰值的 ~75%) ║
║ FP8: ~1100 TFLOPS (理论峰值的 ~73%) ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
4. FlashAttention 系列演进总结
FlashAttention (2022)
├── 核心: 分块 + 在线 Softmax
├── IO: O(L²d²/M)
├── 显存: O(Ld)
└── A100: ~180 TFLOPS
FlashAttention-2 (2023)
├── 优化 1: 减少非 matmul 运算
├── 优化 2: 序列长度维度并行
├── 优化 3: 因果掩码优化
└── A100: ~250 TFLOPS (+39%)
FlashAttention-3 (2024)
├── 优化 1: Hopper 异步流水线 (TMA + WGMMA)
├── 优化 2: FP8 块量化
└── H100: ~1100 TFLOPS FP8
参考文献
- Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024.
- Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08691.
- Vaswani, A., Shazeer, N., et al. (2017). Attention Is All You Need. NeurIPS 2017.
- Milakov, M., & Gimelshein, N. (2018). Online normalizer calculation for softmax. arXiv:1805.02867.
- Hong, J. W., & Kung, H. T. (1981). I/O complexity: The red-blue pebble game. STOC 1981.
- NVIDIA. (2022). NVIDIA H100 Tensor Core GPU Architecture. Technical Report.