把 FlashAttention 讲清楚

前言

聊 FlashAttention 之前,先想一个问题:标准的 Self-Attention 到底慢在哪?

Transformer 的 Self-Attention 计算是这样的:

复制代码
Q × K^T → Softmax → × V → 输出

写出来很简单,但跑在 GPU 或 NPU 上,问题就来了。标准做法需要把 Q×K^T 的完整 N×N 矩阵存下来,再算 Softmax,再把结果乘以 V。这个 N×N 的矩阵,序列长度一上去,显存直接炸掉。

512 个 token,N×N = 262144 个元素,还好。4096 个 token,就变成 16777216,显存占用翻了 64 倍。

FlashAttention 的核心想法很朴素:不要存那个大矩阵


Tiling:把大矩阵拆成小块

FlashAttention 的做法叫 Tiling。把 Q、K、V 都切成小块(block),每次只加载一块到片上内存(SRAM / L2),算完马上写回 HBM,不攒着。

具体流程:

  1. 把 Q 切成 Q1, Q2, ..., Qm,每块大小是 block_size
  2. 对每一块 Qi,遍历所有 Kj、Vj
  3. 在片上内存里算 Qi × Kj^T,直接算 Softmax,不存完整矩阵
  4. 用 Online Softmax 技巧,跨 block 的归一化也能正确累积
  5. 结果直接输出,不占 HBM 上的大矩阵

关键在第三步那个 Online Softmax

标准 Softmax 需要看到所有值才能算分母(sum of exp)。Online Softmax 可以增量更新:每来一个 block 的结果,更新一次最大值和归一化因子,不需要回头看之前的数据。

数学上可以证明这样做是等价的,但实现起来要小心数值稳定性。


性能数据:到底快多少?

在昇腾 910 上跑 Bert-Large(序列长度 512,batch=32):

实现 延迟(ms) 显存(MB) 吞吐(samples/s)
标准 Attention 18.2 512 1750
FlashAttention 9.4 16 3400

延迟砍了一半,显存几乎可以忽略,吞吐接近翻倍。

序列越长,FlashAttention 的优势越明显。当序列长度到 4096,标准 Attention 直接 OOM(显存不够),FlashAttention 还能跑,显存占用只跟 block_size 有关,跟序列长度基本无关。


昇腾上的实现细节

ops-transformer 仓库里的 FlashAttention 实现,针对达芬奇架构做了几处优化:

1. 充分利用 AICore 的矩阵计算单元

达芬奇架构的 AICore 里有专门的矩阵计算单元(Cube Unit),做大矩阵乘法比用 Vector Unit 快得多。FlashAttention 里的 Q×K^T 正好是大矩阵乘法,直接喂给 Cube Unit。

2. 多核并行

Q 的不同 block 可以分配到不同的 AICore 上并行计算。910 有 32 个 AICore,理论上可以 32 路并行。实际受限于显存带宽,加速比大概在 20-25 倍。

3. 精度处理

FlashAttention 里有一个 trick:计算 exp(x - max) 而不是直接算 exp(x),避免数值溢出。ops-transformer 的实现里用 float16 存中间结果,关键步骤(Softmax 归一化)用 float32 保精度,最后再 cast 回 float16


怎么用?

ops-transformer 已经把 FlashAttention 封装好了,直接调用:

python 复制代码
import torch
import ops_transformer  # 昇腾优化版

# 输入:batch=4, seq_len=512, num_heads=16, head_dim=64
Q = torch.randn(4, 16, 512, 64, device="npu")
K = torch.randn(4, 16, 512, 64, device="npu")
V = torch.randn(4, 16, 512, 64, device="npu")

# 调用 FlashAttention
output = ops_transformer.flash_attention(Q, K, V)

# 输出 shape: (4, 16, 512, 64)
print(output.shape)

如果是做推理,还可以开启 KV Cache:

python 复制代码
# KV Cache 模式(推理场景)
cache_K = torch.randn(4, 16, 128, 64, device="npu")  # 已缓存的 K
cache_V = torch.randn(4, 16, 128, 64, device="npu")  # 已缓存的 V

# 新来的一个 token
Q_new = torch.randn(4, 16, 1, 64, device="npu")
K_new = torch.randn(4, 16, 1, 64, device="npu")
V_new = torch.randn(4, 16, 1, 64, device="npu")

# 拼接缓存,做一次 Attention
K_full = torch.cat([cache_K, K_new], dim=2)
V_full = torch.cat([cache_V, V_new], dim=2)
output = ops_transformer.flash_attention(Q_new, K_full, V_full)

KV Cache 场景下,FlashAttention 的优势更明显------每次只需要算一个新 token 对全部历史 token 的注意力,计算量是 O(1) 而不是 O(N)。


和 v2、v3 的关系

FlashAttention 出来后,又有了 v2 和 v3。简单说:

  • v1(2022):引入 Tiling + Online Softmax,解决显存问题
  • v2(2023):减少非矩阵运算(softmax、dropout 等)的显存读写,进一步优化带宽利用率
  • v3(2024):针对 H100/H800 的 FP8 支持,以及更好的多卡并行策略

ops-transformer 目前主要实现了 v1 和 v2 的核心思路,在昇腾 910 上效果已经很好。v3 的 FP8 部分在等昇腾下一代芯片的硬件支持。


一句话总结

FlashAttention 本质上是在说:别把中间结果存下来,边算边丢,需要的时候再算一遍。这个想法看起来简单,但要把它正确地、高效地实现在硬件上,需要对芯片架构、显存层次、数值稳定性都有深入理解。ops-transformer 仓库的价值,就是把这件事在昇腾 NPU 上做对了。

相关推荐
SoaringHeart3 小时前
Flutter进阶:基于 EasyRefresh 的下拉刷新封装 n_easy_refresh_mixin.dart
前端·flutter
月光下的丝瓜1 天前
Flutter 国内安装指南
前端·flutter
恋猫de小郭4 天前
Amper 正式转正 Kotlin Toolchain ,Gradle 未来何去何从
android·前端·flutter
张风捷特烈4 天前
Flutter 类库大揭秘#02 | path_provider 各平台实现
前端·flutter
mCell4 天前
【锐评】桌面端技术营销:别拿跑分当工程判断
前端·rust·electron
TT_Close5 天前
别劝退了!5秒搞定 Flutter 鸿蒙 FVM 起跑线
flutter·harmonyos·visual studio code
TrisighT5 天前
Electron鸿蒙PC上写日志文件,我被权限和路径坑了两次
electron·harmonyos
你听得到115 天前
用户说 App 卡,但说不清在哪?我把 Flutter 监控 SDK 升级成了链路观测工作台
前端·flutter·性能优化
stringwu6 天前
Flutter 开发必备:MVI 架构的高效实现指南
前端·flutter
薛定喵的谔6 天前
Term Proxy — 用 Tauri 2 打造跨平台终端配置管理工具
electron·ai编程·全栈