
文章目录
- 前言
-
- [为什么 FlashAttention 需要 Tiling?](#为什么 FlashAttention 需要 Tiling?)
- [Tiling 策略原理:把 SRAM 用到极致](#Tiling 策略原理:把 SRAM 用到极致)
-
- [SRAM 分块逻辑](#SRAM 分块逻辑)
- [Inner-Outer Loop 结构](#Inner-Outer Loop 结构)
- [Forward/Reverse 重计算](#Forward/Reverse 重计算)
- [ops-transformer 中的实现](#ops-transformer 中的实现)
-
- [FlashAttentionScore 算子](#FlashAttentionScore 算子)
- [FlashAttentionScoreGrad 算子](#FlashAttentionScoreGrad 算子)
- [Tiling 配置的自动调优](#Tiling 配置的自动调优)
- 性能收益:用数据说话
- 关键警告:两个容易踩的坑
-
- [坑 1:块大小选错,性能反而更差](#坑 1:块大小选错,性能反而更差)
- [坑 2:因果注意力 + Tiling,mask 逻辑要小心](#坑 2:因果注意力 + Tiling,mask 逻辑要小心)
- 总结与行动指引
前言
你有没有想过,为什么 Transformer 模型训练到一半突然 OOM?或者明明算力争力十足,Attention 层的耗时却居高不下?
答案藏在显存访问里。在昇腾CANN软件栈中,ops-transformer 作为 Transformer 专用算子库,通过 FlashAttention 算子的 Tiling 策略,把本来会压垮显存的 Attention 计算搬进了速度快得多的 SRAM。这不是简单的分块,而是一场针对内存层级架构的精确打击。
为什么 FlashAttention 需要 Tiling?
标准 Attention 的计算流程看起来人畜无害:
python
# 标准 Attention 实现(伪代码)
def standard_attention(Q, K, V):
# Q, K, V: [batch, heads, seq_len, head_dim]
scores = Q @ K.transpose(-2, -1) # [batch, heads, seq_len, seq_len]
# ⚠️ 问题1: scores 矩阵存入 HBM,占用 O(N²) 空间
attn_weights = softmax(scores) # 需要读回 scores,再次写回 HBM
# ⚠️ 问题2: softmax 需要全局归约,三次 HBM 往返
output = attn_weights @ V # 读 attn_weights,写 output
return output
在昇腾NPU上,HBM(High Bandwidth Memory)的带宽是瓶颈。以 100M 参数量的模型为例,seq_len=2048 时,scores 矩阵仅 FP16 格式就占用 2048×2048×2 bytes ≈ 8MB------看似不大,但多层、多卡并行时,这个中间结果的读写会成为性能杀手。
更致命的是 Attention 的计算密度:矩阵乘法本身是 compute-bound,但 softmax 和 dropout 是 memory-bound。当大模型训练时,HBM 的读写延迟会让 NPU 的算力大量闲置。
金句:Attention 的瓶颈从来不是计算不够快,而是数据搬得不够快。
Tiling 策略原理:把 SRAM 用到极致
FlashAttention 的核心思想是:不存储完整的 N×N 注意力矩阵,而是在 SRAM 中分块计算。
SRAM 分块逻辑
昇腾NPU的片上 SRAM 容量在几 MB 级别(取决于具体型号),远小于 HBM,但带宽是 HBM 的 10-100 倍。Tiling 的目标是把 Q、K、V 切成小块,让每个小块能完整放入 SRAM,避免频繁的 HBM 读写。
cpp
// Ascend C 伪代码:Tiling 参数计算
struct FlashAttentionTiling {
uint32_t batchSize;
uint32_t numHeads;
uint32_t seqLenQ; // Query 序列长度
uint32_t seqLenK; // Key/Value 序列长度
uint32_t headDim; // 头维度(通常 64/128)
// Tiling 参数
uint32_t blockSizeM; // Query 分块大小(通常 128/256)
uint32_t blockSizeN; // Key/Value 分块大小
uint32_t numBlocksM; // Query 方向分块数
uint32_t numBlocksN; // Key 方向分块数
// SRAM 预算分配
uint32_t sramSizePerCore; // 每核 SRAM 配额(字节)
uint32_t workspaceSize; // 临时缓冲区大小
};
FlashAttentionTiling CalculateTilingConfig(
const AttentionInputDesc& desc,
const HardwareInfo& hwInfo)
{
FlashAttentionTiling tiling;
// 核心:根据 SRAM 容量反推分块大小
uint32_t availableSram = hwInfo.sramPerCore - RESERVED_SRAM;
// Q/K/V 各占一块,加上输出和中间 buffer
uint32_t sramPerBlock = availableSram / 5; // 5 个 buffer
tiling.blockSizeM = std::min(
MAX_BLOCK_SIZE_M,
static_cast<uint32_t>(std::sqrt(sramPerBlock / sizeof(float)))
);
tiling.numBlocksM = (desc.seqLenQ + tiling.blockSizeM - 1) / tiling.blockSizeM;
tiling.numBlocksN = (desc.seqLenK + tiling.blockSizeN - 1) / tiling.blockSizeN;
return tiling;
}
Inner-Outer Loop 结构
Tiling 后的计算分为两层循环:
- Outer Loop:遍历 Key/Value 的块(block_N),每次加载一块 K、V 到 SRAM
- Inner Loop:遍历 Query 的块(block_M),每次加载一块 Q 到 SRAM,与当前 K、V 块计算局部注意力
python
# FlashAttention Tiling 计算流程(伪代码)
def flash_attention_tiled(Q, K, V, tiling):
# Q: [batch, heads, seq_len_q, head_dim]
# K, V: [batch, heads, seq_len_k, head_dim]
outputs = []
# Outer Loop: 遍历 K/V 分块
for kv_start in range(0, seq_len_k, block_size_n):
kv_end = min(kv_start + block_size_n, seq_len_k)
K_block = K[:, :, kv_start:kv_end, :] # 加载到 SRAM
V_block = V[:, :, kv_start:kv_end, :]
# Inner Loop: 遍历 Q 分块
for q_start in range(0, seq_len_q, block_size_m):
q_end = min(q_start + block_size_m, seq_len_q)
Q_block = Q[:, :, q_start:q_end, :] # 加载到 SRAM
# 局部注意力计算(全程在 SRAM 中)
scores_block = Q_block @ K_block.T # [block_m, block_n]
# 在线 softmax(不需要存储完整 scores 矩阵)
m_new, l_new, attn_block = online_softmax(
scores_block, m_old, l_old
)
# 累积输出(避免存储完整 attn_weights)
output_block = attn_block @ V_block
outputs.append(output_block)
return concatenate(outputs)
Forward/Reverse 重计算
FlashAttention 的另一个技巧是反向传播时不保存前向的注意力矩阵,而是重计算。
标准实现会在前向保存 attn_weights(N×N 矩阵),反向时直接用它算梯度。FlashAttention 选择多算一遍前向的 attention score,但省掉了存储 N×N 矩阵的空间。
cpp
// Ascend C:反向算子 FlashAttentionScoreGrad 的核心逻辑
__global__ void FlashAttentionScoreGradKernel(
const float* Q, // [batch, heads, seq_q, dim]
const float* K,
const float* V,
const float* dO, // 上游梯度
float* dQ,
float* dK,
float* dV,
const FlashAttentionTiling& tiling
) {
// 关键:不读取前向保存的 attn_weights,而是重新计算
for (int kv_block = 0; kv_block < tiling.numBlocksN; ++kv_block) {
// 重新加载 K、V 块
LoadToSRAM(K, kv_block, sram_K);
LoadToSRAM(V, kv_block, sram_V);
for (int q_block = 0; q_block < tiling.numBlocksM; ++q_block) {
// 重新计算 attention score(重计算在这里)
RecomputeAttentionScore(sram_Q, sram_K, sram_scores);
// 用重计算的 score 算梯度
ComputeGrad(dO, sram_scores, sram_V, dQ, dK, dV);
}
}
}
金句:用计算换存储,在 NPU 上这是一笔划算的买卖。
ops-transformer 中的实现
在昇腾CANN的 ops-transformer 库中,FlashAttention 被封装为两个核心算子:FlashAttentionScore 和 FlashAttentionScoreGrad。
FlashAttentionScore 算子
cpp
// ops-transformer 中的算子调用接口(简化)
#include "ops_transformer/flash_attention_score.h"
atb::Operation* CreateFlashAttentionScoreOp(
const FlashAttentionScoreParam& param)
{
auto op = new atb::FlashAttentionScore();
// 设置 Tiling 参数
op->SetAttr("head_num", param.numHeads);
op->SetAttr("head_dim", param.headDim);
op->SetAttr("block_size", param.blockSize); // Tiling 块大小
op->SetAttr("causal", param.isCausal); // 是否因果注意力
// 内省:自动选择最优 Tiling 配置
op->EnableAutoTuning(true);
return op;
}
// 使用示例
auto fa_op = CreateFlashAttentionScoreOp(param);
atb::Tensor Q = atb::FromNpu("query_tensor");
atb::Tensor K = atb::FromNpu("key_tensor");
atb::Tensor V = atb::FromNpu("value_tensor");
atb::Tensor output = fa_op->Execute({Q, K, V});
FlashAttentionScoreGrad 算子
反向算子需要处理重计算的细节,ops-transformer 提供了透明支持:
python
# 训练脚本中的反向传播(伪代码)
import ops_transformer as opt
# 前向
q = torch.randn(batch, heads, seq_q, dim, device='npu')
k = torch.randn(batch, heads, seq_k, dim, device='npu')
v = torch.randn(batch, heads, seq_k, dim, device='npu')
# 使用 ops-transformer 的 FlashAttention
output = opt.FlashAttentionScore.apply(q, k, v)
# 反向:自动触发 FlashAttentionScoreGrad
loss = output.sum()
loss.backward() # 内部调用 FlashAttentionScoreGrad,重计算 attention score
Tiling 配置的自动调优
ops-transformer 内置了 Tiling 参数的自动搜索逻辑:
cpp
// Ascend C:自动 Tiling 调优(伪代码)
class TilingAutoTuner {
public:
FlashAttentionTiling Tune(
const AttentionInputDesc& desc,
const NpuHardwareInfo& npuInfo)
{
std::vector<FlashAttentionTiling> candidates;
// 生成候选配置
for (uint32_t block_m : {64, 128, 256}) {
for (uint32_t block_n : {64, 128, 256}) {
candidates.push_back(
GenerateTiling(desc, npuInfo, block_m, block_n)
);
}
}
// 在 NPU 上跑 benchmark,选最快的
FlashAttentionTiling best = Profiler::SelectFastest(candidates);
// 缓存到文件,下次直接复用
TilingCache::Save(desc.hash(), best);
return best;
}
};
性能收益:用数据说话
在昇腾NPU上,FlashAttention 的 Tiling 策略带来的收益是实打实的:
| 指标 | 标准 Attention | FlashAttention (Tiling) | 加速比 |
|---|---|---|---|
| HBM 访问次数 | O(N²) | O(N²/B)(B 为块大小) | 减少 10-50× |
| 峰值显存占用 | O(N²) | O(N) | 节省 90%+ |
| 训练吞吐量(seq=2048) | 1.0× | 2.3× | +130% |
| 推理延迟(batch=1, seq=512) | 12ms | 4ms | 3× 加速 |
bash
# 性能 profiling 命令(在昇腾NPU上)
export ASCEND_GLOBAL_LOG_LEVEL=3
export ASCEND_SLOG_PRINT_TO_STDOUT=1
# 运行 benchmark
python benchmark_flash_attention.py \
--batch 4 \
--heads 32 \
--seq-len 2048 \
--head-dim 128 \
--use-tiling \
--profile
# 输出示例:
# [Profiler] FlashAttentionScore (Tiling):
# - Kernel time: 2.34ms (vs 7.89ms standard)
# - HBM read: 1.2GB (vs 18.7GB standard)
# - HBM write: 0.8GB (vs 12.4GB standard)
# - SRAM reuse ratio: 87%
金句:Tiling 不是让计算变快,而是让数据少跑路。
关键警告:两个容易踩的坑
坑 1:块大小选错,性能反而更差
Tiling 的块大小不是越大越好。块太大,SRAM 放不下,会触发 spilling(溢出到 HBM);块太小,Kernel Launch 开销和 SRAM 利用率低会成为新瓶颈。
python
# 错误示例:块大小超过 SRAM 容量
tiling_config = {
"block_size_m": 1024, # ❌ 假设 SRAM 只有 2MB,这会溢出
"block_size_n": 1024
}
# 正确做法:让 ops-transformer 自动选择
tiling_config = auto_tune_tiling(
seq_len=2048,
head_dim=128,
sram_budget_mb=2.0 # 查询 NPU 的 SRAM 大小
)
坑 2:因果注意力 + Tiling,mask 逻辑要小心
因果注意力(Causal Attention)要求上三角的 attention score 为 -inf。Tiling 后,每个块要正确应用 mask,否则会出现"未来信息泄露"。
cpp
// 错误:Tiling 后忘记处理因果 mask
__global__ void BrokenCausalAttention(
float* scores, // [block_m, block_n]
int q_start,
int kv_start,
int seq_len
) {
int q_idx = blockIdx.x * BLOCK_M + threadIdx.x;
int kv_idx = blockIdx.y * BLOCK_N + threadIdx.y;
// ❌ 没有检查 q_idx 和 kv_idx 的因果约束
scores[threadIdx.x][threadIdx.y] = Q[q_idx] @ K[kv_idx];
}
// 正确:在块内正确 mask
__global__ void CorrectCausalAttention(
float* scores,
int q_start,
int kv_start,
int seq_len
) {
int q_idx = q_start + threadIdx.x;
int kv_idx = kv_start + threadIdx.y;
// ✅ 应用因果 mask
if (q_idx < kv_idx) {
scores[threadIdx.x][threadIdx.y] = -INFINITY;
} else {
scores[threadIdx.x][threadIdx.y] = Q[q_idx] @ K[kv_idx];
}
}
总结与行动指引
FlashAttention 的 Tiling 策略本质是对内存层级的精确利用:把计算拆成适合 SRAM 的块,用重计算换存储,最终让 Attention 从 memory-bound 变成 compute-bound。
在昇腾CANN上,ops-transformer 库已经把这套逻辑封装成了开箱即用的算子。你不需要手写 Tiling 逻辑,只需要调用 FlashAttentionScore 和 FlashAttentionScoreGrad,剩下的交给昇腾NPU的硬件特性和 ops-transformer 的自动调优。
如果你对 Attention 的优化还想深入,建议继续学习 SparseFlashAttention(稀疏注意力),它进一步把 Attention 的计算复杂度从 O(N²) 降到 O(N√N)。相关代码和文档可以在 ops-transformer 仓库找到:
https://atomgit.com/cann/ops-transformer
把这个仓库 clone 下来,跑一遍 benchmark,你会看到 Tiling 带来的真实性能差异。