深入理解 vLLM 的 Block 机制
基于 vLLM v1 架构源码分析,涵盖 BlockPool 核心数据结构、分配/释放/驱逐流程、Prefix Caching 实现,以及分布式场景下 Block ID 的统一机制。
1. 整体架构:Block 管理的层次结构
vLLM v1 的 KV cache 管理采用分层设计,BlockPool 是整个 block 生命周期的核心管理者。
关键 :在标准部署中,BlockPool 实例全局只有一个。它在 KVCacheCoordinator.__init__ 中创建,被所有 SingleTypeKVCacheManager 共享。
代码出处:
vllm/v1/core/kv_cache_coordinator.py--- BlockPool 创建vllm/v1/core/kv_cache_coordinator.py--- 所有 STM 共享block_poolvllm/v1/core/kv_cache_manager.py--- KVCacheManager 的快捷引用
2. 核心数据结构
2.1 KVCacheBlock --- Block 的元数据
每个 block 的元数据由 KVCacheBlock 表示,它不存储实际的 KV 数据,只管理逻辑状态。
python
# vllm/v1/core/kv_cache_utils.py
@dataclass(slots=True)
class KVCacheBlock:
block_id: int # 逻辑 ID,范围 [0, num_gpu_blocks)
ref_cnt: int = 0 # 引用计数,被多少个请求共享
_block_hash: BlockHashWithGroupId | None = None # 哈希键(仅满块缓存后设置)
prev_free_block: "KVCacheBlock | None" = None # 双向链表前驱
next_free_block: "KVCacheBlock | None" = None # 双向链表后继
关键属性解读:
| 属性 | 含义 | 何时变化 |
|---|---|---|
block_id |
逻辑索引,从 0 开始递增 | 创建后不变 |
ref_cnt |
引用计数,用于共享 block(prefix cache hit 时多个请求引用同一 block) | touch() +1,free_blocks() -1 |
_block_hash |
block 内容的哈希 + group_id,用于 prefix cache 查找 | cache_full_blocks() 设置,_maybe_evict_cached_block() 清除 |
prev/next_free_block |
空闲链表指针 | 仅由 FreeKVCacheBlockQueue 操作 |
代码出处:vllm/v1/core/kv_cache_utils.py
2.2 FreeKVCacheBlockQueue --- 空闲块的双向链表
空闲块通过双向链表组织,支持 O(1) 的头部弹出和中间删除。
驱逐顺序:链表头部是 LRU(最久未使用)的 block,尾部是最近释放的 block。分配时从头部取,释放时追加到尾部。当请求释放 block 时,block 按逆序释放(尾 block 先释放),确保尾部 block 是"最有价值"的缓存。
代码出处:vllm/v1/core/kv_cache_utils.py
2.3 BlockHashToBlockMap --- Prefix Cache 哈希表
用于通过 block hash 快速查找已缓存的 block,支持 prefix caching。
python
class BlockHashToBlockMap:
_cache: dict[BlockHashWithGroupId, KVCacheBlock | dict[int, KVCacheBlock]]
设计要点:
- 大多数情况下,一个 hash 只对应一个 block(直接存
KVCacheBlock) - 当多个 block 内容相同(hash 冲突),退化为
dict[block_id, KVCacheBlock] - 这种 union 类型设计是为了减少 GC 开销(避免为每个 key 都创建一个 dict)
代码出处:vllm/v1/core/block_pool.py
2.4 BlockPool --- Block 管理的入口
BlockPool 整合了上述所有数据结构:
初始化过程 vllm/v1/core/block_pool.py:
- 创建
num_gpu_blocks个KVCacheBlock,block_id 从 0 递增 - 用所有 block 构造
FreeKVCacheBlockQueue - 弹出 block_id=0 作为
null_block(占位符,ref_cnt 不维护)
3. Block 生命周期:分配、缓存、驱逐、释放
3.1 完整时序图
3.2 分配:get_new_blocks
python
# block_pool.py
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
ret = self.free_block_queue.popleft_n(num_blocks)
# In order to only iterate the list once, we duplicated code a bit
if self.enable_caching:
for block in ret:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
return ret
分配逻辑:
- 从空闲链表头部弹出 N 个 block(LRU 优先分配)
- 若启用 prefix caching,检查 block 是否有缓存哈希,有则驱逐
- 设置
ref_cnt = 1
3.3 缓存命中:touch
当 prefix cache 命中时,已有 block 被"触摸"------增加引用计数:
python
# block_pool.py
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block) # 从空闲链表移除
block.ref_cnt += 1
ref_cnt == 0 意味着 block 在空闲链表中(是驱逐候选),需要先移除。
3.4 驱逐:evict_blocks
驱逐的本质:数据即将失效
当空闲 block 不足时,需要驱逐 block 给新的请求使用 关键在于理解 block 在空闲链表中的状态。一个 block 在空闲链表中可能还有 hash,这意味着:
-
它的 KV cache 数据仍然在 GPU 显存中
-
没有任何请求正在使用它(ref_cnt == 0)
-
它是一个驱逐候选------如果新请求有相同前缀可以命中,如果显存紧张则被回收
驱逐发生的场景
假设有两个请求:
bash
请求 A: "The cat sat on the mat" → Block 3 (hash=0xAB)
请求 A 完成,Block 3 被释放 → ref_cnt=0,进入空闲链表,但 hash=0xAB 仍在哈希表中
此时 Block 3 的 GPU KV cache 数据仍然存在,是 "The cat sat on the mat" 的 KV。
--- 场景 1:不驱逐(正确情况)---
请求 B: "The cat sat on the roof" → hash=0xAB 命中 Block 3
→ 前缀 "The cat sat on the" 复用 Block 3 的 KV cache ✅
→ 只需计算 " roof" 部分
--- 场景 2:显存不足,需要驱逐 ---
请求 C: "Completely different text" → 需要新 block
→ 从空闲链表取出 Block 3
→ Block 3 的 KV cache 将被 "Completely different text" 的 KV 覆盖
→ 必须从哈希表移除 hash=0xAB → Block 3 的映射
→ 否则后续请求 D: "The cat sat on the..." 会命中 Block 3
→ 但 Block 3 的内容已经是 "Completely different text" 的 KV ❌
驱逐 = 从哈希表中移除映射,因为 block 的 KV cache 内容即将/已经被覆盖。保留映射会导致后续请求"假命中"到错误数据。
python
# block_pool.py
# 由 KV Connector 调用,当 Worker 报告某些 block 的 KV 数据已失效(如分布式 KV 传输中远端数据过期),需要主动从哈希表中移除,防止后续请求命中到过期数据
def evict_blocks(self, block_ids: set[int]) -> None:
for block_id in block_ids:
block = self.blocks[block_id]
self._maybe_evict_cached_block(block)
# 分配新 block 时,如果取出的空闲 block 还有缓存哈希,必须驱逐 ------ 因为该 block 即将被新请求的 KV 数据覆盖。
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
# In order to only iterate the list once, we duplicated code a bit
if self.enable_caching:
for block in ret:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
block_hash = block.block_hash
if block_hash is None:
return False # 无哈希,无需驱逐
if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None:
return False # 哈希表中找不到
block.reset_hash() # 清除哈希
return True
3.5 释放:free_blocks
python
# block_pool.py
def free_blocks(self, ordered_blocks, prepend=False) -> None:
for block in blocks_list:
block.ref_cnt -= 1
freed_blocks = [b for b in blocks_list if b.ref_cnt == 0 and not b.is_null]
if prepend:
self.free_block_queue.prepend_n(freed_blocks) # 优先复用
else:
self.free_block_queue.append_n(freed_blocks) # 追加到尾部
释放逻辑:
- 减少引用计数
- 只有
ref_cnt降为 0 的 block 才真正归还空闲链表 - 共享 block(prefix cache hit)在所有引用者释放后才归还
4. Prefix Caching 机制
Prefix Caching 是 vLLM 的核心优化:当不同请求共享相同前缀 token 时,可以复用已计算的 KV cache block,避免重复计算。
4.1 工作原理
4.2 Block Hash 的计算
Block hash 由 Request 对象在创建时和追加新 token 时计算:
BlockHash=NewType("BlockHash", bytes),本质是 bytes 类型BlockHashWithGroupId=BlockHash+ KV cache group ID 的组合,用于区分不同 group 中相同内容的 block
代码出处:vllm/v1/core/kv_cache_utils.py
4.3 缓存查找流程
- Scheduler 调用
KVCacheManager.get_computed_blocks(request) - 遍历 request 的
block_hashes,在BlockHashToBlockMap中逐块查找 - 找到匹配 block 后调用
touch()增加引用计数 - 返回所有命中 block 及其对应的 token 数
5. 分布式场景:Block ID 的统一机制
5.1 架构总览
5.2 为什么不同卡的 Block ID 天然一致?
核心原因:Block ID 是逻辑索引,不是物理地址。
-
所有 Worker 的
num_blocks相同(有 assert 保证):python# kv_cache_utils.py assert all( [cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs] )代码出处:
vllm/v1/core/kv_cache_utils.py -
Block ID 直接作为 KV cache tensor 的第一维下标 :Worker 端的 KV cache tensor 形状为
[num_blocks, 2, block_size, num_kv_heads, head_size],block_id=5 直接索引第 5 行。 -
Tensor Parallelism 下,同一 block_id 在不同卡存的是不同 head 分片:各卡独立计算自己负责的 KV head,最后通过 all-reduce 聚合结果。
5.3 Block ID 从 Scheduler 到 Worker 的完整数据流
关键代码文件:
- Scheduler 生成 block_ids:
vllm/v1/core/sched/scheduler.py - Worker 接收并更新:
vllm/v1/worker/gpu_model_runner.py - 写入 BlockTables:
vllm/v1/worker/gpu_model_runner.py - Attention kernel 查表:
vllm/v1/worker/block_table.py
6. Block 与 Slot 的关系
Block 是 KV cache 管理的逻辑单位,Slot 是 attention kernel 实际访问的物理位置。
ini
Slot 计算:
block_index = position // block_size
block_number = block_table[request_index][block_index]
slot = block_number * block_size + (position % block_size)
关键代码文件:vllm/v1/worker/block_table.py
7. 特殊场景
7.1 Null Block
BlockPool 初始化时,block_id=0 被弹出作为 null_block。它是一个占位符,用于:
- 滑动窗口注意力中被跳过的 block 位置
- Mamba 模型中 align 模式下的填充
null_block 的 ref_cnt 不被维护,释放时需要特殊跳过(not block.is_null)。
7.2 混合模型(Hybrid KV Cache Coordinator)
当模型同时包含 Full Attention 和 Sliding Window Attention 层时,使用 HybridKVCacheCoordinator。此时:
- 所有 KV cache group 共享同一个 BlockPool
- 不同 group 的 block_size 可能不同,但 hash_block_size 是统一的
BlockHashListWithBlockSize负责将 hash_block_size 粒度的哈希转换为实际 block_size 粒度
7.3 Preemption 与 Block 恢复
当 GPU 显存不足时,Scheduler 会抢占(preempt)低优先级请求:
- 调用
KVCacheManager.free(request)释放该请求的所有 block - 被释放的 block 归还空闲链表,可被高优先级请求使用
- 被抢占的请求后续重新调度时,需要重新分配 block 并重算 KV cache
8. 总结
| 概念 | 说明 |
|---|---|
| BlockPool | 全局唯一,管理所有 GPU block 的分配、释放和缓存 |
| KVCacheBlock | Block 的元数据(逻辑 ID、引用计数、哈希、链表指针),不存实际数据 |
| FreeKVCacheBlockQueue | 空闲块的双向链表,LRU 驱逐顺序 |
| BlockHashToBlockMap | Prefix cache 的哈希表,hash → block 映射 |
| Block ID | 逻辑索引 [0, N),直接作为 Worker 端 KV cache tensor 的下标 |
| Slot | Attention kernel 的物理访问位置 = block_id * block_size + offset |
| 分布式统一 | 所有 Worker 的 num_blocks 相同,block_id 天然一致,无需额外协调 |
| TP 下的 block | 同一 block_id 在不同卡存不同 head 分片,独立计算后 all-reduce |
关键文件索引
| 文件 | 职责 |
|---|---|
vllm/v1/core/block_pool.py |
BlockPool、BlockHashToBlockMap 定义 |
vllm/v1/core/kv_cache_utils.py |
KVCacheBlock、FreeKVCacheBlockQueue、BlockHash 类型定义 |
vllm/v1/core/kv_cache_coordinator.py |
BlockPool 创建、多 group 协调 |
vllm/v1/core/kv_cache_manager.py |
对外接口层,组合 coordinator |
vllm/v1/core/single_type_kv_cache_manager.py |
单类型 KV cache 的分配/释放/缓存逻辑 |
vllm/v1/core/sched/scheduler.py |
调度入口,持有 KVCacheManager |
vllm/v1/worker/block_table.py |
Worker 端 block_table GPU tensor 管理 |
vllm/v1/worker/gpu_model_runner.py |
Worker 端接收 block_ids 并更新 block_table |