FlashAttention 学习笔记:从公式到分布式

FlashAttention 切分详解:从公式到分布式

  • [1. 核心公式与切分逻辑 (The Tiling Strategy)](#1. 核心公式与切分逻辑 (The Tiling Strategy))
  • [2. 具体数字举例 (Step-by-Step)](#2. 具体数字举例 (Step-by-Step))
        • [Step 1: 加载 Q 1 Q_1 Q1 (外循环开始)](#Step 1: 加载 Q 1 Q_1 Q1 (外循环开始))
        • [Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮)](#Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮))
        • [Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮)](#Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮))
        • [Step 4: 写回](#Step 4: 写回)
  • [3. "分布式"应用](#3. “分布式”应用)
    • [场景 A:Ring Attention (针对长文本 Prefill)](#场景 A:Ring Attention (针对长文本 Prefill))
    • [场景 B:FlashDecoding (针对 Decoding)](#场景 B:FlashDecoding (针对 Decoding))
    • [4. 总结](#4. 总结)

1. 核心公式与切分逻辑 (The Tiling Strategy)

宏观公式

O = Softmax ( Q K T ) V O = \text{Softmax}(Q K^T) V O=Softmax(QKT)V

其中:

  • Q Q Q (Query): [ N , d ] [N, d] [N,d] ( N N N=序列长度)
  • K , V K, V K,V (Key, Value): [ N , d ] [N, d] [N,d]
  • O O O (Output): [ N , d ] [N, d] [N,d]

微观切分 (Tiling)

FlashAttention 不会一次性算出 N × N N \times N N×N 的 Attention 矩阵。它把 Q , K , V Q, K, V Q,K,V 切分成小块(Block),放入 SRAM(高速缓存/本地内存)中计算。

  • Q Q Q 切分(外循环): 将 Q Q Q 按行切分成 T r T_r Tr 个块,每块大小 B r × d B_r \times d Br×d。记为 Q 1 , Q 2 , ... Q_1, Q_2, \dots Q1,Q2,...
  • K , V K, V K,V 切分(内循环): 将 K , V K, V K,V 按行(对应 Attention 矩阵的列)切分成 T c T_c Tc 个块,每块大小 B c × d B_c \times d Bc×d。记为 K 1 , V 1 , ... K_1, V_1, \dots K1,V1,...

算法流程 (双层循环)

python 复制代码
for i in 1 to Tr:  # 外循环:遍历 Query 块 (加载 Qi 到 SRAM)
    # 初始化局部累加器 O_i, l_i (sum), m_i (max)
    for j in 1 to Tc:  # 内循环:遍历 Key/Value 块 (加载 Kj, Vj 到 SRAM)
        1. 计算分数: S_ij = Qi * Kj^T
        2. 更新统计量 (Online Softmax): 更新局部 max 和 sum
        3. 计算局部结果: P_ij = Softmax(S_ij)
        4. 累加到 O_i: O_i = O_i + P_ij * Vj (注意这里有 rescale)
    # 内循环结束,O_i 计算完成,写回 HBM

2. 具体数字举例 (Step-by-Step)

为了看懂,我们把维度设得很小:

  • 序列长度 N = 4 N = 4 N=4
  • Block 大小 B = 2 B = 2 B=2
  • 所以 Q Q Q 被切成 2 块 ( Q 1 , Q 2 Q_1, Q_2 Q1,Q2), K , V K, V K,V 被切成 2 块 ( K 1 , V 1 K_1, V_1 K1,V1 和 K 2 , V 2 K_2, V_2 K2,V2)。

目标:计算 O 1 O_1 O1 (也就是前 2 个 Token 的输出)。

Step 1: 加载 Q 1 Q_1 Q1 (外循环开始)

SRAM 中读入 Q 1 Q_1 Q1 (Token 0, 1)。此时 O 1 O_1 O1 初始化为 0。

Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮)

从 HBM 读入 K 1 , V 1 K_1, V_1 K1,V1 (Token 0, 1)。

  1. 算分数: S 11 = Q 1 × K 1 T S_{11} = Q_1 \times K_1^T S11=Q1×K1T (这是一个 2 × 2 2\times2 2×2 的小矩阵)。
    • 假设 S 11 = [ 10 20 10 10 ] S_{11} = \begin{bmatrix} 10 & 20 \\ 10 & 10 \end{bmatrix} S11=[10102010]
  2. 局部 Softmax:
    • 行最大值 m n e w = [ 20 , 10 ] m_{new} = [20, 10] mnew=[20,10]。
    • 算出局部概率 P 11 P_{11} P11。
  3. 更新输出: O 1 = P 11 × V 1 O_{1} = P_{11} \times V_1 O1=P11×V1。
    • 此时 O 1 O_1 O1 包含了 Token 0,1 对 Token 0,1 的注意力结果。
Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮)

从 HBM 读入 K 2 , V 2 K_2, V_2 K2,V2 (Token 2, 3)。注意: Q 1 Q_1 Q1 还在 SRAM 里,不用动!

  1. 算分数: S 12 = Q 1 × K 2 T S_{12} = Q_1 \times K_2^T S12=Q1×K2T。
    • 假设 S 12 = [ 30 5 5 5 ] S_{12} = \begin{bmatrix} 30 & 5 \\ 5 & 5 \end{bmatrix} S12=[30555]
  2. Online Softmax 更新 (关键):
    • 对比上一轮的最大值 m o l d = [ 20 , 10 ] m_{old}=[20, 10] mold=[20,10] 和现在的最大值 [ 30 , 5 ] [30, 5] [30,5]。
    • 第一行:新最大值是 30。说明上一轮算出的 O 1 O_1 O1 第一行偏小了,需要乘以 e 20 − 30 e^{20-30} e20−30 缩小(Rescale) 之前的贡献,再加上新的贡献。
  3. 累加输出: O 1 = Rescale ( O 1 ) + P 12 × V 2 O_1 = \text{Rescale}(O_1) + P_{12} \times V_2 O1=Rescale(O1)+P12×V2。
Step 4: 写回

内循环结束,所有 K , V K, V K,V 都看过了。 O 1 O_1 O1 就是最终结果,写回 HBM。


3. "分布式"应用

在单 GPU 上,上面的循环是串行执行的。但在分布式上,我们可以把循环拆开并行

场景 A:Ring Attention (针对长文本 Prefill)

这是 FlashAttention 分布式版本的最经典实现。
逻辑: 将 FlashAttention 的内循环(遍历 K, V 块) 变成 跨设备的"传球"

假设:芯片有 2 个 Memory Block (Core 0, Core 1)。

  • 序列长 N = 4 N=4 N=4。
  • Core 0 存了 Token 0,1 的数据 ( Q 0 , K 0 , V 0 Q_0, K_0, V_0 Q0,K0,V0)。
  • Core 1 存了 Token 2,3 的数据 ( Q 1 , K 1 , V 1 Q_1, K_1, V_1 Q1,K1,V1)。

目标: Core 0 需要计算 Q 0 Q_0 Q0 对全量 K K K ( K 0 K_0 K0 和 K 1 K_1 K1) 的注意力。

并行流程:

  1. Phase 1 (本地计算):

    • Core 0: 计算 A t t n ( Q 0 , K 0 , V 0 ) Attn(Q_0, K_0, V_0) Attn(Q0,K0,V0)。 (自己算自己的)
    • Core 1: 计算 A t t n ( Q 1 , K 1 , V 1 ) Attn(Q_1, K_1, V_1) Attn(Q1,K1,V1)。
    • 此时,NoC 网络空闲。
  2. Phase 2 (通信 + 计算重叠):

    • 通信: Core 0 把 K 0 , V 0 K_0, V_0 K0,V0 发给 Core 1;Core 1 把 K 1 , V 1 K_1, V_1 K1,V1 发给 Core 0。 (形成一个环)
    • 计算: 当 Core 0 收到 K 1 , V 1 K_1, V_1 K1,V1 后,立即计算 A t t n ( Q 0 , K 1 , V 1 ) Attn(Q_0, K_1, V_1) Attn(Q0,K1,V1)。
    • Online Softmax: 利用在线公式,把 Phase 2 的结果和 Phase 1 的结果合并。
  • 在片内,Memory Block 之间走 NoC (片上网络) ,带宽极高,这就是 Ring Attention 的最佳舞台。

场景 B:FlashDecoding (针对 Decoding)

逻辑: 针对 FlashAttention 外循环(Q)无法并行 的问题(因为 Decoding 时 Q Q Q 只有 1 行)。
解法: 强行切分 K 维度 (Split-K)。

假设: 芯片有 100 个核。KV Cache 很大,散落在 100 个 Block 里。

  1. 广播 Q: 把当前的 Query ( 1 × d 1 \times d 1×d) 广播给所有 100 个核。
  2. 并行计算 (Map):
    • Core 0 计算 Q Q Q 和 Block 0 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到 Partial_O_0
    • ...
    • Core 99 计算 Q Q Q 和 Block 99 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到 Partial_O_99
  3. 树状归约 (Reduce):
    • 这 100 个 Partial_O 必须要合并。
    • 利用 Online Softmax 公式,两两合并,最终得到全局的 O O O。

4. 总结

FlashAttention 是怎么切分的?分布式场景怎么用?

"FlashAttention 采用了**双层分块(Tiling)**策略:

  1. 外层循环切分 Query,负责决定输出的行。
  2. 内层循环切分 Key/Value,负责在 SRAM 中流式计算并累加结果。
  3. 利用 Online Softmax 技巧,保证了分块计算后能还原出精确的全局 Softmax 结果,避免了 N × N N \times N N×N 矩阵的显存读写。

分布式/多核架构时,这种切分方式可以自然地映射为并行策略:

  • 针对 Prefill (长文本): 采用 Ring Attention 模式。将内层循环的 KV 加载变成片上网络的数据流动。每个 Core 固定处理一部分 Q,让 KV Block 在 Core 之间流转,实现计算与通信的完美重叠。
  • 针对 Decoding (生成): 由于 Q 很小,采用 Split-K (FlashDecoding) 模式。将 KV Cache 物理打散到所有 Block,广播 Q,让所有 Core 并行计算部分 Attention,最后通过片上归约树(Reduction Tree)合并结果。

这种软硬协同的切分,能最大化利用片内的高带宽和多核算力。"

相关推荐
愚公搬代码2 小时前
【愚公系列】《AI+直播营销》046-销讲型直播内容策划(适合销讲型直播的产品和服务)
人工智能
AlphaFinance2 小时前
Compact命令实践指南
人工智能·claude
臭东西的学习笔记2 小时前
论文学习——人类抗体从通用蛋白质语言模型的高效进化
人工智能·学习·语言模型
HZjiangzi2 小时前
考古现场三维记录革新:思看科技SIMSCAN-E无线扫描仪应用详解
人工智能·科技
一招定胜负2 小时前
opencv视频处理
人工智能·opencv·音视频
KG_LLM图谱增强大模型2 小时前
ARK投资2026年度大创意报告:把握颠覆性创新的未来十年
人工智能·知识图谱
沉淅尘2 小时前
Context Engineering: 优化大语言模型性能的关键策略与艺术
数据库·人工智能·语言模型
救救孩子把2 小时前
59-机器学习与大模型开发数学教程-5-6 Adam、RMSProp、AdaGrad 等自适应优化算法
人工智能·算法·机器学习
王莽v22 小时前
LLM 分布式推理:切分、通信与优化
人工智能·分布式