FlashAttention 学习笔记:从公式到分布式

FlashAttention 切分详解:从公式到分布式

  • [1. 核心公式与切分逻辑 (The Tiling Strategy)](#1. 核心公式与切分逻辑 (The Tiling Strategy))
  • [2. 具体数字举例 (Step-by-Step)](#2. 具体数字举例 (Step-by-Step))
        • [Step 1: 加载 Q 1 Q_1 Q1 (外循环开始)](#Step 1: 加载 Q 1 Q_1 Q1 (外循环开始))
        • [Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮)](#Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮))
        • [Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮)](#Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮))
        • [Step 4: 写回](#Step 4: 写回)
  • [3. "分布式"应用](#3. “分布式”应用)
    • [场景 A:Ring Attention (针对长文本 Prefill)](#场景 A:Ring Attention (针对长文本 Prefill))
    • [场景 B:FlashDecoding (针对 Decoding)](#场景 B:FlashDecoding (针对 Decoding))
    • [4. 总结](#4. 总结)

1. 核心公式与切分逻辑 (The Tiling Strategy)

宏观公式

O = Softmax ( Q K T ) V O = \text{Softmax}(Q K^T) V O=Softmax(QKT)V

其中:

  • Q Q Q (Query): [ N , d ] [N, d] [N,d] ( N N N=序列长度)
  • K , V K, V K,V (Key, Value): [ N , d ] [N, d] [N,d]
  • O O O (Output): [ N , d ] [N, d] [N,d]

微观切分 (Tiling)

FlashAttention 不会一次性算出 N × N N \times N N×N 的 Attention 矩阵。它把 Q , K , V Q, K, V Q,K,V 切分成小块(Block),放入 SRAM(高速缓存/本地内存)中计算。

  • Q Q Q 切分(外循环): 将 Q Q Q 按行切分成 T r T_r Tr 个块,每块大小 B r × d B_r \times d Br×d。记为 Q 1 , Q 2 , ... Q_1, Q_2, \dots Q1,Q2,...
  • K , V K, V K,V 切分(内循环): 将 K , V K, V K,V 按行(对应 Attention 矩阵的列)切分成 T c T_c Tc 个块,每块大小 B c × d B_c \times d Bc×d。记为 K 1 , V 1 , ... K_1, V_1, \dots K1,V1,...

算法流程 (双层循环)

python 复制代码
for i in 1 to Tr:  # 外循环:遍历 Query 块 (加载 Qi 到 SRAM)
    # 初始化局部累加器 O_i, l_i (sum), m_i (max)
    for j in 1 to Tc:  # 内循环:遍历 Key/Value 块 (加载 Kj, Vj 到 SRAM)
        1. 计算分数: S_ij = Qi * Kj^T
        2. 更新统计量 (Online Softmax): 更新局部 max 和 sum
        3. 计算局部结果: P_ij = Softmax(S_ij)
        4. 累加到 O_i: O_i = O_i + P_ij * Vj (注意这里有 rescale)
    # 内循环结束,O_i 计算完成,写回 HBM

2. 具体数字举例 (Step-by-Step)

为了看懂,我们把维度设得很小:

  • 序列长度 N = 4 N = 4 N=4
  • Block 大小 B = 2 B = 2 B=2
  • 所以 Q Q Q 被切成 2 块 ( Q 1 , Q 2 Q_1, Q_2 Q1,Q2), K , V K, V K,V 被切成 2 块 ( K 1 , V 1 K_1, V_1 K1,V1 和 K 2 , V 2 K_2, V_2 K2,V2)。

目标:计算 O 1 O_1 O1 (也就是前 2 个 Token 的输出)。

Step 1: 加载 Q 1 Q_1 Q1 (外循环开始)

SRAM 中读入 Q 1 Q_1 Q1 (Token 0, 1)。此时 O 1 O_1 O1 初始化为 0。

Step 2: 加载 K 1 , V 1 K_1, V_1 K1,V1 (内循环第 1 轮)

从 HBM 读入 K 1 , V 1 K_1, V_1 K1,V1 (Token 0, 1)。

  1. 算分数: S 11 = Q 1 × K 1 T S_{11} = Q_1 \times K_1^T S11=Q1×K1T (这是一个 2 × 2 2\times2 2×2 的小矩阵)。
    • 假设 S 11 = [ 10 20 10 10 ] S_{11} = \begin{bmatrix} 10 & 20 \\ 10 & 10 \end{bmatrix} S11=[10102010]
  2. 局部 Softmax:
    • 行最大值 m n e w = [ 20 , 10 ] m_{new} = [20, 10] mnew=[20,10]。
    • 算出局部概率 P 11 P_{11} P11。
  3. 更新输出: O 1 = P 11 × V 1 O_{1} = P_{11} \times V_1 O1=P11×V1。
    • 此时 O 1 O_1 O1 包含了 Token 0,1 对 Token 0,1 的注意力结果。
Step 3: 加载 K 2 , V 2 K_2, V_2 K2,V2 (内循环第 2 轮)

从 HBM 读入 K 2 , V 2 K_2, V_2 K2,V2 (Token 2, 3)。注意: Q 1 Q_1 Q1 还在 SRAM 里,不用动!

  1. 算分数: S 12 = Q 1 × K 2 T S_{12} = Q_1 \times K_2^T S12=Q1×K2T。
    • 假设 S 12 = [ 30 5 5 5 ] S_{12} = \begin{bmatrix} 30 & 5 \\ 5 & 5 \end{bmatrix} S12=[30555]
  2. Online Softmax 更新 (关键):
    • 对比上一轮的最大值 m o l d = [ 20 , 10 ] m_{old}=[20, 10] mold=[20,10] 和现在的最大值 [ 30 , 5 ] [30, 5] [30,5]。
    • 第一行:新最大值是 30。说明上一轮算出的 O 1 O_1 O1 第一行偏小了,需要乘以 e 20 − 30 e^{20-30} e20−30 缩小(Rescale) 之前的贡献,再加上新的贡献。
  3. 累加输出: O 1 = Rescale ( O 1 ) + P 12 × V 2 O_1 = \text{Rescale}(O_1) + P_{12} \times V_2 O1=Rescale(O1)+P12×V2。
Step 4: 写回

内循环结束,所有 K , V K, V K,V 都看过了。 O 1 O_1 O1 就是最终结果,写回 HBM。


3. "分布式"应用

在单 GPU 上,上面的循环是串行执行的。但在分布式上,我们可以把循环拆开并行

场景 A:Ring Attention (针对长文本 Prefill)

这是 FlashAttention 分布式版本的最经典实现。
逻辑: 将 FlashAttention 的内循环(遍历 K, V 块) 变成 跨设备的"传球"

假设:芯片有 2 个 Memory Block (Core 0, Core 1)。

  • 序列长 N = 4 N=4 N=4。
  • Core 0 存了 Token 0,1 的数据 ( Q 0 , K 0 , V 0 Q_0, K_0, V_0 Q0,K0,V0)。
  • Core 1 存了 Token 2,3 的数据 ( Q 1 , K 1 , V 1 Q_1, K_1, V_1 Q1,K1,V1)。

目标: Core 0 需要计算 Q 0 Q_0 Q0 对全量 K K K ( K 0 K_0 K0 和 K 1 K_1 K1) 的注意力。

并行流程:

  1. Phase 1 (本地计算):

    • Core 0: 计算 A t t n ( Q 0 , K 0 , V 0 ) Attn(Q_0, K_0, V_0) Attn(Q0,K0,V0)。 (自己算自己的)
    • Core 1: 计算 A t t n ( Q 1 , K 1 , V 1 ) Attn(Q_1, K_1, V_1) Attn(Q1,K1,V1)。
    • 此时,NoC 网络空闲。
  2. Phase 2 (通信 + 计算重叠):

    • 通信: Core 0 把 K 0 , V 0 K_0, V_0 K0,V0 发给 Core 1;Core 1 把 K 1 , V 1 K_1, V_1 K1,V1 发给 Core 0。 (形成一个环)
    • 计算: 当 Core 0 收到 K 1 , V 1 K_1, V_1 K1,V1 后,立即计算 A t t n ( Q 0 , K 1 , V 1 ) Attn(Q_0, K_1, V_1) Attn(Q0,K1,V1)。
    • Online Softmax: 利用在线公式,把 Phase 2 的结果和 Phase 1 的结果合并。
  • 在片内,Memory Block 之间走 NoC (片上网络) ,带宽极高,这就是 Ring Attention 的最佳舞台。

场景 B:FlashDecoding (针对 Decoding)

逻辑: 针对 FlashAttention 外循环(Q)无法并行 的问题(因为 Decoding 时 Q Q Q 只有 1 行)。
解法: 强行切分 K 维度 (Split-K)。

假设: 芯片有 100 个核。KV Cache 很大,散落在 100 个 Block 里。

  1. 广播 Q: 把当前的 Query ( 1 × d 1 \times d 1×d) 广播给所有 100 个核。
  2. 并行计算 (Map):
    • Core 0 计算 Q Q Q 和 Block 0 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到 Partial_O_0
    • ...
    • Core 99 计算 Q Q Q 和 Block 99 里的 K , V K, V K,V 的 Attention → \rightarrow → 得到 Partial_O_99
  3. 树状归约 (Reduce):
    • 这 100 个 Partial_O 必须要合并。
    • 利用 Online Softmax 公式,两两合并,最终得到全局的 O O O。

4. 总结

FlashAttention 是怎么切分的?分布式场景怎么用?

"FlashAttention 采用了**双层分块(Tiling)**策略:

  1. 外层循环切分 Query,负责决定输出的行。
  2. 内层循环切分 Key/Value,负责在 SRAM 中流式计算并累加结果。
  3. 利用 Online Softmax 技巧,保证了分块计算后能还原出精确的全局 Softmax 结果,避免了 N × N N \times N N×N 矩阵的显存读写。

分布式/多核架构时,这种切分方式可以自然地映射为并行策略:

  • 针对 Prefill (长文本): 采用 Ring Attention 模式。将内层循环的 KV 加载变成片上网络的数据流动。每个 Core 固定处理一部分 Q,让 KV Block 在 Core 之间流转,实现计算与通信的完美重叠。
  • 针对 Decoding (生成): 由于 Q 很小,采用 Split-K (FlashDecoding) 模式。将 KV Cache 物理打散到所有 Block,广播 Q,让所有 Core 并行计算部分 Attention,最后通过片上归约树(Reduction Tree)合并结果。

这种软硬协同的切分,能最大化利用片内的高带宽和多核算力。"

相关推荐
xixixi777773 分钟前
英伟达Agent专用全模态模型出击,仿冒AI智能体泛滥成灾,《AI伦理安全指引》即将落地——AI治理迎来“技术-风险-规范”三重奏
人工智能·5g·安全·ai·大模型·英伟达·智能体
直奔標竿5 分钟前
Java开发者AI转型第二十六课!Spring AI 个人知识库实战(五)——联网搜索增强实战
java·开发语言·人工智能·spring boot·后端·spring
数据皮皮侠AI9 分钟前
中国城市可再生能源数据集(2005-2021)|顶刊 Sci Data 11 种能源面板
大数据·人工智能·笔记·能源·1024程序员节
G311354227313 分钟前
如何用 QClaw 龙虾做一个规律作息健康助理 Agent
大数据·人工智能·ai·云计算
幂律智能14 分钟前
零售行业合同管理数智化转型解决方案
大数据·人工智能·零售
旺财矿工16 分钟前
零基础搭建 OpenClaw 2.6.6 Win11 本地化运行环境
人工智能·openclaw·小龙虾·龙虾·openclaw安装包
九成宫17 分钟前
动手学深度学习PyTorch版初步安装过程
人工智能·pytorch·深度学习
Traving Yu17 分钟前
Prompt提示词工程
人工智能·prompt
NOCSAH18 分钟前
统好AI CRM功能解析:智能录入与跟进
人工智能
He少年19 分钟前
【AI 辅助编程做设备数据采集:一个真实项目的迭代复盘(OpenSpec 驱动)】
人工智能