FlashAttention 深度解析

目录

  • [第一篇:FlashAttention --- IO 感知的精确注意力](#第一篇:FlashAttention — IO 感知的精确注意力)
  • [第二篇:FlashAttention-2 --- 并行优化与因果掩码](#第二篇:FlashAttention-2 — 并行优化与因果掩码)
  • [第三篇:FlashAttention-3 --- Hopper 架构异步流水线](#第三篇:FlashAttention-3 — Hopper 架构异步流水线)
  • 参考文献

第一篇:FlashAttention --- IO 感知的精确注意力

1. 引言

标准自注意力的计算复杂度为 O ( L 2 d ) O(L^2 d) O(L2d),但实践中真正的瓶颈往往不是浮点运算量(FLOPs),而是内存 IO。GPU 的计算能力(FLOPS)增长远快于显存带宽(GB/s),导致注意力计算被内存访问所制约。

FlashAttention (Dao et al., 2022)从一个全新的视角审视注意力:不优化 FLOPs,而是优化内存 IO 。通过分块计算(tiling)和在线 softmax 技巧,在不引入任何近似的情况下,将注意力的 HBM 访问量从 O ( L 2 ) O(L^2) O(L2) 降低到 O ( L 2 d 2 / M ) O(L^2 d^2 / M) O(L2d2/M),其中 M M M 是 SRAM 容量。在实际 GPU 上,这带来了 2-4× 的端到端加速5-20× 的显存节省


2. 理论基础 --- GPU 内存层次与 IO 复杂度

2.1 GPU 内存层次结构

现代 GPU(如 A100)的内存层次:

层级 容量 带宽 延迟
HBM(高带宽显存) 40-80 GB 1.5-3.35 TB/s ~数百周期
SRAM(共享内存/L1) 每 SM 164-228 KB ~19 TB/s ~数周期
寄存器 每 SM 256 KB ~数百 TB/s 1 周期

关键洞察 :SRAM 比 HBM 快约 10-20×,但容量小约 1000×。标准注意力实现需要将 O ( L 2 ) O(L^2) O(L2) 大小的注意力矩阵在 HBM 和 SRAM 之间反复传输,造成严重的 IO 瓶颈。

2.2 IO 复杂度理论

IO 复杂度衡量算法执行期间在不同内存层级之间的数据传输量。

对于矩阵乘法 C = A B \mathbf{C} = \mathbf{A} \mathbf{B} C=AB,其中 A ∈ R M × K \mathbf{A} \in \mathbb{R}^{M \times K} A∈RM×K, B ∈ R K × N \mathbf{B} \in \mathbb{R}^{K \times N} B∈RK×N:

朴素实现 :每次计算 C i j = ∑ k A i k B k j \mathbf{C}{ij} = \sum_k \mathbf{A}{ik} \mathbf{B}_{kj} Cij=∑kAikBkj,需要从 HBM 读取整行/整列。

分块实现 :将矩阵分成 B r × B c B_r \times B_c Br×Bc 大小的块,每次将两个块加载到 SRAM,在 SRAM 中完成子矩阵乘法。

IO 复杂度下界(Hong & Kung, 1981):

IO ≥ M N K M SRAM \text{IO} \geq \frac{MNK}{M_{\text{SRAM}}} IO≥MSRAMMNK

其中 M SRAM M_{\text{SRAM}} MSRAM 是 SRAM 容量。分块矩阵乘法可以达到这个下界。

2.3 标准注意力的 IO 分析

标准注意力的计算步骤:

S = Q K T ∈ R L × L (写入 HBM) \mathbf{S} = \mathbf{Q} \mathbf{K}^T \in \mathbb{R}^{L \times L} \quad \text{(写入 HBM)} S=QKT∈RL×L(写入 HBM)

P = softmax ( S ) ∈ R L × L (读取 + 写入 HBM) \mathbf{P} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{L \times L} \quad \text{(读取 + 写入 HBM)} P=softmax(S)∈RL×L(读取 + 写入 HBM)

O = P V ∈ R L × d (读取 HBM) \mathbf{O} = \mathbf{P} \mathbf{V} \in \mathbb{R}^{L \times d} \quad \text{(读取 HBM)} O=PV∈RL×d(读取 HBM)

IO 分析

步骤 HBM 读 HBM 写 说明
计算 S \mathbf{S} S O ( L d ) O(Ld) O(Ld) O ( L 2 ) O(L^2) O(L2) 读 Q,K;写 S
计算 P \mathbf{P} P O ( L 2 ) O(L^2) O(L2) O ( L 2 ) O(L^2) O(L2) 读 S;写 P
计算 O \mathbf{O} O O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) O ( L d ) O(Ld) O(Ld) 读 P,V;写 O
总计 O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) O ( L 2 ) O(L^2) O(L2) ---

当 L ≫ d L \gg d L≫d 时(长序列),IO 复杂度为 O ( L 2 ) O(L^2) O(L2),由注意力矩阵 S , P \mathbf{S}, \mathbf{P} S,P 的读写主导。

问题 : L 2 L^2 L2 大小的注意力矩阵必须物化(materialized)到 HBM 中,这既是 IO 瓶颈,也是显存瓶颈。


3. FlashAttention 的核心算法

3.1 核心思想:分块 + 在线 Softmax

FlashAttention 的关键洞察:不需要将完整的注意力矩阵 S \mathbf{S} S 和 P \mathbf{P} P 物化到 HBM 中

通过分块计算,在 SRAM 中逐块完成 softmax 和矩阵乘法,只将最终结果 O \mathbf{O} O 写回 HBM。

3.2 在线 Softmax 算法

标准 softmax 需要两次遍历数据:

softmax ( x ) i = e x i ∑ j e x j \text{softmax}(\mathbf{x})_i = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(x)i=∑jexjexi

第一次遍历计算 max ⁡ ( x ) \max(\mathbf{x}) max(x) 和 ∑ j e x j − max ⁡ \sum_j e^{x_j - \max} ∑jexj−max,第二次计算归一化值。

在线 softmax(Milakov & Gimelshein, 2018)只需一次遍历:

维护运行最大值 m k m_k mk 和运行和 l k l_k lk:

m k = max ⁡ ( m k − 1 , x k ) m_k = \max(m_{k-1}, x_k) mk=max(mk−1,xk)

l k = e m k − 1 − m k l k − 1 + e x k − m k l_k = e^{m_{k-1} - m_k} l_{k-1} + e^{x_k - m_k} lk=emk−1−mklk−1+exk−mk

softmax ( x ) k = e x k − m k l k \text{softmax}(\mathbf{x})_k = \frac{e^{x_k - m_k}}{l_k} softmax(x)k=lkexk−mk

关键性质:当新数据到来时,可以增量更新,无需重新访问历史数据。

3.3 FlashAttention 算法详解

输入 : Q , K , V ∈ R L × d \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{L \times d} Q,K,V∈RL×d,SRAM 容量 M M M

分块策略 :将 Q \mathbf{Q} Q 分成 T r = ⌈ L / B r ⌉ T_r = \lceil L / B_r \rceil Tr=⌈L/Br⌉ 个块, K , V \mathbf{K}, \mathbf{V} K,V 分成 T c = ⌈ L / B c ⌉ T_c = \lceil L / B_c \rceil Tc=⌈L/Bc⌉ 个块。

算法

复制代码
初始化: O = 0_{L×d}, l = 0_L, m = -∞_L

for j = 1 to T_c:                          // 遍历 K,V 的块
    从 HBM 加载 K_j, V_j 到 SRAM           // 块大小: B_c × d

    for i = 1 to T_r:                      // 遍历 Q 的块
        从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM

        // 在 SRAM 中计算
        S_ij = Q_i · K_j^T                 // (B_r × B_c), 在 SRAM 中

        m̃_ij = rowmax(S_ij)                // (B_r,)
        m_new = max(m_i, m̃_ij)             // 更新最大值

        P̃_ij = exp(S_ij - m̃_ij)            // 未归一化的注意力权重
        l̃_ij = rowsum(P̃_ij)                // (B_r,)

        // 修正之前的累积值
        l_new = exp(m_i - m_new) · l_i + exp(m̃_ij - m_new) · l̃_ij

        // 更新输出
        O_i = diag(exp(m_i - m_new))^{-1} · (
              l_i · O_i + exp(m̃_ij - m_new) · P̃_ij · V_j
        ) / l_new

        更新 m_i = m_new, l_i = l_new

        将 O_i, l_i, m_i 写回 HBM          // 块大小: B_r × d + 2 B_r

3.4 算法的数学正确性

关键等式 :在处理完前 j j j 个 KV 块后, O i \mathbf{O}_i Oi 等于:

O i ( j ) = ∑ k = 1 j P ~ i k V k ∑ k = 1 j l ~ i k \mathbf{O}i^{(j)} = \frac{\sum{k=1}^{j} \tilde{\mathbf{P}}{ik} \mathbf{V}k}{\sum{k=1}^{j} \tilde{\mathbf{l}}{ik}} Oi(j)=∑k=1jl~ik∑k=1jP~ikVk

其中 P ~ i k = exp ⁡ ( S i k − m ( j ) ) \tilde{\mathbf{P}}{ik} = \exp(\mathbf{S}{ik} - m^{(j)}) P~ik=exp(Sik−m(j)) 是用最终最大值 m ( j ) m^{(j)} m(j) 修正后的未归一化权重。

证明(归纳法):

基础 : j = 1 j = 1 j=1 时,算法正确计算了第一块的 softmax 加权输出。

归纳 :假设前 j − 1 j-1 j−1 块正确。处理第 j j j 块时:

新的最大值 m ( j ) = max ⁡ ( m ( j − 1 ) , m ~ i j ) m^{(j)} = \max(m^{(j-1)}, \tilde{m}_{ij}) m(j)=max(m(j−1),m~ij)。

旧累积值需要乘以修正因子 exp ⁡ ( m ( j − 1 ) − m ( j ) ) \exp(m^{(j-1)} - m^{(j)}) exp(m(j−1)−m(j)):

l ( j ) = e m ( j − 1 ) − m ( j ) l ( j − 1 ) + e m ~ i j − m ( j ) l ~ i j l^{(j)} = e^{m^{(j-1)} - m^{(j)}} l^{(j-1)} + e^{\tilde{m}{ij} - m^{(j)}} \tilde{l}{ij} l(j)=em(j−1)−m(j)l(j−1)+em~ij−m(j)l~ij

O ( j ) = e m ( j − 1 ) − m ( j ) l ( j − 1 ) O ( j − 1 ) + e m ~ i j − m ( j ) P ~ i j V j l ( j ) \mathbf{O}^{(j)} = \frac{e^{m^{(j-1)} - m^{(j)}} l^{(j-1)} \mathbf{O}^{(j-1)} + e^{\tilde{m}{ij} - m^{(j)}} \tilde{\mathbf{P}}{ij} \mathbf{V}_j}{l^{(j)}} O(j)=l(j)em(j−1)−m(j)l(j−1)O(j−1)+em~ij−m(j)P~ijVj

这恰好等于从头计算的 softmax 加权输出。 ■ \blacksquare ■

3.5 IO 复杂度分析

定理(Dao et al., 2022):FlashAttention 的 HBM IO 复杂度为:

IO = O ( L 2 d 2 M ) \text{IO} = O\left(\frac{L^2 d^2}{M}\right) IO=O(ML2d2)

其中 M M M 是 SRAM 容量。

证明

  • 外层循环 T c = L / B c T_c = L / B_c Tc=L/Bc 次
  • 内层循环 T r = L / B r T_r = L / B_r Tr=L/Br 次
  • 每次内层循环:读取 Q i \mathbf{Q}_i Qi( B r d B_r d Brd)、 O i , l i , m i \mathbf{O}_i, \mathbf{l}_i, \mathbf{m}_i Oi,li,mi( B r d + 2 B r B_r d + 2B_r Brd+2Br),写回( B r d + 2 B r B_r d + 2B_r Brd+2Br)
  • 每次外层循环:读取 K j , V j \mathbf{K}_j, \mathbf{V}_j Kj,Vj( 2 B c d 2 B_c d 2Bcd)

总 IO:

IO = T c ⋅ T r ⋅ O ( B r d ) + T c ⋅ O ( B c d ) \text{IO} = T_c \cdot T_r \cdot O(B_r d) + T_c \cdot O(B_c d) IO=Tc⋅Tr⋅O(Brd)+Tc⋅O(Bcd)

= L B c ⋅ L B r ⋅ O ( B r d ) + L B c ⋅ O ( B c d ) = \frac{L}{B_c} \cdot \frac{L}{B_r} \cdot O(B_r d) + \frac{L}{B_c} \cdot O(B_c d) =BcL⋅BrL⋅O(Brd)+BcL⋅O(Bcd)

= O ( L 2 d B c ) + O ( L d ) = O\left(\frac{L^2 d}{B_c}\right) + O(Ld) =O(BcL2d)+O(Ld)

SRAM 约束: B r d + B c d + B r B c ≤ M B_r d + B_c d + B_r B_c \leq M Brd+Bcd+BrBc≤M,取 B r = B c = M / ( 2 d ) B_r = B_c = \sqrt{M / (2d)} Br=Bc=M/(2d) :

IO = O ( L 2 d M / d ) = O ( L 2 d 3 / 2 M ) \text{IO} = O\left(\frac{L^2 d}{\sqrt{M/d}}\right) = O\left(\frac{L^2 d^{3/2}}{\sqrt{M}}\right) IO=O(M/d L2d)=O(M L2d3/2)

在 d = 128 , M = 192 KB d = 128, M = 192 \text{KB} d=128,M=192KB 的典型设置下,IO 减少约 L 2 d / M ≈ L 2 / 1500 L^2 d / M \approx L^2 / 1500 L2d/M≈L2/1500 倍。 ■ \blacksquare ■


4. 反向传播与显存优化

4.1 问题:注意力矩阵不被存储

标准反向传播需要前向传播中存储的注意力矩阵 P \mathbf{P} P 来计算梯度:

∂ L ∂ V = P T ∂ L ∂ O \frac{\partial \mathcal{L}}{\partial \mathbf{V}} = \mathbf{P}^T \frac{\partial \mathcal{L}}{\partial \mathbf{O}} ∂V∂L=PT∂O∂L

∂ L ∂ P = ∂ L ∂ O V T \frac{\partial \mathcal{L}}{\partial \mathbf{P}} = \frac{\partial \mathcal{L}}{\partial \mathbf{O}} \mathbf{V}^T ∂P∂L=∂O∂LVT

∂ L ∂ S = ∂ L ∂ P ⊙ P − P ⊙ ( ∑ j ∂ L ∂ P i j P i j ) \frac{\partial \mathcal{L}}{\partial \mathbf{S}} = \frac{\partial \mathcal{L}}{\partial \mathbf{P}} \odot \mathbf{P} - \mathbf{P} \odot \left(\sum_j \frac{\partial \mathcal{L}}{\partial \mathbf{P}{ij}} \mathbf{P}{ij}\right) ∂S∂L=∂P∂L⊙P−P⊙(j∑∂Pij∂LPij)

FlashAttention 不存储 P \mathbf{P} P,因此需要重新计算

4.2 重计算策略

FlashAttention 在反向传播时,利用前向传播中存储的以下信息重新计算 P \mathbf{P} P:

  • O \mathbf{O} O:最终输出
  • l \mathbf{l} l:softmax 的归一化因子
  • m \mathbf{m} m:每行的最大值

重新计算 P \mathbf{P} P:

S i j = Q i K j T \mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^T Sij=QiKjT

P i j = exp ⁡ ( S i j − m i ) l i \mathbf{P}{ij} = \frac{\exp(\mathbf{S}{ij} - \mathbf{m}_i)}{\mathbf{l}_i} Pij=liexp(Sij−mi)

额外计算量 :前向传播约 2 L 2 d 2L^2 d 2L2d FLOPs,反向传播约 4 L 2 d 4L^2 d 4L2d FLOPs(重新计算 S \mathbf{S} S 和 P \mathbf{P} P)。总 FLOPs 增加约 50-75%。

:由于 IO 大幅减少,实际墙钟时间反而更快。

4.3 显存节省

方法 需存储的注意力相关数据
标准注意力 S \mathbf{S} S ( L 2 L^2 L2) + P \mathbf{P} P ( L 2 L^2 L2)
FlashAttention O \mathbf{O} O ( L d Ld Ld) + l \mathbf{l} l ( L L L) + m \mathbf{m} m ( L L L)

显存节省: O ( L 2 ) → O ( L d ) O(L^2) \to O(Ld) O(L2)→O(Ld),当 L ≫ d L \gg d L≫d 时节省巨大。


5. 完整可运行实现

5.1 FlashAttention 前向传播

python 复制代码
"""
FlashAttention --- 完整可运行实现 (教学版本)
依赖: torch >= 2.0, numpy, matplotlib
"""

import torch
import torch.nn.functional as F
import math
import numpy as np
from typing import Tuple


def standard_attention(
    Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
) -> torch.Tensor:
    """
    标准注意力实现 (用于验证正确性)

    Q, K, V: (B, H, L, d)
    """
    d = Q.shape[-1]
    S = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d)  # (B, H, L, L)
    P = torch.softmax(S, dim=-1)  # (B, H, L, L)
    O = torch.matmul(P, V)  # (B, H, L, d)
    return O


def flash_attention_forward(
    Q: torch.Tensor,   # (B, H, L, d)
    K: torch.Tensor,   # (B, H, L, d)
    V: torch.Tensor,   # (B, H, L, d)
    block_size_r: int = 32,
    block_size_c: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    FlashAttention 前向传播 (在线 softmax 版本)

    返回: O (输出), l (归一化因子), m (行最大值)
    """
    B, H, L, d = Q.shape
    scale = 1.0 / math.sqrt(d)

    # 初始化输出和统计量
    O = torch.zeros(B, H, L, d, device=Q.device, dtype=Q.dtype)
    l = torch.zeros(B, H, L, device=Q.device, dtype=Q.dtype)
    m = torch.full((B, H, L), float("-inf"), device=Q.device, dtype=Q.dtype)

    # 分块
    num_blocks_r = math.ceil(L / block_size_r)
    num_blocks_c = math.ceil(L / block_size_c)

    for j in range(num_blocks_c):
        # 加载 K_j, V_j 块
        j_start = j * block_size_c
        j_end = min(j_start + block_size_c, L)
        Kj = K[:, :, j_start:j_end, :]  # (B, H, B_c, d)
        Vj = V[:, :, j_start:j_end, :]

        for i in range(num_blocks_r):
            # 加载 Q_i, O_i, l_i, m_i
            i_start = i * block_size_r
            i_end = min(i_start + block_size_r, L)
            Qi = Q[:, :, i_start:i_end, :]  # (B, H, B_r, d)
            Oi = O[:, :, i_start:i_end, :]
            li = l[:, :, i_start:i_end]
            mi = m[:, :, i_start:i_end]

            # 计算 S_ij = Q_i @ K_j^T / sqrt(d)
            Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) * scale  # (B, H, B_r, B_c)

            # 当前行最大值
            mij_tilde = Sij.max(dim=-1).values  # (B, H, B_r)

            # 新的最大值
            mi_new = torch.max(mi, mij_tilde)  # (B, H, B_r)

            # 未归一化的注意力权重
            Pij_tilde = torch.exp(Sij - mij_tilde.unsqueeze(-1))  # (B, H, B_r, B_c)
            lij_tilde = Pij_tilde.sum(dim=-1)  # (B, H, B_r)

            # 修正因子
            alpha = torch.exp(mi - mi_new)  # (B, H, B_r)
            beta = torch.exp(mij_tilde - mi_new)  # (B, H, B_r)

            # 更新归一化因子
            li_new = alpha * li + beta * lij_tilde  # (B, H, B_r)

            # 更新输出
            Oi_new = (
                alpha.unsqueeze(-1) * li.unsqueeze(-1) * Oi
                + beta.unsqueeze(-1) * torch.matmul(Pij_tilde, Vj)
            ) / li_new.unsqueeze(-1)

            # 写回
            O[:, :, i_start:i_end, :] = Oi_new
            l[:, :, i_start:i_end] = li_new
            m[:, :, i_start:i_end] = mi_new

    return O, l, m


def verify_flash_attention():
    """验证 FlashAttention 与标准注意力的等价性"""
    torch.manual_seed(42)

    B, H, L, d = 2, 4, 128, 64
    Q = torch.randn(B, H, L, d)
    K = torch.randn(B, H, L, d)
    V = torch.randn(B, H, L, d)

    # 标准注意力
    O_std = standard_attention(Q, K, V)

    # FlashAttention
    O_flash, l, m = flash_attention_forward(Q, K, V, block_size_r=32, block_size_c=32)

    max_diff = (O_std - O_flash).abs().max().item()
    print(f"FlashAttention 等价性验证:")
    print(f"  最大绝对误差: {max_diff:.6e}")
    print(f"  数学等价: {max_diff < 1e-5}")

    return max_diff < 1e-5

5.2 IO 复杂度对比实验

python 复制代码
def compare_io_complexity():
    """对比标准注意力与 FlashAttention 的 IO 复杂度"""
    import time

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    seq_lengths = [512, 1024, 2048, 4096]
    d = 64
    B, H = 2, 8

    print("注意力 IO 对比 (B=2, H=8, d=64)")
    print("=" * 60)
    print(f"{'序列长度':>10} | {'标准注意力 (ms)':>18} | {'FlashAttention (ms)':>20} | {'加速比':>8}")
    print("-" * 60)

    for L in seq_lengths:
        Q = torch.randn(B, H, L, d, device=device)
        K = torch.randn(B, H, L, d, device=device)
        V = torch.randn(B, H, L, d, device=device)

        # 预热
        for _ in range(3):
            _ = standard_attention(Q, K, V)
            if device.type == "cuda":
                torch.cuda.synchronize()

        # 标准注意力
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        for _ in range(10):
            _ = standard_attention(Q, K, V)
        if device.type == "cuda":
            torch.cuda.synchronize()
        t_std = (time.perf_counter() - t0) / 10 * 1000

        # FlashAttention (使用 PyTorch 内置实现)
        if hasattr(F, "scaled_dot_product_attention"):
            for _ in range(3):
                _ = F.scaled_dot_product_attention(Q, K, V)
            if device.type == "cuda":
                torch.cuda.synchronize()
            t0 = time.perf_counter()
            for _ in range(10):
                _ = F.scaled_dot_product_attention(Q, K, V)
            if device.type == "cuda":
                torch.cuda.synchronize()
            t_flash = (time.perf_counter() - t0) / 10 * 1000
        else:
            t_flash = t_std  # fallback

        speedup = t_std / t_flash if t_flash > 0 else 1.0
        print(f"{L:>10} | {t_std:>18.2f} | {t_flash:>20.2f} | {speedup:>7.2f}×")

5.3 显存占用对比

python 复制代码
def compare_memory_usage():
    """对比标准注意力与 FlashAttention 的显存占用"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    L = 4096
    d = 64
    B, H = 2, 8
    bytes_per_elem = 4  # FP32

    # 标准注意力需要存储 S 和 P
    std_memory = B * H * L * L * 2 * bytes_per_elem  # S + P

    # FlashAttention 只需存储 O, l, m
    flash_memory = B * H * (L * d + L + L) * bytes_per_elem

    print(f"显存对比 (L={L}, d={d}, B={B}, H={H}):")
    print(f"  标准注意力: {std_memory / 1024**2:.1f} MB (存储 S + P)")
    print(f"  FlashAttention: {flash_memory / 1024**2:.2f} MB (存储 O + l + m)")
    print(f"  显存节省: {std_memory / flash_memory:.0f}×")

6. FlashAttention 的理论意义

6.1 IO 复杂度 vs 计算复杂度

指标 标准注意力 FlashAttention
FLOPs O ( L 2 d ) O(L^2 d) O(L2d) O ( L 2 d ) O(L^2 d) O(L2d)(相同)
HBM IO O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) O ( L 2 d 2 / M ) O(L^2 d^2 / M) O(L2d2/M)
显存 O ( L 2 + L d ) O(L^2 + Ld) O(L2+Ld) O ( L + L d ) O(L + Ld) O(L+Ld)

FlashAttention 的 FLOPs 不变,但 IO 大幅减少。在 GPU 上,IO 是实际瓶颈,因此减少 IO = 更快。

6.2 算法-硬件协同设计的典范

FlashAttention 的成功证明了一个重要的设计哲学:

不要只优化理论 FLOPs,要优化实际墙钟时间。

这需要深入理解硬件架构(内存层次、并行度、带宽),并将这些约束融入算法设计。


7. FlashAttention 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    FlashAttention 数学总结                                                ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 标准注意力:                                                                         ║
║     S = Q·K^T / √d           (L×L, 写入 HBM)                                          ║
║     P = softmax(S)           (L×L, 读写 HBM)                                           ║
║     O = P·V                  (L×d, 读 HBM)                                             ║
║     IO: O(L² + Ld)                                                                      ║
║                                                                                        ║
║  2. 在线 Softmax 增量更新:                                                              ║
║     m_new = max(m_old, m̃_j)                                                            ║
║     l_new = e^{m_old-m_new}·l_old + e^{m̃_j-m_new}·l̃_j                                 ║
║     O_new = (e^{m_old-m_new}·l_old·O_old + e^{m̃_j-m_new}·P̃_j·V_j) / l_new            ║
║     数学等价于从头计算的 softmax                                                         ║
║                                                                                        ║
║  3. FlashAttention IO 复杂度:                                                           ║
║     IO = O(L²·d² / M)      (M = SRAM 容量)                                            ║
║     当 d=128, M=192KB 时, IO 减少 ~L²/1500 倍                                          ║
║                                                                                        ║
║  4. 显存节省:                                                                           ║
║     标准: O(L²)    (存储 S + P)                                                        ║
║     Flash: O(Ld)   (存储 O + l + m)                                                    ║
║     节省: O(L/d) 倍                                                                     ║
║                                                                                        ║
║  5. 反向传播:                                                                           ║
║     不存储 P, 利用 (O, l, m) 重新计算                                                  ║
║     FLOPs 增加 ~50-75%, 但 IO 减少 → 总时间更快                                        ║
║                                                                                        ║
║  6. 分块约束:                                                                           ║
║     B_r·d + B_c·d + B_r·B_c ≤ M   (SRAM 容量)                                         ║
║     最优: B_r = B_c = √(M/(2d))                                                        ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第二篇:FlashAttention-2 --- 并行优化与因果掩码

1. 引言

FlashAttention-2(Dao, 2023)在 FlashAttention 的基础上,通过三个关键优化进一步提升了 GPU 利用率:

  1. 减少非矩阵乘法运算:将 softmax 的缩放和修正操作融合到矩阵乘法中
  2. 序列长度维度并行:在 Q 的块之间增加并行度
  3. 优化因果掩码:利用因果注意力的下三角结构减少无效计算

2. FlashAttention-2 的核心优化

2.1 减少非 matmul 运算

在 FlashAttention-1 中,每个内层循环需要执行:

python 复制代码
# 修正因子计算 (非 matmul)
alpha = torch.exp(mi - mi_new)
beta = torch.exp(mij_tilde - mi_new)
li_new = alpha * li + beta * lij_tilde

# 输出修正 (非 matmul)
Oi_new = (alpha * li * Oi + beta * Pij_tilde @ Vj) / li_new

这些逐元素操作(element-wise)虽然 FLOPs 少,但在 GPU 上效率低------GPU 的 Tensor Core 专门优化矩阵乘法,逐元素操作无法充分利用。

FlashAttention-2 的优化:将缩放操作融合到矩阵乘法中:

O i ( j ) = diag ( l ~ i ( j ) ) − 1 ( e m i ( j − 1 ) − m i ( j ) l i ( j − 1 ) O i ( j − 1 ) + P ~ i j V j ) \mathbf{O}_i^{(j)} = \text{diag}(\tilde{l}_i^{(j)})^{-1} \left( e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} \mathbf{O}i^{(j-1)} + \tilde{\mathbf{P}}{ij} \mathbf{V}_j \right) Oi(j)=diag(l~i(j))−1(emi(j−1)−mi(j)li(j−1)Oi(j−1)+P~ijVj)

将 e m i ( j − 1 ) − m i ( j ) l i ( j − 1 ) e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} emi(j−1)−mi(j)li(j−1) 作为标量乘法融合到最后的输出缩放中。

2.2 序列长度维度并行

FlashAttention-1 的并行策略:

复制代码
外层: 遍历 K,V 块 (串行)
内层: 遍历 Q 块 (并行)

问题:外层循环是串行的(因为在线 softmax 需要按序处理 KV 块),限制了 GPU 的并行度。

FlashAttention-2 的优化:将 Q 的块分配给不同的 thread block,KV 块的遍历在每个 thread block 内串行执行:

复制代码
Thread Block i: 处理 Q_i, 串行遍历所有 KV 块
Thread Block j: 处理 Q_j, 串行遍历所有 KV 块
...

每个 thread block 独立计算 O i \mathbf{O}_i Oi,最后合并(reduction)。这使得并行度从 T c T_c Tc(KV 块数)增加到 T r T_r Tr(Q 块数),通常 T r ≥ T c T_r \geq T_c Tr≥Tc。

2.3 因果掩码优化

因果注意力(Causal Attention)的掩码:

S i j = { Q i K j T / d if i ≥ j − ∞ if i < j \mathbf{S}_{ij} = \begin{cases} \mathbf{Q}_i \mathbf{K}_j^T / \sqrt{d} & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} Sij={QiKjT/d −∞if i≥jif i<j

在标准实现中, i < j i < j i<j 的块仍然被计算但被掩码为 − ∞ -\infty −∞,浪费计算。

FlashAttention-2 的优化 :对于每个 Q 块 i i i,只处理 j ≤ i j \leq i j≤i 的 KV 块:

复制代码
for i = 1 to T_r:
    for j = 1 to i:  // 只处理 j ≤ i 的块
        计算 S_ij, P_ij, 更新 O_i

进一步优化 :对于 j = i j = i j=i 的对角块,只有下三角部分有效。FlashAttention-2 对对角块使用特殊的内核,只计算下三角部分。

2.4 FlashAttention-2 的性能提升

在 A100 GPU 上的实测:

模型 FlashAttention-1 FlashAttention-2 提升
GPT-J (2K) 170 TFLOPS 230 TFLOPS 1.35×
GPT-J (4K) 180 TFLOPS 240 TFLOPS 1.33×
长序列 (8K) 190 TFLOPS 250 TFLOPS 1.32×

FlashAttention-2 达到了 A100 理论 FLOPS 的 ~73%(相比 FlashAttention-1 的 ~55%)。


3. 完整可运行实现

3.1 FlashAttention-2 因果掩码实现

python 复制代码
def flash_attention_causal(
    Q: torch.Tensor,   # (B, H, L, d)
    K: torch.Tensor,   # (B, H, L, d)
    V: torch.Tensor,   # (B, H, L, d)
    block_size: int = 64,
) -> torch.Tensor:
    """
    FlashAttention-2 因果掩码版本

    优化: 只计算 j ≤ i 的块, 对角块只计算下三角
    """
    B, H, L, d = Q.shape
    scale = 1.0 / math.sqrt(d)

    O = torch.zeros(B, H, L, d, device=Q.device, dtype=Q.dtype)
    l = torch.zeros(B, H, L, device=Q.device, dtype=Q.dtype)
    m = torch.full((B, H, L), float("-inf"), device=Q.device, dtype=Q.dtype)

    num_blocks = math.ceil(L / block_size)

    for i in range(num_blocks):
        i_start = i * block_size
        i_end = min(i_start + block_size, L)
        Qi = Q[:, :, i_start:i_end, :]
        Oi = O[:, :, i_start:i_end, :]
        li = l[:, :, i_start:i_end]
        mi = m[:, :, i_start:i_end]

        # 因果掩码: 只处理 j ≤ i 的块
        for j in range(i + 1):
            j_start = j * block_size
            j_end = min(j_start + block_size, L)
            Kj = K[:, :, j_start:j_end, :]
            Vj = V[:, :, j_start:j_end, :]

            # 计算 S_ij
            Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) * scale

            # 因果掩码: 对角块只保留下三角
            if i == j:
                mask = torch.triu(
                    torch.ones(block_size, block_size, device=Q.device), diagonal=1
                ).bool()
                Sij = Sij.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))

            # 在线 softmax 更新
            mij_tilde = Sij.max(dim=-1).values
            mi_new = torch.max(mi, mij_tilde)

            Pij_tilde = torch.exp(Sij - mij_tilde.unsqueeze(-1))
            lij_tilde = Pij_tilde.sum(dim=-1)

            alpha = torch.exp(mi - mi_new)
            beta = torch.exp(mij_tilde - mi_new)

            li_new = alpha * li + beta * lij_tilde

            Oi_new = (
                alpha.unsqueeze(-1) * li.unsqueeze(-1) * Oi
                + beta.unsqueeze(-1) * torch.matmul(Pij_tilde, Vj)
            ) / li_new.unsqueeze(-1)

            Oi = Oi_new
            li = li_new
            mi = mi_new

        O[:, :, i_start:i_end, :] = Oi
        l[:, :, i_start:i_end] = li
        m[:, :, i_start:i_end] = mi

    return O

3.2 因果注意力等价性验证

python 复制代码
def verify_causal_flash_attention():
    """验证因果 FlashAttention 与标准因果注意力的等价性"""
    torch.manual_seed(42)

    B, H, L, d = 2, 4, 128, 64
    Q = torch.randn(B, H, L, d)
    K = torch.randn(B, H, L, d)
    V = torch.randn(B, H, L, d)

    # 标准因果注意力
    S = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d)
    mask = torch.triu(torch.ones(L, L), diagonal=1).bool()
    S = S.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf"))
    P = torch.softmax(S, dim=-1)
    O_std = torch.matmul(P, V)

    # FlashAttention 因果版本
    O_flash = flash_attention_causal(Q, K, V, block_size=32)

    max_diff = (O_std - O_flash).abs().max().item()
    print(f"因果 FlashAttention 等价性验证:")
    print(f"  最大绝对误差: {max_diff:.6e}")
    print(f"  数学等价: {max_diff < 1e-5}")

    return max_diff < 1e-5

4. FlashAttention-2 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    FlashAttention-2 数学总结                                              ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 减少非 matmul 运算:                                                                 ║
║     将缩放操作融合到 matmul 中, Tensor Core 利用率从 ~55% 提升到 ~73%                   ║
║                                                                                        ║
║  2. 序列长度维度并行:                                                                   ║
║     FlashAttention-1: Q 块内层并行, KV 块外层串行                                       ║
║     FlashAttention-2: Q 块并行 (不同 thread block), KV 块内部串行                        ║
║     并行度: T_c → T_r (通常 T_r ≥ T_c)                                                  ║
║                                                                                        ║
║  3. 因果掩码优化:                                                                       ║
║     只计算 j ≤ i 的块 (跳过上三角)                                                      ║
║     对角块: 只计算下三角部分                                                            ║
║     计算量减少约 50%                                                                    ║
║                                                                                        ║
║  4. 性能:                                                                               ║
║     A100 上达到 ~250 TFLOPS (理论峰值的 ~73%)                                           ║
║     相比 FlashAttention-1 提升 ~33%                                                     ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第三篇:FlashAttention-3 --- Hopper 架构异步流水线

1. 引言

FlashAttention-3(Shah et al., 2024)针对 NVIDIA Hopper 架构(H100/H200)的硬件特性进行了深度优化:

  1. 异步 WGMMA 指令:利用 Hopper 的异步矩阵乘法指令,实现计算与数据搬运的重叠
  2. FP8 低精度支持:利用 Hopper 的 FP8 Tensor Core,实现 2× 的计算吞吐
  3. 块量化(Block Quantization):在保持精度的同时实现 FP8 注意力

2. Hopper 架构的新特性

2.1 异步执行模型

Hopper 引入了 TMA(Tensor Memory Accelerator)WGMMA(Warp Group Matrix Multiply-Accumulate) 指令:

  • TMA:异步将数据从 HBM 搬运到 SMEM(共享内存),不占用计算单元
  • WGMMA:异步执行矩阵乘法,可以直接从 SMEM 读取操作数

关键 :TMA 和 WGMMA 可以同时执行,实现计算与数据搬运的完全重叠。

2.2 FlashAttention-3 的三阶段流水线

复制代码
时间 →  ┌─────────┬─────────┬─────────┐
        │ 阶段 1   │ 阶段 2   │ 阶段 3   │
        │ TMA 加载 │ WGMMA    │ Softmax  │
        │ KV 块    │ Q·K^T    │ 修正     │
        ├─────────┼─────────┼─────────┤
        │ TMA 加载 │ WGMMA    │ Softmax  │
        │ 下一 KV  │ P·V      │ 修正     │
        └─────────┴─────────┴─────────┘

三个阶段完全重叠,GPU 的计算单元和内存搬运单元同时工作。

2.3 FP8 与块量化

FP8(E4M3 格式)的动态范围有限,直接应用于注意力会导致精度损失。

块量化:将 Q、K 分成小块,每块独立缩放到 FP8 范围:

Q FP8 = Quantize ( Q block s q ) , s q = max ⁡ ( ∣ Q block ∣ ) FP8_MAX \mathbf{Q}{\text{FP8}} = \text{Quantize}\left(\frac{\mathbf{Q}{\text{block}}}{s_q}\right), \quad s_q = \frac{\max(|\mathbf{Q}_{\text{block}}|)}{\text{FP8\_MAX}} QFP8=Quantize(sqQblock),sq=FP8_MAXmax(∣Qblock∣)

注意力计算后反量化:

S = Q FP8 K FP8 T d ⋅ s q ⋅ s k \mathbf{S} = \frac{\mathbf{Q}{\text{FP8}} \mathbf{K}{\text{FP8}}^T}{\sqrt{d}} \cdot s_q \cdot s_k S=d QFP8KFP8T⋅sq⋅sk

2.4 FlashAttention-3 的性能

在 H100 GPU 上:

精度 FlashAttention-2 FlashAttention-3 提升
FP16 350 TFLOPS 580 TFLOPS 1.66×
FP8 N/A 1100 TFLOPS 3.14× (vs FA2 FP16)

FlashAttention-3 达到了 H100 理论 FLOPS 的 ~75% (FP16)和 ~73%(FP8)。


3. FlashAttention-3 数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    FlashAttention-3 数学总结                                              ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 异步流水线:                                                                         ║
║     TMA (数据搬运) ⊕ WGMMA (矩阵乘法) ⊕ Softmax  三阶段完全重叠                        ║
║     GPU 计算单元和内存单元同时满载                                                      ║
║                                                                                        ║
║  2. FP8 块量化:                                                                         ║
║     Q_FP8 = Quantize(Q_block / s_q)                                                    ║
║     s_q = max(|Q_block|) / FP8_MAX                                                     ║
║     S = Q_FP8 · K_FP8^T · s_q · s_k / √d                                              ║
║                                                                                        ║
║  3. 性能 (H100):                                                                        ║
║     FP16: ~580 TFLOPS (理论峰值的 ~75%)                                                ║
║     FP8:  ~1100 TFLOPS (理论峰值的 ~73%)                                               ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

4. FlashAttention 系列演进总结

复制代码
FlashAttention (2022)
├── 核心: 分块 + 在线 Softmax
├── IO: O(L²d²/M)
├── 显存: O(Ld)
└── A100: ~180 TFLOPS

FlashAttention-2 (2023)
├── 优化 1: 减少非 matmul 运算
├── 优化 2: 序列长度维度并行
├── 优化 3: 因果掩码优化
└── A100: ~250 TFLOPS (+39%)

FlashAttention-3 (2024)
├── 优化 1: Hopper 异步流水线 (TMA + WGMMA)
├── 优化 2: FP8 块量化
└── H100: ~1100 TFLOPS FP8

参考文献

  1. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
  2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. ICLR 2024.
  3. Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08691.
  4. Vaswani, A., Shazeer, N., et al. (2017). Attention Is All You Need. NeurIPS 2017.
  5. Milakov, M., & Gimelshein, N. (2018). Online normalizer calculation for softmax. arXiv:1805.02867.
  6. Hong, J. W., & Kung, H. T. (1981). I/O complexity: The red-blue pebble game. STOC 1981.
  7. NVIDIA. (2022). NVIDIA H100 Tensor Core GPU Architecture. Technical Report.
相关推荐
来让爷抱一个1 小时前
阿里发布Qwen3.7-Plus:连续跑11小时,自主开发了一个App
人工智能
圣殿骑士-Khtangc1 小时前
MoE 混合专家模型深度解析:DeepSeek-V3 和 Qwen-MoE 的工程奥秘
人工智能
IT_陈寒1 小时前
Python列表的+=操作符坑了我一整天
前端·人工智能·后端
高洁011 小时前
用知识图谱重构搜索引擎
人工智能·python·数据挖掘·virtualenv·知识图谱
广州灵眸科技有限公司1 小时前
3Tops NPU + 4核高性能架构:灵眸科技EASY-EAI-PI2开发板,为边缘AI开启“easy模式”
服务器·前端·人工智能·python·科技·深度学习·架构
小e说说2 小时前
海同科技可信吗?16年IT教育品牌深度实测解析
大数据·人工智能
满怀冰雪2 小时前
第05篇-滑动窗口算法-一套模板解决子串与子数组问题
java·算法
apcipot_rain2 小时前
计科八股20260609——10分钟速通《线性代数》,知识点极简版
人工智能·线性代数·机器学习