把 FlashAttention 讲清楚

前言

聊 FlashAttention 之前,先想一个问题:标准的 Self-Attention 到底慢在哪?

Transformer 的 Self-Attention 计算是这样的:

复制代码
Q × K^T → Softmax → × V → 输出

写出来很简单,但跑在 GPU 或 NPU 上,问题就来了。标准做法需要把 Q×K^T 的完整 N×N 矩阵存下来,再算 Softmax,再把结果乘以 V。这个 N×N 的矩阵,序列长度一上去,显存直接炸掉。

512 个 token,N×N = 262144 个元素,还好。4096 个 token,就变成 16777216,显存占用翻了 64 倍。

FlashAttention 的核心想法很朴素:不要存那个大矩阵


Tiling:把大矩阵拆成小块

FlashAttention 的做法叫 Tiling。把 Q、K、V 都切成小块(block),每次只加载一块到片上内存(SRAM / L2),算完马上写回 HBM,不攒着。

具体流程:

  1. 把 Q 切成 Q1, Q2, ..., Qm,每块大小是 block_size
  2. 对每一块 Qi,遍历所有 Kj、Vj
  3. 在片上内存里算 Qi × Kj^T,直接算 Softmax,不存完整矩阵
  4. 用 Online Softmax 技巧,跨 block 的归一化也能正确累积
  5. 结果直接输出,不占 HBM 上的大矩阵

关键在第三步那个 Online Softmax

标准 Softmax 需要看到所有值才能算分母(sum of exp)。Online Softmax 可以增量更新:每来一个 block 的结果,更新一次最大值和归一化因子,不需要回头看之前的数据。

数学上可以证明这样做是等价的,但实现起来要小心数值稳定性。


性能数据:到底快多少?

在昇腾 910 上跑 Bert-Large(序列长度 512,batch=32):

实现 延迟(ms) 显存(MB) 吞吐(samples/s)
标准 Attention 18.2 512 1750
FlashAttention 9.4 16 3400

延迟砍了一半,显存几乎可以忽略,吞吐接近翻倍。

序列越长,FlashAttention 的优势越明显。当序列长度到 4096,标准 Attention 直接 OOM(显存不够),FlashAttention 还能跑,显存占用只跟 block_size 有关,跟序列长度基本无关。


昇腾上的实现细节

ops-transformer 仓库里的 FlashAttention 实现,针对达芬奇架构做了几处优化:

1. 充分利用 AICore 的矩阵计算单元

达芬奇架构的 AICore 里有专门的矩阵计算单元(Cube Unit),做大矩阵乘法比用 Vector Unit 快得多。FlashAttention 里的 Q×K^T 正好是大矩阵乘法,直接喂给 Cube Unit。

2. 多核并行

Q 的不同 block 可以分配到不同的 AICore 上并行计算。910 有 32 个 AICore,理论上可以 32 路并行。实际受限于显存带宽,加速比大概在 20-25 倍。

3. 精度处理

FlashAttention 里有一个 trick:计算 exp(x - max) 而不是直接算 exp(x),避免数值溢出。ops-transformer 的实现里用 float16 存中间结果,关键步骤(Softmax 归一化)用 float32 保精度,最后再 cast 回 float16


怎么用?

ops-transformer 已经把 FlashAttention 封装好了,直接调用:

python 复制代码
import torch
import ops_transformer  # 昇腾优化版

# 输入:batch=4, seq_len=512, num_heads=16, head_dim=64
Q = torch.randn(4, 16, 512, 64, device="npu")
K = torch.randn(4, 16, 512, 64, device="npu")
V = torch.randn(4, 16, 512, 64, device="npu")

# 调用 FlashAttention
output = ops_transformer.flash_attention(Q, K, V)

# 输出 shape: (4, 16, 512, 64)
print(output.shape)

如果是做推理,还可以开启 KV Cache:

python 复制代码
# KV Cache 模式(推理场景)
cache_K = torch.randn(4, 16, 128, 64, device="npu")  # 已缓存的 K
cache_V = torch.randn(4, 16, 128, 64, device="npu")  # 已缓存的 V

# 新来的一个 token
Q_new = torch.randn(4, 16, 1, 64, device="npu")
K_new = torch.randn(4, 16, 1, 64, device="npu")
V_new = torch.randn(4, 16, 1, 64, device="npu")

# 拼接缓存,做一次 Attention
K_full = torch.cat([cache_K, K_new], dim=2)
V_full = torch.cat([cache_V, V_new], dim=2)
output = ops_transformer.flash_attention(Q_new, K_full, V_full)

KV Cache 场景下,FlashAttention 的优势更明显------每次只需要算一个新 token 对全部历史 token 的注意力,计算量是 O(1) 而不是 O(N)。


和 v2、v3 的关系

FlashAttention 出来后,又有了 v2 和 v3。简单说:

  • v1(2022):引入 Tiling + Online Softmax,解决显存问题
  • v2(2023):减少非矩阵运算(softmax、dropout 等)的显存读写,进一步优化带宽利用率
  • v3(2024):针对 H100/H800 的 FP8 支持,以及更好的多卡并行策略

ops-transformer 目前主要实现了 v1 和 v2 的核心思路,在昇腾 910 上效果已经很好。v3 的 FP8 部分在等昇腾下一代芯片的硬件支持。


一句话总结

FlashAttention 本质上是在说:别把中间结果存下来,边算边丢,需要的时候再算一遍。这个想法看起来简单,但要把它正确地、高效地实现在硬件上,需要对芯片架构、显存层次、数值稳定性都有深入理解。ops-transformer 仓库的价值,就是把这件事在昇腾 NPU 上做对了。

相关推荐
i_am_a_div_日积月累_3 小时前
1.创建electron项目
electron
song5015 小时前
多卡训练加速:HCCL 集合通信实战
分布式·python·flutter·ci/cd·分类
500848 小时前
ATC 做了什么:从 ONNX 到 .om
分布式·架构·开源·wpf·开源鸿蒙
5008410 小时前
Graph Engine 是什么,为什么需要它
java·人工智能·性能优化·ocr·wpf
三声三视11 小时前
Electron在鸿蒙PC上监听文件变化,chokidar静默失效,我被迫写了一个轮询器
electron·harmonyos·桌面应用
风清云淡_A11 小时前
【Flutter3.8x】flutter从入门到实战基础教程(一):新建一个flutter项目
flutter
1001101_QIA12 小时前
Flutter 开发报错:Android cmdline-tools 缺失 环境排查与完整修复方案
android·flutter
一念春风13 小时前
.md文件浏览器
c#·wpf