【vllm】(五)vLLM v1 Attention — 模块超深度分析之二

第四章:selector.py 逐行解析

4.0 文件概览

selector.py(165行)是后端选择的入口模块,负责根据运行时配置选择最优注意力后端。

行范围 内容
1-15 导入
17-47 AttentionSelectorConfig NamedTuple
49-95 get_attn_backend() 主入口
97-125 _cached_get_attn_backend() 缓存选择
127-165 Mamba后端选择

4.1 导入区

python 复制代码
from functools import cache
from typing import NamedTuple, cast, get_args
  • cache:函数级LRU缓存装饰器(Python 3.9+)
  • NamedTuple:具名元组基类
  • cast:类型转换(仅类型检查时有效,运行时无操作)
  • get_args:获取Literal类型的可选值
python 复制代码
import torch
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.attention.backend import AttentionBackend, AttentionType
from vllm.v1.attention.backends.registry import (
    MAMBA_TYPE_TO_BACKEND_MAP,
    MambaAttentionBackendEnum,
)
  • resolve_obj_by_qualname:通过完全限定名动态导入类
  • MAMBA_TYPE_TO_BACKEND_MAP:Mamba类型→后端枚举映射

4.2 AttentionSelectorConfig(第17-47行)

python 复制代码
class AttentionSelectorConfig(NamedTuple):
    head_size: int
    dtype: torch.dtype
    kv_cache_dtype: CacheDType | None
    block_size: int | None
    use_mla: bool = False
    has_sink: bool = False
    use_sparse: bool = False
    use_mm_prefix: bool = False
    use_per_head_quant_scales: bool = False
    attn_type: str = AttentionType.DECODER
    use_non_causal: bool = False

设计意图

  • 使用NamedTuple(不可变+值语义+可哈希),适合作为缓存key
  • 封装所有影响后端选择的配置维度
  • 每个字段对应AttentionBackend.validate_configuration()的一个检查维度
python 复制代码
    def __repr__(self):
        return (
            f"AttentionSelectorConfig(head_size={self.head_size}, "
            f"dtype={self.dtype}, "
            f"kv_cache_dtype={self.kv_cache_dtype}, "
            f"block_size={self.block_size}, "
            f"use_mla={self.use_mla}, "
            f"has_sink={self.has_sink}, "
            f"use_sparse={self.use_sparse}, "
            f"use_mm_prefix={self.use_mm_prefix}, "
            f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
            f"attn_type={self.attn_type}, "
            f"use_non_causal={self.use_non_causal})"
        )

自定义repr便于调试日志。


4.3 get_attn_backend(第49-95行)--- 主入口

python 复制代码
def get_attn_backend(
    head_size: int,
    dtype: torch.dtype,
    kv_cache_dtype: str | None,
    use_mla: bool = False,
    has_sink: bool = False,
    use_sparse: bool = False,
    use_mm_prefix: bool = False,
    use_per_head_quant_scales: bool = False,
    attn_type: str | None = None,
    num_heads: int | None = None,
) -> type[AttentionBackend]:
    """Selects which attention backend to use and lazily imports it."""

这是外部代码调用后端选择的唯一入口

python 复制代码
    if kv_cache_dtype is not None:
        valid_cache_dtypes = get_args(CacheDType)
        assert kv_cache_dtype in valid_cache_dtypes, (
            f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
            f"Valid values are: {valid_cache_dtypes}"
        )
  • 验证kv_cache_dtype是否合法
python 复制代码
    from vllm.config import get_current_vllm_config
    vllm_config = get_current_vllm_config()
  • 获取当前线程的VllmConfig(上下文变量)
python 复制代码
    cache_config = vllm_config.cache_config
    if cache_config is not None and cache_config.user_specified_block_size:
        block_size = cache_config.block_size
    else:
        block_size = None
  • 从缓存配置获取block_size
  • 只有用户显式指定时才使用,否则让后端自行选择
python 复制代码
    speculative_config = vllm_config.speculative_config
    use_non_causal = (
        speculative_config is not None and speculative_config.method == "dflash"
    )
  • 投机解码使用dflash方法时需要非因果注意力
python 复制代码
    attn_selector_config = AttentionSelectorConfig(
        head_size=head_size,
        dtype=dtype,
        kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
        block_size=block_size,
        use_mla=use_mla,
        has_sink=has_sink,
        use_sparse=use_sparse,
        use_mm_prefix=use_mm_prefix,
        use_per_head_quant_scales=use_per_head_quant_scales,
        attn_type=attn_type or AttentionType.DECODER,
        use_non_causal=use_non_causal,
    )
  • 构建选择配置,默认DECODER注意力类型
python 复制代码
    return _cached_get_attn_backend(
        backend=vllm_config.attention_config.backend,
        attn_selector_config=attn_selector_config,
        num_heads=num_heads,
    )
  • 调用缓存版本,传入用户指定的后端名称和配置

4.4 _cached_get_attn_backend(第97-125行)

python 复制代码
@cache
def _cached_get_attn_backend(
    backend,
    attn_selector_config: AttentionSelectorConfig,
    num_heads: int | None = None,
) -> type[AttentionBackend]:

@cache装饰器:相同参数只执行一次选择逻辑,后续直接返回缓存结果。

为何需要缓存

  1. 后端选择涉及动态导入和平台检测,开销不小
  2. 同一配置下后端不变,无需重复选择
  3. AttentionSelectorConfig是不可变NamedTuple,可哈希,适合做key
python 复制代码
    from vllm.platforms import current_platform
    attention_cls = current_platform.get_attn_backend_cls(
        backend,
        attn_selector_config=attn_selector_config,
        num_heads=num_heads,
    )
  • 平台委托:将实际选择逻辑委托给当前平台
  • 每个平台(CUDA/ROCm/CPU/XPU)有自己的选择策略
  • backend参数:用户指定的后端名称或None(自动选择)
python 复制代码
    if not attention_cls:
        raise ValueError(
            f"Invalid attention backend for {current_platform.device_name}"
        )
  • 无匹配后端时抛出异常
python 复制代码
    backend = resolve_obj_by_qualname(attention_cls)
  • 将完全限定名字符串解析为类对象
  • 延迟导入:只在真正需要时才import后端模块
python 复制代码
    required_layout = backend.get_required_kv_cache_layout()
    if required_layout is not None:
        from vllm.v1.attention.backends.utils import set_kv_cache_layout
        set_kv_cache_layout(required_layout)
        logger.info(
            "Using %s KV cache layout for %s backend.",
            required_layout,
            backend.get_name(),
        )
    return backend
  • KV Cache布局适配:如果后端要求特定布局,全局设置
  • set_kv_cache_layout():修改全局变量并清除缓存
  • 这发生在后端选择确定后,确保后续KV Cache分配使用正确布局

4.5 Mamba后端选择(第127-165行)

python 复制代码
def get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    """Select which mamba attention backend to use and lazily import it."""
    return _cached_get_mamba_attn_backend(mamba_type)

Mamba/SSM后端的选择入口,逻辑更简单------直接通过类型字符串映射。

python 复制代码
@cache
def _cached_get_mamba_attn_backend(
    mamba_type: str,
) -> type[AttentionBackend]:
    assert mamba_type and isinstance(mamba_type, str)
python 复制代码
    selected_backend = None
    try:
        backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
        selected_backend = MambaAttentionBackendEnum[backend_name]
    except KeyError as e:
        raise ValueError(
            f"Invalid mamba attention backend type: '{mamba_type}'. Valid "
            f"types are: {list(MAMBA_TYPE_TO_BACKEND_MAP.keys())}"
        ) from e
  • MAMBA_TYPE_TO_BACKEND_MAP查找对应的枚举成员
  • 不存在则抛出友好错误
python 复制代码
    mamba_attn_backend = selected_backend.get_class()
    return mamba_attn_backend
  • 通过枚举成员的get_class()获取实际类
  • 这会触发延迟导入

第五章:registry.py 逐行解析

5.0 文件概览

registry.py(263行)实现了注意力后端的注册表机制,支持声明式枚举、运行时覆盖、装饰器注册。

行范围 内容
1-11 导入
13-30 _AttentionBackendEnumMeta 元类
32-122 AttentionBackendEnum 枚举
124-196 MambaAttentionBackendEnum 枚举
198-210 MAMBA_TYPE_TO_BACKEND_MAP 映射
212-215 覆盖存储字典
217-263 register_backend() 注册函数

5.1 导入

python 复制代码
from collections.abc import Callable
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
  • EnumMeta:枚举的元类,用于自定义枚举行为
  • resolve_obj_by_qualname:动态类加载

5.2 _AttentionBackendEnumMeta 元类(第13-30行)

python 复制代码
class _AttentionBackendEnumMeta(EnumMeta):
    """Metaclass for AttentionBackendEnum to provide better error messages."""

    def __getitem__(cls, name: str):
        """Get backend by name with helpful error messages."""
        try:
            return super().__getitem__(name)
        except KeyError:
            members = cast("dict[str, Enum]", cls.__members__).keys()
            valid_backends = ", ".join(members)
            raise ValueError(
                f"Unknown attention backend: '{name}'. "
                f"Valid options are: {valid_backends}"
            ) from None

设计意图

  • 覆盖__getitem__(即Enum['NAME']操作)
  • 将默认的KeyError替换为包含所有有效选项的ValueError
  • from None:抑制原始异常链,让traceback更清晰

使用场景:当用户配置了无效的后端名称时,提供友好的错误提示。


5.3 AttentionBackendEnum 枚举(第32-122行)

python 复制代码
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
    """Enumeration of all supported attention backends."""

每个枚举成员的值是完全限定类路径,用于延迟导入。

5.3.1 后端枚举列表
python 复制代码
    FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
  • NVIDIA GPU的首选后端,使用FlashAttention-2/3 kernel
python 复制代码
    FLASH_ATTN_DIFFKV = "vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
  • FlashAttention差分KV变体
python 复制代码
    TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
  • Triton编写的注意力kernel,通用GPU fallback
python 复制代码
    ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
  • AMD ROCm基础注意力后端
python 复制代码
    ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
    ROCM_AITER_TRITON_MLA = "vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
    ROCM_AITER_FA = "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
    ROCM_AITER_MLA_SPARSE = "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
  • AMD ROCm AITER系列后端(AITER = AMD Integrated Tensor Engine Runtime)
python 复制代码
    XPU_MLA_SPARSE = "vllm.v1.attention.backends.mla.xpu_mla_sparse.XPUMLASparseBackend"
  • Intel XPU MLA稀疏后端
python 复制代码
    TORCH_SDPA = ""  # this tag is only used for ViT
  • PyTorch原生SDPA(Scaled Dot-Product Attention),空字符串表示仅用于ViT
python 复制代码
    FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
  • FlashInfer后端,高性能prefill+decode
python 复制代码
    FLASHINFER_MLA = "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
    FLASHINFER_MLA_SPARSE = "vllm.v1.attention.backends.mla.flashinfer_mla_sparse.FlashInferMLASparseBackend"
  • FlashInfer MLA变体
python 复制代码
    TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
    CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
    FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
    FLASHMLA_SPARSE = "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
    FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
  • 各种MLA后端实现
python 复制代码
    NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
  • 无注意力后端(用于某些特殊模型层)
python 复制代码
    FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
  • PyTorch FlexAttention后端(实验性)
python 复制代码
    TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
  • Tree Attention后端(投机解码的树结构验证)
python 复制代码
    ROCM_AITER_UNIFIED_ATTN = "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
  • ROCm AITER统一注意力后端
python 复制代码
    CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
  • CPU注意力后端
python 复制代码
    TURBOQUANT = "vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
  • TurboQuant量化注意力后端
python 复制代码
    CUSTOM = None
  • 占位符:第三方/自定义后端
  • 值为None(避免与空字符串的TORCH_SDPA冲突)
  • 必须先通过register_backend()注册才能使用
5.3.2 枚举方法
python 复制代码
    def get_path(self, include_classname: bool = True) -> str:
        """Get the class path for this backend (respects overrides)."""
        path = _ATTN_OVERRIDES.get(self, self.value)
        if not path:
            raise ValueError(
                f"Backend {self.name} must be registered before use. "
                f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
            )
        if not include_classname:
            path = path.rsplit(".", 1)[0]
        return path

关键设计

  1. 优先从_ATTN_OVERRIDES获取覆盖路径
  2. 无覆盖则使用枚举默认值
  3. None/空字符串抛出友好错误
  4. include_classname=False:只返回模块路径(去掉类名)
python 复制代码
    def get_class(self) -> "type[AttentionBackend]":
        """Get the backend class (respects overrides)."""
        return resolve_obj_by_qualname(self.get_path())
  • 通过限定名动态加载类
python 复制代码
    def is_overridden(self) -> bool:
        return self in _ATTN_OVERRIDES

    def clear_override(self) -> None:
        _ATTN_OVERRIDES.pop(self, None)
  • 检查/清除覆盖

5.4 MambaAttentionBackendEnum(第124-196行)

结构与AttentionBackendEnum完全对称,但针对SSM/Mamba后端:

python 复制代码
    MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
    MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
    SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
    LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
    GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
    CUSTOM = None

5种SSM变体 + 自定义占位符。方法和元类与标准注意力枚举相同。


5.5 MAMBA_TYPE_TO_BACKEND_MAP

python 复制代码
MAMBA_TYPE_TO_BACKEND_MAP = {
    "mamba1": MambaAttentionBackendEnum.MAMBA1.name,
    "mamba2": MambaAttentionBackendEnum.MAMBA2.name,
    "short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
    "linear_attention": MambaAttentionBackendEnum.LINEAR.name,
    "gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
    "custom": MambaAttentionBackendEnum.CUSTOM.name,
}
  • 字符串类型名→枚举成员名的映射
  • 用于selector.py中的Mamba后端查找

5.6 覆盖存储

python 复制代码
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
  • 模块级全局字典,存储运行时覆盖
  • key:枚举成员,value:覆盖的类路径
  • 模块级变量意味着跨调用共享状态

5.7 register_backend(第217-263行)

python 复制代码
def register_backend(
    backend: AttentionBackendEnum | MambaAttentionBackendEnum,
    class_path: str | None = None,
    is_mamba: bool = False,
) -> Callable[[type], type]:

双重用途函数

  1. 直接注册:提供class_path参数
  2. 装饰器注册:不提供class_path,返回装饰器
python 复制代码
    def decorator(cls: type) -> type:
        if is_mamba:
            _MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
        else:
            _ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
        return cls

装饰器模式:从类对象自动生成完全限定名。

python 复制代码
    if class_path is not None:
        if is_mamba:
            _MAMBA_ATTN_OVERRIDES[backend] = class_path
        else:
            _ATTN_OVERRIDES[backend] = class_path
        return lambda x: x

直接注册:使用提供的路径字符串。

python 复制代码
    return decorator

class_path时返回装饰器。

使用示例

python 复制代码
# 装饰器方式
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
    ...

# 直接注册
register_backend(AttentionBackendEnum.CUSTOM, "my.module.MyCustomBackend")

# Mamba后端注册
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
class MyCustomMambaAttn:
    ...

第六章:utils.py 逐行解析

6.0 文件概览

utils.py(892行)是注意力模块的工具箱,提供KV Cache布局管理、批次操作、元数据构建辅助等通用功能。

行范围 内容
1-40 导入与全局变量
41-75 KV Cache布局管理
77-140 PerLayerParameters
142-220 make_local_attention_virtual_batches
222-280 make_kv_sharing_fast_prefill_common_attn_metadata
282-370 split_decodes_prefills_and_extends
372-430 split_decodes_and_prefills
432-460 split_prefill_chunks
462-530 reorder_batch_to_split_decodes_and_prefills
532-555 reshape_query_for_spec_decode / reshape_attn_output_for_spec_decode
557-580 subclass_attention_metadata
582-640 create_fast_prefill_custom_backend
642-720 compute_causal_conv1d_metadata
722-770 get_dcp_local_seq_lens
772-892 mamba_get_block_table_tensor

6.1 导入与全局变量(第1-40行)

python 复制代码
import functools
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass
from typing import TYPE_CHECKING, Any, Literal, Protocol, get_args
  • make_dataclass:动态创建dataclass(用于运行时生成元数据类)
  • Literal:字面量类型(用于KVCacheLayoutType定义)
python 复制代码
import numpy as np
import torch
from typing_extensions import runtime_checkable
  • runtime_checkable:使Protocol支持isinstance检查
python 复制代码
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
  • cdiv:向上取整除法(ceiling division)
  • KVCacheSpec:KV Cache规格基类
  • MambaSpec:Mamba规格(SSM专用)
python 复制代码
if TYPE_CHECKING:
    from vllm.v1.core.sched.output import SchedulerOutput
    from vllm.v1.worker.gpu_input_batch import InputBatch
python 复制代码
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
    get_kv_connector_cache_layout,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import (
    AttentionBackend,
    AttentionImpl,
    AttentionMetadata,
    CommonAttentionMetadata,
    subclass_attention_backend,
)

6.2 KV Cache布局管理(第41-75行)

python 复制代码
KVCacheLayoutType = Literal["NHD", "HND"]
  • 两种KV Cache布局:
    • NHD[2, num_blocks, block_size, num_heads, head_size] --- block维度在前
    • HND[2, num_heads, num_blocks, block_size, head_size] --- head维度在前
python 复制代码
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
  • 全局覆盖变量,优先级最高
python 复制代码
PAD_SLOT_ID = -1
NULL_BLOCK_ID = 0
  • PAD_SLOT_ID:padding slot的ID(-1表示无效)
  • NULL_BLOCK_ID:空块的ID(0号块,用于初始化/空请求)
python 复制代码
def is_valid_kv_cache_layout(value: str) -> bool:
    return value in get_args(KVCacheLayoutType)
  • 验证布局字符串是否合法
python 复制代码
@functools.lru_cache
def get_kv_cache_layout():
    global _KV_CACHE_LAYOUT_OVERRIDE
    cache_layout: Literal["NHD", "HND"] | None = None
    if _KV_CACHE_LAYOUT_OVERRIDE is not None:
        cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
        logger.info_once(
            "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
            "Setting KV cache layout to %s.",
            cache_layout,
        )
        return cache_layout

    cache_layout = envs.VLLM_KV_CACHE_LAYOUT
    if cache_layout is None:
        cache_layout = get_kv_connector_cache_layout()
    else:
        assert is_valid_kv_cache_layout(cache_layout)
        logger.info_once(...)
    return cache_layout

布局决定优先级

  1. 代码级覆盖(_KV_CACHE_LAYOUT_OVERRIDE)--- 最高
  2. 环境变量(VLLM_KV_CACHE_LAYOUT
  3. KV Connector默认布局
  4. 都无则为None

@lru_cache:只计算一次,后续调用直接返回缓存。

python 复制代码
def set_kv_cache_layout(cache_layout: KVCacheLayoutType | None):
    global _KV_CACHE_LAYOUT_OVERRIDE
    _KV_CACHE_LAYOUT_OVERRIDE = cache_layout
    get_kv_cache_layout.cache_clear()
  • 设置覆盖并清除LRU缓存 ,确保下次调用get_kv_cache_layout()重新计算

6.3 PerLayerParameters(第77-140行)

python 复制代码
@dataclass
class PerLayerParameters:
    """Currently, FlashInfer backend only support models in which all layers share
    the same values for the following hyperparameters."""
    window_left: int
    logits_soft_cap: float | None
    sm_scale: float
    has_sinks: bool = False
    has_same_window_lefts: bool | None = field(default=None, compare=False)
    has_same_all_params: bool | None = field(default=None, compare=False)

设计意图:存储每个注意力层的超参数,用于FlashInfer的plan阶段。

  • window_left:滑动窗口左边界(-1表示无窗口)
  • logits_soft_cap:logits软上限
  • sm_scale:softmax缩放因子
  • has_sinks:是否有attention sinks
  • has_same_window_lefts/has_same_all_params:跨层一致性标记(compare=False不参与相等比较)
python 复制代码
def get_per_layer_parameters(
    vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
    layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
    per_layer_params: dict[str, PerLayerParameters] = {}
    for key, layer in layers.items():
        impl = layer.impl
        assert isinstance(impl, cls_)
        window_size = getattr(impl, "sliding_window", None)
        window_left = window_size[0] if window_size is not None else -1
        logits_soft_cap = getattr(impl, "logits_soft_cap", None)
        sm_scale = impl.scale
        has_sinks = getattr(impl, "sinks", None) is not None
        per_layer_params[key] = PerLayerParameters(
            window_left, logits_soft_cap, sm_scale, has_sinks
        )
    return per_layer_params

从模型配置中扫描各注意力层,提取超参数。

python 复制代码
def infer_global_hyperparameters(
    per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
    assert len(per_layer_params) > 0, "No attention layers found in the model."
    param_sets = list(per_layer_params.values())
    global_params = param_sets[0]
    global_params.has_same_window_lefts = all(
        params.window_left == global_params.window_left for params in param_sets
    )
    global_params.has_same_all_params = all(
        params == global_params for params in param_sets
    )
    return global_params

断言所有层参数一致(FlashInfer要求),返回全局参数。


6.4 make_local_attention_virtual_batches(第142-220行)

这是本文件最复杂的函数,实现本地注意力的虚拟批次构建

python 复制代码
def make_local_attention_virtual_batches(
    attn_chunk_size: int,
    common_attn_metadata: CommonAttentionMetadata,
    block_size: int = 0,
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:

核心思想:将长序列的注意力拆分为多个本地注意力块,每个块作为独立的"虚拟批次项"传给注意力kernel。

为什么需要本地注意力

  • 标准因果注意力:每个token关注所有之前的token
  • 本地注意力:每个token只关注最近attn_chunk_size个token
  • 优点:计算复杂度从O(n²)降到O(n·chunk_size)

算法示例(代码注释翻译):

假设batch=3个序列:

  • q_seqlens = [4, 10, 5]
  • kv_seqlens = [6, 17, 9]
  • attn_chunk_size = 4

标准因果注意力mask(batch idx 0):

复制代码
       k_toks >   0 1 2 3 4 5
       q_toks v
              0 | 1 1 1
              1 | 1 1 1 1
              2 | 1 1 1 1 1
              3 | 1 1 1 1 1 1

本地注意力mask(chunk_size=4):

复制代码
       k_toks >   0 1 2 3 4 5
       q_toks v
              0 | 1 1 1
              1 | 1 1 1 1
              2 |         1
              3 |         1 1

拆分为虚拟批次:

复制代码
local-batch 0 (q=2, kv=4):       local-batch 1 (q=2, kv=2):
  k_toks >   0 1 2 3               k_toks >   4 5
  q_toks v                         q_toks v
         0 | 1 1 1                        2 | 1
         1 | 1 1 1 1                      3 | 1 1

核心计算步骤

python 复制代码
    q_tokens_in_first_block = np.minimum(
        attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
    ).astype(np.int32)
  • 计算每个请求第一个块中的query token数
  • seq_lens_np - q_seqlens = 已计算token数
  • 对chunk_size取模 → 已计算token在当前块中的偏移
  • chunk_size - 偏移 = 第一个块还能容纳多少新token
  • 与q_seqlens取最小(不能超过总query数)
python 复制代码
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
  • tokens_in_last_block:最后一个块的token数(Python负取模技巧)
  • local_blocks:每个请求的虚拟块数
python 复制代码
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
  • 构建虚拟批次内的arange索引和逆arange
python 复制代码
    seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    seqlens_q_local[arange > 0] = np.minimum(
        seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
    )[arange > 0]
  • 计算每个虚拟批次的query长度
  • 第一个块可能是部分的
  • 后续块最多chunk_size
python 复制代码
    seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
  • 每个虚拟批次的KV长度:中间块满,最后一块可能是部分的

Block Table构建

python 复制代码
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
        rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
    )
    block_starts = k_seqstarts_absolute // block_size
    pages_per_local_batch = attn_chunk_size // block_size

    block_indices = block_starts[:, None] + np.arange(
        pages_per_local_batch, dtype=np.int32
    )
    block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(
        np.arange(actual_batch_size, dtype=np.int32),
        local_blocks * pages_per_local_batch,
    )
  • 计算每个虚拟批次对应的绝对KV起始位置
  • 转换为block索引
  • 构建新的block_table:每个虚拟批次只引用其KV范围内的块

返回

python 复制代码
    return CommonAttentionMetadata(
        query_start_loc_cpu=query_start_loc_cpu,
        query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
        seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
        ...
    ), make_block_table

返回修改后的CommonAttentionMetadata和block_table重建函数。


6.5 make_kv_sharing_fast_prefill_common_attn_metadata(第222-280行)

python 复制代码
def make_kv_sharing_fast_prefill_common_attn_metadata(
    common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:

KV共享快速Prefill路径:在KV共享模式下,prefill时只对生成位置(最后一个token)计算完整注意力,而非生成位置可以跳过。

python 复制代码
    if common_attn_metadata.max_query_len == 1:
        return common_attn_metadata
  • 全是decode,无需处理
python 复制代码
    logits_indices = logits_indices_padded[:num_logits_indices]
    request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
    num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
  • bucketize:查找每个logit_index属于哪个请求
  • bincount:统计每个请求有多少decode token
python 复制代码
    decode_query_start_loc = torch.empty(num_reqs + 1, ...)
    decode_query_start_loc[0] = 0
    decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)

构建新的query_start_loc,只包含生成位置的token。


6.6 split_decodes_prefills_and_extends(第282-370行)

python 复制代码
def split_decodes_prefills_and_extends(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:

三分割:将批次分为decode / extend / prefill 三类。

  • decode:已处理完所有prompt token,每次只生成1个token
  • extend:chunked prefill中间状态,已处理部分prompt
  • prefill:第一次处理prompt

批次顺序:decode → extend → prefill

python 复制代码
    if max_query_len <= decode_threshold:
        return num_reqs, 0, 0, num_tokens, 0, 0
  • 所有请求都是decode
python 复制代码
    query_lens = query_start_loc[1:] - query_start_loc[:-1]
    is_prefill_or_extend = query_lens > decode_threshold
    is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
  • query_len > threshold:不是纯decode
  • seq_len == query_len:没有已计算token,是纯prefill
python 复制代码
    first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
    first_prefill = is_prefill.int().argmax(dim=-1).item()
  • 利用argmax找到第一个True的位置

返回6个值:num_decodes, num_extends, num_prefills, 及各自token数。


6.7 split_decodes_and_prefills(第372-430行)

python 复制代码
def split_decodes_and_prefills(
    common_attn_metadata: CommonAttentionMetadata,
    decode_threshold: int = 1,
    require_uniform: bool = False,
    treat_short_extends_as_decodes: bool = True,
) -> tuple[int, int, int, int]:

二分割:将批次分为decode / prefill两类(简化版)。

参数

  • require_uniform:是否要求decode请求query_len一致(CUDA Graph需要)
  • treat_short_extends_as_decodes:将短extend视为decode
python 复制代码
    if require_uniform:
        if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
            return num_reqs, 0, num_tokens, 0  # all decodes
        is_prefill = query_lens != query_lens[0]
    else:
        is_prefill = query_lens > decode_threshold
  • require_uniform:不一致的query_len视为prefill
  • 否则:query_len > threshold视为prefill
python 复制代码
    if not treat_short_extends_as_decodes:
        assert common_attn_metadata.is_prefilling is not None
        is_prefill |= common_attn_metadata.is_prefilling
  • 严格模式:使用is_prefilling标记区分真正的decode和短extend

6.8 split_prefill_chunks(第432-460行)

python 复制代码
def split_prefill_chunks(
    seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:

将prefill请求拆分为工作空间大小的块:避免单个长prefill占用过多GPU内存。

python 复制代码
    chunk_bounds = []
    i, n = 0, len(seq_lens_cpu)
    assert torch.all(seq_lens_cpu <= workspace_size).item()

    while i < n:
        start, chunk_total = i, 0
        while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
            chunk_total += s
            i += 1
        chunk_bounds.append((start + request_offset, i + request_offset))
    return chunk_bounds

贪心装箱:尽可能多地将请求放入同一chunk,直到总token数超过workspace_size。


6.9 reorder_batch_to_split_decodes_and_prefills(第462-530行)

python 复制代码
def reorder_batch_to_split_decodes_and_prefills(
    input_batch: "InputBatch",
    scheduler_output: "SchedulerOutput",
    decode_threshold: int = 1,
) -> bool:

实际重排批次:将InputBatch中的请求重新排列为 decode → short_extend → long_extend → prefill 顺序。

python 复制代码
    is_pure_prefill = ~has_context
    is_long_extend = has_context & ~is_below_threshold
    is_short_extend = has_context & is_below_threshold & ~done_prefilling
    is_decode = has_context & is_below_threshold & done_prefilling

四类互斥分类

  1. 无上下文 → 纯prefill
  2. 有上下文,超阈值 → 长extend
  3. 有上下文,低于阈值,仍在prefill → 短extend
  4. 有上下文,低于阈值,已完成prefill → decode
python 复制代码
    req_regions = np.zeros(num_reqs, dtype=np.int32)  # 0 = decode by default
    req_regions[is_short_extend] = 1
    req_regions[is_long_extend] = 2
    req_regions[is_pure_prefill] = 3
python 复制代码
    target_regions = np.repeat(
        [0, 1, 2, 3],
        [num_decodes, num_short_extends, num_long_extends, num_prefills],
    ).astype(np.int32)

    needs_swap = req_regions != target_regions
    if not needs_swap.any():
        return False

如果当前顺序与目标顺序一致,无需重排。

python 复制代码
    for src in src_dest_map:
        dst = src_dest_map[src]
        while src != dst:
            input_batch.swap_states(src, dst)
            next_dst = src_dest_map.get(dst, dst)
            src_dest_map[dst] = dst
            dst = next_dst

环式交换:通过swap_states实现原地重排,避免额外内存分配。


6.10 reshape_query_for_spec_decode / reshape_attn_output_for_spec_decode(第532-555行)

python 复制代码
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
    assert query.dim() == 3
    total_tokens = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[2]
    seq_len = total_tokens // batch_size
    return query.view(batch_size, seq_len, num_heads, head_dim)

将3D query [total_tokens, num_heads, head_dim] 重塑为4D [batch_size, seq_len, num_heads, head_dim],投机解码需要4D格式。

python 复制代码
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
    if attn_output.dim() == 3:
        return attn_output
    total_tokens = attn_output.shape[0] * attn_output.shape[1]
    return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])

反向操作:4D输出展平为3D。


6.11 subclass_attention_metadata(第557-580行)

python 复制代码
def subclass_attention_metadata(
    name_prefix: str,
    metadata_cls: Any,
    fields: list[tuple[str, Any, Any]],
) -> Any:
    name: str = name_prefix + metadata_cls.__name__
    Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
    return Wrapped

动态创建元数据子类 :使用make_dataclass在运行时生成带额外字段的dataclass。

  • metadata_cls:基础元数据类
  • fields:额外字段列表 [(name, type, default)]
  • 返回新的dataclass子类

6.12 KVSharingFastPrefillMetadata Protocol + create_fast_prefill_custom_backend(第582-640行)

python 复制代码
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
    logits_indices_padded: torch.Tensor | None = None
    num_logits_indices: int | None = None

Protocol:标记支持KV共享快速Prefill的元数据类。

python 复制代码
def create_fast_prefill_custom_backend(
    prefix: str,
    underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:

工厂函数:基于已有后端创建支持快速Prefill的变体。

python 复制代码
    class FastPrefillAttentionBuilder(underlying_builder):
        def build(
            self,
            common_prefix_len: int,
            common_attn_metadata: CommonAttentionMetadata,
            fast_build: bool = False,
        ) -> AttentionMetadata:
            new_common_attn_metadata = (
                make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
            )
            metadata = super().build(
                common_prefix_len, new_common_attn_metadata, fast_build
            )
  • 继承底层Builder
  • build()中先将CommonAttentionMetadata转换为快速Prefill版本
  • 然后调用父类build
python 复制代码
            class KVSharingFastPrefillAttentionMetadata(
                metadata.__class__,
                KVSharingFastPrefillMetadata,
            ):
                def __init__(self, metadata, common_attn_metadata):
                    for _field in fields(metadata.__class__):
                        setattr(self, _field.name, getattr(metadata, _field.name))
                    self.logits_indices_padded = (
                        common_attn_metadata.logits_indices_padded
                    )
                    self.num_logits_indices = common_attn_metadata.num_logits_indices

动态元数据类:多重继承底层metadata类和Protocol,浅拷贝所有字段,添加logits_indices。

python 复制代码
    attn_backend = subclass_attention_backend(
        name_prefix=prefix,
        attention_backend_cls=underlying_attn_backend,
        builder_cls=FastPrefillAttentionBuilder,
    )
    return attn_backend

最终用动态子类工厂创建新的Backend类。


6.13 compute_causal_conv1d_metadata(第642-720行)

python 复制代码
def compute_causal_conv1d_metadata(
    query_start_loc_p_cpu: torch.Tensor,
    *,
    device: torch.device,
):

为causal_conv1d kernel(Mamba/SSM使用)计算元数据。

python 复制代码
    for BLOCK_M in [8]:
        nums = -(-seqlens // BLOCK_M)
  • 对每个序列,计算需要多少个BLOCK_M大小的块
  • -(-a // b) = 向上取整除法
python 复制代码
        mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
  • 批次索引列表:每个块对应的序列索引
python 复制代码
        batch_ptr = torch.full(
            (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
        )
        batch_ptr[0:mlist_len].copy_(mlist, non_blocking=True)
  • 预分配GPU缓冲区,将CPU数据异步拷贝到GPU

返回nums_dict,包含每个BLOCK_M的批次指针和偏移指针。


6.14 get_dcp_local_seq_lens(第722-770行)

python 复制代码
def get_dcp_local_seq_lens(
    seq_lens: torch.Tensor,
    dcp_size: int = 1,
    dcp_rank: int | None = None,
    cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:

DCP(Decode Context Parallelism)序列长度切分:将序列长度按DCP rank分割,每个rank只处理一部分KV Cache。

python 复制代码
    if dcp_rank is None:
        rank_offsets = (
            torch.arange(dcp_size, ...)
            .unsqueeze(0).repeat(num_requests, 1)
        )
    else:
        rank_offsets = torch.tensor([[dcp_rank]], ...)
  • dcp_rank=None:计算所有rank的分布
  • 否则只计算指定rank
python 复制代码
    base = seq_lens_tiled // cp_kv_cache_interleave_size // dcp_size * cp_kv_cache_interleave_size
    remainder = seq_lens_tiled - base * dcp_size
    remainder = torch.clip(
        remainder - rank_offsets * cp_kv_cache_interleave_size,
        0,
        cp_kv_cache_interleave_size,
    )
    dcp_local_seq_lens = base + remainder

分配算法

  1. base:每个rank均匀分配的基础长度
  2. remainder:余数部分,前几个rank多分一些
  3. 考虑cp_kv_cache_interleave_size交错因子

6.15 mamba_get_block_table_tensor(第772-892行)

python 复制代码
def mamba_get_block_table_tensor(
    block_table: torch.Tensor,
    seq_lens: torch.Tensor,
    kv_cache_spec: KVCacheSpec,
    mamba_cache_mode: str,
) -> torch.Tensor:

Mamba/SSM的Block Table适配:不同Mamba缓存模式需要不同的block_table切片。

三种模式:

python 复制代码
    if mamba_cache_mode in ("all", "none"):
        return block_table
  • "all":使用完整block_table
  • "none":直接使用输入block_table(已预裁剪)
python 复制代码
    else:  # "align"
        assert isinstance(kv_cache_spec, MambaSpec)
        start_indices = torch.clamp(
            (seq_lens - 1) // kv_cache_spec.block_size,
            min=0,
        )
        offsets = torch.arange(
            1 + kv_cache_spec.num_speculative_blocks,
            device=block_table.device,
            dtype=torch.int32,
        )
        indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64)
        return torch.gather(block_table, 1, indices_to_gather)

"align"模式

  • 只取每个请求最后1+num_speculative_blocks个块
  • start_indices:每个请求当前序列的起始块位置
  • torch.gather:从block_table中提取指定位置

用途:Mamba的SSM状态只依赖最近的token,不需要完整block_table,只需末尾几块用于投机解码。


附录:核心设计模式总结

设计模式 应用位置 说明
策略模式 AttentionBackend → AttentionImpl 后端可互换,统一接口
工厂方法 get_impl_cls(), get_builder_cls() 后端类提供实现类的工厂
模板方法 AttentionImplBase.new() 子类自动获取CP配置
注册表模式 registry.py 枚举+Decorator注册+运行时覆盖
装饰器模式 register_backend() 类注册装饰器
代理模式 selector.py → platform.get_attn_backend_cls() 委托平台做实际选择
备忘录模式 @cache on _cached_get_attn_backend 避免重复选择
组合模式 create_fast_prefill_custom_backend 基于已有后端组合出新后端
Protocol AttentionLayer, KVSharingFastPrefillMetadata 结构化子类型(鸭子类型+静态检查)
动态子类 subclass_attention_backend type()运行时创建子类

本文档为 vLLM v1 Attention 模块 Part 1,覆盖核心抽象层的4个文件共2324行代码的逐行分析。

Part 2 将覆盖具体后端实现(backends/.py),Part 3 将覆盖ops层(ops/.py)。

相关推荐
EnCi Zheng2 小时前
01b-上下文向量与信息瓶颈
人工智能
Yuer20252 小时前
幻觉量化(Hallucination Quantization):从随机语言模型到确定性交付系统的工程范式
大数据·人工智能·语言模型
九章智算云2 小时前
一份CLAUDE.md,为何能让GitHub榜首项目狂揽6万星?
人工智能·ai·大模型·agent·ai工具·claude code·vibe-coding
Yunzenn2 小时前
# 零基础复现Claude Code(二):地基篇——让模型开口说话
人工智能·架构
科技AI训练师2 小时前
2026 屋顶风机行业观察测评:英飞风机助力建筑通风排烟升级
大数据·人工智能
扬帆破浪2 小时前
免费开源的WPS AI插件 察元AI助手:脱密加密模块:Web Crypto 与口令校验
人工智能·开源·ai编程·wps
heimeiyingwang2 小时前
【架构实战】容器安全最佳实践
安全·架构
openFuyao2 小时前
openFuyao技术讲堂 | AI推理鹰眼(Eagle Eye)
人工智能
水木流年追梦2 小时前
CodeTop Top 300 热门题目2-最长回文子串
开发语言·人工智能·python·算法·leetcode