第四章:selector.py 逐行解析
4.0 文件概览
selector.py(165行)是后端选择的入口模块,负责根据运行时配置选择最优注意力后端。
| 行范围 | 内容 |
|---|---|
| 1-15 | 导入 |
| 17-47 | AttentionSelectorConfig NamedTuple |
| 49-95 | get_attn_backend() 主入口 |
| 97-125 | _cached_get_attn_backend() 缓存选择 |
| 127-165 | Mamba后端选择 |
4.1 导入区
python
from functools import cache
from typing import NamedTuple, cast, get_args
cache:函数级LRU缓存装饰器(Python 3.9+)NamedTuple:具名元组基类cast:类型转换(仅类型检查时有效,运行时无操作)get_args:获取Literal类型的可选值
python
import torch
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.attention.backend import AttentionBackend, AttentionType
from vllm.v1.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
)
resolve_obj_by_qualname:通过完全限定名动态导入类MAMBA_TYPE_TO_BACKEND_MAP:Mamba类型→后端枚举映射
4.2 AttentionSelectorConfig(第17-47行)
python
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
use_per_head_quant_scales: bool = False
attn_type: str = AttentionType.DECODER
use_non_causal: bool = False
设计意图:
- 使用
NamedTuple(不可变+值语义+可哈希),适合作为缓存key - 封装所有影响后端选择的配置维度
- 每个字段对应
AttentionBackend.validate_configuration()的一个检查维度
python
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"use_per_head_quant_scales={self.use_per_head_quant_scales}, "
f"attn_type={self.attn_type}, "
f"use_non_causal={self.use_non_causal})"
)
自定义repr便于调试日志。
4.3 get_attn_backend(第49-95行)--- 主入口
python
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
use_per_head_quant_scales: bool = False,
attn_type: str | None = None,
num_heads: int | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
这是外部代码调用后端选择的唯一入口。
python
if kv_cache_dtype is not None:
valid_cache_dtypes = get_args(CacheDType)
assert kv_cache_dtype in valid_cache_dtypes, (
f"Invalid kv_cache_dtype: {kv_cache_dtype}. "
f"Valid values are: {valid_cache_dtypes}"
)
- 验证kv_cache_dtype是否合法
python
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
- 获取当前线程的VllmConfig(上下文变量)
python
cache_config = vllm_config.cache_config
if cache_config is not None and cache_config.user_specified_block_size:
block_size = cache_config.block_size
else:
block_size = None
- 从缓存配置获取block_size
- 只有用户显式指定时才使用,否则让后端自行选择
python
speculative_config = vllm_config.speculative_config
use_non_causal = (
speculative_config is not None and speculative_config.method == "dflash"
)
- 投机解码使用dflash方法时需要非因果注意力
python
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
block_size=block_size,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
use_non_causal=use_non_causal,
)
- 构建选择配置,默认DECODER注意力类型
python
return _cached_get_attn_backend(
backend=vllm_config.attention_config.backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
- 调用缓存版本,传入用户指定的后端名称和配置
4.4 _cached_get_attn_backend(第97-125行)
python
@cache
def _cached_get_attn_backend(
backend,
attn_selector_config: AttentionSelectorConfig,
num_heads: int | None = None,
) -> type[AttentionBackend]:
@cache装饰器:相同参数只执行一次选择逻辑,后续直接返回缓存结果。
为何需要缓存:
- 后端选择涉及动态导入和平台检测,开销不小
- 同一配置下后端不变,无需重复选择
AttentionSelectorConfig是不可变NamedTuple,可哈希,适合做key
python
from vllm.platforms import current_platform
attention_cls = current_platform.get_attn_backend_cls(
backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
- 平台委托:将实际选择逻辑委托给当前平台
- 每个平台(CUDA/ROCm/CPU/XPU)有自己的选择策略
backend参数:用户指定的后端名称或None(自动选择)
python
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
- 无匹配后端时抛出异常
python
backend = resolve_obj_by_qualname(attention_cls)
- 将完全限定名字符串解析为类对象
- 延迟导入:只在真正需要时才import后端模块
python
required_layout = backend.get_required_kv_cache_layout()
if required_layout is not None:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout(required_layout)
logger.info(
"Using %s KV cache layout for %s backend.",
required_layout,
backend.get_name(),
)
return backend
- KV Cache布局适配:如果后端要求特定布局,全局设置
set_kv_cache_layout():修改全局变量并清除缓存- 这发生在后端选择确定后,确保后续KV Cache分配使用正确布局
4.5 Mamba后端选择(第127-165行)
python
def get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
"""Select which mamba attention backend to use and lazily import it."""
return _cached_get_mamba_attn_backend(mamba_type)
Mamba/SSM后端的选择入口,逻辑更简单------直接通过类型字符串映射。
python
@cache
def _cached_get_mamba_attn_backend(
mamba_type: str,
) -> type[AttentionBackend]:
assert mamba_type and isinstance(mamba_type, str)
python
selected_backend = None
try:
backend_name = MAMBA_TYPE_TO_BACKEND_MAP[mamba_type]
selected_backend = MambaAttentionBackendEnum[backend_name]
except KeyError as e:
raise ValueError(
f"Invalid mamba attention backend type: '{mamba_type}'. Valid "
f"types are: {list(MAMBA_TYPE_TO_BACKEND_MAP.keys())}"
) from e
- 从
MAMBA_TYPE_TO_BACKEND_MAP查找对应的枚举成员 - 不存在则抛出友好错误
python
mamba_attn_backend = selected_backend.get_class()
return mamba_attn_backend
- 通过枚举成员的
get_class()获取实际类 - 这会触发延迟导入
第五章:registry.py 逐行解析
5.0 文件概览
registry.py(263行)实现了注意力后端的注册表机制,支持声明式枚举、运行时覆盖、装饰器注册。
| 行范围 | 内容 |
|---|---|
| 1-11 | 导入 |
| 13-30 | _AttentionBackendEnumMeta 元类 |
| 32-122 | AttentionBackendEnum 枚举 |
| 124-196 | MambaAttentionBackendEnum 枚举 |
| 198-210 | MAMBA_TYPE_TO_BACKEND_MAP 映射 |
| 212-215 | 覆盖存储字典 |
| 217-263 | register_backend() 注册函数 |
5.1 导入
python
from collections.abc import Callable
from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, cast
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
EnumMeta:枚举的元类,用于自定义枚举行为resolve_obj_by_qualname:动态类加载
5.2 _AttentionBackendEnumMeta 元类(第13-30行)
python
class _AttentionBackendEnumMeta(EnumMeta):
"""Metaclass for AttentionBackendEnum to provide better error messages."""
def __getitem__(cls, name: str):
"""Get backend by name with helpful error messages."""
try:
return super().__getitem__(name)
except KeyError:
members = cast("dict[str, Enum]", cls.__members__).keys()
valid_backends = ", ".join(members)
raise ValueError(
f"Unknown attention backend: '{name}'. "
f"Valid options are: {valid_backends}"
) from None
设计意图:
- 覆盖
__getitem__(即Enum['NAME']操作) - 将默认的
KeyError替换为包含所有有效选项的ValueError from None:抑制原始异常链,让traceback更清晰
使用场景:当用户配置了无效的后端名称时,提供友好的错误提示。
5.3 AttentionBackendEnum 枚举(第32-122行)
python
class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""Enumeration of all supported attention backends."""
每个枚举成员的值是完全限定类路径,用于延迟导入。
5.3.1 后端枚举列表
python
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
- NVIDIA GPU的首选后端,使用FlashAttention-2/3 kernel
python
FLASH_ATTN_DIFFKV = "vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
- FlashAttention差分KV变体
python
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
- Triton编写的注意力kernel,通用GPU fallback
python
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
- AMD ROCm基础注意力后端
python
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
ROCM_AITER_TRITON_MLA = "vllm.v1.attention.backends.mla.aiter_triton_mla.AiterTritonMLABackend"
ROCM_AITER_FA = "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
ROCM_AITER_MLA_SPARSE = "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse.ROCMAiterMLASparseBackend"
- AMD ROCm AITER系列后端(AITER = AMD Integrated Tensor Engine Runtime)
python
XPU_MLA_SPARSE = "vllm.v1.attention.backends.mla.xpu_mla_sparse.XPUMLASparseBackend"
- Intel XPU MLA稀疏后端
python
TORCH_SDPA = "" # this tag is only used for ViT
- PyTorch原生SDPA(Scaled Dot-Product Attention),空字符串表示仅用于ViT
python
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
- FlashInfer后端,高性能prefill+decode
python
FLASHINFER_MLA = "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
FLASHINFER_MLA_SPARSE = "vllm.v1.attention.backends.mla.flashinfer_mla_sparse.FlashInferMLASparseBackend"
- FlashInfer MLA变体
python
TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
FLASHMLA_SPARSE = "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
- 各种MLA后端实现
python
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
- 无注意力后端(用于某些特殊模型层)
python
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
- PyTorch FlexAttention后端(实验性)
python
TREE_ATTN = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend"
- Tree Attention后端(投机解码的树结构验证)
python
ROCM_AITER_UNIFIED_ATTN = "vllm.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
- ROCm AITER统一注意力后端
python
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
- CPU注意力后端
python
TURBOQUANT = "vllm.v1.attention.backends.turboquant_attn.TurboQuantAttentionBackend"
- TurboQuant量化注意力后端
python
CUSTOM = None
- 占位符:第三方/自定义后端
- 值为
None(避免与空字符串的TORCH_SDPA冲突) - 必须先通过
register_backend()注册才能使用
5.3.2 枚举方法
python
def get_path(self, include_classname: bool = True) -> str:
"""Get the class path for this backend (respects overrides)."""
path = _ATTN_OVERRIDES.get(self, self.value)
if not path:
raise ValueError(
f"Backend {self.name} must be registered before use. "
f"Use register_backend(Backend.{self.name}, 'your.module.YourClass')"
)
if not include_classname:
path = path.rsplit(".", 1)[0]
return path
关键设计:
- 优先从
_ATTN_OVERRIDES获取覆盖路径 - 无覆盖则使用枚举默认值
None/空字符串抛出友好错误include_classname=False:只返回模块路径(去掉类名)
python
def get_class(self) -> "type[AttentionBackend]":
"""Get the backend class (respects overrides)."""
return resolve_obj_by_qualname(self.get_path())
- 通过限定名动态加载类
python
def is_overridden(self) -> bool:
return self in _ATTN_OVERRIDES
def clear_override(self) -> None:
_ATTN_OVERRIDES.pop(self, None)
- 检查/清除覆盖
5.4 MambaAttentionBackendEnum(第124-196行)
结构与AttentionBackendEnum完全对称,但针对SSM/Mamba后端:
python
MAMBA1 = "vllm.v1.attention.backends.mamba1_attn.Mamba1AttentionBackend"
MAMBA2 = "vllm.v1.attention.backends.mamba2_attn.Mamba2AttentionBackend"
SHORT_CONV = "vllm.v1.attention.backends.short_conv_attn.ShortConvAttentionBackend"
LINEAR = "vllm.v1.attention.backends.linear_attn.LinearAttentionBackend"
GDN_ATTN = "vllm.v1.attention.backends.gdn_attn.GDNAttentionBackend"
CUSTOM = None
5种SSM变体 + 自定义占位符。方法和元类与标准注意力枚举相同。
5.5 MAMBA_TYPE_TO_BACKEND_MAP
python
MAMBA_TYPE_TO_BACKEND_MAP = {
"mamba1": MambaAttentionBackendEnum.MAMBA1.name,
"mamba2": MambaAttentionBackendEnum.MAMBA2.name,
"short_conv": MambaAttentionBackendEnum.SHORT_CONV.name,
"linear_attention": MambaAttentionBackendEnum.LINEAR.name,
"gdn_attention": MambaAttentionBackendEnum.GDN_ATTN.name,
"custom": MambaAttentionBackendEnum.CUSTOM.name,
}
- 字符串类型名→枚举成员名的映射
- 用于
selector.py中的Mamba后端查找
5.6 覆盖存储
python
_ATTN_OVERRIDES: dict[AttentionBackendEnum, str] = {}
_MAMBA_ATTN_OVERRIDES: dict[MambaAttentionBackendEnum, str] = {}
- 模块级全局字典,存储运行时覆盖
- key:枚举成员,value:覆盖的类路径
- 模块级变量意味着跨调用共享状态
5.7 register_backend(第217-263行)
python
def register_backend(
backend: AttentionBackendEnum | MambaAttentionBackendEnum,
class_path: str | None = None,
is_mamba: bool = False,
) -> Callable[[type], type]:
双重用途函数:
- 直接注册:提供
class_path参数 - 装饰器注册:不提供
class_path,返回装饰器
python
def decorator(cls: type) -> type:
if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
else:
_ATTN_OVERRIDES[backend] = f"{cls.__module__}.{cls.__qualname__}"
return cls
装饰器模式:从类对象自动生成完全限定名。
python
if class_path is not None:
if is_mamba:
_MAMBA_ATTN_OVERRIDES[backend] = class_path
else:
_ATTN_OVERRIDES[backend] = class_path
return lambda x: x
直接注册:使用提供的路径字符串。
python
return decorator
无class_path时返回装饰器。
使用示例:
python
# 装饰器方式
@register_backend(AttentionBackendEnum.FLASH_ATTN)
class MyCustomFlashAttn:
...
# 直接注册
register_backend(AttentionBackendEnum.CUSTOM, "my.module.MyCustomBackend")
# Mamba后端注册
@register_backend(MambaAttentionBackendEnum.LINEAR, is_mamba=True)
class MyCustomMambaAttn:
...
第六章:utils.py 逐行解析
6.0 文件概览
utils.py(892行)是注意力模块的工具箱,提供KV Cache布局管理、批次操作、元数据构建辅助等通用功能。
| 行范围 | 内容 |
|---|---|
| 1-40 | 导入与全局变量 |
| 41-75 | KV Cache布局管理 |
| 77-140 | PerLayerParameters |
| 142-220 | make_local_attention_virtual_batches |
| 222-280 | make_kv_sharing_fast_prefill_common_attn_metadata |
| 282-370 | split_decodes_prefills_and_extends |
| 372-430 | split_decodes_and_prefills |
| 432-460 | split_prefill_chunks |
| 462-530 | reorder_batch_to_split_decodes_and_prefills |
| 532-555 | reshape_query_for_spec_decode / reshape_attn_output_for_spec_decode |
| 557-580 | subclass_attention_metadata |
| 582-640 | create_fast_prefill_custom_backend |
| 642-720 | compute_causal_conv1d_metadata |
| 722-770 | get_dcp_local_seq_lens |
| 772-892 | mamba_get_block_table_tensor |
6.1 导入与全局变量(第1-40行)
python
import functools
from collections.abc import Callable
from dataclasses import dataclass, field, fields, make_dataclass
from typing import TYPE_CHECKING, Any, Literal, Protocol, get_args
make_dataclass:动态创建dataclass(用于运行时生成元数据类)Literal:字面量类型(用于KVCacheLayoutType定义)
python
import numpy as np
import torch
from typing_extensions import runtime_checkable
runtime_checkable:使Protocol支持isinstance检查
python
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
cdiv:向上取整除法(ceiling division)KVCacheSpec:KV Cache规格基类MambaSpec:Mamba规格(SSM专用)
python
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
python
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
CommonAttentionMetadata,
subclass_attention_backend,
)
6.2 KV Cache布局管理(第41-75行)
python
KVCacheLayoutType = Literal["NHD", "HND"]
- 两种KV Cache布局:
- NHD :
[2, num_blocks, block_size, num_heads, head_size]--- block维度在前 - HND :
[2, num_heads, num_blocks, block_size, head_size]--- head维度在前
- NHD :
python
_KV_CACHE_LAYOUT_OVERRIDE: KVCacheLayoutType | None = None
- 全局覆盖变量,优先级最高
python
PAD_SLOT_ID = -1
NULL_BLOCK_ID = 0
PAD_SLOT_ID:padding slot的ID(-1表示无效)NULL_BLOCK_ID:空块的ID(0号块,用于初始化/空请求)
python
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
- 验证布局字符串是否合法
python
@functools.lru_cache
def get_kv_cache_layout():
global _KV_CACHE_LAYOUT_OVERRIDE
cache_layout: Literal["NHD", "HND"] | None = None
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once(
"`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. "
"Setting KV cache layout to %s.",
cache_layout,
)
return cache_layout
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
assert is_valid_kv_cache_layout(cache_layout)
logger.info_once(...)
return cache_layout
布局决定优先级:
- 代码级覆盖(
_KV_CACHE_LAYOUT_OVERRIDE)--- 最高 - 环境变量(
VLLM_KV_CACHE_LAYOUT) - KV Connector默认布局
- 都无则为None
@lru_cache:只计算一次,后续调用直接返回缓存。
python
def set_kv_cache_layout(cache_layout: KVCacheLayoutType | None):
global _KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
get_kv_cache_layout.cache_clear()
- 设置覆盖并清除LRU缓存 ,确保下次调用
get_kv_cache_layout()重新计算
6.3 PerLayerParameters(第77-140行)
python
@dataclass
class PerLayerParameters:
"""Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters."""
window_left: int
logits_soft_cap: float | None
sm_scale: float
has_sinks: bool = False
has_same_window_lefts: bool | None = field(default=None, compare=False)
has_same_all_params: bool | None = field(default=None, compare=False)
设计意图:存储每个注意力层的超参数,用于FlashInfer的plan阶段。
window_left:滑动窗口左边界(-1表示无窗口)logits_soft_cap:logits软上限sm_scale:softmax缩放因子has_sinks:是否有attention sinkshas_same_window_lefts/has_same_all_params:跨层一致性标记(compare=False不参与相等比较)
python
def get_per_layer_parameters(
vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"]
) -> dict[str, PerLayerParameters]:
layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, cls_)
window_size = getattr(impl, "sliding_window", None)
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
has_sinks = getattr(impl, "sinks", None) is not None
per_layer_params[key] = PerLayerParameters(
window_left, logits_soft_cap, sm_scale, has_sinks
)
return per_layer_params
从模型配置中扫描各注意力层,提取超参数。
python
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters],
) -> PerLayerParameters:
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
global_params.has_same_window_lefts = all(
params.window_left == global_params.window_left for params in param_sets
)
global_params.has_same_all_params = all(
params == global_params for params in param_sets
)
return global_params
断言所有层参数一致(FlashInfer要求),返回全局参数。
6.4 make_local_attention_virtual_batches(第142-220行)
这是本文件最复杂的函数,实现本地注意力的虚拟批次构建。
python
def make_local_attention_virtual_batches(
attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> tuple[CommonAttentionMetadata, Callable[[torch.Tensor], torch.Tensor]]:
核心思想:将长序列的注意力拆分为多个本地注意力块,每个块作为独立的"虚拟批次项"传给注意力kernel。
为什么需要本地注意力:
- 标准因果注意力:每个token关注所有之前的token
- 本地注意力:每个token只关注最近
attn_chunk_size个token - 优点:计算复杂度从O(n²)降到O(n·chunk_size)
算法示例(代码注释翻译):
假设batch=3个序列:
- q_seqlens = [4, 10, 5]
- kv_seqlens = [6, 17, 9]
- attn_chunk_size = 4
标准因果注意力mask(batch idx 0):
k_toks > 0 1 2 3 4 5
q_toks v
0 | 1 1 1
1 | 1 1 1 1
2 | 1 1 1 1 1
3 | 1 1 1 1 1 1
本地注意力mask(chunk_size=4):
k_toks > 0 1 2 3 4 5
q_toks v
0 | 1 1 1
1 | 1 1 1 1
2 | 1
3 | 1 1
拆分为虚拟批次:
local-batch 0 (q=2, kv=4): local-batch 1 (q=2, kv=2):
k_toks > 0 1 2 3 k_toks > 4 5
q_toks v q_toks v
0 | 1 1 1 2 | 1
1 | 1 1 1 1 3 | 1 1
核心计算步骤:
python
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens
).astype(np.int32)
- 计算每个请求第一个块中的query token数
seq_lens_np - q_seqlens= 已计算token数- 对chunk_size取模 → 已计算token在当前块中的偏移
- chunk_size - 偏移 = 第一个块还能容纳多少新token
- 与q_seqlens取最小(不能超过总query数)
python
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size)
tokens_in_last_block:最后一个块的token数(Python负取模技巧)local_blocks:每个请求的虚拟块数
python
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
- 构建虚拟批次内的arange索引和逆arange
python
seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
seqlens_q_local[arange == 0] = q_tokens_in_first_block
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size
)[arange > 0]
- 计算每个虚拟批次的query长度
- 第一个块可能是部分的
- 后续块最多chunk_size
python
seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
- 每个虚拟批次的KV长度:中间块满,最后一块可能是部分的
Block Table构建:
python
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - (
rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks)
)
block_starts = k_seqstarts_absolute // block_size
pages_per_local_batch = attn_chunk_size // block_size
block_indices = block_starts[:, None] + np.arange(
pages_per_local_batch, dtype=np.int32
)
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(
np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch,
)
- 计算每个虚拟批次对应的绝对KV起始位置
- 转换为block索引
- 构建新的block_table:每个虚拟批次只引用其KV范围内的块
返回:
python
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True),
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
...
), make_block_table
返回修改后的CommonAttentionMetadata和block_table重建函数。
6.5 make_kv_sharing_fast_prefill_common_attn_metadata(第222-280行)
python
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
KV共享快速Prefill路径:在KV共享模式下,prefill时只对生成位置(最后一个token)计算完整注意力,而非生成位置可以跳过。
python
if common_attn_metadata.max_query_len == 1:
return common_attn_metadata
- 全是decode,无需处理
python
logits_indices = logits_indices_padded[:num_logits_indices]
request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True)
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
bucketize:查找每个logit_index属于哪个请求bincount:统计每个请求有多少decode token
python
decode_query_start_loc = torch.empty(num_reqs + 1, ...)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
构建新的query_start_loc,只包含生成位置的token。
6.6 split_decodes_prefills_and_extends(第282-370行)
python
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
三分割:将批次分为decode / extend / prefill 三类。
- decode:已处理完所有prompt token,每次只生成1个token
- extend:chunked prefill中间状态,已处理部分prompt
- prefill:第一次处理prompt
批次顺序:decode → extend → prefill
python
if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
- 所有请求都是decode
python
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill_or_extend = query_lens > decode_threshold
is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
- query_len > threshold:不是纯decode
- seq_len == query_len:没有已计算token,是纯prefill
python
first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
first_prefill = is_prefill.int().argmax(dim=-1).item()
- 利用argmax找到第一个True的位置
返回6个值:num_decodes, num_extends, num_prefills, 及各自token数。
6.7 split_decodes_and_prefills(第372-430行)
python
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
require_uniform: bool = False,
treat_short_extends_as_decodes: bool = True,
) -> tuple[int, int, int, int]:
二分割:将批次分为decode / prefill两类(简化版)。
参数:
require_uniform:是否要求decode请求query_len一致(CUDA Graph需要)treat_short_extends_as_decodes:将短extend视为decode
python
if require_uniform:
if torch.all((query_lens == query_lens[0]) | (query_lens == 0)):
return num_reqs, 0, num_tokens, 0 # all decodes
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
require_uniform:不一致的query_len视为prefill- 否则:query_len > threshold视为prefill
python
if not treat_short_extends_as_decodes:
assert common_attn_metadata.is_prefilling is not None
is_prefill |= common_attn_metadata.is_prefilling
- 严格模式:使用is_prefilling标记区分真正的decode和短extend
6.8 split_prefill_chunks(第432-460行)
python
def split_prefill_chunks(
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
) -> list[tuple[int, int]]:
将prefill请求拆分为工作空间大小的块:避免单个长prefill占用过多GPU内存。
python
chunk_bounds = []
i, n = 0, len(seq_lens_cpu)
assert torch.all(seq_lens_cpu <= workspace_size).item()
while i < n:
start, chunk_total = i, 0
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
chunk_total += s
i += 1
chunk_bounds.append((start + request_offset, i + request_offset))
return chunk_bounds
贪心装箱:尽可能多地将请求放入同一chunk,直到总token数超过workspace_size。
6.9 reorder_batch_to_split_decodes_and_prefills(第462-530行)
python
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
实际重排批次:将InputBatch中的请求重新排列为 decode → short_extend → long_extend → prefill 顺序。
python
is_pure_prefill = ~has_context
is_long_extend = has_context & ~is_below_threshold
is_short_extend = has_context & is_below_threshold & ~done_prefilling
is_decode = has_context & is_below_threshold & done_prefilling
四类互斥分类:
- 无上下文 → 纯prefill
- 有上下文,超阈值 → 长extend
- 有上下文,低于阈值,仍在prefill → 短extend
- 有上下文,低于阈值,已完成prefill → decode
python
req_regions = np.zeros(num_reqs, dtype=np.int32) # 0 = decode by default
req_regions[is_short_extend] = 1
req_regions[is_long_extend] = 2
req_regions[is_pure_prefill] = 3
python
target_regions = np.repeat(
[0, 1, 2, 3],
[num_decodes, num_short_extends, num_long_extends, num_prefills],
).astype(np.int32)
needs_swap = req_regions != target_regions
if not needs_swap.any():
return False
如果当前顺序与目标顺序一致,无需重排。
python
for src in src_dest_map:
dst = src_dest_map[src]
while src != dst:
input_batch.swap_states(src, dst)
next_dst = src_dest_map.get(dst, dst)
src_dest_map[dst] = dst
dst = next_dst
环式交换:通过swap_states实现原地重排,避免额外内存分配。
6.10 reshape_query_for_spec_decode / reshape_attn_output_for_spec_decode(第532-555行)
python
def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor:
assert query.dim() == 3
total_tokens = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[2]
seq_len = total_tokens // batch_size
return query.view(batch_size, seq_len, num_heads, head_dim)
将3D query [total_tokens, num_heads, head_dim] 重塑为4D [batch_size, seq_len, num_heads, head_dim],投机解码需要4D格式。
python
def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor:
if attn_output.dim() == 3:
return attn_output
total_tokens = attn_output.shape[0] * attn_output.shape[1]
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
反向操作:4D输出展平为3D。
6.11 subclass_attention_metadata(第557-580行)
python
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
name: str = name_prefix + metadata_cls.__name__
Wrapped = make_dataclass(name, fields, bases=(metadata_cls,))
return Wrapped
动态创建元数据子类 :使用make_dataclass在运行时生成带额外字段的dataclass。
metadata_cls:基础元数据类fields:额外字段列表[(name, type, default)]- 返回新的dataclass子类
6.12 KVSharingFastPrefillMetadata Protocol + create_fast_prefill_custom_backend(第582-640行)
python
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int | None = None
Protocol:标记支持KV共享快速Prefill的元数据类。
python
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
工厂函数:基于已有后端创建支持快速Prefill的变体。
python
class FastPrefillAttentionBuilder(underlying_builder):
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = (
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
)
metadata = super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
- 继承底层Builder
- 在
build()中先将CommonAttentionMetadata转换为快速Prefill版本 - 然后调用父类build
python
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__,
KVSharingFastPrefillMetadata,
):
def __init__(self, metadata, common_attn_metadata):
for _field in fields(metadata.__class__):
setattr(self, _field.name, getattr(metadata, _field.name))
self.logits_indices_padded = (
common_attn_metadata.logits_indices_padded
)
self.num_logits_indices = common_attn_metadata.num_logits_indices
动态元数据类:多重继承底层metadata类和Protocol,浅拷贝所有字段,添加logits_indices。
python
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder,
)
return attn_backend
最终用动态子类工厂创建新的Backend类。
6.13 compute_causal_conv1d_metadata(第642-720行)
python
def compute_causal_conv1d_metadata(
query_start_loc_p_cpu: torch.Tensor,
*,
device: torch.device,
):
为causal_conv1d kernel(Mamba/SSM使用)计算元数据。
python
for BLOCK_M in [8]:
nums = -(-seqlens // BLOCK_M)
- 对每个序列,计算需要多少个BLOCK_M大小的块
-(-a // b)= 向上取整除法
python
mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums))
- 批次索引列表:每个块对应的序列索引
python
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device
)
batch_ptr[0:mlist_len].copy_(mlist, non_blocking=True)
- 预分配GPU缓冲区,将CPU数据异步拷贝到GPU
返回nums_dict,包含每个BLOCK_M的批次指针和偏移指针。
6.14 get_dcp_local_seq_lens(第722-770行)
python
def get_dcp_local_seq_lens(
seq_lens: torch.Tensor,
dcp_size: int = 1,
dcp_rank: int | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
DCP(Decode Context Parallelism)序列长度切分:将序列长度按DCP rank分割,每个rank只处理一部分KV Cache。
python
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_size, ...)
.unsqueeze(0).repeat(num_requests, 1)
)
else:
rank_offsets = torch.tensor([[dcp_rank]], ...)
dcp_rank=None:计算所有rank的分布- 否则只计算指定rank
python
base = seq_lens_tiled // cp_kv_cache_interleave_size // dcp_size * cp_kv_cache_interleave_size
remainder = seq_lens_tiled - base * dcp_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = base + remainder
分配算法:
base:每个rank均匀分配的基础长度remainder:余数部分,前几个rank多分一些- 考虑
cp_kv_cache_interleave_size交错因子
6.15 mamba_get_block_table_tensor(第772-892行)
python
def mamba_get_block_table_tensor(
block_table: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_spec: KVCacheSpec,
mamba_cache_mode: str,
) -> torch.Tensor:
Mamba/SSM的Block Table适配:不同Mamba缓存模式需要不同的block_table切片。
三种模式:
python
if mamba_cache_mode in ("all", "none"):
return block_table
"all":使用完整block_table"none":直接使用输入block_table(已预裁剪)
python
else: # "align"
assert isinstance(kv_cache_spec, MambaSpec)
start_indices = torch.clamp(
(seq_lens - 1) // kv_cache_spec.block_size,
min=0,
)
offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks,
device=block_table.device,
dtype=torch.int32,
)
indices_to_gather = (start_indices.unsqueeze(1) + offsets).to(torch.int64)
return torch.gather(block_table, 1, indices_to_gather)
"align"模式:
- 只取每个请求最后1+num_speculative_blocks个块
start_indices:每个请求当前序列的起始块位置torch.gather:从block_table中提取指定位置
用途:Mamba的SSM状态只依赖最近的token,不需要完整block_table,只需末尾几块用于投机解码。
附录:核心设计模式总结
| 设计模式 | 应用位置 | 说明 |
|---|---|---|
| 策略模式 | AttentionBackend → AttentionImpl | 后端可互换,统一接口 |
| 工厂方法 | get_impl_cls(), get_builder_cls() | 后端类提供实现类的工厂 |
| 模板方法 | AttentionImplBase.new() | 子类自动获取CP配置 |
| 注册表模式 | registry.py | 枚举+Decorator注册+运行时覆盖 |
| 装饰器模式 | register_backend() | 类注册装饰器 |
| 代理模式 | selector.py → platform.get_attn_backend_cls() | 委托平台做实际选择 |
| 备忘录模式 | @cache on _cached_get_attn_backend | 避免重复选择 |
| 组合模式 | create_fast_prefill_custom_backend | 基于已有后端组合出新后端 |
| Protocol | AttentionLayer, KVSharingFastPrefillMetadata | 结构化子类型(鸭子类型+静态检查) |
| 动态子类 | subclass_attention_backend | type()运行时创建子类 |
本文档为 vLLM v1 Attention 模块 Part 1,覆盖核心抽象层的4个文件共2324行代码的逐行分析。
Part 2 将覆盖具体后端实现(backends/.py),Part 3 将覆盖ops层(ops/.py)。