FlashAttention 原理与工程实践:从 IO-aware 到 H100 上的 1 PFLOPs/s

关键词:IO-aware、在线 Softmax(online softmax)、tiling、shared memory、SDPA、PyTorch 后端、FA1/FA2/FA3

1. 为什么需要 FlashAttention?

标准自注意力(Scaled Dot-Product Attention)要显式构造 A = softmax(QK^T / √d),其中 A 是形状为 [L, L] 的注意力矩阵。

当序列长度 L 增大时:

  • 显存/内存压力 :需要物化(materialize)L×L 的矩阵,中间张量巨大(训练时还要保存用于反传),显存占用随 O(L^2) 暴涨;

  • 速度瓶颈在内存带宽 :GPU 算力很强,但把大矩阵在 HBM(显存)与 SM(片上 SRAM/寄存器)之间来回搬运,I/O 成了主要瓶颈------算力吃不满。

FlashAttention 的关键洞见 是:自注意力在计算上虽是 O(L^2),但完全没必要把 L×L 的注意力矩阵完整放到显存里。把算法设计为 IO-aware ,通过分块(tiling)把计算局限在片上(shared memory / 寄存器),并对 softmax 采用在线(分块增量)计算 ,即可大幅减少 HBM 读写,做到精确注意力但只需 O(L) 级别的外存占用 ,训练/推理显著加速。这一思想由 FlashAttention 系列系统化阐述并给出最优 IO 复杂度证明与 GPU 友好的实现。arXiv


2. FlashAttention 的核心思想

2.1 分块(Tiling)与工作集(Working Set)

  • Q, K, V 沿序列维度分块(常见 block size 如 128、256 等,具体由硬件共享内存大小与 kernel 调度决定)。

  • 一次把一个 Q 块与若干个 K/V 块放入片上存储(shared memory + 寄存器) ,完成该 Q 块对应的部分注意力与加权求和,然后再换下一个 K/V 块。

  • 这样整个过程从不物化 完整的 [L, L] 注意力矩阵,只在片上对当前小块做运算,并把最终输出 O 写回 HBM。

直观对比(ASCII 示意):

sql 复制代码
标准注意力(物化 A=softmax(QK^T)):

 Q [Lxd]  ──┐              ┌── K [Lxd]
            ├── matmul ────┤
            │   [LxL]      │
            └── softmax ───┘
                  │
                  └── A[LxL]  ×  V[Lxd]  →  O[Lxd]   # A 常驻显存

FlashAttention(分块流式,不物化 A):

for q_block in partition(Q):
    init m, l, acc_o for rows in q_block   # 行最大值、归一化因子、输出累加器
    for kv_block in partition(K,V):
        load q_block, kv_block into SRAM
        s = q_block @ kv_block.K^T / sqrt(d)
        # 在线 softmax 更新(见下)
        p_t = softmax_tile_update(s, m, l)
        acc_o += p_t @ kv_block.V
    write acc_o to HBM

2.2 在线 Softmax(Online / Streaming Softmax)

把一行 softmax 按列块(tile)逐步"扫过去"。为保证数值稳定和与一次性 softmax 完全一致,维护三样东西:

  • m:当前已处理部分的行最大值(running max)

  • l:以 m 为基准的归一化分母(running sum of exp)

  • acc_o加权输出的累加器(以相同归一化系数更新)

对当前 tile 的得分矩阵 S_t(shape: [Br, Bc]Br/Bc 为行/列块大小),先做
m_new = max(m, rowwise_max(S_t)),然后

sql 复制代码
alpha = exp(m - m_new)                 # 缩放旧分母与旧输出
l = alpha * l + sum(exp(S_t - m_new), axis=tile_cols)
acc_o = alpha * acc_o + (exp(S_t - m_new) @ V_t)
m = m_new

tile 全扫完之后,最后 O_row = acc_o / l 即是精确 softmax 的结果(与一次性 softmax 完全一致)。这就是 online softmax 的要点。arXiv

小例子(标量级)

假设某一行分两块列:S1=[2.0, 0.0]S2=[3.0, 1.0]

  • 处理 S1m=2.0, l=exp(0)+exp(-2)=1+0.1353=1.1353

  • 处理 S2m_new=max(2.0,3.0)=3.0alpha=exp(2-3)=exp(-1)=0.3679
    l = 0.3679*1.1353 + (exp(0)+exp(-2)) = 0.4176 + (1+0.1353)=1.5529

    注意 l 等价于对全量 [2,0,3,1]softmax 的分母。

    最终 p = exp([2,0,3,1]-3)/l = [e^{-1},e^{-3},1,e^{-2}]/1.5529,与一次性 softmax 等价。

2.3 掩码 / 因果(Causal)/ Dropout

  • 因果掩码 可以在每个 S_t 上局部应用:对越过上三角的元素打 -inf,仍可在线更新,不影响精确性;

  • 一般掩码(padding、局部窗口、ALiBi 等)同理可在 tile 内施加;

  • Dropout 在训练时在 tile 内依据 RNG 生成 mask 并参与 acc_o 更新(反向时可基于种子/Philox 复现)。

    这些在官方实现与 PyTorch SDPA 后端里均已工程化支持(随特定版本/后端略有差异)。GitHub+1


3. 从 FA1 到 FA2:更好的并行与工作划分

FA1(2022)首次把 IO-aware 精确注意力做成了通用可用 的 CUDA 内核,在 A100 上对常见设定给出 2--4× 的端到端加速,同时训练内存从二次降到线性级别。arXiv

FA2(2023)进一步分析了 GPU 上的占用率与通信代价,提出三点工程优化:

  1. 减少非 matmul FLOPs;

  2. 单头内进一步并行 ,跨 thread block 切分 Q 行提高占用;

  3. 在 block 内更好地在 warp 之间分配工作,减少 shared memory 往返

    由此把 A100 的注意力核效率从 ~25--40% 提升到 50--73% FLOPs 利用率 ,训练 GPT 时单卡可达 225 TFLOPs/sarXiv+1


4. FA3(2024):面向 Hopper(H100)的再加速

Hopper 架构(H100)新引入 Tensor Memory Accelerator (TMA) 、更强的 Tensor Core 异步执行与 FP8 支持。FlashAttention-3 通过:

  • warp specialization + TMA,重叠数据搬运与计算;

  • 更紧密地交错 block 级 matmul 与 softmax;

  • 块级量化(FP8)/不相干处理,用硬件低精度加速同时降低误差;

在 H100 上,FA3 相比 FA2 再提速 1.5--2.0× ,FP16 可达 ~740 TFLOPs/s(约 75% 利用率) ,FP8 甚至接近 1.2 PFLOPs/s ,并报告 低于基线 FP8 的数值误差arXiv+2arXiv+2


5. 数值稳定性与等价性

  • 在线 softmax 使用行最大值重定标保证与一次性 softmax 完全等价(浮点舍入误差级别差异);

  • 混合精度(FP16/BF16/FP8)下,缩放与累加通常在更高精度(FP32)中进行,以减少溢出/下溢;

  • FA3 针对 FP8 做了专门设计,报告在 H100 上优于朴素 FP8 注意力的误差。arXiv+1


6. 伪代码:行级在线 softmax + 分块

python 复制代码
# 计算 O = softmax(QK^T / sqrt(d)) V
# 仅示意 forward;忽略并行/向量化细节
def flash_attention_forward(Q, K, V, block_rows=128, block_cols=128, causal=False):
    L, d = Q.shape
    O = zeros(L, d)

    for rs in range(0, L, block_rows):          # 行块
        re = min(rs + block_rows, L)
        q = Q[rs:re]                             # (Br, d)

        # 初始化在线 softmax 的三个状态
        m = full(re - rs, -inf)                  # 行最大值
        l = zeros(re - rs)                       # 归一化分母
        o = zeros(re - rs, d)                    # 输出累加器

        for cs in range(0, L, block_cols):       # 列块
            ce = min(cs + block_cols, L)
            k = K[cs:ce]                         # (Bc, d)
            v = V[cs:ce]                         # (Bc, d)

            s = q @ k.T / sqrt(d)                # (Br, Bc)

            if causal:
                # 局部因果掩码:只屏蔽 q 的绝对行 > k 的绝对列 的元素
                for i in range(re - rs):
                    s[i, max(0, (rs+i) - cs + 1):] = -inf

            # 在线 softmax 更新
            m_new = maximum(m, s.max(axis=1))
            alpha = exp(m - m_new)
            p = exp(s - m_new[:, None])          # 每行以 m_new 作为基准
            l = alpha * l + p.sum(axis=1)
            o = alpha[:, None] * o + p @ v

            m = m_new

        O[rs:re] = o / l[:, None]
    return O

7. 在 PyTorch 中如何"开箱即用"

7.1 直接用 PyTorch 的 SDPA(推荐)

PyTorch 的 torch.nn.functional.scaled_dot_product_attention(SDPA)会自动选择 最优后端(FlashAttention-2 / Memory-Efficient / 数学实现),也提供上下文管理器强制选择。只要你的张量与 dtype 满足条件,它就会走到 FA2 kerneldocs.pytorch.org

bash 复制代码
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

B, H, L, D = 8, 16, 4096, 128
q = torch.randn(B, H, L, D, dtype=torch.float16, device="cuda")
k = torch.randn(B, H, L, D, dtype=torch.float16, device="cuda")
v = torch.randn(B, H, L, D, dtype=torch.float16, device="cuda")

# 可选:强制用 FlashAttention 后端(否则让 PyTorch 自动选择即可)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
    o = F.scaled_dot_product_attention(
        q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True
    )

要检查/控制后端,可使用 sdpa_kerneltorch.backends.cuda.enable_*_sdp() 开关。GQA 支持is_causal/attn_mask 的限制、dropout 语义等见官方文档。docs.pytorch.org+2docs.pytorch.org+2

7.2 直接使用 flash-attn 包(需要编译/兼容)

安装(建议带 --no-build-isolation,否则编译时间可能很长;Windows 支持较新版本才开始实验):

bash 复制代码
pip install flash-attn --no-build-isolation
# 或者源码安装
# python setup.py install

包内暴露了函数式接口与模块化 MHA,支持因果、滑窗、ALiBi、变长等。常用接口例如:

python 复制代码
from flash_attn import flash_attn_qkvpacked_func   # qkv 打包更快
from flash_attn import flash_attn_func

# qkv: (B, L, 3, H, D)
o = flash_attn_qkvpacked_func(
    qkv, dropout_p=0.0, softmax_scale=None, causal=True,
    window_size=(-1, -1), alibi_slopes=None, deterministic=False
)

请注意 GPU 架构 / CUDA 版本 / 头维 D 的约束 (例如 Ampere/Ada/Hopper、fp16/bf16、D ≤ 256 等),以及 ROCm/AMD 的支持情况(CK / Triton 后端)。GitHub


8. 训练与推理的工程要点

  1. 形状与 dtype :FA 后端通常要求 float16/bfloat16(内部可能以 FP32 累加),head_dim 常见需是 8/16 的倍数,越规则越容易命中最快路径。GitHub

  2. 掩码/因果 :尽量使用 is_causal=True 而不是构造巨大上三角 mask;一般 mask 用 SDPA 的 attn_mask 即可。docs.pytorch.org

  3. Dropout :训练时为非零;推理或 model.eval() 时务必设为 0,避免数值偏差与不必要随机性。docs.pytorch.org

  4. 调后端 :用 sdpa_kernel([...]) 强制运行指定后端,便于 A/B 与排障;若触发不了 FlashAttention,PyTorch 会给出原因(dtype、形状、设备等)。docs.pytorch.org

  5. 长序列/变长 :FA 提供 varlen 接口与滑窗局部注意力;推理侧与 KV-cache 结合良好。GitHub

  6. 瓶颈判断:若"算力富余、显存紧张/带宽打满",FA 往往收益巨大;若本就小序列/小 batch,收益不一定显著。

  7. 混合精度与稳定性 :优先 BF16(硬件支持时),在 FP8(H100/FA3)场景按官方建议设置比例/scaler。arXiv


9. 性能与对比:FA1、FA2、FA3

  • FA1 :提出 IO-aware 精确注意力,线性外存2--4× 加速arXiv

  • FA2更好的并行/工作划分 ,在 A100 上达 50--73% FLOPs 利用率 ,GPT 训练 ~225 TFLOPs/sarXiv+1

  • FA3 :利用 Hopper(TMA + Tensor Core 异步 + FP8),在 H100 上 1.5--2.0× 进一步提速,FP16 ~740 TFLOPs/s , FP8 ~1.2 PFLOPs/sarXiv+2arXiv+2

这些数字均为论文/官方基准报告的端到端或核级指标,实际应用取决于你的 batch、序列、隐藏维、硬件/软件栈等。


10. 与 PyTorch SDPA 的关系

PyTorch 自 2.x 起的 SDPA 会在 CUDA 上优先尝试 Memory-Efficient / FlashAttention-2 等融合 kernel,并根据输入自动路由。

  • 你可以完全不关心细节,直接用 F.scaled_dot_product_attention

  • 也可以用上下文管理器强制选择某个后端做对比与调试。docs.pytorch.org


11. 快速基准脚本(训练或推理)

python 复制代码
import torch, time
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

def bench(b=8, h=16, L=8192, d=128, causal=True, backend=SDPBackend.FLASH_ATTENTION):
    q = torch.randn(b, h, L, d, dtype=torch.float16, device="cuda")
    k = torch.randn(b, h, L, d, dtype=torch.float16, device="cuda")
    v = torch.randn(b, h, L, d, dtype=torch.float16, device="cuda")

    # 预热
    for _ in range(5):
        with sdpa_kernel(backends=[backend]):
            F.scaled_dot_product_attention(q, k, v, is_causal=causal)

    # 正式计时
    torch.cuda.synchronize()
    t0 = time.time()
    with sdpa_kernel(backends=[backend]):
        for _ in range(20):
            F.scaled_dot_product_attention(q, k, v, is_causal=causal)
    torch.cuda.synchronize()
    dt = (time.time() - t0)/20
    print(f"backend={backend.name}, L={L}, d={d}, time={dt*1000:.2f} ms")

bench()

提示:你也可以切换到 SDPBackend.MATH 当作"朴素参考基线",或用 flash-attn 包做函数级对照。docs.pytorch.org+1


12. 常见问题(FAQ)

Q1:为什么我的代码没有走 FlashAttention?

A:检查 dtype(半精度)、设备是否 CUDA、head_dim 是否满足约束、是否传了不受支持的 mask 形状;用 sdpa_kernel 强制后端并查看警告。docs.pytorch.org

Q2:训练显存还是不够?

A:结合 --gradient-checkpointing 与 FA,显存占用还能进一步下降;同时合理设置 batch / 叠加张量并行策略。

Q3:Ascend / ROCm / 低资源 GPU 怎么办?

A:官方已提供 ROCm(CK / Triton)后端;社区也有将 FA 思想迁移到 NPU/老 GPU 的工作可参考(如"Extend FlashAttention2 to NPUs and Low-resource GPUs",针对昇腾等做了两级 tiling 与通信优化)。实际以厂商/版本为准。GitHub+1

Q4:FP8 会不会不准?

A:FA3 论文报告在 H100 上的 FP8 方案相对基线 FP8 误差更低,同时速度更快;但请结合你任务的损失与评测实际验证。arXiv


13. 进一步阅读

  • FA1 论文 :提出 IO-aware 精确注意力与在线 softmax 的系统化实现与理论分析。arXiv

  • FA2 论文 :更好的并行/工作划分,把核效率推向 GEMM。arXiv

  • FA3 论文/博文 :面向 Hopper 的异步/TMA/FP8 设计与 SOTA 性能。arXiv+1

  • PyTorch SDPA 文档 :如何启用/强制指定后端、限制与示例代码。docs.pytorch.org

  • flash-attention 仓库(含安装与 API) :实践落地与最新支持矩阵。GitHub


14. 一句话总结

FlashAttention 用"IO-aware + 在线 softmax + GPU 友好分块"把"精确注意力"跑得又快又省显存 ;FA2 把核效率推近 GEMM;FA3 借助 Hopper 新特性把注意力推到 ~0.75 PF/s(FP16) / ~1.2 PF/s(FP8) 。在 PyTorch 中用 SDPA 即可"一键吃到"这套加速,在大多数长上下文场景里是必开的基础设施级优化arXiv+1

相关推荐
有点不太正常9 小时前
Differentially Private Synthetic Text Generation for RAG——论文阅读
论文阅读·大模型·llm·rag
山顶夕景9 小时前
【LLM】大模型vibe coding(cursor、copilot、comate)
大模型·copilot·coding·vibe coding·代码模型
杀生丸学AI13 小时前
【三维重建】即插即用的3DGS的PDE优化:高质量渲染和重建
人工智能·3d·大模型·aigc·3dgs·高斯泼溅·空间智能
想躺平的咸鱼干1 天前
远程MCP的调用和阿里云生态的知识库和工作流的使用
阿里云·大模型·云计算·idea·格式化输出·mcp
haogexiaole1 天前
什么是语言模型
大模型
泥烟1 天前
使用Milvus和DeepSeek构建RAG demo
大模型·milvus·deepseek
CoderJia程序员甲2 天前
GitHub 热榜项目 - 日榜(2025-10-09)
ai·开源·大模型·github·ai教程
Wild_Pointer.2 天前
面向Qt/C++开发工程师的Ai提示词(附Trae示例)
人工智能·ai·大模型
喜欢吃豆2 天前
从潜在空间到实际应用:Embedding模型架构与训练范式的综合解析
python·自然语言处理·架构·大模型·微调·embedding