第十三章:ROCm/AMD 后端
13.1 rocm_attn.py 逐行解析 (533行)
13.1.1 RocmAttentionBackend
AMD ROCm平台的标准注意力后端:
- 使用
rocm_aiter_ops--- AMD aiter库封装 get_kv_cache_shape():(2, num_blocks, block_size, num_kv_heads, head_size)--- 标准NHD布局- 支持
fp8KV cache (ROCm FP8格式)
13.1.2 RocmAttentionMetadata
python
@dataclass
class RocmAttentionMetadata:
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: torch.Tensor | None
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None
scheduler_metadata: torch.Tensor | None
prefix_scheduler_metadata: torch.Tensor | None
- Cascade Attention : ROCm支持cascade (公共前缀+后缀分离计算)
cu_prefix_query_lens: 前缀查询长度prefix_kv_lens/suffix_kv_lens: 前缀/后缀KV长度
scheduler_metadata: aiter调度器元数据(AOT调度)
13.1.3 RocmAttentionMetadataBuilder
_cudagraph_support = ALWAYS: 完全支持CUDA Graphbuild_for_cudagraph_capture(): 特殊处理 --- seq_lens设为1避免graph capture过慢build(): 构建metadata → 支持cascade attention路径
13.1.4 RocmAttentionImpl
forward() 方法:
- KV cache写入:
triton_reshape_and_cache_flash()→ 写入KV到paged cache - Prefill路径:
chunked_prefill_paged_decode()--- ROCm chunked prefill kernel - Decode路径:
rocm_aiter_ops.decode_forward()--- aiter decode kernel - Cascade路径: 公共前缀 → 后缀 → merge_attn_states() 合并
- FP8 KV cache: 使用
k_scale/v_scale反量化
13.2 rocm_aiter_fa.py 逐行解析 (1460行)
这是attention模块中最大的单个后端文件。
13.2.1 核心Triton Kernel --- cp_mha_gather_cache_kernel
功能: Context Parallelism(CP)场景下,从远程GPU gather KV cache到本地。
逐行解析:
token_id = tl.program_id(0): 每个program处理1个tokenhead_id = tl.program_id(1): 每个program处理1个headbatch_idx = tl.load(token_to_batch_ptr + token_id): token→batch映射batch_start = tl.load(seq_start_ptr + batch_idx): batch起始位置block_offset = batch_offset // PAGE_SIZE: 计算block偏移block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset): 查block_tableslot_id = batch_offset % PAGE_SIZE: block内slot偏移- NHD布局 :
key_cache[block_id, slot, head, col]= key[token, head, col] - HND布局 :
key_cache[block_id, head, slot, col]= key[token, head, col] - FP8反量化(DEQUANT) :
key = key_cache * k_scale
设计意图: CP需要从其他GPU收集KV → 拼接到本地 → 执行注意力。此kernel实现了高效的跨GPU KV gather。
13.2.2 RocmAiterFABackend
- 继承
AttentionBackend - 支持
fp8KV cache - CP(Context Parallelism): 支持
_CP_TOKENS_PER_ITER_ROCM = 32K每次迭代token数 - KV cache布局: 优先NHD (num_blocks, page_size, num_heads, head_dim)
13.2.3 RocmAiterFAMetadataBuilder
- 支持
split_decodes_prefills_and_extends()--- 三分类(decode/prefill/extend) - Extend: CP场景下的KV扩展(gather远程KV后扩展本地cache)
- CP路径: 构建
cp_seqlen(本地+远程KV长度) - Workspace管理: 预分配gather缓冲区
13.2.4 RocmAiterFAImpl
forward() 方法复杂流程:
- 非CP路径: 标准prefill/decode
- CP路径 :
a. gather远程KV到workspace → 拼接本地KV
b. Prefill: 在完整KV上做attention
c. 输出修正:correct_attn_cp_output()--- CP输出聚合 - FP8 KV cache: 使用scale反量化
- Sliding Window: 限制KV扫描范围
- Cascade: 公共前缀分离计算
13.3 rocm_aiter_unified_attn.py 逐行解析 (318行)
ROCm统一注意力后端:
- 将prefill/decode/cascade统一到单个后端
- 使用aiter库的
flash_attention_forward统一接口 - 支持FP8 KV cache
- 自动选择prefill/decode kernel路径
第十四章:CPU/Mamba/GDN/Tree/TurboQuant 等专用后端
14.1 cpu_attn.py 逐行解析 (498行)
14.1.1 CPUAttentionBackend
CPU平台专用注意力后端:
supported_dtypes: [fp16, bf16, fp32] --- CPU支持更多精度get_supported_head_sizes(): [32,64,80,96,112,128,160,192,224,256,512] --- CPU更灵活supports_attn_type(): 支持所有4种注意力类型(DECODER/ENCODER/ENCODER_ONLY/ENCODER_DECODER)get_kv_cache_shape():(2, num_blocks, num_kv_heads, block_size, head_size)--- CPU用HND布局use_cascade_attention(): False --- CPU不支持cascade
14.1.2 CPUAttentionMetadata
python
@dataclass
class CPUAttentionMetadata:
isa: str # CPU指令集(avx2/avx512/sve等)
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
scheduler_metadata: torch.Tensor | None
causal: bool = True
use_sdpa_prefill: bool = False # 是否使用SDPA prefill
num_decode_tokens: int = 0
sdpa_attn_masks: list | None = None # SDPA attention masks
sdpa_start_loc: torch.Tensor | None = None
isa: CPU架构指令集,决定使用哪个kerneluse_sdpa_prefill: PyTorch SDPA (Scaled Dot Product Attention) 用于prefillsdpa_attn_masks: SDPA需要的attention mask列表
14.1.3 CPUAttentionImpl
forward() 方法:
- KV cache写入 :
ops.reshape_and_cache_cpu()--- CPU版KV写入 - Prefill路径 :
- SDPA可用:
torch.nn.functional.scaled_dot_product_attention()--- PyTorch原生 - 不可用: 手动实现Q×K^T/√d → softmax → ×V
- SDPA可用:
- Decode路径 :
ops.paged_attention_v1()--- CPU分页注意力 - FP8 KV cache :
ops.convert_fp8_cpu()反量化
14.2 mamba_attn.py 逐行解析 (589行)
14.2.1 BaseMambaAttentionMetadata
Mamba状态空间模型(SSM)的metadata:
- 不使用传统Q/K/V → 无KV Cache → 无attention计算
has_initial_states_p: 是否有初始状态(prefill时从checkpoint加载)state_indices_tensor_p/d: prefill/decode的状态索引num_accepted_tokens: spec decode接受的token数(用于加载正确checkpoint)block_idx_*: 前缀缓存相关block索引cu_chunk_seqlen_p: chunked prefill的累积chunk长度
14.2.2 BaseMambaAttentionMetadataBuilder
reorder_batch_threshold = 1: 不做reorder(Mamba不支持混合batch)supports_update_block_table = True: 支持增量更新block_tablespeculative_config: 推测解码配置use_spec_decode: 是否使用spec decode
build() 方法:
split_decodes_and_prefills(): 分离decode/prefill- Prefill: 构建状态索引 + 初始状态标记 + chunked prefill元数据
- Decode: 构建decode状态索引 + spec decode token处理
- Prefix caching: 计算block索引 (Mamba的prefix是状态checkpoint)
14.3 mamba1_attn.py 逐行解析 (64行)
Mamba-1 SSM后端:
- 仅64行,thin wrapper
- 使用
selective_scan_fn--- Mamba-1的selective scan算子 - 无KV Cache →
get_kv_cache_shape()返回(0,)
14.4 mamba2_attn.py 逐行解析 (171行)
Mamba-2 SSM后端:
- 使用
ssd_fn--- Mamba-2的SSD(Structured State Duality)算法 - 比Mamba-1更高效的矩阵化实现
- 支持chunked prefill + decode
get_kv_cache_shape(): 返回(num_blocks, block_size, state_dim)--- SSM状态缓存
14.5 gdn_attn.py 逐行解析 (475行)
GDN (Gated DeltaNet) 注意力后端:
14.5.1 GDNAttentionBackend
is_ssm() = True: GDN被归类为SSM(类似Mamba)- 不使用标准Q/K/V → 使用门控递推(gated recurrence)
- KV Cache: 状态缓存(类似Mamba)
14.5.2 GDNAttentionMetadata
num_spec_decodes / num_spec_decode_tokens: Spec decode计数has_initial_state: 是否有初始状态spec_query_start_loc / non_spec_query_start_loc: spec/non-spec分离chunk_indices / chunk_offsets: FLA(Flash Linear Attention) chunk元数据nums_dict / batch_ptr / token_chunk_offset_ptr: Triton causal_conv1d参数
14.5.3 GDNAttentionMetadataBuilder
- 类似Mamba的metadata构建
- Spec decode: 分离spec和非spec请求
compute_causal_conv1d_metadata(): 计算causal_conv1d所需的元数据mamba_get_block_table_tensor(): 获取Mamba格式的block_table
14.6 tree_attn.py 逐行解析 (442行)
Tree Attention --- 推测解码(Spec Decoding)专用注意力:
14.6.1 TreeAttentionBackend
get_supported_kernel_block_sizes(): [MultipleOf(16)] --- block_size必须是16的倍数forward_includes_kv_cache_update = False: forward不包含KV写入(需外部处理)use_cascade_attention() = False: 不支持cascade
14.6.2 TreeAttentionMetadata
python
@dataclass
class TreeAttentionMetadata:
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
tree_attn_bias: torch.Tensor | None = None # 树形注意力掩码
tree_attn_bias: 核心字段 --- 树形注意力掩码- 定义了draft token之间的因果关系
- 每个draft token只attend到其祖先token
14.6.3 TreeAttentionImpl
forward() 方法:
- 使用
unified_attention()(Triton统一注意力)执行实际计算 tree_attn_bias控制哪些token对之间可以attend- 结果: 找到最长匹配前缀 → accept/reject draft tokens
设计意图: 一次性验证整棵draft tree → 比逐token串行验证更高效。多个draft路径并行计算,找到最深匹配。
14.7 turboquant_attn.py 逐行解析 (800行)
TurboQuant注意力 --- KV Cache极致压缩:
14.7.1 核心概念
-
KV Cache布局 :
(num_blocks, block_size, num_kv_heads, slot_size)slot_size = key_packed_size + value_fp16_size- 例: turboquant_k3v4_nc + head_dim=256: [100 bytes key | 512 bytes value] = 612
-
量化模式:
turboquant_k8v4: 8-bit key + 4-bit valueturboquant_4bit_nc: 4-bit非均匀量化turboquant_k3v4_nc: 3-bit key + 4-bit value (最激进)turboquant_3bit_nc: 3-bit非均匀量化
14.7.2 _build_hadamard --- Hadamard矩阵
python
def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
H = torch.tensor([[1.0]])
while H.shape[0] < d:
H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
return (H / math.sqrt(d)).to(torch.device(device_str))
- Sylvester构造的Hadamard矩阵
- 用于Walsh-Hadamard变换(WHT) → 量化前的正交旋转
@functools.cache: 缓存避免重复构造- 优势: 单次cuBLAS GEMM代替log2(D)次butterfly操作
14.7.3 TurboQuantAttentionImpl
forward() 流程:
-
Prefill路径:
- 标准SDPA/FlashAttention计算attention
triton_turboquant_store(): 将KV量化写入压缩cache- 小continuation(≤128 tokens): 直接用TQ decode kernel
-
Decode路径:
triton_turboquant_decode_attention(): 从压缩cache读取+解量化+计算attention- Stage1: MSE key解量化 + 质心查表
- Stage2: 3/4-bit value解包 + softmax + weighted sum
-
KV Cache写入(do_kv_cache_update):
- Hadamard旋转 → 量化 → 打包到slot
- FP8/MSE量化 + 3/4-bit value打包
14.8 flash_attn_diffkv.py 逐行解析 (284行)
FlashAttention DiffKV后端:
- 支持差异化KV维度(key和value的head_dim不同)
- 用于某些模型(K头维度 ≠ V头维度)
- FlashAttention API原生支持不同KV维度
flash_attn_varlen_func(dv=...): 传入不同的V维度
14.9 linear_attn.py 逐行解析 (93行)
Linear Attention后端:
- O(n)复杂度注意力(而非标准O(n²))
- 使用kernel trick: φ(Q) × (φ(K)^T × V) → 先计算K^T×V(无序列长度依赖)
- 93行精简实现
- 用于Linear Transformer/Perceiver等模型
14.10 short_conv_attn.py 逐行解析 (34行)
短卷积注意力:
- 仅34行,最短的后端
- 不使用attention → 使用局部1D卷积
- 用于某些Mamba-1变体的short convolution层
- 无KV Cache
第十五章总结:专用后端设计哲学
15.1 平台适应性矩阵
| 后端 | NVIDIA | AMD | Intel XPU | CPU |
|---|---|---|---|---|
| FlashMLA | ✅ SM90+ | - | - | - |
| FlashMLASparse | ✅ SM90+ | - | - | - |
| FlashInferMLA | ✅ | - | - | - |
| FlashAttnMLA | ✅ | - | - | - |
| CUTLASSMLA | ✅ SM90+ | - | - | - |
| TritonMLA | ✅ | ✅ | ✅ | - |
| ROCmAiterMLA | - | ✅ | - | - |
| XPUMLASparse | - | - | ✅ | - |
| ROCmAttn | - | ✅ | - | - |
| ROCmAiterFA | - | ✅ | - | - |
| CPUAttn | - | - | - | ✅ |
15.2 MLA后端选择决策
DeepSeek-V3.2模型 → 选择MLA后端:
├── NVIDIA GPU:
│ ├── SM90+ (Hopper/Blackwell) + Dense → FlashMLA
│ ├── SM90+ + Sparse → FlashMLASparse
│ ├── SM90+ + CUTLASS → CUTLASSMLA
│ └── 其他SM → TritonMLA (跨平台fallback)
├── AMD GPU:
│ ├── aiter可用 → ROCmAiterMLA
│ └── aiter不可用 → TritonMLA
└── Intel XPU:
└── XPUMLASparse
15.3 特殊后端适用场景
| 后端 | 适用场景 | 核心优势 |
|---|---|---|
| MambaAttn | 状态空间模型 | 无KV Cache, 状态递推 |
| TreeAttn | 推测解码验证 | 树形并行验证draft |
| TurboQuant | 长上下文+显存受限 | KV极致压缩(3-4bit) |
| GDNAttn | GatedDeltaNet | 门控递推, 类Mamba |
| LinearAttn | Linear Transformer | O(n)注意力 |
| ShortConv | 局部卷积 | 无需全局注意力 |
| FlashAttnDiffKV | K/V维度不同 | 差异化KV支持 |
Part3 完 --- 涵盖MLA体系(13文件) + ROCm/AMD后端(3文件) + CPU/Mamba/GDN/Tree/TurboQuant等(10文件) = 26个文件的完整逐行分析
vLLM v1 Attention 底层算子(ops/)超深度逐行分析 --- Part 4
分析范围 : vLLM v1 attention 子系统 ops/ 目录下全部算子文件
源码根目录 :vllm/v1/attention/ops/
分析方法 : 逐行解析代码作用、语法、逻辑意图,说明每个变量/函数/模块的设计目的
撰写风格: 资深架构师视角,严谨、不漏细节
第十五章:通用算子(common.py)逐行解析
文件 :
ops/common.py(465行)
职责: 上下文并行(Context Parallelism, CP)注意力输出修正、序列打包/解包工具函数
15.1 模块总览
common.py 是 vLLM v1 attention 算子层的"基础设施"模块,提供两类核心功能:
- CP 注意力输出修正:当使用上下文并行(Context Parallelism)时,每个 rank 只看到部分 KV 序列,产生局部注意力输出和局部 log-sum-exp(LSE)。需要在 rank 间 all-gather LSE 后修正各 rank 的输出,再通过 reduce-scatter 或 all-reduce 得到最终结果。
- 序列打包/解包(pack/unpack):将不等长序列压缩为定长 batch 张量(padding),以及反向解包。这对 MLA 稀疏注意力等需要 batch 维度的算子至关重要。
15.2 导入与全局变量
python
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
torch:PyTorch 核心库,提供张量操作GroupCoordinator:vLLM 分布式通信协调器,封装 NCCL 集合通信tl, triton:Triton JIT 编译框架及其内建库(tl为 triton.language)
15.3 _correct_attn_cp_out_kernel --- CP注意力输出修正Triton核
15.3.1 函数签名与参数
python
@triton.jit
def _correct_attn_cp_out_kernel(
outputs_ptr, # 输入/输出注意力结果 [B, H, D]
new_output_ptr, # 修正后输出 [B, H, D]
lses_ptr, # 所有rank的LSE值 [N, B, H](N=world_size)
vlse_ptr, # 输出:全局LSE [B, H]
outputs_stride_B, outputs_stride_H, outputs_stride_D, # outputs的步幅
lses_stride_N, lses_stride_B, lses_stride_H, # lses的步幅
lse_idx, # 当前rank在CP组中的rank id
HEAD_DIM: tl.constexpr, # 头维度(编译时常量)
N_ROUNDED: tl.constexpr, # rank数量(编译时常量)
IS_BASE_E: tl.constexpr, # LSE是否以e为底(True)或以2为底(False)
)
设计意图 :此核修正 CP 场景下各 rank 的局部注意力输出。每个 rank 对自己的 KV 分片计算注意力后,得到局部 out_local 和局部 lse_local。但 softmax 的归一化范围应该是全局的(所有 KV 分片),所以需要:
- 从 all-gathered LSE 计算全局 LSE
- 用
exp(lse_local - lse_global)作为缩放因子修正输出
15.3.2 逐行解析
python
batch_idx = tl.program_id(axis=0).to(tl.int64) # 批次索引
head_idx = tl.program_id(axis=1).to(tl.int64) # 头索引
d_offsets = tl.arange(0, HEAD_DIM) # 头维度偏移向量
num_n_offsets = tl.arange(0, N_ROUNDED) # rank偏移向量
- 每个 program 处理一个 (batch, head) 对
d_offsets和num_n_offsets用于向量化加载
python
lse_offsets = (
num_n_offsets * lses_stride_N
+ batch_idx * lses_stride_B
+ head_idx * lses_stride_H
)
- 计算所有 N 个 rank 的 LSE 在
lses_ptr中的偏移
python
lse = tl.load(lses_ptr + lse_offsets) # 加载 [N] 个LSE值
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
- NaN 和 +inf 处理 :
lse != lse检测 NaN(IEEE 754 下 NaN ≠ 自身),lse == inf检测正无穷。这些异常值替换为-inf,在 log-sum-exp 中等效于"无贡献"。
python
lse_max = tl.max(lse, axis=0) # 所有rank的LSE最大值(数值稳定性)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) # 全为-inf时置0
lse -= lse_max # 减去最大值(数值稳定化的标准操作)
- log-sum-exp 数值稳定性 :先减去最大值再取 exp,避免溢出。若所有 LSE 为
-inf(空序列),最大值设为 0 以避免(-inf) - (-inf) = NaN。
python
if IS_BASE_E:
lse_exp = tl.exp(lse) # e^(lse - max)
lse_acc = tl.sum(lse_exp, axis=0) # Σ e^(lse - max)
lse = tl.log(lse_acc) # log(Σ e^(lse - max))
else:
lse_exp = tl.exp2(lse) # 2^(lse - max)
lse_acc = tl.sum(lse_exp, axis=0)
lse = tl.log2(lse_acc) # log2(Σ 2^(lse - max))
lse += lse_max # 加回最大值 → 全局LSE
- 双基底支持 :FlashAttention 2/3 内部使用
log2计算 LSE(更快,因为 GPU 上exp2比exp快),但数学上 LSE 通常以 e 为底。IS_BASE_E标志决定使用哪种基底。 - 最终公式 :
lse_global = lse_max + log(Σ_i exp(lse_i - lse_max))
python
tl.store(vlse_ptr + lse_offsets, lse) # 只存 [B, H] 部分(去掉N维度)
- 存储全局 LSE 供后续使用
python
lse_offset = (lse_idx * lses_stride_N + batch_idx * lses_stride_B
+ head_idx * lses_stride_H)
lse_tmp = tl.load(lses_ptr + lse_offset) # 当前rank的原始LSE
lse_finally = lse_tmp - lse # 局部LSE - 全局LSE
lse_finally = tl.where(
(lse_finally != lse_finally) | (lse_finally == float("inf")),
-float("inf"), lse_finally,
)
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
- 修正因子 :
exp(lse_local - lse_global)是将局部归一化输出缩放到全局归一化输出的因子 - 数学原理:局部输出
o_local = Σ(p_i * v_i) / Σ(p_i),其中p_i = exp(q·k_i - lse_local)。全局输出o_global = Σ(p_i * v_i) / Σ_all(p_i)。二者关系为o_global = o_local * exp(lse_local - lse_global)
python
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
- 加载原始输出、乘以修正因子、存储
15.4 CPTritonContext --- 避免Triton重编译的上下文
python
class CPTritonContext:
def __init__(self):
self.inner_kernel = None
def call_kernel(self, kernel, grid, *regular_args, **const_args):
if self.inner_kernel is None:
self.inner_kernel = kernel[grid](*regular_args, **const_args)
else:
self.inner_kernel[grid](*regular_args)
设计意图 :Triton JIT 编译的 kernel 在首次调用时编译,后续调用复用编译结果。但 kernel[grid](...) 语法每次都会创建新的启动配置对象。CPTritonContext 缓存已编译的 kernel 实例,避免重复创建。
- 首次调用:编译 + 启动
- 后续调用:仅启动(grid 可能不同,但 const_args 不变)
15.5 correct_attn_out --- CP注意力输出修正入口
python
def correct_attn_out(
out: torch.Tensor, # [B, H, D]
lses: torch.Tensor, # [N, B, H]
cp_rank: int, # 当前rank在CP组中的rank id
ctx: CPTritonContext, # Triton上下文
is_lse_base_on_e: bool = True, # LSE基底
) -> tuple[torch.Tensor, torch.Tensor]:
维度规范化:
python
if out.ndim == 4 and out.shape[1] == 1:
out = out.squeeze(1) # [B, 1, H, D] → [B, H, D]
- 有些后端(如 FlashAttention)可能输出 4D 张量,squeeze 掉大小为 1 的维度
python
if lses.ndim == 4 and lses.shape[-1] == 1:
lses = lses.squeeze(-1)
if lses.ndim == 4 and lses.shape[1] == 1:
lses = lses.squeeze(1)
- LSE 可能被封装为
[N, 1, B, H]或[N, B, H, 1],squeeze 到 3D
步幅处理:
python
lse = torch.empty_strided(
(B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
)
- 关键设计 :
lse的步幅必须与lses在 B/H 维度上的步幅一致,因为核函数中vlse_ptr使用lses_stride_B/H写入。这保证了即使lses是非连续视图(如 4D squeeze),写入位置也正确。
核启动配置:
python
grid = (B, H, 1) # 每个(batch, head)对一个program
15.6 _cp_lse_common --- CP LSE通用处理
python
def _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=None, is_lse_base_on_e=True):
if cp_group.world_size == 1:
return cp_attn_out # 单rank无需修正
- 短路:如果 CP 组只有 1 个 rank(即未启用 CP),直接返回
python
lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
(cp_group.world_size,) + cp_attn_lse.shape
)
- All-gather 所有 rank 的 LSE,reshape 为
[N, B, H]
15.7 cp_lse_ag_out_rs / cp_lse_ag_out_ar --- 两种CP输出聚合策略
cp_lse_ag_out_rs(AllGather + ReduceScatter):
python
out = cp_group.reduce_scatter(out, dim=1) # 在head维度上scatter
- 每个 rank 获得自己负责的 head 子集的输出
- 适用于 TP+CP 联合场景,head 已经在 TP 中被分割
cp_lse_ag_out_ar(AllGather + AllReduce):
python
out = cp_group.all_reduce(out) # 全归约
- 所有 rank 获得完整的输出
- 适用于纯 CP 场景
15.8 _pack_seq_kernel --- 序列打包Triton核
python
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D] 输入(N=总token数)
out_ptr, # [B, Lmax, D] 输出
lengths_ptr, # [B] 各序列长度
N, D, Lmax, PAD_VALUE, BLOCK_T, BLOCK_D: tl.constexpr
)
设计意图 :将 1D 连续 token 序列(格式 [total_tokens, D])打包为 2D batch 张量 [B, Lmax, D],短序列用 PAD_VALUE(默认 -inf)填充。这对 MLA 稀疏注意力中需要 batch 维度的 MQA logits 计算至关重要。
关键逻辑:
python
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i) # 累加得到当前batch的起始位置
seq_len = tl.load(lengths_ptr + pid_b) # 当前序列长度
- 通过累加 lengths 计算每个 batch 在展平序列中的起始索引
python
valid_row = (off_t < seq_len) & t_mask # 有效token掩码
- 超出序列长度的位置写入 PAD_VALUE
15.9 pack_seq_triton / unpack_seq_triton --- 打包/解包Python入口
python
def pack_seq_triton(x, lengths, pad_value=-float("inf"), block_t=64, block_d=64):
- 支持多维输入(通过 reshape 处理)
pad_value=-inf使得 padding 位置在 softmax 中贡献为 0
python
def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64):
- 反向操作:从
[B, Lmax, ...]解包回[N, ...] - 跳过 padding 位置
第十六章:Triton注意力算子族
涵盖文件:
triton_unified_attention.py,triton_prefill_attention.py,triton_decode_attention.py,triton_reshape_and_cache_flash.py,triton_merge_attn_states.py
16.1 Triton统一注意力算子(triton_unified_attention.py, 1268行)
来源 : IBM Research 贡献,基于 vLLM 的 Triton attention 框架
职责: 提供统一的 prefill + decode 注意力计算,支持 2D(单段 softmax)和 3D(分段 softmax)两种调度模式
16.1.1 模块级常量与辅助函数
python
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
float8_info = torch.finfo(current_platform.fp8_dtype())
is_batch_invariant:环境变量控制是否假设 batch 中所有序列等长(优化路径)float8_info:FP8 数据类型的 min/max 范围,用于核内 clamp
cdiv_fn(Triton JIT 版向上取整):
python
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
apply_softcap(Gemma2 等模型使用的 tanh softcap):
python
@triton.jit
def apply_softcap(S, x):
Sdiv = S / x
p1 = tl.exp(Sdiv)
p2 = tl.exp(-Sdiv)
return x * (p1 - p2) / (p1 + p2) # x * tanh(S/x)
- 数学等价于
x * tanh(S / x),但用exp展开,避免 Triton 不直接支持tanh的问题 - 用于限制注意力 logit 范围,提升训练稳定性
16.1.2 _prepare_kv_tile --- KV 量化反量化统一接口
python
@triton.jit
def _prepare_kv_tile(
data, Q, tensor_scale, scale_cache_ptr,
physical_block_idx, seq_offset, kv_head_idx,
stride_s_blk, stride_s_slot, stride_s_head,
tile_mask, BLOCK_SIZE, KV_QUANT_MODE
)
KV_QUANT_MODE 设计:
| 值 | 含义 | 反量化策略 |
|---|---|---|
| 0 | 无量化(bf16/fp16) | 仅类型转换(实际无操作) |
| 1 | FP8 全局量化 | data * tensor_scale(每个张量一个 scale) |
| 2 | int8 per-token-head 量化 | 返回 data + token_head_scales,调用者在 dot 后乘 |
| 3 | FP8 per-token-head 量化 | 同上 |
为什么 KV_QUANT_MODE ≥ 2 时不在 tile 内反量化?:
python
if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8)
token_head_scales = tl.load(scale_cache_ptr + scale_idx, ...)
return data.to(Q.dtype), token_head_scales
- Per-token-head 量化下,每个 (token, head) 对有自己的 scale
- 将 scale 乘法融合到 softmax scale 中(
S += tl.dot(Q, K) * (scale * k_token_head_scales)),避免额外的 BLOCK_M×TILE_SIZE 逐元素乘法 - 这是数值精度和计算效率的权衡:先做整数 dot product,再乘 scale
16.1.3 find_seq_idx --- 二分查找定位token所属序列
python
@triton.jit
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
BLOCK_Q, use_q_block_mode):
- 问题:在 batch 预填充中,所有 token 展平排列,核函数需要知道当前 program 对应哪个序列
- 算法 :二分查找
query_start_len数组 use_q_block_mode=True时,query_start_len的值是cu_seqlens[i] // BLOCK_Q + i(因为 q-block 是序列级分块,需要偏移序列索引)
16.1.4 kernel_unified_attention_2d --- 2D统一注意力核
核函数组织:
Grid: (total_num_q_blocks, num_kv_heads)
每个program: 处理一个 (q_block, kv_head) 对
核心计算流程:
- 定位当前序列:
python
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True)
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
q_block_global_idx:全局 q-block 编号q_block_local_idx:在当前序列内的 q-block 编号
- 加载 Query:
python
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
- GQA 支持 :
num_queries_per_kv> 1 时,一个 kv_head 对应多个 q_head offs_m的低log2(num_queries_per_kv)位选 q_head,高位选 token
- Softmax 初始化:
python
if not USE_SINKS:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(sink_ptr + query_offset_1, ...) # 从 attention sink 初始化
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32) # 初始化为1(不是0,因为M初始为-inf时exp(-inf-(-inf))=1)
- Attention Sinks(StreamingLLM 技术):初始 M 值从 sink 张量加载,允许模型"记住"序列开头的 token,防止滑动窗口导致的注意力崩溃
- Tile循环 --- 分块KV处理:
python
for j in range(tile_start, tile_end):
seq_offset = j * TILE_SIZE + offs_t
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE)
- 遍历 KV 的 tile,通过 block_table 将逻辑位置映射到物理 cache 块
- 注意力掩码构建(causal + sliding_window + mm_prefix 三层叠加):
python
seq_mask = seq_offset[None, :] <= query_abs_pos # 因果掩码
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
# 多模态token的双向注意力范围
seq_mask |= q_in_range & k_in_range
- 掩码叠加顺序:先 causal AND sliding_window,再 OR mm_prefix
- 这与 FlexAttention 的语义一致:
(causal ∧ sliding_window) ∨ mm_prefix
- Online Softmax 累积(标准 FlashAttention 算法):
python
m_j = tl.maximum(M, tl.max(S, axis=1)) # 新的行最大值
P = tl.exp(S - m_j[:, None]) # 未归一化的注意力权重
l_j = tl.sum(P, axis=1) # 行和
alpha = tl.exp(M - m_j) # 旧累积值的缩放因子
acc = acc * alpha[:, None] # 缩放旧的累加器
L = L * alpha + l_j # 更新行和
M = m_j # 更新行最大值
acc += tl.dot(P.to(V.dtype), V) # 累积加权值
- Sliding Window 特殊处理:
python
if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where((context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0)
- 对滑动窗口外的 V 值置零,防止它们通过
tl.dot污染累加器 - 为什么在 P 上掩码还不够?因为
tl.dot的掩码语义可能不完全屏蔽,需要额外保护
- FP8 输出支持:
python
if USE_FP8:
acc = acc * tl.load(out_scale)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
16.1.5 kernel_unified_attention_3d --- 3D分段softmax核
核函数组织:
Grid: (total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
每个program: 处理一个 (q_block, kv_head, segment) 三元组
与 2D 版本的核心区别:
- 将 KV 序列分成
NUM_SEGMENTS_PER_SEQ段,每段独立计算部分 softmax(部分 M, L, acc) - 每个 segment 输出独立的
(部分输出, 部分M, 部分L) - 最终由
reduce_segments核将所有段的结果合并
为什么需要 3D 模式?
- 长序列 decode 时,单个 program 处理整个 KV 序列的 softmax 可能导致 GPU 利用率低(串行循环太长)
- 3D 模式将 KV 分段并行处理,增加并行度
- 代价:需要额外存储中间结果和一次 reduce
段范围计算:
python
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
for j in range(
max(segm_idx * tiles_per_segment, tile_start),
min((segm_idx + 1) * tiles_per_segment, tile_end),
):
Sink 初始化的特殊处理:
python
if USE_SINKS:
if segm_idx == 0:
M = tl.load(sink_ptr + ...)
else:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
- 只有第一个 segment 使用 sink 值,因为 sink 是全局初始值
16.1.6 reduce_segments --- 分段结果归约
python
@triton.jit
def reduce_segments(...)
算法(标准的 online softmax reduce):
- 加载所有 segment 的 M 值,计算全局最大值
overall_max - 加载所有 segment 的 L 值,缩放后求和
overall_expsum - 加载所有 segment 的输出,缩放后求和
- 归一化:
acc = acc_sum / overall_expsum
16.1.7 unified_attention --- Python入口与调度策略
2D vs 3D 选择逻辑:
python
if (seq_threshold_3D is None
or max_seqlen_q > 1 # 有prefill
or num_seqs > seq_threshold_3D
or is_batch_invariant):
kernel_unified_attention_2d[...] # 2D模式
else:
kernel_unified_attention_3d[...] # 3D模式(纯decode)
reduce_segments[...]
- 2D 模式:适用于 prefill + decode 混合批次
- 3D 模式:仅在纯 decode 且序列数低于阈值时启用(阈值由配置控制)
BLOCK_M / BLOCK_Q 计算:
python
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
BLOCK_Q = BLOCK_M // num_queries_per_kv
- GQA 下 BLOCK_M 需要覆盖整个 kv_group 的 q_head
- BLOCK_Q 是每个 kv_group 在一个 BLOCK_M 中覆盖的 token 数
Tile Size 选择(Gemma3 优化):
python
def _get_tile_size(head_size, sliding_window, element_size, is_prefill):
if _is_gemma3_attention(head_size, sliding_window):
return 32 # Gemma3专用
return 32 if is_prefill else (16 if element_size >= 2 else 32)
- Gemma3 使用 sliding_window=1024 + head_size=128/256 的独特组合
- 较大的 head_size 需要更大的 tile 以充分利用 tensor core
- FP8(element_size=1)需要 tile≥32 以满足对齐要求
16.2 Triton Prefill专用算子(triton_prefill_attention.py, 253行)
来源 : 适配自 SGLang,原作者 ModelTC/LightLLM
职责: 纯 prefill 场景的注意力计算,不支持 KV cache 分页(page_size=1假设)
16.2.1 核函数 _fwd_kernel
与统一核的关键区别:
- 不使用 KV cache:直接从连续的 K、V 张量读取(prefill 阶段 K/V 还未写入 cache)
- 使用 exp2 替代 exp :通过
sm_scale *= RCP_LN2(1/ln(2))将 softmax scale 融入 base-2 指数 - 支持双向滑动窗口 :
SLIDING_WINDOW_Q和SLIDING_WINDOW_K分别控制 Q→K 和 K→Q 方向的窗口
注意力掩码:
python
mask = pos_k < cur_batch_seq_len # 序列边界
if IS_CAUSAL:
mask &= pos_q >= pos_k # 因果掩码
if SLIDING_WINDOW_Q > 0:
mask &= pos_q - pos_k <= SLIDING_WINDOW_Q # 前向窗口
if SLIDING_WINDOW_K > 0:
mask &= pos_k - pos_q <= SLIDING_WINDOW_K # 后向窗口
Block Size 自适应:
python
def get_block_size(dtype):
if dtype == torch.float32: return 32
elif has_device_capability(80): return 128 # A100+
else: return 64
16.3 Triton Decode专用算子(triton_decode_attention.py, 778行)
来源 : 适配自 SGLang,原始设计来自 LightLLM 的 DeepSeek-V2 GQA 解码核
职责: 高效的 decode 阶段注意力,支持 MHA、GQA、MQA、MLA
16.3.1 两阶段架构
Stage 1:Split-KV 并行计算
- 将 KV 序列分成
NUM_KV_SPLITS段 - 每个 split 独立计算部分注意力输出和 LSE
- 输出中间结果
Att_Out [B, H, NUM_KV_SPLITS, D+1](最后1个元素存 LSE)
Stage 2:跨 split 归约
- 加载所有 split 的部分输出和 LSE
- Online softmax reduce 得到最终输出
设计优势:
- 长序列 decode 时,单 program 串行遍历整个 KV 序列成为瓶颈
- Split-KV 将序列分段并行处理,增加 SM 利用率
- 两阶段通信量小:仅在 stage2 做一次 reduce
16.3.2 _fwd_kernel_stage1 --- MHA版本
python
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
kv_page_number = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
- 分页 KV cache 访问 :通过
Req_to_tokens(block table)将逻辑位置映射到物理位置 - 支持任意
PAGE_SIZE(不仅限于 1)
FP8 反量化:
python
if k.dtype.is_fp8():
k = (k.to(tl.float32) * ks).to(q.dtype) # FP8 → fp32 → 反量化 → q.dtype
Logit Cap(tanh softcap):
python
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
中间输出格式:
Att_Out[b, h, s, 0:D] = 部分注意力输出(已除以部分expsum)
Att_Out[b, h, s, D] = 部分LSE = e_max + log(e_sum)
16.3.3 _fwd_grouped_kernel_stage1 --- GQA/MQA/MLA版本
关键区别 :一个 program 处理 BLOCK_H 个 q_head(而非 1 个)
python
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
VALID_BLOCK_H = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
- 将多个 q_head 捆绑处理,共享 KV 加载,减少全局内存访问
kv_group_num = Hq // Hkv,即 GQA 组大小
MLA(Multi-head Latent Attention)特殊处理:
python
if not IS_MLA:
v = tl.load(V_Buffer + offs_buf_v, ...) # 正常加载V
else:
v = tl.trans(k) # MLA: V = K^T(共享压缩KV)
- DeepSeek-V2/V3 的 MLA 将 KV 压缩为单个低维向量
c_kv - V 就是 K 的转置(因为
c_kv同时编码 K 和 V 的信息) - 避免重复加载相同的
c_kv
DPE(Decoupled Position Embedding)支持:
python
if BLOCK_DPE > 0:
kpe = tl.load(K_Buffer + offs_buf_kpe, ...)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
- MLA 的 key 由
k_nope(非位置部分)和k_pe(位置编码部分)组成 qk = q_nope @ k_nope + q_pe @ k_pe- 对应 DeepSeek-V2 的解耦 RoPE 设计
具体维度映射:
python
if Lk == 576: # DeepSeek-V2 256+64的head_dim+nope分解
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288: # DeepSeek-V2 小模型的128+32分解
BLOCK_DMODEL = 256
BLOCK_DPE = 32
16.3.4 _fwd_kernel_stage2 --- 跨split归约
python
for split_kv_id in range(0, NUM_KV_SPLITS):
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
- 标准的 online softmax reduce,与 stage1 的累积逻辑对称
16.3.5 入口函数
python
def decode_attention_fwd(q, k_buffer, v_buffer, o, lse, req_to_token,
b_seq_len, attn_logits, num_kv_splits, sm_scale,
page_size=1, logit_cap=0.0, k_scale=None,
v_scale=None, is_mla=False):
- MHA vs GQA 选择:
python
if kv_group_num == 1:
decode_attention_fwd_normal(...) # MHA:每个program一个q_head
else:
decode_attention_fwd_grouped(...) # GQA/MQA/MLA:每个program一组q_head
16.4 Triton KV Cache写入算子(triton_reshape_and_cache_flash.py, 601行)
职责: 将新生成的 K、V 写入分页 KV cache,支持 FP8 量化和 per-token-head 动态量化
16.4.1 reshape_and_cache_kernel_flash --- 主KV写入核
两种 KV Cache 布局:
Head-Major Layout(5D):
K_cache: [num_blocks, num_kv_heads, head_size//x, block_size, x]
V_cache: [num_blocks, num_kv_heads, head_size, block_size]
- Key 使用 5D 布局以优化 Tensor Core 访问(x 维度对齐到 16B)
Slot-Major Layout(4D):
KV_cache: [num_blocks, block_size, num_kv_heads, head_size+head_size_v]
- K 和 V 在同一张量中连续存储
写入逻辑:
python
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
return # padding token,跳过
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
slot_mapping:每个 token 在 KV cache 中的物理位置slot_idx < 0表示 padding(如 chunked prefill 中的虚拟 token)
FP8 量化:
python
if FP8_KV_CACHE:
key_tile = key_load / tl.load(k_scale) # 反向量化:除以scale → FP8范围
else:
key_tile = key_load
- 注意:Triton 的
tl.store会自动将 float 值截断为目标 dtype(FP8) k_scale/v_scale是全局张量级 scale(per-tensor quantization)
16.4.2 _reshape_cache_per_token_head --- Per-Token-Head动态量化核
设计意图:KV cache 的 per-token-head 量化比 per-tensor 量化精度更高,每个 (token, head) 对有独立的 scale。
Grid : (num_tokens, num_kv_heads) --- 每个 program 处理一个 (token, head) 对
核心逻辑:
python
k_h = tl.load(key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs,
mask=k_mask, other=0.0).to(tl.float32)
k_scale = tl.maximum(tl.max(tl.abs(k_h)) / QUANT_MAX, 1e-6)
tl.store(k_scale_cache_ptr + blk * stride_ks_blk + slot_in_blk * stride_ks_slot + head * stride_ks_head, k_scale)
k_q = tl.clamp(k_h * (1.0 / k_scale), QUANT_MIN, QUANT_MAX)
tl.store(key_cache_ptr + ..., k_q, mask=k_mask)
- 加载一个 (token, head) 的 K 值
- 计算
absmax / QUANT_MAX作为 scale - 量化:
clamp(value / scale, min, max) - 存储:量化值写入
key_cache,scale 写入k_scale_cache
量化参数映射:
python
_PER_TOKEN_HEAD_QUANT_PARAMS = {
torch.int8: (127.0, -128.0),
FP8_DTYPE: (FP8_MAX, FP8_MIN),
}
16.4.3 reshape_and_cache_kernel_flash_diffkv --- 差异化KV维度写入
设计意图 :MLA(如 DeepSeek-V2)中 K 和 V 的维度不同(K 包含 nope+pe,V 仅 nope 部分)。此核将 K 和 V 写入同一个合并 cache [num_blocks, block_size, num_heads, head_size_k + head_size_v]。
16.5 Triton注意力状态合并(triton_merge_attn_states.py, 175行)
数学原理 : Section 2.2 of Split-KV Attention论文
职责: 合并两段部分注意力输出(prefix/suffix),用于 chunked prefill 场景
核函数 merge_attn_states_kernel
两种 token 处理路径:
1. 无 prefix context 的 token(decode token):
python
if not prefix_mask:
s_lse = tl.load(suffix_lse + ...)
s_out = tl.load(suffix_output + ...)
tl.store(output + ..., s_out) # 直接复制suffix输出
2. 有 prefix context 的 token(prefill token):
python
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
# FA2/FA3 兼容性处理
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse
max_lse = tl.maximum(p_lse, s_lse)
p_se = tl.exp(p_lse - max_lse) # prefix的exp(lse - max)
s_se = tl.exp(s_lse - max_lse) # suffix的exp(lse - max)
out_se = p_se + s_se # 合并的exp sum
# LSE计算(可选输出)
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
# 加权合并
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
数学原理:
设 prefix 部分的注意力输出为 O_p = Σ_i (α_i * v_i),suffix 部分为 O_s = Σ_j (β_j * v_j)。
全局注意力输出为:
O = (Σ_i α_i * v_i + Σ_j β_j * v_j) / (Σ_i α_i + Σ_j β_j)
= O_p * exp(lse_p - lse) + O_s * exp(lse_s - lse)
其中 lse = log(exp(lse_p) + exp(lse_s))(log-sum-exp 合并)。
第十七章:Prefix Prefill与Chunked Prefill算子
涵盖文件:
prefix_prefill.py,chunked_prefill_paged_decode.py,merge_attn_states.py,paged_attn.py
17.1 Prefix Prefill算子(prefix_prefill.py, 864行)
来源 : 适配自 LightLLM 的 context_attention_fwd
职责: 带有 prefix(KV cache 中已有 context)的 prefill 注意力计算
17.1.1 全局参数
python
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 4 if current_platform.is_rocm() else 8
IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())
- Ampere+ (SM80+): BLOCK=128,更多的 shared memory
- Turing (SM75): 使用 IEEE 浮点(无 tensor core FP32 支持)
17.1.2 _fwd_kernel --- 主Prefill核
核心架构:两阶段注意力计算
阶段1 --- Context注意力(Q vs KV cache中的prefix):
python
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache):
token_indices = start_n + offs_bs_n
bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s)
internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE
-
非标准 block_size 支持:
- 正常模型:block_size = 16/32/64/128/256(2的幂)
- Qwen3-Next-80B:block_size = 544(非2的幂!)
- 当 block_size 不是 2 的幂时,一个 32-token 的 tile 可能跨越两个物理块
- 解决方案:每个 token 独立计算逻辑块索引
token_indices // PHYSICAL_BLOCK_SIZE
-
K/V 地址计算(5D K cache):
python
off_k = (
bn[None, :] * stride_k_cache_bs # 物理块偏移
+ cur_kv_head * stride_k_cache_h # head偏移
+ (offs_d[:, None] // x) * stride_k_cache_d # dim//x偏移
+ internal_offsets[None, :] * stride_k_cache_bl # block内位置偏移
+ (offs_d[:, None] % x) * stride_k_cache_x # x内偏移
)
- Key cache 5D 布局:
[num_blocks, num_kv_heads, head_size//x, block_size, x] x = 16 // element_size:将 head_size 维度拆分为(head_size//x, x)以优化内存访问
阶段2 --- Query自注意力(Q vs 新K/V,带因果掩码):
python
for start_n in tl.range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N,
loop_unroll_factor=num_unroll_request):
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
- 因果掩码:Q 只能看到自身及之前的 token
block_mask * (start_m + 1) * BLOCK_M:当当前 BLOCK_M 超出 query 范围时,跳过
FP8 KV 反量化:
python
if k_load.dtype.is_fp8():
k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
k = k_load
Sliding Window:
python
if SLIDING_WINDOW > 0:
qk = tl.where(
(cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW,
qk, float("-inf"),
)
- 绝对位置计算:query 位置 =
ctx_len + query_pos,key 位置 =start_n + key_pos - 只允许 query 看到
SLIDING_WINDOW范围内的 key
FP8 输出:
python
if USE_FP8:
acc = acc * tl.load(out_scale_inv) # scale → FP8范围
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
17.1.3 _fwd_kernel_alibi --- ALiBi位置编码版本
ALiBi(Attention with Linear Biases):
python
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]) * alibi_slope
alibi = tl.where((alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, float("-inf"))
qk += alibi
- ALiBi 将位置信息编码为注意力偏置,而非 RoPE 那样修改 Q/K
alibi_slope:每个 head 的斜率(预计算,通常为几何序列)- 偏置公式:
bias = slope * (key_pos - query_pos),且只允许bias ≤ 0(因果性)
17.1.4 context_attention_fwd --- Python入口
非标准 block_size 处理:
python
is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0)
if is_pow2:
BLOCK_M = 128
BLOCK_N = 64
else:
BLOCK_M = 32
BLOCK_N = 32
TRITON_BLOCK_SIZE = 32 # 始终使用32作为内部tile大小
- 2的幂 block_size:可以使用较大的 BLOCK_M/N,因为一个 tile 不会跨块
- 非2的幂(如544):必须使用小的 BLOCK_M/N=32,配合跨块索引逻辑
Block Table 指针归一化:
python
if is_block_table_ptr:
kv_element_size = k_cache.element_size()
block_byte_stride = k_cache.stride(0) * kv_element_size
base_addr = k_cache.data_ptr()
processed_b_loc = torch.where(
mask, (b_loc - base_addr) // block_byte_stride, b_loc
).to(torch.int32)
- 某些后端(如 FlashAttention)的 block_table 存储的是字节指针而非块索引
- 转换为块索引:
(ptr - base_addr) / block_byte_stride
17.2 Chunked Prefill + Paged Decode(chunked_prefill_paged_decode.py, 467行)
来源 : IBM Research 贡献
职责: 混合 batch 中同时处理 prefill 和 decode 请求
17.2.1 kernel_paged_attention_2d --- Paged Decode核
Grid : (num_seqs, num_kv_heads) --- 每个program处理一个 (序列, kv_head) 对
Query 加载(GQA展开):
python
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(0, num_queries_per_kv_padded)
Q = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
mask=dim_mask[None, :] & head_mask[:, None], other=0.0)
- 一次加载一个 kv_group 的所有 q_head(
num_queries_per_kv_padded个) padded确保是 2 的幂,为tl.dot对齐
Block遍历(非标准block_size支持):
python
for j in range(0, num_blocks):
abs_token_idx = start_n + offs_n
l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE # 逻辑块索引
p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
- 与 prefix_prefill 相同的跨块索引逻辑
Decode token 过滤:
python
if filter_by_query_len:
cur_batch_query_len = tl.load(query_start_len_ptr + seq_idx + 1) - tl.load(query_start_len_ptr + seq_idx)
if cur_batch_query_len > 1:
return # 跳过prefill token(它们由prefix_prefill核处理)
17.2.2 chunked_prefill_paged_decode --- 混合调度入口
两阶段处理:
python
if max_query_len > 1:
context_attention_fwd(...) # Prefill tokens: 使用prefix_prefill核
kernel_paged_attention_2d[...] # Decode tokens: 使用paged decode核
- 先处理 prefill token(使用
prefix_prefill.context_attention_fwd,skip_decode=True) - 再处理 decode token(使用
kernel_paged_attention_2d,filter_by_query_len=True)
ROCm 自定义分页注意力回退:
python
if use_custom:
ops.paged_attention_rocm(...) # ROCm自定义CUDA核
else:
kernel_paged_attention_2d[...] # Triton回退
17.3 注意力状态合并(merge_attn_states.py, 103行)
设计意图:提供 CPU/GPU 双路径的注意力状态合并,自动选择最优实现。
python
def merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse, ...):
if current_platform.is_cuda() and supported_dtypes(prefix_output) and supported_headdim(prefix_output):
from vllm._custom_ops import merge_attn_states # CUDA核
return merge_attn_states(...)
else:
from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states # Triton核
return merge_attn_states(...)
CUDA 核条件:
supported_dtypes:float32/half/bfloat16supported_headdim:head_dim 必须是 4(float32)或 8(half/bf16)的倍数(128b 对齐)
17.4 分页注意力入口(paged_attn.py, 51行)
PagedAttention 类:传统分页注意力的入口点。
split_kv_cache:将联合 KV cache 拆分为 K cache 和 V cache
python
@staticmethod
def split_kv_cache(kv_cache, num_kv_heads, head_size):
x = 16 // kv_cache.element_size() # 16B对齐因子
key_cache = kv_cache[0].view(num_blocks, num_kv_heads, head_size // x, -1, x)
value_cache = kv_cache[1].view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
write_to_paged_cache:调用 CUDA 核写入 KV cache
python
@staticmethod
def write_to_paged_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale):
ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping.flatten(), kv_cache_dtype, k_scale, v_scale)
- 调用 vLLM 自定义 CUDA 算子
ops.reshape_and_cache
第十八章:ViT/DCP/FlashMLA/特殊算子
涵盖文件:
flashmla.py,vit_attn_wrappers.py,dcp_alltoall.py,rocm_aiter_mla_sparse.py,xpu_mla_sparse.py,triton_turboquant_store.py,triton_turboquant_decode.py
18.1 FlashMLA算子(flashmla.py, 153行)
来源 : 适配自 DeepSeek-AI 的 FlashMLA
职责: DeepSeek-V2/V3 MLA 注意力的专用后端
18.1.1 可用性检测
python
def _is_flashmla_available() -> tuple[bool, str | None]:
if not _flashmla_C_AVAILABLE:
return False, "vllm._flashmla_C is not available..."
if not _flashmla_extension_C_AVAILABLE:
return False, "vllm._flashmla_extension_C is not available..."
return True, None
- 双重检测:核心库
_flashmla_C和扩展库_flashmla_extension_C - 编译失败时优雅降级
python
def is_flashmla_dense_supported():
if not current_platform.is_device_capability_family(90):
return False, "FlashMLA Dense is only supported on Hopper devices."
return True, None
def is_flashmla_sparse_supported():
if not (is_device_capability_family(90) or is_device_capability_family(100)):
return False, "FlashMLA Sparse is only supported on Hopper and Blackwell devices."
return True, None
- Dense:仅 Hopper(SM90),因为 FlashMLA Dense 使用了 Hopper 专用的 WMMA 指令
- Sparse:Hopper + Blackwell(SM100),因为稀疏注意力需要 2:4 稀疏矩阵支持
18.1.2 FP8 MLA Decode
python
def flash_mla_with_kvcache_fp8(q, k_cache, block_table, cache_seqlens,
head_dim_v, tile_scheduler_metadata, num_splits, ...):
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal,
tile_scheduler_metadata, num_splits, descale_q, descale_k)
return out, softmax_lse
- 专用 FP8 MLA decode 核:直接在 FP8 精度下计算 Q·K 点积
tile_scheduler_metadata/num_splits:FlashMLA 的自定义 tile 调度器输出descale_q/descale_k:FP8 反量化 scale
18.2 ViT注意力封装(vit_attn_wrappers.py, 361行)
职责 : 为 Vision Transformer 提供多种注意力后端的统一封装,兼容
torch.compile
18.2.1 设计动机
ViT 注意力与 LLM 注意力有本质区别:
- 双向注意力(非因果)
- 固定序列长度(图像 patch 数量固定)
- 不使用 KV cache(每次推理都是完整 forward)
torch.compile 不支持 .item() 等动态操作(如 FlashAttention 的 max_seqlen 参数),所以需要封装。
18.2.2 四种后端封装
1. FlashAttention 封装:
python
def flash_attn_maxseqlen_wrapper(q, k, v, batch_size, is_rocm_aiter, fa_version, ...):
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, ...)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
einops.rearrange:批次维度展平/恢复cu_seqlens:构建累积序列长度数组(ViT 中每个序列等长)
2. Triton 封装:
python
def triton_attn_wrapper(q, k, v, batch_size, scale, cu_seqlens, max_seqlen):
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
context_attention_fwd(q, k, v, output, is_causal=False, ...)
- 使用
triton_prefill_attention.context_attention_fwd is_causal=False:ViT 使用双向注意力
3. PyTorch SDPA 封装:
python
def torch_sdpa_wrapper(q, k, v, scale, cu_seqlens, enable_gqa):
if cu_seqlens is None:
return apply_sdpa(q, k, v, scale, enable_gqa)
# 有cu_seqlens时,按序列拆分分别计算
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = apply_sdpa(q_i, k_i, v_i, scale, enable_gqa)
- 回退方案:使用 PyTorch 原生
F.scaled_dot_product_attention - 支持变长序列(按
cu_seqlens拆分)
4. FlashInfer 封装:
python
def flashinfer_wrapper(q, k, v, scale, workspace_buffer, cu_seqlens, ...):
output, _ = cudnn_batch_prefill_with_kv_cache(
q, k, v, scale, workspace_buffer, causal=False, ...)
- 使用 cuDNN 后端的 FlashInfer
cu_seqlens拆分为 Q/K/O 和 V 两套偏移(cuDNN 要求)
18.2.3 torch.compile 兼容性
所有封装都通过 direct_register_custom_op 注册为自定义算子:
python
direct_register_custom_op(
op_name="flash_attn_maxseqlen_wrapper",
op_func=flash_attn_maxseqlen_wrapper,
fake_impl=flash_attn_maxseqlen_wrapper_fake,
)
fake_impl:提供torch.compile的 meta tensor 推断- 避免
torch.compile试图追踪 Python 控制流和.item()调用
18.3 DCP AlltoAll通信(dcp_alltoall.py, 363行)
论文参考 : https://arxiv.org/abs/2507.07120
职责: Decode Context Parallelism 的 All-to-All 通信后端,替代 AllGather+ReduceScatter
18.3.1 设计动机
传统 CP 通信流程(AG+RS):
1. AllGather Q → 每个rank获得完整Q
2. AllGather K metadata
3. 本地注意力计算
4. ReduceScatter output → 每个rank获得部分head的输出
= 3 次 NCCL 集合通信
A2A 通信流程:
1. 本地注意力计算(每个rank拥有部分head + 本地KV分片)
2. All-to-All output → 交换各rank的部分注意力输出
3. All-to-All LSE → 交换各rank的LSE
4. 本地 LSE-weighted combine → 最终输出
= 2 次 NCCL 集合通信
优势:减少一次 NCCL 调用,降低延迟。长序列 decode 时 NCCL 延迟是 step 时间的显著组成部分。
18.3.2 _lse_weighted_combine --- CPU参考实现
python
def _lse_weighted_combine(outputs, lses, return_lse, is_lse_base_on_e):
lse_max, _ = lses.max(dim=0) # [B, H]
weights = torch.exp(lses - lse_max.unsqueeze(0)) # [N, B, H]
weight_sum = weights.sum(dim=0, keepdim=True)
weights = weights / weight_sum.clamp(min=1e-10)
result = (outputs * weights.unsqueeze(-1)).sum(dim=0) # [B, H, D]
- 纯 PyTorch 实现,用于测试和验证
- 标准的 LSE-weighted 组合:
output = Σ_n (output_n * exp(lse_n - lse_max)) / Σ_n exp(lse_n - lse_max)
18.3.3 _dcp_lse_combine_kernel --- Triton GPU核
三遍扫描算法:
- 第一遍:找所有 rank 的 LSE 最大值(数值稳定性)
- 第二遍 :计算
Σ exp(lse - max)和全局 LSE - 第三遍:加权组合所有 rank 的部分输出
python
for n in tl.static_range(N):
lse_val = tl.load(recv_lse_ptr + n * rl_stride_N + base_lse_offset)
lse_val = tl.where((lse_val != lse_val) | (lse_val == float("inf")), -float("inf"), lse_val)
lse_max = tl.maximum(lse_max, lse_val)
tl.static_range(N):编译时展开循环(N 是 constexpr)
18.3.4 dcp_a2a_lse_reduce --- 完整A2A流程
python
def dcp_a2a_lse_reduce(cp_attn_out, cp_attn_lse, cp_group, ...):
# Reshape for All-to-All: [B, H, D] → [N, B, H/N, D]
send_output = local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3)
send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2)
# Async A2A overlap
work_output = dist.all_to_all_single(recv_output.view(-1), send_output.view(-1),
group=cp_group.device_group, async_op=True)
work_lse = dist.all_to_all_single(recv_lse.view(-1), send_lse.view(-1),
group=cp_group.device_group, async_op=True)
work_output.wait()
work_lse.wait()
# Local LSE-weighted combination
return dcp_lse_combine_triton(recv_output, recv_lse, ...)
- Reshape 逻辑 :将 head 维度拆分为
N份,每个 rank 的数据标记为发送给对应 rank - 异步 A2A:output 和 LSE 的 A2A 同时启动,重叠通信
- 本地组合:无需额外通信,Triton 核完成
18.4 ROCm Aiter MLA稀疏算子(rocm_aiter_mla_sparse.py, 655行)
职责 : DeepSeek-V3.2 的 MLA 稀疏注意力在 AMD ROCm 平台上的实现
核心: Indexer(token选择器) + FP8 MQA Logits + TopK 筛选
18.4.1 整体架构
DeepSeek-V3.2 的稀疏 MLA 注意力流程:
- Indexer: 用 FP8 MQA logits 计算每个 query 应该关注哪些 KV token
- TopK: 选择 top-k 个最相关的 KV token
- Sparse Attention: 只在选中的 KV token 上计算完整注意力
18.4.2 _indexer_k_quant_and_cache_kernel --- Indexer K量化写入
python
@triton.jit
def _indexer_k_quant_and_cache_kernel(k_ptr, kv_cache_ptr, kv_cache_scale_ptr,
slot_mapping_ptr, ...):
val = tl.load(src_ptr + offset)
amax = tl.max(val.abs(), axis=-1).to(tl.float32)
scale = tl.maximum(1e-4, amax) / 448.0 # FP8 E4M3 max = 448
if USE_UE8M0:
scale = tl.exp2(tl.ceil(tl.log2(scale))) # 对齐到2的幂
fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
- FP8 量化:per-token absmax 量化,scale = absmax / 448
- UE8M0 格式 :scale 对齐到 2 的幂(
exp2(ceil(log2(scale)))),某些硬件优化需要
18.4.3 _cp_gather_indexer_quant_cache_kernel --- 从KV cache读取量化K
python
@triton.jit
def _cp_gather_indexer_quant_cache_kernel(kv_cache_ptr, kv_cache_scale_ptr,
k_fp8_ptr, k_scale_ptr, ...):
# 通过block_table定位KV cache中的K
block_id = tl.load(block_table_ptr + block_table_offset)
# 读取FP8 K和scale
val = tl.load(src_cache_ptr + tiled_src_offset)
scale_val = tl.load(src_scale_ptr)
# 存储到连续缓冲区
tl.store(dst_k_ptr + offset, val)
tl.store(k_scale_ptr + tid, scale_val)
- 从分页 KV cache 中读取已量化的 K,组装为连续缓冲区供 MQA logits 计算
18.4.4 rocm_aiter_sparse_attn_indexer --- 完整Indexer流程
Prefill 路径:
python
if has_prefill:
for chunk in prefill_metadata.chunks:
# 1. 从KV cache读取量化K
ops.cp_gather_indexer_k_quant_cache(kv_cache, k_fp8, k_scale, ...)
# 2. FP8 MQA logits
logits = rocm_fp8_mqa_logits(q_fp8[chunk.token_start:chunk.token_end], ...)
# 3. TopK筛选
torch.ops._C.top_k_per_row_prefill(logits, ...)
Decode 路径:
python
if has_decode:
# 1. Pack Q为batch格式
padded_q_fp8_decode_tokens = pack_seq_triton(q_fp8[:num_decode_tokens], decode_lens)
# 2. FP8 Paged MQA logits
logits = rocm_fp8_paged_mqa_logits(padded_q_fp8_decode_tokens, kv_cache, ...)
# 3. TopK筛选
torch.ops._C.top_k_per_row_decode(logits, ...)
# 4. Unpack结果
topk_indices = unpack_seq_triton(topk_indices.reshape(...), decode_lens)
- 使用
pack_seq_triton/unpack_seq_triton(来自 common.py)处理不等长 decode 序列
18.5 XPU MLA稀疏算子(xpu_mla_sparse.py, 265行)
职责 : Intel XPU(GPU Max/Flex)上的 MLA 稀疏注意力实现
设计: BF16 精度的 Triton 实现,无需 FP8 量化
核函数 _bf16_mla_sparse_kernel
与 ROCm 版本的关键区别:
- 直接使用 BF16 精度(无 FP8 量化步骤)
- 使用
exp2(base-2 指数)替代exp(GPU 上exp2更快) sm_scale *= LOG2E将 softmax scale 转换到 base-2
Index-based 稀疏注意力:
python
for start_indice in range(0, index_topk, BLOCK_N):
indices = tl.load(indices_ptr + cur_q * stride_indices_token + cur_kv_head_id * stride_indices_head + offs_indice)
# 根据topk indices加载K和V
k = tl.load(k_buffer + indices[None, :] * stride_k_token + ...)
v = tl.load(v_buffer + indices[:, None] * stride_v_token + ...)
- 不遍历完整 KV 序列,只加载 topk indices 指向的 token
- 大幅减少计算量和内存访问
DPE(解耦位置编码)支持:
python
if BLOCK_DPE > 0:
kpe = tl.load(k_buffer + offs_kpe, ...)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
LSE 输出格式:
python
max_logits = e_max * LOGE2 # base-2 → base-e 转换
lse = max_logits + tl.log2(e_sum) * LOGE2 # log2(sum) → log(sum)
18.6 TurboQuant存储算子(triton_turboquant_store.py, 447行)
职责: TurboQuant(TQ)KV cache 的写入端,支持 FP8 key + 均匀量化 value,以及 MSE 最优量化 key
18.6.1 TurboQuant 概述
TurboQuant 是一种混合 KV cache 量化方案:
- Key: FP8 量化(per-token absmax)或 MSE 最优量化(vector quantization)
- Value: 均匀量化(3-bit 或 4-bit,per-head min-max 量化)
相比标准 FP8 量化,TurboQuant 的 value 使用更低精度(3/4 bit),进一步压缩 KV cache 占用。
18.6.2 _store_quantized_value --- Value均匀量化子程序
4-bit 量化:
python
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
v_scale = (val_max - val_min) / 15.0 # 4-bit: 0~15
q_all = tl.minimum(tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 15)
# 打包:两个4-bit值 → 一个字节
q_pairs = tl.reshape(q_all, [BLOCK_D // 2, 2])
packed_val = tl.sum((q_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)
3-bit 量化(更复杂):
python
v_scale = (val_max - val_min) / 7.0 # 3-bit: 0~7
q_vals = tl.minimum(tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 7)
# 打包:8个3-bit值 → 3个字节(24 bit)
q_grp = tl.reshape(q_vals, [BLOCK_GRP, 8])
packed_24 = tl.sum(q_grp << shifts_3bit[None, :], axis=1) # 24-bit打包
b0 = (packed_24 & 0xFF).to(tl.uint8)
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
- 3-bit 打包是非对齐的(8×3=24 bit = 3 bytes),比 4-bit(2×4=8 bit = 1 byte)复杂
Scale/Zero 存储(float16):
python
sc_f16 = v_scale.to(tl.float16)
sc_u16 = sc_f16.to(tl.uint16, bitcast=True) # 浮点→位模式
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
tl.store(KV_cache_ptr + slot_base + sc_offset + 1, ((sc_u16 >> 8) & 0xFF).to(tl.uint8))
- float16 的 2 字节存储到 uint8 缓冲区,需要拆分为低/高字节
18.6.3 _tq_fused_store_fp8 --- FP8 Key + Value量化融合核
python
@triton.jit
def _tq_fused_store_fp8(Key_ptr, Value_ptr, KV_cache_ptr, Slot_mapping_ptr, ...):
# FP8 KEY: 在核内直接转换
k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0)
k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv)
k_bytes = k_fp8.to(tl.uint8, bitcast=True)
tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask)
# VALUE: 均匀量化
_store_quantized_value(Value_ptr, KV_cache_ptr, ...)
- FP8 格式选择 :
e4b15(Ampere/Ada, SM < 8.9):自定义 FP8 格式e4nv(Hopper+, SM ≥ 8.9):NVIDIA 标准 FP8 格式
18.6.4 _tq_fused_store_mse --- MSE量化融合核
MSE(最小均方误差)量化流程:
- 旋转投影(核外 cuBLAS GEMM,更快):
python
k_flat = key.float().reshape(NH, D)
norms = k_flat.norm(dim=1, keepdim=True)
x_hat = k_flat / (norms + 1e-8) # 归一化
y = x_hat @ PiT # 旋转投影到量化空间
- 二分搜索桶化(核内):
python
lo = tl.zeros([BLOCK_D], dtype=tl.int32)
hi = tl.full([BLOCK_D], N_CENTROIDS - 1, dtype=tl.int32)
for _ in range(MSE_BITS): # log2(n_centroids) 次迭代
mid = (lo + hi) >> 1
mid_val = tl.load(Midpoints_ptr + safe_mid, ...)
lo = tl.where(y_vec >= mid_val, mid + 1, lo)
hi = tl.where(y_vec >= mid_val, hi, mid)
idx = tl.minimum(lo, N_CENTROIDS - 1)
- 对每个维度独立查找最近的质心索引
- 二分搜索复杂度 O(log N_CENTROIDS) vs 线性搜索 O(N_CENTROIDS)
- 索引打包(3-bit 或 4-bit):
- 与 value 打包逻辑相同
- 范数存储(fp16, 2 bytes)
18.7 TurboQuant解码算子(triton_turboquant_decode.py, 623行)
职责: TurboQuant KV cache 的读取+注意力计算端
18.7.1 _tq_decode_stage1 --- TQ解码Stage1核
两种 Key 读取路径:
FP8 Key 路径:
python
if KEY_FP8:
k_raw = tl.load(KV_cache_ptr + k_addrs, ...)
k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32)
scores = tl.sum(q_rot[None, :] * k_float, axis=1) * ATTN_SCALE
MSE Key 路径:
python
else:
# 1. 从打包位中解包MSE索引
mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask
# 2. 查表得到质心值
c_vals = tl.load(Centroids_ptr + mse_idx, ...)
# 3. 加载向量范数
vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
# 4. 重构K ≈ norm * centroid
scores = vec_norms * term1 * ATTN_SCALE
Value 解量化:
3-bit Value 解包:
python
v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32)
v_scales = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None] # 仿射反量化
4-bit Value 解包:
python
v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]
Online Softmax + 值累积:
python
n_e_max = tl.maximum(tl.max(scores, 0), m_prev)
re_scale = tl.exp(m_prev - n_e_max)
p = tl.exp(scores - n_e_max)
acc = acc * re_scale + tl.sum(p[:, None] * values, 0)
l_prev = l_prev * re_scale + tl.sum(p, 0)
m_prev = n_e_max
18.7.2 _tq_full_dequant_kv --- 全量KV解量化核
设计意图:将 TQ KV cache 完整解量化为 fp16 K 和 V,用于 prefill 阶段(需要完整 K/V 做矩阵注意力)。
18.7.3 triton_turboquant_decode_attention --- 完整解码入口
三步流程:
- Query 旋转投影(MSE 路径需要):
python
if key_fp8:
q_rot = query.contiguous() # FP8: 直接用原始query
else:
q_rot = (query.float() @ PiT).contiguous() # MSE: 旋转投影
- Stage 1: 分段 KV 注意力
- Stage 2 : 跨段归约(复用
triton_decode_attention._fwd_kernel_stage2)
缓冲区复用:
python
if mid_o_buf is not None and mid_o_buf.shape[0] >= B:
mid_o = mid_o_buf[:B, :Hq, :NUM_KV_SPLITS, :]
else:
mid_o = torch.empty(...)
if buf_holder is not None:
buf_holder._tq_mid_o_buf = mid_o
- 预分配缓冲区,避免每次调用都分配新内存
- 对 CUDA Graph 兼容至关重要(CUDA Graph 要求固定的内存地址)
总结:vLLM v1 Attention Ops架构全景
算子分层关系
┌─────────────────────────────────────────────────────────┐
│ 上层调度(Backends) │
│ FlashAttention / Triton / FlashInfer / ROCm / XPU │
├─────────────────────────────────────────────────────────┤
│ 通用工具层(common.py) │
│ CP修正 / Pack-Unpack / Block Table归一化 │
├──────────────────────┬──────────────────────────────────┤
│ KV Cache 写入层 │ 注意力计算层 │
│ reshape_and_cache │ unified_attention (2D/3D) │
│ per-token-head量化 │ prefix_prefill │
│ TurboQuant store │ decode_attention (2-stage) │
│ Indexer quant+cache │ chunked_prefill_paged_decode │
├──────────────────────┴──────────────────────────────────┤
│ 注意力合并层 │
│ merge_attn_states (Triton/CUDA) │
├─────────────────────────────────────────────────────────┤
│ 特殊算子层 │
│ FlashMLA (Dense/Sparse) │ ViT Wrappers │
│ DCP AlltoAll │ ROCm Aiter MLA Sparse │
│ XPU MLA Sparse │ TurboQuant Decode │
└─────────────────────────────────────────────────────────┘
核心设计模式
-
两阶段Split-KV:decode_attention 和 turboquant_decode 都采用 Split-KV + Reduce 的两阶段架构,解决长序列单 program 串行瓶颈
-
2D/3D统一调度:unified_attention 根据是否有 prefill 和序列数选择 2D(单段 softmax)或 3D(分段 softmax)模式
-
多层掩码叠加:causal ∧ sliding_window ∨ mm_prefix,与 FlexAttention 语义一致
-
多级KV量化:从无量化 → FP8 per-tensor → FP8/int8 per-token-head → MSE 最优量化 + 均匀量化 value,精度和压缩率递进
-
平台自适应 :通过
current_platform运行时检测,自动选择 CUDA/ROCm/XPU 路径和参数 -
非标准block_size兼容:Qwen3-Next 的 544 block_size 通过逐 token 计算逻辑块索引实现兼容
-
GQA/MQA/MLA统一支持 :通过
num_queries_per_kv和IS_MLA编译时常量,同一核函数支持多种注意力模式 -
编译时优化 :大量使用
tl.constexpr将配置参数编译为常量,消除运行时分支和冗余计算