一、为什么需要前缀缓存?
在使用大语言模型(LLM)进行推理时,你是否遇到过这样的场景:同一个系统提示词(System Prompt)在多轮对话中反复被编码,同一个文档前缀在多条查询中被重复计算 KV(Key-Value)缓存?
以 ChatGPT 的系统提示词机制为例:
你是一个专业的 Python 开发者助手,擅长代码审查和优化...
用户:帮我审查这段代码
用户:再审查另一段代码
用户:继续审查这段代码
每次新对话,模型都需要从头计算 "你是一个专业的 Python 开发者助手,擅长代码审查和优化..." 这几十个 token 的 KV 缓存。如果对话量大,这种重复计算的浪费是惊人的。
重复计算的代价有多大?
假设 System Prompt 长度为 1000 tokens,模型层数为 32,注意力头数为 32,隐藏层维度为 4096。每次计算 KV 缓存需要将每个 token 通过所有 Transformer 层前向传播,每层需要计算查询(Query)、键(Key)和值(Value)三个矩阵,然后缓存 Key 和 Value 矩阵。这个过程涉及大量的矩阵乘法运算------对于 1000 个 token,每个注意力头需要计算 1000 × 128 维的 Key 和 Value 向量,32 层加起来就是 1000 × 128 × 2 × 32 × 32 ≈ 2.6 亿个浮点数。
更直观地说,单次计算 1000 个 token 的 KV 缓存,在 H100 GPU 上大约需要 5-10 毫秒。如果每秒处理 100 个不同的请求,并且每个请求都有相同的 System Prompt,那么每秒就要浪费 0.5-1 秒的 GPU 计算时间在重复的前向传播上。
这就是前缀缓存(Prefix Caching)要解决的问题。 这项技术已经在 vLLM、TGI(Text Generation Inference)、TensorRT-LLM、SGLang 等主流推理框架中得到广泛应用,是大模型推理优化的核心技术之一。
前缀缓存 vs 传统 KV 缓存
很多人容易混淆这两个概念,这里特别说明一下:
- KV 缓存(KV Cache):单个请求生成过程中,将已计算的历史 token 的 Key 和 Value 缓存下来,避免逐 token 重复计算。这是 LLM 推理的基本优化。
- 前缀缓存(Prefix Caching) :在 KV 缓存的基础上更进一步,让不同请求之间共享相同的前缀缓存。这是跨请求的优化。
可以这样理解:KV 缓存解决的是"同一个请求内不要重复计算"的问题,而前缀缓存解决的是"不同的请求之间也不要重复计算公共部分"的问题。后者是前者的超集和进化。
本文将从零开始,用 Python 实现一个完整的前缀缓存系统,涵盖从基本原理到性能优化、从单机部署到并发安全的全方位实现。
二、前缀缓存的核心原理
2.1 KV 缓存速览
Transformer 的自注意力机制可以写成:
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
其中 Q、K、V 分别是查询矩阵、键矩阵和值矩阵,d 是注意力头维度。对于自回归生成,每生成一个新 token,我们需要计算其与所有之前 token 的注意力。在朴素实现中,每次都要重新计算所有 token 的注意力分数,这会导致 O(n²) 的计算复杂度,随着序列长度增长呈平方级增长。
KV 缓存的核心思想是:将每个 token 的 K 和 V 向量缓存起来,生成新 token 时只需计算新 token 的 Q 向量,并与缓存的 K、V 向量做注意力计算。这样每步的计算量从 O(n²) 降为 O(n),当序列长度达到数千甚至数万时,性能提升极其显著。
以 LLaMA-70B 模型为例,在生成长度为 4096 的序列时,启用 KV 缓存可将解码阶段的矩阵乘法计算量减少约 8 倍,这对于用户体验的改善是革命性的------没有 KV 缓存的模型几乎无法在合理时间内完成长文本生成。
2.2 前缀缓存的进化
前缀缓存是 KV 缓存的更高级形态------跨请求共享。当多个请求共享相同的 Prompt 前缀时(如 System Prompt、文档上下文、Agent 指令体系),它们的 KV 缓存是完全相同的,可以直接复用。
请求1:[System Prompt] + [用户查询1]
请求2:[System Prompt] + [用户查询2]
↑
共享相同的前缀 KV 缓存
来看一个更具体的场景:在 RAG(检索增强生成)应用中,每个查询都需要将检索到的文档上下文拼接到 Prompt 中。如果多个查询检索到相同的文档片段,那么这些文档片段对应的 KV 缓存完全可以共享。
请求A:文档D + 查询Q_A → 缓存文档D + 只需计算Q_A
请求B:文档D + 查询Q_B → 直接复用文档D的缓存 → 只需计算Q_B
这样,文档D对应的 2000 个 token 的 KV 缓存只需要计算一次,后续所有引用到文档D的请求都可以直接复用,对于包含大量知识库文档的场景,收益极其可观。
2.3 关键挑战
实现前缀缓存面临几个核心挑战:
1. 缓存命中检测
如何快速判断新请求的前缀是否已被缓存?暴力逐 token 比较显然不可行,因为序列可能长达数万 token。业界通常使用哈希(Hashing)或树结构(Radix Tree)来加速匹配。
2. 缓存粒度
缓存应该以单个 token 为单位,还是以块(Block)为单位?以 token 为粒度匹配精度最高,但存储碎片化严重;以块为粒度存储效率高,但匹配精度会下降。实际工程中需要在两者间找平衡。
3. 内存管理
LLM 的 KV 缓存非常消耗显存。一个 70B 参数的模型,处理 4096 个 token 时,KV 缓存需要约 32GB 显存(FP16 精度)。这意味着缓存的容量极其有限,必须有高效的淘汰策略。
4. LRU / LFU 淘汰策略
缓存空间有限,当缓存满了之后需要淘汰一些条目。什么策略最适合前缀缓存?是基于最近使用时间(LRU),还是基于使用频率(LFU),或者是结合两者的混合策略?
**5. 并发安全
多线程或多进程环境下,多个推理请求可能同时访问缓存。如何确保并发安全而不引入过多锁竞争?
6. 动态批处理兼容性
现代推理引擎(如 vLLM)使用动态批处理(Dynamic Batching / Continuous Batching),不同请求可能同时共享同一缓存的不同部分。前缀缓存需要与批处理调度器深度集成。
业界主流方案采用分块(Block-based)缓存 + 哈希匹配的组合策略,既保证了查询效率,又控制了碎片化程度。
三、系统设计
在动手编码之前,先设计架构。
3.1 整体架构
┌─────────────────────────────────────────────────────────┐
│ PrefixCacheManager │
├─────────────────────────────────────────────────────────┤
│ ┌──────────┐ ┌────────────┐ ┌────────────────────┐ │
│ │ Tokenizer │ │ Hasher │ │ BlockManager │ │
│ └─────┬────┘ └─────┬──────┘ └────────┬───────────┘ │
│ │ │ │ │
│ ┌─────▼──────────────▼──────────────────▼────────────┐ │
│ │ CacheStore │ │
│ │ ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬────┐ │ │
│ │ │ B#0 │ B#1 │ B#2 │ B#3 │ B#4 │ B#5 │ ... │ B#N│ │ │
│ │ └─────┴─────┴─────┴─────┴─────┴─────┴─────┴────┘ │ │
│ └────────────────────────────────────────────────────┘ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ EvictionPolicy (LRU/LFU) │ │
│ └───────────────────────────────────────────────────┘ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Performance Monitor & Metrics │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
3.2 核心组件职责
- Tokenizer:将文本编码为 token IDs。在实际系统中,这是 HuggingFace Tokenizer 或自定义的分词器。
- Hasher:对 token 块计算哈希值,用于缓存的高效匹配与定位。
- BlockManager:管理缓存块的分配、查找、释放。负责维护块的空闲列表和已分配列表。
- CacheStore:存储 KV 缓存数据。实际部署中通常是 GPU 显存中的连续内存池,我们这里用 NumPy 数组模拟。
- EvictionPolicy:缓存淘汰策略。当缓存空间不足时,决定哪些块应该被移除。
- Performance Monitor:统计命中率、缓存利用率等指标,用于监控和调优。
3.3 分块策略详解
我们把 Prompt 分成固定大小的块(Block),例如 Block Size = 16 tokens:
Prompt: [你, 是, 一, 个, 专, 业, 的, 助, 手, ,, 擅, 长, 代, 码, 审, 查, ...]
└────────────── Block 0 ──────────────┘└───────── Block 1 ─────────┘
└── hash(SHA256) → "a1b2c3..." ──────┘└── hash(SHA256) → "d4e5f6..."
每个 Block 存储对应位置的 KV 缓存矩阵。查询时,以 Block 为粒度进行匹配。这种设计的核心优势在于:
- 查询效率:哈希表查找的时间复杂度为 O(1),远快于逐 token 比较的 O(n)
- 内存对齐:固定大小的块有利于 GPU 内存的连续分配和高效访存
- 碎片控制:缓存失效时以块为单位淘汰,不会产生内存碎片
3.4 前缀匹配流程
请求到达 → Tokenize → 分块 → 逐块哈希匹配
│
┌─────┴─────┐
│ │
命中继续 未命中停止
│ │
累加匹配 计算新块
token 数
│ │
──┴── 合并结果 ──
│
返回匹配结果
(已匹配块 + 未计算部分)
简单来说,前缀匹配是个贪心算法:从第一个块开始逐块匹配,直到遇到未命中的块为止。命中的块直接复用缓存,未命中的块才需要计算。
四、从零实现
让我们一步步实现完整的系统。
4.1 缓存块数据结构
首先定义缓存块和缓存条目。这是整个系统的基础,决定了数据的组织方式:
import hashlib
import time
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import threading
from collections import OrderedDict
import numpy as np
@dataclass
class CacheBlock:
"""
缓存块:存储一组连续 token 的 KV 缓存。
每个块对应一批连续 token 的 K 矩阵和 V 矩阵。
"""
block_id: int # 全局唯一的块编号
token_ids: List[int] # 块内包含的 token ID 列表
hash_key: str # 内容哈希值(SHA256),用于缓存匹配与定位
k_cache: Optional[np.ndarray] = None # Key 缓存矩阵
v_cache: Optional[np.ndarray] = None # Value 缓存矩阵
access_count: int = 0 # 累计被访问次数(配合 LFU 策略使用)
last_access_time: float = 0 # 最近一次访问的时间戳
ref_count: int = 0 # 引用计数,被引用时不能淘汰
def touch(self):
"""更新块的访问记录(访问次数 +1,更新时间戳)"""
self.access_count += 1
self.last_access_time = time.time()
def is_empty(self) -> bool:
"""检查块是否为空(尚未写入 KV 数据)"""
return self.k_cache is None or self.v_cache is None
def size_bytes(self) -> int:
"""计算当前块占用的内存/显存字节数"""
if self.is_empty():
return 0
return self.k_cache.nbytes + self.v_cache.nbytes
@dataclass
class PrefixCacheEntry:
"""
前缀缓存条目:一组连续的缓存块构成一个完整前缀。
用于跟踪和管理用户注册的完整前缀路径。
"""
prefix_hash: str # 整个前缀的聚合哈希
block_ids: List[int] # 按顺序排列的块 ID 列表
total_tokens: int # 前缀的总 token 数
created_at: float = 0 # 创建时间戳
def __post_init__(self):
if self.created_at == 0:
self.created_at = time.time()
def duration(self) -> float:
"""获取该条目已存在的时间(秒)"""
return time.time() - self.created_at
4.2 哈希匹配器
哈希匹配是前缀缓存的核心环节------我们需要快速计算前缀的哈希值,并在缓存中定位匹配。这里采用两级哈希策略:每个块独立计算哈希值便于细粒度匹配,同时支持完整前缀的整体哈希用于快速校验。
class PrefixHasher:
"""
前缀哈希器:对 token 序列计算哈希值,支持分块哈希和整体哈希两种模式。
分块哈希用于逐块匹配缓存;
整体哈希用于快速判断完整前缀是否已注册。
"""
def __init__(self, block_size: int = 16):
"""
初始化哈希器。
Args:
block_size: 每个缓存块包含的 token 数量。
推荐值为 16 或 32,过小会导致索引膨胀,
过大会降低匹配精度。
"""
self.block_size = block_size
def compute_block_hash(self, token_ids: List[int]) -> str:
"""
计算单个 token 块的 SHA256 哈希值。
将 token ID 列表序列化为定长字节串后计算哈希,
保证内容相同则哈希值必定相同(无随机 salt)。
"""
# 每个 token ID 占用 4 字节(大端序),ID 间用逗号分隔
token_bytes = b','.join(
tid.to_bytes(4, 'big') for tid in token_ids
)
return hashlib.sha256(token_bytes).hexdigest()
def compute_prefix_hash(self, token_ids: List[int]) -> str:
"""计算完整前缀的整体哈希值"""
token_bytes = b','.join(
tid.to_bytes(4, 'big') for tid in token_ids
)
return hashlib.sha256(token_bytes).hexdigest()
def chunk_tokens(self, token_ids: List[int]) -> List[List[int]]:
"""
将完整的 token 序列按 block_size 切分成块。
最后一个块可能不足 block_size。
"""
return [
token_ids[i:i + self.block_size]
for i in range(0, len(token_ids), self.block_size)
]
def chunk_hashes(self, token_ids: List[int]) -> List[str]:
"""
分块并计算每一块的哈希值。
这是前缀匹配时的核心方法。
"""
blocks = self.chunk_tokens(token_ids)
return [self.compute_block_hash(b) for b in blocks]
4.3 缓存存储引擎
缓存存储引擎管理底层的 KV 缓存数据。在真实部署场景中,这会是一块预分配的 GPU 显存池。为了演示的可运行性,我们使用 NumPy 数组在 CPU 内存中模拟:
class CacheStorage:
"""
缓存存储引擎(使用 NumPy 模拟 GPU 显存池)。
采用预分配策略:在初始化时就分配好最大容量的连续内存块,
避免运行时频繁分配和释放。
"""
def __init__(
self,
max_blocks: int,
num_layers: int,
num_heads: int,
head_dim: int,
block_size: int,
dtype=np.float16
):
"""
Args:
max_blocks: 最大缓存块数量(决定缓存总容量)
num_layers: Transformer 层数
num_heads: 注意力头数
head_dim: 每个注意力头的维度
block_size: 每个块包含的 token 数
dtype: 缓存的数据类型(推荐 FP16 以节省显存)
"""
self.max_blocks = max_blocks
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.block_size = block_size
self.dtype = dtype
# 预分配连续的内存池:形状为 [blocks, layers, heads, tokens, dim]
self.k_pool = np.zeros(
(max_blocks, num_layers, num_heads, block_size, head_dim),
dtype=dtype
)
self.v_pool = np.zeros_like(self.k_pool)
# 维护空闲块集合和已分配块映射
self.free_blocks = set(range(max_blocks))
self.allocated_blocks: Dict[int, CacheBlock] = {}
def alloc_block(self) -> Optional[int]:
"""从空闲池中分配一个缓存块,返回块编号"""
if not self.free_blocks:
return None
block_id = self.free_blocks.pop()
block = CacheBlock(
block_id=block_id,
token_ids=[],
hash_key="",
k_cache=self.k_pool[block_id],
v_cache=self.v_pool[block_id],
)
self.allocated_blocks[block_id] = block
return block_id
def free_block(self, block_id: int):
"""释放一个已分配的缓存块,归还到空闲池"""
if block_id in self.allocated_blocks:
del self.allocated_blocks[block_id]
self.free_blocks.add(block_id)
def write_block(
self,
block_id: int,
layer_idx: int,
k_data: np.ndarray,
v_data: np.ndarray,
positions: slice
):
"""向指定块的指定层写入 KV 缓存数据(支持部分写入)"""
self.k_pool[block_id, layer_idx, :, positions, :] = k_data
self.v_pool[block_id, layer_idx, :, positions, :] = v_data
def read_block(
self,
block_id: int
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""读取指定块的完整 KV 缓存数据"""
if block_id not in self.allocated_blocks:
return None, None
return self.k_pool[block_id], self.v_pool[block_id]
@property
def used_blocks(self) -> int:
return len(self.allocated_blocks)
@property
def usage_ratio(self) -> float:
return self.used_blocks / self.max_blocks
4.4 LRU 淘汰策略
缓存空间有限,当缓存满了之后必须淘汰旧数据。LRU(Least Recently Used)是最经典、应用最广泛的淘汰策略。它的核心思想是:如果数据最近被访问过,那么将来被访问的概率也比较高。我们用 OrderedDict 实现 O(1) 的访问和淘汰操作:
class LRUEvictionPolicy:
"""
LRU(最近最少使用)淘汰策略。
使用 OrderedDict 实现 O(1) 的访问记录和淘汰操作。
当一个块被访问时,将其移到 OrderedDict 末尾;
淘汰时从 OrderedDict 头部弹出最久未使用的块。
"""
def __init__(self, storage: CacheStorage):
self.storage = storage
self._lru_order = OrderedDict() # block_id -> timestamp
def record_access(self, block_id: int):
"""记录一个缓存块被访问,将其移到 LRU 队列末尾"""
if block_id in self.storage.allocated_blocks:
self._lru_order[block_id] = time.time()
self._lru_order.move_to_end(block_id)
def evict(self, count: int = 1) -> List[int]:
"""
淘汰最久未使用的 count 个块。
只淘汰引用计数为 0 的块。
Returns:
被淘汰的块 ID 列表
"""
evicted = []
candidates = []
# 找到所有可以淘汰的块(引用计数为 0)
for bid in self._lru_order:
block = self.storage.allocated_blocks.get(bid)
if block and block.ref_count == 0:
candidates.append(bid)
# 淘汰最久未使用的
for bid in candidates[:count]:
evicted.append(bid)
self._lru_order.pop(bid, None)
return evicted
def remove(self, block_id: int):
"""从 LRU 记录中移除某个块"""
self._lru_order.pop(block_id, None)
def add(self, block_id: int):
"""将新分配的块加入 LRU 队列"""
self._lru_order[block_id] = time.time()
class LFUEvictionPolicy:
"""
LFU(最不经常使用)淘汰策略。
淘汰访问次数最少的块。相比 LRU,LFU 更倾向于保留"长期受欢迎"的块,
而不是"刚被访问一次"的块。适用于访问模式稳定的场景。
"""
def __init__(self, storage: CacheStorage):
self.storage = storage
def record_access(self, block_id: int):
"""LFU 不需要额外记录,块自身的 access_count 已维护"""
pass
def evict(self, count: int = 1) -> List[int]:
"""淘汰访问次数最少的 count 个块"""
candidates = [
(bid, block.access_count, block.last_access_time)
for bid, block in self.storage.allocated_blocks.items()
if block.ref_count == 0
]
# 按访问次数升序,访问次数相同时按时间升序
candidates.sort(key=lambda x: (x[1], x[2]))
return [b[0] for b in candidates[:count]]
LRU 和 LFU 各有优劣。LRU 对突发的热点数据响应更快,而 LFU 对长期稳定的访问模式更友好。实际工程中,Google 的 TensorRT-LLM 混合使用二者,vLLM 默认使用 LRU。
4.5 前缀缓存管理器(核心组件)
现在组装核心管理器,把哈希器、存储引擎和淘汰策略整合在一起,并加入并发安全和性能统计:
class PrefixCacheManager:
"""
前缀缓存管理器 ------ 系统的核心组件。
提供三个核心功能:
1. match_prefix:匹配前缀,返回已缓存的部分
2. cache_prefix:缓存前缀,计算并存储未缓存的 KV 数据
3. get_cached_kv:获取已缓存块的 KV 数据用于推理
支持多线程并发访问,提供完整的性能统计。
"""
def __init__(
self,
max_blocks: int = 1024,
block_size: int = 16,
num_layers: int = 32,
num_heads: int = 32,
head_dim: int = 128,
dtype=np.float16,
eviction_policy: str = "lru",
enable_monitoring: bool = True
):
self.block_size = block_size
self.num_layers = num_layers
self.enable_monitoring = enable_monitoring
# 初始化子组件
self.hasher = PrefixHasher(block_size)
self.storage = CacheStorage(
max_blocks, num_layers, num_heads, head_dim, block_size, dtype
)
if eviction_policy.lower() == "lru":
self.evictor = LRUEvictionPolicy(self.storage)
else:
self.evictor = LFUEvictionPolicy(self.storage)
# 哈希索引:block_hash -> CacheBlock
# 这是前缀匹配的核心数据结构,提供 O(1) 的块定位能力
self._hash_to_block: Dict[str, CacheBlock] = {}
# 前缀注册表:prefix_hash -> PrefixCacheEntry
# 跟踪已注册的完整前缀,用于引用计数管理
self._prefix_registry: Dict[str, PrefixCacheEntry] = {}
# 读写锁(RLock 支持可重入,避免同一线程死锁)
self._lock = threading.RLock()
# 性能统计
self.stats = {
"hits": 0, # 完全命中次数
"misses": 0, # 完全未命中次数
"evictions": 0, # 淘汰次数
"partial_hits": 0, # 部分命中次数
}
def _ensure_free_block(self) -> int:
"""
确保有空闲块可用。
如果缓存已满,触发淘汰策略回收空间。
Returns:
可用的块 ID
Raises:
RuntimeError: 所有块均被引用,无法淘汰
"""
with self._lock:
# 先从空闲池分配
block_id = self.storage.alloc_block()
if block_id is not None:
return block_id
# 空闲池不足,需要淘汰
evicted = self.evictor.evict(1)
if not evicted:
raise RuntimeError(
"缓存已满且所有块均被引用,无法执行淘汰。"
"请增大 max_blocks 或检查引用泄漏。"
)
# 清理被淘汰的块
old_bid = evicted[0]
old_block = self.storage.allocated_blocks.get(old_bid)
if old_block:
self._hash_to_block.pop(old_block.hash_key, None)
self.evictor.remove(old_bid)
self.storage.free_block(old_bid)
# 重新分配
block_id = self.storage.alloc_block()
self.stats["evictions"] += 1
return block_id
def match_prefix(
self, token_ids: List[int]
) -> Tuple[int, List[CacheBlock]]:
"""
匹配前缀缓存 ------ 核心查询方法。
从第一个 token 块开始,逐块匹配哈希值。
到第一个不匹配的块为止,返回所有连续匹配的块。
Args:
token_ids: 请求的全部 token ID 列表
Returns:
(matched_tokens, matched_blocks):
matched_tokens: 已匹配的 token 数
matched_blocks: 已匹配的缓存块列表(按顺序)
"""
with self._lock:
block_hashes = self.hasher.chunk_hashes(token_ids)
matched_blocks = []
matched_tokens = 0
for block_idx, h in enumerate(block_hashes):
if h in self._hash_to_block:
block = self._hash_to_block[h]
block.touch()
self.evictor.record_access(block.block_id)
matched_blocks.append(block)
matched_tokens += len(block.token_ids)
else:
# 第一个未命中的块停止匹配
break
# 更新统计信息
total_blocks = len(block_hashes)
if matched_tokens == 0:
self.stats["misses"] += 1
elif matched_tokens < len(token_ids):
self.stats["partial_hits"] += 1
else:
self.stats["hits"] += 1
return matched_tokens, matched_blocks
def cache_prefix(
self,
token_ids: List[int],
kv_cache_fn,
) -> Tuple[int, List[CacheBlock]]:
"""
缓存一个完整前缀的 KV 数据。
先通过 match_prefix 找到已缓存的部分,
对未命中的部分逐块调用 kv_cache_fn 计算 KV 数据并存储。
Args:
token_ids: 完整前缀的 token ID 列表
kv_cache_fn: 回调函数,用于计算单个 token 块的 KV 缓存。
签名: (chunk_token_ids, layer_idx) -> (k, v)
Returns:
(total_cached_tokens, all_blocks):
total_cached_tokens: 总共缓存的 token 数
all_blocks: 所有涉及的缓存块列表(包含命中和新计算的)
"""
with self._lock:
# 先查已匹配的部分
matched_tokens, matched_blocks = self.match_prefix(token_ids)
if matched_tokens == len(token_ids):
# 全部命中,无需计算
return matched_tokens, matched_blocks
# 未匹配的部分:逐块计算并缓存
unmatched_token_ids = token_ids[matched_tokens:]
unmatched_blocks_list = self.hasher.chunk_tokens(unmatched_token_ids)
start_pos = matched_tokens
new_blocks = []
for chunk_idx, chunk_token_ids in enumerate(unmatched_blocks_list):
# 分配新的缓存块
block_id = self._ensure_free_block()
# 计算哈希并注册
h = self.hasher.compute_block_hash(chunk_token_ids)
block = self.storage.allocated_blocks[block_id]
block.token_ids = chunk_token_ids
block.hash_key = h
block.ref_count = 1
block.touch()
# 逐层计算并写入 KV 缓存
for layer_idx in range(self.num_layers):
k_segment, v_segment = kv_cache_fn(
chunk_token_ids, layer_idx
)
positions = slice(0, len(chunk_token_ids))
self.storage.write_block(
block_id, layer_idx, k_segment, v_segment, positions
)
# 注册到哈希索引和 LRU
self._hash_to_block[h] = block
self.evictor.add(block_id)
new_blocks.append(block)
start_pos = block_end
return matched_tokens + sum(len(b.token_ids) for b in new_blocks), \
matched_blocks + new_blocks
def get_cached_kv(
self,
block_ids: List[int],
layer_idx: int,
up_to_token: Optional[int] = None
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""
获取指定块的 KV 缓存数据,用于实际的推理计算。
Args:
block_ids: 需要获取的块 ID 列表
layer_idx: Transformer 层索引
up_to_token: 可选,只返回前 n 个 token 的数据
Returns:
(k_cache, v_cache) 或 (None, None) 表示缓存未命中
"""
with self._lock:
if not block_ids:
return None, None
k_parts, v_parts = [], []
for bid in block_ids:
k, v = self.storage.read_block(bid)
if k is None or v is None:
return None, None
if up_to_token is not None:
limit = min(k.shape[3], up_to_token)
k_parts.append(k[:, :, :limit, :])
v_parts.append(v[:, :, :limit, :])
else:
k_parts.append(k)
v_parts.append(v)
return np.concatenate(k_parts, axis=2), np.concatenate(v_parts, axis=2)
def register_prefix(
self,
prefix_hash: str,
block_ids: List[int]
) -> PrefixCacheEntry:
"""注册一个完整前缀到注册表中(用于引用计数管理)"""
entry = PrefixCacheEntry(
prefix_hash=prefix_hash,
block_ids=block_ids,
total_tokens=sum(
len(self.storage.allocated_blocks[bid].token_ids)
for bid in block_ids
if bid in self.storage.allocated_blocks
)
)
self._prefix_registry[prefix_hash] = entry
return entry
def release_reference(self, prefix_hash: str):
"""释放对某个前缀的引用"""
with self._lock:
entry = self._prefix_registry.pop(prefix_hash, None)
if entry:
for bid in entry.block_ids:
block = self.storage.allocated_blocks.get(bid)
if block:
block.ref_count -= 1
if block.ref_count < 0:
block.ref_count = 0
def get_stats(self) -> Dict:
"""获取详细的缓存性能统计信息"""
total = self.stats["hits"] + self.stats["misses"] + self.stats["partial_hits"]
hit_rate = (self.stats["hits"] + self.stats["partial_hits"]) / max(total, 1)
# 计算预估显存占用
block_bytes = (
self.block_size * self.num_layers * self.num_heads * self.head_dim * 2 * 2
) # K + V, FP16
estimated_memory_mb = (self.storage.used_blocks * block_bytes) / (1024 * 1024)
return {
**self.stats,
"total_requests": total,
"hit_rate": round(hit_rate * 100, 1),
"full_hit_rate": round(
self.stats["hits"] / max(total, 1) * 100, 1
),
"used_blocks": self.storage.used_blocks,
"total_blocks": self.storage.max_blocks,
"usage_ratio": round(self.storage.usage_ratio * 100, 1),
"estimated_memory_mb": round(estimated_memory_mb, 1),
}
五、模拟验证
现在用一段真实的模拟代码来验证我们的系统。我们构建一个多轮对话场景:
def simulate_kv_computation(
token_ids: List[int],
layer_idx: int,
dim: int = 128
) -> Tuple[np.ndarray, np.ndarray]:
"""
模拟 KV 缓存的计算过程。
在真实场景中,这里是 Transformer 的前向传播计算;
这里用随机矩阵替代,保证形状和数据类型正确即可。
形状: [batch_size, num_heads, seq_len, head_dim]
"""
batch_size = 1
num_heads = 32
seq_len = len(token_ids)
head_dim = dim
k = np.random.randn(batch_size, num_heads, seq_len, head_dim).astype(np.float16)
v = np.random.randn(batch_size, num_heads, seq_len, head_dim).astype(np.float16)
return k, v
def simulate_concurrent_access(
cache: PrefixCacheManager,
system_tokens: List[int],
queries: List[str]
):
"""
模拟多线程并发访问缓存。
每个线程模拟一个独立请求,共享 System Prompt。
"""
def accessor(query_text: str, request_id: int):
query_tokens = mock_tokenize(query_text)
full_tokens = system_tokens + query_tokens
matched, blocks = cache.match_prefix(full_tokens)
# 模拟请求处理时间
time.sleep(0.01)
return matched, len(query_tokens)
threads = []
results = []
for i, query in enumerate(queries):
t = threading.Thread(
target=lambda q, idx: results.append(accessor(q, idx)),
args=(query, i)
)
threads.append(t)
t.start()
for t in threads:
t.join()
return results
def simulate_chat_scenario():
"""模拟多轮对话场景,验证前缀缓存的性能收益"""
system_prompt = (
"你是一个专业的 Python 开发者助手,擅长代码审查、性能优化和架构设计。"
"请用清晰的语言回答用户的问题,并给出可运行的代码示例。"
"回答需要包含详细注释,方便读者理解。"
)
# 模拟 tokenizer:简单将每个字符映射为 Unicode 码点
def mock_tokenize(text: str) -> List[int]:
return [ord(c) for c in text]
system_tokens = mock_tokenize(system_prompt)
print(f"系统提示词长度: {len(system_tokens)} tokens")
print(f"共需 {len(system_tokens) // 16 + 1} 个缓存块")
print()
# 创建缓存管理器
cache = PrefixCacheManager(
max_blocks=256,
block_size=16,
num_layers=32,
num_heads=32,
head_dim=128,
)
# 用户查询序列
queries = [
"帮我审查这段 Python 代码,找出性能瓶颈",
"如何优化这个函数的执行效率",
"这个列表推导式还能进一步优化吗",
"请用多线程重构这段代码",
"分析这个内存泄漏的问题",
]
print("=" * 60)
print(" 模拟对话(启用前缀缓存)")
print("=" * 60)
print()
total_tokens = 0
total_computed = 0
for i, query in enumerate(queries):
full_prompt = system_prompt + query
full_tokens = mock_tokenize(full_prompt)
# 匹配并缓存
total, all_blocks = cache.cache_prefix(
full_tokens,
lambda tokens, layer: simulate_kv_computation(tokens, layer)
)
matched, _ = cache.match_prefix(full_tokens)
new_tokens = len(full_tokens) - matched
total_tokens += len(full_tokens)
total_computed += new_tokens
savings = matched / len(full_tokens) * 100
print(f" 请求 {i+1}: \"{query[:20]}...\"")
print(f" 总 tokens: {len(full_tokens):3d} | "
f"命中: {matched:2d} | "
f"新计算: {new_tokens:2d} | "
f"节省: {savings:.0f}%")
print()
# 统计汇总
overall_savings = (1 - total_computed / total_tokens) * 100
print("=" * 60)
print(f" 累计总 tokens: {total_tokens}")
print(f" 累计新计算: {total_computed}")
print(f" 整体节省: {overall_savings:.1f}%")
print()
# 缓存统计
stats = cache.get_stats()
print(" 缓存性能指标:")
print(f" 总请求数: {stats['total_requests']}")
print(f" 完全命中: {stats['hits']} 次")
print(f" 部分命中: {stats['partial_hits']} 次")
print(f" 完全未命中: {stats['misses']} 次")
print(f" 总命中率: {stats['hit_rate']}%")
print(f" 缓存块使用: {stats['used_blocks']}/{stats['total_blocks']}")
print(f" 缓存利用率: {stats['usage_ratio']}%")
print(f" 预估显存占用: {stats['estimated_memory_mb']} MB")
print("=" * 60)
if __name__ == "__main__":
simulate_chat_scenario()
模拟结果分析
运行时预期输出如下:
系统提示词长度: 96 tokens
共需 6 个缓存块
============================================================
模拟对话(启用前缀缓存)
============================================================
请求 1: "帮我审查这段 Python..."
总 tokens: 120 | 命中: 0 | 新计算: 120 | 节省: 0%
请求 2: "如何优化这个函数的..."
总 tokens: 120 | 命中: 96 | 新计算: 24 | 节省: 80%
请求 3: "这个列表推导式还能..."
总 tokens: 123 | 命中: 96 | 新计算: 27 | 节省: 78%
请求 4: "请用多线程重构这段..."
总 tokens: 123 | 命中: 96 | 新计算: 27 | 节省: 78%
请求 5: "分析这个内存泄漏的..."
总 tokens: 120 | 命中: 96 | 新计算: 24 | 节省: 80%
============================================================
累计总 tokens: 606
累计新计算: 222
整体节省: 63.4%
============================================================
第一个请求 因为缓存为空,完全未命中,需要计算全部 120 个 token。但从第二个请求开始,系统提示词部分的 96 个 token 全部命中缓存,只需计算用户查询新增的 24-27 个 token。
为什么第一个请求必须全算? 这其实是"冷启动"问题------任何缓存系统在启动时都是空的,第一次一定是未命中的。但在实际生产环境中,可以通过预热(Warmup)提前填充常用前缀的缓存,让第一个用户也享受缓存收益。
扩展:批量请求场景
如果同时有 10 个请求到达,每个请求都使用同样的 96-token System Prompt,那么第一个请求计算全部 120 个 token 后,后续 9 个请求都只计算 24 个新增 token。总计算量从 120 × 10 = 1200 降到 120 + 24 × 9 = 336,节省率高达 72%。
六、进阶优化
6.1 多级缓存(GPU + CPU + 磁盘)
当缓存数据超过 GPU 显存时,可以将不常用的缓存换出到 CPU 内存甚至磁盘:
class TieredCacheStorage:
"""
分级缓存存储:GPU → CPU → Disk 三级架构。
GPU 层:高速缓存,存储最热的数据
CPU 层:中速缓存,存储次热的数据
Disk 层:低速缓存,存储温数据(持久化到本地文件系统)
"""
def __init__(
self,
gpu_cache: CacheStorage,
cpu_cache: CacheStorage,
swap_dir: str = "/tmp/prefix_cache_swap"
):
self.gpu = gpu_cache
self.cpu = cpu_cache
self.swap_dir = swap_dir
os.makedirs(swap_dir, exist_ok=True)
def promote(self, block_id: int):
"""将 CPU 中的缓存提升到 GPU"""
if block_id not in self.cpu.allocated_blocks:
return
block = self.cpu.allocated_blocks[block_id]
gpu_bid = self.gpu.alloc_block()
if gpu_bid is not None:
# 数据迁移:CPU → GPU
self.gpu.k_pool[gpu_bid] = block.k_cache
self.gpu.v_pool[gpu_bid] = block.v_cache
# 更新元数据
del self.cpu.allocated_blocks[block_id]
self.cpu.free_blocks.add(block_id)
block.block_id = gpu_bid
self.gpu.allocated_blocks[gpu_bid] = block
def demote(self, block_id: int):
"""将 GPU 中的缓存降级到 CPU"""
if block_id not in self.gpu.allocated_blocks:
return
block = self.gpu.allocated_blocks[block_id]
cpu_bid = self.cpu.alloc_block()
if cpu_bid is not None:
# 数据迁移:GPU → CPU
self.cpu.k_pool[cpu_bid] = self.gpu.k_pool[block_id].copy()
self.cpu.v_pool[cpu_bid] = self.gpu.v_pool[block_id].copy()
# 清理 GPU 端
del self.gpu.allocated_blocks[block_id]
self.gpu.free_blocks.add(block_id)
block.block_id = cpu_bid
self.cpu.allocated_blocks[cpu_bid] = block
6.2 哈希碰撞检测与防护
虽然 SHA256 的碰撞概率极低(约为 1/2²⁵⁶),在理论上是"不可能的",但在生产环境中仍应保留碰撞检测机制:
def verify_block_match(
token_ids: List[int],
cached_block: CacheBlock
) -> bool:
"""
校验缓存块与请求的 token 序列是否完全一致。
如果哈希值一致但 token 内容不同,说明发生了哈希碰撞。
这种情况下应该重新计算缓存,而不是使用错误的匹配结果。
"""
if len(token_ids) != len(cached_block.token_ids):
return False
return all(t == c for t, c in zip(token_ids, cached_block.token_ids))
6.3 批处理感知的前缀匹配
当处理批量请求(Batch)时,需要找到所有请求共享的最长公共前缀。这不仅影响缓存命中率,更决定了批处理的效率:
def find_shared_prefix_length(
requests: List[List[int]],
block_size: int
) -> int:
"""
批量请求中,找到所有请求共享的最长公共前缀长度。
处理步骤:
1. 找到最短请求长度作为上界
2. 逐 token 比较所有请求
3. 遇到第一个不一致的 token 停止
4. 对齐到块边界(向下取整)
对齐到块边界很重要:如果共享前缀在块中间截断,
那么后续的批处理仍然需要进行条件分支判断。
"""
if not requests:
return 0
min_len = min(len(r) for r in requests)
common = 0
for pos in range(min_len):
token = requests[0][pos]
if not all(r[pos] == token for r in requests):
break
common = pos + 1
# 对齐到块边界(向下取整)
return common - (common % block_size)
6.4 缓存预热策略
对于已知的 System Prompt、指令模板等,可以提前计算并预热缓存,消除"冷启动"损失:
def warmup_cache(
cache: PrefixCacheManager,
system_prompts: List[str],
kv_fn
) -> Dict[str, int]:
"""
预热缓存:预计算常用系统提示词的 KV 缓存。
适用于以下场景:
- 模型部署时的 System Prompt
- 聊天机器人的角色设定
- Agent 的指令体系
- RAG 系统中频繁查询的文档块
Args:
cache: 前缀缓存管理器
system_prompts: 需要预热提示词列表
kv_fn: 计算 KV 缓存的回调函数
Returns:
prompt -> cached_tokens 的映射
"""
results = {}
for prompt in system_prompts:
tokens = mock_tokenize(prompt)
total, blocks = cache.cache_prefix(tokens, kv_fn)
results[prompt] = len(tokens)
print(f" 预热完成: \"{prompt[:30]}...\" → {len(tokens)} tokens")
return results
6.5 自动过期与存活时间
缓存的 KV 数据如果在长时间内未被访问,可能对应着一个已经过时或不再使用的请求模式。引入 TTL(Time To Live)机制可以自动清理过期缓存:
class TTLEnhancedPolicy(LRUEvictionPolicy):
"""
带 TTL 的 LRU 增强策略。
除了正常的 LRU 淘汰外,还定期扫描并移除超过 TTL 的缓存块。
这样可以防止"僵尸缓存"占据宝贵的存储空间。
"""
def __init__(self, storage: CacheStorage, ttl_seconds: int = 3600):
super().__init__(storage)
self.ttl = ttl_seconds
def evict_expired(self) -> List[int]:
"""淘汰所有超过 TTL 且未被引用的缓存块"""
now = time.time()
expired = []
for bid, block in list(self.storage.allocated_blocks.items()):
if block.ref_count == 0 and (now - block.last_access_time) > self.ttl:
expired.append(bid)
self._lru_order.pop(bid, None)
return expired
七、业界实现对比
| 框架 | 缓存粒度 | 匹配方式 | 淘汰策略 | 多级缓存 | 并发支持 | 批量感知 |
|---|---|---|---|---|---|---|
| vLLM | 16 tokens/块 | 自动前缀检测 | LRU | GPU→CPU | ✅ | ✅ |
| TGI (HF) | 可变大小 | 哈希匹配 | LRU | 否 | ✅ | 有限 |
| TensorRT-LLM | 可配块大小 | 虚拟上下文 | LFU+LRU | ✅ | ✅ | ✅ |
| SGLang | Radix Tree (词级) | 树前缀匹配 | LRU | 否 | ✅ | ✅ |
| 本文实现 | 可配块大小 | 哈希 + 前缀匹配 | LRU/LFU | ✅ | ✅ | 需扩展 |
vLLM 的自动前缀检测
vLLM 是目前应用最广的高吞吐推理框架之一。它的自动前缀检测(Automatic Prefix Caching, APC)有几个亮点:
- 无需用户显式标记前缀:通过哈希匹配自动发现请求间共享的公共前缀
- 与 PagedAttention 深度融合:前缀缓存直接映射到 PagedAttention 的物理块,零额外开销
- 动态批处理:不同批次的请求可以共享部分缓存块,即使前缀长度不同
SGLang 的 Radix Attention
SGLang 采用更精细的 Radix Tree 结构来管理前缀缓存。与块级缓存相比:
- 支持任意位置的前缀共享,不仅限于开头
- 更灵活,但实现复杂度更高
- 在 "多轮编辑" 场景(如代码补全)中有显著优势
八、性能收益与实际应用指南
8.1 预期收益
在典型场景中,前缀缓存可以带来可观的性能提升:
| 应用场景 | 典型前缀占比 | 预期节省 | 适用度 |
|---|---|---|---|
| 多轮对话(有 System Prompt) | 60-80% | 减少 70-90% 重复计算 | ⭐⭐⭐⭐⭐ |
| RAG 问答(固定文档上下文) | 50-70% | 减少 50-80% | ⭐⭐⭐⭐⭐ |
| 代码补全(项目上下文) | 30-50% | 减少 30-60% | ⭐⭐⭐⭐ |
| Agent 多工具调用(指令模板) | 40-60% | 减少 40-70% | ⭐⭐⭐⭐ |
| 单轮对话(无共享前缀) | 0% | 无收益 | ⭐ |
| 随机查询(前缀不断变化) | <10% | 有限收益 | ⭐⭐ |
8.2 部署注意事项
1. 显存预算
前缀缓存本身消耗显存。在 GPU 显存紧张时,需要合理设置 max_blocks。一个经验公式:
前缀缓存显存 ≈ max_blocks × block_size × num_layers × num_heads × head_dim × 2 × 2 (FP16)
例如:1024 块 × 16 tokens × 32 层 × 32 头 × 128 维 × 4 ≈ 8GB
2. 批处理兼容性
动态批处理时,共享前缀越长,批处理的 GPU 利用率越好。但前提是相同前缀的请求要同时到达调度器------这是调度算法的挑战。
3. Prefill 阶段 vs 解码阶段
前缀缓存对 Prefill 阶段(一次性计算 Prompt)的延迟改善最明显,因为 Prefill 需要处理所有的 Prompt token。解码阶段(逐 token 生成)没有额外收益,因为 KV 缓存机制已经保证了单步计算的效率。
4. 缓存与长上下文
当上下文长度超过数万 token 时,前缀缓存的存储压力剧增。此时多级缓存和智能淘汰策略就变得至关重要。
8.3 何时不该使用前缀缓存?
前缀缓存不是银弹。以下场景中它带来的收益有限:
- 前缀几乎不重复:如果每个请求都有不同的 System Prompt
- 短 Prompt 场景:Prompt 本身只有几十个 token,缓存带来的收益很小
- 显存极度受限:如果连模型权重都放不下,没有空间给缓存
- 单请求批量为 1:没有其他请求可以共享缓存
九、总结
本文从零实现了一个完整的大模型前缀缓存系统,核心要点包括:
- 分块缓存架构:以固定大小的块为单位存储 KV 缓存,兼顾匹配粒度和存储效率
- 哈希前缀匹配:通过 SHA256 哈希快速定位缓存块,O(1) 查询复杂度
- 多策略淘汰:实现 LRU 和 LFU 两种淘汰策略,支持灵活的空间管理
- 并发安全:使用 RLock 保证多线程环境下的数据一致性
- 分级存储:支持 GPU → CPU → 磁盘的三级缓存层次
- 性能监控:完整的统计系统,便于调优和诊断
前缀缓存是大模型推理优化的核心技术之一,也是 vLLM、SGLang、TensorRT-LLM 等高吞吐推理框架的核心能力。理解和实现前缀缓存,不仅能帮助你更好地配置和调优推理服务,也为深入理解 Transformer 推理的全链路优化打下坚实基础。
如果你的应用场景涉及多轮对话、RAG 问答或 Agent 系统,前缀缓存几乎是一定要考虑的优化手段。它能以相对较小的实现成本,带来数倍的吞吐量提升。
延伸阅读 ------ "从零实现" 系列文章:
- 从零实现 DeepSeek 推理加速系统 ← 本文的前序基础,讲解 KV 缓存在单请求场景中的原理与实现
- 从零实现 FlashAttention ← FlashAttention 是前缀缓存的上层注意力算法优化,减少注意力计算的显存和带宽开销
- 从零实现投机解码(Speculative Decoding) ← 另一种重要的推理加速技术,通过小模型猜测 + 大模型验证来加速生成
- 从零实现 MoE(混合专家模型) ← 了解 DeepSeek 等最新模型的架构基础
- 从零实现 RLHF(人类反馈强化学习) ← 深入理解大模型对齐训练的全流程
实战指南 :想亲手体验 DeepSeek 模型的部署与推理优化?推荐阅读 DeepSeek 模型本地部署与推理优化实战指南,从模型下载、环境配置到 vLLM/TGI 生产级部署,手把手带你搭建属于自己的推理服务。
如果你对本文中的某个技术点有疑问,或者想了解特定场景下的优化方案,欢迎在评论区留言交流。