本篇文章是Transformer系列的第七篇。
Transformer系列文章:
● Transformer------Attention怎么实现集中注意力
● Transformer------FeedForward模块在干什么?
● Transformer注意力机制------MHA&MQA&GQA
所有相关源码示例、流程图、模型配置与知识库构建技巧,我也将持续更新在Github:LLMHub,欢迎关注收藏!
希望大家带着下面的问题来学习,我会在文末给出答案。
- 传统 Attention 的主要性能瓶颈在哪里?为什么需要 FlashAttention?
- FlashAttention 是如何利用 shared memory 降低显存占用并提高速度的?
- FlashAttention 在实际应用中还有哪些不足或限制?
一、引言
Transformer 模型自诞生以来,已成为自然语言处理、计算机视觉、语音等领域的核心架构。而 Attention 机制作为 Transformer 的核心计算模块,其计算复杂度和显存占用在处理长序列时常常成为性能瓶颈。传统 Attention 的时间和空间复杂度为 O(n^2),这在大规模模型或长文本输入中表现为效率低下、显存不足。
FlashAttention 是由 Stanford Hazy Research 团队提出的一种高效实现方式,专为 GPU 设计,通过 I/O 感知(IO-aware)
的优化策略,在不损失精度的前提下显著加速 Attention 计算并降低显存占用。
二、Attention回顾
在标准 Transformer 中,Attention 计算如下:

其中:
- Q(Query)、K(Key)、V(Value)为输入的线性变换结果;
- QK^T生成的是 Attention Score 矩阵,大小为 (n, n),即每个 token 对所有 token 的相关性。
计算流程如下:
- 计算 QK^T:需存储一个n*n 矩阵;
- 计算 softmax;
- 乘以 V 得到最终输出。
瓶颈分析:
- 需要在 GPU 上存储整个 QK^T 结果 → 显存开销大;
- 需要频繁访问 HBM(高带宽显存)→ 带宽受限 → 性能下降。
三、FlashAttention 的核心思想(核心原理)
FlashAttention 的目标是:降低显存占用,同时提升速度,具体做法包括:
1. Tile-Based 计算
将 Q, K, V 分块为小块(tile),每次仅处理一小块:
- 利用 GPU 的片上 SRAM(Shared Memory)完成 QK^T 和 softmax
- 避免中间结果写入 HBM

上面图中左半部分是计算机的内存分布, HBM 是 "High Bandwidth Memory
" 的缩写,也就是高带宽显存 ,是一种专为高性能计算和显存密集型任务 (如 GPU、AI 加速、图形渲染等)设计的下一代显存技术。 SRAM是一种静态随机访问存储器,用于高速缓存等内部存储器,具有更快的访问速度和更低的延迟,但成本更高且占用更多芯片空间。
标准Attention的计算算法如下:

可以看到,标准 Attention 实现大量中间结果需频繁访问 HBM ,而 HBM 的访问速度远远低于 GPU 的SRAM。因此 FlashAttention 通过"tile 计算+显存访问优化 "方案,减少了对 HBM 的依赖,提高了整体执行效率。
softmax计算公式如下:

为了数值稳定性,FlashAttention采用Safe Softmax,对于向量x
同理,对于向量x = [ x1,x2],softmax可以分解计算:

这就说明即使Q,K,V被分成块也是可以计算softmax的。
2. Recomputation Strategy
为了节省存储中间的 softmax 权重,FlashAttention 在需要时重新计算部分内容,避免保存完整矩阵。
标准Attention的反向传播算法如下,其中P代表Softmax(QKᵀ / √dₖ),也就是注意力权重矩阵。
结合着Attention的计算公式更好理解


在标准 Attention 实现中,为了完成前向传播和反向传播,我们通常需要保存如下中间结果:
- QKᵀ(Attention Score 矩阵)
- softmax 权重
- Attention output(最终结果)
这些矩阵很大,尤其是在处理长序列时,显存消耗会非常高。
FlashAttention 为了降低显存占用,采取了一种策略:
在前向传播时 不保留中间矩阵 ,而是到了反向传播阶段 再把它们重新计算出来。
以 softmax 的 attention score 为例:
- 标准方法:
latex
QKᵀ → softmax → 缓存在显存中 → 用于乘V和反向传播
- FlashAttention 方法:
latex
QKᵀ → softmax → 直接用于乘V,不缓存
...
后面反向传播需要用到 softmax → 再重新计算一次 QKᵀ 和 softmax
这就节省了存 softmax 的显存开销,尤其在长序列上非常可观。
FlashAttention的前向传播算法如下:

FlashAttention的反向传播的过程如下:

可以看到其中没有存储,反向传播的过程中需要的数据都是重新计算的,这种"以算代存"的方式是一种典型的时间换空间(compute vs. memory)策略。虽然多计算一次会略微增加一点时间,但显存节省得非常明显,反而提升了整体性能,因为:
- 显存访问慢,限制吞吐;
- GPU 有大量计算资源,计算冗余可以承受;
- 避免了 HBM 带宽瓶颈。
3.Block Sparse FlashAttention
传统 Attention 是 全连接的:每个 token 都和所有其他 token 交互,计算量为 O(n^2)。
而 Sparse Attention 只计算部分 token 对的关系,常见稀疏模式包括:
- Sliding Window:每个 token 只关注它前后几个邻居;
- Block Sparse:将 Q、K、V 分成若干块(block),只计算特定 block 对之间的 attention;
- Global + Local:大部分是局部 attention,少数 token 有全局连接(如 Longformer)。
在FlashAttention 的基础上,为了进一步提升处理超长序列的性能和可扩展性,Block Sparse FlashAttention 结合了 FlashAttention 的 IO-aware 高效计算方式和 block-sparse attention mask 的稀疏结构,从而实现 更少计算 + 更少显存占用 的 attention 操作。
Block Sparse FlashAttention 的关键是在 FlashAttention 高效计算的基础上,只计算被稀疏掩码指定的 QK 块对,算法如下:

- 输入:Q、K、V 被划分为若干 block;
- 依据稀疏掩码(mask)决定哪些 Q-block 要与哪些 K-block 交互;
- 对每个有效块对,执行 FlashAttention 核心流程(QKᵀ → softmax → ×V);
- 将结果整合,拼接为完整输出。
四、FlashAttention vs 标准Attention
项目 | 原始 Attention | FlashAttention |
---|---|---|
时间复杂度 | O(n^2 d) | O(n^2 d),但更快 |
显存消耗 | 高(存储中间矩阵) | 低(tile重计算) |
速度表现 | 慢(受限于显存读写) | 快(高效访存) |
精度控制 | float32 为主 | 支持 fp16 / bf16 |
在长序列任务中,FlashAttention 可将显存减少 2-4 倍,速度提升达 2-4 倍。
五、从0手撸FlashAttention
python
for i in range(0, N, block_size):
q_block = q[:, i:i+block_size] # [B, Bq, D]
max_score = None
row_sum_exp = None
acc = torch.zeros_like(q_block)
for j in range(0, N, block_size):
k_block = k[:, j:j+block_size] # [B, Bk, D]
v_block = v[:, j:j+block_size] # [B, Bk, D]
# 1. Attention logits
scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale # [B, Bq, Bk]
# 2. Numerical stability
block_max = scores.max(dim=-1, keepdim=True).values # [B, Bq, 1]
scores = scores - block_max
exp_scores = scores.exp() # [B, Bq, Bk]
# 3. Dropout (可选)
if dropout_p > 0.0:
exp_scores = F.dropout(exp_scores, p=dropout_p, training=True)
# 4. Weighted sum
acc += torch.bmm(exp_scores, v_block) # [B, Bq, D]
# 5. Softmax normalization (log-sum-exp trick)
block_sum = exp_scores.sum(dim=-1, keepdim=True) # [B, Bq, 1]
if row_sum_exp is None:
row_sum_exp = block_sum
max_score = block_max
else:
row_sum_exp += block_sum
max_score = torch.max(max_score, block_max)
# Normalize accumulated result
output[:, i:i+block_size] = acc / (row_sum_exp + 1e-6)
return output
要注意的是 上面的PyTorch 实现并没有用到 Shared Memory ,它只是演示了 FlashAttention 的思想流程。
真正利用了 SRAM 的,是 FlashAttention 的 CUDA kernel 或 Triton kernel 实现。
如果想要测试效率,可以直接调用torch封装好的flashattention
python
from flash_attn.modules.mha import FlashMHA
import torch
x = torch.randn(8, 512, 512, device='cuda') # batch, seq_len, dim
mha = FlashMHA(embed_dim=512, num_heads=8, device='cuda')
output = mha(x)
print(output.shape) # [8, 512, 512]
六、总结
FlashAttention 提供了一种高效、低显存的 Attention 实现方式,极大地缓解了 Transformer 模型在长序列处理中的性能瓶颈。在当前大模型时代,FlashAttention 成为高效训练与部署的关键组件之一。
最后,我们回答一下文章开头提出的问题。
- 传统 Attention 的主要性能瓶颈在哪里?为什么需要 FlashAttention?
标准的 Attention 实现存在两个严重问题:
- 显存占用高 :完整计算 attention 需要构造形如
<font style="color:rgb(25, 27, 31);">[Batch, Head, SeqLen, SeqLen]</font>
的 score 矩阵 QKᵀ,即 O(n^2) 的显存需求; - 访存带宽瓶颈:计算过程中,Q、K、V、score、softmax 权重、输出 O 都需要多次读写 global memory(HBM),而 GPU 的计算能力往往无法完全发挥出来。
FlashAttention 被提出,目标就是通过"在 shared memory 中块级 tile 化 attention 计算",避免 score 的 materialization 和重复访存,从而提升效率、减少内存压力。
- FlashAttention 是如何利用 shared memory 降低显存占用并提高速度的?
FlashAttention 的关键设计是:
- 将 Q/K/V 分为小块(tiles),在 shared memory(即 SRAM)中进行 attention 的计算;
- 在计算 softmax 的过程中使用 log-sum-exp 技巧,确保数值稳定;
- 将 softmax 后与 V 的乘法也集成进 tile 内的计算流程,避免生成大矩阵;
- 利用 recomputation:不存储 softmax 权重 P,而是在反向传播时重算 QKᵀ,换取显存节省。
-
FlashAttention 在实际应用中还有哪些不足或限制?
尽管 FlashAttention 在性能和显存方面带来显著改善,但也存在一些实际问题:
- 线程并行效率不高:使用的是 "1 warp 对应 1 Q 行" 的划分方式,warp 内线程空闲率高;
- split-K 导致频繁 HBM 读写:每次 tile 操作都要访问 Q 和 O,存在冗余累加;
- 不支持 MQA / GQA 等高效注意力结构:仅适用于标准 MHA;
- 实现依赖 Triton 编译器:对部署平台要求高,难以在 PyTorch、TensorFlow 等框架中原生集成;
- 反向传播内核较少优化:精度和性能兼顾方面还有改进空间。
关于深度学习和大模型相关的知识和前沿技术更新,请关注公众号算法coting
!
以上内容部分参考了
FlashAttention:Fast and Memory-Efficient Exact Attention with IO-Awareness
非常感谢,如有侵权请联系删除!