FA4 CuTeDSL 前向内核深度分析
【总】开篇
本篇分析 FA4 CuTeDSL 前向内核的实现架构与核心机制。FA4 前向内核是 FlashAttention 项目中最核心的组件,负责计算 O = softmax(Q @ K^T) @ V。FA4 的最大革新在于使用 Python CuTeDSL 重写了所有内核,替代了此前基于 Cutlass C++ 的实现,同时支持 SM80/90/100/120 四种 GPU 架构。
核心结论 :FA4 CuTeDSL 前向内核通过架构特化实现了跨代 GPU 的最优性能------SM80 使用 cp.async 异步拷贝 + HMMA 矩阵乘、SM90 使用 TMA 加载 + WGMMA 矩阵乘 + 生产者-消费者流水线、SM100 使用 UMMA + 2CTA + SplitKV + Paged KV + 持久化内核 + FP8 等黑威尔特性、SM120 复用 SM80 路径但适配更小的共享内存。所有架构的核心计算循环都是 Q tile 加载 → K/V blocks 遍历 → online softmax → 写出 O,但在数据搬运、矩阵乘指令、流水线编排上各有根本性差异。
【分】主体
1. flash_fwd.py --- FlashAttentionForwardBase / FlashAttentionForwardSm80
1.1 基类设计:FlashAttentionForwardBase
FlashAttentionForwardBase(flash_fwd.py:40-577)是所有前向内核的公共基类,定义了共享的参数结构、内存布局、数据加载逻辑和 epilogue 写出逻辑。
__init__ 参数 (flash_fwd.py:42-117):
| 参数 | 类型 | 说明 |
|---|---|---|
dtype |
Type[cutlass.Numeric] |
数据类型,支持 FP16/BF16 |
head_dim |
int |
Q/K 的 head 维度 |
head_dim_v |
Optional[int] |
V 的 head 维度,默认等于 head_dim |
qhead_per_kvhead |
int |
GQA 中每个 KV head 对应的 Q head 数 |
is_causal / is_local |
bool |
因果/局部注意力掩码 |
pack_gqa |
bool |
是否使用 PackGQA 优化 |
tile_m / tile_n |
int |
Q/KV 的分块大小,默认 128×128 |
num_stages |
int |
流水线阶段数 |
num_threads |
int |
线程数 |
Q_in_regs |
bool |
是否将 Q 保持在寄存器中 |
score_mod / mask_mod |
Optional[cutlass.Constexpr] |
自定义分数修改/掩码函数 |
基类在 __init__ 中完成关键的对齐计算:将 head_dim 向上取整到 16 的倍数作为 tile_hdim(flash_fwd.py:83),并设置 check_hdim_oob 标志以判断是否需要 head 维度越界检查(flash_fwd.py:88-89)。
can_implement 静态方法 (flash_fwd.py:119-175)检查内核是否可执行:验证数据类型(仅 FP16/BF16)、head_dim 对齐(8 的倍数)、tile_n 对齐(16 的倍数)、线程数对齐(32 的倍数)、共享内存容量是否足够。共享内存用量计算公式为:
smem_usage_QV = (smem_Q + smem_V) if not Q_in_regs else max(smem_Q, smem_V)
smem_usage = smem_usage_QV + smem_K
_setup_attributes 方法 (flash_fwd.py:206-300)配置共享内存布局和全局内存拷贝:
- 使用
_get_smem_layout_atom()获取架构特定的 smem 布局原子 - 通过
cute.tile_to_shape扩展为完整的 Q/K/V/O 布局 - 配置
cpasync.CopyG2SOp作为异步拷贝原子(128 bits/拷贝) - 构造
gmem_tiled_copy_Q/K/V/O用于全局内存到共享内存的数据搬运
epilogue 方法 (flash_fwd.py:331-449)处理内核的收尾工作:
- 将
acc_O(Float32 累加器)转换为输出数据类型并写入 smem - 通过
NamedBarrierFwd.Epilogue同步所有 epilogue 线程 - 写出 LSE(log-sum-exp)到全局内存
- 根据
use_tma_O选择 TMA 或 cp.async 写出 O 到全局内存
load_Q / load_K / load_V 方法 (flash_fwd.py:456-576)封装了数据从全局内存到共享内存的异步加载逻辑,包含边界检查和谓词计算。
1.2 FlashAttentionForwardSm80 实现
FlashAttentionForwardSm80(flash_fwd.py:579-1229)继承自 FlashAttentionForwardBase,实现了 Ampere 架构(SM80)的前向注意力内核。
Smem 布局 (flash_fwd.py:580-586):使用 sm80_utils.get_smem_layout_atom 获取 Ampere 兼容的共享内存布局,Q/K 使用相同布局,V/O 使用 head_dim_v 对应的布局。
MMA 指令 (flash_fwd.py:588-599):使用 warp.MmaF16BF16Op,即 HMMA(Half-Precision Matrix Multiply-Accumulate),指令形状为 (16, 8, 16),warp 数量为 num_threads // 32。QK 和 PV 使用两个独立的 tiled_mma。
共享内存结构 (flash_fwd.py:601-620):根据 Q_in_regs 选择 SharedStorageQKV(Q/V 分开存储)或 SharedStorageSharedQV(Q/V 共享同一块 smem 以节省空间),所有结构体按 1024 字节对齐。
__call__ 方法 (flash_fwd.py:622-743):
- 类型检查和张量布局重排(varlen 3D vs 非 varlen 4D)
- 选择 TileScheduler:varlen 用
SingleTileVarlenScheduler,否则用SingleTileScheduler - 计算 softmax_scale 的 log2 版本
- 启动 kernel,grid 维度由 TileScheduler 决定
kernel 方法 (flash_fwd.py:745-1082)是 SM80 前向的核心:
- 初始化阶段 :获取 tile scheduler 的工作分配,计算序列长度信息,确定 K/V 遍历范围
[n_block_min, n_block_max) - 共享内存分配 :通过
SmemAllocator分配 smem,获取 Q/K/V 的 tensor 视图 - MMA 分区 :为 QK 和 PV 两个 MMA 操作分配寄存器片段
tSrQ、tSrK、tOrVt,以及累加器acc_O(初始化为 0) - Smem 拷贝原子 :使用
LdMatrix8x8x16bOp从 smem 加载到寄存器,Q/K 不转置,V 转置 - Softmax 初始化 :创建
Softmax对象并 reset
Prologue (flash_fwd.py:957-989):
- 加载 Q tile 到 smem,
cp_async_commit_group - 如果
Q_in_regs:先加载 1 stage K,等待 Q 加载完成,将 Q 从 smem 拷贝到寄存器,再加载 V - 如果
!Q_in_regs:加载所有 stages 的 K 和 V,然后等待 Q 加载完成
Mainloop (flash_fwd.py:993-1057):
核心循环由 compute_one_n_block 方法实现,从 n_block_max-1 向 n_block_min 遍历 K/V blocks。循环分为三段:
- 最后一块 (需序列长度掩码):
is_first_n_block=True - 因果/局部掩码块 (需因果掩码):
mask_seqlen=True - 无掩码块 :
mask_seqlen=False
compute_one_n_block 方法 (flash_fwd.py:1084-1194)处理单个 K/V block 的完整计算:
sync() → 加载 V_next → QK GEMM → score_mod → 掩码 → online_softmax → rescale_O →
P 类型转换 → PV GEMM
关键步骤:
sync():等待 smem 中的 QK 数据就绪(cp_async_wait_group)sm80_utils.gemm:执行 Q @ K^T 的 HMMAsoftmax.online_softmax:在线 softmax,维护 row_max 和 row_sumsoftmax.rescale_O:用新的缩放因子调整 O 累加器sm80_utils.gemm_rs:执行 P @ V 的 HMMA(P 来自寄存器)
#mermaid-svg-ZX4JGYu6C8bW2jJo{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-ZX4JGYu6C8bW2jJo .error-icon{fill:#552222;}#mermaid-svg-ZX4JGYu6C8bW2jJo .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-ZX4JGYu6C8bW2jJo .marker{fill:#333333;stroke:#333333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .marker.cross{stroke:#333333;}#mermaid-svg-ZX4JGYu6C8bW2jJo svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-ZX4JGYu6C8bW2jJo p{margin:0;}#mermaid-svg-ZX4JGYu6C8bW2jJo .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .cluster-label text{fill:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .cluster-label span{color:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .cluster-label span p{background-color:transparent;}#mermaid-svg-ZX4JGYu6C8bW2jJo .label text,#mermaid-svg-ZX4JGYu6C8bW2jJo span{fill:#333;color:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .node rect,#mermaid-svg-ZX4JGYu6C8bW2jJo .node circle,#mermaid-svg-ZX4JGYu6C8bW2jJo .node ellipse,#mermaid-svg-ZX4JGYu6C8bW2jJo .node polygon,#mermaid-svg-ZX4JGYu6C8bW2jJo .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .rough-node .label text,#mermaid-svg-ZX4JGYu6C8bW2jJo .node .label text,#mermaid-svg-ZX4JGYu6C8bW2jJo .image-shape .label,#mermaid-svg-ZX4JGYu6C8bW2jJo .icon-shape .label{text-anchor:middle;}#mermaid-svg-ZX4JGYu6C8bW2jJo .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .rough-node .label,#mermaid-svg-ZX4JGYu6C8bW2jJo .node .label,#mermaid-svg-ZX4JGYu6C8bW2jJo .image-shape .label,#mermaid-svg-ZX4JGYu6C8bW2jJo .icon-shape .label{text-align:center;}#mermaid-svg-ZX4JGYu6C8bW2jJo .node.clickable{cursor:pointer;}#mermaid-svg-ZX4JGYu6C8bW2jJo .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .arrowheadPath{fill:#333333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-ZX4JGYu6C8bW2jJo .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-ZX4JGYu6C8bW2jJo .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-ZX4JGYu6C8bW2jJo .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-ZX4JGYu6C8bW2jJo .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .cluster text{fill:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo .cluster span{color:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-ZX4JGYu6C8bW2jJo .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-ZX4JGYu6C8bW2jJo rect.text{fill:none;stroke-width:0;}#mermaid-svg-ZX4JGYu6C8bW2jJo .icon-shape,#mermaid-svg-ZX4JGYu6C8bW2jJo .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-ZX4JGYu6C8bW2jJo .icon-shape p,#mermaid-svg-ZX4JGYu6C8bW2jJo .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-ZX4JGYu6C8bW2jJo .icon-shape .label rect,#mermaid-svg-ZX4JGYu6C8bW2jJo .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-ZX4JGYu6C8bW2jJo .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-ZX4JGYu6C8bW2jJo .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-ZX4JGYu6C8bW2jJo :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 是
否
加载 Q tile 到 smem
Prologue: 加载 K/V stages
遍历 K/V blocks
等待 smem QK 就绪
QK GEMM: S = Q @ K^T
应用 score_mod / mask
Online Softmax: 更新 row_max, row_sum
Rescale O 累加器
P = S 转换为 FP16/BF16
PV GEMM: O += P @ V
还有更多 blocks?
Softmax finalize + rescale
Epilogue: 写出 O 和 LSE
2. flash_fwd_sm90.py --- FlashAttentionForwardSm90
FlashAttentionForwardSm90(flash_fwd_sm90.py:52-1545)继承自 FlashAttentionForwardBase,实现了 Hopper 架构(SM90)的前向注意力内核。相比 SM80,SM90 的核心升级是 TMA(Tensor Memory Access)加载 、WGMMA(Warp Group Matrix Multiply-Accumulate) 和生产者-消费者流水线。
2.1 架构差异:TMA vs cp.async
SM90 的 __init__(flash_fwd_sm90.py:53-70)新增了:
intra_wg_overlap: bool = True:是否启用 warp group 内重叠mma_pv_is_rs: bool = True:PV 的 MMA 是否使用寄存器到 smem 路径paged_kv_non_tma: bool = False:是否使用非 TMA 的 Paged KV
TMA 加载 (flash_fwd_sm90.py:261-301):
- Q 使用
CopyBulkTensorTileG2SOp(TMA G2S)加载 - K/V 同样使用 TMA 加载
- O 使用
CopyBulkTensorTileS2GOp(TMA S2G)写回 - TMA 由单个 warp(warp 0)发起,而 cp.async 需要整个 warp group 参与
WGMMA vs HMMA (flash_fwd_sm90.py:96-118):
- QK MMA:
warpgroup.OperandMajorMode.K×warpgroup.OperandMajorMode.K,atom_layout 为(tile_m // 64, 1, 1) - PV MMA:A 来源可以是 RMEM(
mma_pv_is_rs=True)或 SMEM(mma_pv_is_rs=False),B 来源为 SMEM(OperandMajorMode.MN) - WGMMA 指令形状为 64×N×K,由 128 个线程(4 个 warp)组成的 warp group 协同执行
2.2 生产者-消费者流水线
SM90 内核的 kernel 方法(flash_fwd_sm90.py:401-636)采用生产者-消费者模型:
-
生产者(warp_idx < 4,即 warp 0-3):负责数据加载
setmaxregister_decrease:减少寄存器分配以释放给消费者- 调用
self.load()方法执行 Q/K/V 的 TMA 或 cp.async 加载
-
消费者(warp_idx >= 4,即 warp 4-7+):负责 MMA 计算
setmaxregister_increase:增加寄存器分配用于 WGMMA- 调用
self.mma()方法执行注意力计算
流水线类型 (flash_fwd_sm90.py:458-513):
PipelineTmaAsync:TMA 加载的异步流水线,生产者为单个 TMA warpPipelineCpAsync:cp.async 加载的异步流水线,生产者为 128 个线程- Q 使用 1 stage 流水线,K/V 使用
num_stagesstage 流水线
load 方法 (flash_fwd_sm90.py:638-914):
生产者端的加载逻辑,使用 while work_tile.is_valid_tile 持久化循环:
- 获取当前 tile 的 m_block、head_idx、batch_idx
- 计算 TMA 加载的闭包函数
load_Q、load_K、load_V - 支持 Paged KV:TMA 路径(page_size == n_block_size)和 cp.async 路径
- K/V 加载使用
pipeline_k和pipeline_v的producer_acquire/commit机制 - 支持 Block Sparsity 的特殊加载路径
mma 方法 (flash_fwd_sm90.py:936-1267):
消费者端的计算逻辑,同样使用持久化循环:
- MMA 分区 :使用
sm90_utils.partition_fragment_ABC分配 QK 和 PV 的寄存器片段 - Softmax 初始化 :创建
Softmax对象 - 主循环 :遍历 K/V blocks,调用
mma_one_n_block或mma_one_n_block_intrawg_overlap
Intra-WarpGroup Overlap (flash_fwd_sm90.py:1410-1477):
当 intra_wg_overlap=True 时,QK GEMM 和 PV GEMM 可以在同一个 warp group 内重叠执行:
first_half_block_overlap:执行 QK GEMM + softmax + P 转换last_half_block_overlap:执行 PV GEMM- 两者之间通过
warp_scheduler_barrier同步
#mermaid-svg-13LIAhEn3T8jpmXD{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-13LIAhEn3T8jpmXD .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-13LIAhEn3T8jpmXD .error-icon{fill:#552222;}#mermaid-svg-13LIAhEn3T8jpmXD .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-13LIAhEn3T8jpmXD .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-13LIAhEn3T8jpmXD .marker{fill:#333333;stroke:#333333;}#mermaid-svg-13LIAhEn3T8jpmXD .marker.cross{stroke:#333333;}#mermaid-svg-13LIAhEn3T8jpmXD svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-13LIAhEn3T8jpmXD p{margin:0;}#mermaid-svg-13LIAhEn3T8jpmXD .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-13LIAhEn3T8jpmXD .cluster-label text{fill:#333;}#mermaid-svg-13LIAhEn3T8jpmXD .cluster-label span{color:#333;}#mermaid-svg-13LIAhEn3T8jpmXD .cluster-label span p{background-color:transparent;}#mermaid-svg-13LIAhEn3T8jpmXD .label text,#mermaid-svg-13LIAhEn3T8jpmXD span{fill:#333;color:#333;}#mermaid-svg-13LIAhEn3T8jpmXD .node rect,#mermaid-svg-13LIAhEn3T8jpmXD .node circle,#mermaid-svg-13LIAhEn3T8jpmXD .node ellipse,#mermaid-svg-13LIAhEn3T8jpmXD .node polygon,#mermaid-svg-13LIAhEn3T8jpmXD .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-13LIAhEn3T8jpmXD .rough-node .label text,#mermaid-svg-13LIAhEn3T8jpmXD .node .label text,#mermaid-svg-13LIAhEn3T8jpmXD .image-shape .label,#mermaid-svg-13LIAhEn3T8jpmXD .icon-shape .label{text-anchor:middle;}#mermaid-svg-13LIAhEn3T8jpmXD .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-13LIAhEn3T8jpmXD .rough-node .label,#mermaid-svg-13LIAhEn3T8jpmXD .node .label,#mermaid-svg-13LIAhEn3T8jpmXD .image-shape .label,#mermaid-svg-13LIAhEn3T8jpmXD .icon-shape .label{text-align:center;}#mermaid-svg-13LIAhEn3T8jpmXD .node.clickable{cursor:pointer;}#mermaid-svg-13LIAhEn3T8jpmXD .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-13LIAhEn3T8jpmXD .arrowheadPath{fill:#333333;}#mermaid-svg-13LIAhEn3T8jpmXD .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-13LIAhEn3T8jpmXD .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-13LIAhEn3T8jpmXD .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-13LIAhEn3T8jpmXD .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-13LIAhEn3T8jpmXD .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-13LIAhEn3T8jpmXD .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-13LIAhEn3T8jpmXD .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-13LIAhEn3T8jpmXD .cluster text{fill:#333;}#mermaid-svg-13LIAhEn3T8jpmXD .cluster span{color:#333;}#mermaid-svg-13LIAhEn3T8jpmXD div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-13LIAhEn3T8jpmXD .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-13LIAhEn3T8jpmXD rect.text{fill:none;stroke-width:0;}#mermaid-svg-13LIAhEn3T8jpmXD .icon-shape,#mermaid-svg-13LIAhEn3T8jpmXD .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-13LIAhEn3T8jpmXD .icon-shape p,#mermaid-svg-13LIAhEn3T8jpmXD .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-13LIAhEn3T8jpmXD .icon-shape .label rect,#mermaid-svg-13LIAhEn3T8jpmXD .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-13LIAhEn3T8jpmXD .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-13LIAhEn3T8jpmXD .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-13LIAhEn3T8jpmXD :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 消费者 Warp 4+
生产者 Warp 0-3
mbarrier 信号
释放 pipeline slot
TMA/cp.async 加载 Q
TMA/cp.async 加载 K/V blocks
pipeline producer_acquire/commit
pipeline consumer_wait
WGMMA: S = Q @ K^T
Online Softmax
Rescale O
WGMMA: O += P @ V
pipeline consumer_release
2.3 Warp Scheduler Barrier
当使用多个 MMA warp group(num_wg_mma >= 2)时,需要 warp scheduler barrier 来协调不同 warp group 的执行顺序(flash_fwd_sm90.py:1524-1545):
warp_scheduler_barrier_sync:当前 warp group 等待warp_scheduler_barrier_arrive:通知下一个 warp group 可以开始
3. flash_fwd_sm100.py --- FlashAttentionForwardSm100
FlashAttentionForwardSm100(flash_fwd_sm100.py:114-999+)是 Blackwell 架构(SM100)的前向注意力内核,不继承自 FlashAttentionForwardBase,而是完全独立实现。这是 SM100 架构根本性变化所致。
3.1 Blackwell 架构核心特性
UMMA(Unified Matrix Multiply-Accumulate) :
SM100 使用 tcgen05 指令集替代 SM90 的 warpgroup 指令。UMMA 的关键区别:
- 累加器存储在 TMEM(Tensor Memory) 而非寄存器中
- MMA 指令由单个 warp 发起(而非 warp group)
- 支持 2CTA 模式:两个 CTA 协同执行一个更大的 MMA 操作
2CTA 模式 (flash_fwd_sm100.py:154-172):
use_2cta_instrs=True时,cta_group_size=2,cluster_shape_mn=(2,1)- MMA tiler 的 M 维度翻倍:
mma_tiler_qk = (2 * m_block_size, n_block_size, head_dim_padded) - 两个 CTA 组成一个 cluster,协同加载和计算
SplitKV (flash_fwd_sm100.py:124,190-191):
- 将 K/V 序列沿序列维度拆分为多个 split,每个 split 独立计算部分 O 和 LSE
- 最终通过
flash_fwd_combine.py合并部分结果 - 不支持
head_dim_v_padded >= 192的 SplitKV
Paged KV (flash_fwd_sm100.py:139):
use_tma_KV=True:TMA 路径,page_size 必须等于 tile_nuse_tma_KV=False(paged_kv_non_tma):cp.async 路径,支持任意 page_size
持久化内核 (flash_fwd_sm100.py:130,176):
is_persistent=True时使用StaticPersistentTileScheduler- 一个 CTA 处理多个 tile,减少 kernel launch 开销
- 支持 CLC(Cooperative Launch Controller)动态调度器
FP8 支持 (flash_fwd_sm100.py:105-111,443-460):
- 通过
DescaleTensors提供 Q/K/V 的反量化缩放因子 - FP8 的 MMA 使用
q_dtype.width == 8的数据类型
3.2 Warp 特化
SM100 内核使用高度特化的 warp 分配(flash_fwd_sm100.py:249-286):
| Warp ID | 角色 | 说明 |
|---|---|---|
| 0-3 | Softmax 0 | 计算 S0 的 softmax(第一个 Q stage) |
| 4-7 | Softmax 1 | 计算 S1 的 softmax(第二个 Q stage) |
| 8-11 | Correction | O 的在线缩放修正 |
| 12 | MMA | 执行 QK 和 PV 的 UMMA |
| 13 | Epilogue | 写出 O 到全局内存 |
| 14 | Load | TMA/cp.async 数据加载 |
| 15 | Empty | 空闲(或 CLC 调度器) |
当 q_stage=1 时,Softmax 1 warps 被重新分配为 Load warps 或 Empty warps。
寄存器分配 (flash_fwd_sm100.py:304-323):
通过 _TUNING_CONFIG 查找表根据 (use_2cta, is_causal, head_dim_padded, is_sm103) 确定各 warp 的寄存器数量。总预算为 512 个寄存器/线程,在 softmax、correction 和 other warps 之间分配。
TMEM 布局 (flash_fwd_sm100.py:289-299):
tmem_s_offset = [0, n_block_size]:S0 和 S1 的 TMEM 偏移tmem_o_offset:O 的 TMEM 偏移(紧跟 S 之后)tmem_p_offset:P 的 TMEM 偏移(S 的后半部分,因为 P 只需 FP16 部分)
3.3 流水线架构
SM100 的流水线比 SM90 更复杂(flash_fwd_sm100.py:914-999):
PipelineTmaUmma/PipelineAsyncUmma:Q 和 K/V 的加载流水线,连接 TMA/cp.async 生产者和 UMMA 消费者PipelineUmmaAsync:S → Softmax 的信号流水线,MMA warp 通知 softmax warp S 已就绪PipelineAsyncUmma:P → MMA 的信号流水线,softmax warp 通知 MMA warp P 已写入 TMEMPipelineUmmaAsync:O → Correction 的信号流水线PipelineAsync:Softmax stats → Correction 的信号流水线
#mermaid-svg-8Qh8m9TRjQYJFjeN{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-8Qh8m9TRjQYJFjeN .error-icon{fill:#552222;}#mermaid-svg-8Qh8m9TRjQYJFjeN .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-8Qh8m9TRjQYJFjeN .marker{fill:#333333;stroke:#333333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .marker.cross{stroke:#333333;}#mermaid-svg-8Qh8m9TRjQYJFjeN svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-8Qh8m9TRjQYJFjeN p{margin:0;}#mermaid-svg-8Qh8m9TRjQYJFjeN .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .cluster-label text{fill:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .cluster-label span{color:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .cluster-label span p{background-color:transparent;}#mermaid-svg-8Qh8m9TRjQYJFjeN .label text,#mermaid-svg-8Qh8m9TRjQYJFjeN span{fill:#333;color:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .node rect,#mermaid-svg-8Qh8m9TRjQYJFjeN .node circle,#mermaid-svg-8Qh8m9TRjQYJFjeN .node ellipse,#mermaid-svg-8Qh8m9TRjQYJFjeN .node polygon,#mermaid-svg-8Qh8m9TRjQYJFjeN .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .rough-node .label text,#mermaid-svg-8Qh8m9TRjQYJFjeN .node .label text,#mermaid-svg-8Qh8m9TRjQYJFjeN .image-shape .label,#mermaid-svg-8Qh8m9TRjQYJFjeN .icon-shape .label{text-anchor:middle;}#mermaid-svg-8Qh8m9TRjQYJFjeN .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .rough-node .label,#mermaid-svg-8Qh8m9TRjQYJFjeN .node .label,#mermaid-svg-8Qh8m9TRjQYJFjeN .image-shape .label,#mermaid-svg-8Qh8m9TRjQYJFjeN .icon-shape .label{text-align:center;}#mermaid-svg-8Qh8m9TRjQYJFjeN .node.clickable{cursor:pointer;}#mermaid-svg-8Qh8m9TRjQYJFjeN .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .arrowheadPath{fill:#333333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-8Qh8m9TRjQYJFjeN .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-8Qh8m9TRjQYJFjeN .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-8Qh8m9TRjQYJFjeN .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-8Qh8m9TRjQYJFjeN .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .cluster text{fill:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN .cluster span{color:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-8Qh8m9TRjQYJFjeN .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-8Qh8m9TRjQYJFjeN rect.text{fill:none;stroke-width:0;}#mermaid-svg-8Qh8m9TRjQYJFjeN .icon-shape,#mermaid-svg-8Qh8m9TRjQYJFjeN .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-8Qh8m9TRjQYJFjeN .icon-shape p,#mermaid-svg-8Qh8m9TRjQYJFjeN .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-8Qh8m9TRjQYJFjeN .icon-shape .label rect,#mermaid-svg-8Qh8m9TRjQYJFjeN .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-8Qh8m9TRjQYJFjeN .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-8Qh8m9TRjQYJFjeN .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-8Qh8m9TRjQYJFjeN :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} Epilogue Warp
Correction Warps
Softmax Warps
MMA Warp
Load Warp
pipeline_kv
pipeline_s_p_o
pipeline_p_lastsplit
pipeline_o_acc
pipeline_o_epi
TMA 加载 Q
TMA 加载 K/V
UMMA: S = Q @ K^T
信号: S 就绪
UMMA: O += P @ V
信号: O 就绪
等待 S 就绪
从 TMEM 读取 S
Online Softmax + exp2
写入 P 到 TMEM
信号: P 就绪
等待 O 就绪
从 TMEM 读取 O
Rescale O
写回 TMEM
等待最终 O
TMA 写出 O
3.4 Exp2 仿真
SM100 的 softmax 需要计算 exp2(S * scale - row_max * scale)。在非 SM103 的 Blackwell GPU 上,硬件 exp2 指令较慢,因此使用 多项式仿真 (flash_fwd_sm100.py:72-97,462-469):
ex2_emu_freq:每隔多少个 fragment 使用一次仿真ex2_emu_start_frg:从第几个 fragment 开始仿真- SM103(B300)有快速硬件 exp2,
ex2_emu_freq=0
4. flash_fwd_sm120.py --- FlashAttentionForwardSm120
FlashAttentionForwardSm120(flash_fwd_sm120.py:14-59)是 SM120(Blackwell GeForce / DGX Spark)的前向内核,继承自 FlashAttentionForwardSm80。
SM120 的关键约束是共享内存容量仅 99 KB(vs SM80 的 163 KB),因此:
- 复用 SM80 的 MMA 指令(
mma.sync.aligned.m16n8k16,即 HMMA) - 复用 SM80 的 cp.async 数据加载路径
- 仅重写
can_implement方法,将共享内存容量检查改为sm_120的 99 KB
python
# flash_fwd_sm120.py:54
smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120")
注意 arch = 80 是故意设置的(flash_fwd_sm120.py:17),因为 SM120 使用与 SM80 相同的 cp.async 代码路径,编译目标由实际 GPU 决定而非此字段。
5. flash_fwd_combine.py --- FlashAttentionForwardCombine
FlashAttentionForwardCombine(flash_fwd_combine.py:21-698)负责合并 SplitKV 产生的部分结果。当 K/V 序列很长时,SM100 内核将其拆分为多个 split 并行计算,每个 split 产生 O_partial 和 LSE_partial,最终由 Combine 内核合并。
__init__ 参数 (flash_fwd_combine.py:22-54):
dtype:输出数据类型dtype_partial:部分累积数据类型head_dim:head 维度tile_m:M 分块大小(默认 8)k_block_size:K 分块大小(默认 64)log_max_splits:最大 split 数的 log2(默认 4,即最多 16 splits)num_threads:线程数(默认 256)stages:流水线阶段数(默认 4)
合并算法 (flash_fwd_combine.py:323-665):
- 加载 LSE_partial:从全局内存加载所有 split 的 LSE 到 smem
- 计算全局 LSE :
- 对每个 row,找到所有 split 中的
lse_max - 计算缩放因子
scale_s = exp(lse_s - lse_max) - 求和
lse_sum = Σ scale_s - 最终
lse = lse_max + log(lse_sum)
- 对每个 row,找到所有 split 中的
- 加载 O_partial 并累积 :
- 使用流水线化加载,
stages-1个 stage 预取 - 对每个 split:
O += scale_s * O_partial_s
- 使用流水线化加载,
- 写出最终 O:将累积结果写回全局内存
关键优化:使用 warp_reduction_max 和 warp_reduction_sum 进行高效的 warp 级归约,使用 swizzled smem 布局避免 bank conflict。
6. flash_fwd_mla_sm100.py --- FlashAttentionMLAForwardSm100
FlashAttentionMLAForwardSm100(flash_fwd_mla_sm100.py:53-200)实现了 Multi-head Latent Attention(MLA)的前向内核,这是 DeepSeek-V2/V3 提出的注意力变体。
MLA 的核心思想:将 K/V 压缩到低维潜在空间,注意力计算在压缩空间中进行。
关键参数 (flash_fwd_mla_sm100.py:54-69):
hdim=64:Q/K 的 head 维度(较小)hdimv=512:V 的 head 维度(较大)qhead_per_kvhead=128:MQA 配置,128 个 Q head 共享 1 个 KV headis_topk_gather=True:是否使用 TopK 聚合 KVtopk_length=2048:TopK 选择的序列长度
Warp 分配 (flash_fwd_mla_sm100.py:110-142):
MLA 使用 12 或 16 个 warp,包括专门的 cpasync_load_warp_indices(4 个 warp)用于 cp.async 加载 KV,以及 relay_warp_id 用于中继。
2CTA 配置 (flash_fwd_mla_sm100.py:172-176):
use_2cta_instrs=True,cluster_shape_mn=(2,1)cta_tile_m=64,cluster_tile_m=128(2 个 CTA 各处理 64 行)
V 维度拆分 (flash_fwd_mla_sm100.py:189):
num_hdimv_splits=2:将 hdimv=512 拆分为两个 256 的部分- 每个部分独立执行
Qv @ V^T和P @ V的 MMA
7. sm100_hd256_2cta_fmha_forward.py --- BlackwellFusedMultiHeadAttentionForward
BlackwellFusedMultiHeadAttentionForward(sm100_hd256_2cta_fmha_forward.py:36-1889)是 SM100 上 head_dim=256 的专用 2CTA 前向内核。
约束 (sm100_hd256_2cta_fmha_forward.py:60-76):
- 仅支持
(head_dim, head_dim_v) = (256, 256) - 不支持 score_mod、mask_mod、aux_tensors、pack_gqa、SplitKV、q_subtile_factor
- 固定
m_block_size=128, n_block_size=128
Warp 分配 (sm100_hd256_2cta_fmha_forward.py:116-133):
- Softmax: 4 warp (0-3)
- Correction: 4 warp (4-7)
- MMA: 1 warp (8)
- Load: 1 warp (9)
- Empty: 2 warp (10-11)
- 共 12 warp = 384 线程
MMA Tiler (sm100_hd256_2cta_fmha_forward.py:82-106):
mma_tiler = (128, 128, 256)qk_mma_tiler = (256, 128, 128):2CTA 的 QK MMA,M 维度翻倍,K 维度拆分为 128 的迭代iterations_qk = 256 // 128 = 2:QK 需要两次迭代iterations_pv = 256 // 128 = 2:PV 需要两次迭代
内核结构 (sm100_hd256_2cta_fmha_forward.py:559-1528):
内核分为四个主要部分,每个部分由不同的 warp 执行:
- Load Warp (
flash_fwd_sm100.py:808-1031):TMA 加载 Q/K/V,使用PipelineTmaUmma - MMA Warp (
sm100_hd256_2cta_fmha_forward.py:1036-1283):执行 UMMA QK 和 PV,只有 leader CTA 发起 MMA - Softmax Warps (
sm100_hd256_2cta_fmha_forward.py:1285-1401):从 TMEM 读取 S,执行 softmax,写 P 回 TMEM - Correction Warps (
sm100_hd256_2cta_fmha_forward.py:1406-1492):在线缩放 O 和最终 epilogue
TMEM 协作释放 (sm100_hd256_2cta_fmha_forward.py:1521-1527):
2CTA 模式下,TMEM 由两个 CTA 共享,因此需要 cluster 级同步后才能释放:
python
cute.arch.cluster_arrive()
cute.arch.cluster_wait()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
8. Online Softmax 累积过程
所有架构的前向内核都使用 Online Softmax 算法来避免对完整注意力矩阵的物化。以下是其核心过程:
渲染错误: Mermaid 渲染失败: Parse error on line 12: ...um₁ = row_sum₀ * exp(...) + Σ P₁] -----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'PS'
Online Softmax 的数学原理:
对于每个新的 K/V block j:
- 计算新的 row_max:
m_new = max(m_old, max(S_j)) - 计算缩放因子:
α = exp(m_old - m_new) - 缩放累积的 O 和 row_sum:
O *= α,row_sum *= α - 计算新的 P:
P_j = exp(S_j - m_new) - 更新 row_sum:
row_sum += Σ P_j - 更新 O:
O += P_j @ V_j
最终:O_final = O / row_sum,LSE = m_final + log(row_sum)
在 SM80/90 中,这通过 Softmax.online_softmax() 和 Softmax.rescale_O() 实现。在 SM100 中,softmax 由专门的 warp 执行,correction warp 负责 O 的缩放,两者通过 pipeline barrier 协调。
9. 四种架构前向内核对比
#mermaid-svg-EgaAYWenYTf8gVwQ{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-EgaAYWenYTf8gVwQ .error-icon{fill:#552222;}#mermaid-svg-EgaAYWenYTf8gVwQ .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-EgaAYWenYTf8gVwQ .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-EgaAYWenYTf8gVwQ .marker{fill:#333333;stroke:#333333;}#mermaid-svg-EgaAYWenYTf8gVwQ .marker.cross{stroke:#333333;}#mermaid-svg-EgaAYWenYTf8gVwQ svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-EgaAYWenYTf8gVwQ p{margin:0;}#mermaid-svg-EgaAYWenYTf8gVwQ .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ .cluster-label text{fill:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ .cluster-label span{color:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ .cluster-label span p{background-color:transparent;}#mermaid-svg-EgaAYWenYTf8gVwQ .label text,#mermaid-svg-EgaAYWenYTf8gVwQ span{fill:#333;color:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ .node rect,#mermaid-svg-EgaAYWenYTf8gVwQ .node circle,#mermaid-svg-EgaAYWenYTf8gVwQ .node ellipse,#mermaid-svg-EgaAYWenYTf8gVwQ .node polygon,#mermaid-svg-EgaAYWenYTf8gVwQ .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-EgaAYWenYTf8gVwQ .rough-node .label text,#mermaid-svg-EgaAYWenYTf8gVwQ .node .label text,#mermaid-svg-EgaAYWenYTf8gVwQ .image-shape .label,#mermaid-svg-EgaAYWenYTf8gVwQ .icon-shape .label{text-anchor:middle;}#mermaid-svg-EgaAYWenYTf8gVwQ .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-EgaAYWenYTf8gVwQ .rough-node .label,#mermaid-svg-EgaAYWenYTf8gVwQ .node .label,#mermaid-svg-EgaAYWenYTf8gVwQ .image-shape .label,#mermaid-svg-EgaAYWenYTf8gVwQ .icon-shape .label{text-align:center;}#mermaid-svg-EgaAYWenYTf8gVwQ .node.clickable{cursor:pointer;}#mermaid-svg-EgaAYWenYTf8gVwQ .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-EgaAYWenYTf8gVwQ .arrowheadPath{fill:#333333;}#mermaid-svg-EgaAYWenYTf8gVwQ .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-EgaAYWenYTf8gVwQ .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-EgaAYWenYTf8gVwQ .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-EgaAYWenYTf8gVwQ .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-EgaAYWenYTf8gVwQ .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-EgaAYWenYTf8gVwQ .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-EgaAYWenYTf8gVwQ .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-EgaAYWenYTf8gVwQ .cluster text{fill:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ .cluster span{color:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-EgaAYWenYTf8gVwQ .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-EgaAYWenYTf8gVwQ rect.text{fill:none;stroke-width:0;}#mermaid-svg-EgaAYWenYTf8gVwQ .icon-shape,#mermaid-svg-EgaAYWenYTf8gVwQ .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-EgaAYWenYTf8gVwQ .icon-shape p,#mermaid-svg-EgaAYWenYTf8gVwQ .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-EgaAYWenYTf8gVwQ .icon-shape .label rect,#mermaid-svg-EgaAYWenYTf8gVwQ .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-EgaAYWenYTf8gVwQ .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-EgaAYWenYTf8gVwQ .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-EgaAYWenYTf8gVwQ :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} SM120 Blackwell-Geo
cp.async 同 SM80
HMMA 同 SM80
__syncthreads 同 SM80
同 SM80, 99KB smem 限制
SM100 Blackwell
TMA + UMMA 集成
UMMA + TMEM 累加器
多层 pipeline: Load→MMA→Softmax→Correction→Epilogue
16 warp 高度特化: Load/MMA/Softmax/Correction/Epilogue
SM90 Hopper
TMA 张量加载
WGMMA 64xNxK
mbarrier 生产者-消费者流水线
生产者 warp + 消费者 warp group
SM80 Ampere
cp.async 异步拷贝
HMMA 16x8x16
__syncthreads + cp_async_wait_group
128-256 线程统一角色
| 特性 | SM80 | SM90 | SM100 | SM120 |
|---|---|---|---|---|
| 数据加载 | cp.async | TMA / cp.async | TMA / cp.async | cp.async |
| MMA 指令 | HMMA | WGMMA | UMMA (tcgen05) | HMMA |
| 累加器位置 | 寄存器 | 寄存器 | TMEM | 寄存器 |
| 流水线 | cp_async_wait_group | mbarrier Pipeline | 多层 Pipeline | cp_async_wait_group |
| Warp 特化 | 无 | 生产者/消费者 | 5 类 warp | 无 |
| 2CTA | 不支持 | 不支持 | 支持 | 不支持 |
| SplitKV | 不支持 | 不支持 | 支持 | 不支持 |
| Paged KV | 不支持 | TMA/cp.async | TMA/cp.async | 不支持 |
| FP8 | 不支持 | 不支持 | 支持 | 不支持 |
| 持久化 | 不支持 | 不支持 | 支持 | 不支持 |
| 共享内存 | 163 KB | 228 KB | 227 KB | 99 KB |
| head_dim | ≤128 | ≤128 | ≤256 | ≤128 |
【总】收尾
FA4 CuTeDSL 前向内核通过架构特化实现跨代 GPU 的最优性能。从 SM80 的简单 cp.async + HMMA 模型,到 SM90 的 TMA + WGMMA + 生产者-消费者流水线,再到 SM100 的 UMMA + TMEM + 2CTA + 多层 warp 特化流水线,每一代架构都充分利用了硬件的新特性。
SM120 作为 Blackwell 的消费级版本,因共享内存限制回退到 SM80 路径。SplitKV 和 Combine 内核解决了超长序列的并行度问题。MLA 内核为 DeepSeek 风格的注意力提供了专门优化。hd256 专用内核则为大规模 head 维度场景提供了最高性能。
所有这些内核共享同一套 Online Softmax 数学框架,但在数据搬运、计算指令、同步机制上做了深度架构适配,体现了 FA4 "一次算法、多代优化"的设计哲学。