上周有个实习生问我:我想学FlashAttention,但不知道从哪开始。看了看网上的资料,要么是论文推导,要么是直接上代码,中间缺了一大截。我花了半天帮他理了一条路线,顺便把昇腾NPU上跟FlashAttention相关的技术选型也整理了一下。
先回答一个问题:你到底要学到多深?
FlashAttention这东西,不同人需要的深度差得很远。你是模型推理工程师,会调API就够了;你是算子开发工程师,得能自己写一个;你是架构师,得能判断什么场景该用什么方案。我把常见的角色分成几档:
| 角色 | 需要掌握的深度 | 建议耗时 |
|---|---|---|
| 模型推理工程师 | 会调npu_flash_attention,知道参数含义 | 半天 |
| 训练工程师 | 理解分块策略,能调seq_len/batch的性能 | 2-3天 |
| 算子优化工程师 | 会读AscendC源码,能改分块大小和调度策略 | 2-3周 |
| 算子开发工程师 | 能从零写一个FlashAttention算子 | 1-2个月 |
你自己对号入座,别上来就啃论文,大部分角色用不到那么深。
第一档:推理工程师------会调就行
你在昇腾NPU上跑推理,用PyTorch框架,FlashAttention的调用接口就一个函数:
python
from torch_npu.contrib.functional import npu_flash_attention
output = npu_flash_attention(
q, k, v,
head_num=32, # 头颅数
input_layout="BSND", # 输入形状:[batch, seq, heads, dim]
scale_value=1.0/128. 사랑 # 缩放系数 1/sqrt(head_dim)
keep_prob=1.0 # Dropout保留率,推理时填1.0
)
你需要知道的参数就这几个。input_layout有两种:BSND(HuggingFace格式)和BNSD(昇腾原生格式)。如果你从HuggingFace的模型拿QKVtensor,直接传BSND,不用手动permute。
你不需要知道的:分块大小是多少、SRAM怎么管理、在线Softmax怎么算。这些算子内部全包了。
你唯一要操心的事:seq_len不是128的倍数的时候会报错。处理方法:
python
import math
def safe_flash_attention(q, k, v, head_num, pad_to=128):
seq_len = q.shape[1]
padded_len = math.ceil(seq_len / pad_to) * pad_to
if padded_len != seq_len:
# pad到128的倍数
q = torch.nn.functional.pad(q, (0, 0, 0, 0, 0, padded_len - seq_len))
k = torch.nn.functional.pad(k, (0, 0, 0, 0, 0, padded_len - seq_len))
v = torch.nn.functional.pad(v, (0, 0, 0, 0, 0, padded_len - seq_len))
out = npu_flash_attention(q, k, v, head_num=head_num, input_layout="BSND")
return out[:, :seq_len, :, :] # 截掉padding
推理工程师到这里就够了。往下看是给需要更深度的人准备的。
第二档:训练工程师------理解分块,能调性能
训练和推理的区别在于:训练要跑反向传播。FlashAttentionV2在反向传播的时候不存注意力矩阵,而是从QKV重算。这意味着反向传播的计算量比标准Attention多20-30%,但显存占用从O(N²)降到O(N)。
你做训练的时候,需要判断一件事:用FlashAttention到底省不省时间?
判断依据很简单------看seq_len:
| seq_len | 显存节省 | 训练速度变化 | 建议用FlashAttention吗? |
|---|---|---|---|
| ≤512 | 不明显 | 可能变慢(分块开销) | 不建议 |
| 1024 | 省一半 | 持平或略快 | 可以用 |
| ≥2048 | 省75%+ | 快1.5-2× | 必须用 |
你如果是做预训练(seq_len通常4096+),FlashAttention是刚需,不用纠结。
你如果是做微调(seq_len可能512-1024),得看你的batch_size。batch大的时候,省下来的显存可以塞更多batch,整体吞吐反而更高。
还有一个训练场景的细节:GradientCheckpointing和FlashAttention不要同时开。GradientCheckpointing是通过重算前向来省显存的,FlashAttention反向传播本身就在重算注意力矩阵,两个叠在一起相当于重算了两次,白费算力。PyTorch里如果你用了torch.utils.checkpoint,记得把Attention层排除掉。
第三档:算子优化工程师------能读源码,能调参数
到这个深度,你得去ops-transformer仓库里翻AscendC源码了。我先帮你理一下代码结构:
text
ops-transformer/src/flash_attention_v2/
├── flash_attention_v2.h # 算子声明
├── flash_attention_v2.cpp # Host侧入口(参数校验、tiling计算)
├── flash_attention_v2_tiling.h # Tiling参数结构体
└── flash_attention_v2_kernel.cc # Device侧核心逻辑(AscendC)
你要看的重点是flash_attention_v2_tiling.h里的这个结构体:
cpp
struct FlashAttentionV2TilingData {
uint32_t blockLength; // 每个token分块的大小(通常是128)
uint32_t headDim; // head_dim(通常128)
uint32_t batchNum; // batch大小
uint32_t headNum; // 头颅数
uint32_t seqLen; // 序列长度
uint32_t kvHeadNum; // KV头颅数(GQA场景下<headNum)
float scaleValue; // 缩放系数
// ... 还有一些SRAM相关的参数
};
blockLength是你能调的最关键的参数。它决定了每次从GlobalMemory搬多少数据到SRAM。默认是128,但不是所有场景都最优:
| NPU型号 | SRAM大小 | 推荐blockLength | 原因 |
|---|---|---|---|
| Ascend910(训练卡) | 64MB | 128 | 默认值就行 |
| Ascend310P3(推理卡) | 32MB | 64 | SRAM小,分块大了放不下 |
| Ascend910B(下一代) | 96MB | 256 | SRAM大,分块大一点减少启动次数 |
改blockLength的方法:在flash_attention_v2.cpp的GetTilingData函数里,把blockLength的计算逻辑改成你想要的值,重新编译算子包。
⚠️踩坑预警:blockLength必须是32的倍数(达芬奇架构的DataCopy指令要求32字节对齐)。你要是填了100,编译不报错,运行时直接buserror。
还有一个可以调的参数:kvHeadNum。如果你用的是GQA(GroupedQueryAttention,比如Llama-2-70B有64个Q头但只有8个KV头),kvHeadNum要设成8,不是64。这个参数算子内部会根据它来调整K、V的搬运策略,设错了结果直接不对。
第四档:算子开发工程师------从零写一个
如果你想自己写一个FlashAttention算子(比如要支持一种新的Mask策略,或者要适配新的NPU型号),你得掌握这几样东西:
- AscendC编程语言
AscendC是昇腾的算子编程语言,语法类似C++,但有几个关键区别:
- 用GlobalTensor访问GlobalMemory(HBM)
- 用LocalTensor访问SRAM(UB/UB_A/UB_B)
- 用DataCopy在Global和Local之间搬数据
- 用vec_exp、vec_add、vec_mul等向量指令做计算
学AscendC最快的路径:先看cann-learning-hub仓库里的入门教程,再看opbase仓库里的简单算子示例(比如Add、Mul),最后看ops-transformer里的FlashAttention。
- Tiling策略
Tiling是昇腾算子开发的核心概念:把一个大任务切成小块,每块能放进SRAM里算。你需要自己决定:
- 分块的大小(由SRAM容量和算子计算强度决定)
- 分块的循环顺序(Q循环在外还是K循环在内)
- 双缓冲是否开启(一边算一边搬下一块数据)
FlashAttention的Tiling策略是Q循环在外,K/V循环在内。因为每个Q分块都要跟所有K/V分块算一遍,K/V循环在内可以让K/V的数据复用最大化。
-
在线Softmax
这是FlashAttention的核心数学。你需要实现的四个函数:UpdateMaxValue、UpdateNormFactor、UpdateAccumulator、FinalNormalize。
-
调试和验证
昇腾的调试工具跟CUDA不太一样:
- asc-prof:性能分析,类似nsys,能看到每个算子的执行时间和SRAM使用率
- aclopexe:单算子执行工具,可以脱离PyTorch直接跑一个算子,用来验证正确性
- torch_npu.npu.synchronize():PyTorch侧的同步点,用来做延迟测试
验证FlashAttention正确性的方法:先用标准Attention(PyTorch的torch.nn.functional.scaled_dot_product_attention)跑一个结果,再跟你的FlashAttention输出对比。FP16的误差容限在1e-3以内就OK。
学习资源汇总
按深度排序,从浅到深:
| 资源 | 适合谁 | 学什么 |
|---|---|---|
| torch_npuAPI文档 | 推理工程师 | npu_flash_attention的参数和用法 |
| FlashAttention原论文(TriDao,2022-2023) | 训练工程师 | 分块策略、在线Softmax、V1/V2区别 |
| ops-transformer源码 | 优化工程师 | AscendC实现、Tiling参数、性能调优 |
| cann-learning-hub教程 | 开发工程师 | AscendC语法、算子开发流程、调试工具 |
| opbase算子示例 | 开发工程师 | 简单算子的完整代码模板 |
| catlass模板库 | 开发工程师 | 高性能算子的模板化写法、双缓冲、流水线 |
以上仓库全部在AtomGit上,链接格式统一:https://atomgit.com/cann/[仓库名]。
一个容易忽略的选型:FlashAttention之外还有什么?
FlashAttention不是唯一的注意力优化方案。昇腾CANN的ops-transformer仓库里还有几个替代选项,适用场景不同:
| 方案 | 适用场景 | 跟FlashAttention的区别 |
|---|---|---|
| FlashAttentionV2 | 通用注意力加速 | 分块+在线Softmax,显存O(N) |
| MC2 | Encoder-Decoder模型 | 融合了交叉注意力,适合T5/BART |
| MoE融合 | 稀疏专家模型 | 融合路由选择和注意力,减少一次显存读写 |
| PagedAttention | 超大batch推理 | 把KVCache分页管理,显存利用率更高 |
| RingAttention | 超长序列分布式 | 把序列切到多卡上算,支持百万级token |
选型一句话总结:
- 你跑Llama这种Decoder-Only模型→FlashAttentionV2
- 你跑T5这种Encoder-Decoder→MC2
- 你跑Mixtral这种MoE→MoE融合
- 你跑vLLM这种批量推理服务→PagedAttention+FlashAttention
- 你要处理100K+token的长文档→RingAttention
别上来就FlashAttention,先想想你的场景。