06-vLLM 注意力后端实现详细分析

vLLM 注意力后端实现详细分析

📍 定位

本文档深入分析 vLLM 的注意力(Attention)后端架构,涵盖从高层选择器到底层算子的完整实现链路。vLLM 通过模块化的后端设计,支持多种硬件平台(NVIDIA CUDA、AMD ROCm、CPU)和多种注意力变体(标准 Attention、MLA、Mamba SSM 等),实现了高性能推理服务。

架构总览

#mermaid-svg-hnqnTKGyu4H0Imla{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-hnqnTKGyu4H0Imla .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-hnqnTKGyu4H0Imla .error-icon{fill:#552222;}#mermaid-svg-hnqnTKGyu4H0Imla .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-hnqnTKGyu4H0Imla .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-hnqnTKGyu4H0Imla .marker{fill:#333333;stroke:#333333;}#mermaid-svg-hnqnTKGyu4H0Imla .marker.cross{stroke:#333333;}#mermaid-svg-hnqnTKGyu4H0Imla svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-hnqnTKGyu4H0Imla p{margin:0;}#mermaid-svg-hnqnTKGyu4H0Imla .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-hnqnTKGyu4H0Imla .cluster-label text{fill:#333;}#mermaid-svg-hnqnTKGyu4H0Imla .cluster-label span{color:#333;}#mermaid-svg-hnqnTKGyu4H0Imla .cluster-label span p{background-color:transparent;}#mermaid-svg-hnqnTKGyu4H0Imla .label text,#mermaid-svg-hnqnTKGyu4H0Imla span{fill:#333;color:#333;}#mermaid-svg-hnqnTKGyu4H0Imla .node rect,#mermaid-svg-hnqnTKGyu4H0Imla .node circle,#mermaid-svg-hnqnTKGyu4H0Imla .node ellipse,#mermaid-svg-hnqnTKGyu4H0Imla .node polygon,#mermaid-svg-hnqnTKGyu4H0Imla .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-hnqnTKGyu4H0Imla .rough-node .label text,#mermaid-svg-hnqnTKGyu4H0Imla .node .label text,#mermaid-svg-hnqnTKGyu4H0Imla .image-shape .label,#mermaid-svg-hnqnTKGyu4H0Imla .icon-shape .label{text-anchor:middle;}#mermaid-svg-hnqnTKGyu4H0Imla .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-hnqnTKGyu4H0Imla .rough-node .label,#mermaid-svg-hnqnTKGyu4H0Imla .node .label,#mermaid-svg-hnqnTKGyu4H0Imla .image-shape .label,#mermaid-svg-hnqnTKGyu4H0Imla .icon-shape .label{text-align:center;}#mermaid-svg-hnqnTKGyu4H0Imla .node.clickable{cursor:pointer;}#mermaid-svg-hnqnTKGyu4H0Imla .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-hnqnTKGyu4H0Imla .arrowheadPath{fill:#333333;}#mermaid-svg-hnqnTKGyu4H0Imla .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-hnqnTKGyu4H0Imla .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-hnqnTKGyu4H0Imla .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-hnqnTKGyu4H0Imla .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-hnqnTKGyu4H0Imla .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-hnqnTKGyu4H0Imla .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-hnqnTKGyu4H0Imla .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-hnqnTKGyu4H0Imla .cluster text{fill:#333;}#mermaid-svg-hnqnTKGyu4H0Imla .cluster span{color:#333;}#mermaid-svg-hnqnTKGyu4H0Imla div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-hnqnTKGyu4H0Imla .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-hnqnTKGyu4H0Imla rect.text{fill:none;stroke-width:0;}#mermaid-svg-hnqnTKGyu4H0Imla .icon-shape,#mermaid-svg-hnqnTKGyu4H0Imla .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-hnqnTKGyu4H0Imla .icon-shape p,#mermaid-svg-hnqnTKGyu4H0Imla .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-hnqnTKGyu4H0Imla .icon-shape .label rect,#mermaid-svg-hnqnTKGyu4H0Imla .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-hnqnTKGyu4H0Imla .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-hnqnTKGyu4H0Imla .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-hnqnTKGyu4H0Imla :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 底层算子 (ops/)
SSM/特殊注意力
MLA 后端群
标准注意力后端
Selector 层
AttentionSelector

selector.py
FlashAttention

flash_attn.py
FlashInfer

flashinfer.py
Triton Attention

triton_attn.py
ROCm Attention

rocm_attn.py
CPU Attention

cpu_attn.py
FlashInfer MLA

flashinfer_mla.py
FlashAttn MLA

flashattn_mla.py
Triton MLA

triton_mla.py
Cutlass MLA

cutlass_mla.py
FlashMLA

flashmla.py
ROCm AIter MLA

rocm_aiter_mla.py
Mamba1 Attention

mamba1_attn.py
Mamba2 Attention

mamba2_attn.py
Linear Attention

linear_attn.py
Short Conv Attention

short_conv_attn.py
GDN Attention

gdn_attn.py
paged_attn.py

分页注意力核心
chunked_prefill_paged_decode.py

Chunked Prefill + Paged Decode
triton_prefill_attention.py

Triton Prefill Kernel
triton_decode_attention.py

Triton Decode Kernel
triton_unified_attention.py

统一 Triton Kernel
merge_attn_states.py

状态合并
dcp_alltoall.py

DCP 通信


一、注意力后端选择器

1.1 核心入口:get_attn_backend()

文件位置: selector.py(file:///workspace/vllm/v1/attention/selector.py#L53-L103)

python 复制代码
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str | None,
    use_mla: bool = False,
    has_sink: bool = False,
    use_sparse: bool = False,
    use_mm_prefix: bool = False,
    use_per_head_quant_scales: bool = False,
    attn_type: str | None = None,
    num_heads: int | None = None,
) -> type[AttentionBackend]:

选择器的核心职责:

  1. 配置收集 :从 VllmConfig 提取 cache_configattention_config 等信息
  2. 构建 SelectorConfig :将所有参数封装为 AttentionSelectorConfig 数据类
  3. 委托给平台层 :调用 current_platform.get_attn_backend_cls() 进行实际的后端选择
AttentionSelectorConfig 结构

文件位置: selector.py(file:///workspace/vllm/v1/attention/selector.py#L22-L50)

python 复制代码
class AttentionSelectorConfig(NamedTuple):
    head_size: int                    # 头维度大小
    dtype: torch.dtype               # 计算数据类型
    kv_cache_dtype: CacheDType | None # KV 缓存数据类型
    block_size: int | None            # 分页块大小
    use_mla: bool = False             # 是否使用 MLA
    has_sink: bool = False            # 是否使用 attention sinks
    use_sparse: bool = False          # 是否使用稀疏注意力
    use_mm_prefix: bool = False       # 是否使用 multi-modal prefix
    use_per_head_quant_scales: bool = False  # 每头量化 scale
    attn_type: str = AttentionType.DECODER   # 注意力类型
    use_non_causal: bool = False      # 非因果注意力
    use_batch_invariant: bool = False # batch invariant 模式

1.2 缓存机制

文件位置: selector.py(file:///workspace/vllm/v1/attention/selector.py#L106-L137)

python 复制代码
@cache
def _cached_get_attn_backend(
    backend,
    attn_selector_config: AttentionSelectorConfig,
    num_heads: int | None = None,
) -> type[AttentionBackend]:
  • 使用 @cache 装饰器实现单例缓存,相同配置只解析一次
  • 调用平台特定的 get_attn_backend_cls() 方法获取后端类
  • 自动调整 KV cache 布局(如 HND/NHD)

1.3 Mamba 后端选择器

文件位置: selector.py(file:///workspace/vllm/v1/attention/selector.py#L140-L168)

python 复制代码
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:

Mamba 类型的映射关系(定义在 registry.py(file:///workspace/vllm/v1/attention/backends/registry.py#L196-L203)):

mamba_type 后端枚举
mamba1 MAMBA1
mamba2 MAMBA2
short_conv SHORT_CONV
linear_attention LINEAR
gdn_attention GDN_ATTN

二、FlashAttention 后端

2.1 后端定义与能力声明

文件位置: flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L69-L223)

python 复制代码
class FlashAttentionBackend(AttentionBackend):
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "auto", "float16", "bfloat16",
    ]

关键能力

方法 返回值 说明
supports_compute_capability() ≥ sm_80 需要 Ampere+ GPU
supports_head_size() 8 整除且 ≤256 (FA4≤512) 头维度约束
supports_kv_cache_dtype() auto/fp16/bf16/fp8 FP8 需要特定支持
supports_sink() FA3 且支持 sinks Attention sink token
supports_non_causal() True 支持非因果注意力

2.2 KV Cache 布局

文件位置: flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L140-L171)

python 复制代码
@staticmethod
def get_kv_cache_shape(...) -> tuple[int, ...]:
    return (2, num_blocks, block_size, num_kv_heads, head_size)

布局格式:(kv_half, num_blocks, block_size, num_kv_heads, head_size)

支持的 stride order:

  • NHD : (0, 1, 2, 3, 4) - 默认布局
  • HND : (0, 1, 3, 2, 4) - 交换 block_size 和 num_kv_heads 维度

2.3 Metadata Builder

文件位置: flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L279-L594)

FlashAttentionMetadata 结构
python 复制代码
@dataclass
class FlashAttentionMetadata:
    num_actual_tokens: int           # 实际 token 数(排除 padding)
    max_query_len: int               # 最大查询长度
    query_start_loc: torch.Tensor    # 查询起始位置 [batch+1]
    max_seq_len: int                 # 最大序列长度
    seq_lens: torch.Tensor           # 各序列长度
    block_table: torch.Tensor        # 分页表
    slot_mapping: torch.Tensor       # slot 映射

    # Cascade Attention 支持
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: torch.Tensor | None
    prefix_kv_lens: torch.Tensor | None
    suffix_kv_lens: torch.Tensor | None

    # DCP(Decode Context Parallelism)
    max_dcp_context_kv_len: int | None
    dcp_context_kv_lens: torch.Tensor | None

    # FA3 AOT Scheduling
    scheduler_metadata: torch.Tensor | None
    prefix_scheduler_metadata: torch.Tensor | None
    max_num_splits: int = 0

    causal: bool = True
CUDA Graph 支持

文件位置: flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L298-L302)

python 复制代码
_cudagraph_support = (
    AttentionCGSupport.ALWAYS
    if get_flash_attn_version() == 3
    else AttentionCGSupport.UNIFORM_BATCH
)
  • FA3: 完全支持所有 CUDA Graph 场景
  • FA2: 仅支持 uniform batch(由于 packed-GQA 特殊处理)

2.4 核心前向传播

文件位置: flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L677-L861)

主 forward 流程
python 复制代码
def forward(self, layer, query, key, value, kv_cache,
            attn_metadata, output, ...):
    # 1. Encoder attention 直接处理(无 KV cache)
    if attn_type in (ENCODER_ONLY, ENCODER):
        return self._forward_encoder_attention(...)

    # 2. 解包 KV cache 并修正 strides
    key_cache, value_cache = kv_cache.unbind(0)
    fixed_k = canonicalize_singleton_dim_strides(key_cache)

    # 3. FP8 量化处理
    if is_quantized_kv_cache(self.kv_cache_dtype):
        key_cache = key_cache.view(fp8_dtype)

    # 4. 正常注意力 / Cascade 注意力 / DCP 分支
    if not attn_metadata.use_cascade:
        flash_attn_varlen_func(
            q=query[:num_actual_tokens],
            k=key_cache,
            v=value_cache,
            out=output[:num_actual_tokens],
            cu_seqlens_q=cu_seqlens_q,
            seqused_k=seqused_k,
            ...
        )
    else:
        cascade_attention(...)
关键特性
  1. Cascade Attentionflash_attn.py#L1067-L1236(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L1067-L1236)):

    • common_prefix_len >= 256 时启用
    • 将共享前缀和后缀分别处理以节省带宽
    • 使用 merge_attn_states() 合并两部分结果
  2. DCP(Decode Context Parallelism)flash_attn.py#L898-L995(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L898-L995)):

    • context KV 在多个 rank 间分区
    • 使用 all-gather 收集 queries
    • 通过 LSE reduce 合并结果
  3. Encoder Attentionflash_attn.py#L997-L1064(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L997-L1064)):

    • 无需 KV cache,直接在 Q/K/V 上计算
    • 双向(非因果)注意力

2.5 KV Cache 更新

文件位置: flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py#L863-L896)

python 复制代码
def do_kv_cache_update(self, layer, key, value, kv_cache, slot_mapping):
    reshape_and_cache_flash(
        key, value,
        key_cache, value_cache,
        slot_mapping,
        self.kv_cache_dtype,
        layer._k_scale, layer._v_scale,
    )

使用 FlashAttention 库提供的 reshape_and_cache_flash 操作进行 scatter 写入。


三、FlashInfer 后端

3.1 后端概述

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L327-L437)

python 复制代码
class FlashInferBackend(AttentionBackend):
    supported_dtypes = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes = [
        "auto", "float16", "bfloat16",
        "fp8", "fp8_e4m3", "fp8_e5m2",  # 扩展的 FP8 支持
        "nvfp4",                          # NVFP4 量化支持
    ]

独特优势

  • 原生支持 NVFP4 量化(4-bit 量化)
  • 内置 TRT-LLM kernel 集成(Blackwell SM100)
  • 高效的 Page-level 注意力操作

3.2 KV Cache 形状差异

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L358-L369)

python 复制代码
def get_kv_cache_shape(...):
    if cache_dtype_str == "nvfp4":
        last_dim = nvfp4_kv_cache_full_dim(head_size)  # 打包维度
        return (num_blocks, 2, block_size, num_kv_heads, last_dim)
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

注意:FlashInfer 的 KV cache shape 以 num_blocks 为第一维(不同于 FlashAttention 的 (2, num_blocks, ...))。

3.3 双路径架构:Native vs TRT-LLM

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L440-L537)

FlashInfer 后端维护两套并行的执行路径:
#mermaid-svg-YpW7NooSl2SDO1Pl{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-YpW7NooSl2SDO1Pl .error-icon{fill:#552222;}#mermaid-svg-YpW7NooSl2SDO1Pl .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-YpW7NooSl2SDO1Pl .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-YpW7NooSl2SDO1Pl .marker{fill:#333333;stroke:#333333;}#mermaid-svg-YpW7NooSl2SDO1Pl .marker.cross{stroke:#333333;}#mermaid-svg-YpW7NooSl2SDO1Pl svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-YpW7NooSl2SDO1Pl p{margin:0;}#mermaid-svg-YpW7NooSl2SDO1Pl .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl .cluster-label text{fill:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl .cluster-label span{color:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl .cluster-label span p{background-color:transparent;}#mermaid-svg-YpW7NooSl2SDO1Pl .label text,#mermaid-svg-YpW7NooSl2SDO1Pl span{fill:#333;color:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl .node rect,#mermaid-svg-YpW7NooSl2SDO1Pl .node circle,#mermaid-svg-YpW7NooSl2SDO1Pl .node ellipse,#mermaid-svg-YpW7NooSl2SDO1Pl .node polygon,#mermaid-svg-YpW7NooSl2SDO1Pl .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-YpW7NooSl2SDO1Pl .rough-node .label text,#mermaid-svg-YpW7NooSl2SDO1Pl .node .label text,#mermaid-svg-YpW7NooSl2SDO1Pl .image-shape .label,#mermaid-svg-YpW7NooSl2SDO1Pl .icon-shape .label{text-anchor:middle;}#mermaid-svg-YpW7NooSl2SDO1Pl .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-YpW7NooSl2SDO1Pl .rough-node .label,#mermaid-svg-YpW7NooSl2SDO1Pl .node .label,#mermaid-svg-YpW7NooSl2SDO1Pl .image-shape .label,#mermaid-svg-YpW7NooSl2SDO1Pl .icon-shape .label{text-align:center;}#mermaid-svg-YpW7NooSl2SDO1Pl .node.clickable{cursor:pointer;}#mermaid-svg-YpW7NooSl2SDO1Pl .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-YpW7NooSl2SDO1Pl .arrowheadPath{fill:#333333;}#mermaid-svg-YpW7NooSl2SDO1Pl .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-YpW7NooSl2SDO1Pl .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-YpW7NooSl2SDO1Pl .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-YpW7NooSl2SDO1Pl .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-YpW7NooSl2SDO1Pl .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-YpW7NooSl2SDO1Pl .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-YpW7NooSl2SDO1Pl .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-YpW7NooSl2SDO1Pl .cluster text{fill:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl .cluster span{color:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-YpW7NooSl2SDO1Pl .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-YpW7NooSl2SDO1Pl rect.text{fill:none;stroke-width:0;}#mermaid-svg-YpW7NooSl2SDO1Pl .icon-shape,#mermaid-svg-YpW7NooSl2SDO1Pl .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-YpW7NooSl2SDO1Pl .icon-shape p,#mermaid-svg-YpW7NooSl2SDO1Pl .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-YpW7NooSl2SDO1Pl .icon-shape .label rect,#mermaid-svg-YpW7NooSl2SDO1Pl .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-YpW7NooSl2SDO1Pl .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-YpW7NooSl2SDO1Pl .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-YpW7NooSl2SDO1Pl :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} Decode 路径
FIDecode

Native Wrapper
BatchDecodeWithPagedKVCacheWrapper
TRTLLMDecode
trtllm_batch_decode_with_kv_cache
Prefill 路径
FIPrefill

Native Wrapper
BatchPrefillWithPagedKVCacheWrapper
TRTLLMPrefill
trtllm_batch_context_with_kv_cache

Metadata 结构
python 复制代码
@dataclass
class FlashInferMetadata:
    num_actual_tokens: int
    slot_mapping: torch.Tensor
    q_data_type: torch.dtype              # Query 数据类型(可能为 FP8)
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int

    prefill: FIPrefill | TRTLLMPrefill | None
    decode: FIDecode | TRTLLMDecode | None

    use_cascade: bool
    cascade_wrapper: MultiLevelCascadeAttentionWrapper | None

3.4 TRT-LLM 路由决策

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L915-L934)

python 复制代码
# Prefill TRT-LLM 决策
prefill_use_trtllm = use_trtllm_attention(
    self.num_qo_heads, self.num_kv_heads,
    num_prefill_tokens, max_seq_len,
    self.dcp_world_size, self.cache_dtype,
    self.q_data_type, is_prefill=True,
    has_sinks=self.has_sinks,
)

# Decode TRT-LLM 决策(优先使用)
decode_use_trtllm = (
    self.use_trtllm_decode_attention and self.dcp_world_size <= 1
)

TRT-LLM 启用条件

  • Blackwell GPU (SM100)
  • 支持 NVFP4/FP8 量化
  • 更高效的 kernel 实现

3.5 NVFP4 量化处理

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L1500-L1507)

python 复制代码
if self.is_kvcache_nvfp4:
    nvfp4_kv_data, nvfp4_kv_block_scales = nvfp4_kv_cache_split_views(
        kv_cache_permute
    )

NVFP4 格式将数据和 block-scale 打包在同一维度中,需要特殊拆分视图。

3.6 DCP 支持

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L213-L324)

python 复制代码
class BatchDCPPrefillWrapper:
    def __init__(self, workspace_buffer, dcp_a2a=False):
        self._context = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, ...)
        self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, ...)

DCP 模式下:

  • Context 部分:跨 rank 收集 queries → paged attention
  • New tokens 部分:ragged KV attention
  • LSE reduce 合并结果

3.7 Fast Plan for CUDA Graph

文件位置: flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py#L1857-L1947)

python 复制代码
def fast_plan_decode(self, indptr_cpu, indices, last_page_len_cpu, ...):
    if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True):
        self.plan(...)  # Warm-up
        return

    fast_decode_plan(self, ...)  # CUDA Graph 优化路径

优化点:

  • 仅 host-to-device copy of indptr/last_page_len
  • 避免 device-to-device copy of indices buffer

四、MLA (Multi-head Latent Attention)

4.1 MLA 概述

MLA 是 DeepSeek-V3/V4 引入的新型注意力机制,通过 KV 压缩 显著减少 KV cache 内存占用。

核心思想

  • 将 Key/Value 压缩到低维 latent space
  • 分离 content(no-pe)和 positional(pe)信息
  • 支持稀疏注意力模式

4.2 MLA 后端家族

目录位置: mla/(file:///workspace/vllm/v1/attention/backends/mla/)

后端 文件 适用场景
FlashInfer MLA flashinfer_mla.py(file:///workspace/vllm/v1/attention/backends/mla/flashinfer_mla.py) Blackwell GPU,推荐
FlashAttn MLA flashattn_mla.py(file:///workspace/vllm/v1/attention/backends/mla/flashattn_mla.py) Ampere/Hopper GPU
Triton MLA triton_mla.py(file:///workspace/vllm/v1/attention/backends/mla/triton_mla.py) 通用 CUDA GPU
Cutlass MLA cutlass_mla.py(file:///workspace/vllm/v1/attention/backends/mla/cutlass_mla.py) Cutlass 实现
FlashMLA flashmla.py(file:///workspace/vllm/v1/attention/backends/mla/flashmla.py) FlashMLA 库
ROCm AIter MLA rocm_aiter_mla.py(file:///workspace/vllm/v1/attention/backends/mla/rocm_aiter_mla.py) AMD GPU
Sparse 变体 *_sparse.py DeepSeek-V4 稀疏模式

4.3 FlashInfer MLA 实现

文件位置: flashinfer_mla.py(file:///workspace/vllm/v1/attention/backends/mla/flashinfer_mla.py#L38-L97)

python 复制代码
class FlashInferMLABackend(MLACommonBackend):
    supported_kv_cache_dtypes = ["auto", "float16", "bfloat16", "fp8", "fp8_e4m3"]

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [32, 64]  # 仅支持 32 和 64

    @classmethod
    def supports_compute_capability(cls, capability):
        return capability.major == 10  # 仅 Blackwell

约束条件

  • qk_nope_head_dim 必须在 [64, 128, 192]
  • 要求 HND KV cache 布局
  • 不支持 ALiBi、sliding_window、logits_soft_cap

4.4 MLA Forward 流程

文件位置: flashinfer_mla.py(file:///workspace/vllm/v1/attention/backends/mla/flashinfer_mla.py#L155-L209)

python 复制代码
def forward_mqa(self, q, kv_c_and_k_pe_cache, attn_metadata, layer):
    # 1. 合并 no-pe 和 pe query
    if isinstance(q, tuple):
        q_nope, q_pe = q
        q = torch.cat([q_nope, q_pe], dim=-1)

    # 2. 重塑为 [num_decodes, q_len, heads, dim]
    q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])

    # 3. 调用 TRT-LLM MLA decode kernel
    o = trtllm_batch_decode_with_kv_cache_mla(
        query=q,
        kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
        workspace_buffer=self._workspace_buffer,
        qk_nope_head_dim=self.qk_nope_head_dim,
        kv_lora_rank=self.kv_lora_rank,
        qk_rope_head_dim=self.qk_rope_head_dim,
        block_tables=attn_metadata.decode.block_table,
        seq_lens=attn_metadata.decode.seq_lens,
        ...
    )

4.5 MLA Sparse 变体

DeepSeek-V4 引入了稀疏 MLA,相关实现位于:

  • flashinfer_mla_sparse.py(file:///workspace/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py):FlashInfer 稀疏 MLA
  • sparse_utils.py(file:///workspace/vllm/v1/attention/backends/mla/sparse_utils.py):稀疏工具函数
  • sparse_swa.py(file:///workspace/vllm/v1/attention/backends/mla/sparse_swa.py):Sliding Window Attention

4.6 MLA Prefill 选择器

文件位置: mla/prefill/selector.py(file:///workspace/vllm/v1/attention/backends/mla/prefill/selector.py)

MLA prefill 阶段有独立的选择逻辑,支持:

  • FlashInfer prefill
  • FlashAttention prefill
  • TRT-LLM ragged prefill

五、Triton 注意力

5.1 后端定义

文件位置: triton_attn.py(file:///workspace/vllm/v1/attention/backends/triton_attn.py#L265-L388)

python 复制代码
class TritonAttentionBackend(AttentionBackend):
    supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
    supported_kv_cache_dtypes = [
        "auto", "float16", "bfloat16",
        "fp8", "fp8_e4m3", "fp8_e5m2",
        "int8_per_token_head",     # Per-token-head 量化
        "fp8_per_token_head",      # Per-token-head FP8
    ]

独特优势

  • 最广泛的数据类型支持(包括 float32)
  • 支持 per-token-head 量化
  • 完全自定义的 Triton kernel
  • 支持 ALiBi sqrt 变体

5.2 Per-Token-Head 量化支持

文件位置: triton_attn.py#L307-L328(file:///workspace/vllm/v1/attention/backends/triton_attn.py#L307-L328)

python 复制代码
def get_kv_cache_shape(...):
    if kv_cache_uses_per_token_head_scales(cache_dtype_str):
        # 在 head_size 后追加 padding 存放 scale
        scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype)
        return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad)
    return (num_blocks, 2, block_size, num_kv_heads, head_size)

Scale 提取机制(triton_attn.py#L395-L445(file:///workspace/vllm/v1/attention/backends/triton_attn.py#L395-L445)):

python 复制代码
def _ensure_scale_caches(self, kv_cache):
    # 使用 as_strided 创建零拷贝视图
    self._k_scale_cache = torch.as_strided(
        base_f32,
        size=(num_blocks, block_size, nkv),
        stride=(full_block_f32, slot_f32, head_f32),
        storage_offset=scale_off_f32,
    )

5.3 统一 Kernel 入口

文件位置: triton_attn.py#L609-L641(file:///workspace/vllm/v1/attention/backends/triton_attn.py#L609-L641)

python 复制代码
unified_attention(
    q=query[:num_actual_tokens],
    k=key_cache,
    v=value_cache,
    out=output[:num_actual_tokens],
    cu_seqlens_q=cu_seqlens_q,
    max_seqlen_q=max_query_len,
    seqused_k=seqused_k,
    max_seqlen_k=max_seq_len,
    softmax_scale=self.scale,
    causal=True,
    alibi_slopes=self.alibi_slopes,
    window_size=self.sliding_window,
    block_table=block_table,
    softcap=self.logits_soft_cap,
    seq_threshold_3D=seq_threshold_3D,     # 2D/3D kernel 切换阈值
    num_par_softmax_segments=num_par_softmax_segments,  # 并行 softmax 分段数
    sinks=self.sinks,
    mm_prefix_range=mm_prefix_range_tensor,
    kv_quant_mode=self._kv_quant_mode,
    chunk_lookback=self.chunk_lookback,
)

5.4 2D/3D Kernel 自适应选择

文件位置: triton_attn.py#L125-L176(file:///workspace/vllm/v1/attention/backends/triton_attn.py#L125-L176)

python 复制代码
MIN_LAUNCH_GRID_SIZE_2D = 128  # 2D kernel 最小启动网格大小

# 计算阈值
self.seq_threshold_3D = MIN_LAUNCH_GRID_SIZE_2D // self.num_heads_kv
  • 2D Kernel : grid = (num_q_blocks, num_kv_heads) --- 适合大批量
  • 3D Kernel : grid = (num_seqs, num_heads_q, num_kv_heads) --- 适合小批量

5.5 RoPE 融合支持

文件位置: triton_attn.py#L740-L773(file:///workspace/vllm/v1/attention/backends/triton_attn.py#L740-L773)

python 复制代码
def fused_rope_kvcache_supported(self):
    if self._is_per_token_head_quant:
        return False
    return rocm_aiter_ops.is_enabled()

def do_rope_and_kv_cache_update(self, layer, query, key, value,
                                positions, cos_sin_cache, is_neox,
                                kv_cache, layer_slot_mapping):
    rocm_aiter_ops.triton_rope_and_cache(
        query, key, value, positions, cos_sin_cache, is_neox,
        key_cache, value_cache, layer_slot_mapping,
        layer._k_scale, layer._v_scale, flash_layout, is_fp8_kv_cache,
    )

支持 RoPE + KV cache 写入的融合操作。


六、ROCm 后端 --- AMD GPU 支持

6.1 后端定义

文件位置: rocm_attn.py(file:///workspace/vllm/v1/attention/backends/rocm_attn.py#L164-L253)

python 复制代码
class RocmAttentionBackend(AttentionBackend):
    supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
    supported_kv_cache_dtypes = ["auto", "float16", "bfloat16",
                                  "fp8", "fp8_e4m3", "fp8_e5m2"]

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [MultipleOf(16)]  # 支持任意 16 的倍数

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [32, 64, 80, 96, 128, 160, 192, 224, 256]

6.2 双 Kernel 策略

文件位置: rocm_attn.py#L453-L496(file:///workspace/vllm/v1/attention/backends/rocm_attn.py#L453-L496)

ROCm 后端根据 block_size 选择不同的底层实现:

python 复制代码
def do_kv_cache_update(self, layer, key, value, kv_cache, slot_mapping):
    block_size = value_cache.shape[3]

    if block_size in (16, 32):
        # 标准 block sizes:使用 HIP C++ kernel
        PagedAttention.write_to_paged_cache(key, value, ...)
    else:
        # 非标准 blocks(如 Qwen3-Next 的 544):使用 Triton kernel
        triton_reshape_and_cache_flash(key, value, ...)

6.3 Chunked Prefill + Paged Decode

文件位置: rocm_attn.py#L353-L450(file:///workspace/vllm/v1/attention/backends/rocm_attn.py#L353-L450)

ROCm 后端使用统一的 chunked_prefill_paged_decode 函数:

python 复制代码
def forward(self, layer, query, key, value, kv_cache,
            attn_metadata, output, ...):
    chunked_prefill_paged_decode(
        query=query[:num_actual_tokens],
        key=key[:num_actual_tokens],
        value=value[:num_actual_tokens],
        output=output[:num_actual_tokens],
        kv_cache_dtype=self.kv_cache_dtype,
        key_cache=key_cache,
        value_cache=value_cache,
        block_table=block_table,
        query_start_loc=cu_seqlens_q,
        seq_lens=seqused_k,
        max_seq_len=max_seqlen_k,
        max_query_len=max_seqlen_q,
        causal=attn_metadata.causal,
        ...
    )

这个函数同时处理 prefill 和 decode 阶段。

6.4 AIter 集成

文件位置: rocm_aiter_unified_attn.py(file:///workspace/vllm/v1/attention/backends/rocm_aiter_unified_attn.py)、rocm_aiter_fa.py(file:///workspace/vllm/v1/attention/backends/rocm_aiter_fa.py)

ROCm 平台还提供 AIter(AMD Iteration Compiler)加速版本:

  • RocmAiterUnifiedAttentionBackend: 统一 AIter 后端
  • AiterFlashAttentionBackend: AIter FlashAttention 兼容层
  • AiterMLABackend: AIter MLA 实现

七、CPU 注意力 --- CPU fallback 实现

7.1 后端定义

文件位置: cpu_attn.py(file:///workspace/vllm/v1/attention/backends/cpu_attn.py#L41-97)

python 复制代码
class CPUAttentionBackend(AttentionBackend):
    supported_dtypes = [torch.float16, torch.bfloat16, torch.float32]
    supported_kv_cache_dtypes = ["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256, 512]

7.2 ISA 自适应选择

文件位置: cpu_attn.py#L509-545(file:///workspace/vllm/v1/attention/backends/cpu_attn.py#L509-545)

python 复制代码
def _get_attn_isa(dtype, block_size, head_size=None, kv_cache_dtype=None):
    supports_amx = torch.cpu._is_amx_tile_supported()
    arch = current_platform.get_cpu_architecture()

    if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
        return "amx"           # Intel AMX 指令集
    elif block_size % 32 == 0:
        if arch == CpuArchEnum.ARM:
            return "neon"       # ARM NEON
        elif arch == CpuArchEnum.S390X:
            return "vxe"        # IBM Z Vector Extensions
        elif arch == CpuArchEnum.POWERPC:
            return "vsx"        # PowerPC VSX
        else:
            return "vec"        # 通用 SIMD
    else:
        return "vec16"          # 16字节对齐向量

7.3 SDPA 混合策略

文件位置: cpu_attn.py#L120-162(file:///workspace/vllm/v1/attention/backends/cpu_attn.py#L120-162)

对于 x86/ARM/PowerPC 等架构,使用混合策略:

python 复制代码
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
    reorder_batch_threshold = 1
    self.use_sdpa_prefill = True
    # Decode tokens 排到前面 → SDPA for prefill
    # cpu_attention_with_kv_cache for decode

7.4 CPU Attention 执行流程

文件位置: cpu_attn.py#L285-391(file:///workspace/vllm/v1/attention/backends/cpu_attn.py#L285-L391)

python 复制代码
def forward(self, layer, query, key, value, kv_cache,
            attn_metadata, output, ...):
    # 1. Encoder attention → SDPA
    if self.attn_type in (ENCODER_ONLY, ENCODER):
        return self._run_sdpa_forward(...)

    # 2. KV cache 更新
    ops.cpu_attn_reshape_and_cache(key, value, key_cache, value_cache, ...)

    # 3. SDPA prefill(如果启用)
    if attn_metadata.use_sdpa_prefill:
        self._run_sdpa_forward(query[num_decode:], key[num_decode:], ...)

    # 4. CPU paged attention for decode
    if num_actual_tokens > 0:
        ops.cpu_attention_with_kv_cache(
            query=query[:num_actual_tokens],
            key_cache=key_cache,
            value_cache=value_cache,
            output=output[:num_actual_tokens],
            ...
        )

八、其他特殊注意力

8.1 Mamba/Mamba2 注意力(SSM)

Mamba1 Attention

文件位置: mamba1_attn.py(file:///workspace/vllm/v1/attention/backends/mamba1_attn.py#L14-L36)

python 复制代码
class Mamba1AttentionBackend(AttentionBackend):
    @staticmethod
    def get_name() -> str:
        return "MAMBA1_ATTN"

    @classmethod
    def is_ssm(cls) -> bool:
        return True  # 标记为 State Space Model

Mamba1 使用 selective scan 替代传统注意力,具有 O(L) 复杂度而非 O(L²)。

Mamba2 Attention

文件位置: mamba2_attn.py(file:///workspace/vllm/v1/attention/backends/mamba2_attn.py)

Mamba2 是 Mamba 的改进版本,引入了更高效的 SSM 实现。

8.2 Linear Attention

文件位置: linear_attn.py(file:///workspace/vllm/v1/attention/backends/linear_attn.py#L21-32)

python 复制代码
class LinearAttentionBackend(AttentionBackend):
    @classmethod
    def is_ssm(cls) -> bool:
        return True

Linear Attention 将 softmax 注意力替换为线性复杂度变体:

  • 使用 feature map 映射避免 softmax
  • 复杂度从 O(N²d) 降低到 O(Nd²)

Metadata 结构

python 复制代码
@dataclass
class LinearAttentionMetadata:
    state_indices_tensor: torch.Tensor  # 状态索引
    query_start_loc: torch.Tensor
    seq_lens: torch.Tensor

8.3 Short Conv Attention

文件位置: short_conv_attn.py(file:///workspace/vllm/v1/attention/backends/short_conv_attn.py#L12-33)

python 复制代码
class ShortConvAttentionBackend(AttentionBackend):
    @classmethod
    def is_ssm(cls) -> bool:
        return True

Short Convolution Attention 使用短卷积替代全局注意力,适用于局部依赖建模。

8.4 GDN (Gated DeltaNet) Attention

文件位置: gdn_attn.py(file:///workspace/vllm/v1/attention/backends/gdn_attn.py#L25-36)

python 复制代码
class GDNAttentionBackend(AttentionBackend):
    @classmethod
    def is_ssm(cls) -> bool:
        return True

GDN 是 Gated DeltaNet 注意力,结合了门控机制和 DeltaNet 的并行扫描。

高级特性gdn_attn.py#L76-100(file:///workspace/vllm/v1/attention/backends/gdn_attn.py#L76-100)):

python 复制代码
@dataclass
class GDNAttentionMetadata:
    # Speculative decoding 支持
    spec_query_start_loc: torch.Tensor | None
    non_spec_query_start_loc: torch.Tensor | None
    spec_state_indices_tensor: torch.Tensor | None
    non_spec_state_indices_tensor: torch.Tensor | None
    num_accepted_tokens: torch.Tensor | None

    # FLA chunk metadata(预计算避免 GPU→CPU sync)
    chunk_indices: torch.Tensor | None
    chunk_offsets: torch.Tensor | None

    # Triton causal_conv1d 元数据
    nums_dict: dict | None
    batch_ptr: torch.Tensor | None

GDN 后端完整支持 speculative decoding,包括:

  • Spec/non-spec 请求分离处理
  • FLA(Fast Linear Attention)chunk 操作
  • Causal convolution 1D 的 Triton 实现

九、注意力操作层(底层算子)

9.1 Ops 目录结构

复制代码
vllm/v1/attention/ops/
├── paged_attn.py                    # 分页注意力核心
├── chunked_prefill_paged_decode.py  # Chunked prefill + paged decode
├── triton_prefill_attention.py     # Triton prefill kernel
├── triton_decode_attention.py      # Triton decode kernel
├── triton_unified_attention.py      # 统一 Triton kernel
├── triton_reshape_and_cache_flash.py # KV cache 写入
├── merge_attn_states.py            # 注意力状态合并
├── common.py                        # 公共工具(LSE reduce)
├── dcp_alltoall.py                  # DCP AllToAll 通信
├── prefix_prefill.py                # Prefix prefill
├── flashmla.py                      # FlashMLA 操作
└── deepseek_v4_ops/                 # DeepSeek-V4 专用操作
    ├── fused_indexer_q.py
    ├── fused_compress_quant_cache.py
    ├── fused_inv_rope_fp8_quant.py
    └── fused_qk_rmsnorm.py

9.2 Paged Attention 核心

文件位置: paged_attn.py(file:///workspace/vllm/v1/attention/ops/paged_attn.py)

PagedAttention 类提供底层的分页注意力原语:

python 复制代码
class PagedAttention:
    @staticmethod
    def split_kv_cache(kv_cache, num_kv_heads, head_size):
        """解包 KV cache 为 key 和 value"""

    @staticmethod
    def write_to_paged_cache(key, value, key_cache, value_cache,
                              slot_mapping, kv_cache_dtype, ...):
        """写入 KV cache(HIP/CUDA 实现)"""

9.3 Chunked Prefill + Paged Decode

文件位置: chunked_prefill_paged_decode.py(file:///workspace/vllm/v1/attention/ops/chunked_prefill_paged_decode.py)

这是 ROCm 后端使用的统一入口函数:

python 复制代码
def chunked_prefill_paged_decode(
    query, key, value, output,
    kv_cache_dtype, key_cache, value_cache,
    block_table, query_start_loc, seq_lens,
    max_seq_len, max_query_len, ..., causal=True
):

特点:

  • 同时处理 prefill 和 decode tokens
  • 内部根据 max_query_len 自动选择 prefill 或 decode kernel
  • 支持 sliding window、alibi slopes、soft cap

9.4 Triton Prefill Attention

文件位置: triton_prefill_attention.py(file:///workspace/vllm/v1/attention/ops/triton_prefill_attention.py)

python 复制代码
def context_attention_fwd(
    q, k, v, o,
    b_start_loc, b_seq_len, max_input_len,
    is_causal, softmax_scale,
    sliding_window_q, sliding_window_k,
    abli_slopes=None, softcap=0.0
):

Triton 实现的 prefill attention kernel,用于:

  • ROCm encoder attention
  • Triton backend 的 prefill 阶段

9.5 Triton Unified Attention

文件位置: triton_unified_attention.py(file:///workspace/vllm/v1/attention/ops/triton_unified_attention.py)

python 复制代码
def unified_attention(
    q, k, v, out,
    cu_seqlens_q, max_seqlen_q,
    seqused_k, max_seqlen_k,
    softmax_scale, causal, ...,
    seq_threshold_3D,           # 2D/3D 切换阈值
    num_par_softmax_segments,   # 并行 softmax 分段
    sinks=None,                 # Attention sinks
    mm_prefix_range=None,       # Multi-modal prefix
    chunk_lookback=-1,          # Chunk lookback
):

这是 Triton backend 的统一入口,内部包含:

  • 2D decode kernel: 高吞吐量解码
  • 3D prefill/prompt kernel: 灵活的前缀填充
  • 并行分段 Softmax: 大批次数值稳定性

9.6 Merge Attention States

文件位置: merge_attn_states.py(file:///workspace/vllm/v1/attention/ops/merge_attn_states.py)

python 复制代码
def merge_attn_states(output, prefix_output, prefix_lse,
                     suffix_output, suffix_lse):

用于 Cascade Attention 中合并共享前缀和后缀的注意力结果。基于 log-sum-exp 的数学恒等式:

output=eOprefix⋅elseprefix+eOsuffix⋅elsesuffixelseprefix+elsesuffix\text{output} = \frac{e^{O_{prefix}} \cdot e^{lse_{prefix}} + e^{O_{suffix}} \cdot e^{lse_{suffix}}}{e^{lse_{prefix}} + e^{lse_{suffix}}}output=elseprefix+elsesuffixeOprefix⋅elseprefix+eOsuffix⋅elsesuffix

9.7 DCP AllToAll 通信

文件位置: dcp_alltoall.py(file:///workspace/vllm/v1/attention/ops/dcp_alltoall.py)

python 复制代码
def dcp_a2a_lse_reduce(output, lse, group, return_lse=False, ...):
def cp_lse_ag_out_rs(output, lse, group, return_lse=False, ...):

两种 DCP reduce 策略:

  • AllToAll: 适合大规模并行
  • Collective Reduce Scatter: 通信量更小

9.8 DeepSeek-V4 专用操作

目录位置: deepseek_v4_ops/(file:///workspace/vllm/v1/attention/ops/deepseek_v4_ops/)

文件 功能
fused_indexer_q.py 融合 Q 索引操作
fused_compress_quant_cache.py 融合压缩+量化 cache
fused_inv_rope_fp8_quant.py 融合逆 RoPE + FP8 量化
fused_qk_rmsnorm.py 融合 Q/K RMSNorm

这些融合操作显著减少 kernel launch 开销和内存访问。


十、后端选择决策流程

渲染错误: Mermaid 渲染失败: Parse error on line 28: ... VERIFY_HEAD -->|不在 64,128,256| FALLBA -----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'SQS'


十一、Registry 系统

11.1 后端注册表

文件位置: registry.py(file:///workspace/vllm/v1/attention/backends/registry.py#L34-131)

python 复制代码
class AttentionBackendEnum(Enum):
    FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
    FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
    TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
    ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
    CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
    # ... 更多后端

11.2 动态覆盖机制

文件位置: registry.py#L210-L261(file:///workspace/vllm/v1/attention/backends/registry.py#L210-L261)

python 复制代码
def register_backend(backend, class_path=None, is_mamba=False):
    """注册或覆盖后端实现"""
    def decorator(cls):
        if is_mamba:
            _MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
        else:
            _ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
        return cls
    return decorator

允许第三方或用户自定义后端实现。


十二、性能优化要点总结

12.1 各后端适用场景

后端 最佳场景 GPU 要求 特色功能
FlashAttention 通用 NVIDIA GPU sm_80+ FA3 AOT scheduling, Cascade Attn
FlashInfer 高吞吐/低延迟 sm_75+ NVFP4, TRT-LLM, Page-level ops
Triton 特殊 dtype/head_size 任意 CUDA Per-token-head quant, ALiBi sqrt
ROCm AMD GPU CDNA AIter 加速, 双 kernel 策略
CPU CPU 推理/调试 x86/ARM/PPC 多 ISA 适配, SDPA 混合
MLA 系列 DeepSeek-V3/V4 取决于变体 KV 压缩, 稀疏注意力

12.2 关键优化技术

  1. CUDA Graph 支持: 所有主要后端都支持 CUDA Graph capture
  2. Cascade Attention: 共享前缀优化(FlashAttention/FlashInfer)
  3. DCP(Decode Context Parallelism): 跨 rank KV 分区
  4. FP8/NVFP4 量化: 减少内存带宽需求
  5. Fused Operations: RoPE+Cache write, Quant+Attention
  6. Batch Invariant Mode: 固定形状提升利用率
  7. Speculative Decoding: GDN/Mamba 完整支持

附录:关键文件索引

文件路径 核心内容
v1/attention/selector.py(file:///workspace/vllm/v1/attention/selector.py) 后端选择器入口
v1/attention/backend.py(file:///workspace/vllm/v1/attention/backend.py) 抽象基类定义
v1/attention/backends/registry.py(file:///workspace/vllm/v1/attention/backends/registry.py) 后端枚举与注册
v1/attention/backends/flash_attn.py(file:///workspace/vllm/v1/attention/backends/flash_attn.py) FlashAttention 后端
v1/attention/backends/flashinfer.py(file:///workspace/vllm/v1/attention/backends/flashinfer.py) FlashInfer 后端
v1/attention/backends/triton_attn.py(file:///workspace/vllm/v1/attention/backends/triton_attn.py) Triton 后端
v1/attention/backends/rocm_attn.py(file:///workspace/vllm/v1/attention/backends/rocm_attn.py) ROCm 后端
v1/attention/backends/cpu_attn.py(file:///workspace/vllm/v1/attention/backends/cpu_attn.py) CPU 后端
v1/attention/backends/mla/(file:///workspace/vllm/v1/attention/backends/mla/) MLA 后端家族
v1/attention/ops/paged_attn.py(file:///workspace/vllm/v1/attention/ops/paged_attn.py) 分页注意力核心
v1/attention/ops/chunked_prefill_paged_decode.py(file:///workspace/vllm/v1/attention/ops/chunked_prefill_paged_decode.py) Chunked prefill + decode
v1/attention/ops/triton_unified_attention.py(file:///workspace/vllm/v1/attention/ops/triton_unified_attention.py) 统一 Triton kernel