【vllm】(五)vLLM v1 Attention — 模块超深度分析之五

第十三章:ROCm/AMD 后端

13.1 rocm_attn.py 逐行解析 (533行)

13.1.1 RocmAttentionBackend

AMD ROCm平台的标准注意力后端:

  • 使用 rocm_aiter_ops --- AMD aiter库封装
  • get_kv_cache_shape(): (2, num_blocks, block_size, num_kv_heads, head_size) --- 标准NHD布局
  • 支持 fp8 KV cache (ROCm FP8格式)
13.1.2 RocmAttentionMetadata
python 复制代码
@dataclass
class RocmAttentionMetadata:
    num_actual_tokens: int
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: torch.Tensor | None
    prefix_kv_lens: torch.Tensor | None
    suffix_kv_lens: torch.Tensor | None
    scheduler_metadata: torch.Tensor | None
    prefix_scheduler_metadata: torch.Tensor | None
  • Cascade Attention : ROCm支持cascade (公共前缀+后缀分离计算)
    • cu_prefix_query_lens: 前缀查询长度
    • prefix_kv_lens / suffix_kv_lens: 前缀/后缀KV长度
  • scheduler_metadata: aiter调度器元数据(AOT调度)
13.1.3 RocmAttentionMetadataBuilder
  • _cudagraph_support = ALWAYS: 完全支持CUDA Graph
  • build_for_cudagraph_capture(): 特殊处理 --- seq_lens设为1避免graph capture过慢
  • build(): 构建metadata → 支持cascade attention路径
13.1.4 RocmAttentionImpl

forward() 方法:

  1. KV cache写入: triton_reshape_and_cache_flash() → 写入KV到paged cache
  2. Prefill路径: chunked_prefill_paged_decode() --- ROCm chunked prefill kernel
  3. Decode路径: rocm_aiter_ops.decode_forward() --- aiter decode kernel
  4. Cascade路径: 公共前缀 → 后缀 → merge_attn_states() 合并
  5. FP8 KV cache: 使用 k_scale/v_scale 反量化

13.2 rocm_aiter_fa.py 逐行解析 (1460行)

这是attention模块中最大的单个后端文件。

13.2.1 核心Triton Kernel --- cp_mha_gather_cache_kernel

功能: Context Parallelism(CP)场景下,从远程GPU gather KV cache到本地。

逐行解析:

  • token_id = tl.program_id(0): 每个program处理1个token
  • head_id = tl.program_id(1): 每个program处理1个head
  • batch_idx = tl.load(token_to_batch_ptr + token_id): token→batch映射
  • batch_start = tl.load(seq_start_ptr + batch_idx): batch起始位置
  • block_offset = batch_offset // PAGE_SIZE: 计算block偏移
  • block_id = tl.load(block_table_ptr + max_block_num * batch_idx + block_offset): 查block_table
  • slot_id = batch_offset % PAGE_SIZE: block内slot偏移
  • NHD布局 : key_cache[block_id, slot, head, col] = key[token, head, col]
  • HND布局 : key_cache[block_id, head, slot, col] = key[token, head, col]
  • FP8反量化(DEQUANT) : key = key_cache * k_scale

设计意图: CP需要从其他GPU收集KV → 拼接到本地 → 执行注意力。此kernel实现了高效的跨GPU KV gather。

13.2.2 RocmAiterFABackend
  • 继承 AttentionBackend
  • 支持 fp8 KV cache
  • CP(Context Parallelism): 支持 _CP_TOKENS_PER_ITER_ROCM = 32K 每次迭代token数
  • KV cache布局: 优先NHD (num_blocks, page_size, num_heads, head_dim)
13.2.3 RocmAiterFAMetadataBuilder
  • 支持 split_decodes_prefills_and_extends() --- 三分类(decode/prefill/extend)
  • Extend: CP场景下的KV扩展(gather远程KV后扩展本地cache)
  • CP路径: 构建 cp_seqlen (本地+远程KV长度)
  • Workspace管理: 预分配gather缓冲区
13.2.4 RocmAiterFAImpl

forward() 方法复杂流程:

  1. 非CP路径: 标准prefill/decode
  2. CP路径 :
    a. gather远程KV到workspace → 拼接本地KV
    b. Prefill: 在完整KV上做attention
    c. 输出修正: correct_attn_cp_output() --- CP输出聚合
  3. FP8 KV cache: 使用scale反量化
  4. Sliding Window: 限制KV扫描范围
  5. Cascade: 公共前缀分离计算

13.3 rocm_aiter_unified_attn.py 逐行解析 (318行)

ROCm统一注意力后端:

  • 将prefill/decode/cascade统一到单个后端
  • 使用aiter库的 flash_attention_forward 统一接口
  • 支持FP8 KV cache
  • 自动选择prefill/decode kernel路径

第十四章:CPU/Mamba/GDN/Tree/TurboQuant 等专用后端

14.1 cpu_attn.py 逐行解析 (498行)

14.1.1 CPUAttentionBackend

CPU平台专用注意力后端:

  • supported_dtypes: [fp16, bf16, fp32] --- CPU支持更多精度
  • get_supported_head_sizes(): [32,64,80,96,112,128,160,192,224,256,512] --- CPU更灵活
  • supports_attn_type(): 支持所有4种注意力类型(DECODER/ENCODER/ENCODER_ONLY/ENCODER_DECODER)
  • get_kv_cache_shape(): (2, num_blocks, num_kv_heads, block_size, head_size) --- CPU用HND布局
  • use_cascade_attention(): False --- CPU不支持cascade
14.1.2 CPUAttentionMetadata
python 复制代码
@dataclass
class CPUAttentionMetadata:
    isa: str                          # CPU指令集(avx2/avx512/sve等)
    num_actual_tokens: int
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
    scheduler_metadata: torch.Tensor | None
    causal: bool = True
    use_sdpa_prefill: bool = False    # 是否使用SDPA prefill
    num_decode_tokens: int = 0
    sdpa_attn_masks: list | None = None  # SDPA attention masks
    sdpa_start_loc: torch.Tensor | None = None
  • isa: CPU架构指令集,决定使用哪个kernel
  • use_sdpa_prefill: PyTorch SDPA (Scaled Dot Product Attention) 用于prefill
  • sdpa_attn_masks: SDPA需要的attention mask列表
14.1.3 CPUAttentionImpl

forward() 方法:

  1. KV cache写入 : ops.reshape_and_cache_cpu() --- CPU版KV写入
  2. Prefill路径 :
    • SDPA可用: torch.nn.functional.scaled_dot_product_attention() --- PyTorch原生
    • 不可用: 手动实现Q×K^T/√d → softmax → ×V
  3. Decode路径 : ops.paged_attention_v1() --- CPU分页注意力
  4. FP8 KV cache : ops.convert_fp8_cpu() 反量化

14.2 mamba_attn.py 逐行解析 (589行)

14.2.1 BaseMambaAttentionMetadata

Mamba状态空间模型(SSM)的metadata:

  • 不使用传统Q/K/V → 无KV Cache → 无attention计算
  • has_initial_states_p: 是否有初始状态(prefill时从checkpoint加载)
  • state_indices_tensor_p/d: prefill/decode的状态索引
  • num_accepted_tokens: spec decode接受的token数(用于加载正确checkpoint)
  • block_idx_*: 前缀缓存相关block索引
  • cu_chunk_seqlen_p: chunked prefill的累积chunk长度
14.2.2 BaseMambaAttentionMetadataBuilder
  • reorder_batch_threshold = 1: 不做reorder(Mamba不支持混合batch)
  • supports_update_block_table = True: 支持增量更新block_table
  • speculative_config: 推测解码配置
  • use_spec_decode: 是否使用spec decode

build() 方法:

  1. split_decodes_and_prefills(): 分离decode/prefill
  2. Prefill: 构建状态索引 + 初始状态标记 + chunked prefill元数据
  3. Decode: 构建decode状态索引 + spec decode token处理
  4. Prefix caching: 计算block索引 (Mamba的prefix是状态checkpoint)

14.3 mamba1_attn.py 逐行解析 (64行)

Mamba-1 SSM后端:

  • 仅64行,thin wrapper
  • 使用 selective_scan_fn --- Mamba-1的selective scan算子
  • 无KV Cache → get_kv_cache_shape() 返回 (0,)

14.4 mamba2_attn.py 逐行解析 (171行)

Mamba-2 SSM后端:

  • 使用 ssd_fn --- Mamba-2的SSD(Structured State Duality)算法
  • 比Mamba-1更高效的矩阵化实现
  • 支持chunked prefill + decode
  • get_kv_cache_shape(): 返回 (num_blocks, block_size, state_dim) --- SSM状态缓存

14.5 gdn_attn.py 逐行解析 (475行)

GDN (Gated DeltaNet) 注意力后端:

14.5.1 GDNAttentionBackend
  • is_ssm() = True: GDN被归类为SSM(类似Mamba)
  • 不使用标准Q/K/V → 使用门控递推(gated recurrence)
  • KV Cache: 状态缓存(类似Mamba)
14.5.2 GDNAttentionMetadata
  • num_spec_decodes / num_spec_decode_tokens: Spec decode计数
  • has_initial_state: 是否有初始状态
  • spec_query_start_loc / non_spec_query_start_loc: spec/non-spec分离
  • chunk_indices / chunk_offsets: FLA(Flash Linear Attention) chunk元数据
  • nums_dict / batch_ptr / token_chunk_offset_ptr: Triton causal_conv1d参数
14.5.3 GDNAttentionMetadataBuilder
  • 类似Mamba的metadata构建
  • Spec decode: 分离spec和非spec请求
  • compute_causal_conv1d_metadata(): 计算causal_conv1d所需的元数据
  • mamba_get_block_table_tensor(): 获取Mamba格式的block_table

14.6 tree_attn.py 逐行解析 (442行)

Tree Attention --- 推测解码(Spec Decoding)专用注意力:

14.6.1 TreeAttentionBackend
  • get_supported_kernel_block_sizes(): [MultipleOf(16)] --- block_size必须是16的倍数
  • forward_includes_kv_cache_update = False: forward不包含KV写入(需外部处理)
  • use_cascade_attention() = False: 不支持cascade
14.6.2 TreeAttentionMetadata
python 复制代码
@dataclass
class TreeAttentionMetadata:
    num_actual_tokens: int
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
    tree_attn_bias: torch.Tensor | None = None  # 树形注意力掩码
  • tree_attn_bias: 核心字段 --- 树形注意力掩码
    • 定义了draft token之间的因果关系
    • 每个draft token只attend到其祖先token
14.6.3 TreeAttentionImpl

forward() 方法:

  1. 使用 unified_attention() (Triton统一注意力)执行实际计算
  2. tree_attn_bias 控制哪些token对之间可以attend
  3. 结果: 找到最长匹配前缀 → accept/reject draft tokens

设计意图: 一次性验证整棵draft tree → 比逐token串行验证更高效。多个draft路径并行计算,找到最深匹配。

14.7 turboquant_attn.py 逐行解析 (800行)

TurboQuant注意力 --- KV Cache极致压缩:

14.7.1 核心概念
  • KV Cache布局 : (num_blocks, block_size, num_kv_heads, slot_size)

    • slot_size = key_packed_size + value_fp16_size
    • 例: turboquant_k3v4_nc + head_dim=256: [100 bytes key | 512 bytes value] = 612
  • 量化模式:

    • turboquant_k8v4: 8-bit key + 4-bit value
    • turboquant_4bit_nc: 4-bit非均匀量化
    • turboquant_k3v4_nc: 3-bit key + 4-bit value (最激进)
    • turboquant_3bit_nc: 3-bit非均匀量化
14.7.2 _build_hadamard --- Hadamard矩阵
python 复制代码
def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
    H = torch.tensor([[1.0]])
    while H.shape[0] < d:
        H = torch.cat([torch.cat([H, H], 1), torch.cat([H, -H], 1)], 0)
    return (H / math.sqrt(d)).to(torch.device(device_str))
  • Sylvester构造的Hadamard矩阵
  • 用于Walsh-Hadamard变换(WHT) → 量化前的正交旋转
  • @functools.cache: 缓存避免重复构造
  • 优势: 单次cuBLAS GEMM代替log2(D)次butterfly操作
14.7.3 TurboQuantAttentionImpl

forward() 流程:

  1. Prefill路径:

    • 标准SDPA/FlashAttention计算attention
    • triton_turboquant_store(): 将KV量化写入压缩cache
    • 小continuation(≤128 tokens): 直接用TQ decode kernel
  2. Decode路径:

    • triton_turboquant_decode_attention(): 从压缩cache读取+解量化+计算attention
    • Stage1: MSE key解量化 + 质心查表
    • Stage2: 3/4-bit value解包 + softmax + weighted sum
  3. KV Cache写入(do_kv_cache_update):

    • Hadamard旋转 → 量化 → 打包到slot
    • FP8/MSE量化 + 3/4-bit value打包

14.8 flash_attn_diffkv.py 逐行解析 (284行)

FlashAttention DiffKV后端:

  • 支持差异化KV维度(key和value的head_dim不同)
  • 用于某些模型(K头维度 ≠ V头维度)
  • FlashAttention API原生支持不同KV维度
  • flash_attn_varlen_func(dv=...): 传入不同的V维度

14.9 linear_attn.py 逐行解析 (93行)

Linear Attention后端:

  • O(n)复杂度注意力(而非标准O(n²))
  • 使用kernel trick: φ(Q) × (φ(K)^T × V) → 先计算K^T×V(无序列长度依赖)
  • 93行精简实现
  • 用于Linear Transformer/Perceiver等模型

14.10 short_conv_attn.py 逐行解析 (34行)

短卷积注意力:

  • 仅34行,最短的后端
  • 不使用attention → 使用局部1D卷积
  • 用于某些Mamba-1变体的short convolution层
  • 无KV Cache

第十五章总结:专用后端设计哲学

15.1 平台适应性矩阵

后端 NVIDIA AMD Intel XPU CPU
FlashMLA ✅ SM90+ - - -
FlashMLASparse ✅ SM90+ - - -
FlashInferMLA - - -
FlashAttnMLA - - -
CUTLASSMLA ✅ SM90+ - - -
TritonMLA -
ROCmAiterMLA - - -
XPUMLASparse - - -
ROCmAttn - - -
ROCmAiterFA - - -
CPUAttn - - -

15.2 MLA后端选择决策

复制代码
DeepSeek-V3.2模型 → 选择MLA后端:
  ├── NVIDIA GPU:
  │   ├── SM90+ (Hopper/Blackwell) + Dense → FlashMLA
  │   ├── SM90+ + Sparse → FlashMLASparse
  │   ├── SM90+ + CUTLASS → CUTLASSMLA
  │   └── 其他SM → TritonMLA (跨平台fallback)
  ├── AMD GPU:
  │   ├── aiter可用 → ROCmAiterMLA
  │   └── aiter不可用 → TritonMLA
  └── Intel XPU:
      └── XPUMLASparse

15.3 特殊后端适用场景

后端 适用场景 核心优势
MambaAttn 状态空间模型 无KV Cache, 状态递推
TreeAttn 推测解码验证 树形并行验证draft
TurboQuant 长上下文+显存受限 KV极致压缩(3-4bit)
GDNAttn GatedDeltaNet 门控递推, 类Mamba
LinearAttn Linear Transformer O(n)注意力
ShortConv 局部卷积 无需全局注意力
FlashAttnDiffKV K/V维度不同 差异化KV支持

Part3 完 --- 涵盖MLA体系(13文件) + ROCm/AMD后端(3文件) + CPU/Mamba/GDN/Tree/TurboQuant等(10文件) = 26个文件的完整逐行分析

vLLM v1 Attention 底层算子(ops/)超深度逐行分析 --- Part 4

分析范围 : vLLM v1 attention 子系统 ops/ 目录下全部算子文件
源码根目录 : vllm/v1/attention/ops/
分析方法 : 逐行解析代码作用、语法、逻辑意图,说明每个变量/函数/模块的设计目的
撰写风格: 资深架构师视角,严谨、不漏细节


第十五章:通用算子(common.py)逐行解析

文件 : ops/common.py (465行)
职责: 上下文并行(Context Parallelism, CP)注意力输出修正、序列打包/解包工具函数

15.1 模块总览

common.py 是 vLLM v1 attention 算子层的"基础设施"模块,提供两类核心功能:

  1. CP 注意力输出修正:当使用上下文并行(Context Parallelism)时,每个 rank 只看到部分 KV 序列,产生局部注意力输出和局部 log-sum-exp(LSE)。需要在 rank 间 all-gather LSE 后修正各 rank 的输出,再通过 reduce-scatter 或 all-reduce 得到最终结果。
  2. 序列打包/解包(pack/unpack):将不等长序列压缩为定长 batch 张量(padding),以及反向解包。这对 MLA 稀疏注意力等需要 batch 维度的算子至关重要。

15.2 导入与全局变量

python 复制代码
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.triton_utils import tl, triton
  • torch:PyTorch 核心库,提供张量操作
  • GroupCoordinator:vLLM 分布式通信协调器,封装 NCCL 集合通信
  • tl, triton:Triton JIT 编译框架及其内建库(tl 为 triton.language)

15.3 _correct_attn_cp_out_kernel --- CP注意力输出修正Triton核

15.3.1 函数签名与参数
python 复制代码
@triton.jit
def _correct_attn_cp_out_kernel(
    outputs_ptr,        # 输入/输出注意力结果 [B, H, D]
    new_output_ptr,     # 修正后输出 [B, H, D]
    lses_ptr,           # 所有rank的LSE值 [N, B, H](N=world_size)
    vlse_ptr,           # 输出:全局LSE [B, H]
    outputs_stride_B, outputs_stride_H, outputs_stride_D,  # outputs的步幅
    lses_stride_N, lses_stride_B, lses_stride_H,          # lses的步幅
    lse_idx,            # 当前rank在CP组中的rank id
    HEAD_DIM: tl.constexpr,    # 头维度(编译时常量)
    N_ROUNDED: tl.constexpr,   # rank数量(编译时常量)
    IS_BASE_E: tl.constexpr,  # LSE是否以e为底(True)或以2为底(False)
)

设计意图 :此核修正 CP 场景下各 rank 的局部注意力输出。每个 rank 对自己的 KV 分片计算注意力后,得到局部 out_local 和局部 lse_local。但 softmax 的归一化范围应该是全局的(所有 KV 分片),所以需要:

  1. 从 all-gathered LSE 计算全局 LSE
  2. exp(lse_local - lse_global) 作为缩放因子修正输出
15.3.2 逐行解析
python 复制代码
batch_idx = tl.program_id(axis=0).to(tl.int64)  # 批次索引
head_idx = tl.program_id(axis=1).to(tl.int64)   # 头索引
d_offsets = tl.arange(0, HEAD_DIM)               # 头维度偏移向量
num_n_offsets = tl.arange(0, N_ROUNDED)           # rank偏移向量
  • 每个 program 处理一个 (batch, head) 对
  • d_offsetsnum_n_offsets 用于向量化加载
python 复制代码
lse_offsets = (
    num_n_offsets * lses_stride_N
    + batch_idx * lses_stride_B
    + head_idx * lses_stride_H
)
  • 计算所有 N 个 rank 的 LSE 在 lses_ptr 中的偏移
python 复制代码
lse = tl.load(lses_ptr + lse_offsets)  # 加载 [N] 个LSE值
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
  • NaN 和 +inf 处理lse != lse 检测 NaN(IEEE 754 下 NaN ≠ 自身),lse == inf 检测正无穷。这些异常值替换为 -inf,在 log-sum-exp 中等效于"无贡献"。
python 复制代码
lse_max = tl.max(lse, axis=0)  # 所有rank的LSE最大值(数值稳定性)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)  # 全为-inf时置0
lse -= lse_max  # 减去最大值(数值稳定化的标准操作)
  • log-sum-exp 数值稳定性 :先减去最大值再取 exp,避免溢出。若所有 LSE 为 -inf(空序列),最大值设为 0 以避免 (-inf) - (-inf) = NaN
python 复制代码
if IS_BASE_E:
    lse_exp = tl.exp(lse)          # e^(lse - max)
    lse_acc = tl.sum(lse_exp, axis=0)  # Σ e^(lse - max)
    lse = tl.log(lse_acc)          # log(Σ e^(lse - max))
else:
    lse_exp = tl.exp2(lse)         # 2^(lse - max)
    lse_acc = tl.sum(lse_exp, axis=0)
    lse = tl.log2(lse_acc)         # log2(Σ 2^(lse - max))
lse += lse_max  # 加回最大值 → 全局LSE
  • 双基底支持 :FlashAttention 2/3 内部使用 log2 计算 LSE(更快,因为 GPU 上 exp2exp 快),但数学上 LSE 通常以 e 为底。IS_BASE_E 标志决定使用哪种基底。
  • 最终公式 : lse_global = lse_max + log(Σ_i exp(lse_i - lse_max))
python 复制代码
tl.store(vlse_ptr + lse_offsets, lse)  # 只存 [B, H] 部分(去掉N维度)
  • 存储全局 LSE 供后续使用
python 复制代码
lse_offset = (lse_idx * lses_stride_N + batch_idx * lses_stride_B 
              + head_idx * lses_stride_H)
lse_tmp = tl.load(lses_ptr + lse_offset)   # 当前rank的原始LSE
lse_finally = lse_tmp - lse                 # 局部LSE - 全局LSE
lse_finally = tl.where(
    (lse_finally != lse_finally) | (lse_finally == float("inf")),
    -float("inf"), lse_finally,
)
factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally)
  • 修正因子exp(lse_local - lse_global) 是将局部归一化输出缩放到全局归一化输出的因子
  • 数学原理:局部输出 o_local = Σ(p_i * v_i) / Σ(p_i),其中 p_i = exp(q·k_i - lse_local)。全局输出 o_global = Σ(p_i * v_i) / Σ_all(p_i)。二者关系为 o_global = o_local * exp(lse_local - lse_global)
python 复制代码
output = tl.load(outputs_ptr + output_offsets)
output = output * factor
tl.store(new_output_ptr + output_offsets, output)
  • 加载原始输出、乘以修正因子、存储

15.4 CPTritonContext --- 避免Triton重编译的上下文

python 复制代码
class CPTritonContext:
    def __init__(self):
        self.inner_kernel = None

    def call_kernel(self, kernel, grid, *regular_args, **const_args):
        if self.inner_kernel is None:
            self.inner_kernel = kernel[grid](*regular_args, **const_args)
        else:
            self.inner_kernel[grid](*regular_args)

设计意图 :Triton JIT 编译的 kernel 在首次调用时编译,后续调用复用编译结果。但 kernel[grid](...) 语法每次都会创建新的启动配置对象。CPTritonContext 缓存已编译的 kernel 实例,避免重复创建。

  • 首次调用:编译 + 启动
  • 后续调用:仅启动(grid 可能不同,但 const_args 不变)

15.5 correct_attn_out --- CP注意力输出修正入口

python 复制代码
def correct_attn_out(
    out: torch.Tensor,         # [B, H, D]
    lses: torch.Tensor,        # [N, B, H]
    cp_rank: int,              # 当前rank在CP组中的rank id
    ctx: CPTritonContext,      # Triton上下文
    is_lse_base_on_e: bool = True,  # LSE基底
) -> tuple[torch.Tensor, torch.Tensor]:

维度规范化

python 复制代码
if out.ndim == 4 and out.shape[1] == 1:
    out = out.squeeze(1)  # [B, 1, H, D] → [B, H, D]
  • 有些后端(如 FlashAttention)可能输出 4D 张量,squeeze 掉大小为 1 的维度
python 复制代码
if lses.ndim == 4 and lses.shape[-1] == 1:
    lses = lses.squeeze(-1)
if lses.ndim == 4 and lses.shape[1] == 1:
    lses = lses.squeeze(1)
  • LSE 可能被封装为 [N, 1, B, H][N, B, H, 1],squeeze 到 3D

步幅处理

python 复制代码
lse = torch.empty_strided(
    (B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype
)
  • 关键设计lse 的步幅必须与 lses 在 B/H 维度上的步幅一致,因为核函数中 vlse_ptr 使用 lses_stride_B/H 写入。这保证了即使 lses 是非连续视图(如 4D squeeze),写入位置也正确。

核启动配置

python 复制代码
grid = (B, H, 1)  # 每个(batch, head)对一个program

15.6 _cp_lse_common --- CP LSE通用处理

python 复制代码
def _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=None, is_lse_base_on_e=True):
    if cp_group.world_size == 1:
        return cp_attn_out  # 单rank无需修正
  • 短路:如果 CP 组只有 1 个 rank(即未启用 CP),直接返回
python 复制代码
lses = cp_group.all_gather(cp_attn_lse, dim=0).reshape(
    (cp_group.world_size,) + cp_attn_lse.shape
)
  • All-gather 所有 rank 的 LSE,reshape 为 [N, B, H]

15.7 cp_lse_ag_out_rs / cp_lse_ag_out_ar --- 两种CP输出聚合策略

cp_lse_ag_out_rs(AllGather + ReduceScatter)

python 复制代码
out = cp_group.reduce_scatter(out, dim=1)  # 在head维度上scatter
  • 每个 rank 获得自己负责的 head 子集的输出
  • 适用于 TP+CP 联合场景,head 已经在 TP 中被分割

cp_lse_ag_out_ar(AllGather + AllReduce)

python 复制代码
out = cp_group.all_reduce(out)  # 全归约
  • 所有 rank 获得完整的输出
  • 适用于纯 CP 场景

15.8 _pack_seq_kernel --- 序列打包Triton核

python 复制代码
@triton.jit
def _pack_seq_kernel(
    x_ptr,        # [N, D] 输入(N=总token数)
    out_ptr,      # [B, Lmax, D] 输出
    lengths_ptr,  # [B] 各序列长度
    N, D, Lmax, PAD_VALUE, BLOCK_T, BLOCK_D: tl.constexpr
)

设计意图 :将 1D 连续 token 序列(格式 [total_tokens, D])打包为 2D batch 张量 [B, Lmax, D],短序列用 PAD_VALUE(默认 -inf)填充。这对 MLA 稀疏注意力中需要 batch 维度的 MQA logits 计算至关重要。

关键逻辑

python 复制代码
in_start = 0
for i in range(pid_b):
    in_start += tl.load(lengths_ptr + i)  # 累加得到当前batch的起始位置
seq_len = tl.load(lengths_ptr + pid_b)    # 当前序列长度
  • 通过累加 lengths 计算每个 batch 在展平序列中的起始索引
python 复制代码
valid_row = (off_t < seq_len) & t_mask  # 有效token掩码
  • 超出序列长度的位置写入 PAD_VALUE

15.9 pack_seq_triton / unpack_seq_triton --- 打包/解包Python入口

python 复制代码
def pack_seq_triton(x, lengths, pad_value=-float("inf"), block_t=64, block_d=64):
  • 支持多维输入(通过 reshape 处理)
  • pad_value=-inf 使得 padding 位置在 softmax 中贡献为 0
python 复制代码
def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64):
  • 反向操作:从 [B, Lmax, ...] 解包回 [N, ...]
  • 跳过 padding 位置

第十六章:Triton注意力算子族

涵盖文件:triton_unified_attention.py, triton_prefill_attention.py, triton_decode_attention.py, triton_reshape_and_cache_flash.py, triton_merge_attn_states.py

16.1 Triton统一注意力算子(triton_unified_attention.py, 1268行)

来源 : IBM Research 贡献,基于 vLLM 的 Triton attention 框架
职责: 提供统一的 prefill + decode 注意力计算,支持 2D(单段 softmax)和 3D(分段 softmax)两种调度模式

16.1.1 模块级常量与辅助函数
python 复制代码
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
float8_info = torch.finfo(current_platform.fp8_dtype())
  • is_batch_invariant:环境变量控制是否假设 batch 中所有序列等长(优化路径)
  • float8_info:FP8 数据类型的 min/max 范围,用于核内 clamp

cdiv_fn(Triton JIT 版向上取整):

python 复制代码
@triton.jit
def cdiv_fn(x, y):
    return (x + y - 1) // y

apply_softcap(Gemma2 等模型使用的 tanh softcap):

python 复制代码
@triton.jit
def apply_softcap(S, x):
    Sdiv = S / x
    p1 = tl.exp(Sdiv)
    p2 = tl.exp(-Sdiv)
    return x * (p1 - p2) / (p1 + p2)  # x * tanh(S/x)
  • 数学等价于 x * tanh(S / x),但用 exp 展开,避免 Triton 不直接支持 tanh 的问题
  • 用于限制注意力 logit 范围,提升训练稳定性
16.1.2 _prepare_kv_tile --- KV 量化反量化统一接口
python 复制代码
@triton.jit
def _prepare_kv_tile(
    data, Q, tensor_scale, scale_cache_ptr,
    physical_block_idx, seq_offset, kv_head_idx,
    stride_s_blk, stride_s_slot, stride_s_head,
    tile_mask, BLOCK_SIZE, KV_QUANT_MODE
)

KV_QUANT_MODE 设计

含义 反量化策略
0 无量化(bf16/fp16) 仅类型转换(实际无操作)
1 FP8 全局量化 data * tensor_scale(每个张量一个 scale)
2 int8 per-token-head 量化 返回 data + token_head_scales,调用者在 dot 后乘
3 FP8 per-token-head 量化 同上

为什么 KV_QUANT_MODE ≥ 2 时不在 tile 内反量化?

python 复制代码
if KV_QUANT_MODE >= 2:  # per-token-head (int8 or fp8)
    token_head_scales = tl.load(scale_cache_ptr + scale_idx, ...)
    return data.to(Q.dtype), token_head_scales
  • Per-token-head 量化下,每个 (token, head) 对有自己的 scale
  • 将 scale 乘法融合到 softmax scale 中(S += tl.dot(Q, K) * (scale * k_token_head_scales)),避免额外的 BLOCK_M×TILE_SIZE 逐元素乘法
  • 这是数值精度和计算效率的权衡:先做整数 dot product,再乘 scale
16.1.3 find_seq_idx --- 二分查找定位token所属序列
python 复制代码
@triton.jit
def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, 
                 BLOCK_Q, use_q_block_mode):
  • 问题:在 batch 预填充中,所有 token 展平排列,核函数需要知道当前 program 对应哪个序列
  • 算法 :二分查找 query_start_len 数组
  • use_q_block_mode=True 时,query_start_len 的值是 cu_seqlens[i] // BLOCK_Q + i(因为 q-block 是序列级分块,需要偏移序列索引)
16.1.4 kernel_unified_attention_2d --- 2D统一注意力核

核函数组织

复制代码
Grid: (total_num_q_blocks, num_kv_heads)
每个program: 处理一个 (q_block, kv_head) 对

核心计算流程

  1. 定位当前序列
python 复制代码
seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True)
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
  • q_block_global_idx:全局 q-block 编号
  • q_block_local_idx:在当前序列内的 q-block 编号
  1. 加载 Query
python 复制代码
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
  • GQA 支持num_queries_per_kv > 1 时,一个 kv_head 对应多个 q_head
  • offs_m 的低 log2(num_queries_per_kv) 位选 q_head,高位选 token
  1. Softmax 初始化
python 复制代码
if not USE_SINKS:
    M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
    M = tl.load(sink_ptr + query_offset_1, ...)  # 从 attention sink 初始化
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)  # 初始化为1(不是0,因为M初始为-inf时exp(-inf-(-inf))=1)
  • Attention Sinks(StreamingLLM 技术):初始 M 值从 sink 张量加载,允许模型"记住"序列开头的 token,防止滑动窗口导致的注意力崩溃
  1. Tile循环 --- 分块KV处理
python 复制代码
for j in range(tile_start, tile_end):
    seq_offset = j * TILE_SIZE + offs_t
    physical_block_idx = tl.load(block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE)
  • 遍历 KV 的 tile,通过 block_table 将逻辑位置映射到物理 cache 块
  1. 注意力掩码构建(causal + sliding_window + mm_prefix 三层叠加):
python 复制代码
seq_mask = seq_offset[None, :] <= query_abs_pos  # 因果掩码

if SLIDING_WINDOW > 0:
    seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)

if USE_MM_PREFIX:
    for i in range(MAX_MM_RANGES):
        # 多模态token的双向注意力范围
        seq_mask |= q_in_range & k_in_range
  • 掩码叠加顺序:先 causal AND sliding_window,再 OR mm_prefix
  • 这与 FlexAttention 的语义一致:(causal ∧ sliding_window) ∨ mm_prefix
  1. Online Softmax 累积(标准 FlashAttention 算法):
python 复制代码
m_j = tl.maximum(M, tl.max(S, axis=1))  # 新的行最大值
P = tl.exp(S - m_j[:, None])             # 未归一化的注意力权重
l_j = tl.sum(P, axis=1)                  # 行和
alpha = tl.exp(M - m_j)                  # 旧累积值的缩放因子
acc = acc * alpha[:, None]               # 缩放旧的累加器
L = L * alpha + l_j                      # 更新行和
M = m_j                                  # 更新行最大值
acc += tl.dot(P.to(V.dtype), V)          # 累积加权值
  1. Sliding Window 特殊处理
python 复制代码
if SLIDING_WINDOW:
    qpos_lo = q_block_local_idx * BLOCK_Q
    V = tl.where((context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0)
  • 对滑动窗口外的 V 值置零,防止它们通过 tl.dot 污染累加器
  • 为什么在 P 上掩码还不够?因为 tl.dot 的掩码语义可能不完全屏蔽,需要额外保护
  1. FP8 输出支持
python 复制代码
if USE_FP8:
    acc = acc * tl.load(out_scale)
    acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
16.1.5 kernel_unified_attention_3d --- 3D分段softmax核

核函数组织

复制代码
Grid: (total_num_q_blocks, num_kv_heads, num_par_softmax_segments)
每个program: 处理一个 (q_block, kv_head, segment) 三元组

与 2D 版本的核心区别

  • 将 KV 序列分成 NUM_SEGMENTS_PER_SEQ 段,每段独立计算部分 softmax(部分 M, L, acc)
  • 每个 segment 输出独立的 (部分输出, 部分M, 部分L)
  • 最终由 reduce_segments 核将所有段的结果合并

为什么需要 3D 模式?

  • 长序列 decode 时,单个 program 处理整个 KV 序列的 softmax 可能导致 GPU 利用率低(串行循环太长)
  • 3D 模式将 KV 分段并行处理,增加并行度
  • 代价:需要额外存储中间结果和一次 reduce

段范围计算

python 复制代码
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
for j in range(
    max(segm_idx * tiles_per_segment, tile_start),
    min((segm_idx + 1) * tiles_per_segment, tile_end),
):

Sink 初始化的特殊处理

python 复制代码
if USE_SINKS:
    if segm_idx == 0:
        M = tl.load(sink_ptr + ...)
    else:
        M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
  • 只有第一个 segment 使用 sink 值,因为 sink 是全局初始值
16.1.6 reduce_segments --- 分段结果归约
python 复制代码
@triton.jit
def reduce_segments(...)

算法(标准的 online softmax reduce):

  1. 加载所有 segment 的 M 值,计算全局最大值 overall_max
  2. 加载所有 segment 的 L 值,缩放后求和 overall_expsum
  3. 加载所有 segment 的输出,缩放后求和
  4. 归一化:acc = acc_sum / overall_expsum
16.1.7 unified_attention --- Python入口与调度策略

2D vs 3D 选择逻辑

python 复制代码
if (seq_threshold_3D is None
    or max_seqlen_q > 1          # 有prefill
    or num_seqs > seq_threshold_3D
    or is_batch_invariant):
    kernel_unified_attention_2d[...]  # 2D模式
else:
    kernel_unified_attention_3d[...]  # 3D模式(纯decode)
    reduce_segments[...]
  • 2D 模式:适用于 prefill + decode 混合批次
  • 3D 模式:仅在纯 decode 且序列数低于阈值时启用(阈值由配置控制)

BLOCK_M / BLOCK_Q 计算

python 复制代码
BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv)
BLOCK_Q = BLOCK_M // num_queries_per_kv
  • GQA 下 BLOCK_M 需要覆盖整个 kv_group 的 q_head
  • BLOCK_Q 是每个 kv_group 在一个 BLOCK_M 中覆盖的 token 数

Tile Size 选择(Gemma3 优化):

python 复制代码
def _get_tile_size(head_size, sliding_window, element_size, is_prefill):
    if _is_gemma3_attention(head_size, sliding_window):
        return 32  # Gemma3专用
    return 32 if is_prefill else (16 if element_size >= 2 else 32)
  • Gemma3 使用 sliding_window=1024 + head_size=128/256 的独特组合
  • 较大的 head_size 需要更大的 tile 以充分利用 tensor core
  • FP8(element_size=1)需要 tile≥32 以满足对齐要求

16.2 Triton Prefill专用算子(triton_prefill_attention.py, 253行)

来源 : 适配自 SGLang,原作者 ModelTC/LightLLM
职责: 纯 prefill 场景的注意力计算,不支持 KV cache 分页(page_size=1假设)

16.2.1 核函数 _fwd_kernel

与统一核的关键区别

  • 不使用 KV cache:直接从连续的 K、V 张量读取(prefill 阶段 K/V 还未写入 cache)
  • 使用 exp2 替代 exp :通过 sm_scale *= RCP_LN2(1/ln(2))将 softmax scale 融入 base-2 指数
  • 支持双向滑动窗口SLIDING_WINDOW_QSLIDING_WINDOW_K 分别控制 Q→K 和 K→Q 方向的窗口

注意力掩码

python 复制代码
mask = pos_k < cur_batch_seq_len  # 序列边界
if IS_CAUSAL:
    mask &= pos_q >= pos_k        # 因果掩码
if SLIDING_WINDOW_Q > 0:
    mask &= pos_q - pos_k <= SLIDING_WINDOW_Q  # 前向窗口
if SLIDING_WINDOW_K > 0:
    mask &= pos_k - pos_q <= SLIDING_WINDOW_K  # 后向窗口

Block Size 自适应

python 复制代码
def get_block_size(dtype):
    if dtype == torch.float32: return 32
    elif has_device_capability(80): return 128  # A100+
    else: return 64

16.3 Triton Decode专用算子(triton_decode_attention.py, 778行)

来源 : 适配自 SGLang,原始设计来自 LightLLM 的 DeepSeek-V2 GQA 解码核
职责: 高效的 decode 阶段注意力,支持 MHA、GQA、MQA、MLA

16.3.1 两阶段架构

Stage 1:Split-KV 并行计算

  • 将 KV 序列分成 NUM_KV_SPLITS
  • 每个 split 独立计算部分注意力输出和 LSE
  • 输出中间结果 Att_Out [B, H, NUM_KV_SPLITS, D+1](最后1个元素存 LSE)

Stage 2:跨 split 归约

  • 加载所有 split 的部分输出和 LSE
  • Online softmax reduce 得到最终输出

设计优势

  • 长序列 decode 时,单 program 串行遍历整个 KV 序列成为瓶颈
  • Split-KV 将序列分段并行处理,增加 SM 利用率
  • 两阶段通信量小:仅在 stage2 做一次 reduce
16.3.2 _fwd_kernel_stage1 --- MHA版本
python 复制代码
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
    kv_page_number = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n // PAGE_SIZE)
    kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
  • 分页 KV cache 访问 :通过 Req_to_tokens(block table)将逻辑位置映射到物理位置
  • 支持任意 PAGE_SIZE(不仅限于 1)

FP8 反量化

python 复制代码
if k.dtype.is_fp8():
    k = (k.to(tl.float32) * ks).to(q.dtype)  # FP8 → fp32 → 反量化 → q.dtype

Logit Cap(tanh softcap)

python 复制代码
if logit_cap > 0:
    qk = logit_cap * tanh(qk / logit_cap)

中间输出格式

复制代码
Att_Out[b, h, s, 0:D]   = 部分注意力输出(已除以部分expsum)
Att_Out[b, h, s, D]     = 部分LSE = e_max + log(e_sum)
16.3.3 _fwd_grouped_kernel_stage1 --- GQA/MQA/MLA版本

关键区别 :一个 program 处理 BLOCK_H 个 q_head(而非 1 个)

python 复制代码
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
VALID_BLOCK_H = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
  • 将多个 q_head 捆绑处理,共享 KV 加载,减少全局内存访问
  • kv_group_num = Hq // Hkv,即 GQA 组大小

MLA(Multi-head Latent Attention)特殊处理

python 复制代码
if not IS_MLA:
    v = tl.load(V_Buffer + offs_buf_v, ...)  # 正常加载V
else:
    v = tl.trans(k)  # MLA: V = K^T(共享压缩KV)
  • DeepSeek-V2/V3 的 MLA 将 KV 压缩为单个低维向量 c_kv
  • V 就是 K 的转置(因为 c_kv 同时编码 K 和 V 的信息)
  • 避免重复加载相同的 c_kv

DPE(Decoupled Position Embedding)支持

python 复制代码
if BLOCK_DPE > 0:
    kpe = tl.load(K_Buffer + offs_buf_kpe, ...)
    qk += tl.dot(qpe, kpe.to(qpe.dtype))
  • MLA 的 key 由 k_nope(非位置部分)和 k_pe(位置编码部分)组成
  • qk = q_nope @ k_nope + q_pe @ k_pe
  • 对应 DeepSeek-V2 的解耦 RoPE 设计

具体维度映射

python 复制代码
if Lk == 576:      # DeepSeek-V2 256+64的head_dim+nope分解
    BLOCK_DMODEL = 512
    BLOCK_DPE = 64
elif Lk == 288:    # DeepSeek-V2 小模型的128+32分解
    BLOCK_DMODEL = 256
    BLOCK_DPE = 32
16.3.4 _fwd_kernel_stage2 --- 跨split归约
python 复制代码
for split_kv_id in range(0, NUM_KV_SPLITS):
    tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d)
    tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
    n_e_max = tl.maximum(tlogic, e_max)
    old_scale = tl.exp(e_max - n_e_max)
    acc *= old_scale
    exp_logic = tl.exp(tlogic - n_e_max)
    acc += exp_logic * tv
    e_sum = e_sum * old_scale + exp_logic
    e_max = n_e_max
  • 标准的 online softmax reduce,与 stage1 的累积逻辑对称
16.3.5 入口函数
python 复制代码
def decode_attention_fwd(q, k_buffer, v_buffer, o, lse, req_to_token, 
                         b_seq_len, attn_logits, num_kv_splits, sm_scale, 
                         page_size=1, logit_cap=0.0, k_scale=None, 
                         v_scale=None, is_mla=False):
  • MHA vs GQA 选择
python 复制代码
if kv_group_num == 1:
    decode_attention_fwd_normal(...)   # MHA:每个program一个q_head
else:
    decode_attention_fwd_grouped(...)  # GQA/MQA/MLA:每个program一组q_head

16.4 Triton KV Cache写入算子(triton_reshape_and_cache_flash.py, 601行)

职责: 将新生成的 K、V 写入分页 KV cache,支持 FP8 量化和 per-token-head 动态量化

16.4.1 reshape_and_cache_kernel_flash --- 主KV写入核

两种 KV Cache 布局

Head-Major Layout(5D)

复制代码
K_cache: [num_blocks, num_kv_heads, head_size//x, block_size, x]
V_cache: [num_blocks, num_kv_heads, head_size, block_size]
  • Key 使用 5D 布局以优化 Tensor Core 访问(x 维度对齐到 16B)

Slot-Major Layout(4D)

复制代码
KV_cache: [num_blocks, block_size, num_kv_heads, head_size+head_size_v]
  • K 和 V 在同一张量中连续存储

写入逻辑

python 复制代码
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
if slot_idx < 0:
    return  # padding token,跳过
block_idx = slot_idx // block_size
block_offset = slot_idx % block_size
  • slot_mapping:每个 token 在 KV cache 中的物理位置
  • slot_idx < 0 表示 padding(如 chunked prefill 中的虚拟 token)

FP8 量化

python 复制代码
if FP8_KV_CACHE:
    key_tile = key_load / tl.load(k_scale)  # 反向量化:除以scale → FP8范围
else:
    key_tile = key_load
  • 注意:Triton 的 tl.store 会自动将 float 值截断为目标 dtype(FP8)
  • k_scale / v_scale 是全局张量级 scale(per-tensor quantization)
16.4.2 _reshape_cache_per_token_head --- Per-Token-Head动态量化核

设计意图:KV cache 的 per-token-head 量化比 per-tensor 量化精度更高,每个 (token, head) 对有独立的 scale。

Grid : (num_tokens, num_kv_heads) --- 每个 program 处理一个 (token, head) 对

核心逻辑

python 复制代码
k_h = tl.load(key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs,
              mask=k_mask, other=0.0).to(tl.float32)
k_scale = tl.maximum(tl.max(tl.abs(k_h)) / QUANT_MAX, 1e-6)
tl.store(k_scale_cache_ptr + blk * stride_ks_blk + slot_in_blk * stride_ks_slot + head * stride_ks_head, k_scale)
k_q = tl.clamp(k_h * (1.0 / k_scale), QUANT_MIN, QUANT_MAX)
tl.store(key_cache_ptr + ..., k_q, mask=k_mask)
  1. 加载一个 (token, head) 的 K 值
  2. 计算 absmax / QUANT_MAX 作为 scale
  3. 量化:clamp(value / scale, min, max)
  4. 存储:量化值写入 key_cache,scale 写入 k_scale_cache

量化参数映射

python 复制代码
_PER_TOKEN_HEAD_QUANT_PARAMS = {
    torch.int8: (127.0, -128.0),
    FP8_DTYPE: (FP8_MAX, FP8_MIN),
}
16.4.3 reshape_and_cache_kernel_flash_diffkv --- 差异化KV维度写入

设计意图 :MLA(如 DeepSeek-V2)中 K 和 V 的维度不同(K 包含 nope+pe,V 仅 nope 部分)。此核将 K 和 V 写入同一个合并 cache [num_blocks, block_size, num_heads, head_size_k + head_size_v]


16.5 Triton注意力状态合并(triton_merge_attn_states.py, 175行)

数学原理 : Section 2.2 of Split-KV Attention论文
职责: 合并两段部分注意力输出(prefix/suffix),用于 chunked prefill 场景

核函数 merge_attn_states_kernel

两种 token 处理路径

1. 无 prefix context 的 token(decode token)

python 复制代码
if not prefix_mask:
    s_lse = tl.load(suffix_lse + ...)
    s_out = tl.load(suffix_output + ...)
    tl.store(output + ..., s_out)  # 直接复制suffix输出

2. 有 prefix context 的 token(prefill token)

python 复制代码
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)

# FA2/FA3 兼容性处理
p_lse = float("-inf") if p_lse == float("inf") else p_lse
s_lse = float("-inf") if s_lse == float("inf") else s_lse

max_lse = tl.maximum(p_lse, s_lse)
p_se = tl.exp(p_lse - max_lse)  # prefix的exp(lse - max)
s_se = tl.exp(s_lse - max_lse)  # suffix的exp(lse - max)
out_se = p_se + s_se             # 合并的exp sum

# LSE计算(可选输出)
if OUTPUT_LSE:
    out_lse = tl.log(out_se) + max_lse

# 加权合并
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale

数学原理

设 prefix 部分的注意力输出为 O_p = Σ_i (α_i * v_i),suffix 部分为 O_s = Σ_j (β_j * v_j)

全局注意力输出为:

复制代码
O = (Σ_i α_i * v_i + Σ_j β_j * v_j) / (Σ_i α_i + Σ_j β_j)
  = O_p * exp(lse_p - lse) + O_s * exp(lse_s - lse)

其中 lse = log(exp(lse_p) + exp(lse_s))(log-sum-exp 合并)。


第十七章:Prefix Prefill与Chunked Prefill算子

涵盖文件:prefix_prefill.py, chunked_prefill_paged_decode.py, merge_attn_states.py, paged_attn.py

17.1 Prefix Prefill算子(prefix_prefill.py, 864行)

来源 : 适配自 LightLLM 的 context_attention_fwd
职责: 带有 prefix(KV cache 中已有 context)的 prefill 注意力计算

17.1.1 全局参数
python 复制代码
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
NUM_WARPS = 4 if current_platform.is_rocm() else 8
IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())
  • Ampere+ (SM80+): BLOCK=128,更多的 shared memory
  • Turing (SM75): 使用 IEEE 浮点(无 tensor core FP32 支持)
17.1.2 _fwd_kernel --- 主Prefill核

核心架构:两阶段注意力计算

阶段1 --- Context注意力(Q vs KV cache中的prefix):

python 复制代码
for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache):
    token_indices = start_n + offs_bs_n
    bn_logical_indices = token_indices // PHYSICAL_BLOCK_SIZE
    bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + bn_logical_indices * stride_b_loc_s)
    internal_offsets = token_indices % PHYSICAL_BLOCK_SIZE
  • 非标准 block_size 支持

    • 正常模型:block_size = 16/32/64/128/256(2的幂)
    • Qwen3-Next-80B:block_size = 544(非2的幂!)
    • 当 block_size 不是 2 的幂时,一个 32-token 的 tile 可能跨越两个物理块
    • 解决方案:每个 token 独立计算逻辑块索引 token_indices // PHYSICAL_BLOCK_SIZE
  • K/V 地址计算(5D K cache):

python 复制代码
off_k = (
    bn[None, :] * stride_k_cache_bs           # 物理块偏移
    + cur_kv_head * stride_k_cache_h           # head偏移
    + (offs_d[:, None] // x) * stride_k_cache_d  # dim//x偏移
    + internal_offsets[None, :] * stride_k_cache_bl  # block内位置偏移
    + (offs_d[:, None] % x) * stride_k_cache_x    # x内偏移
)
  • Key cache 5D 布局:[num_blocks, num_kv_heads, head_size//x, block_size, x]
  • x = 16 // element_size:将 head_size 维度拆分为 (head_size//x, x) 以优化内存访问

阶段2 --- Query自注意力(Q vs 新K/V,带因果掩码):

python 复制代码
for start_n in tl.range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N,
                         loop_unroll_factor=num_unroll_request):
    qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
  • 因果掩码:Q 只能看到自身及之前的 token
  • block_mask * (start_m + 1) * BLOCK_M:当当前 BLOCK_M 超出 query 范围时,跳过

FP8 KV 反量化

python 复制代码
if k_load.dtype.is_fp8():
    k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype)
else:
    k = k_load

Sliding Window

python 复制代码
if SLIDING_WINDOW > 0:
    qk = tl.where(
        (cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW,
        qk, float("-inf"),
    )
  • 绝对位置计算:query 位置 = ctx_len + query_pos,key 位置 = start_n + key_pos
  • 只允许 query 看到 SLIDING_WINDOW 范围内的 key

FP8 输出

python 复制代码
if USE_FP8:
    acc = acc * tl.load(out_scale_inv)  # scale → FP8范围
    acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
17.1.3 _fwd_kernel_alibi --- ALiBi位置编码版本

ALiBi(Attention with Linear Biases)

python 复制代码
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None]) * alibi_slope
alibi = tl.where((alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, float("-inf"))
qk += alibi
  • ALiBi 将位置信息编码为注意力偏置,而非 RoPE 那样修改 Q/K
  • alibi_slope:每个 head 的斜率(预计算,通常为几何序列)
  • 偏置公式:bias = slope * (key_pos - query_pos),且只允许 bias ≤ 0(因果性)
17.1.4 context_attention_fwd --- Python入口

非标准 block_size 处理

python 复制代码
is_pow2 = real_block_size > 0 and (real_block_size & (real_block_size - 1) == 0)
if is_pow2:
    BLOCK_M = 128
    BLOCK_N = 64
else:
    BLOCK_M = 32
    BLOCK_N = 32
TRITON_BLOCK_SIZE = 32  # 始终使用32作为内部tile大小
  • 2的幂 block_size:可以使用较大的 BLOCK_M/N,因为一个 tile 不会跨块
  • 非2的幂(如544):必须使用小的 BLOCK_M/N=32,配合跨块索引逻辑

Block Table 指针归一化

python 复制代码
if is_block_table_ptr:
    kv_element_size = k_cache.element_size()
    block_byte_stride = k_cache.stride(0) * kv_element_size
    base_addr = k_cache.data_ptr()
    processed_b_loc = torch.where(
        mask, (b_loc - base_addr) // block_byte_stride, b_loc
    ).to(torch.int32)
  • 某些后端(如 FlashAttention)的 block_table 存储的是字节指针而非块索引
  • 转换为块索引:(ptr - base_addr) / block_byte_stride

17.2 Chunked Prefill + Paged Decode(chunked_prefill_paged_decode.py, 467行)

来源 : IBM Research 贡献
职责: 混合 batch 中同时处理 prefill 和 decode 请求

17.2.1 kernel_paged_attention_2d --- Paged Decode核

Grid : (num_seqs, num_kv_heads) --- 每个program处理一个 (序列, kv_head) 对

Query 加载(GQA展开):

python 复制代码
query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange(0, num_queries_per_kv_padded)
Q = tl.load(query_ptr + query_offset + tl.arange(0, HEAD_SIZE_PADDED)[None, :],
            mask=dim_mask[None, :] & head_mask[:, None], other=0.0)
  • 一次加载一个 kv_group 的所有 q_head(num_queries_per_kv_padded 个)
  • padded 确保是 2 的幂,为 tl.dot 对齐

Block遍历(非标准block_size支持):

python 复制代码
for j in range(0, num_blocks):
    abs_token_idx = start_n + offs_n
    l_block_idx = abs_token_idx // PHYSICAL_BLOCK_SIZE  # 逻辑块索引
    p_block_idx = tl.load(block_tables_ptr + block_table_offset + l_block_idx)
    internal_offsets = abs_token_idx % PHYSICAL_BLOCK_SIZE
  • 与 prefix_prefill 相同的跨块索引逻辑

Decode token 过滤

python 复制代码
if filter_by_query_len:
    cur_batch_query_len = tl.load(query_start_len_ptr + seq_idx + 1) - tl.load(query_start_len_ptr + seq_idx)
    if cur_batch_query_len > 1:
        return  # 跳过prefill token(它们由prefix_prefill核处理)
17.2.2 chunked_prefill_paged_decode --- 混合调度入口

两阶段处理

python 复制代码
if max_query_len > 1:
    context_attention_fwd(...)  # Prefill tokens: 使用prefix_prefill核

kernel_paged_attention_2d[...]  # Decode tokens: 使用paged decode核
  1. 先处理 prefill token(使用 prefix_prefill.context_attention_fwd,skip_decode=True)
  2. 再处理 decode token(使用 kernel_paged_attention_2d,filter_by_query_len=True)

ROCm 自定义分页注意力回退

python 复制代码
if use_custom:
    ops.paged_attention_rocm(...)  # ROCm自定义CUDA核
else:
    kernel_paged_attention_2d[...]  # Triton回退

17.3 注意力状态合并(merge_attn_states.py, 103行)

设计意图:提供 CPU/GPU 双路径的注意力状态合并,自动选择最优实现。

python 复制代码
def merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse, ...):
    if current_platform.is_cuda() and supported_dtypes(prefix_output) and supported_headdim(prefix_output):
        from vllm._custom_ops import merge_attn_states  # CUDA核
        return merge_attn_states(...)
    else:
        from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states  # Triton核
        return merge_attn_states(...)

CUDA 核条件

  • supported_dtypes:float32/half/bfloat16
  • supported_headdim:head_dim 必须是 4(float32)或 8(half/bf16)的倍数(128b 对齐)

17.4 分页注意力入口(paged_attn.py, 51行)

PagedAttention:传统分页注意力的入口点。

split_kv_cache:将联合 KV cache 拆分为 K cache 和 V cache

python 复制代码
@staticmethod
def split_kv_cache(kv_cache, num_kv_heads, head_size):
    x = 16 // kv_cache.element_size()  # 16B对齐因子
    key_cache = kv_cache[0].view(num_blocks, num_kv_heads, head_size // x, -1, x)
    value_cache = kv_cache[1].view(num_blocks, num_kv_heads, head_size, -1)
    return key_cache, value_cache

write_to_paged_cache:调用 CUDA 核写入 KV cache

python 复制代码
@staticmethod
def write_to_paged_cache(key, value, key_cache, value_cache, slot_mapping, 
                         kv_cache_dtype, k_scale, v_scale):
    ops.reshape_and_cache(key, value, key_cache, value_cache, 
                          slot_mapping.flatten(), kv_cache_dtype, k_scale, v_scale)
  • 调用 vLLM 自定义 CUDA 算子 ops.reshape_and_cache

第十八章:ViT/DCP/FlashMLA/特殊算子

涵盖文件:flashmla.py, vit_attn_wrappers.py, dcp_alltoall.py, rocm_aiter_mla_sparse.py, xpu_mla_sparse.py, triton_turboquant_store.py, triton_turboquant_decode.py

18.1 FlashMLA算子(flashmla.py, 153行)

来源 : 适配自 DeepSeek-AI 的 FlashMLA
职责: DeepSeek-V2/V3 MLA 注意力的专用后端

18.1.1 可用性检测
python 复制代码
def _is_flashmla_available() -> tuple[bool, str | None]:
    if not _flashmla_C_AVAILABLE:
        return False, "vllm._flashmla_C is not available..."
    if not _flashmla_extension_C_AVAILABLE:
        return False, "vllm._flashmla_extension_C is not available..."
    return True, None
  • 双重检测:核心库 _flashmla_C 和扩展库 _flashmla_extension_C
  • 编译失败时优雅降级
python 复制代码
def is_flashmla_dense_supported():
    if not current_platform.is_device_capability_family(90):
        return False, "FlashMLA Dense is only supported on Hopper devices."
    return True, None

def is_flashmla_sparse_supported():
    if not (is_device_capability_family(90) or is_device_capability_family(100)):
        return False, "FlashMLA Sparse is only supported on Hopper and Blackwell devices."
    return True, None
  • Dense:仅 Hopper(SM90),因为 FlashMLA Dense 使用了 Hopper 专用的 WMMA 指令
  • Sparse:Hopper + Blackwell(SM100),因为稀疏注意力需要 2:4 稀疏矩阵支持
18.1.2 FP8 MLA Decode
python 复制代码
def flash_mla_with_kvcache_fp8(q, k_cache, block_table, cache_seqlens,
                                head_dim_v, tile_scheduler_metadata, num_splits, ...):
    out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
        q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal,
        tile_scheduler_metadata, num_splits, descale_q, descale_k)
    return out, softmax_lse
  • 专用 FP8 MLA decode 核:直接在 FP8 精度下计算 Q·K 点积
  • tile_scheduler_metadata / num_splits:FlashMLA 的自定义 tile 调度器输出
  • descale_q / descale_k:FP8 反量化 scale

18.2 ViT注意力封装(vit_attn_wrappers.py, 361行)

职责 : 为 Vision Transformer 提供多种注意力后端的统一封装,兼容 torch.compile

18.2.1 设计动机

ViT 注意力与 LLM 注意力有本质区别:

  • 双向注意力(非因果)
  • 固定序列长度(图像 patch 数量固定)
  • 不使用 KV cache(每次推理都是完整 forward)

torch.compile 不支持 .item() 等动态操作(如 FlashAttention 的 max_seqlen 参数),所以需要封装。

18.2.2 四种后端封装

1. FlashAttention 封装

python 复制代码
def flash_attn_maxseqlen_wrapper(q, k, v, batch_size, is_rocm_aiter, fa_version, ...):
    q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
    output = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, ...)
    context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
  • einops.rearrange:批次维度展平/恢复
  • cu_seqlens:构建累积序列长度数组(ViT 中每个序列等长)

2. Triton 封装

python 复制代码
def triton_attn_wrapper(q, k, v, batch_size, scale, cu_seqlens, max_seqlen):
    q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
    context_attention_fwd(q, k, v, output, is_causal=False, ...)
  • 使用 triton_prefill_attention.context_attention_fwd
  • is_causal=False:ViT 使用双向注意力

3. PyTorch SDPA 封装

python 复制代码
def torch_sdpa_wrapper(q, k, v, scale, cu_seqlens, enable_gqa):
    if cu_seqlens is None:
        return apply_sdpa(q, k, v, scale, enable_gqa)
    # 有cu_seqlens时,按序列拆分分别计算
    for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
        output_i = apply_sdpa(q_i, k_i, v_i, scale, enable_gqa)
  • 回退方案:使用 PyTorch 原生 F.scaled_dot_product_attention
  • 支持变长序列(按 cu_seqlens 拆分)

4. FlashInfer 封装

python 复制代码
def flashinfer_wrapper(q, k, v, scale, workspace_buffer, cu_seqlens, ...):
    output, _ = cudnn_batch_prefill_with_kv_cache(
        q, k, v, scale, workspace_buffer, causal=False, ...)
  • 使用 cuDNN 后端的 FlashInfer
  • cu_seqlens 拆分为 Q/K/O 和 V 两套偏移(cuDNN 要求)
18.2.3 torch.compile 兼容性

所有封装都通过 direct_register_custom_op 注册为自定义算子:

python 复制代码
direct_register_custom_op(
    op_name="flash_attn_maxseqlen_wrapper",
    op_func=flash_attn_maxseqlen_wrapper,
    fake_impl=flash_attn_maxseqlen_wrapper_fake,
)
  • fake_impl:提供 torch.compile 的 meta tensor 推断
  • 避免 torch.compile 试图追踪 Python 控制流和 .item() 调用

18.3 DCP AlltoAll通信(dcp_alltoall.py, 363行)

论文参考 : https://arxiv.org/abs/2507.07120
职责: Decode Context Parallelism 的 All-to-All 通信后端,替代 AllGather+ReduceScatter

18.3.1 设计动机

传统 CP 通信流程(AG+RS):

复制代码
1. AllGather Q → 每个rank获得完整Q
2. AllGather K metadata
3. 本地注意力计算
4. ReduceScatter output → 每个rank获得部分head的输出

= 3 次 NCCL 集合通信

A2A 通信流程:

复制代码
1. 本地注意力计算(每个rank拥有部分head + 本地KV分片)
2. All-to-All output → 交换各rank的部分注意力输出
3. All-to-All LSE → 交换各rank的LSE
4. 本地 LSE-weighted combine → 最终输出

= 2 次 NCCL 集合通信

优势:减少一次 NCCL 调用,降低延迟。长序列 decode 时 NCCL 延迟是 step 时间的显著组成部分。

18.3.2 _lse_weighted_combine --- CPU参考实现
python 复制代码
def _lse_weighted_combine(outputs, lses, return_lse, is_lse_base_on_e):
    lse_max, _ = lses.max(dim=0)  # [B, H]
    weights = torch.exp(lses - lse_max.unsqueeze(0))  # [N, B, H]
    weight_sum = weights.sum(dim=0, keepdim=True)
    weights = weights / weight_sum.clamp(min=1e-10)
    result = (outputs * weights.unsqueeze(-1)).sum(dim=0)  # [B, H, D]
  • 纯 PyTorch 实现,用于测试和验证
  • 标准的 LSE-weighted 组合:output = Σ_n (output_n * exp(lse_n - lse_max)) / Σ_n exp(lse_n - lse_max)
18.3.3 _dcp_lse_combine_kernel --- Triton GPU核

三遍扫描算法

  1. 第一遍:找所有 rank 的 LSE 最大值(数值稳定性)
  2. 第二遍 :计算 Σ exp(lse - max) 和全局 LSE
  3. 第三遍:加权组合所有 rank 的部分输出
python 复制代码
for n in tl.static_range(N):
    lse_val = tl.load(recv_lse_ptr + n * rl_stride_N + base_lse_offset)
    lse_val = tl.where((lse_val != lse_val) | (lse_val == float("inf")), -float("inf"), lse_val)
    lse_max = tl.maximum(lse_max, lse_val)
  • tl.static_range(N):编译时展开循环(N 是 constexpr)
18.3.4 dcp_a2a_lse_reduce --- 完整A2A流程
python 复制代码
def dcp_a2a_lse_reduce(cp_attn_out, cp_attn_lse, cp_group, ...):
    # Reshape for All-to-All: [B, H, D] → [N, B, H/N, D]
    send_output = local_output.view(B, world_size, H_per_rank, D).permute(1, 0, 2, 3)
    send_lse = local_lse.view(B, world_size, H_per_rank).permute(1, 0, 2)

    # Async A2A overlap
    work_output = dist.all_to_all_single(recv_output.view(-1), send_output.view(-1), 
                                          group=cp_group.device_group, async_op=True)
    work_lse = dist.all_to_all_single(recv_lse.view(-1), send_lse.view(-1),
                                       group=cp_group.device_group, async_op=True)
    work_output.wait()
    work_lse.wait()

    # Local LSE-weighted combination
    return dcp_lse_combine_triton(recv_output, recv_lse, ...)
  • Reshape 逻辑 :将 head 维度拆分为 N 份,每个 rank 的数据标记为发送给对应 rank
  • 异步 A2A:output 和 LSE 的 A2A 同时启动,重叠通信
  • 本地组合:无需额外通信,Triton 核完成

18.4 ROCm Aiter MLA稀疏算子(rocm_aiter_mla_sparse.py, 655行)

职责 : DeepSeek-V3.2 的 MLA 稀疏注意力在 AMD ROCm 平台上的实现
核心: Indexer(token选择器) + FP8 MQA Logits + TopK 筛选

18.4.1 整体架构

DeepSeek-V3.2 的稀疏 MLA 注意力流程:

  1. Indexer: 用 FP8 MQA logits 计算每个 query 应该关注哪些 KV token
  2. TopK: 选择 top-k 个最相关的 KV token
  3. Sparse Attention: 只在选中的 KV token 上计算完整注意力
18.4.2 _indexer_k_quant_and_cache_kernel --- Indexer K量化写入
python 复制代码
@triton.jit
def _indexer_k_quant_and_cache_kernel(k_ptr, kv_cache_ptr, kv_cache_scale_ptr, 
                                       slot_mapping_ptr, ...):
    val = tl.load(src_ptr + offset)
    amax = tl.max(val.abs(), axis=-1).to(tl.float32)
    scale = tl.maximum(1e-4, amax) / 448.0  # FP8 E4M3 max = 448
    if USE_UE8M0:
        scale = tl.exp2(tl.ceil(tl.log2(scale)))  # 对齐到2的幂
    fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
  • FP8 量化:per-token absmax 量化,scale = absmax / 448
  • UE8M0 格式 :scale 对齐到 2 的幂(exp2(ceil(log2(scale)))),某些硬件优化需要
18.4.3 _cp_gather_indexer_quant_cache_kernel --- 从KV cache读取量化K
python 复制代码
@triton.jit
def _cp_gather_indexer_quant_cache_kernel(kv_cache_ptr, kv_cache_scale_ptr,
                                            k_fp8_ptr, k_scale_ptr, ...):
    # 通过block_table定位KV cache中的K
    block_id = tl.load(block_table_ptr + block_table_offset)
    # 读取FP8 K和scale
    val = tl.load(src_cache_ptr + tiled_src_offset)
    scale_val = tl.load(src_scale_ptr)
    # 存储到连续缓冲区
    tl.store(dst_k_ptr + offset, val)
    tl.store(k_scale_ptr + tid, scale_val)
  • 从分页 KV cache 中读取已量化的 K,组装为连续缓冲区供 MQA logits 计算
18.4.4 rocm_aiter_sparse_attn_indexer --- 完整Indexer流程

Prefill 路径

python 复制代码
if has_prefill:
    for chunk in prefill_metadata.chunks:
        # 1. 从KV cache读取量化K
        ops.cp_gather_indexer_k_quant_cache(kv_cache, k_fp8, k_scale, ...)
        # 2. FP8 MQA logits
        logits = rocm_fp8_mqa_logits(q_fp8[chunk.token_start:chunk.token_end], ...)
        # 3. TopK筛选
        torch.ops._C.top_k_per_row_prefill(logits, ...)

Decode 路径

python 复制代码
if has_decode:
    # 1. Pack Q为batch格式
    padded_q_fp8_decode_tokens = pack_seq_triton(q_fp8[:num_decode_tokens], decode_lens)
    # 2. FP8 Paged MQA logits
    logits = rocm_fp8_paged_mqa_logits(padded_q_fp8_decode_tokens, kv_cache, ...)
    # 3. TopK筛选
    torch.ops._C.top_k_per_row_decode(logits, ...)
    # 4. Unpack结果
    topk_indices = unpack_seq_triton(topk_indices.reshape(...), decode_lens)
  • 使用 pack_seq_triton / unpack_seq_triton(来自 common.py)处理不等长 decode 序列

18.5 XPU MLA稀疏算子(xpu_mla_sparse.py, 265行)

职责 : Intel XPU(GPU Max/Flex)上的 MLA 稀疏注意力实现
设计: BF16 精度的 Triton 实现,无需 FP8 量化

核函数 _bf16_mla_sparse_kernel

与 ROCm 版本的关键区别

  • 直接使用 BF16 精度(无 FP8 量化步骤)
  • 使用 exp2(base-2 指数)替代 exp(GPU 上 exp2 更快)
  • sm_scale *= LOG2E 将 softmax scale 转换到 base-2

Index-based 稀疏注意力

python 复制代码
for start_indice in range(0, index_topk, BLOCK_N):
    indices = tl.load(indices_ptr + cur_q * stride_indices_token + cur_kv_head_id * stride_indices_head + offs_indice)
    # 根据topk indices加载K和V
    k = tl.load(k_buffer + indices[None, :] * stride_k_token + ...)
    v = tl.load(v_buffer + indices[:, None] * stride_v_token + ...)
  • 不遍历完整 KV 序列,只加载 topk indices 指向的 token
  • 大幅减少计算量和内存访问

DPE(解耦位置编码)支持

python 复制代码
if BLOCK_DPE > 0:
    kpe = tl.load(k_buffer + offs_kpe, ...)
    qk += tl.dot(qpe, kpe.to(qpe.dtype))

LSE 输出格式

python 复制代码
max_logits = e_max * LOGE2  # base-2 → base-e 转换
lse = max_logits + tl.log2(e_sum) * LOGE2  # log2(sum) → log(sum)

18.6 TurboQuant存储算子(triton_turboquant_store.py, 447行)

职责: TurboQuant(TQ)KV cache 的写入端,支持 FP8 key + 均匀量化 value,以及 MSE 最优量化 key

18.6.1 TurboQuant 概述

TurboQuant 是一种混合 KV cache 量化方案:

  • Key: FP8 量化(per-token absmax)或 MSE 最优量化(vector quantization)
  • Value: 均匀量化(3-bit 或 4-bit,per-head min-max 量化)

相比标准 FP8 量化,TurboQuant 的 value 使用更低精度(3/4 bit),进一步压缩 KV cache 占用。

18.6.2 _store_quantized_value --- Value均匀量化子程序

4-bit 量化

python 复制代码
val_min = tl.min(tl.where(d_mask, val_vec, float("inf")), axis=0)
val_max = tl.max(tl.where(d_mask, val_vec, -float("inf")), axis=0)
v_scale = (val_max - val_min) / 15.0  # 4-bit: 0~15
q_all = tl.minimum(tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 15)
# 打包:两个4-bit值 → 一个字节
q_pairs = tl.reshape(q_all, [BLOCK_D // 2, 2])
packed_val = tl.sum((q_pairs & 0xF) << shifts_4[None, :], axis=1).to(tl.uint8)

3-bit 量化(更复杂):

python 复制代码
v_scale = (val_max - val_min) / 7.0  # 3-bit: 0~7
q_vals = tl.minimum(tl.maximum(((val_vec - val_min) / v_scale + 0.5).to(tl.int32), 0), 7)
# 打包:8个3-bit值 → 3个字节(24 bit)
q_grp = tl.reshape(q_vals, [BLOCK_GRP, 8])
packed_24 = tl.sum(q_grp << shifts_3bit[None, :], axis=1)  # 24-bit打包
b0 = (packed_24 & 0xFF).to(tl.uint8)
b1 = ((packed_24 >> 8) & 0xFF).to(tl.uint8)
b2 = ((packed_24 >> 16) & 0xFF).to(tl.uint8)
  • 3-bit 打包是非对齐的(8×3=24 bit = 3 bytes),比 4-bit(2×4=8 bit = 1 byte)复杂

Scale/Zero 存储(float16):

python 复制代码
sc_f16 = v_scale.to(tl.float16)
sc_u16 = sc_f16.to(tl.uint16, bitcast=True)  # 浮点→位模式
tl.store(KV_cache_ptr + slot_base + sc_offset, (sc_u16 & 0xFF).to(tl.uint8))
tl.store(KV_cache_ptr + slot_base + sc_offset + 1, ((sc_u16 >> 8) & 0xFF).to(tl.uint8))
  • float16 的 2 字节存储到 uint8 缓冲区,需要拆分为低/高字节
18.6.3 _tq_fused_store_fp8 --- FP8 Key + Value量化融合核
python 复制代码
@triton.jit
def _tq_fused_store_fp8(Key_ptr, Value_ptr, KV_cache_ptr, Slot_mapping_ptr, ...):
    # FP8 KEY: 在核内直接转换
    k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0)
    k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv)
    k_bytes = k_fp8.to(tl.uint8, bitcast=True)
    tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask)
    
    # VALUE: 均匀量化
    _store_quantized_value(Value_ptr, KV_cache_ptr, ...)
  • FP8 格式选择
    • e4b15(Ampere/Ada, SM < 8.9):自定义 FP8 格式
    • e4nv(Hopper+, SM ≥ 8.9):NVIDIA 标准 FP8 格式
18.6.4 _tq_fused_store_mse --- MSE量化融合核

MSE(最小均方误差)量化流程

  1. 旋转投影(核外 cuBLAS GEMM,更快):
python 复制代码
k_flat = key.float().reshape(NH, D)
norms = k_flat.norm(dim=1, keepdim=True)
x_hat = k_flat / (norms + 1e-8)  # 归一化
y = x_hat @ PiT  # 旋转投影到量化空间
  1. 二分搜索桶化(核内):
python 复制代码
lo = tl.zeros([BLOCK_D], dtype=tl.int32)
hi = tl.full([BLOCK_D], N_CENTROIDS - 1, dtype=tl.int32)
for _ in range(MSE_BITS):  # log2(n_centroids) 次迭代
    mid = (lo + hi) >> 1
    mid_val = tl.load(Midpoints_ptr + safe_mid, ...)
    lo = tl.where(y_vec >= mid_val, mid + 1, lo)
    hi = tl.where(y_vec >= mid_val, hi, mid)
idx = tl.minimum(lo, N_CENTROIDS - 1)
  • 对每个维度独立查找最近的质心索引
  • 二分搜索复杂度 O(log N_CENTROIDS) vs 线性搜索 O(N_CENTROIDS)
  1. 索引打包(3-bit 或 4-bit):
  • 与 value 打包逻辑相同
  1. 范数存储(fp16, 2 bytes)

18.7 TurboQuant解码算子(triton_turboquant_decode.py, 623行)

职责: TurboQuant KV cache 的读取+注意力计算端

18.7.1 _tq_decode_stage1 --- TQ解码Stage1核

两种 Key 读取路径

FP8 Key 路径

python 复制代码
if KEY_FP8:
    k_raw = tl.load(KV_cache_ptr + k_addrs, ...)
    k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32)
    scores = tl.sum(q_rot[None, :] * k_float, axis=1) * ATTN_SCALE

MSE Key 路径

python 复制代码
else:
    # 1. 从打包位中解包MSE索引
    mse_idx = (raw16 >> mse_bit_shift[None, :]) & mse_mask
    # 2. 查表得到质心值
    c_vals = tl.load(Centroids_ptr + mse_idx, ...)
    # 3. 加载向量范数
    vec_norms = (n_lo | (n_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
    # 4. 重构K ≈ norm * centroid
    scores = vec_norms * term1 * ATTN_SCALE

Value 解量化

3-bit Value 解包

python 复制代码
v_idx = ((raw16 >> val_bit_shift[None, :]) & 0x7).to(tl.float32)
v_scales = (sc_lo | (sc_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
v_zeros = (zr_lo | (zr_hi << 8)).to(tl.float16, bitcast=True).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]  # 仿射反量化

4-bit Value 解包

python 复制代码
v_idx = ((val_raw >> vb_shift[None, :]) & 0xF).to(tl.float32)
values = v_idx * v_scales[:, None] + v_zeros[:, None]

Online Softmax + 值累积

python 复制代码
n_e_max = tl.maximum(tl.max(scores, 0), m_prev)
re_scale = tl.exp(m_prev - n_e_max)
p = tl.exp(scores - n_e_max)
acc = acc * re_scale + tl.sum(p[:, None] * values, 0)
l_prev = l_prev * re_scale + tl.sum(p, 0)
m_prev = n_e_max
18.7.2 _tq_full_dequant_kv --- 全量KV解量化核

设计意图:将 TQ KV cache 完整解量化为 fp16 K 和 V,用于 prefill 阶段(需要完整 K/V 做矩阵注意力)。

18.7.3 triton_turboquant_decode_attention --- 完整解码入口

三步流程

  1. Query 旋转投影(MSE 路径需要):
python 复制代码
if key_fp8:
    q_rot = query.contiguous()  # FP8: 直接用原始query
else:
    q_rot = (query.float() @ PiT).contiguous()  # MSE: 旋转投影
  1. Stage 1: 分段 KV 注意力
  2. Stage 2 : 跨段归约(复用 triton_decode_attention._fwd_kernel_stage2

缓冲区复用

python 复制代码
if mid_o_buf is not None and mid_o_buf.shape[0] >= B:
    mid_o = mid_o_buf[:B, :Hq, :NUM_KV_SPLITS, :]
else:
    mid_o = torch.empty(...)
    if buf_holder is not None:
        buf_holder._tq_mid_o_buf = mid_o
  • 预分配缓冲区,避免每次调用都分配新内存
  • 对 CUDA Graph 兼容至关重要(CUDA Graph 要求固定的内存地址)

总结:vLLM v1 Attention Ops架构全景

算子分层关系

复制代码
┌─────────────────────────────────────────────────────────┐
│                    上层调度(Backends)                    │
│  FlashAttention / Triton / FlashInfer / ROCm / XPU      │
├─────────────────────────────────────────────────────────┤
│                   通用工具层(common.py)                  │
│  CP修正 / Pack-Unpack / Block Table归一化               │
├──────────────────────┬──────────────────────────────────┤
│    KV Cache 写入层    │         注意力计算层               │
│  reshape_and_cache   │  unified_attention (2D/3D)       │
│  per-token-head量化  │  prefix_prefill                  │
│  TurboQuant store    │  decode_attention (2-stage)      │
│  Indexer quant+cache │  chunked_prefill_paged_decode    │
├──────────────────────┴──────────────────────────────────┤
│                   注意力合并层                            │
│  merge_attn_states (Triton/CUDA)                        │
├─────────────────────────────────────────────────────────┤
│                  特殊算子层                               │
│  FlashMLA (Dense/Sparse) │ ViT Wrappers                 │
│  DCP AlltoAll            │ ROCm Aiter MLA Sparse        │
│  XPU MLA Sparse          │ TurboQuant Decode            │
└─────────────────────────────────────────────────────────┘

核心设计模式

  1. 两阶段Split-KV:decode_attention 和 turboquant_decode 都采用 Split-KV + Reduce 的两阶段架构,解决长序列单 program 串行瓶颈

  2. 2D/3D统一调度:unified_attention 根据是否有 prefill 和序列数选择 2D(单段 softmax)或 3D(分段 softmax)模式

  3. 多层掩码叠加:causal ∧ sliding_window ∨ mm_prefix,与 FlexAttention 语义一致

  4. 多级KV量化:从无量化 → FP8 per-tensor → FP8/int8 per-token-head → MSE 最优量化 + 均匀量化 value,精度和压缩率递进

  5. 平台自适应 :通过 current_platform 运行时检测,自动选择 CUDA/ROCm/XPU 路径和参数

  6. 非标准block_size兼容:Qwen3-Next 的 544 block_size 通过逐 token 计算逻辑块索引实现兼容

  7. GQA/MQA/MLA统一支持 :通过 num_queries_per_kvIS_MLA 编译时常量,同一核函数支持多种注意力模式

  8. 编译时优化 :大量使用 tl.constexpr 将配置参数编译为常量,消除运行时分支和冗余计算

相关推荐
网络工程小王2 小时前
【hermes多智能体协作】个人学习笔记
笔记·学习·ai·智能体·hermes
俊哥V2 小时前
每日 AI 研究简报 · 2026-04-22
人工智能·ai
yyk的萌2 小时前
Claude Code 命令大全
linux·运维·服务器·ai·claude code
zs宝来了2 小时前
PyTorch DDP:分布式训练与梯度同步
机器学习·ai·基础设施
我母鸡啊2 小时前
软考架构师故事系列-数据库系统
后端·架构
张忠琳2 小时前
【vllm】(五)vLLM v1 Attention — 模块超深度分析之二
人工智能·深度学习·ai·架构·vllm
九章智算云3 小时前
一份CLAUDE.md,为何能让GitHub榜首项目狂揽6万星?
人工智能·ai·大模型·agent·ai工具·claude code·vibe-coding
Yunzenn3 小时前
# 零基础复现Claude Code(二):地基篇——让模型开口说话
人工智能·架构
heimeiyingwang3 小时前
【架构实战】容器安全最佳实践
安全·架构