目录
前言
看了几个视频和几篇文章学习了下 Flash Attention,记录下个人学习笔记,仅供自己参考😄
refer1:Flash Attention 为什么那么快?原理讲解
refer2:LLM(17):从 FlashAttention 到 PagedAttention, 如何进一步优化 Attention 性能
refer3:https://chatgpt.com/
0. 简述
这篇文章我们简单讨论下 Flash Attention(FA) 的原理,FA 现在是训练大模型默认采用的技术,从论文题目《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》就可以看出它有两大优势:
- Fast:加快模型训练的速度
- Memory-Efficient:内存高效,它可以减少显存的占用
并且 FA 保证 exact attention,也就是它和标准的 attention 计算得到的结果是完全一致的,并不像其它的一些算法是以降低 attention 的精度为代价来提高训练速度的。最后标题中的 with IO-Awareness 说明了它的整个算法是以改进 I/O 效率为目的的
论文的标题言简意赅,直接说明了 Flash Attention 的优势和目的
1. self-attention
在讲解 Flash Attention 之前,我们先看下 Transformer 中 self-attention 的标准计算过程,如下图所示:

Note :在上图中 N N N 代表序列长度, D D D 代表每个 head 的维度
上图展示了 self-attention 的标准计算过程,首先是输入的 token 向量通过 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv 生成对应的 Q , K , V Q,K,V Q,K,V 矩阵,然后 Q K T QK^T QKT 矩阵相乘得到注意力分数矩阵 S S S,接着对 S S S 矩阵按行求取 S o f t m a x \mathrm{Softmax} Softmax 得到注意力概率分布矩阵 P P P,最后利用 P V PV PV 相乘得到最终的注意力输出 O O O
为了简单起见,这里我们省略了缩放因子 1 D \frac{1}{\sqrt{D}} D 1,也没有考虑多头注意力机制以及 dropout、mask 等

Algorithm 0 描述了标准 attention 的实现:
- 输入为三个矩阵 Q , K , V ∈ R N × d \mathbf{Q},\mathbf{K},\mathbf{V} \in \mathbb{R}^{N\times d} Q,K,V∈RN×d,初始存储于 HBM 中
- 步骤 1 :计算 S = Q K ⊤ \mathbf{S} = \mathbf{Q}\mathbf{K}^{\top} S=QK⊤
- 加载 :从高带宽存储(HBM)中分块(by blocks)读取矩阵 Q , K \mathbf{Q},\mathbf{K} Q,K
- 计算 :对每个分块执行矩阵乘法 Q K ⊤ \mathbf{Q}\mathbf{K}^{\top} QK⊤ 得到注意力分数矩阵 S \mathbf{S} S
- 写回 :将计算得到的矩阵 S \mathbf{S} S 写回到 HBM 中
- 步骤 2 :计算 P = s o f t m a x ( S ) \mathbf{P} = \mathrm{softmax} (\mathbf{S}) P=softmax(S)
- 加载 :从 HBM 中读取矩阵 S \mathbf{S} S
- 计算 :对每个分块进行 s o f t m a x \mathrm{softmax} softmax 操作,得到注意力概率分布矩阵 P \mathbf{P} P
- 写回 :将 s o f t m a x \mathrm{softmax} softmax 后的矩阵 P \mathbf{P} P 写回到 HBM
- 这里的 s o f t m a x \mathrm{softmax} softmax 通常是在最后一个维度(即序列长度维度)上进行,以保证每个位置的注意力权重之和为 1
- 步骤 3 :计算 O = P V \mathbf{O}=\mathbf{PV} O=PV
- 加载 :从 HBM 中分块读取矩阵 P , V \mathbf{P},\mathbf{V} P,V
- 计算 :对每个分块执行矩阵乘法 P V \mathbf{PV} PV,得到输出矩阵 O \mathbf{O} O
- 写回 :将最终的输出矩阵 O \mathbf{O} O 写回到 HBM
- 步骤 4 :返回结果 O \mathbf{O} O
- 最终返回的结果矩阵 O \mathbf{O} O 就是标准自注意力的输出
整个流程对应了标准注意力机制中最常见的计算公式:
O = s o f t m a x ( Q K T ) ⏟ S ⏟ P × V O = \underbrace{\mathrm{softmax}\underbrace{(QK^T)}{S}}{P} \times V O=P softmaxS (QKT)×V
那大家不禁会问为什么不一次性计算 Q K T QK^T QKT、 P V PV PV 等矩阵乘法呢,而采用分块的方式呢?🤔
在实际硬件实现(例如 GPU 上)时,会出于效率考虑,采用"分块"(block-wise)读写和计算的方式。具体原因包括:
- 显存或带宽存储的容量与带宽限制:一次性将大矩阵读入或写出会导致带宽开销过大,分块能更好地利用缓存与寄存器
- 并行效率:矩阵乘法、softmax 都可以在分块层面做并行,提升运行效率
- 避免中间激增地存储开销 :例如直接存储完整的 S S S 或 P P P 可能很大,片上内存 SRAM 存储空间不够,而分块可以减少一次性占用的显存量
2. roofline model
在讲解 FlashAttention 的优化之前,我们先来分析下标准的 attention 计算的瓶颈,我们可以使用 roofline model 来进行分析
roofline model 是一种常用的、直观的性能分析工具,用来帮助我们评估应用程序在给定硬件架构(CPU 或 GPU)上的性能上限(即理论最佳性能),并确定性能瓶颈究竟来自于算力(Compute Bound)还是存储带宽(Memory Bandwidth Bound)
关于 roofline model 韩君老师也简单讲过,大家感兴趣的可以看看:四. TensorRT模型部署优化-Roofline model

上面是 roofline model 的图示,图中坐标含义如下:
- 横轴:Operational Intensity I I I / 计算强度
- 单位 FLOP/Byte
- 每访问内存一个字节所执行的浮点运算次数
- 计算强度越高,程序越趋近于计算密集型(compute-bound),受计算能力限制;计算强度越低,则趋近于内存密集型(memory-bound),受内存带宽限制
- 纵轴:Attainable Performance P P P / 性能
- 单位 FLOP/s
- 每秒执行的浮点运算次数
- 算力 π \pi π:也称为计算平台的性能上限
- 单位 FLOP/s
- 指的是一个计算平台倾尽全力每秒钟所能完成的浮点运算数
- 带宽 β \beta β:计算平台的带宽上限
- 单位 Byte/s
- 指的是一个计算平台倾尽全力每秒所能完成的内存交换量
- 计算强度上限 I m a x = π β I_{max}=\frac{\pi}{\beta} Imax=βπ:两个指标相除即可得到计算平台的计算强度上限
- 单位 FLOPs/Byte
- 描述的是在这个计算平台上,单位内存交换最多用来进行多少次计算
roofline model 中算力 决定 "屋顶" 的高度(图中绿色线段),带宽决定 "房檐" 的斜率(图中红色线段)。roofline model 划分出的两个瓶颈区域为:
P = { β ⋅ I , w h e n I < I m a x Memory Bound π , w h e n I ⩾ I m a x Compute Bound P=\left\{ \begin{array} {ll}\beta\cdot I, & when\ \ I<I_{max}\quad{\color{red}\text{Memory Bound}} \\ \\ \pi, & when\ \ I\geqslant I_{max}\quad{\color{green}\text{Compute Bound}} \end{array}\right. P=⎩ ⎨ ⎧β⋅I,π,when I<ImaxMemory Boundwhen I⩾ImaxCompute Bound
- 计算约束(Memory Bound)---此时 HBM 访问所花费的时间相对较低,不管模型的计算强度 I I I 多大,它的理论性能 P P P 最大只能等于计算平台的算力 π \pi π。例如,具有较大维度的矩阵乘法和具有大量通道的卷积
- 带宽约束(Compute Bound)---此时模型位于 "房檐" 区间,模型理论性能 P P P 大小完全由计算平台的带宽上限 β \beta β(房檐的斜率)以及模型自身的计算强度 I I I 所决定。例如,elementwise 操作(如 activation,dropout 等)和规约操作(如 sum,softmax,batch normalization,layer normalization 等)
在 self-attention 中,它的计算瓶颈主要是 Memory Bound,这主要是因为:
- 大量读写中间结果
- self-attention 的标准实现中会将 s o f t m a x ( Q K T ) \mathrm{softmax}(QK^T) softmax(QKT) 的中间分数矩阵 S S S 先写回 HBM,再读回 SRAM 做 softmax,再写出 P P P,随后又要读入 P P P 才能与 V V V 相乘
- 这类频繁的显存/主存的 I/O 极大地消耗带宽
- 序列长度变大时, N 2 N^2 N2 读写地影响更严重
- 注意力分数矩阵和概率矩阵的大小是 N 2 N^2 N2,随着 N N N 的增加,存储和带宽消耗呈平方级增长,很快就变成主要瓶颈
对于 Memory Bound 的优化一般是进行 fusion 融合操作,不对中间结果缓存,减少 HBM 的访问,如下图所示:

因此 FlashAttention 论文的优化思路本质就是要减少不必要的中间读写,尽可能地使用 SRAM 来加快计算速度,在内部使用 tiling 计算与 softmax 融合,提高算术强度,降低访存量,从而减轻内存带宽瓶颈
前面反复提到的 HBM 和 SRAM 究竟指什么内存呢,二者有什么区别呢?🤔
Note :杜老师课程时有简单讲解过内存模型,大家感兴趣的可以看看:3.3.cuda运行时API-内存的学习,pinnedmemory,内存效率问题
要搞清楚 HBM 和 SRAM 的区别,我们首先需要对 GPU 存储体系结构做了解:

上图展示了 CUDA 编程模型下各种类型的存储,以及它们在 GPU 体系结构中的特点、作用范围和访问方式,大家可以简单看看
我们来看一个不太严谨但是能辅助理解的图:

这张图展示了显卡中常见内存所处位置,其中 shared memory 为片上内存,global memory 为片外内存,我们需要知道的是距离计算芯片(运算单元)越近的内存,速度越快,但空间越小,价格越贵。因此 shared memory 的存储容量是远远小于 global memory 的,但它的速度是要更快的
ok,再回到我们的 HBM 和 SRAM 的话题
HBM (H igh B andwidth M emory)是指 "高带宽内存",属于 GPU 上的全局内存(global memory)。SRAM (S tatic R andom-A ccess Memory)是指 "静态随机存取存储器",在 GPU 架构中通常指代片上(on-chip)的快速缓存、寄存器或共享内存(shared memory/L1 Cache/寄存器等)
从性能角度看,SRAM(片上内存)比 HBM(片外内存)更快、延迟更低,但容量要小得多,HBM 容量更大,但带宽和延迟都逊于片上内存

3. 矩阵分块
下面我们来看下如何通过矩阵分块和融合多个计算来减少对 HBM 的访问,整个计算过程如下(假设 N = 6 , D = 4 N=6,D=4 N=6,D=4):

Note :这里我们先跳过 softmax 操作,因为它比较特殊,分块计算比较麻烦,后面我们再单独讨论,我们先认为结果直接就是 S × V S\times V S×V
首先我们从 HBM 中分块读取 Q Q Q 的前两行, K T K^T KT 的前三列,然后传入到 SRAM 中对它们进行矩阵乘法 Q K T QK^T QKT 的计算得到 S S S。得到的 S S S 并不存入 HBM 中,而是直接和 V V V 的分块进行计算,得到 O O O 的前两行
值得注意的是这时我们得到的 O O O 的前两行并不是最终的结果,因为我们知道 O O O 是对所有的 V V V 的一个加权平均,目前只是对 V V V 的前三行进行加权平均,后面我们还要对 O O O 进行更新,因此这里我们用浅一点的颜色来表示,这个值还只是一个中间结果

接着 K K K 和 V V V 的分块还保留在 SRAM 中,从 HBM 中读取 Q Q Q 的中间两行,然后经过同样的计算,得到 O O O 的第三行和第四行的中间结果

然后继续保留 K K K 和 V V V 的分块在 SRAM 中,从 H B M HBM HBM 里读取 Q Q Q 的最后两行,经过同样的计算,得到 O O O 的最后两行的中间结果

接下来读取 K T K^T KT 的后三列, V V V 的后三行, Q Q Q 的前两行进行计算,这里算出的 O O O 是对 V V V 的后三行值的加权平均,再从 HBM 中读取 O O O 之前保存的中间结果(也就是对 V V V 的前三行的加权平均值进行加和),这就是 O O O 的前两行的最终结果了

同样保持 K K K 和 V V V 的分块不变,从 HBM 里读取下一个分块的 Q Q Q 进行计算,从 HBM 里读取之前计算的中间结果 O ′ O^{\prime} O′,加和更新后存入 HBM

最后继续 SRAM 里面的 K K K 和 V V V 分块不变,从 HBM 里读取最后一个分块的 Q Q Q 进行计算,从 HBM 里读取之前计算的中间结果 O ′ O^{\prime} O′,加和更新后存入 HBM,这时就完成了一个 attention 的计算
我们可以发现,通过对矩阵分块以及将多步计算进行融合,中途没有将中间计算结果 S S S 存入 HBM 中,大大减少了 I/O 的时间,这一切看起来都不错,除了 softmax
softmax 是按行进行的,只有这一行所有的数据都计算完成后,才能进行之后对 V V V 的加权求和计算,所以如果想要让之前的矩阵分块对 attention 多步进行融合计算,前提是必须解决 softmax 分块计算的问题
4. softmax分块
下面我们就来讨论 softmax 的分块计算
在讨论之前我们先来看原始 softmax 中存在的数值溢出问题,我们在上篇文章中也详细讲过,大家感兴趣的可以看看:从Online Softmax到FlashAttention
原始 softmax 的计算公式如下:
softmax ( x i ) = e x i ∑ i = 1 N e x i \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{i=1}^{N}e^{x_i}} softmax(xi)=∑i=1Nexiexi
现在我们训练都是在混合精度 FP16 下进行的,FP16 所能表示的最大数值是 65536,这意味着当 x ≥ 11 x \geq 11 x≥11 时, e x e^x ex 就超过了 FP16 所能表示的最大范围了
python
import numpy as np
scores = np.array([123, 456, 789])
softmax = np.exp(scores) / np.sum(np.exp(scores))
print(softmax) # [ 0. 0. nan]
当然针对数值溢出有其对应的优化方法,那就是将每一个输入值减去输入值中最大的值,这种方法也被称为 safe softmax
m ( x ) : = max i x i softmax ( x i ) = e x i − m ∑ i = 1 N e x i − m m(x) := \max_{i}x_i \\ \text{softmax}(x_i) = \frac{e^{x_i-m}}{\sum_{i=1}^{N}e^{x_i-m}} m(x):=imaxxisoftmax(xi)=∑i=1Nexi−mexi−m
python
import numpy as np
scores = np.array([123, 456, 789])
scores -= np.max(scores)
p = np.exp(scores) / np.sum(np.exp(scores))
print(p) # [5.75274406e-290 2.39848787e-145 1.00000000e+000]
参考自:一文详解Softmax函数
我们再来看下 safe softmax 的计算过程:
x = [ x 1 , ... , x N ] m ( x ) : = max i x i p ( x ) : = [ e x 1 − m ( x ) ... e x B − m ( x ) ] ℓ ( x ) : = ∑ i p ( x ) i softmax ( x ) : = p ( x ) ℓ ( x ) x=[x_1,\ldots,x_N]\\ m(x) := \max_{i}x_i \\ p(x) := [e^{x_1-m(x)}\quad\ldots \quad e^{x_B-m(x)}] \\ \ell \left( x \right):= \sum_{i}p \left( x \right)_{i} \\ \text{softmax}(x) := \frac{p(x)}{\ell(x)} x=[x1,...,xN]m(x):=imaxxip(x):=[ex1−m(x)...exB−m(x)]ℓ(x):=i∑p(x)isoftmax(x):=ℓ(x)p(x)
接下来我们看下 safe softmax 的分块是怎么做的
假设有一组 x = [ x 1 , ... , x N , ... , x 2 N ] x=[x_1,\ldots,x_N,\ldots,x_{2N}] x=[x1,...,xN,...,x2N],我们将它分为两个部分,第一个部分是 x ( 1 ) = [ x 1 , ... , x N ] x^{(1)}=[x_1,\ldots,x_N] x(1)=[x1,...,xN],第二个部分是 x ( 2 ) = [ x N + 1 , ... , x 2 N ] x^{(2)}=[x_{N+1},\ldots,x_{2N}] x(2)=[xN+1,...,x2N],也就是 x = [ x ( 1 ) , x ( 2 ) ] x=[x^{(1)},x^{(2)}] x=[x(1),x(2)],那么有:
m ( x ) = m ( [ x ( 1 ) x ( 2 ) ] ) = max ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) {m(x)=m\left(\left[x^{(1)}x^{(2)}\right]\right)=\max\left(m\left(x ^{(1)}\right),m\left(x^{(2)}\right)\right)} m(x)=m([x(1)x(2)])=max(m(x(1)),m(x(2)))
为了与分块前的结果保持统一,我们可以通过如下的方式来构造 p ( x ) p(x) p(x):
p ( x ) = [ e m ( x ( 1 ) ) − m ( x ) p ( x ( 1 ) ) e m ( x ( 2 ) ) − m ( x ) p ( x ( 2 ) ) ] = [ e m ( x ( 1 ) ) − m ( x ) [ e x 1 ( 1 ) − m ( x ( 1 ) ) ... e x N ( 1 ) − m ( x ( 1 ) ) ] e m ( x ( 2 ) ) − m ( x ) [ e x N + 1 ( 2 ) − m ( x ( 2 ) ) ... e x 2 N ( 2 ) − m ( x ( 2 ) ) ] ] = [ [ e x 1 ( 1 ) − m ( x ) ... e x N ( 1 ) − m ( x ) ] [ e x N + 1 ( 2 ) − m ( x ) ... e x 2 N ( 2 ) − m ( x ) ] ] = [ e x 1 − m ( x ) ... e x 2 N − m ( x ) ] \begin{aligned} p(x)&= \begin{bmatrix} e^{m\left(x^{(1)} \right)-m(x)}p\left(x^{(1)}\right)\quad e^{m\left(x^{(2)}\right)-m(x)}p\left(x^ {(2)}\right)\end{bmatrix}\\ &= \begin{bmatrix} e^{m\left(x^{(1)}\right)-m(x)}\big[ e^{x {1}^{(1)}-m(x^{(1)})}\quad\ldots\quad e^{x{N}^{(1)}-m(x^{(1)})} \big] \quad e^{m\left(x^{(2)}\right)-m(x)} \big[ e^{x {N+1}^{(2)}-m(x^{(2)})}\quad\ldots\quad e^{x{2N}^{(2)}-m(x^{(2)})} \big]\end{bmatrix}\\ &= \begin{bmatrix} \big[ e^{x_{1}^{(1)}-m(x)}\quad\ldots \quad e^{x_{N}^{(1)}-m(x)} \big]\quad\big[ e^{x_{N+1}^{(2)}-m(x)}\quad \ldots\quad e^{x_{2N}^{(2)}-m(x)} \big]\end{bmatrix} \\ &= \big[ e^{x_{1}-m(x)}\quad\ldots\quad e^{x_{2N}-m(x)} \big]\end{aligned} p(x)=[em(x(1))−m(x)p(x(1))em(x(2))−m(x)p(x(2))]=[em(x(1))−m(x)[ex1(1)−m(x(1))...exN(1)−m(x(1))]em(x(2))−m(x)[exN+1(2)−m(x(2))...ex2N(2)−m(x(2))]]=[[ex1(1)−m(x)...exN(1)−m(x)][exN+1(2)−m(x)...ex2N(2)−m(x)]]=[ex1−m(x)...ex2N−m(x)]
那为什么要这么构造 p ( x ) p(x) p(x) 呢?为什么要给它们乘以一个系数呢?我们来举个例子简单说明下,由于 m ( x ) m(x) m(x) 是 m ( x ( 1 ) ) m(x^{(1)}) m(x(1)) 是 m ( x ( 2 ) ) m(x^{(2)}) m(x(2)) 里较大的那个值,所以它可能等于 m ( x ( 1 ) ) m(x^{(1)}) m(x(1)) 或者 m ( x ( 2 ) ) m(x^{(2)}) m(x(2))
我们假设 m ( x ) = m ( x ( 2 ) ) m(x) = m(x^{(2)}) m(x)=m(x(2)),那么 p ( x ( 2 ) ) p(x^{(2)}) p(x(2)) 的系数 e m ( x ( 2 ) ) − m ( x ) = 1 e^{m\left(x^{(2)}\right)-m(x)}=1 em(x(2))−m(x)=1,分块部分的 p ( x ( 2 ) ) p(x^{(2)}) p(x(2)) 不变,这很好理解,因为 p ( x ( 2 ) ) p{(x^{(2)})} p(x(2)) 分块计算的时候,它的 x x x 减去的就是全局最大值,所以此时不需要再进行调整
那么对于 p ( x ( 1 ) ) p{(x^{(1)})} p(x(1)) 部分,它在分块计算时,它的 x x x 减去的是局部最大值不是全局最大值,那么它和全局最大值相比少减了多少呢?也就是上面的 m ( x ( 1 ) ) − m ( x ) m(x^{(1)})-m(x) m(x(1))−m(x),现在给它补回来,这样乘上 e m ( x ( 1 ) ) − m ( x ) e^{m\left(x^{(1)} \right)-m(x)} em(x(1))−m(x) 系数之后, p ( x ( 1 ) ) p{(x^{(1)})} p(x(1)) 就被调整为减去全局最大值的正确结果了
同理 ℓ ( x ) \ell \left(x\right) ℓ(x) 的构造方式如下:
ℓ ( x ) = ℓ ( [ x ( 1 ) x ( 2 ) ] ) = e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) \ell(x)=\ell\left(\left[x^{(1)}x^{(2)}\right]\right)=e^{m\left(x^{(1)}\right)-m(x)}\ell\left(x^{(1)}\right)+e^{m\left(x^{(2)}\right)-m(x)}\ell\left(x^{(2)}\right) ℓ(x)=ℓ([x(1)x(2)])=em(x(1))−m(x)ℓ(x(1))+em(x(2))−m(x)ℓ(x(2))
最后 softmax 计算如下:
softmax ( x ) : = p ( x ) ℓ ( x ) \text{softmax}(x) := \frac{p(x)}{\ell(x)} softmax(x):=ℓ(x)p(x)
可以看到 softmax 也可以通过分块来计算了,只是我们需要额外保存几个变量 m ( x ( 1 ) ) , m ( x ( 2 ) ) , ℓ ( x ( 1 ) ) , ℓ ( x ( 2 ) ) m(x^{(1)}),m(x^{(2)}),\ell(x^{(1)}),\ell(x^{(2)}) m(x(1)),m(x(2)),ℓ(x(1)),ℓ(x(2)),不过它们对于 softmax 每行都各自占用一个数字,存储占用非常小
另外就是在分块进行合并时,需要额外的调整计算,增加了计算量,但是这些计算量相对于减少 I/O 时间都是非常划算的
5. FlashAttention
下面我们看下 FlashAttention 伪代码的实现:

下面是 FlashAttention 的整体流程:
一、算法基本设置:
- 输入为三个矩阵 Q , K , V ∈ R N × d \mathbf{Q},\mathbf{K},\mathbf{V} \in \mathbb{R}^{N\times d} Q,K,V∈RN×d,初始存储于低速 HBM(高带宽内存)
- 片上内存 SRAM(on-chip SRAM)的大小为 M M M
二、算法详细步骤分析:
步骤 1-2(初始化阶段)
- 步骤 1 :设置分块大小
- 将序列长度 N N N 按照列块大小 B c B_c Bc 和行块大小 B r B_r Br 进行划分,使得每次在片上(on-chip SRAM)只处理一小部分 Q \mathbf{Q} Q 和 K \mathbf{K} K
- 设定每个列块的尺寸为 B c = ⌈ M 4 d ⌉ B_c=\lceil \frac{M}{4d} \rceil Bc=⌈4dM⌉,每个行块的尺寸为 B r = min ( ⌈ M 4 d ⌉ , d ) B_r=\min \left(\lceil \frac{M}{4d}\rceil,d\right) Br=min(⌈4dM⌉,d),同时确保其不超过序列长度 N N N
- 步骤 2 :在片外内存 HBM 中,初始化输出矩阵 O = ( 0 ) N × d ∈ R N × d \mathbf{O}=(0){N\times d}\in \mathbb{R}^{N\times d} O=(0)N×d∈RN×d、归一化因子向量 ℓ = ( 0 ) N ∈ R N \ell = (0){N} \in \mathbb{R}^{N} ℓ=(0)N∈RN 以及最大值向量 m = ( − ∞ ) N ∈ R N m = (- \infty)_N \in \mathbb{R}^N m=(−∞)N∈RN
步骤 3-4 (输入输出分块阶段)
- 步骤 3 :输入分块(tiling)
- 将输入矩阵 Q \mathbf{Q} Q 划分成 T r = ⌈ N B r ⌉ T_r= \lceil \frac{N}{B_r} \rceil Tr=⌈BrN⌉ 个子块 Q 1 , ... , Q T r \mathbf{Q}1,\ldots,\mathbf{Q}{T_r} Q1,...,QTr,每个子块大小为 B r × d B_r \times d Br×d
- 同样地,将 K , V \mathbf{K},\mathbf{V} K,V 划分成 T c = ⌈ N B c ⌉ T_c = \lceil \frac{N}{B_c} \rceil Tc=⌈BcN⌉ 个子块 K 1 , ... , K T c \mathbf{K}1,\ldots,\mathbf{K}{T_c} K1,...,KTc 以及 V 1 , ... , V T c \mathbf{V}1,\ldots,\mathbf{V}{T_c} V1,...,VTc,每个子块大小为 B c × d B_c \times d Bc×d
- 步骤 4 :输出与中间值分块
- 将输出矩阵 O \mathbf{O} O 同样分成 T r T_r Tr 个子块 O 1 , ... , O T r \mathbf{O}1,\ldots,\mathbf{O}{T_r} O1,...,OTr,每个子块大小为 B r × d B_r\times d Br×d
- 中间归一化因子向量 ℓ \ell ℓ 与中间最大值向量 m m m 分别划分为 T r T_r Tr 个子块 ℓ 1 , ... , ℓ T r \ell_1,\ldots,\ell_{T_r} ℓ1,...,ℓTr 以及 m 1 , ... , m T r m_1,\ldots,m_{T_r} m1,...,mTr,每块大小为 B r B_r Br
步骤 5-15(分块计算阶段)
- 步骤 5 :外层循环(步骤 5-15),遍历所有的键值块( K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj), j = 1 , ... , T c j=1,\ldots,T_c j=1,...,Tc
- 步骤 6 :加载对应的键值块 K j , V j \mathbf{K}_j,\mathbf{V}_j Kj,Vj 到片上内存 SRAM 上
- 步骤 7 :内层循环(步骤 7-14),遍历所有的查询块( Q i \mathbf{Q}_i Qi), i = 1 , ... , T r i=1,\ldots,T_r i=1,...,Tr,每个查询块内部的具体计算过程如下:
- 步骤 8 :从片外内存 HBM 上加载 Q i , O i , ℓ i , m i \mathbf{Q}_i,\mathbf{O}_i,\ell_i,m_i Qi,Oi,ℓi,mi 到片上内存 SRAM 上
- 步骤 9 :在片上内存 SRAM 上计算 attention 分数矩阵: S i j = Q i K j T ∈ R B r × B c \mathbf{S}_{ij}=\mathbf{Q}_i\mathbf{K}_j^T\in \mathbb{R}^{B_r\times B_c} Sij=QiKjT∈RBr×Bc
- 步骤 10 :在片上内存上对注意力分数进行分块式 softmax:
- 计算当前子块最大值(用于数值稳定): m ~ i j = r o w m a x ( S i j ) ∈ R B r \tilde{m}{ij}=\mathrm{rowmax}(\mathbf{S}{ij}) \in \mathbb{R}^{B_r} m~ij=rowmax(Sij)∈RBr
- 计算指数化矩阵: P ~ i j = e x p ( S i j − m ~ i j ) ∈ R B r × B c \tilde{\mathbf{P}}{ij} = \mathrm{exp}(\mathbf{S}{ij}-\tilde{m}_{ij})\in \mathbb{R}^{B_r\times B_c} P~ij=exp(Sij−m~ij)∈RBr×Bc
- 计算局部归一化因子(每行求和): ℓ ~ i j = r o w s u m ( P ~ i j ) ∈ R B r \tilde{\ell}{ij}=\mathrm{rowsum}(\tilde{\mathbf{P}}{ij})\in \mathbb{R}^{B_r} ℓ~ij=rowsum(P~ij)∈RBr
- 步骤 11 :更新全局归一化因子:
- 新的最大值: m i n e w = max ( m i , m ~ i j ) ∈ R B r m_i^{\mathrm{new}}=\max(m_i,\tilde{m}_{ij})\in \mathbb{R}^{B_r} minew=max(mi,m~ij)∈RBr
- 新的归一化向量: ℓ i n e w = e m i − m i n e w ℓ i + e m ~ i j − m i n e w ℓ ~ i j ∈ R B r \ell_i^{\mathrm{new}}=e^{m_i-m_i^{\mathrm{new}}}\ell_i+e^{\tilde{m}{ij}-m_i^{new}}\tilde{\ell}{ij}\in \mathbb{R}^{B_r} ℓinew=emi−minewℓi+em~ij−minewℓ~ij∈RBr
- 步骤 12 :计算注意力输出 O i \mathbf{O}_i Oi 并写回 HBM: O i ← d i a g ( ℓ i n e w ) − 1 ( d i a g ( ℓ i ) e m i − m i n e w O i + e m ~ i j − m i n e w P ~ i j V j ) \mathbf{O}i\leftarrow \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1}(\mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}}\mathbf{O}i+e^{\tilde{m}{ij}-m_i^{\mathrm{new}}}\tilde{\mathbf{P}}{ij}\mathbf{V}_j) Oi←diag(ℓinew)−1(diag(ℓi)emi−minewOi+em~ij−minewP~ijVj)
- 步骤 13 :更新全局向量 ℓ i ← ℓ i n e w , m i ← m i n e w \ell_i \leftarrow \ell_i^{\mathrm{new}},m_i\leftarrow m_i^{\mathrm{new}} ℓi←ℓinew,mi←minew 写回 HBM
- 步骤 14:内层循环结束
- 步骤 15:外层循环结束
步骤 16(输出结果阶段)
- 步骤 16 :将输出计算的 attention 结果 O \mathbf{O} O 返回
Note :列块大小为 ⌈ M 4 d ⌉ \lceil \frac{M}{4d} \rceil ⌈4dM⌉ 原因在于要存储 Q , K , V , O \mathrm{Q,K,V,O} Q,K,V,O 四个分块矩阵;行块大小为 min ( ⌈ M 4 d ⌉ , d ) \min \left(\lceil \frac{M}{4d}\rceil,d\right) min(⌈4dM⌉,d) 原因在于为了控制 Q \mathbf{Q} Q 矩阵分块最大就是一个 d × d d\times d d×d 的方阵,不要让 Q \mathrm{Q} Q 分块矩阵的行太大,否则可能使其在 SRAM 里生成的中间矩阵计算结果太大,而超出 SRAM 的大小。
其中步骤 12 需要专门来推导一下,为简洁起见,先不考虑 mask 和 dropout 操作:
O i ( j + 1 ) = P i , : j + 1 V : j + 1 = s o f t m a x ( S i , : j + 1 ) V : j + 1 = d i a g ( ℓ ( j + 1 ) ) − 1 ( exp ( [ S i , : j S i , j : j + 1 ] − m ( j + 1 ) ) ) [ V : j V j : j + 1 ] = d i a g ( ℓ ( j + 1 ) ) − 1 ( exp ( S i , : j − m ( j + 1 ) ) V : j + exp ( S i , j : j + 1 − m ( j + 1 ) ) V j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( e − m ( j + 1 ) exp ( S i , : j ) V : j + e − m ( j + 1 ) exp ( S i , j : j + 1 ) V j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ ( j ) ) e m ( j ) − m ( j + 1 ) d i a g ( ℓ ( j ) ) − 1 exp ( S i , : j − m ( j ) ) V : j + e − m ( j + 1 ) exp ( S i , j : j + 1 ) V j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ j ) e m ( j ) − m ( j + 1 ) P i , : j V : j + e − m ( j + 1 ) exp ( S i , j : j + 1 ) V j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ ( j ) ) e m ( j ) − m ( j + 1 ) O i ( j ) + e m ~ − m ( j + 1 ) exp ( S i , j : j + 1 − m ~ ) V j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ ( j ) ) e m ( j ) − m ( j + 1 ) O i ( j ) + e m ~ − m ( j + 1 ) P i , j : j + 1 V j : j + 1 ) \begin{aligned} \mathbf{O}i^{(j+1)} &= \mathbf{P}{i,:j+1}\mathbf{V}{:j+1} = \mathrm{softmax}\left(\mathbf{S}{i,:j+1}\right)\mathbf{V}{:j+1} \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left(\exp\left([\mathbf{S}{i,:j} \quad \mathbf{S}{i,j:j+1}]-m^{(j+1)}\right)\right) \left[\mathbf{V}{:j}\atop\mathbf{V}{j:j+1}\right] \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left(\exp\left(\mathbf{S}{i,:j}-m^{(j+1)}\right)\mathbf{V}{:j} + \exp\left(\mathbf{S}{i,j:j+1}-m^{(j+1)}\right)\mathbf{V}{j:j+1} \right) \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left( e^{-m^{(j+1)}} \exp (\mathbf{S}{i,:j})\mathbf{V}{:j} + e^{-m^{(j+1)}} \exp (\mathbf{S}{i,j:j+1}) \mathbf{V}{j:j+1} \right) \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left( \mathrm{diag}\left(\ell^{(j)}\right) e^{m^{(j)}-m^{(j+1)}} \mathrm{diag}\left(\ell^{(j)}\right)^{-1} \exp \left(\mathbf{S}{i,:j} - m^{(j)} \right) \mathbf{V}{:j} + e^{-m^{(j+1)}} \exp (\mathbf{S}{i,j:j+1}) \mathbf{V}{j:j+1} \right) \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left( \mathrm{diag} \left(\ell^{j}\right) e^{m^{(j)}-m^{(j+1)}} \mathbf{P}{i,:j} \mathbf{V}{:j} + e^{-m^{(j+1)}} \exp (\mathbf{S}{i,j:j+1}) \mathbf{V}{j:j+1} \right) \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left( \mathrm{diag}\left( \ell^{(j)} \right) e^{m^{(j)}-m^{(j+1)}} \mathbf{O}i^{(j)} + e^{\tilde{m}-m^{(j+1)}} \exp \left( \mathbf{S}{i,j:j+1} - \tilde{m} \right)\mathbf{V}{j:j+1} \right) \\ &= \mathrm{diag}\left(\ell^{(j+1)}\right)^{-1} \left( \mathrm{diag}\left( \ell^{(j)} \right) e^{m^{(j)}-m^{(j+1)}} \mathbf{O}i^{(j)} + e^{\tilde{m}-m^{(j+1)}} \mathbf{P}{i,j:j+1} \mathbf{V}_{j:j+1} \right) \end{aligned} Oi(j+1)=Pi,:j+1V:j+1=softmax(Si,:j+1)V:j+1=diag(ℓ(j+1))−1(exp([Si,:jSi,j:j+1]−m(j+1)))[Vj:j+1V:j]=diag(ℓ(j+1))−1(exp(Si,:j−m(j+1))V:j+exp(Si,j:j+1−m(j+1))Vj:j+1)=diag(ℓ(j+1))−1(e−m(j+1)exp(Si,:j)V:j+e−m(j+1)exp(Si,j:j+1)Vj:j+1)=diag(ℓ(j+1))−1(diag(ℓ(j))em(j)−m(j+1)diag(ℓ(j))−1exp(Si,:j−m(j))V:j+e−m(j+1)exp(Si,j:j+1)Vj:j+1)=diag(ℓ(j+1))−1(diag(ℓj)em(j)−m(j+1)Pi,:jV:j+e−m(j+1)exp(Si,j:j+1)Vj:j+1)=diag(ℓ(j+1))−1(diag(ℓ(j))em(j)−m(j+1)Oi(j)+em~−m(j+1)exp(Si,j:j+1−m~)Vj:j+1)=diag(ℓ(j+1))−1(diag(ℓ(j))em(j)−m(j+1)Oi(j)+em~−m(j+1)Pi,j:j+1Vj:j+1)
Note :对角矩阵的逆矩阵对于原始矩阵对角元素的倒数构成的矩阵( d i a g ( ℓ i n e w ) − 1 \mathrm{diag}(\ell_i^{\mathrm{new}})^{-1} diag(ℓinew)−1),乘以后面的 O \mathrm{O} O 值( ( d i a g ( ℓ i ) e m i − m i n e w O i + e m ~ i j − m i n e w P ~ i j V j ) (\mathrm{diag}(\ell_i)e^{m_i-m_i^{\mathrm{new}}}\mathbf{O}i+e^{\tilde{m}{ij}-m_i^{\mathrm{new}}}\tilde{\mathbf{P}}_{ij}\mathbf{V}_j) (diag(ℓi)emi−minewOi+em~ij−minewP~ijVj)),相当于给每一行除以每一行的求和值。这里先给原来的 O i \mathrm{O}_i Oi 先乘以原来的 d i a g ( ℓ i ) \mathrm{diag}(\ell_i) diag(ℓi) 还原回去,再加上新的 O \mathrm{O} O 值,再除以新的求和值,从而更新 O \mathrm{O} O 值到 HBM
其核心思想仍是使用分块的计算更新迭代以得到全局的结果
上述算法并没有增加额外的计算,只是将大的操作拆分成多个分块逐个计算,因此其算法复杂度仍为 O ( N 2 d ) O(N^2d) O(N2d),另外由于增加了变量 ℓ , m \ell,m ℓ,m,因此空间复杂度增加 O ( N ) O(N) O(N)
FlashAttention2 的大致思想和 FlashAttention1 类似,增加了一些工程上的优化,比如减少了非矩阵乘法的计算,将 Q \mathrm{Q} Q 改为外循环, K , V \mathrm{K,V} K,V 改为内循环,更进一步减少了对 HBM 的读写,增加了并行度
此外 FlashAttention2 进一步利用分块计算的优势,如果判断一个分块是原始矩阵的上三角部分,也就是它是被 mask 掉的部分,那么就不需要进行 attention的计算了,从而更进一步减少了计算量
OK,以上就是关于 FlashAttention 原理讲解的全部内容了
结语
这篇文章我们主要讲解了 Flash Attention 的原理,与上篇文章 [从Online Softmax到FlashAttention] 中详细的公式推导不同,这里我们主要是通过图示的方法来一步步看 FlashAttention 是怎么做的
FlashAttention 的核心思路是将 Q , K , V \mathbf{Q},\mathbf{K},\mathbf{V} Q,K,V 按块分割到片上高速存储(SRAM)中,逐块计算并累加结果,从而在有限的片上内存下高效完成大规模的注意力运算
首先从矩阵分块的角度出发来看 ( Q K T ) V (QK^T)V (QKT)V 是怎么做的,接着谈到 softmax 的分块计算,通过 safe softmax 的分块将 softmax 也融入到矩阵分块中计算,最后在前面基础上我们分析了 FlashAttention 的伪代码实现
视频和文章的讲解都非常的不错,大家感兴趣的可以多看看🤗