从零实现大模型前缀缓存(Prefix Caching)系统 — 告别重复计算,推理速度翻倍

一、为什么需要前缀缓存?

在使用大语言模型(LLM)进行推理时,你是否遇到过这样的场景:同一个系统提示词(System Prompt)在多轮对话中反复被编码,同一个文档前缀在多条查询中被重复计算 KV(Key-Value)缓存?

以 ChatGPT 的系统提示词机制为例:

复制代码
你是一个专业的 Python 开发者助手,擅长代码审查和优化...
用户:帮我审查这段代码
用户:再审查另一段代码  
用户:继续审查这段代码

每次新对话,模型都需要从头计算 "你是一个专业的 Python 开发者助手,擅长代码审查和优化..." 这几十个 token 的 KV 缓存。如果对话量大,这种重复计算的浪费是惊人的。

重复计算的代价有多大?

假设 System Prompt 长度为 1000 tokens,模型层数为 32,注意力头数为 32,隐藏层维度为 4096。每次计算 KV 缓存需要将每个 token 通过所有 Transformer 层前向传播,每层需要计算查询(Query)、键(Key)和值(Value)三个矩阵,然后缓存 Key 和 Value 矩阵。这个过程涉及大量的矩阵乘法运算------对于 1000 个 token,每个注意力头需要计算 1000 × 128 维的 Key 和 Value 向量,32 层加起来就是 1000 × 128 × 2 × 32 × 32 ≈ 2.6 亿个浮点数。

更直观地说,单次计算 1000 个 token 的 KV 缓存,在 H100 GPU 上大约需要 5-10 毫秒。如果每秒处理 100 个不同的请求,并且每个请求都有相同的 System Prompt,那么每秒就要浪费 0.5-1 秒的 GPU 计算时间在重复的前向传播上。

这就是前缀缓存(Prefix Caching)要解决的问题。 这项技术已经在 vLLM、TGI(Text Generation Inference)、TensorRT-LLM、SGLang 等主流推理框架中得到广泛应用,是大模型推理优化的核心技术之一。

前缀缓存 vs 传统 KV 缓存

很多人容易混淆这两个概念,这里特别说明一下:

  • KV 缓存(KV Cache):单个请求生成过程中,将已计算的历史 token 的 Key 和 Value 缓存下来,避免逐 token 重复计算。这是 LLM 推理的基本优化。
  • 前缀缓存(Prefix Caching) :在 KV 缓存的基础上更进一步,让不同请求之间共享相同的前缀缓存。这是跨请求的优化。

可以这样理解:KV 缓存解决的是"同一个请求内不要重复计算"的问题,而前缀缓存解决的是"不同的请求之间也不要重复计算公共部分"的问题。后者是前者的超集和进化。

本文将从零开始,用 Python 实现一个完整的前缀缓存系统,涵盖从基本原理到性能优化、从单机部署到并发安全的全方位实现。

二、前缀缓存的核心原理

2.1 KV 缓存速览

Transformer 的自注意力机制可以写成:

复制代码
Attention(Q, K, V) = softmax(Q × K^T / √d) × V

其中 Q、K、V 分别是查询矩阵、键矩阵和值矩阵,d 是注意力头维度。对于自回归生成,每生成一个新 token,我们需要计算其与所有之前 token 的注意力。在朴素实现中,每次都要重新计算所有 token 的注意力分数,这会导致 O(n²) 的计算复杂度,随着序列长度增长呈平方级增长。

KV 缓存的核心思想是:将每个 token 的 K 和 V 向量缓存起来,生成新 token 时只需计算新 token 的 Q 向量,并与缓存的 K、V 向量做注意力计算。这样每步的计算量从 O(n²) 降为 O(n),当序列长度达到数千甚至数万时,性能提升极其显著。

以 LLaMA-70B 模型为例,在生成长度为 4096 的序列时,启用 KV 缓存可将解码阶段的矩阵乘法计算量减少约 8 倍,这对于用户体验的改善是革命性的------没有 KV 缓存的模型几乎无法在合理时间内完成长文本生成。

2.2 前缀缓存的进化

前缀缓存是 KV 缓存的更高级形态------跨请求共享。当多个请求共享相同的 Prompt 前缀时(如 System Prompt、文档上下文、Agent 指令体系),它们的 KV 缓存是完全相同的,可以直接复用。

复制代码
请求1:[System Prompt] + [用户查询1]
请求2:[System Prompt] + [用户查询2]
               ↑
       共享相同的前缀 KV 缓存

来看一个更具体的场景:在 RAG(检索增强生成)应用中,每个查询都需要将检索到的文档上下文拼接到 Prompt 中。如果多个查询检索到相同的文档片段,那么这些文档片段对应的 KV 缓存完全可以共享。

复制代码
请求A:文档D + 查询Q_A → 缓存文档D + 只需计算Q_A
请求B:文档D + 查询Q_B → 直接复用文档D的缓存 → 只需计算Q_B

这样,文档D对应的 2000 个 token 的 KV 缓存只需要计算一次,后续所有引用到文档D的请求都可以直接复用,对于包含大量知识库文档的场景,收益极其可观。

2.3 关键挑战

实现前缀缓存面临几个核心挑战:

1. 缓存命中检测

如何快速判断新请求的前缀是否已被缓存?暴力逐 token 比较显然不可行,因为序列可能长达数万 token。业界通常使用哈希(Hashing)或树结构(Radix Tree)来加速匹配。

2. 缓存粒度

缓存应该以单个 token 为单位,还是以块(Block)为单位?以 token 为粒度匹配精度最高,但存储碎片化严重;以块为粒度存储效率高,但匹配精度会下降。实际工程中需要在两者间找平衡。

3. 内存管理

LLM 的 KV 缓存非常消耗显存。一个 70B 参数的模型,处理 4096 个 token 时,KV 缓存需要约 32GB 显存(FP16 精度)。这意味着缓存的容量极其有限,必须有高效的淘汰策略。

4. LRU / LFU 淘汰策略

缓存空间有限,当缓存满了之后需要淘汰一些条目。什么策略最适合前缀缓存?是基于最近使用时间(LRU),还是基于使用频率(LFU),或者是结合两者的混合策略?

**5. 并发安全

多线程或多进程环境下,多个推理请求可能同时访问缓存。如何确保并发安全而不引入过多锁竞争?

6. 动态批处理兼容性

现代推理引擎(如 vLLM)使用动态批处理(Dynamic Batching / Continuous Batching),不同请求可能同时共享同一缓存的不同部分。前缀缓存需要与批处理调度器深度集成。

业界主流方案采用分块(Block-based)缓存 + 哈希匹配的组合策略,既保证了查询效率,又控制了碎片化程度。

三、系统设计

在动手编码之前,先设计架构。

3.1 整体架构

复制代码
┌─────────────────────────────────────────────────────────┐
│                    PrefixCacheManager                      │
├─────────────────────────────────────────────────────────┤
│  ┌──────────┐  ┌────────────┐  ┌────────────────────┐   │
│  │ Tokenizer │  │  Hasher    │  │  BlockManager      │   │
│  └─────┬────┘  └─────┬──────┘  └────────┬───────────┘   │
│        │              │                  │                │
│  ┌─────▼──────────────▼──────────────────▼────────────┐  │
│  │                    CacheStore                        │  │
│  │  ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬────┐  │  │
│  │  │ B#0 │ B#1 │ B#2 │ B#3 │ B#4 │ B#5 │ ... │ B#N│  │  │
│  │  └─────┴─────┴─────┴─────┴─────┴─────┴─────┴────┘  │  │
│  └────────────────────────────────────────────────────┘  │
│  ┌───────────────────────────────────────────────────┐   │
│  │              EvictionPolicy (LRU/LFU)              │   │
│  └───────────────────────────────────────────────────┘   │
│  ┌───────────────────────────────────────────────────┐   │
│  │            Performance Monitor & Metrics             │   │
│  └───────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────┘

3.2 核心组件职责

  1. Tokenizer:将文本编码为 token IDs。在实际系统中,这是 HuggingFace Tokenizer 或自定义的分词器。
  2. Hasher:对 token 块计算哈希值,用于缓存的高效匹配与定位。
  3. BlockManager:管理缓存块的分配、查找、释放。负责维护块的空闲列表和已分配列表。
  4. CacheStore:存储 KV 缓存数据。实际部署中通常是 GPU 显存中的连续内存池,我们这里用 NumPy 数组模拟。
  5. EvictionPolicy:缓存淘汰策略。当缓存空间不足时,决定哪些块应该被移除。
  6. Performance Monitor:统计命中率、缓存利用率等指标,用于监控和调优。

3.3 分块策略详解

我们把 Prompt 分成固定大小的块(Block),例如 Block Size = 16 tokens:

复制代码
Prompt: [你, 是, 一, 个, 专, 业, 的, 助, 手, ,, 擅, 长, 代, 码, 审, 查, ...]
        └────────────── Block 0 ──────────────┘└───────── Block 1 ─────────┘
        └── hash(SHA256) → "a1b2c3..." ──────┘└── hash(SHA256) → "d4e5f6..."

每个 Block 存储对应位置的 KV 缓存矩阵。查询时,以 Block 为粒度进行匹配。这种设计的核心优势在于:

  • 查询效率:哈希表查找的时间复杂度为 O(1),远快于逐 token 比较的 O(n)
  • 内存对齐:固定大小的块有利于 GPU 内存的连续分配和高效访存
  • 碎片控制:缓存失效时以块为单位淘汰,不会产生内存碎片

3.4 前缀匹配流程

复制代码
请求到达 → Tokenize → 分块 → 逐块哈希匹配
                                    │
                              ┌─────┴─────┐
                              │            │
                          命中继续     未命中停止
                              │            │
                          累加匹配      计算新块
                           token 数
                              │            │
                           ──┴── 合并结果 ──
                                    │
                              返回匹配结果
                           (已匹配块 + 未计算部分)

简单来说,前缀匹配是个贪心算法:从第一个块开始逐块匹配,直到遇到未命中的块为止。命中的块直接复用缓存,未命中的块才需要计算。

四、从零实现

让我们一步步实现完整的系统。

4.1 缓存块数据结构

首先定义缓存块和缓存条目。这是整个系统的基础,决定了数据的组织方式:

复制代码
import hashlib
import time
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import threading
from collections import OrderedDict
import numpy as np


@dataclass
class CacheBlock:
    """
    缓存块:存储一组连续 token 的 KV 缓存。
    每个块对应一批连续 token 的 K 矩阵和 V 矩阵。
    """
    block_id: int                # 全局唯一的块编号
    token_ids: List[int]         # 块内包含的 token ID 列表
    hash_key: str                # 内容哈希值(SHA256),用于缓存匹配与定位
    k_cache: Optional[np.ndarray] = None  # Key 缓存矩阵
    v_cache: Optional[np.ndarray] = None  # Value 缓存矩阵
    access_count: int = 0        # 累计被访问次数(配合 LFU 策略使用)
    last_access_time: float = 0  # 最近一次访问的时间戳
    ref_count: int = 0           # 引用计数,被引用时不能淘汰

    def touch(self):
        """更新块的访问记录(访问次数 +1,更新时间戳)"""
        self.access_count += 1
        self.last_access_time = time.time()

    def is_empty(self) -> bool:
        """检查块是否为空(尚未写入 KV 数据)"""
        return self.k_cache is None or self.v_cache is None

    def size_bytes(self) -> int:
        """计算当前块占用的内存/显存字节数"""
        if self.is_empty():
            return 0
        return self.k_cache.nbytes + self.v_cache.nbytes


@dataclass
class PrefixCacheEntry:
    """
    前缀缓存条目:一组连续的缓存块构成一个完整前缀。
    用于跟踪和管理用户注册的完整前缀路径。
    """
    prefix_hash: str              # 整个前缀的聚合哈希
    block_ids: List[int]          # 按顺序排列的块 ID 列表
    total_tokens: int             # 前缀的总 token 数
    created_at: float = 0         # 创建时间戳

    def __post_init__(self):
        if self.created_at == 0:
            self.created_at = time.time()

    def duration(self) -> float:
        """获取该条目已存在的时间(秒)"""
        return time.time() - self.created_at

4.2 哈希匹配器

哈希匹配是前缀缓存的核心环节------我们需要快速计算前缀的哈希值,并在缓存中定位匹配。这里采用两级哈希策略:每个块独立计算哈希值便于细粒度匹配,同时支持完整前缀的整体哈希用于快速校验。

复制代码
class PrefixHasher:
    """
    前缀哈希器:对 token 序列计算哈希值,支持分块哈希和整体哈希两种模式。

    分块哈希用于逐块匹配缓存;
    整体哈希用于快速判断完整前缀是否已注册。
    """

    def __init__(self, block_size: int = 16):
        """
        初始化哈希器。

        Args:
            block_size: 每个缓存块包含的 token 数量。
                        推荐值为 16 或 32,过小会导致索引膨胀,
                        过大会降低匹配精度。
        """
        self.block_size = block_size

    def compute_block_hash(self, token_ids: List[int]) -> str:
        """
        计算单个 token 块的 SHA256 哈希值。

        将 token ID 列表序列化为定长字节串后计算哈希,
        保证内容相同则哈希值必定相同(无随机 salt)。
        """
        # 每个 token ID 占用 4 字节(大端序),ID 间用逗号分隔
        token_bytes = b','.join(
            tid.to_bytes(4, 'big') for tid in token_ids
        )
        return hashlib.sha256(token_bytes).hexdigest()

    def compute_prefix_hash(self, token_ids: List[int]) -> str:
        """计算完整前缀的整体哈希值"""
        token_bytes = b','.join(
            tid.to_bytes(4, 'big') for tid in token_ids
        )
        return hashlib.sha256(token_bytes).hexdigest()

    def chunk_tokens(self, token_ids: List[int]) -> List[List[int]]:
        """
        将完整的 token 序列按 block_size 切分成块。
        最后一个块可能不足 block_size。
        """
        return [
            token_ids[i:i + self.block_size]
            for i in range(0, len(token_ids), self.block_size)
        ]

    def chunk_hashes(self, token_ids: List[int]) -> List[str]:
        """
        分块并计算每一块的哈希值。
        这是前缀匹配时的核心方法。
        """
        blocks = self.chunk_tokens(token_ids)
        return [self.compute_block_hash(b) for b in blocks]

4.3 缓存存储引擎

缓存存储引擎管理底层的 KV 缓存数据。在真实部署场景中,这会是一块预分配的 GPU 显存池。为了演示的可运行性,我们使用 NumPy 数组在 CPU 内存中模拟:

复制代码
class CacheStorage:
    """
    缓存存储引擎(使用 NumPy 模拟 GPU 显存池)。

    采用预分配策略:在初始化时就分配好最大容量的连续内存块,
    避免运行时频繁分配和释放。
    """

    def __init__(
        self,
        max_blocks: int,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        block_size: int,
        dtype=np.float16
    ):
        """
        Args:
            max_blocks: 最大缓存块数量(决定缓存总容量)
            num_layers: Transformer 层数
            num_heads: 注意力头数
            head_dim: 每个注意力头的维度
            block_size: 每个块包含的 token 数
            dtype: 缓存的数据类型(推荐 FP16 以节省显存)
        """
        self.max_blocks = max_blocks
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.block_size = block_size
        self.dtype = dtype

        # 预分配连续的内存池:形状为 [blocks, layers, heads, tokens, dim]
        self.k_pool = np.zeros(
            (max_blocks, num_layers, num_heads, block_size, head_dim),
            dtype=dtype
        )
        self.v_pool = np.zeros_like(self.k_pool)

        # 维护空闲块集合和已分配块映射
        self.free_blocks = set(range(max_blocks))
        self.allocated_blocks: Dict[int, CacheBlock] = {}

    def alloc_block(self) -> Optional[int]:
        """从空闲池中分配一个缓存块,返回块编号"""
        if not self.free_blocks:
            return None
        block_id = self.free_blocks.pop()
        block = CacheBlock(
            block_id=block_id,
            token_ids=[],
            hash_key="",
            k_cache=self.k_pool[block_id],
            v_cache=self.v_pool[block_id],
        )
        self.allocated_blocks[block_id] = block
        return block_id

    def free_block(self, block_id: int):
        """释放一个已分配的缓存块,归还到空闲池"""
        if block_id in self.allocated_blocks:
            del self.allocated_blocks[block_id]
            self.free_blocks.add(block_id)

    def write_block(
        self,
        block_id: int,
        layer_idx: int,
        k_data: np.ndarray,
        v_data: np.ndarray,
        positions: slice
    ):
        """向指定块的指定层写入 KV 缓存数据(支持部分写入)"""
        self.k_pool[block_id, layer_idx, :, positions, :] = k_data
        self.v_pool[block_id, layer_idx, :, positions, :] = v_data

    def read_block(
        self,
        block_id: int
    ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        """读取指定块的完整 KV 缓存数据"""
        if block_id not in self.allocated_blocks:
            return None, None
        return self.k_pool[block_id], self.v_pool[block_id]

    @property
    def used_blocks(self) -> int:
        return len(self.allocated_blocks)

    @property
    def usage_ratio(self) -> float:
        return self.used_blocks / self.max_blocks

4.4 LRU 淘汰策略

缓存空间有限,当缓存满了之后必须淘汰旧数据。LRU(Least Recently Used)是最经典、应用最广泛的淘汰策略。它的核心思想是:如果数据最近被访问过,那么将来被访问的概率也比较高。我们用 OrderedDict 实现 O(1) 的访问和淘汰操作:

复制代码
class LRUEvictionPolicy:
    """
    LRU(最近最少使用)淘汰策略。

    使用 OrderedDict 实现 O(1) 的访问记录和淘汰操作。
    当一个块被访问时,将其移到 OrderedDict 末尾;
    淘汰时从 OrderedDict 头部弹出最久未使用的块。
    """

    def __init__(self, storage: CacheStorage):
        self.storage = storage
        self._lru_order = OrderedDict()  # block_id -> timestamp

    def record_access(self, block_id: int):
        """记录一个缓存块被访问,将其移到 LRU 队列末尾"""
        if block_id in self.storage.allocated_blocks:
            self._lru_order[block_id] = time.time()
            self._lru_order.move_to_end(block_id)

    def evict(self, count: int = 1) -> List[int]:
        """
        淘汰最久未使用的 count 个块。
        只淘汰引用计数为 0 的块。

        Returns:
            被淘汰的块 ID 列表
        """
        evicted = []
        candidates = []

        # 找到所有可以淘汰的块(引用计数为 0)
        for bid in self._lru_order:
            block = self.storage.allocated_blocks.get(bid)
            if block and block.ref_count == 0:
                candidates.append(bid)

        # 淘汰最久未使用的
        for bid in candidates[:count]:
            evicted.append(bid)
            self._lru_order.pop(bid, None)

        return evicted

    def remove(self, block_id: int):
        """从 LRU 记录中移除某个块"""
        self._lru_order.pop(block_id, None)

    def add(self, block_id: int):
        """将新分配的块加入 LRU 队列"""
        self._lru_order[block_id] = time.time()


class LFUEvictionPolicy:
    """
    LFU(最不经常使用)淘汰策略。

    淘汰访问次数最少的块。相比 LRU,LFU 更倾向于保留"长期受欢迎"的块,
    而不是"刚被访问一次"的块。适用于访问模式稳定的场景。
    """

    def __init__(self, storage: CacheStorage):
        self.storage = storage

    def record_access(self, block_id: int):
        """LFU 不需要额外记录,块自身的 access_count 已维护"""
        pass

    def evict(self, count: int = 1) -> List[int]:
        """淘汰访问次数最少的 count 个块"""
        candidates = [
            (bid, block.access_count, block.last_access_time)
            for bid, block in self.storage.allocated_blocks.items()
            if block.ref_count == 0
        ]
        # 按访问次数升序,访问次数相同时按时间升序
        candidates.sort(key=lambda x: (x[1], x[2]))
        return [b[0] for b in candidates[:count]]

LRU 和 LFU 各有优劣。LRU 对突发的热点数据响应更快,而 LFU 对长期稳定的访问模式更友好。实际工程中,Google 的 TensorRT-LLM 混合使用二者,vLLM 默认使用 LRU。

4.5 前缀缓存管理器(核心组件)

现在组装核心管理器,把哈希器、存储引擎和淘汰策略整合在一起,并加入并发安全和性能统计:

复制代码
class PrefixCacheManager:
    """
    前缀缓存管理器 ------ 系统的核心组件。

    提供三个核心功能:
    1. match_prefix:匹配前缀,返回已缓存的部分
    2. cache_prefix:缓存前缀,计算并存储未缓存的 KV 数据
    3. get_cached_kv:获取已缓存块的 KV 数据用于推理

    支持多线程并发访问,提供完整的性能统计。
    """

    def __init__(
        self,
        max_blocks: int = 1024,
        block_size: int = 16,
        num_layers: int = 32,
        num_heads: int = 32,
        head_dim: int = 128,
        dtype=np.float16,
        eviction_policy: str = "lru",
        enable_monitoring: bool = True
    ):
        self.block_size = block_size
        self.num_layers = num_layers
        self.enable_monitoring = enable_monitoring

        # 初始化子组件
        self.hasher = PrefixHasher(block_size)
        self.storage = CacheStorage(
            max_blocks, num_layers, num_heads, head_dim, block_size, dtype
        )

        if eviction_policy.lower() == "lru":
            self.evictor = LRUEvictionPolicy(self.storage)
        else:
            self.evictor = LFUEvictionPolicy(self.storage)

        # 哈希索引:block_hash -> CacheBlock
        # 这是前缀匹配的核心数据结构,提供 O(1) 的块定位能力
        self._hash_to_block: Dict[str, CacheBlock] = {}

        # 前缀注册表:prefix_hash -> PrefixCacheEntry
        # 跟踪已注册的完整前缀,用于引用计数管理
        self._prefix_registry: Dict[str, PrefixCacheEntry] = {}

        # 读写锁(RLock 支持可重入,避免同一线程死锁)
        self._lock = threading.RLock()

        # 性能统计
        self.stats = {
            "hits": 0,          # 完全命中次数
            "misses": 0,        # 完全未命中次数
            "evictions": 0,     # 淘汰次数
            "partial_hits": 0,  # 部分命中次数
        }

    def _ensure_free_block(self) -> int:
        """
        确保有空闲块可用。
        如果缓存已满,触发淘汰策略回收空间。

        Returns:
            可用的块 ID

        Raises:
            RuntimeError: 所有块均被引用,无法淘汰
        """
        with self._lock:
            # 先从空闲池分配
            block_id = self.storage.alloc_block()
            if block_id is not None:
                return block_id

            # 空闲池不足,需要淘汰
            evicted = self.evictor.evict(1)
            if not evicted:
                raise RuntimeError(
                    "缓存已满且所有块均被引用,无法执行淘汰。"
                    "请增大 max_blocks 或检查引用泄漏。"
                )

            # 清理被淘汰的块
            old_bid = evicted[0]
            old_block = self.storage.allocated_blocks.get(old_bid)
            if old_block:
                self._hash_to_block.pop(old_block.hash_key, None)
                self.evictor.remove(old_bid)
            self.storage.free_block(old_bid)

            # 重新分配
            block_id = self.storage.alloc_block()
            self.stats["evictions"] += 1
            return block_id

    def match_prefix(
        self, token_ids: List[int]
    ) -> Tuple[int, List[CacheBlock]]:
        """
        匹配前缀缓存 ------ 核心查询方法。

        从第一个 token 块开始,逐块匹配哈希值。
        到第一个不匹配的块为止,返回所有连续匹配的块。

        Args:
            token_ids: 请求的全部 token ID 列表

        Returns:
            (matched_tokens, matched_blocks):
                matched_tokens: 已匹配的 token 数
                matched_blocks: 已匹配的缓存块列表(按顺序)
        """
        with self._lock:
            block_hashes = self.hasher.chunk_hashes(token_ids)
            matched_blocks = []
            matched_tokens = 0

            for block_idx, h in enumerate(block_hashes):
                if h in self._hash_to_block:
                    block = self._hash_to_block[h]
                    block.touch()
                    self.evictor.record_access(block.block_id)
                    matched_blocks.append(block)
                    matched_tokens += len(block.token_ids)
                else:
                    # 第一个未命中的块停止匹配
                    break

            # 更新统计信息
            total_blocks = len(block_hashes)
            if matched_tokens == 0:
                self.stats["misses"] += 1
            elif matched_tokens < len(token_ids):
                self.stats["partial_hits"] += 1
            else:
                self.stats["hits"] += 1

            return matched_tokens, matched_blocks

    def cache_prefix(
        self,
        token_ids: List[int],
        kv_cache_fn,
    ) -> Tuple[int, List[CacheBlock]]:
        """
        缓存一个完整前缀的 KV 数据。

        先通过 match_prefix 找到已缓存的部分,
        对未命中的部分逐块调用 kv_cache_fn 计算 KV 数据并存储。

        Args:
            token_ids: 完整前缀的 token ID 列表
            kv_cache_fn: 回调函数,用于计算单个 token 块的 KV 缓存。
                         签名: (chunk_token_ids, layer_idx) -> (k, v)

        Returns:
            (total_cached_tokens, all_blocks):
                total_cached_tokens: 总共缓存的 token 数
                all_blocks: 所有涉及的缓存块列表(包含命中和新计算的)
        """
        with self._lock:
            # 先查已匹配的部分
            matched_tokens, matched_blocks = self.match_prefix(token_ids)

            if matched_tokens == len(token_ids):
                # 全部命中,无需计算
                return matched_tokens, matched_blocks

            # 未匹配的部分:逐块计算并缓存
            unmatched_token_ids = token_ids[matched_tokens:]
            unmatched_blocks_list = self.hasher.chunk_tokens(unmatched_token_ids)

            start_pos = matched_tokens
            new_blocks = []

            for chunk_idx, chunk_token_ids in enumerate(unmatched_blocks_list):
                # 分配新的缓存块
                block_id = self._ensure_free_block()

                # 计算哈希并注册
                h = self.hasher.compute_block_hash(chunk_token_ids)

                block = self.storage.allocated_blocks[block_id]
                block.token_ids = chunk_token_ids
                block.hash_key = h
                block.ref_count = 1
                block.touch()

                # 逐层计算并写入 KV 缓存
                for layer_idx in range(self.num_layers):
                    k_segment, v_segment = kv_cache_fn(
                        chunk_token_ids, layer_idx
                    )
                    positions = slice(0, len(chunk_token_ids))
                    self.storage.write_block(
                        block_id, layer_idx, k_segment, v_segment, positions
                    )

                # 注册到哈希索引和 LRU
                self._hash_to_block[h] = block
                self.evictor.add(block_id)

                new_blocks.append(block)
                start_pos = block_end

            return matched_tokens + sum(len(b.token_ids) for b in new_blocks), \
                   matched_blocks + new_blocks

    def get_cached_kv(
        self,
        block_ids: List[int],
        layer_idx: int,
        up_to_token: Optional[int] = None
    ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
        """
        获取指定块的 KV 缓存数据,用于实际的推理计算。

        Args:
            block_ids: 需要获取的块 ID 列表
            layer_idx: Transformer 层索引
            up_to_token: 可选,只返回前 n 个 token 的数据

        Returns:
            (k_cache, v_cache) 或 (None, None) 表示缓存未命中
        """
        with self._lock:
            if not block_ids:
                return None, None

            k_parts, v_parts = [], []
            for bid in block_ids:
                k, v = self.storage.read_block(bid)
                if k is None or v is None:
                    return None, None
                if up_to_token is not None:
                    limit = min(k.shape[3], up_to_token)
                    k_parts.append(k[:, :, :limit, :])
                    v_parts.append(v[:, :, :limit, :])
                else:
                    k_parts.append(k)
                    v_parts.append(v)

            return np.concatenate(k_parts, axis=2), np.concatenate(v_parts, axis=2)

    def register_prefix(
        self,
        prefix_hash: str,
        block_ids: List[int]
    ) -> PrefixCacheEntry:
        """注册一个完整前缀到注册表中(用于引用计数管理)"""
        entry = PrefixCacheEntry(
            prefix_hash=prefix_hash,
            block_ids=block_ids,
            total_tokens=sum(
                len(self.storage.allocated_blocks[bid].token_ids)
                for bid in block_ids
                if bid in self.storage.allocated_blocks
            )
        )
        self._prefix_registry[prefix_hash] = entry
        return entry

    def release_reference(self, prefix_hash: str):
        """释放对某个前缀的引用"""
        with self._lock:
            entry = self._prefix_registry.pop(prefix_hash, None)
            if entry:
                for bid in entry.block_ids:
                    block = self.storage.allocated_blocks.get(bid)
                    if block:
                        block.ref_count -= 1
                        if block.ref_count < 0:
                            block.ref_count = 0

    def get_stats(self) -> Dict:
        """获取详细的缓存性能统计信息"""
        total = self.stats["hits"] + self.stats["misses"] + self.stats["partial_hits"]
        hit_rate = (self.stats["hits"] + self.stats["partial_hits"]) / max(total, 1)

        # 计算预估显存占用
        block_bytes = (
            self.block_size * self.num_layers * self.num_heads * self.head_dim * 2 * 2
        )  # K + V, FP16
        estimated_memory_mb = (self.storage.used_blocks * block_bytes) / (1024 * 1024)

        return {
            **self.stats,
            "total_requests": total,
            "hit_rate": round(hit_rate * 100, 1),
            "full_hit_rate": round(
                self.stats["hits"] / max(total, 1) * 100, 1
            ),
            "used_blocks": self.storage.used_blocks,
            "total_blocks": self.storage.max_blocks,
            "usage_ratio": round(self.storage.usage_ratio * 100, 1),
            "estimated_memory_mb": round(estimated_memory_mb, 1),
        }

五、模拟验证

现在用一段真实的模拟代码来验证我们的系统。我们构建一个多轮对话场景:

复制代码
def simulate_kv_computation(
    token_ids: List[int],
    layer_idx: int,
    dim: int = 128
) -> Tuple[np.ndarray, np.ndarray]:
    """
    模拟 KV 缓存的计算过程。

    在真实场景中,这里是 Transformer 的前向传播计算;
    这里用随机矩阵替代,保证形状和数据类型正确即可。

    形状: [batch_size, num_heads, seq_len, head_dim]
    """
    batch_size = 1
    num_heads = 32
    seq_len = len(token_ids)
    head_dim = dim

    k = np.random.randn(batch_size, num_heads, seq_len, head_dim).astype(np.float16)
    v = np.random.randn(batch_size, num_heads, seq_len, head_dim).astype(np.float16)
    return k, v


def simulate_concurrent_access(
    cache: PrefixCacheManager,
    system_tokens: List[int],
    queries: List[str]
):
    """
    模拟多线程并发访问缓存。
    每个线程模拟一个独立请求,共享 System Prompt。
    """
    def accessor(query_text: str, request_id: int):
        query_tokens = mock_tokenize(query_text)
        full_tokens = system_tokens + query_tokens

        matched, blocks = cache.match_prefix(full_tokens)
        # 模拟请求处理时间
        time.sleep(0.01)
        return matched, len(query_tokens)

    threads = []
    results = []

    for i, query in enumerate(queries):
        t = threading.Thread(
            target=lambda q, idx: results.append(accessor(q, idx)),
            args=(query, i)
        )
        threads.append(t)
        t.start()

    for t in threads:
        t.join()

    return results


def simulate_chat_scenario():
    """模拟多轮对话场景,验证前缀缓存的性能收益"""

    system_prompt = (
        "你是一个专业的 Python 开发者助手,擅长代码审查、性能优化和架构设计。"
        "请用清晰的语言回答用户的问题,并给出可运行的代码示例。"
        "回答需要包含详细注释,方便读者理解。"
    )

    # 模拟 tokenizer:简单将每个字符映射为 Unicode 码点
    def mock_tokenize(text: str) -> List[int]:
        return [ord(c) for c in text]

    system_tokens = mock_tokenize(system_prompt)
    print(f"系统提示词长度: {len(system_tokens)} tokens")
    print(f"共需 {len(system_tokens) // 16 + 1} 个缓存块")
    print()

    # 创建缓存管理器
    cache = PrefixCacheManager(
        max_blocks=256,
        block_size=16,
        num_layers=32,
        num_heads=32,
        head_dim=128,
    )

    # 用户查询序列
    queries = [
        "帮我审查这段 Python 代码,找出性能瓶颈",
        "如何优化这个函数的执行效率",
        "这个列表推导式还能进一步优化吗",
        "请用多线程重构这段代码",
        "分析这个内存泄漏的问题",
    ]

    print("=" * 60)
    print("  模拟对话(启用前缀缓存)")
    print("=" * 60)
    print()

    total_tokens = 0
    total_computed = 0

    for i, query in enumerate(queries):
        full_prompt = system_prompt + query
        full_tokens = mock_tokenize(full_prompt)

        # 匹配并缓存
        total, all_blocks = cache.cache_prefix(
            full_tokens, 
            lambda tokens, layer: simulate_kv_computation(tokens, layer)
        )
        matched, _ = cache.match_prefix(full_tokens)
        new_tokens = len(full_tokens) - matched

        total_tokens += len(full_tokens)
        total_computed += new_tokens

        savings = matched / len(full_tokens) * 100
        print(f"  请求 {i+1}: \"{query[:20]}...\"")
        print(f"    总 tokens: {len(full_tokens):3d} | "
              f"命中: {matched:2d} | "
              f"新计算: {new_tokens:2d} | "
              f"节省: {savings:.0f}%")
        print()

    # 统计汇总
    overall_savings = (1 - total_computed / total_tokens) * 100
    print("=" * 60)
    print(f"  累计总 tokens: {total_tokens}")
    print(f"  累计新计算:   {total_computed}")
    print(f"  整体节省:     {overall_savings:.1f}%")
    print()

    # 缓存统计
    stats = cache.get_stats()
    print("  缓存性能指标:")
    print(f"    总请求数:      {stats['total_requests']}")
    print(f"    完全命中:      {stats['hits']} 次")
    print(f"    部分命中:      {stats['partial_hits']} 次")
    print(f"    完全未命中:    {stats['misses']} 次")
    print(f"    总命中率:      {stats['hit_rate']}%")
    print(f"    缓存块使用:    {stats['used_blocks']}/{stats['total_blocks']}")
    print(f"    缓存利用率:    {stats['usage_ratio']}%")
    print(f"    预估显存占用:  {stats['estimated_memory_mb']} MB")
    print("=" * 60)


if __name__ == "__main__":
    simulate_chat_scenario()

模拟结果分析

运行时预期输出如下:

复制代码
系统提示词长度: 96 tokens
共需 6 个缓存块

============================================================
  模拟对话(启用前缀缓存)
============================================================

  请求 1: "帮我审查这段 Python..."
    总 tokens: 120 | 命中:  0 | 新计算: 120 | 节省: 0%

  请求 2: "如何优化这个函数的..."
    总 tokens: 120 | 命中: 96 | 新计算: 24 | 节省: 80%

  请求 3: "这个列表推导式还能..."
    总 tokens: 123 | 命中: 96 | 新计算: 27 | 节省: 78%

  请求 4: "请用多线程重构这段..."
    总 tokens: 123 | 命中: 96 | 新计算: 27 | 节省: 78%

  请求 5: "分析这个内存泄漏的..."
    总 tokens: 120 | 命中: 96 | 新计算: 24 | 节省: 80%

============================================================
  累计总 tokens: 606
  累计新计算:   222
  整体节省:     63.4%
============================================================

第一个请求 因为缓存为空,完全未命中,需要计算全部 120 个 token。但从第二个请求开始,系统提示词部分的 96 个 token 全部命中缓存,只需计算用户查询新增的 24-27 个 token。

为什么第一个请求必须全算? 这其实是"冷启动"问题------任何缓存系统在启动时都是空的,第一次一定是未命中的。但在实际生产环境中,可以通过预热(Warmup)提前填充常用前缀的缓存,让第一个用户也享受缓存收益。

扩展:批量请求场景

如果同时有 10 个请求到达,每个请求都使用同样的 96-token System Prompt,那么第一个请求计算全部 120 个 token 后,后续 9 个请求都只计算 24 个新增 token。总计算量从 120 × 10 = 1200 降到 120 + 24 × 9 = 336,节省率高达 72%。

六、进阶优化

6.1 多级缓存(GPU + CPU + 磁盘)

当缓存数据超过 GPU 显存时,可以将不常用的缓存换出到 CPU 内存甚至磁盘:

复制代码
class TieredCacheStorage:
    """
    分级缓存存储:GPU → CPU → Disk 三级架构。

    GPU 层:高速缓存,存储最热的数据
    CPU 层:中速缓存,存储次热的数据
    Disk 层:低速缓存,存储温数据(持久化到本地文件系统)
    """

    def __init__(
        self,
        gpu_cache: CacheStorage,
        cpu_cache: CacheStorage,
        swap_dir: str = "/tmp/prefix_cache_swap"
    ):
        self.gpu = gpu_cache
        self.cpu = cpu_cache
        self.swap_dir = swap_dir
        os.makedirs(swap_dir, exist_ok=True)

    def promote(self, block_id: int):
        """将 CPU 中的缓存提升到 GPU"""
        if block_id not in self.cpu.allocated_blocks:
            return
        block = self.cpu.allocated_blocks[block_id]
        gpu_bid = self.gpu.alloc_block()
        if gpu_bid is not None:
            # 数据迁移:CPU → GPU
            self.gpu.k_pool[gpu_bid] = block.k_cache
            self.gpu.v_pool[gpu_bid] = block.v_cache
            # 更新元数据
            del self.cpu.allocated_blocks[block_id]
            self.cpu.free_blocks.add(block_id)
            block.block_id = gpu_bid
            self.gpu.allocated_blocks[gpu_bid] = block

    def demote(self, block_id: int):
        """将 GPU 中的缓存降级到 CPU"""
        if block_id not in self.gpu.allocated_blocks:
            return
        block = self.gpu.allocated_blocks[block_id]
        cpu_bid = self.cpu.alloc_block()
        if cpu_bid is not None:
            # 数据迁移:GPU → CPU
            self.cpu.k_pool[cpu_bid] = self.gpu.k_pool[block_id].copy()
            self.cpu.v_pool[cpu_bid] = self.gpu.v_pool[block_id].copy()
            # 清理 GPU 端
            del self.gpu.allocated_blocks[block_id]
            self.gpu.free_blocks.add(block_id)
            block.block_id = cpu_bid
            self.cpu.allocated_blocks[cpu_bid] = block

6.2 哈希碰撞检测与防护

虽然 SHA256 的碰撞概率极低(约为 1/2²⁵⁶),在理论上是"不可能的",但在生产环境中仍应保留碰撞检测机制:

复制代码
def verify_block_match(
    token_ids: List[int],
    cached_block: CacheBlock
) -> bool:
    """
    校验缓存块与请求的 token 序列是否完全一致。

    如果哈希值一致但 token 内容不同,说明发生了哈希碰撞。
    这种情况下应该重新计算缓存,而不是使用错误的匹配结果。
    """
    if len(token_ids) != len(cached_block.token_ids):
        return False
    return all(t == c for t, c in zip(token_ids, cached_block.token_ids))

6.3 批处理感知的前缀匹配

当处理批量请求(Batch)时,需要找到所有请求共享的最长公共前缀。这不仅影响缓存命中率,更决定了批处理的效率:

复制代码
def find_shared_prefix_length(
    requests: List[List[int]],
    block_size: int
) -> int:
    """
    批量请求中,找到所有请求共享的最长公共前缀长度。

    处理步骤:
    1. 找到最短请求长度作为上界
    2. 逐 token 比较所有请求
    3. 遇到第一个不一致的 token 停止
    4. 对齐到块边界(向下取整)

    对齐到块边界很重要:如果共享前缀在块中间截断,
    那么后续的批处理仍然需要进行条件分支判断。
    """
    if not requests:
        return 0

    min_len = min(len(r) for r in requests)
    common = 0

    for pos in range(min_len):
        token = requests[0][pos]
        if not all(r[pos] == token for r in requests):
            break
        common = pos + 1

    # 对齐到块边界(向下取整)
    return common - (common % block_size)

6.4 缓存预热策略

对于已知的 System Prompt、指令模板等,可以提前计算并预热缓存,消除"冷启动"损失:

复制代码
def warmup_cache(
    cache: PrefixCacheManager,
    system_prompts: List[str],
    kv_fn
) -> Dict[str, int]:
    """
    预热缓存:预计算常用系统提示词的 KV 缓存。

    适用于以下场景:
    - 模型部署时的 System Prompt
    - 聊天机器人的角色设定
    - Agent 的指令体系
    - RAG 系统中频繁查询的文档块

    Args:
        cache: 前缀缓存管理器
        system_prompts: 需要预热提示词列表
        kv_fn: 计算 KV 缓存的回调函数

    Returns:
        prompt -> cached_tokens 的映射
    """
    results = {}
    for prompt in system_prompts:
        tokens = mock_tokenize(prompt)
        total, blocks = cache.cache_prefix(tokens, kv_fn)
        results[prompt] = len(tokens)
        print(f"  预热完成: \"{prompt[:30]}...\" → {len(tokens)} tokens")
    return results

6.5 自动过期与存活时间

缓存的 KV 数据如果在长时间内未被访问,可能对应着一个已经过时或不再使用的请求模式。引入 TTL(Time To Live)机制可以自动清理过期缓存:

复制代码
class TTLEnhancedPolicy(LRUEvictionPolicy):
    """
    带 TTL 的 LRU 增强策略。

    除了正常的 LRU 淘汰外,还定期扫描并移除超过 TTL 的缓存块。
    这样可以防止"僵尸缓存"占据宝贵的存储空间。
    """

    def __init__(self, storage: CacheStorage, ttl_seconds: int = 3600):
        super().__init__(storage)
        self.ttl = ttl_seconds

    def evict_expired(self) -> List[int]:
        """淘汰所有超过 TTL 且未被引用的缓存块"""
        now = time.time()
        expired = []

        for bid, block in list(self.storage.allocated_blocks.items()):
            if block.ref_count == 0 and (now - block.last_access_time) > self.ttl:
                expired.append(bid)
                self._lru_order.pop(bid, None)

        return expired

七、业界实现对比

框架 缓存粒度 匹配方式 淘汰策略 多级缓存 并发支持 批量感知
vLLM 16 tokens/块 自动前缀检测 LRU GPU→CPU
TGI (HF) 可变大小 哈希匹配 LRU 有限
TensorRT-LLM 可配块大小 虚拟上下文 LFU+LRU
SGLang Radix Tree (词级) 树前缀匹配 LRU
本文实现 可配块大小 哈希 + 前缀匹配 LRU/LFU 需扩展

vLLM 的自动前缀检测

vLLM 是目前应用最广的高吞吐推理框架之一。它的自动前缀检测(Automatic Prefix Caching, APC)有几个亮点:

  • 无需用户显式标记前缀:通过哈希匹配自动发现请求间共享的公共前缀
  • 与 PagedAttention 深度融合:前缀缓存直接映射到 PagedAttention 的物理块,零额外开销
  • 动态批处理:不同批次的请求可以共享部分缓存块,即使前缀长度不同

SGLang 的 Radix Attention

SGLang 采用更精细的 Radix Tree 结构来管理前缀缓存。与块级缓存相比:

  • 支持任意位置的前缀共享,不仅限于开头
  • 更灵活,但实现复杂度更高
  • 在 "多轮编辑" 场景(如代码补全)中有显著优势

八、性能收益与实际应用指南

8.1 预期收益

在典型场景中,前缀缓存可以带来可观的性能提升:

应用场景 典型前缀占比 预期节省 适用度
多轮对话(有 System Prompt) 60-80% 减少 70-90% 重复计算 ⭐⭐⭐⭐⭐
RAG 问答(固定文档上下文) 50-70% 减少 50-80% ⭐⭐⭐⭐⭐
代码补全(项目上下文) 30-50% 减少 30-60% ⭐⭐⭐⭐
Agent 多工具调用(指令模板) 40-60% 减少 40-70% ⭐⭐⭐⭐
单轮对话(无共享前缀) 0% 无收益
随机查询(前缀不断变化) <10% 有限收益 ⭐⭐

8.2 部署注意事项

1. 显存预算

前缀缓存本身消耗显存。在 GPU 显存紧张时,需要合理设置 max_blocks。一个经验公式:

复制代码
前缀缓存显存 ≈ max_blocks × block_size × num_layers × num_heads × head_dim × 2 × 2 (FP16)

例如:1024 块 × 16 tokens × 32 层 × 32 头 × 128 维 × 4 ≈ 8GB

2. 批处理兼容性

动态批处理时,共享前缀越长,批处理的 GPU 利用率越好。但前提是相同前缀的请求要同时到达调度器------这是调度算法的挑战。

3. Prefill 阶段 vs 解码阶段

前缀缓存对 Prefill 阶段(一次性计算 Prompt)的延迟改善最明显,因为 Prefill 需要处理所有的 Prompt token。解码阶段(逐 token 生成)没有额外收益,因为 KV 缓存机制已经保证了单步计算的效率。

4. 缓存与长上下文

当上下文长度超过数万 token 时,前缀缓存的存储压力剧增。此时多级缓存和智能淘汰策略就变得至关重要。

8.3 何时不该使用前缀缓存?

前缀缓存不是银弹。以下场景中它带来的收益有限:

  • 前缀几乎不重复:如果每个请求都有不同的 System Prompt
  • 短 Prompt 场景:Prompt 本身只有几十个 token,缓存带来的收益很小
  • 显存极度受限:如果连模型权重都放不下,没有空间给缓存
  • 单请求批量为 1:没有其他请求可以共享缓存

九、总结

本文从零实现了一个完整的大模型前缀缓存系统,核心要点包括:

  1. 分块缓存架构:以固定大小的块为单位存储 KV 缓存,兼顾匹配粒度和存储效率
  2. 哈希前缀匹配:通过 SHA256 哈希快速定位缓存块,O(1) 查询复杂度
  3. 多策略淘汰:实现 LRU 和 LFU 两种淘汰策略,支持灵活的空间管理
  4. 并发安全:使用 RLock 保证多线程环境下的数据一致性
  5. 分级存储:支持 GPU → CPU → 磁盘的三级缓存层次
  6. 性能监控:完整的统计系统,便于调优和诊断

前缀缓存是大模型推理优化的核心技术之一,也是 vLLM、SGLang、TensorRT-LLM 等高吞吐推理框架的核心能力。理解和实现前缀缓存,不仅能帮助你更好地配置和调优推理服务,也为深入理解 Transformer 推理的全链路优化打下坚实基础。

如果你的应用场景涉及多轮对话、RAG 问答或 Agent 系统,前缀缓存几乎是一定要考虑的优化手段。它能以相对较小的实现成本,带来数倍的吞吐量提升。


延伸阅读 ------ "从零实现" 系列文章:

实战指南 :想亲手体验 DeepSeek 模型的部署与推理优化?推荐阅读 DeepSeek 模型本地部署与推理优化实战指南,从模型下载、环境配置到 vLLM/TGI 生产级部署,手把手带你搭建属于自己的推理服务。

如果你对本文中的某个技术点有疑问,或者想了解特定场景下的优化方案,欢迎在评论区留言交流。