一文搞懂FlashAttention怎么提升速度的?

本篇文章是Transformer系列的第七篇。

Transformer系列文章:

一览Transformer整体架构

Transformer------Attention怎么实现集中注意力

Transformer------FeedForward模块在干什么?

从0开始实现Transformer

什么是KV-Cache

Transformer注意力机制------MHA&MQA&GQA

所有相关源码示例、流程图、模型配置与知识库构建技巧,我也将持续更新在Github:LLMHub,欢迎关注收藏!

希望大家带着下面的问题来学习,我会在文末给出答案。

  1. 传统 Attention 的主要性能瓶颈在哪里?为什么需要 FlashAttention?
  2. FlashAttention 是如何利用 shared memory 降低显存占用并提高速度的?
  3. FlashAttention 在实际应用中还有哪些不足或限制?

一、引言

Transformer 模型自诞生以来,已成为自然语言处理、计算机视觉、语音等领域的核心架构。而 Attention 机制作为 Transformer 的核心计算模块,其计算复杂度和显存占用在处理长序列时常常成为性能瓶颈。传统 Attention 的时间和空间复杂度为 O(n^2),这在大规模模型或长文本输入中表现为效率低下、显存不足。

FlashAttention 是由 Stanford Hazy Research 团队提出的一种高效实现方式,专为 GPU 设计,通过 I/O 感知(IO-aware)的优化策略,在不损失精度的前提下显著加速 Attention 计算并降低显存占用。

二、Attention回顾

在标准 Transformer 中,Attention 计算如下:

其中:

  • Q(Query)、K(Key)、V(Value)为输入的线性变换结果;
  • QK^T生成的是 Attention Score 矩阵,大小为 (n, n),即每个 token 对所有 token 的相关性。

计算流程如下:

  1. 计算 QK^T:需存储一个n*n 矩阵;
  2. 计算 softmax;
  3. 乘以 V 得到最终输出。

瓶颈分析:

  • 需要在 GPU 上存储整个 QK^T 结果 → 显存开销大;
  • 需要频繁访问 HBM(高带宽显存)→ 带宽受限 → 性能下降。

三、FlashAttention 的核心思想(核心原理)

FlashAttention 的目标是:降低显存占用,同时提升速度,具体做法包括:

1. Tile-Based 计算

将 Q, K, V 分块为小块(tile),每次仅处理一小块:

  • 利用 GPU 的片上 SRAM(Shared Memory)完成 QK^T 和 softmax
  • 避免中间结果写入 HBM

上面图中左半部分是计算机的内存分布, HBM 是 "High Bandwidth Memory" 的缩写,也就是高带宽显存 ,是一种专为高性能计算和显存密集型任务 (如 GPU、AI 加速、图形渲染等)设计的下一代显存技术。 SRAM是一种静态随机访问存储器,用于高速缓存等内部存储器,具有更快的访问速度和更低的延迟,但成本更高且占用更多芯片空间。

标准Attention的计算算法如下:

可以看到,标准 Attention 实现大量中间结果需频繁访问 HBM ,而 HBM 的访问速度远远低于 GPU 的SRAM。因此 FlashAttention 通过"tile 计算+显存访问优化 "方案,减少了对 HBM 的依赖,提高了整体执行效率

softmax计算公式如下:

为了数值稳定性,FlashAttention采用Safe Softmax,对于向量x

同理,对于向量x = [ x1,x2],softmax可以分解计算:

这就说明即使Q,K,V被分成块也是可以计算softmax的。

2. Recomputation Strategy

为了节省存储中间的 softmax 权重,FlashAttention 在需要时重新计算部分内容,避免保存完整矩阵。

标准Attention的反向传播算法如下,其中P代表Softmax(QKᵀ / √dₖ),也就是注意力权重矩阵。

结合着Attention的计算公式更好理解

在标准 Attention 实现中,为了完成前向传播和反向传播,我们通常需要保存如下中间结果:

  • QKᵀ(Attention Score 矩阵)
  • softmax 权重
  • Attention output(最终结果)

这些矩阵很大,尤其是在处理长序列时,显存消耗会非常高。

FlashAttention 为了降低显存占用,采取了一种策略:

在前向传播时 不保留中间矩阵 ,而是到了反向传播阶段 再把它们重新计算出来

以 softmax 的 attention score 为例:

  • 标准方法:
latex 复制代码
QKᵀ → softmax → 缓存在显存中 → 用于乘V和反向传播
  • FlashAttention 方法:
latex 复制代码
QKᵀ → softmax → 直接用于乘V,不缓存
...
后面反向传播需要用到 softmax → 再重新计算一次 QKᵀ 和 softmax

这就节省了存 softmax 的显存开销,尤其在长序列上非常可观。

FlashAttention的前向传播算法如下:

FlashAttention的反向传播的过程如下:

可以看到其中没有存储,反向传播的过程中需要的数据都是重新计算的,这种"以算代存"的方式是一种典型的时间换空间(compute vs. memory)策略。虽然多计算一次会略微增加一点时间,但显存节省得非常明显,反而提升了整体性能,因为:

  • 显存访问慢,限制吞吐;
  • GPU 有大量计算资源,计算冗余可以承受;
  • 避免了 HBM 带宽瓶颈。

3.Block Sparse FlashAttention

传统 Attention 是 全连接的:每个 token 都和所有其他 token 交互,计算量为 O(n^2)。

Sparse Attention 只计算部分 token 对的关系,常见稀疏模式包括:

  • Sliding Window:每个 token 只关注它前后几个邻居;
  • Block Sparse:将 Q、K、V 分成若干块(block),只计算特定 block 对之间的 attention;
  • Global + Local:大部分是局部 attention,少数 token 有全局连接(如 Longformer)。

在FlashAttention 的基础上,为了进一步提升处理超长序列的性能和可扩展性,Block Sparse FlashAttention 结合了 FlashAttention 的 IO-aware 高效计算方式和 block-sparse attention mask 的稀疏结构,从而实现 更少计算 + 更少显存占用 的 attention 操作。

Block Sparse FlashAttention 的关键是在 FlashAttention 高效计算的基础上,只计算被稀疏掩码指定的 QK 块对,算法如下:

  1. 输入:Q、K、V 被划分为若干 block;
  2. 依据稀疏掩码(mask)决定哪些 Q-block 要与哪些 K-block 交互;
  3. 对每个有效块对,执行 FlashAttention 核心流程(QKᵀ → softmax → ×V);
  4. 将结果整合,拼接为完整输出。

四、FlashAttention vs 标准Attention

项目 原始 Attention FlashAttention
时间复杂度 O(n^2 d) O(n^2 d),但更快
显存消耗 高(存储中间矩阵) 低(tile重计算)
速度表现 慢(受限于显存读写) 快(高效访存)
精度控制 float32 为主 支持 fp16 / bf16

在长序列任务中,FlashAttention 可将显存减少 2-4 倍,速度提升达 2-4 倍。

五、从0手撸FlashAttention

python 复制代码
for i in range(0, N, block_size):
    q_block = q[:, i:i+block_size]  # [B, Bq, D]
    max_score = None
    row_sum_exp = None
    acc = torch.zeros_like(q_block)

    for j in range(0, N, block_size):
        k_block = k[:, j:j+block_size]  # [B, Bk, D]
        v_block = v[:, j:j+block_size]  # [B, Bk, D]

        # 1. Attention logits
        scores = torch.bmm(q_block, k_block.transpose(1, 2)) * scale  # [B, Bq, Bk]

        # 2. Numerical stability
        block_max = scores.max(dim=-1, keepdim=True).values  # [B, Bq, 1]
        scores = scores - block_max
        exp_scores = scores.exp()  # [B, Bq, Bk]

        # 3. Dropout (可选)
        if dropout_p > 0.0:
            exp_scores = F.dropout(exp_scores, p=dropout_p, training=True)

        # 4. Weighted sum
        acc += torch.bmm(exp_scores, v_block)  # [B, Bq, D]

        # 5. Softmax normalization (log-sum-exp trick)
        block_sum = exp_scores.sum(dim=-1, keepdim=True)  # [B, Bq, 1]
        if row_sum_exp is None:
            row_sum_exp = block_sum
            max_score = block_max
        else:
            row_sum_exp += block_sum
            max_score = torch.max(max_score, block_max)

    # Normalize accumulated result
    output[:, i:i+block_size] = acc / (row_sum_exp + 1e-6)

return output

要注意的是 上面的PyTorch 实现并没有用到 Shared Memory ,它只是演示了 FlashAttention 的思想流程。

真正利用了 SRAM 的,是 FlashAttention 的 CUDA kernel 或 Triton kernel 实现

如果想要测试效率,可以直接调用torch封装好的flashattention

python 复制代码
from flash_attn.modules.mha import FlashMHA
import torch

x = torch.randn(8, 512, 512, device='cuda')  # batch, seq_len, dim
mha = FlashMHA(embed_dim=512, num_heads=8, device='cuda')
output = mha(x)
print(output.shape)  # [8, 512, 512]

六、总结

FlashAttention 提供了一种高效、低显存的 Attention 实现方式,极大地缓解了 Transformer 模型在长序列处理中的性能瓶颈。在当前大模型时代,FlashAttention 成为高效训练与部署的关键组件之一。

最后,我们回答一下文章开头提出的问题。

  1. 传统 Attention 的主要性能瓶颈在哪里?为什么需要 FlashAttention?

标准的 Attention 实现存在两个严重问题:

  • 显存占用高 :完整计算 attention 需要构造形如 <font style="color:rgb(25, 27, 31);">[Batch, Head, SeqLen, SeqLen]</font> 的 score 矩阵 QKᵀ,即 O(n^2) 的显存需求;
  • 访存带宽瓶颈:计算过程中,Q、K、V、score、softmax 权重、输出 O 都需要多次读写 global memory(HBM),而 GPU 的计算能力往往无法完全发挥出来。

FlashAttention 被提出,目标就是通过"在 shared memory 中块级 tile 化 attention 计算",避免 score 的 materialization 和重复访存,从而提升效率、减少内存压力。

  1. FlashAttention 是如何利用 shared memory 降低显存占用并提高速度的?

FlashAttention 的关键设计是:

  • 将 Q/K/V 分为小块(tiles),在 shared memory(即 SRAM)中进行 attention 的计算;
  • 在计算 softmax 的过程中使用 log-sum-exp 技巧,确保数值稳定;
  • 将 softmax 后与 V 的乘法也集成进 tile 内的计算流程,避免生成大矩阵;
  • 利用 recomputation:不存储 softmax 权重 P,而是在反向传播时重算 QKᵀ,换取显存节省。
  1. FlashAttention 在实际应用中还有哪些不足或限制?

    尽管 FlashAttention 在性能和显存方面带来显著改善,但也存在一些实际问题:

  • 线程并行效率不高:使用的是 "1 warp 对应 1 Q 行" 的划分方式,warp 内线程空闲率高;
  • split-K 导致频繁 HBM 读写:每次 tile 操作都要访问 Q 和 O,存在冗余累加;
  • 不支持 MQA / GQA 等高效注意力结构:仅适用于标准 MHA;
  • 实现依赖 Triton 编译器:对部署平台要求高,难以在 PyTorch、TensorFlow 等框架中原生集成;
  • 反向传播内核较少优化:精度和性能兼顾方面还有改进空间。

关于深度学习和大模型相关的知识和前沿技术更新,请关注公众号算法coting!

以上内容部分参考了

FlashAttention:Fast and Memory-Efficient Exact Attention with IO-Awareness

Flash Attention原理详解(含代码讲解)

非常感谢,如有侵权请联系删除!

相关推荐
不知天地为何吴女士1 小时前
Day32| 509. 斐波那契数、70. 爬楼梯、746. 使用最小花费爬楼梯
算法
小坏坏的大世界1 小时前
C++ STL常用容器总结(vector, deque, list, map, set)
c++·算法
励志要当大牛的小白菜4 小时前
ART配对软件使用
开发语言·c++·qt·算法
qq_513970444 小时前
力扣 hot100 Day56
算法·leetcode
白-胖-子4 小时前
深入剖析大模型在文本生成式 AI 产品架构中的核心地位
人工智能·架构
PAK向日葵5 小时前
【算法导论】如何攻克一道Hard难度的LeetCode题?以「寻找两个正序数组的中位数」为例
c++·算法·面试
想要成为计算机高手5 小时前
11. isaacsim4.2教程-Transform 树与Odometry
人工智能·机器人·自动驾驶·ros·rviz·isaac sim·仿真环境
静心问道6 小时前
InstructBLIP:通过指令微调迈向通用视觉-语言模型
人工智能·多模态·ai技术应用
宇称不守恒4.06 小时前
2025暑期—06神经网络-常见网络2
网络·人工智能·神经网络
爱喝矿泉水的猛男7 小时前
非定长滑动窗口(持续更新)
算法·leetcode·职场和发展