关键词:IO-aware、在线 Softmax(online softmax)、tiling、shared memory、SDPA、PyTorch 后端、FA1/FA2/FA3
1. 为什么需要 FlashAttention?
标准自注意力(Scaled Dot-Product Attention)要显式构造 A = softmax(QK^T / √d)
,其中 A
是形状为 [L, L]
的注意力矩阵。
当序列长度 L
增大时:
-
显存/内存压力 :需要物化(materialize)
L×L
的矩阵,中间张量巨大(训练时还要保存用于反传),显存占用随O(L^2)
暴涨; -
速度瓶颈在内存带宽 :GPU 算力很强,但把大矩阵在 HBM(显存)与 SM(片上 SRAM/寄存器)之间来回搬运,I/O 成了主要瓶颈------算力吃不满。
FlashAttention 的关键洞见 是:自注意力在计算上虽是 O(L^2)
,但完全没必要把 L×L
的注意力矩阵完整放到显存里。把算法设计为 IO-aware ,通过分块(tiling)把计算局限在片上(shared memory / 寄存器),并对 softmax 采用在线(分块增量)计算 ,即可大幅减少 HBM 读写,做到精确注意力但只需 O(L) 级别的外存占用 ,训练/推理显著加速。这一思想由 FlashAttention 系列系统化阐述并给出最优 IO 复杂度证明与 GPU 友好的实现。arXiv
2. FlashAttention 的核心思想
2.1 分块(Tiling)与工作集(Working Set)
-
将
Q, K, V
沿序列维度分块(常见 block size 如 128、256 等,具体由硬件共享内存大小与 kernel 调度决定)。 -
一次把一个
Q
块与若干个K/V
块放入片上存储(shared memory + 寄存器) ,完成该Q
块对应的部分注意力与加权求和,然后再换下一个K/V
块。 -
这样整个过程从不物化 完整的
[L, L]
注意力矩阵,只在片上对当前小块做运算,并把最终输出O
写回 HBM。
直观对比(ASCII 示意):
sql
标准注意力(物化 A=softmax(QK^T)):
Q [Lxd] ──┐ ┌── K [Lxd]
├── matmul ────┤
│ [LxL] │
└── softmax ───┘
│
└── A[LxL] × V[Lxd] → O[Lxd] # A 常驻显存
FlashAttention(分块流式,不物化 A):
for q_block in partition(Q):
init m, l, acc_o for rows in q_block # 行最大值、归一化因子、输出累加器
for kv_block in partition(K,V):
load q_block, kv_block into SRAM
s = q_block @ kv_block.K^T / sqrt(d)
# 在线 softmax 更新(见下)
p_t = softmax_tile_update(s, m, l)
acc_o += p_t @ kv_block.V
write acc_o to HBM
2.2 在线 Softmax(Online / Streaming Softmax)
把一行 softmax 按列块(tile)逐步"扫过去"。为保证数值稳定和与一次性 softmax 完全一致,维护三样东西:
-
m
:当前已处理部分的行最大值(running max) -
l
:以m
为基准的归一化分母(running sum of exp) -
acc_o
:加权输出的累加器(以相同归一化系数更新)
对当前 tile 的得分矩阵 S_t
(shape: [Br, Bc]
,Br
/Bc
为行/列块大小),先做
m_new = max(m, rowwise_max(S_t))
,然后
sql
alpha = exp(m - m_new) # 缩放旧分母与旧输出
l = alpha * l + sum(exp(S_t - m_new), axis=tile_cols)
acc_o = alpha * acc_o + (exp(S_t - m_new) @ V_t)
m = m_new
tile 全扫完之后,最后 O_row = acc_o / l
即是精确 softmax 的结果(与一次性 softmax 完全一致)。这就是 online softmax 的要点。arXiv
小例子(标量级)
假设某一行分两块列:
S1=[2.0, 0.0]
,S2=[3.0, 1.0]
。
处理
S1
:m=2.0, l=exp(0)+exp(-2)=1+0.1353=1.1353
;处理
S2
:m_new=max(2.0,3.0)=3.0
,alpha=exp(2-3)=exp(-1)=0.3679
;
l = 0.3679*1.1353 + (exp(0)+exp(-2)) = 0.4176 + (1+0.1353)=1.5529
;注意
l
等价于对全量[2,0,3,1]
做softmax
的分母。最终
p = exp([2,0,3,1]-3)/l = [e^{-1},e^{-3},1,e^{-2}]/1.5529
,与一次性 softmax 等价。
2.3 掩码 / 因果(Causal)/ Dropout
-
因果掩码 可以在每个
S_t
上局部应用:对越过上三角的元素打-inf
,仍可在线更新,不影响精确性; -
一般掩码(padding、局部窗口、ALiBi 等)同理可在 tile 内施加;
-
Dropout 在训练时在 tile 内依据 RNG 生成 mask 并参与
acc_o
更新(反向时可基于种子/Philox 复现)。这些在官方实现与 PyTorch SDPA 后端里均已工程化支持(随特定版本/后端略有差异)。GitHub+1
3. 从 FA1 到 FA2:更好的并行与工作划分
FA1(2022)首次把 IO-aware 精确注意力做成了通用可用 的 CUDA 内核,在 A100 上对常见设定给出 2--4× 的端到端加速,同时训练内存从二次降到线性级别。arXiv
FA2(2023)进一步分析了 GPU 上的占用率与通信代价,提出三点工程优化:
-
减少非 matmul FLOPs;
-
单头内进一步并行 ,跨 thread block 切分
Q
行提高占用; -
在 block 内更好地在 warp 之间分配工作,减少 shared memory 往返 。
由此把 A100 的注意力核效率从 ~25--40% 提升到 50--73% FLOPs 利用率 ,训练 GPT 时单卡可达 225 TFLOPs/s 。arXiv+1
4. FA3(2024):面向 Hopper(H100)的再加速
Hopper 架构(H100)新引入 Tensor Memory Accelerator (TMA) 、更强的 Tensor Core 异步执行与 FP8 支持。FlashAttention-3 通过:
-
warp specialization + TMA,重叠数据搬运与计算;
-
更紧密地交错 block 级 matmul 与 softmax;
-
块级量化(FP8)/不相干处理,用硬件低精度加速同时降低误差;
在 H100 上,FA3 相比 FA2 再提速 1.5--2.0× ,FP16 可达 ~740 TFLOPs/s(约 75% 利用率) ,FP8 甚至接近 1.2 PFLOPs/s ,并报告 低于基线 FP8 的数值误差 。arXiv+2arXiv+2
5. 数值稳定性与等价性
-
在线 softmax 使用行最大值重定标保证与一次性 softmax 完全等价(浮点舍入误差级别差异);
-
混合精度(FP16/BF16/FP8)下,缩放与累加通常在更高精度(FP32)中进行,以减少溢出/下溢;
-
FA3 针对 FP8 做了专门设计,报告在 H100 上优于朴素 FP8 注意力的误差。arXiv+1
6. 伪代码:行级在线 softmax + 分块
python
# 计算 O = softmax(QK^T / sqrt(d)) V
# 仅示意 forward;忽略并行/向量化细节
def flash_attention_forward(Q, K, V, block_rows=128, block_cols=128, causal=False):
L, d = Q.shape
O = zeros(L, d)
for rs in range(0, L, block_rows): # 行块
re = min(rs + block_rows, L)
q = Q[rs:re] # (Br, d)
# 初始化在线 softmax 的三个状态
m = full(re - rs, -inf) # 行最大值
l = zeros(re - rs) # 归一化分母
o = zeros(re - rs, d) # 输出累加器
for cs in range(0, L, block_cols): # 列块
ce = min(cs + block_cols, L)
k = K[cs:ce] # (Bc, d)
v = V[cs:ce] # (Bc, d)
s = q @ k.T / sqrt(d) # (Br, Bc)
if causal:
# 局部因果掩码:只屏蔽 q 的绝对行 > k 的绝对列 的元素
for i in range(re - rs):
s[i, max(0, (rs+i) - cs + 1):] = -inf
# 在线 softmax 更新
m_new = maximum(m, s.max(axis=1))
alpha = exp(m - m_new)
p = exp(s - m_new[:, None]) # 每行以 m_new 作为基准
l = alpha * l + p.sum(axis=1)
o = alpha[:, None] * o + p @ v
m = m_new
O[rs:re] = o / l[:, None]
return O
7. 在 PyTorch 中如何"开箱即用"
7.1 直接用 PyTorch 的 SDPA(推荐)
PyTorch 的 torch.nn.functional.scaled_dot_product_attention
(SDPA)会自动选择 最优后端(FlashAttention-2 / Memory-Efficient / 数学实现),也提供上下文管理器强制选择。只要你的张量与 dtype 满足条件,它就会走到 FA2 kernel 。docs.pytorch.org
bash
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
B, H, L, D = 8, 16, 4096, 128
q = torch.randn(B, H, L, D, dtype=torch.float16, device="cuda")
k = torch.randn(B, H, L, D, dtype=torch.float16, device="cuda")
v = torch.randn(B, H, L, D, dtype=torch.float16, device="cuda")
# 可选:强制用 FlashAttention 后端(否则让 PyTorch 自动选择即可)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
o = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
)
要检查/控制后端,可使用 sdpa_kernel
或 torch.backends.cuda.enable_*_sdp()
开关。GQA 支持 、is_causal
/attn_mask
的限制、dropout 语义等见官方文档。docs.pytorch.org+2docs.pytorch.org+2
7.2 直接使用 flash-attn 包(需要编译/兼容)
安装(建议带 --no-build-isolation
,否则编译时间可能很长;Windows 支持较新版本才开始实验):
bash
pip install flash-attn --no-build-isolation
# 或者源码安装
# python setup.py install
包内暴露了函数式接口与模块化 MHA,支持因果、滑窗、ALiBi、变长等。常用接口例如:
python
from flash_attn import flash_attn_qkvpacked_func # qkv 打包更快
from flash_attn import flash_attn_func
# qkv: (B, L, 3, H, D)
o = flash_attn_qkvpacked_func(
qkv, dropout_p=0.0, softmax_scale=None, causal=True,
window_size=(-1, -1), alibi_slopes=None, deterministic=False
)
请注意 GPU 架构 / CUDA 版本 / 头维 D 的约束 (例如 Ampere/Ada/Hopper、fp16/bf16、D ≤ 256 等),以及 ROCm/AMD 的支持情况(CK / Triton 后端)。GitHub
8. 训练与推理的工程要点
-
形状与 dtype :FA 后端通常要求
float16/bfloat16
(内部可能以 FP32 累加),head_dim
常见需是 8/16 的倍数,越规则越容易命中最快路径。GitHub -
掩码/因果 :尽量使用
is_causal=True
而不是构造巨大上三角 mask;一般 mask 用 SDPA 的attn_mask
即可。docs.pytorch.org -
Dropout :训练时为非零;推理或
model.eval()
时务必设为 0,避免数值偏差与不必要随机性。docs.pytorch.org -
调后端 :用
sdpa_kernel([...])
强制运行指定后端,便于 A/B 与排障;若触发不了 FlashAttention,PyTorch 会给出原因(dtype、形状、设备等)。docs.pytorch.org -
长序列/变长 :FA 提供 varlen 接口与滑窗局部注意力;推理侧与 KV-cache 结合良好。GitHub
-
瓶颈判断:若"算力富余、显存紧张/带宽打满",FA 往往收益巨大;若本就小序列/小 batch,收益不一定显著。
-
混合精度与稳定性 :优先 BF16(硬件支持时),在 FP8(H100/FA3)场景按官方建议设置比例/scaler。arXiv
9. 性能与对比:FA1、FA2、FA3
-
FA1 :提出 IO-aware 精确注意力,线性外存 与 2--4× 加速 ;arXiv
-
FA2 :更好的并行/工作划分 ,在 A100 上达 50--73% FLOPs 利用率 ,GPT 训练 ~225 TFLOPs/s ;arXiv+1
-
FA3 :利用 Hopper(TMA + Tensor Core 异步 + FP8),在 H100 上 1.5--2.0× 进一步提速,FP16 ~740 TFLOPs/s , FP8 ~1.2 PFLOPs/s ;arXiv+2arXiv+2
这些数字均为论文/官方基准报告的端到端或核级指标,实际应用取决于你的 batch、序列、隐藏维、硬件/软件栈等。
10. 与 PyTorch SDPA 的关系
PyTorch 自 2.x 起的 SDPA 会在 CUDA 上优先尝试 Memory-Efficient / FlashAttention-2 等融合 kernel,并根据输入自动路由。
-
你可以完全不关心细节,直接用
F.scaled_dot_product_attention
; -
也可以用上下文管理器强制选择某个后端做对比与调试。docs.pytorch.org
11. 快速基准脚本(训练或推理)
python
import torch, time
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
def bench(b=8, h=16, L=8192, d=128, causal=True, backend=SDPBackend.FLASH_ATTENTION):
q = torch.randn(b, h, L, d, dtype=torch.float16, device="cuda")
k = torch.randn(b, h, L, d, dtype=torch.float16, device="cuda")
v = torch.randn(b, h, L, d, dtype=torch.float16, device="cuda")
# 预热
for _ in range(5):
with sdpa_kernel(backends=[backend]):
F.scaled_dot_product_attention(q, k, v, is_causal=causal)
# 正式计时
torch.cuda.synchronize()
t0 = time.time()
with sdpa_kernel(backends=[backend]):
for _ in range(20):
F.scaled_dot_product_attention(q, k, v, is_causal=causal)
torch.cuda.synchronize()
dt = (time.time() - t0)/20
print(f"backend={backend.name}, L={L}, d={d}, time={dt*1000:.2f} ms")
bench()
提示:你也可以切换到
SDPBackend.MATH
当作"朴素参考基线",或用flash-attn
包做函数级对照。docs.pytorch.org+1
12. 常见问题(FAQ)
Q1:为什么我的代码没有走 FlashAttention?
A:检查 dtype(半精度)、设备是否 CUDA、head_dim 是否满足约束、是否传了不受支持的 mask 形状;用 sdpa_kernel
强制后端并查看警告。docs.pytorch.org
Q2:训练显存还是不够?
A:结合 --gradient-checkpointing
与 FA,显存占用还能进一步下降;同时合理设置 batch / 叠加张量并行策略。
Q3:Ascend / ROCm / 低资源 GPU 怎么办?
A:官方已提供 ROCm(CK / Triton)后端;社区也有将 FA 思想迁移到 NPU/老 GPU 的工作可参考(如"Extend FlashAttention2 to NPUs and Low-resource GPUs",针对昇腾等做了两级 tiling 与通信优化)。实际以厂商/版本为准。GitHub+1
Q4:FP8 会不会不准?
A:FA3 论文报告在 H100 上的 FP8 方案相对基线 FP8 误差更低,同时速度更快;但请结合你任务的损失与评测实际验证。arXiv
13. 进一步阅读
-
FA1 论文 :提出 IO-aware 精确注意力与在线 softmax 的系统化实现与理论分析。arXiv
-
FA2 论文 :更好的并行/工作划分,把核效率推向 GEMM。arXiv
-
FA3 论文/博文 :面向 Hopper 的异步/TMA/FP8 设计与 SOTA 性能。arXiv+1
-
PyTorch SDPA 文档 :如何启用/强制指定后端、限制与示例代码。docs.pytorch.org
-
flash-attention 仓库(含安装与 API) :实践落地与最新支持矩阵。GitHub
14. 一句话总结
FlashAttention 用"IO-aware + 在线 softmax + GPU 友好分块"把"精确注意力"跑得又快又省显存 ;FA2 把核效率推近 GEMM;FA3 借助 Hopper 新特性把注意力推到 ~0.75 PF/s(FP16) / ~1.2 PF/s(FP8) 。在 PyTorch 中用 SDPA 即可"一键吃到"这套加速,在大多数长上下文场景里是必开的基础设施级优化 。arXiv+1