做大模型推理时,我们经常会碰到几个看起来很像的词:KV cache、cache hit、prefix cache、prompt cache。它们都和"缓存"有关,但并不在同一个抽象层次上。
可以先用一句话建立直觉:KV cache 是 Transformer decoder self-attention 里缓存下来的 Key/Value 张量;缓存命中是推理系统发现某段已有 K/V 可以复用;prefix cache 则是把"相同 token 前缀对应的 KV cache"跨请求复用起来的一套工程机制。
如果你熟悉一点 PyTorch,理解过 attention 的基本公式,那么 KV cache 并不神秘。它本质上就是把原本会在每一步重复计算的中间张量保存下来,下一步继续用。
1. 先看普通 causal self-attention
Transformer attention 的核心公式是:
scss
Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V
写成一个最小版 PyTorch causal self-attention,大概是这样:
ini
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
# x: [B, T, C]
B, T, C = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# [B, T, C] -> [B, H, T, D]
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
scores = q @ k.transpose(-2, -1) # [B, H, T, T]
scores = scores / (self.head_dim ** 0.5)
mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
scores = scores.masked_fill(~mask, float("-inf"))
probs = F.softmax(scores, dim=-1)
out = probs @ v # [B, H, T, D]
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(out)
这里最重要的是三次线性投影:
ini
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
如果是在 prefill 阶段处理完整 prompt,比如输入长度 T=1000,模型会一次性算出所有 token 的 Q/K/V,然后做一个 [T, T] 的 causal attention。这很自然,也没什么浪费。
麻烦发生在自回归生成阶段。生成文本时,模型每次只往后追加一个 token。假设上下文里已经有 1000 个 token,现在要生成第 1001 个 token。新 token 的 query 需要和前面 1000 个 token,以及它自己,对应的 key/value 做 attention。如果每一步都把完整上下文重新喂进模型,就会一遍又一遍地重算历史 token 的 K/V、attention 和 MLP,成本非常高。
KV cache 要解决的就是这个重复计算问题。
2. KV cache 到底缓存了什么
KV cache 缓存的不是原始 token,不是 embedding,也不是 logits,而是每一层 self-attention 中已经投影好的 Key 和 Value。
对一个 decoder-only LLM 来说,每一层都有自己的 KV cache。结构上通常类似这样:
ini
past_key_values = [
(k_layer_0, v_layer_0),
(k_layer_1, v_layer_1),
# ...
(k_layer_n, v_layer_n),
]
其中每个 k_layer_i、v_layer_i 的典型 shape 是:
csharp
[B, num_kv_heads, seq_len, head_dim]
如果是普通 multi-head attention,num_kv_heads 通常等于 num_heads。如果模型用了 MQA 或 GQA,K/V head 数量会少于 Q head 数量。比如 Q 有 32 个 heads,而 K/V 只有 8 个 heads。这样做的一个直接收益就是降低 KV cache 的显存占用。
把上面的 attention 改成带 cache 的版本,可以写成下面这样:
ini
class CachedCausalSelfAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x, past_k=None, past_v=None):
# x: [B, T_new, C]
# decode 阶段 T_new 通常是 1
B, T_new, C = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2)
if past_k is not None:
# past_k: [B, H, T_past, D]
# k: [B, H, T_new, D]
k_all = torch.cat([past_k, k], dim=2)
v_all = torch.cat([past_v, v], dim=2)
else:
k_all = k
v_all = v
scores = q @ k_all.transpose(-2, -1) # [B, H, T_new, T_total]
scores = scores / (self.head_dim ** 0.5)
# decode 时如果 T_new=1,k_all 只包含历史+当前,通常不需要完整 causal mask
probs = F.softmax(scores, dim=-1)
out = probs @ v_all # [B, H, T_new, D]
out = out.transpose(1, 2).contiguous().view(B, T_new, C)
out = self.o_proj(out)
return out, k_all, v_all
关键逻辑只有几行:
ini
k_all = torch.cat([past_k, k], dim=2)
v_all = torch.cat([past_v, v], dim=2)
当前 token 只需要计算自己的 Q/K/V。它的 K/V 会追加到历史 cache 后面;它的 Q 则拿来和完整的历史 K/V 做 attention。
这里也能解释一个常见问题:为什么叫 KV cache,而不是 QKV cache?原因是历史 token 的 query 对未来没什么用。历史 token 的输出已经算完了,未来 token 不需要再用历史 query。未来 token 真正需要的是"历史 token 可以被我 attend 到",也就是历史 token 的 K 和 V。
3. Prefill 和 Decode:KV cache 在哪里发挥作用
LLM 推理通常可以拆成两个阶段:prefill 和 decode。
Prefill 阶段处理完整 prompt:
ini
outputs = model(input_ids=prompt_ids, use_cache=True)
past_key_values = outputs.past_key_values
如果 prompt 长度是 1000,那么每一层会生成类似这样的 cache:
yaml
k: [B, H_kv, 1000, D]
v: [B, H_kv, 1000, D]
Decode 阶段开始后,每次只输入最新的一个 token:
ini
input_ids = next_token # [B, 1]
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
一个极简 generation loop 可以写成这样:
ini
past_key_values = None
input_ids = prompt_ids
for _ in range(max_new_tokens):
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = outputs.logits
past_key_values = outputs.past_key_values
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
# 关键点:下一轮只喂最新 token,而不是完整上下文
input_ids = next_token
没有 KV cache 时,第 t 步生成可能要重新跑长度为 t 的上下文,历史部分会被反复计算。有 KV cache 后,每一步只需要为新 token 计算一次 Q/K/V,并复用之前保存下来的 K/V。
不过,KV cache 并不意味着 decode 变成了常数时间。生成新 token 时仍然要做这一步:
ini
scores = q_new @ k_cache.transpose(-2, -1)
这里的 k_cache 会随着上下文变长而变长,所以单步 attention 的计算量仍然和上下文长度相关。KV cache 省掉的是"重复计算历史 token 的中间状态",不是把长上下文 attention 本身消掉。
4. 缓存命中和 KV cache 的关系
说到"缓存命中",其实已经不是 attention 算子内部的概念了,而是推理服务层面的概念。
在单个请求内部,decode 每一步都会复用 past_key_values。这当然也是一种复用,但一般不会特别叫 cache hit,因为这是同一个请求的正常状态传递。
工程上更常说的 cache hit,通常指跨请求或跨轮次的 prefix KV cache hit。比如请求 A 是:
makefile
System: 你是一个代码助手...
User: 请解释 KV cache
请求 B 是:
makefile
System: 你是一个代码助手...
User: 请解释 attention
如果两者的 system prompt 完全相同,并且 tokenization 后的前缀 token 序列也完全一致,那么推理服务可以直接复用 system prompt 对应的 KV cache。请求 B 不需要重新 prefill 那段 system prompt,只需要从没有命中的 suffix 开始继续计算。
概念上可以这么写:
ini
cached_past, hit_len = prefix_cache.lookup(input_ids)
if cached_past is None:
outputs = model(input_ids=input_ids, use_cache=True)
else:
remaining_ids = input_ids[:, hit_len:]
outputs = model(
input_ids=remaining_ids,
past_key_values=cached_past,
use_cache=True,
)
所以这几个词的关系可以理解为:
ini
KV cache = 实际被缓存的每层 K/V 张量
cache hit = 当前请求能复用已有 KV cache
cache miss = 当前请求需要重新 prefill 计算 K/V
需要注意的是,命中的对象通常不是自然语言字符串,而是 token 序列。两个 prompt 看起来只差一个空格、换行或标点,但 token 序列可能已经不同,最终能命中的前缀长度也可能变短。
5. Prefix cache:把相同前缀的 KV cache 复用起来
Prefix cache 可以理解成"跨请求复用 decoder KV cache 的机制"。它不是一种新的 attention,也不是另一套模型缓存,而是 KV cache 上面的一层工程化管理策略。
更准确地说,prefix cache 是针对自回归 decoder 的 KV cache 复用策略:当一个新请求的 token 前缀和之前算过的前缀一致时,直接复用这个前缀在各个 decoder layer 里的 self-attention K/V,从而减少 prefill 计算。
这里有一个很重要的限制:必须是前缀。
假设系统之前缓存过:
css
A B C D E F
现在来了一个新请求:
r
X Y C D E F
虽然 C D E F 这段内容一样,但它不是同一个前缀,通常不能直接复用。原因在于 decoder self-attention 中每个 token 的 hidden state 都依赖它前面的上下文。之前 C 前面是 A B,现在 C 前面是 X Y,那么 C 在各层里算出来的 K/V 也可能不同。
所以 prefix cache 的前提可以概括成:
arduino
same prefix, same position, same model state
落到真实推理系统里,还要继续检查模型权重版本、tokenizer、RoPE position、LoRA adapter、量化配置等是否一致。温度、top_p 这类采样参数主要影响 decode 时怎么选 token,一般不影响 prompt prefill 的 K/V;但不同 adapter 或不同 system prompt 会直接改变 K/V,自然也就不能复用。
6. 工程实现里为什么不会一直 torch.cat
前面的教学代码里用了:
ini
k_all = torch.cat([past_k, k], dim=2)
这个写法很好理解,但不适合高性能推理。因为每生成一个 token 都 torch.cat,意味着不断申请新显存、复制旧 cache,长上下文或高并发下成本会很明显。
更接近工程实现的方式是预分配 cache,然后按位置写入:
ini
# k_cache: [B, H, max_seq_len, D]
# v_cache: [B, H, max_seq_len, D]
def append_kv_cache(k_cache, v_cache, cache_pos, k_new, v_new):
T_new = k_new.shape[2]
k_cache[:, :, cache_pos:cache_pos + T_new, :] = k_new
v_cache[:, :, cache_pos:cache_pos + T_new, :] = v_new
return cache_pos + T_new
更复杂的推理框架会把 KV cache 做成 block/page 管理。比如每 16 或 32 个 token 一块:
makefile
block_0: token 0-15
block_1: token 16-31
block_2: token 32-47
这样 prefix cache 命中时,就可以按 block 复用,而不是粗暴地管理一整段连续 tensor。vLLM 的 PagedAttention、SGLang、TensorRT-LLM、TGI 这类推理系统,都会围绕 KV cache 做大量优化,包括分页管理、prefix sharing、cache eviction、KV offload、跨请求调度等。
7. 小结
KV cache 是 decoder self-attention 的底层缓存,保存的是每一层已经算好的 Key/Value 张量。它让 decode 阶段每次只计算新 token 的 Q/K/V,而不是反复重算完整上下文。
缓存命中是推理服务层面的判断:当前请求的某段 token 前缀,是否已经有可复用的 KV cache。如果命中,就可以跳过命中前缀的 prefill 计算。
Prefix cache 则是 KV cache 的一种工程化复用策略,专门优化多个请求共享相同前缀的场景。它复用的是"相同 token 前缀在每个 decoder layer 中的 K/V",不是原始文本,也不是任意中间片段。
如果从源码视角收束一下,可以这么理解:
swift
KV cache 关心的是 past_key_values 怎么传;
prefix cache 关心的是新请求进来时,past_key_values 能不能从已有缓存里拿。