FlashAttention 切分详解:从公式到分布式
- [1. 核心公式与切分逻辑 (The Tiling Strategy)](#1. 核心公式与切分逻辑 (The Tiling Strategy))
- [2. 具体数字举例 (Step-by-Step)](#2. 具体数字举例 (Step-by-Step))
-
-
-
- [Step 1: 加载 Q 1 Q_1 Q1 (外循环开始)](#Step 1: 加载 Q 1 Q_1 Q1 (外循环开始))
- [Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮)](#Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮))
- [Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮)](#Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮))
- [Step 4: 写回](#Step 4: 写回)
-
-
- [3. "分布式"应用](#3. “分布式”应用)
-
- [场景 A:Ring Attention (针对长文本 Prefill)](#场景 A:Ring Attention (针对长文本 Prefill))
- [场景 B:FlashDecoding (针对 Decoding)](#场景 B:FlashDecoding (针对 Decoding))
- [4. 总结](#4. 总结)
1. 核心公式与切分逻辑 (The Tiling Strategy)
宏观公式
O = Softmax ( Q K T ) V O = \text{Softmax}(Q K^T) V O=Softmax(QKT)V
其中:
- Q Q Q (Query): [ N , d ] [N, d] [N,d] ( N N N=序列长度)
- K , V K, V K,V (Key, Value): [ N , d ] [N, d] [N,d]
- O O O (Output): [ N , d ] [N, d] [N,d]
微观切分 (Tiling)
FlashAttention 不会一次性算出 N × N N \times N N×N 的 Attention 矩阵。它把 Q , K , V Q, K, V Q,K,V 切分成小块(Block),放入 SRAM(高速缓存/本地内存)中计算。
- Q Q Q 切分(外循环): 将 Q Q Q 按行切分成 T r T_r Tr 个块,每块大小 B r × d B_r \times d Br×d。记为 Q 1 , Q 2 , ... Q_1, Q_2, \dots Q1,Q2,...
- K , V K, V K,V 切分(内循环): 将 K , V K, V K,V 按行(对应 Attention 矩阵的列)切分成 T c T_c Tc 个块,每块大小 B c × d B_c \times d Bc×d。记为 K 1 , V 1 , ... K_1, V_1, \dots K1,V1,...
算法流程 (双层循环)
python
for i in 1 to Tr: # 外循环:遍历 Query 块 (加载 Qi 到 SRAM)
# 初始化局部累加器 O_i, l_i (sum), m_i (max)
for j in 1 to Tc: # 内循环:遍历 Key/Value 块 (加载 Kj, Vj 到 SRAM)
1. 计算分数: S_ij = Qi * Kj^T
2. 更新统计量 (Online Softmax): 更新局部 max 和 sum
3. 计算局部结果: P_ij = Softmax(S_ij)
4. 累加到 O_i: O_i = O_i + P_ij * Vj (注意这里有 rescale)
# 内循环结束,O_i 计算完成,写回 HBM
2. 具体数字举例 (Step-by-Step)
为了看懂,我们把维度设得很小:
- 序列长度 N = 4 N = 4 N=4
- Block 大小 B = 2 B = 2 B=2
- 所以 Q Q Q 被切成 2 块 ( Q 1 , Q 2 Q_1, Q_2 Q1,Q2), K , V K, V K,V 被切成 2 块 ( K 1 , V 1 K_1, V_1 K1,V1 和 K 2 , V 2 K_2, V_2 K2,V2)。
目标:计算 O 1 O_1 O1 (也就是前 2 个 Token 的输出)。
Step 1: 加载 Q 1 Q_1 Q1 (外循环开始)
SRAM 中读入 Q 1 Q_1 Q1 (Token 0, 1)。此时 O 1 O_1 O1 初始化为 0。
Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮)
从 HBM 读入 K 1 , V 1 K_1, V_1 K1,V1 (Token 0, 1)。
- 算分数: S 11 = Q 1 × K 1 T S_{11} = Q_1 \times K_1^T S11=Q1×K1T (这是一个 2 × 2 2\times2 2×2 的小矩阵)。
- 假设 S 11 = [ 10 20 10 10 ] S_{11} = \begin{bmatrix} 10 & 20 \\ 10 & 10 \end{bmatrix} S11=[10102010]
- 局部 Softmax:
- 行最大值 m n e w = [ 20 , 10 ] m_{new} = [20, 10] mnew=[20,10]。
- 算出局部概率 P 11 P_{11} P11。
- 更新输出: O 1 = P 11 × V 1 O_{1} = P_{11} \times V_1 O1=P11×V1。
- 此时 O 1 O_1 O1 包含了 Token 0,1 对 Token 0,1 的注意力结果。
Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮)
从 HBM 读入 K 2 , V 2 K_2, V_2 K2,V2 (Token 2, 3)。注意: Q 1 Q_1 Q1 还在 SRAM 里,不用动!
- 算分数: S 12 = Q 1 × K 2 T S_{12} = Q_1 \times K_2^T S12=Q1×K2T。
- 假设 S 12 = [ 30 5 5 5 ] S_{12} = \begin{bmatrix} 30 & 5 \\ 5 & 5 \end{bmatrix} S12=[30555]
- Online Softmax 更新 (关键):
- 对比上一轮的最大值 m o l d = [ 20 , 10 ] m_{old}=[20, 10] mold=[20,10] 和现在的最大值 [ 30 , 5 ] [30, 5] [30,5]。
- 第一行:新最大值是 30。说明上一轮算出的 O 1 O_1 O1 第一行偏小了,需要乘以 e 20 − 30 e^{20-30} e20−30 缩小(Rescale) 之前的贡献,再加上新的贡献。
- 累加输出: O 1 = Rescale ( O 1 ) + P 12 × V 2 O_1 = \text{Rescale}(O_1) + P_{12} \times V_2 O1=Rescale(O1)+P12×V2。
Step 4: 写回
内循环结束,所有 K , V K, V K,V 都看过了。 O 1 O_1 O1 就是最终结果,写回 HBM。
3. "分布式"应用
在单 GPU 上,上面的循环是串行执行的。但在分布式上,我们可以把循环拆开并行。
场景 A:Ring Attention (针对长文本 Prefill)
这是 FlashAttention 分布式版本的最经典实现。
逻辑: 将 FlashAttention 的内循环(遍历 K, V 块) 变成 跨设备的"传球"。
假设:芯片有 2 个 Memory Block (Core 0, Core 1)。
- 序列长 N = 4 N=4 N=4。
- Core 0 存了 Token 0,1 的数据 ( Q 0 , K 0 , V 0 Q_0, K_0, V_0 Q0,K0,V0)。
- Core 1 存了 Token 2,3 的数据 ( Q 1 , K 1 , V 1 Q_1, K_1, V_1 Q1,K1,V1)。
目标: Core 0 需要计算 Q 0 Q_0 Q0 对全量 K K K ( K 0 K_0 K0 和 K 1 K_1 K1) 的注意力。
并行流程:
-
Phase 1 (本地计算):
- Core 0: 计算 A t t n ( Q 0 , K 0 , V 0 ) Attn(Q_0, K_0, V_0) Attn(Q0,K0,V0)。 (自己算自己的)
- Core 1: 计算 A t t n ( Q 1 , K 1 , V 1 ) Attn(Q_1, K_1, V_1) Attn(Q1,K1,V1)。
- 此时,NoC 网络空闲。
-
Phase 2 (通信 + 计算重叠):
- 通信: Core 0 把 K 0 , V 0 K_0, V_0 K0,V0 发给 Core 1;Core 1 把 K 1 , V 1 K_1, V_1 K1,V1 发给 Core 0。 (形成一个环)
- 计算: 当 Core 0 收到 K 1 , V 1 K_1, V_1 K1,V1 后,立即计算 A t t n ( Q 0 , K 1 , V 1 ) Attn(Q_0, K_1, V_1) Attn(Q0,K1,V1)。
- Online Softmax: 利用在线公式,把 Phase 2 的结果和 Phase 1 的结果合并。
- 在片内,Memory Block 之间走 NoC (片上网络) ,带宽极高,这就是 Ring Attention 的最佳舞台。
场景 B:FlashDecoding (针对 Decoding)
逻辑: 针对 FlashAttention 外循环(Q)无法并行 的问题(因为 Decoding 时 Q Q Q 只有 1 行)。
解法: 强行切分 K 维度 (Split-K)。
假设: 芯片有 100 个核。KV Cache 很大,散落在 100 个 Block 里。
- 广播 Q: 把当前的 Query ( 1 × d 1 \times d 1×d) 广播给所有 100 个核。
- 并行计算 (Map):
- Core 0 计算 Q Q Q 和 Block 0 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到
Partial_O_0。 - ...
- Core 99 计算 Q Q Q 和 Block 99 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到
Partial_O_99。
- Core 0 计算 Q Q Q 和 Block 0 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到
- 树状归约 (Reduce):
- 这 100 个
Partial_O必须要合并。 - 利用 Online Softmax 公式,两两合并,最终得到全局的 O O O。
- 这 100 个
4. 总结
FlashAttention 是怎么切分的?分布式场景怎么用?
"FlashAttention 采用了**双层分块(Tiling)**策略:
- 外层循环切分 Query,负责决定输出的行。
- 内层循环切分 Key/Value,负责在 SRAM 中流式计算并累加结果。
- 利用 Online Softmax 技巧,保证了分块计算后能还原出精确的全局 Softmax 结果,避免了 N × N N \times N N×N 矩阵的显存读写。
分布式/多核架构时,这种切分方式可以自然地映射为并行策略:
- 针对 Prefill (长文本): 采用 Ring Attention 模式。将内层循环的 KV 加载变成片上网络的数据流动。每个 Core 固定处理一部分 Q,让 KV Block 在 Core 之间流转,实现计算与通信的完美重叠。
- 针对 Decoding (生成): 由于 Q 很小,采用 Split-K (FlashDecoding) 模式。将 KV Cache 物理打散到所有 Block,广播 Q,让所有 Core 并行计算部分 Attention,最后通过片上归约树(Reduction Tree)合并结果。
这种软硬协同的切分,能最大化利用片内的高带宽和多核算力。"