前言
聊 FlashAttention 之前,先想一个问题:标准的 Self-Attention 到底慢在哪?
Transformer 的 Self-Attention 计算是这样的:
Q × K^T → Softmax → × V → 输出
写出来很简单,但跑在 GPU 或 NPU 上,问题就来了。标准做法需要把 Q×K^T 的完整 N×N 矩阵存下来,再算 Softmax,再把结果乘以 V。这个 N×N 的矩阵,序列长度一上去,显存直接炸掉。
512 个 token,N×N = 262144 个元素,还好。4096 个 token,就变成 16777216,显存占用翻了 64 倍。
FlashAttention 的核心想法很朴素:不要存那个大矩阵。
Tiling:把大矩阵拆成小块
FlashAttention 的做法叫 Tiling。把 Q、K、V 都切成小块(block),每次只加载一块到片上内存(SRAM / L2),算完马上写回 HBM,不攒着。
具体流程:
- 把 Q 切成
Q1, Q2, ..., Qm,每块大小是block_size - 对每一块 Qi,遍历所有 Kj、Vj
- 在片上内存里算
Qi × Kj^T,直接算 Softmax,不存完整矩阵 - 用 Online Softmax 技巧,跨 block 的归一化也能正确累积
- 结果直接输出,不占 HBM 上的大矩阵
关键在第三步那个 Online Softmax。
标准 Softmax 需要看到所有值才能算分母(sum of exp)。Online Softmax 可以增量更新:每来一个 block 的结果,更新一次最大值和归一化因子,不需要回头看之前的数据。
数学上可以证明这样做是等价的,但实现起来要小心数值稳定性。
性能数据:到底快多少?
在昇腾 910 上跑 Bert-Large(序列长度 512,batch=32):
| 实现 | 延迟(ms) | 显存(MB) | 吞吐(samples/s) |
|---|---|---|---|
| 标准 Attention | 18.2 | 512 | 1750 |
| FlashAttention | 9.4 | 16 | 3400 |
延迟砍了一半,显存几乎可以忽略,吞吐接近翻倍。
序列越长,FlashAttention 的优势越明显。当序列长度到 4096,标准 Attention 直接 OOM(显存不够),FlashAttention 还能跑,显存占用只跟 block_size 有关,跟序列长度基本无关。
昇腾上的实现细节
ops-transformer 仓库里的 FlashAttention 实现,针对达芬奇架构做了几处优化:
1. 充分利用 AICore 的矩阵计算单元
达芬奇架构的 AICore 里有专门的矩阵计算单元(Cube Unit),做大矩阵乘法比用 Vector Unit 快得多。FlashAttention 里的 Q×K^T 正好是大矩阵乘法,直接喂给 Cube Unit。
2. 多核并行
Q 的不同 block 可以分配到不同的 AICore 上并行计算。910 有 32 个 AICore,理论上可以 32 路并行。实际受限于显存带宽,加速比大概在 20-25 倍。
3. 精度处理
FlashAttention 里有一个 trick:计算 exp(x - max) 而不是直接算 exp(x),避免数值溢出。ops-transformer 的实现里用 float16 存中间结果,关键步骤(Softmax 归一化)用 float32 保精度,最后再 cast 回 float16。
怎么用?
ops-transformer 已经把 FlashAttention 封装好了,直接调用:
python
import torch
import ops_transformer # 昇腾优化版
# 输入:batch=4, seq_len=512, num_heads=16, head_dim=64
Q = torch.randn(4, 16, 512, 64, device="npu")
K = torch.randn(4, 16, 512, 64, device="npu")
V = torch.randn(4, 16, 512, 64, device="npu")
# 调用 FlashAttention
output = ops_transformer.flash_attention(Q, K, V)
# 输出 shape: (4, 16, 512, 64)
print(output.shape)
如果是做推理,还可以开启 KV Cache:
python
# KV Cache 模式(推理场景)
cache_K = torch.randn(4, 16, 128, 64, device="npu") # 已缓存的 K
cache_V = torch.randn(4, 16, 128, 64, device="npu") # 已缓存的 V
# 新来的一个 token
Q_new = torch.randn(4, 16, 1, 64, device="npu")
K_new = torch.randn(4, 16, 1, 64, device="npu")
V_new = torch.randn(4, 16, 1, 64, device="npu")
# 拼接缓存,做一次 Attention
K_full = torch.cat([cache_K, K_new], dim=2)
V_full = torch.cat([cache_V, V_new], dim=2)
output = ops_transformer.flash_attention(Q_new, K_full, V_full)
KV Cache 场景下,FlashAttention 的优势更明显------每次只需要算一个新 token 对全部历史 token 的注意力,计算量是 O(1) 而不是 O(N)。
和 v2、v3 的关系
FlashAttention 出来后,又有了 v2 和 v3。简单说:
- v1(2022):引入 Tiling + Online Softmax,解决显存问题
- v2(2023):减少非矩阵运算(softmax、dropout 等)的显存读写,进一步优化带宽利用率
- v3(2024):针对 H100/H800 的 FP8 支持,以及更好的多卡并行策略
ops-transformer 目前主要实现了 v1 和 v2 的核心思路,在昇腾 910 上效果已经很好。v3 的 FP8 部分在等昇腾下一代芯片的硬件支持。
一句话总结
FlashAttention 本质上是在说:别把中间结果存下来,边算边丢,需要的时候再算一遍。这个想法看起来简单,但要把它正确地、高效地实现在硬件上,需要对芯片架构、显存层次、数值稳定性都有深入理解。ops-transformer 仓库的价值,就是把这件事在昇腾 NPU 上做对了。