六、InputBatch 深度解析
源文件:
gpu_input_batch.py,1085 行
InputBatch 是 GPU 推理批次的持久化状态容器,它维护了一个虚拟连续数组 ------逻辑上紧凑,物理上可能因请求增删而存在空洞(通过 condense() 消除)。
6.1 CachedRequestState
python
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
generator: torch.Generator | None
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: torch.Tensor | None = None
mrope_position_delta: int | None = None
xdrope_positions: torch.Tensor | None = None
lora_request: LoRARequest | None = None
prompt_embeds: torch.Tensor | None = None
prev_num_draft_len: int = 0
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
字段详解
| 字段 | 类型 | 用途 |
|---|---|---|
req_id |
str |
请求唯一标识符,与调度器共享 |
prompt_token_ids |
`list[int] | None` |
mm_features |
list[MultiModalFeatureSpec] |
多模态特征列表(图像/音频位置、类型等) |
sampling_params |
`SamplingParams | None` |
generator |
`torch.Generator | None` |
block_ids |
tuple[list[int], ...] |
KV 缓存 block ID 元组,每个 KV 缓存组一个列表 |
num_computed_tokens |
int |
已完成计算的 token 数(含 prompt + 已接受 output) |
output_token_ids |
list[int] |
已生成的输出 token ID 列表 |
mrope_positions |
`torch.Tensor | None` |
mrope_position_delta |
`int | None` |
xdrope_positions |
`torch.Tensor | None` |
lora_request |
`LoRARequest | None` |
prompt_embeds |
`torch.Tensor | None` |
prev_num_draft_len |
int |
上一步的 draft token 数(异步调度用) |
pooling_params |
`PoolingParams | None` |
pooling_states |
`PoolingStates | None` |
post_init
python
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
)
if self.pooling_params is not None:
self.pooling_states = PoolingStates()
__post_init__ 在 dataclass 初始化后自动调用:
- 根据
prompt_token_ids或prompt_embeds计算num_prompt_tokens。 - 如果是 pooling 请求,初始化
pooling_states。
num_tokens 属性
python
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
总 token 数 = prompt token 数 + 已生成 output token 数。这是序列的当前总长度。
6.2 InputBatch.init
核心缓冲区
python
# Token IDs 缓冲区
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len), device="cpu", dtype=torch.int32, pin_memory=False
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
二维 token ID 矩阵 :[max_num_reqs, max_model_len]。每行对应一个请求的完整 token 序列(prompt + output)。不使用 pin_memory 因为这个缓冲区不直接传输到 GPU------token IDs 通过 index_select 选择性地传输。
python
self.is_token_ids_tensor = torch.zeros(
(max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
)
self.is_token_ids = self.is_token_ids_tensor.numpy()
类型掩码:标记哪些位置是 token IDs(True)vs embeddings(False)。用于混合输入(prompt embeds)场景。
python
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
懒分配 :Prompt embeddings 不预分配大矩阵(max_num_reqs × max_model_len × hidden_size 可能几十 GB),而是按请求存储在字典中。
计数器缓冲区
python
self.num_tokens_no_spec_cpu_tensor = torch.zeros(
(max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory
)
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
self.num_prompt_tokens_cpu_tensor = torch.zeros(...)
self.num_prompt_tokens = self.num_prompt_tokens_cpu_tensor.numpy()
self.num_computed_tokens_cpu_tensor = torch.zeros(...)
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
三个 per-request 计数器,使用 pinned memory 以加速 CPU→GPU 传输:
num_tokens_no_spec:总 token 数(不含 spec tokens),用于写入位置计算。num_prompt_tokens:prompt token 数,用于区分 prefill/decode 阶段。num_computed_tokens_cpu:已计算 token 数,用于 position 计算。
Block Table
python
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
max_num_blocks=max_num_blocks_per_req,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
委托给 MultiGroupBlockTable,支持多 KV 缓存组(详见第七章)。
采样参数缓冲区
每个采样参数都有CPU + GPU 双缓冲:
python
self.temperature = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs,), ..., pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
集合追踪 :greedy_reqs 和 random_reqs 分别追踪 greedy 和 random 采样的请求 ID,用于 all_greedy / all_random 快速判断------如果整个批次都是 greedy 采样,可以跳过 temperature scaling 和随机采样内核。
同样模式适用于 top_p、top_k、frequency_penalties、presence_penalties、repetition_penalties------每个都有 GPU buffer、CPU buffer、请求集合。
Logits Processor 基础设施
python
self.batch_update_builder = BatchUpdateBuilder()
self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
BatchUpdateBuilder 追踪每步的批次变化(添加/移除/移动),生成 BatchUpdate 传递给 logits processors 更新内部状态。LogitsProcessors 是可插拔的 logits 处理管线(如自定义 logits processor)。
Spec Decode 相关
python
self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int32, ...
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
spec_token_ids:每个请求的 draft token IDs 列表。num_accepted_tokens_cpu:每个请求上一步接受的 token 数,默认为 1(正常 decode)。
异步调度相关
python
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_req_id_to_index: dict[str, int] | None = None
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None
prev_sampled_token_ids:上一步的 GPU 采样结果张量。prev_req_id_to_index:上一步的请求→索引映射。sampled_token_ids_cpu/async_copy_ready_event:异步 GPU→CPU 拷贝的目标和同步事件。
6.3 add_request()
python
def add_request(self, request: "CachedRequestState") -> int:
req_index = self._register_add_request(request)
Slot 分配
python
def _register_add_request(self, request: "CachedRequestState") -> int:
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.batch_changed = True
...
return new_req_index
Slot 分配策略 :优先复用被移除请求留下的空 slot(pop_removed),否则追加到末尾。这确保 slot 不会无限增长。
状态初始化
python
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index].clear()
追加 vs 复用:如果是新 slot,append 到列表末尾;如果是复用空 slot,直接赋值。
python
self.req_id_to_index[req_id] = req_index
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(...)
self.num_prompt_tokens[req_index] = num_prompt_tokens
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids
self.is_token_ids[req_index, start_idx:end_idx] = True
self.num_tokens_no_spec[req_index] = request.num_tokens
Token ID 初始化 :将 prompt token IDs 和 output token IDs 写入 token_ids_cpu 的对应位置。对于 prompt embeds,只存储引用不写入 token_ids_cpu。is_token_ids 标记哪些位置是 token IDs。
采样参数初始化
python
if sampling_params := request.sampling_params:
if sampling_params.sampling_type == SamplingType.GREEDY:
self.temperature_cpu[req_index] = 0.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
...
分类收集 :根据采样参数将请求归类到不同集合,用于后续的快速判断(如 all_greedy 时跳过随机采样)。
LoRA 映射
python
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
self.request_lora_mapping[req_index] = 0
LoRA 路由 :request_lora_mapping[req_index] 存储每个请求的 LoRA ID,0 表示无 LoRA。lora_id_to_request_ids 追踪每个 LoRA 的请求集合,用于 LoRA 权重加载和激活。
6.4 remove_request()
python
def remove_request(self, req_id: str) -> int | None:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear()
self.block_table.clear_row(req_index)
核心操作:
- 从 ID→索引映射中移除。
- 将 slot 标记为空(
_req_ids[req_index] = None)。 - 注册到
batch_update_builder的 removed 列表(供 condense 和 logits processor 使用)。 - 清除 block table 行。
不立即压缩 :remove_request 只标记为空,不移动数据。后续 condense() 统一整理。
python
# LoRA 清理
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
LoRA 引用计数:当某个 LoRA ID 的请求集合为空时,移除该 LoRA 的映射,释放 LoRA 权重缓存。
python
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
...
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
...
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
全面清理 :从所有集合和映射中移除请求。allowed_token_ids_mask 行填 False(不屏蔽任何 token)。
6.5 condense()
python
def condense(self) -> None:
if not (empty_req_indices := self.batch_update_builder.removed):
return
if num_reqs == 0:
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return
前置检查:没有空 slot 则直接返回;批次为空则清空列表。
python
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
while last_req_index in empty_req_indices:
last_req_index -= 1
empty_index = self.batch_update_builder.peek_removed()
if empty_index >= last_req_index:
break
self.batch_update_builder.pop_removed()
req_id = self._req_ids[last_req_index]
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
双指针压缩:
empty_index:最小空 slot 索引(从removed队列取,按升序排列)。last_req_index:最大非空 slot 索引(从末尾向前扫描)。- 将
last_req_index的数据移动到empty_index,直到两者相遇。
python
num_tokens = self._get_active_token_count(last_req_index)
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens
]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens
]
只拷贝活跃 token :不拷贝整行 max_model_len,只拷贝 num_tokens 个 token。对于 max_model_len = 128K 的模型,这可以节省几个数量级的数据拷贝量。
python
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[
last_req_index
]
self.block_table.move_row(last_req_index, empty_index)
逐字段移动 :所有 per-request 状态都需要从旧位置移到新位置。block_table.move_row 高效地拷贝 block table 行。
python
# Autoregressive models require detailed tracking
self.batch_update_builder.moved.append(
(last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL)
)
Logits processor 通知:UNIDIRECTIONAL 移动(非交换)需要特殊处理,因为 logits processor 的内部状态可能依赖于请求位置。
python
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
列表截断 :压缩后移除超出 num_reqs 的元素,释放内存。
6.6 _make_sampling_metadata()
python
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(
self.temperature_cpu_tensor, self.temperature, num_reqs
)
else:
temperature = None
按需同步 :只有当批次中存在非 greedy 请求时,才将 temperature 拷贝到 GPU。copy_slice 是高效的 CPU→GPU 切片拷贝。同样逻辑应用于 top_p、top_k、penalties------只在有请求需要时才传输。
python
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any()
)
prompt_token_ids_cpu = (
self._make_prompt_token_ids_cpu_tensor() if needs_prompt_token_ids else None
)
prompt_token_ids = (
prompt_token_ids_cpu.to(device=self.device, non_blocking=True)
if prompt_token_ids_cpu is not None else None
)
惰性构建 prompt_token_ids:Frequency/presence/repetition penalty 需要知道 prompt 中的 token IDs 来计算惩罚。如果整个批次都没有 penalty,跳过构建------这避免了分配和传输一个大矩阵的开销。
python
needs_output_token_ids = (
not self.no_penalties
or bool(self.bad_words_token_ids)
or self.logitsprocs_need_output_token_ids
)
output_token_ids = (
cast(list[list[int]], self.req_output_token_ids)
if needs_output_token_ids else []
)
Output token IDs 的按需传递:只在需要 penalty / bad_words / 自定义 logits processor 时传递。空列表是轻量级的,不产生 CPU→GPU 传输。
python
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
...
)
构建 SamplingMetadata :将所有采样相关数据打包为单一对象传给 Sampler。None 值表示该维度不需要计算,Sampler 内部会跳过对应的内核。
6.7 set_async_sampled_token_ids / update_async_output_token_ids
python
def set_async_sampled_token_ids(
self, sampled_token_ids_cpu: torch.Tensor, async_copy_ready_event: torch.Event
) -> None:
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None
条件存储 :只在 output_token_ids 非空(即有 logits processor 需要 output IDs)时才保存引用。否则不存储,避免不必要的同步等待。
python
def update_async_output_token_ids(self) -> None:
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
return
assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.tolist()
new_ids: list[int] = sampled_token_ids[prev_index]
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
first_placeholder = req_output_token_ids.index(-1)
num_placeholders = len(req_output_token_ids) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders)
del new_ids[num_to_replace:]
req_output_token_ids[first_placeholder:] = new_ids
延迟同步策略:
- 遍历所有请求,检查是否有需要修正的占位符(
-1结尾)。 - 只在首次遇到需要修正的请求时 才调用
synchronize()------最小化 GPU→CPU 同步延迟。 - 使用
prev_req_id_to_index将当前请求映射到上一步的索引,从sampled_token_ids中取出对应的真实 IDs。 - 处理边界情况:占位符可能比实际采样 ID 多(乐观 spec decode)或少(KV load 失败导致 token 丢弃),取
min确保安全。
七、BlockTable 深度解析
源文件:
block_table.py(373行)+gpu/block_table.py(281行)
BlockTable 是 KV 缓存管理的核心数据结构,负责将逻辑 token position 映射到物理 KV 缓存 slot。
7.1 BlockTable 类(block_table.py)
7.1.1 构造函数
python
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
cp_kv_cache_interleave_size: int,
):
Hybrid Block Size 处理
python
if kernel_block_size == block_size:
self.block_size = block_size
self.blocks_per_kv_block = 1
self.use_hybrid_blocks = False
else:
if block_size % kernel_block_size != 0:
raise ValueError(...)
self.block_size = kernel_block_size
self.blocks_per_kv_block = block_size // kernel_block_size
self.use_hybrid_blocks = True
设计背景:KV 缓存管理器和 attention 内核可能使用不同的 block size。例如,管理器按 32 token 分配内存块,但 FlashAttention 内核按 16 token 处理。此时每个管理器 block 包含 2 个内核 block。
Hybrid 模式:
block_size被重写为内核 block size(更细粒度)。blocks_per_kv_block记录比例关系。use_hybrid_blocks标识是否需要 ID 转换。
核心缓冲区
python
self.block_table = self._make_buffer(
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int64
)
block_table :CpuGpuBuffer,CPU 端是 numpy ndarray,GPU 端是 torch Tensor。形状 [max_num_reqs, max_num_blocks_per_req],每行是一个请求的 block ID 序列。
num_blocks_per_row:每行的有效 block 数(类似变长数组的长度字段)。
slot_mapping :最终的 slot 映射,int64 因为 slot ID 可能很大(block_number * block_size + offset)。
7.1.2 append_row / add_row / clear_row / move_row / swap_row
append_row
python
def append_row(self, block_ids: list[int], row_idx: int) -> None:
if not block_ids:
return
if self.use_hybrid_blocks:
block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
追加 block IDs:在已有 block 序列末尾追加新 blocks。这是 decode 阶段的常见操作------每生成几个 token 就追加一个新 block。
Hybrid 转换 :如果内核和管理器使用不同 block size,先调用 map_to_kernel_blocks 将管理器 block ID 转换为内核 block ID。
add_row
python
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
设置(非追加):先将行长度归零,再追加。用于新请求或恢复请求的 block 初始化------完全替换而非追加。
clear_row
python
def clear_row(self, row_idx: int) -> None:
num_blocks = self.num_blocks_per_row[row_idx]
if num_blocks > 0:
self.block_table.np[row_idx, :num_blocks] = 0
self.num_blocks_per_row[row_idx] = 0
清除行 :将有效区域置零并重置长度。0 不是有效的 block ID(block 0 是 padding block),因此置零等于"无 block"。
move_row
python
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
block_table_np = self.block_table.np
block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
单向移动 :condense() 时使用,将数据从大索引移到小索引。源位置的数据不需要清除------它会被 condense 的下一轮操作覆盖或被截断删除。
swap_row
python
def swap_row(self, src: int, tgt: int) -> None:
src_tgt, tgt_src = [src, tgt], [tgt, src]
self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src]
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
交换行 :用于批次重排序(_may_reorder_batch)。src_tgt = [src, tgt] 作为索引同时赋值------numpy 的高级索引会自动处理交换(创建临时副本)。
7.1.3 commit_block_table(): CPU→GPU 同步
python
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
将 CPU 端的 block table 拷贝到 GPU。CpuGpuBuffer.copy_to_gpu(num_reqs) 只拷贝前 num_reqs 行,避免传输整个预分配矩阵。这是 async/non_blocking 的------在 _prepare_inputs 开头调用,与后续 CPU 计算重叠。
7.1.4 compute_slot_mapping(): Token→KV Slot 映射
python
def compute_slot_mapping(
self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None:
num_tokens = positions.shape[0]
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
_compute_slot_mapping_kernel[(num_reqs + 1,)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.block_table.gpu,
self.block_table.gpu.stride(0),
self.block_size,
self.slot_mapping.gpu,
TOTAL_CP_WORLD_SIZE=total_cp_world_size,
TOTAL_CP_RANK=total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
Triton kernel 并行化 :启动 num_reqs + 1 个 program(额外1个用于 padding)。每个 program 处理一个请求的所有 token。
内核逻辑 (_compute_slot_mapping_kernel):
python
@triton.jit
def _compute_slot_mapping_kernel(
num_tokens, max_num_tokens,
query_start_loc_ptr, positions_ptr,
block_table_ptr, block_table_stride,
block_size, slot_mapping_ptr,
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == tl.num_programs(0) - 1:
# Padding program: fill remaining slots with PAD_ID
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offsets, PAD_ID, mask=offsets < max_num_tokens)
return
start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)
Padding program :最后一个 program 负责将 slot_mapping 中 num_tokens 到 max_num_tokens 的区域填充为 PAD_SLOT_ID(通常是 -1)。这是 CUDA graph 兼容性要求的------CUDA graph 重放时访问固定大小的缓冲区,padding 区域的 slot ID 必须是安全的无效值。
python
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
row_offset = req_idx * block_table_stride
for i in range(start_idx, end_idx, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < end_idx
pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
block_indices = pos // virtual_block_size
block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(tl.int64)
Slot 映射的核心算法:
-
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE:考虑 Context Parallelism (CP) 后的虚拟 block 大小。在 CP 模式下,一个逻辑 block 的 token 分散在多个 CP rank 上,每个 rank 只存储block_size / cp_size的 token。但 position 是全局的,所以需要用更大的虚拟 block size 来计算 block index。 -
block_indices = pos // virtual_block_size:position 除以虚拟 block size 得到逻辑 block 索引。 -
block_numbers = block_table[req_idx, block_indices]:从 block table 中查找物理 block 编号。
python
virtual_block_offsets = pos - block_indices * virtual_block_size
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)
CP 分片映射:
-
virtual_block_offsets:position 在 block 内的偏移量(虚拟空间)。 -
Locality 判断 :
is_local检查当前偏移是否属于本 CP rank。CP 可以是 interleave 模式(token 交替分配给各 rank),公式(offset // interleave_size) % cp_size == rank。 -
Local offset 转换:将虚拟偏移转换为 local 偏移。例如,CP=2、interleave=1 时,虚拟偏移 3 对应 rank 1 的 local 偏移 1。
-
Slot ID = block_number * block_size + local_offset:物理 slot 在 KV 缓存中的线性地址。
-
非 local → PAD_ID:不属于本 rank 的 slot 标记为 padding,attention 内核会跳过这些 slot。
7.1.5 map_to_kernel_blocks()
python
@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
if blocks_per_kv_block == 1:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_arange
)
return kernel_block_ids.reshape(-1)
Block ID 转换示例:
- 管理器 block ID
[0, 1, 2],blocks_per_kv_block = 2 - 结果:
[0, 1, 2, 3, 4, 5] - 映射关系:管理器 0 → 内核 [0,1],管理器 1 → 内核 [2,3],管理器 2 → 内核 [4,5]
这通过广播乘法和加法实现,非常高效。
7.2 MultiGroupBlockTable(block_table.py)
python
class MultiGroupBlockTable:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
max_num_blocks: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
多组管理
python
if max_num_blocks is None:
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [
cdiv(max_model_len, block_size * total_cp_world_size)
for block_size in block_sizes
]
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
block_sizes, kernel_block_sizes, max_num_blocks
)
]
设计目的 :某些模型使用混合 KV 缓存------不同层可能有不同的 block size(如 attention 层 16 token/block、SSM 层 1 token/block)。MultiGroupBlockTable 为每个 KV 缓存组创建独立的 BlockTable,每组有自己的 block size 和 max blocks。
代理方法
所有单 BlockTable 的方法都被代理:
python
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def compute_slot_mapping(self, num_reqs, query_start_loc, positions) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
block_ids 从 list[int] 变为 tuple[list[int], ...]------每个 KV 缓存组一个列表。所有操作都是组间独立的,可以并行化(当前实现是串行的,但 CPU 操作足够快)。
7.3 BlockTables(gpu/block_table.py)--- GPU 优化的变体
这个文件提供了 BlockTable 的另一种实现,使用更高级的 GPU 内存管理:
7.3.1 StagedWriteTensor
python
self.block_tables: list[StagedWriteTensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size * self.cp_size)
block_table = StagedWriteTensor(
(self.max_num_reqs, max_num_blocks),
dtype=torch.int32,
device=device,
)
self.block_tables.append(block_table)
StagedWriteTensor 是一种延迟写入机制------CPU 端的修改先 staged(暂存),然后通过 apply_staged_writes() 一次性提交到 GPU。这避免了每次 append_row 都触发一次 CPU→GPU 传输,将多次小传输合并为一次大传输。
7.3.2 gather_block_tables()
python
def gather_block_tables(
self, idx_mapping: torch.Tensor, num_reqs_padded: int,
) -> tuple[torch.Tensor, ...]:
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
idx_mapping, self.block_table_ptrs, self.input_block_table_ptrs,
self.block_table_strides, self.num_blocks.gpu, ...
)
return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
GPU 端重排 :与 block_table.py 的"CPU 排列 + 整体传输"不同,这个变体在 GPU 上直接根据 idx_mapping(重排索引)gather block table 行到 input_block_tables。这避免了 CPU→GPU 传输,适合 CUDA graph 模式。
7.3.3 compute_slot_mappings()(GPU 变体)
python
def compute_slot_mappings(
self, idx_mapping, query_start_loc, positions, num_tokens_padded,
) -> torch.Tensor:
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
self.max_num_batched_tokens,
idx_mapping, query_start_loc, positions,
self.block_table_ptrs, self.block_table_strides,
self.block_sizes_tensor,
self.slot_mappings, ...
)
return self.slot_mappings[:, :num_tokens_padded]
与 block_table.py 的 kernel 类似,但增加了 idx_mapping 支持------通过重排索引间接访问 block table,适应批次重排场景。block_sizes_tensor 允许一个 kernel 处理不同 block size 的 KV 缓存组。
总结:三大核心组件的协作关系
SchedulerOutput
│
▼
┌──────────────────────┐
│ GPUModelRunner │
│ .execute_model() │
│ │ Phase 0: 前置检查
│ │ Phase 1: _update_states()
│ ┌───────────┐ │ │
│ │InputBatch │◄──┤──────┘ 更新持久化批次状态
│ │ │ │
│ │BlockTable │◄──┤── Phase 2: _prepare_inputs()
│ └───────────┘ │ │ block_table.commit + compute_slot_mapping
│ │ ▼
│ │ Phase 3: _preprocess() → input_ids, positions, embeds
│ │ Phase 4: _model_forward() → hidden_states, logits
│ │ Phase 5: _sample() + _bookkeeping_sync() → output
│ │
└──────────────────────┘
│
▼
ModelRunnerOutput / AsyncGPUModelRunnerOutput
关键设计原则:
- 两步协议(execute_model + sample_tokens):将 forward 和采样解耦,为异步调度提供弹性。
- 乐观假设 + 延迟修正:spec decode 场景下先假设所有 draft 被接受,forward 后再修正,避免 CPU-GPU 同步阻塞。
- 按需同步:只在需要时传输 GPU 数据(temperature、penalties 等),最小化 PCIe 带宽消耗。
- 计算-通信重叠 :block_table commit 在
_prepare_inputs开头异步启动,与后续 CPU 计算并行。 - 持久化批次 + condense:避免每步重建批次状态,通过 condense 消除空洞保持紧凑。
- Hybrid block size 支持:管理器与内核 block size 解耦,通过 ID 转换层桥接。
- Context Parallelism 感知:slot mapping kernel 内置 CP 分片逻辑,支持 token interleaving 和 local offset 转换。
8.1 Sampler(sampler.py)------ 顶层编排器
8.1.1 构造函数
python
class Sampler:
def __init__(self, max_num_reqs, vocab_size, device, req_states,
logprobs_mode="raw_logprobs", num_speculative_tokens=1):
初始化四个核心子状态:
- SamplingStates:温度、top_k/top_p/min_p、种子、logprobs 请求计数
- PenaltiesState:重复/频率/存在惩罚
- LogitBiasState:allowed_token_ids、logit_bias、min_tokens
- BadWordsState:禁用词 token 序列
所有子状态统一使用 UvaBackedTensor / StagedWriteTensor 实现 CPU 写入 → UVA/GPU 读取的零拷贝传输模式。
8.1.2 __call__ 完整流程
python
def __call__(self, logits, input_batch) -> SamplerOutput:
- 提取映射索引 :
expanded_idx_mapping(expanded token → request 映射)、idx_mapping_np(request index 映射)、positions、input_ids - NaN 检测 :在采样参数应用之前计算
num_nans,确保统计的是原始 logits 状态 - 调用
sample():执行采样管线,返回(sampled_token_ids, processed_logits) - logprobs 计算 :若任一请求需要 logprobs(
max_num_logprobs != NO_LOGPROBS),调用compute_topk_logprobs() - 构造输出 :
SamplerOutput(sampled_token_ids, logprobs_tensors, num_nans, num_sampled)
8.1.3 apply_sampling_params ------ 参数应用管线
执行顺序严格固定,每一步操作 logits 的方式不同:
logits (FP16/BF16) → copy to FP32
→ apply_logit_bias (in-place mask/bias)
→ apply_penalties (in-place scale/subtract)
→ apply_bad_words (in-place -inf mask)
→ apply_temperature (in-place divide)
→ apply_min_p (in-place -inf threshold)
→ apply_top_k_top_p (may return new tensor)
架构洞察:
- FP32 精度提升:原始 logits 从半精度拷贝为 FP32 后全程在 FP32 上操作,避免数值溢出
- top_k_top_p 是唯一可能产生新 tensor 的步骤 :因为它使用
torch.topk+ scatter 的方式实现,无法原地操作 - 执行顺序不是随意排列------penalties 必须在 temperature 之前(影响 logit 量级),bad_words 必须在采样前生效,min_p 必须在 top_k_top_p 之前(否则 top_k 可能先截断掉被 min_p 保护的 token)
8.1.4 sample() ------ Gumbel 采样
python
def sample(self, logits, ...) -> (sampled, processed_logits):
processed_logits = self.apply_sampling_params(logits, ...)
sampled = gumbel_sample(processed_logits, expanded_idx_mapping,
temperature.gpu, seeds.gpu, pos,
apply_temperature=False)
return sampled, processed_logits
apply_temperature=False 表明温度已在 apply_sampling_params 中应用,此处不再重复。
8.2 PenaltiesState(penalties.py)------ 惩罚机制
8.2.1 三种惩罚类型
| 惩罚类型 | 默认值 | 含义 | 机制 |
|---|---|---|---|
repetition_penalty |
1.0(无惩罚) | 重复惩罚 | 正logits除以penalty,负logits乘以penalty |
frequency_penalty |
0.0 | 频率惩罚 | logit -= freq_penalty × 出现次数 |
presence_penalty |
0.0 | 存在惩罚 | logit -= pres_penalty × (出现>0 ? 1 : 0) |
8.2.2 数据结构
prompt_bin_mask:[max_num_reqs, ceil(vocab_size/32)]的 int32 位掩码,用位运算存储 prompt token 是否出现output_bin_counts:[max_num_reqs, vocab_size]的 int32 计数矩阵,存储 output token 出现次数
位掩码设计极为精妙------prompt 通常包含数千 token,用 bit-packed 表示将内存降低 32×。output 计数则需要精确值(frequency_penalty 需要确切次数),无法压缩。
8.2.3 Triton Kernel _penalties_kernel
核心逻辑:
python
# 对每个 (token_idx, vocab_block) 组合
# 1. 加载 rep_penalty, freq_penalty, pres_penalty
# 2. 计算 output_bin_counts = base_output_counts + draft_counts
# draft_counts: 当前 step 内前面位置的 draft token 累积计数
# 3. 加载 prompt_bin_mask(解包 packed int32 → bool)
# 4. 应用 repetition_penalty:
# scale = where(prompt|output, rep_penalty, 1.0)
# logits *= where(logits > 0, 1.0/scale, scale)
# 5. 应用 frequency_penalty: logits -= freq * output_counts
# 6. 应用 presence_penalty: logits -= pres * (counts > 0)
推测解码支持 :draft_counts 通过 tl.static_range(MAX_SPEC_LEN) 在编译期展开循环计算,累积同一 step 内前序 draft 位置对 output counts 的贡献,确保惩罚在推测解码场景下正确。
8.2.4 _bincount_kernel ------ Token 计数初始化
对新加入惩罚的请求,通过 Triton kernel 并行计算 prompt bin mask(atomic_or 位操作)和 output bin counts(atomic_add)。这是一个全量扫描过程,但只在请求首次加入时执行一次。
8.3 LogitBiasState(logit_bias.py)------ Token 级偏置
8.3.1 三类偏置功能
- Allowed Token IDs:白名单模式,将非白名单 token logits 设为 -inf
- Logit Bias:对指定 token ID 施加固定偏置值(dict 映射)
- Min Tokens:当已生成 token 数 < min_len 时,将 stop token 的 logits 设为 -inf
8.3.2 Triton Kernel _bias_kernel
执行流程:
python
# 对每个 token_idx:
# 1. Allowed token IDs:
# - 保存白名单 token 的 logits
# - 全部设为 -inf(分块遍历 vocab)
# - 恢复白名单 token 的 logits
# 2. Logit bias:
# - 加载 bias_token_ids + bias 值
# - logits[token_ids] += bias
# 3. Min tokens:
# - if pos < min_len and num_stop_token_ids > 0:
# - stop_token_ids 的 logits 设为 -inf
Allowed Token IDs 的实现技巧 :先保存 → 全局 -inf → 恢复,避免了复杂的 mask 构造。分块大小 LOGITS_BLOCK_SIZE=8192 与 vocab 大小匹配,确保遍历效率。
8.4 min_p.py ------ 最小概率阈值
python
@triton.jit
def _min_p_kernel(logits, expanded_idx_mapping, min_p, vocab_size):
# 1. 两遍扫描:第一遍求 max(logits)
# 2. threshold = max_val + log(min_p)
# 3. 第二遍:logits < threshold → -inf
数学原理 :min_p 等价于在 softmax 概率空间中设置阈值。log(min_p) 将概率阈值转换到 logit 空间:softmax(x) < min_p ⟺ x < max + log(min_p)。这种实现避免了实际计算 softmax(数值不稳定且昂贵),直接在 logit 空间操作。
8.5 BadWordsState(bad_words.py)------ 禁用词 logits 掩码
8.5.1 数据结构
bad_word_token_ids:[max_num_reqs, 1024]展平存储所有禁用词的 token 序列bad_word_offsets:[max_num_reqs, 129]每个禁用词的起始偏移量(CSR 格式)num_bad_words:每个请求的禁用词数量
8.5.2 Triton Kernel _bad_words_kernel
核心是 多 token 序列匹配:
python
# 对每个 (token_idx, bw_idx):
# 1. 获取 bad_word 的 prefix_len = bad_word_len - 1
# 2. 如果 prefix_len > effective_len: 跳过
# 3. 逐位置比较 prefix:
# - 实际 token 从 all_token_ids (已输出) 或 input_ids (spec input) 获取
# 4. 若全部匹配: 将 last_token 的 logits 设为 -inf
推测解码兼容 :from_spec_input 判断逻辑------当 actual_pos >= output_len 时,token 来自推测输入(input_ids),否则来自已确认输出(all_token_ids),确保多步 draft 下的禁用词检查正确。
8.6 Gumbel 采样(gumbel.py)------ 核心采样引擎
8.6.1 Gumbel-Max Trick
vLLM 使用 Gumbel-Max Trick 统一 greedy 和 stochastic 采样:
temperature=0: argmax(logits) --- Gumbel noise 为 0
temperature>0: argmax(logits/temp + Gumbel_noise)
8.6.2 Triton Kernel _gumbel_sample_kernel
python
# 对每个 (token_idx, vocab_block):
# 1. 加载 logits block → FP32
# 2. gumbel_block_argmax():
# - 如果 temp != 0 且 APPLY_TEMPERATURE: logits /= temp
# - 如果 temp != 0: 生成 Gumbel noise
# - seed = randint(seed, pos) --- 确定性随机
# - u = tl_rand64(gumbel_seed, block) --- FP64 uniform
# - gumbel_noise = -log(-log(u)) --- 双精度 Gumbel 变换
# - logits = logits + gumbel_noise (mask 为 -inf)
# 3. block 内 max → (value, index)
# 4. 存入 local_argmax, local_max
最终归约 :CPU 端执行 local_max.argmax(dim=-1) → local_argmax.gather(),选出全局最大值对应的 token ID。
8.6.3 tl_rand64 ------ 确定性 FP64 随机数
python
def tl_rand64(seed, offset, includes_zero=False):
lo, hi, _, _ = tl.randint4x(seed, offset)
r = (hi << 32) | lo # 64-bit 整数
u = r * 5.421e-20 # 转换为 [0,1) 的 FP64 uniform
if not includes_zero:
u = max(u, 2.225e-308) # float64 tiny, 避免 log(0)
return u
使用 Triton 内置 randint4x 生成 4 个 32-bit 随机整数,拼合为 64-bit,再乘以 2^-64 转为 uniform。FP64 精度 对于 Gumbel 变换至关重要-------log(-log(u)) 在 u 接近 0 或 1 时对精度极度敏感,FP32 会产生灾难性数值误差。
8.7 logprob.py ------ Log Softmax + Top-K
8.7.1 compute_topk_logprobs
python
def compute_topk_logprobs(logits, num_logprobs, sampled_token_ids, cu_num_logits):
# 1. 构造 logprob_token_ids: sampled + topk_indices
# 2. compute_token_logprobs(): 只计算这些 token 的 log_softmax
# 3. _ranks_kernel: 计算每个 sampled token 在 vocab 中的排名
内存优化 :不实例化完整 [batch, vocab] 的 logprobs 矩阵,只计算 top-k + 1 个 token 的 logprobs,将显存从 O(batch × vocab) 降为 O(batch × k)。
8.7.2 _topk_log_softmax_kernel
两遍扫描算法:
- 第一遍:求
max_val(数值稳定) - 第二遍:求
log(sum(exp(logits - max_val)))=lse - 输出:
logits[topk_ids] - max_val - lse
这是经典的 log-softmax 数值稳定实现。
8.8 prompt_logprob.py ------ Prompt Token Logprob
8.8.1 PromptLogprobsWorker
管理 prompt token 的 logprob 计算,处理 分块 prefill 场景:
in_progress_prompt_logprobs:dict 映射 req_id → list[LogprobsTensors],存储跨 step 累积的 prompt logprobs- 当 prompt 被分块调度时,每步计算一部分 logprobs,追加到列表
- 当 prompt 全部处理完成(
is_prompt_chunked=False),合并所有分块并返回
8.8.2 compute_prompt_logprobs_with_chunking
为避免 prompt logits 过大导致 OOM,使用 CHUNK_SIZE=1024 分块计算:
python
for start_idx in range(0, num_prompt_tokens, 1024):
prompt_logits = logits_fn(hidden_states[start_idx:start_idx+1024])
logprobs = compute_topk_logprobs(prompt_logits, 0, token_ids[...])
架构洞察 :logits_fn 涉及 all-gather 通信(TP 场景),分块处理可避免一次性分配完整 logits 的显存峰值。
8.9 SamplingStates(states.py)------ 采样参数状态
python
class SamplingStates:
temperature: UvaBackedTensor # [max_num_reqs]
top_k: UvaBackedTensor # [max_num_reqs], 默认 = vocab_size
top_p: UvaBackedTensor # [max_num_reqs], 默认 = 1.0
min_p: UvaBackedTensor # [max_num_reqs], 默认 = 0.0
seeds: UvaBackedTensor # [max_num_reqs]
num_logprobs: np.ndarray # [max_num_reqs], 默认 = -1
关键优化 :每次 kernel launch 前检查是否有请求实际使用该参数------np.all(temp == 0.0 | temp == 1.0) 跳过 temperature kernel,np.all(min_p == 0.0) 跳过 min_p kernel。这种 CPU 端短路 避免了大量无效 kernel launch。
8.10 SamplerOutput(output.py)
python
@dataclass
class SamplerOutput:
sampled_token_ids: torch.Tensor # [num_reqs, 1]
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None
num_sampled: torch.Tensor | None
极简的数据容器,sampled_token_ids 为 2D 形状以便与推测解码的多 token 输出对齐。
九、CUDA Graph 与 UBatch 深度解析
9.1 CudaGraphManager(cudagraph_utils.py)------ 图管理基座
9.1.1 BatchExecutionDescriptor
python
@dataclass(frozen=True)
class BatchExecutionDescriptor:
cg_mode: CUDAGraphMode # NONE / PIECEWISE / FULL
num_tokens: int # 捕获时的 token 数
num_reqs: int | None # None = PIECEWISE 无需请求填充
uniform_token_count: int | None # 非 None = 均匀 decode
冻结不可变性 :frozen=True 使其可作为 dict key(self.graphs 的键)。
匹配规则 _is_compatible:
uniform_token_count=None(PIECEWISE)可处理任意 uniform_token_countnum_reqs=None(PIECEWISE)无需请求填充desc.num_tokens >= num_tokens:允许向上匹配(padding 到更大的图)
9.1.2 _init_candidates() ------ 候选图预计算
python
def _init_candidates(self):
capture_sizes = sorted(compilation_config.cudagraph_capture_sizes)
# 为每个 num_tokens 生成候选 descriptor 列表
# 建立 _candidates[token_count] → [desc1, desc2, ...] 的索引
# 同时按 mode 分组为 _capture_descs[mode] → [desc, ...]
优先级设计 :_capture_descs 中 PIECEWISE 先于 FULL 捕获------PIECEWISE 激活更大,先捕获让 FULL 的分配可以复用 PIECEWISE 的内存池空间。
9.1.3 capture() ------ 图捕获流程
python
def capture(self, create_forward_fn):
with graph_capture(device=self.device): # 设置 TP 上下文
for mode in [PIECEWISE, FULL]:
for desc in self._capture_descs[mode]:
forward_fn = create_forward_fn(desc)
forward_fn(NONE) # Warmup
if desc.cg_mode == PIECEWISE:
forward_fn(PIECEWISE) # PW 不创建 CUDAGraph
else:
graph = torch.cuda.CUDAGraph()
get_offloader().sync_prev_onload() # 同步 offloader
with torch.cuda.graph(graph, self.pool):
forward_fn(NONE)
get_offloader().join_after_forward()
self.graphs[desc] = graph
关键细节:
- PIECEWISE 不创建
CUDAGraph对象,只运行编译区域 - FULL 图捕获前同步 offloader 的 copy stream,确保预取完成
- 捕获后 join offloader 的 copy stream,避免未 join stream 错误
9.1.4 dispatch() ------ 运行时图匹配
python
def dispatch(self, num_reqs, num_tokens, uniform_token_count):
if self._graphs_captured and 0 < num_tokens < len(self._candidates):
for desc in self._candidates[num_tokens]:
if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
return desc
return BatchExecutionDescriptor(NONE, num_tokens, num_reqs)
O(1) 查找------直接索引到 _candidates[num_tokens],遍历少量候选即可匹配。
9.2 ModelCudaGraphManager ------ 模型级图管理
扩展 CudaGraphManager,增加了 hidden states / intermediate tensors 的管理:
hidden_states:最后一个 PP rank 的输出 hidden states 缓冲区aux_hidden_states:EAGLE3 的辅助 hidden statesintermediate_tensors:非最后 PP rank 的中间张量
run_fullgraph() 在 replay 后从预分配缓冲区切片返回结果,避免图捕获时张量生命周期问题。
9.3 EncoderCudaGraphManager(encoder_cudagraph.py)------ 多模态编码器图
9.3.1 Budget-Based 图管理
不同于主模型的 token-count-based 图,编码器使用 token budget 概念:
python
self.token_budgets = _generate_budgets(min_budget, max_budget)
# 生成 2 的幂次预算: [min, 2*min, 4*min, ..., max]
9.3.2 贪心装箱 _execute_local()
python
# 1. 按 output token 数升序排序
# 2. 贪心装箱: 尽可能多的小图打包到同一 budget
# 3. 对每个 batch: 找最小 fitting budget
# 4. 超出 max_budget 的单图回退到 eager
交换论证:贪心升序装箱最小化 eager 回退------任何其他排序都会使某个 batch 的 token 总和更高,增加超出 budget 的概率。
9.3.3 DP 分片 _dp_shard()
当 mm_encoder_tp_mode="data" 时,将图像/视频按 input size 负载均衡分配到 TP rank,使用 get_load_balance_assignment() 实现。
9.4 UBatchWrapper(gpu_ubatch_wrapper.py)------ 微批次双缓冲重叠
9.4.1 DBO(Dual-Buffer Overlap)机制
UBatchWrapper 实现了 计算-通信重叠 的核心架构:
┌──────────┐ ┌──────────┐
│ UBatch 0 │ │ UBatch 1 │
│ compute │ ──→ │ comm │
│ stream │ │ stream │
└──────────┘ └──────────┘
│ │
▼ ▼
Switch streams + GPU events sync
compute_stream:执行模型 forward(矩阵乘法、attention)comm_stream:执行 all-to-all 等 EP 通信ready_barrier:线程同步屏障,确保两个 ubatch 线程同时就绪
9.4.2 _capture_ubatches() ------ 微批次图捕获
python
def _capture_ubatches(self, ubatch_metadata, model):
# 1. 启动 ubatch 线程,每个线程初始化 CUDA context
# 2. ready_barrier.wait() --- 所有线程就绪
# 3. 在 compute_stream 上捕获 CUDAGraph
# 4. 设置 cpu_wait_event --- 线程内交替执行
# 5. 收集结果并拼接
线程模型 :每个 ubatch 一个 Python 线程,各自持有 UBatchContext 管理流切换。cpu_wait_event / cpu_signal_event 实现线程间的 CPU 级同步------同一时刻只有一个线程在 GPU 上提交工作。
9.4.3 SM 控制 SMControlContextManager
python
class SMControlContextManager:
# comm_sms: 分配给通信的 SM 数
# compute_sms: total_sms - comm_sms
# DeepEP highthroughput: set_num_sms()
# DeepGEMM: deep_gemm_set_num_sms()
在 DBO 模式下,通信和计算共享 GPU SM 资源。通过限制通信 SM 数量,避免通信占用全部 SM 导致计算饥饿。
9.5 UBatchContext(ubatching.py)------ 流切换 + 事件同步
9.5.1 核心状态
python
class UBatchContext:
compute_stream: torch.cuda.Stream
comm_stream: torch.cuda.Stream
gpu_comm_done_event: torch.Event # 通信完成事件
gpu_compute_done_event: torch.Event # 计算完成事件
cpu_wait_event: threading.Event # CPU 线程等待
cpu_signal_event: threading.Event # 通知下一个 ubatch
9.5.2 流切换方法
| 方法 | 动作 | 同步 |
|---|---|---|
switch_to_comm() |
切到 comm_stream | 无 |
switch_to_compute() |
切到 compute_stream | 无 |
switch_to_comm_sync() |
signal compute_done → 切 comm → wait compute_done | GPU event |
switch_to_compute_sync() |
signal comm_done → 切 compute → wait comm_done | GPU event |
9.5.3 全局注册函数
python
dbo_yield = _register_ubatch_function(UBatchContext.yield_)
dbo_yield_and_switch_from_compute_to_comm = ...
dbo_yield_and_switch_from_comm_to_compute = ...
这些函数被模型层代码(如 MoE all-to-all)调用,实现自动的流切换------无需在每个层手动判断 DBO 上下文。
十、推测解码深度解析(gpu/spec_decode/ 目录)
10.1 EagleSpeculator(eagle/speculator.py)------ EAGLE 单步 draft
10.1.1 架构概述
EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency)通过轻量级 draft 模型自回归生成候选 token,再由 target 模型一次 forward 验证。vLLM v1 实现了完整的 EAGLE/EAGLE3/MTP 三种变体。
10.1.2 propose() ------ 主入口
python
def propose(self, input_batch, attn_metadata, slot_mappings,
last_hidden_states, aux_hidden_states, ...):
# 1. 复制 hidden_states 到 draft buffer
# 2. EAGLE3: combine aux_hidden_states → compressed input
# 3. prepare_eagle_inputs(): 构建 draft 输入
# 4. Prefill (draft position 0):
# - 可能使用 FULL CUDA graph replay
# - 或 PIECEWISE/eager
# 5. Decode (draft positions 1..N-1):
# - 每步: run_model → compute_logits → gumbel_sample
# - update_eagle_inputs() 准备下一步
Padding 策略:为避免 CPU-GPU 同步获取 rejected 数量,draft 的 input 大小保持与 target 相同------每个请求的 query 长度 pad 到包含 rejected 位置,从而复用 target 的 attention metadata。
10.1.3 prefill() ------ Draft 首步
python
def prefill(self, num_reqs, num_tokens, ...):
# 1. run_model(): 执行 draft 模型 forward
# 2. 提取 last_token 位置的 hidden states
# 3. compute_logits → gumbel_sample(pos + 1)
# 4. 存储结果到 draft_tokens[:, 0] 和 hidden_states
pos + 1 的关键 :draft 采样的 Gumbel noise 种子使用 pos + 1,确保与 target 采样的种子一致,使得 strict rejection sample 可正确匹配。
10.1.4 generate_draft() ------ Draft 自回归循环
python
def generate_draft(self, num_reqs, num_tokens_padded, ...):
for step in range(1, num_speculative_steps):
last_hidden_states, hidden_states = self.run_model(...)
logits = self.model.compute_logits(last_hidden_states[:num_reqs])
draft_tokens = gumbel_sample(logits, ..., pos + 1)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < num_speculative_steps - 1:
update_eagle_inputs(draft_tokens, hidden_states, ...)
每步使用 update_eagle_inputs kernel 在 GPU 上更新 input_ids、positions、hidden_states 和 seq_lens,避免 CPU 回读。
10.2 RejectionSampler(rejection_sampler.py)------ 概率拒绝采样验证
10.2.1 三种验证方法
| 方法 | 原理 | 需要 draft_logits | 特点 |
|---|---|---|---|
strict |
target 采样 == draft 采样? | 否 | 最简单,但接受率低 |
probabilistic |
p(x) > u × q(x) | 是 | 接受率最高,数学最优 |
synthetic |
固定接受率衰减 | 否 | 中等,避免 target softmax |
10.2.2 Strict Rejection
python
@triton.jit
def _strict_rejection_sample_kernel:
for i in range(num_tokens - 1):
if not rejected:
if target_sampled != draft_sampled:
rejected = True
sampled[i] = target_sampled
num_sampled += 1
if not rejected:
sampled[last] = target_sampled # bonus token
num_sampled += 1
简单粗暴------逐 token 比较是否匹配,第一个不匹配即停止。
10.2.3 Probabilistic Rejection
三阶段流水线:
_compute_block_max_and_sumexp_kernel:并行计算每个 vocab block 的 max/sumexp,为 target 和 draft logits 构建归约结构_probabilistic_rejection_kernel:逐 token 执行概率比测试- Greedy: target_argmax == draft_sampled?
- Stochastic:
target_log_prob > log(u) + draft_log_prob
_resample_kernel:从残差分布max(p(x) - q(x), 0)重采样被拒绝/bonus token
数学正确性:残差分布采样确保最终输出严格遵循 target 分布------接受 draft token 时概率为 q(x),拒绝时从 p(x) - q(x) 采样,总概率 = q(x) × (q(x)/p(x)) + (p(x) - q(x)) = p(x)。
10.2.4 Synthetic Rejection
python
# 接受率按步衰减: acceptance_rate = base_rate × decay_factor^step
# 通过二分搜索计算 base_rate 和 decay_factor
# 使得平均联合概率 = synthetic_acceptance_rate
不需要 draft logits,用固定衰减率模拟接受过程,减少通信开销。
10.3 EAGLE 辅助工具
10.3.1 load_eagle_model(eagle/utils.py)
- 加载 draft 模型,设置
model_tag="eagle_head"用于编译后端区分 - Embedding 共享:若 draft 无自有 embed_tokens,直接引用 target 的
- LM Head 共享:若 draft 无自有 lm_head,直接引用 target 的
10.3.2 set_eagle3_aux_hidden_state_layers(eagle/eagle3_utils.py)
EAGLE3 使用 target 模型中间层的 hidden states 作为 draft 输入:
- 从
hf_config.eagle_aux_hidden_state_layer_ids或model.get_eagle3_default_aux_hidden_state_layers()获取辅助层 ID - 设置到模型上,在 forward 时自动提取
10.3.3 EagleCudaGraphManager(eagle/cudagraph.py)
EAGLE 使用独立内存池,避免与主模型的 CUDA graph 分配冲突。捕获 prefill 和 decode 两个阶段:
- Prefill:所有请求的 draft position 0
- Decode:
FULL_DECODE_ONLY模式(不支持 PIECEWISE decode,因为 PIECEWISE 的 num_reqs 不 padding 会导致 attention backend 越界)
十一、KV Connector 深度解析
11.1 KVConnector 基类(kv_connector.py)
python
class KVConnector:
def pre_forward(self, scheduler_output) -> None: pass
def post_forward(self, scheduler_output, wait_for_save=True) -> KVConnectorOutput | None: return None
def no_forward(self, scheduler_output) -> ModelRunnerOutput: return EMPTY_MODEL_RUNNER_OUTPUT
三接口设计:
pre_forward:模型 forward 前,加载跨节点 KV cachepost_forward:模型 forward 后,保存 KV cache 并获取传输状态no_forward:纯 KV 传输步骤(无实际推理),返回空输出
11.2 ActiveKVConnector
python
class ActiveKVConnector(KVConnector):
def pre_forward(self, scheduler_output):
self.kv_connector.bind_connector_metadata(metadata)
self.kv_connector.handle_preemptions(metadata)
self.kv_connector.start_load_kv(forward_context)
def post_forward(self, scheduler_output, ...):
if wait_for_save: self.kv_connector.wait_for_save()
output.finished_sending/recving = self.kv_connector.get_finished(...)
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
set_disabled() :通过 kv_transfer_state._KV_CONNECTOR_AGENT = None 禁用逐层 connector hooks,确保 warmup 等场景不触发 KV 传输。
11.3 KVConnectorModelRunnerMixin(kv_connector_model_runner_mixin.py)
提供 ModelRunner 的混入功能:
11.3.1 _get_kv_connector_output 上下文管理器
python
@contextmanager
def _get_kv_connector_output(scheduler_output, ...):
kv_connector.bind_connector_metadata(metadata)
kv_connector.start_load_kv(forward_context) # 后台加载
try:
yield output
finally:
if wait_for_save: kv_connector.wait_for_save()
output.finished_sending/recving = kv_connector.get_finished(...)
if not defer_finalize: kv_connector.clear_connector_metadata()
defer_finalize:推测解码场景下,draft model forward 后不能立即 finalize,需要等 target 采样完成。
11.3.2 use_uniform_kv_cache ------ 跨层统一 KV 布局
当 KV connector 配置且 attention backend 支持时,将所有层的 KV cache 合并为一个 [num_layers, ...] 的连续张量,使得跨节点 KV 传输可以一次拷贝所有层的块数据,而非逐层传输。
11.4 ECConnectorModelRunnerMixin(ec_connector_model_runner_mixin.py)
Encoder Cache Connector:将编码器输出(视觉 embeddings)跨节点传输。
python
@contextmanager
def _get_ec_connector_output(scheduler_output, encoder_cache):
ec_connector.bind_connector_metadata(metadata)
if ec_connector.is_consumer:
ec_connector.start_load_caches(encoder_cache)
try:
yield output
finally:
output.finished_sending/recving = ec_connector.get_finished(...)
Consumer 角色在 forward 前加载缓存,Provider 角色在 forward 后自动保存。
十二、多模态处理深度解析(gpu/mm/ 目录)
12.1 EncoderRunner(encoder_runner.py)------ MM 编码器执行
12.1.1 执行流水线
python
# 1. prepare_mm_inputs(): 从 encoder_cache 收集待编码的 MM 输入
# 2. execute_mm_encoder(): 按 modality 分组批量执行编码器
# 3. gather_mm_embeddings(): 将编码器输出映射到 prefill token 位置
# 4. get_inputs_embeds(): 合并 text embeddings + MM embeddings
12.1.2 gather_mm_embeddings() ------ 精确位置映射
python
def gather_mm_embeddings(self, req_ids, total_num_scheduled_tokens, ...):
# 对每个 prefill 请求:
# 对每个 MM feature:
# 计算 start_pos, num_encoder_tokens
# 判断当前 step 是否需要该 feature 的 embeddings
# 裁剪到当前 query 范围: [start_idx, end_idx]
# 标记 is_mm_embed mask
# 收集 mm_embeds
分块 prefill 兼容:当 MM feature 跨越多个 prefill step 时,只映射当前 step 范围内的 embeddings,避免重复处理。
12.1.3 group_and_batch_mm_kwargs
按 modality 分组批量执行------同一 modality 的多个输入合并为一个 batch,利用编码器的批量处理能力。不同 modality(image/video/audio)分别执行。
12.2 EncoderCache(encoder_cache.py)------ 编码器输出缓存
python
class EncoderCache:
mm_features: dict[str, list[MultiModalFeatureSpec]] # req_id → features
encoder_outputs: dict[str, torch.Tensor] # mm_hash → output
Hash-based 去重 :相同图像/视频(相同 mm_hash)只编码一次,后续请求复用缓存。free_encoder_cache(mm_hash) 在请求完成后释放缓存。
12.3 RopeState(rope.py)------ 多维 RoPE 位置编码
12.3.1 M-RoPE 与 XD-RoPE
| 变体 | 维度 | Decode 位置 | 适用模型 |
|---|---|---|---|
| M-RoPE | 3 | orig_pos + delta |
Qwen2-VL |
| XD-RoPE | 3-4 | orig_pos (delta=0) |
其他多维模型 |
12.3.2 prefill_positions 预计算
python
self.prefill_positions = StagedWriteTensor(
(max_num_reqs * num_dims, max_model_len), ...)
为每个请求的每个维度预计算完整的位置序列,存入 GPU tensor。Decode 时只需 orig_pos + delta 即可。
12.3.3 _prepare_rope_positions_kernel
python
# 对每个 request:
# is_prefill → 从 prefill_positions 加载
# is_decode → orig_pos + delta
# 存入 positions[dim, query_start:query_end]
非连续设计 :positions 的最后一维是 max_num_tokens + 1(+1 是故意非连续),使其与 torch.compile 兼容。
十三、其他支撑模块
13.1 StructuredOutputsWorker(structured_outputs.py)------ 语法约束输出
python
class StructuredOutputsWorker:
def apply_grammar_bitmask(self, logits, input_batch, grammar_req_ids, grammar_bitmask):
# 1. 异步拷贝 grammar_bitmask 到 GPU (copy_stream)
# 2. 构建 logits_indices: 每个 grammar request 的 logits 行映射
# 3. 等待 copy_stream 完成
# 4. _apply_grammar_bitmask_kernel: 解包 bitmask → -inf mask
性能关键:bitmask 拷贝在独立 stream 上异步执行,与主 stream 的 logits 计算并行。bitmask 来自 xgrammar,是一个 packed int32 数组(每 bit 代表一个 token 是否合法),解包后对不合法 token 设置 -inf。
13.2 dp_utils.py ------ 数据并行工具
13.2.1 sync_cudagraph_and_dp_padding
DP 场景下所有 rank 必须使用相同的 CUDA graph 形状:
python
def sync_cudagraph_and_dp_padding(...):
# 1. all_reduce([num_tokens, cg_mode, uniform_token_count])
# 2. 取 max num_tokens → 同步 padding
# 3. 取 min cg_mode → 任一 rank 用 eager 则全部 eager
# 4. 同步 uniform_token_count → 不一致则设为 None
关键约束:如果任一 DP rank 无法使用 CUDA graph(如形状不匹配),所有 rank 回退到 eager------CUDA graph 要求所有 rank 同步 replay。
13.3 cp_utils.py ------ 上下文并行工具
python
def prepare_dcp_local_seq_lens(dcp_local_seq_lens, seq_lens, num_reqs, dcp_size, dcp_rank, cp_interleave):
# DCP (Disaggregated Context Parallel) 本地 seq_len 计算
# 将 KV cache 按 round-robin 分配给不同 rank
# rounds = seq_len // (dcp_size * cp_interleave)
# remainder = max(seq_len % (dcp_size * cp_interleave) - dcp_rank * cp_interleave, 0)
# local_seq_len = rounds * cp_interleave + min(remainder, cp_interleave)
13.4 eplb_utils.py ------ 专家并行负载均衡
python
class EPLBController:
def step(self, is_dummy, is_profile):
# 调用 EplbState.step() 进行专家重分配
# 支持 MoE 模型的动态负载均衡
装饰器 @step_eplb_after(is_dummy=False) 自动在 model runner 方法后触发 EPLB step。
13.5 pp_utils.py ------ 流水线并行工具
python
def pp_broadcast(sampled_token_ids, num_sampled, num_rejected):
# Last PP rank → 其他 rank 广播采样结果
torch.distributed.broadcast(sampled_token_ids, src=pp.last_rank)
torch.distributed.broadcast([num_sampled, num_rejected], src=pp.last_rank)
def pp_receive(num_reqs):
# 非 last rank 接收采样结果
13.6 async_utils.py ------ 异步输出
python
class AsyncOutput(AsyncModelRunnerOutput):
def __init__(self, model_runner_output, sampler_output, ...):
with stream(copy_stream, main_stream):
copy_stream.wait_stream(main_stream)
self.sampled_token_ids = async_copy_to_np(...)
self.logprobs_tensors = sampler_output.logprobs_tensors.to_cpu_nonblocking()
self.copy_event.record(copy_stream)
def get_output(self):
self.copy_event.synchronize() # 等待异步拷贝完成
# 转换为 Python list 格式返回
双流异步 :在 copy_stream 上异步将 GPU tensor 拷贝到 CPU,主 stream 继续下一次推理。copy_event.synchronize() 只在需要结果时才等待。
13.7 buffer_utils.py ------ GPU 缓冲区管理
13.7.1 UvaBackedTensor
python
class UvaBackedTensor:
cpu: torch.Tensor # CPU 源数据
np: np.ndarray # numpy 视图
gpu: torch.Tensor # UVA 视图 (GPU 可直接访问)
def copy_to_uva(self):
self.gpu = self.pool.copy_to_uva(self.np)
UVA(Unified Virtual Addressing) :CPU pinned memory 通过 UVA 暴露给 GPU,GPU kernel 可直接通过 PCIe 读取,无需显式 H2D 拷贝。配合 UvaBufferPool 的 round-robin 缓冲池实现并发安全。
13.7.2 StagedWriteTensor
python
class StagedWriteTensor:
def stage_write(self, index, start, x):
# 累积写入请求到 _staged_write_* 列表
def apply_write(self):
# 一次性将所有累积写入提交到 GPU
# 使用 _apply_write_kernel (Triton) 并行写入
延迟写入 :多个 add_request 的写入累积在 CPU 列表中,apply_write 一次性批量提交,减少 CPU→GPU 传输次数。
13.8 warmup.py ------ 预热逻辑
python
def warmup_kernels(model_runner, worker_execute_model, worker_sample_tokens):
# Step 1: Prefill --- 每个请求 2 个 prompt token
# Step 2: Decode --- 每个请求 1 + num_spec_steps token
# 覆盖所有 Triton kernel 的 JIT 编译路径
# 同时 warmup grammar bitmask kernel
两步预热确保 prefill 和 decode 两个阶段的所有 kernel 都被编译。
13.9 model_states/ ------ 模型状态管理
13.9.1 ModelState 接口(interface.py)
python
class ModelState(ABC):
def add_request(self, req_index, new_req_data) -> None
def apply_staged_writes(self) -> None
def get_mm_embeddings(self, ...) -> torch.Tensor | None
def prepare_inputs(self, input_batch, req_states) -> dict
def prepare_dummy_inputs(self, num_reqs, num_tokens) -> dict
def prepare_attn(self, input_batch, cudagraph_mode, ...) -> dict
13.9.2 DefaultModelState(default.py)
标准 decoder-only 模型状态:
- 集成
EncoderRunner(多模态)和RopeState(多维 RoPE) prepare_inputs()主要处理 M-RoPE/XD-RoPE 的位置计算prepare_attn()调用build_attn_metadata()构建注意力元数据
13.9.3 WhisperModelState(whisper.py)
Whisper 专用状态:
- 编码器输出通过
encoder_outputs直接传入模型(非inputs_embeds) - 管理
encoder_seq_lens用于 cross-attention 的 KV cache 长度 - CUDA graph 捕获时使用
max_encoder_len确保图兼容
13.10 lora_utils.py ------ LoRA 工具
python
class LoraState:
lora_ids: np.ndarray # [max_num_reqs] LoRA ID
lora_requests: dict[str, LoRARequest]
def make_lora_inputs(self, req_ids, idx_mapping, num_scheduled_tokens):
prompt_lora_mapping = tuple(lora_ids[idx_mapping])
token_lora_mapping = tuple(lora_ids.repeat(num_scheduled_tokens))
active_lora_requests = {lora for req_id in req_ids if (lora := ...)}
简洁的 ID 映射机制------每个请求/ token 关联 LoRA ID,模型层根据 ID 查找对应的 LoRA 权重。
13.11 pool/ ------ 池化模型处理
13.11.1 PoolingRunner(pooling_runner.py)
python
class PoolingRunner:
def pool(self, hidden_states, input_batch, req_states):
last_hidden_states = hidden_states[input_batch.logits_indices]
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
is_valid = input_batch.seq_lens == prompt_len
return last_hidden_states, is_valid
简单的 LAST pooling + L2 归一化,is_valid 标识请求是否完成(完整序列的 pooler 输出才有效)。
13.11.2 LateInteractionRunner(late_interaction_runner.py)
支持 ColBERT 风格的 late interaction scoring:
python
# CACHE_QUERY 模式: 缓存 query 的 token embeddings
# SCORE_DOC 模式: 计算 query × doc 的 MaxSim score
# query_uses: 引用计数,所有 doc 处理完成后释放 query 缓存
十四、Mamba/SSM 与其他 Worker 类型
14.1 mamba_utils.py ------ SSM 状态管理
14.1.1 核心问题
Mamba/SSM 模型不像 Transformer 使用 KV cache,而是维护 递归状态 (SSM state)。这些状态存储在 KV cache 的 block 中,但语义完全不同------需要精确的 状态复制 而非简单的 block 映射。
14.1.2 preprocess_mamba() ------ 前向预处理
python
def preprocess_mamba(scheduler_output, kv_cache_config, ...):
# 对每个请求:
# prev_state_idx = mamba_state_idx.get(req_id, computed // block_size)
# curr_state_idx = num_blocks - 1 - num_speculative_blocks
# mamba_state_idx[req_id] = curr_state_idx
# if prev != curr:
# collect_mamba_copy_meta(prev, curr, accept_bias, ...)
# do_mamba_copy_block(copy_bufs)
运行状态位置 :始终保存在最后一个非 speculative block(num_blocks - 1 - num_spec_blocks)。当 block 分配变化时,需要将旧状态拷贝到新位置。
14.1.3 postprocess_mamba() ------ 后向处理
python
def postprocess_mamba(scheduler_output, ...):
# 当 partial block 变为 full block 时:
# 从 running_state block 复制状态到新的 full block
# accept_token_bias = aligned_new_computed - num_tokens_running_state
14.1.4 batch_memcpy_kernel ------ 批量内存拷贝
python
@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes):
# 每个程序处理一次 memcpy:
# src_ptr, dst_ptr, size = load from arrays
# 分块拷贝: BLOCK_SIZE=1024
通用批量拷贝 kernel------一次 launch 处理多组不同源/目标/大小的内存拷贝,适用于 Mamba 状态的零散复制需求。
14.2 cpu_worker.py + cpu_model_runner.py ------ CPU 推理
14.2.1 CPUWorker
继承 GPU Worker,但做了 CPU 特化:
python
class CPUWorker(Worker):
def init_device(self):
# 1. 绑定 NUMA node
# 2. 初始化分布式环境 (gloo backend)
# 3. 构造 CPUModelRunner
def determine_available_memory(self):
# 使用 psutil 查询可用 CPU 内存
# 考虑 explicit_kv_cache_size 或 auto 计算
def compile_or_warm_up_model(self):
# 只需一次 profile_run 用于 torch.compile 编译
不支持:sleep/wake 模式、dummy weights、elastic EP。
14.2.2 CPUModelRunner
继承 GPUModelRunner,做以下替换:
- Tensor 替换 :所有
CpuGpuBuffer.gpu指向.cpu,设备张量替换为 CPU 张量 - Kernel 替换 :Triton kernel 替换为
cpu_tl.*的 C++ 原生实现 - CUDA API 替换 :
_torch_cuda_wrapper()将torch.cuda.Stream/Event替换为占位符
14.3 xpu_worker.py + xpu_model_runner.py ------ XPU 推理
14.3.1 XPUWorker
类似 CPUWorker,但使用 XPU 设备:
python
class XPUWorker(Worker):
def init_device(self):
# 初始化 XPU 设备 + oneCCL 分布式
14.3.2 XPUModelRunner
python
class XPUModelRunner(GPUModelRunner):
# _torch_cuda_wrapper():
# torch.cuda.Stream → torch.xpu.Stream
# torch.cuda.CUDAGraph → torch.xpu.XPUGraph (if supported)
# torch.cuda.graph_pool_handle → torch.xpu.graph_pool_handle
XPU API 与 CUDA API 高度同构,通过 monkey-patch 实现兼容。
14.4 tpu_input_batch.py ------ TPU 输入批
TPU 使用不同的 InputBatch 实现,原因是 TPU 的 Attention 后端和采样机制与 GPU 不同:
- 使用 CPU 端的
token_ids_cpunumpy 数组管理 token ID - 采样参数(temperature/top_p/top_k)有独立的 CPU/GPU 双缓冲
- 不使用 UVA,而是显式的
pin_memory+non_blocking copy SamplingType分为GREEDY/RANDOM/TOP_P/TOP_K等类别,TPU 采样器按类别批量处理
14.5 workspace.py ------ 工作空间管理
python
class WorkspaceManager:
def __init__(self, device, num_ubatches=1):
self._current_workspaces: list[torch.Tensor | None] = [None] * num_ubatches
self._locked: bool = False
def get_simultaneous(self, *shapes_and_dtypes):
# 从单个分配中获取多个 workspace tensor
# 自动扩展大小(仅当未 lock 时)
def lock/unlock(self):
# 锁定/解锁 workspace 大小
UBatch 感知 :每个 ubatch slot 维护独立的 workspace,避免 DBO 场景下的数据竞争。lock() 在执行期间禁止增长,确保 CUDA graph 兼容。
14.6 utils.py ------ 辅助函数
14.6.1 KVBlockZeroer
python
class KVBlockZeroer:
def init_meta(self, attn_groups, kernel_block_sizes, cache_dtype):
# 预计算每个 KV cache buffer 的 segment 地址
# 支持 block_dim=0 (blocks outermost) 和 block_dim=1 (K/V outermost)
def zero_block_ids(self, block_ids):
# 使用 _zero_kv_blocks_kernel 批量清零新分配的 block
14.6.2 AttentionGroup
python
@dataclass
class AttentionGroup:
backend: AttentionBackend
builder: AttentionMetadataBuilder
kv_cache_spec: KVCacheSpec
layer_names: list[str]
将 attention 后端、元数据构建器和 KV cache 规范绑定为一组,支持同一模型内多种 attention 类型(如 self-attention + cross-attention)。
总结:Part 3 架构全景
核心设计原则
- GPU 原地操作:采样管线全程 in-place Triton kernel,零中间分配
- CPU 端短路:每步采样操作前检查是否有请求实际使用,跳过无效 kernel
- 延迟写入 + UVA:所有参数更新通过 staged write 批量提交,UVA 避免 H2D 拷贝
- 图兼容性:所有缓冲区预分配最大尺寸,CUDA graph 内零动态分配
- DP/TP/PP/CP 全并行:每个并行维度都有专门的状态同步与协调机制
- 推测解码深度集成:从采样(Gumbel 种子一致性)到惩罚(draft counts)到图管理(独立 pool)的全链路支持
- 多设备抽象:GPU → CPU/XPU/TPU 通过继承 + monkey-patch 实现最小改动兼容
数据流总览
SchedulerOutput
↓
GPUModelRunner._prepare_inputs()
├── RequestState.update() ← 请求状态更新
├── EncoderRunner.execute_mm_encoder() ← 多模态编码
├── RopeState.prepare_positions() ← 位置编码
├── BlockTables.compute_slot_mappings() ← KV block 映射
└── build_attn_metadata() ← 注意力元数据
↓
Model.forward() → logits
↓
Sampler.__call__()
├── apply_logit_bias() ← 白名单/偏置/min_tokens
├── apply_penalties() ← 重复/频率/存在惩罚
├── apply_bad_words() ← 禁用词掩码
├── apply_temperature() ← 温度缩放
├── apply_min_p() ← 最小概率阈值
├── apply_top_k_top_p() ← top-k/top-p 截断
├── gumbel_sample() ← Gumbel-Max 采样
└── compute_topk_logprobs() ← logprobs 计算
↓
[If Speculative Decoding]
RejectionSampler.__call__()
└── strict / probabilistic / synthetic rejection
↓
AsyncOutput.get_output() → ModelRunnerOutput
关键设计决策总结
| # | 决策 | 原因 |
|---|---|---|
| 1 | GPUModelRunner 6阶段流水线 | 清晰的职责分离,每阶段可独立优化 |
| 2 | InputBatch持久化GPU tensor | 避免每步重新分配,减少GPU内存碎片 |
| 3 | 批量CPU→GPU同步 | 减少PCIe传输次数,最小化延迟 |
| 4 | deferred_state_corrections | 避免在forward前修改GPU状态,保证正确性 |
| 5 | CUDA Graph decode图重放 | 消除kernel launch开销,提升decode吞吐 |
| 6 | UBatch DBO重叠 | compute/comm双流并行,隐藏NCCL延迟 |
| 7 | 8种SpecDecode方法 | 不同场景不同策略,2-5x吞吐提升 |
| 8 | EncoderCache mm_hash缓存 | 避免重复编码相同图像,节省GPU计算 |
| 9 | CuMemAllocator休眠/唤醒 | 权重卸载→GPU共享→模型热更新 |
| 10 | AsyncModelRunnerOutput | GPU→CPU异步流拷贝,避免同步阻塞 |