从 PyTorch Attention 源码理解 KV Cache、缓存命中与 Prefix Cache

做大模型推理时,我们经常会碰到几个看起来很像的词:KV cache、cache hit、prefix cache、prompt cache。它们都和"缓存"有关,但并不在同一个抽象层次上。

可以先用一句话建立直觉:KV cache 是 Transformer decoder self-attention 里缓存下来的 Key/Value 张量;缓存命中是推理系统发现某段已有 K/V 可以复用;prefix cache 则是把"相同 token 前缀对应的 KV cache"跨请求复用起来的一套工程机制。

如果你熟悉一点 PyTorch,理解过 attention 的基本公式,那么 KV cache 并不神秘。它本质上就是把原本会在每一步重复计算的中间张量保存下来,下一步继续用。

1. 先看普通 causal self-attention

Transformer attention 的核心公式是:

scss 复制代码
Attention(Q, K, V) = softmax(QK^T / sqrt(d)) V

写成一个最小版 PyTorch causal self-attention,大概是这样:

ini 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
​
class CausalSelfAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
​
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
​
    def forward(self, x):
        # x: [B, T, C]
        B, T, C = x.shape
​
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
​
        # [B, T, C] -> [B, H, T, D]
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
​
        scores = q @ k.transpose(-2, -1)  # [B, H, T, T]
        scores = scores / (self.head_dim ** 0.5)
​
        mask = torch.tril(torch.ones(T, T, device=x.device)).bool()
        scores = scores.masked_fill(~mask, float("-inf"))
​
        probs = F.softmax(scores, dim=-1)
        out = probs @ v  # [B, H, T, D]
​
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.o_proj(out)

这里最重要的是三次线性投影:

ini 复制代码
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

如果是在 prefill 阶段处理完整 prompt,比如输入长度 T=1000,模型会一次性算出所有 token 的 Q/K/V,然后做一个 [T, T] 的 causal attention。这很自然,也没什么浪费。

麻烦发生在自回归生成阶段。生成文本时,模型每次只往后追加一个 token。假设上下文里已经有 1000 个 token,现在要生成第 1001 个 token。新 token 的 query 需要和前面 1000 个 token,以及它自己,对应的 key/value 做 attention。如果每一步都把完整上下文重新喂进模型,就会一遍又一遍地重算历史 token 的 K/V、attention 和 MLP,成本非常高。

KV cache 要解决的就是这个重复计算问题。

2. KV cache 到底缓存了什么

KV cache 缓存的不是原始 token,不是 embedding,也不是 logits,而是每一层 self-attention 中已经投影好的 Key 和 Value。

对一个 decoder-only LLM 来说,每一层都有自己的 KV cache。结构上通常类似这样:

ini 复制代码
past_key_values = [
    (k_layer_0, v_layer_0),
    (k_layer_1, v_layer_1),
    # ...
    (k_layer_n, v_layer_n),
]

其中每个 k_layer_iv_layer_i 的典型 shape 是:

csharp 复制代码
[B, num_kv_heads, seq_len, head_dim]

如果是普通 multi-head attention,num_kv_heads 通常等于 num_heads。如果模型用了 MQA 或 GQA,K/V head 数量会少于 Q head 数量。比如 Q 有 32 个 heads,而 K/V 只有 8 个 heads。这样做的一个直接收益就是降低 KV cache 的显存占用。

把上面的 attention 改成带 cache 的版本,可以写成下面这样:

ini 复制代码
class CachedCausalSelfAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
​
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
​
    def forward(self, x, past_k=None, past_v=None):
        # x: [B, T_new, C]
        # decode 阶段 T_new 通常是 1
        B, T_new, C = x.shape
​
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
​
        q = q.view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T_new, self.num_heads, self.head_dim).transpose(1, 2)
​
        if past_k is not None:
            # past_k: [B, H, T_past, D]
            # k:      [B, H, T_new,  D]
            k_all = torch.cat([past_k, k], dim=2)
            v_all = torch.cat([past_v, v], dim=2)
        else:
            k_all = k
            v_all = v
​
        scores = q @ k_all.transpose(-2, -1)  # [B, H, T_new, T_total]
        scores = scores / (self.head_dim ** 0.5)
​
        # decode 时如果 T_new=1,k_all 只包含历史+当前,通常不需要完整 causal mask
        probs = F.softmax(scores, dim=-1)
        out = probs @ v_all  # [B, H, T_new, D]
​
        out = out.transpose(1, 2).contiguous().view(B, T_new, C)
        out = self.o_proj(out)
​
        return out, k_all, v_all

关键逻辑只有几行:

ini 复制代码
k_all = torch.cat([past_k, k], dim=2)
v_all = torch.cat([past_v, v], dim=2)

当前 token 只需要计算自己的 Q/K/V。它的 K/V 会追加到历史 cache 后面;它的 Q 则拿来和完整的历史 K/V 做 attention。

这里也能解释一个常见问题:为什么叫 KV cache,而不是 QKV cache?原因是历史 token 的 query 对未来没什么用。历史 token 的输出已经算完了,未来 token 不需要再用历史 query。未来 token 真正需要的是"历史 token 可以被我 attend 到",也就是历史 token 的 K 和 V。

3. Prefill 和 Decode:KV cache 在哪里发挥作用

LLM 推理通常可以拆成两个阶段:prefill 和 decode。

Prefill 阶段处理完整 prompt:

ini 复制代码
outputs = model(input_ids=prompt_ids, use_cache=True)
past_key_values = outputs.past_key_values

如果 prompt 长度是 1000,那么每一层会生成类似这样的 cache:

yaml 复制代码
k: [B, H_kv, 1000, D]
v: [B, H_kv, 1000, D]

Decode 阶段开始后,每次只输入最新的一个 token:

ini 复制代码
input_ids = next_token  # [B, 1]
outputs = model(
    input_ids=input_ids,
    past_key_values=past_key_values,
    use_cache=True,
)
past_key_values = outputs.past_key_values

一个极简 generation loop 可以写成这样:

ini 复制代码
past_key_values = None
input_ids = prompt_ids

for _ in range(max_new_tokens):
    outputs = model(
        input_ids=input_ids,
        past_key_values=past_key_values,
        use_cache=True,
    )

    logits = outputs.logits
    past_key_values = outputs.past_key_values

    next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)

    # 关键点:下一轮只喂最新 token,而不是完整上下文
    input_ids = next_token

没有 KV cache 时,第 t 步生成可能要重新跑长度为 t 的上下文,历史部分会被反复计算。有 KV cache 后,每一步只需要为新 token 计算一次 Q/K/V,并复用之前保存下来的 K/V。

不过,KV cache 并不意味着 decode 变成了常数时间。生成新 token 时仍然要做这一步:

ini 复制代码
scores = q_new @ k_cache.transpose(-2, -1)

这里的 k_cache 会随着上下文变长而变长,所以单步 attention 的计算量仍然和上下文长度相关。KV cache 省掉的是"重复计算历史 token 的中间状态",不是把长上下文 attention 本身消掉。

4. 缓存命中和 KV cache 的关系

说到"缓存命中",其实已经不是 attention 算子内部的概念了,而是推理服务层面的概念。

在单个请求内部,decode 每一步都会复用 past_key_values。这当然也是一种复用,但一般不会特别叫 cache hit,因为这是同一个请求的正常状态传递。

工程上更常说的 cache hit,通常指跨请求或跨轮次的 prefix KV cache hit。比如请求 A 是:

makefile 复制代码
System: 你是一个代码助手...
User: 请解释 KV cache

请求 B 是:

makefile 复制代码
System: 你是一个代码助手...
User: 请解释 attention

如果两者的 system prompt 完全相同,并且 tokenization 后的前缀 token 序列也完全一致,那么推理服务可以直接复用 system prompt 对应的 KV cache。请求 B 不需要重新 prefill 那段 system prompt,只需要从没有命中的 suffix 开始继续计算。

概念上可以这么写:

ini 复制代码
cached_past, hit_len = prefix_cache.lookup(input_ids)

if cached_past is None:
    outputs = model(input_ids=input_ids, use_cache=True)
else:
    remaining_ids = input_ids[:, hit_len:]
    outputs = model(
        input_ids=remaining_ids,
        past_key_values=cached_past,
        use_cache=True,
    )

所以这几个词的关系可以理解为:

ini 复制代码
KV cache = 实际被缓存的每层 K/V 张量
cache hit = 当前请求能复用已有 KV cache
cache miss = 当前请求需要重新 prefill 计算 K/V

需要注意的是,命中的对象通常不是自然语言字符串,而是 token 序列。两个 prompt 看起来只差一个空格、换行或标点,但 token 序列可能已经不同,最终能命中的前缀长度也可能变短。

5. Prefix cache:把相同前缀的 KV cache 复用起来

Prefix cache 可以理解成"跨请求复用 decoder KV cache 的机制"。它不是一种新的 attention,也不是另一套模型缓存,而是 KV cache 上面的一层工程化管理策略。

更准确地说,prefix cache 是针对自回归 decoder 的 KV cache 复用策略:当一个新请求的 token 前缀和之前算过的前缀一致时,直接复用这个前缀在各个 decoder layer 里的 self-attention K/V,从而减少 prefill 计算。

这里有一个很重要的限制:必须是前缀。

假设系统之前缓存过:

css 复制代码
A B C D E F

现在来了一个新请求:

r 复制代码
X Y C D E F

虽然 C D E F 这段内容一样,但它不是同一个前缀,通常不能直接复用。原因在于 decoder self-attention 中每个 token 的 hidden state 都依赖它前面的上下文。之前 C 前面是 A B,现在 C 前面是 X Y,那么 C 在各层里算出来的 K/V 也可能不同。

所以 prefix cache 的前提可以概括成:

arduino 复制代码
same prefix, same position, same model state

落到真实推理系统里,还要继续检查模型权重版本、tokenizer、RoPE position、LoRA adapter、量化配置等是否一致。温度、top_p 这类采样参数主要影响 decode 时怎么选 token,一般不影响 prompt prefill 的 K/V;但不同 adapter 或不同 system prompt 会直接改变 K/V,自然也就不能复用。

6. 工程实现里为什么不会一直 torch.cat

前面的教学代码里用了:

ini 复制代码
k_all = torch.cat([past_k, k], dim=2)

这个写法很好理解,但不适合高性能推理。因为每生成一个 token 都 torch.cat,意味着不断申请新显存、复制旧 cache,长上下文或高并发下成本会很明显。

更接近工程实现的方式是预分配 cache,然后按位置写入:

ini 复制代码
# k_cache: [B, H, max_seq_len, D]
# v_cache: [B, H, max_seq_len, D]

def append_kv_cache(k_cache, v_cache, cache_pos, k_new, v_new):
    T_new = k_new.shape[2]
    k_cache[:, :, cache_pos:cache_pos + T_new, :] = k_new
    v_cache[:, :, cache_pos:cache_pos + T_new, :] = v_new
    return cache_pos + T_new

更复杂的推理框架会把 KV cache 做成 block/page 管理。比如每 16 或 32 个 token 一块:

makefile 复制代码
block_0: token 0-15
block_1: token 16-31
block_2: token 32-47

这样 prefix cache 命中时,就可以按 block 复用,而不是粗暴地管理一整段连续 tensor。vLLM 的 PagedAttention、SGLang、TensorRT-LLM、TGI 这类推理系统,都会围绕 KV cache 做大量优化,包括分页管理、prefix sharing、cache eviction、KV offload、跨请求调度等。

7. 小结

KV cache 是 decoder self-attention 的底层缓存,保存的是每一层已经算好的 Key/Value 张量。它让 decode 阶段每次只计算新 token 的 Q/K/V,而不是反复重算完整上下文。

缓存命中是推理服务层面的判断:当前请求的某段 token 前缀,是否已经有可复用的 KV cache。如果命中,就可以跳过命中前缀的 prefill 计算。

Prefix cache 则是 KV cache 的一种工程化复用策略,专门优化多个请求共享相同前缀的场景。它复用的是"相同 token 前缀在每个 decoder layer 中的 K/V",不是原始文本,也不是任意中间片段。

如果从源码视角收束一下,可以这么理解:

swift 复制代码
KV cache 关心的是 past_key_values 怎么传;
prefix cache 关心的是新请求进来时,past_key_values 能不能从已有缓存里拿。
相关推荐
IT_陈寒2 小时前
React状态更新总是不及时?你可能漏了这步批处理机制
前端·人工智能·后端
Jinkey2 小时前
要用户手机号真的是为了打骚扰电话吗?浅谈微信生态会员账号体系与资产合并
后端·微信·微信小程序
葫芦和十三3 小时前
图解 MongoDB 06|模式演进:无 schema 是优势还是债
后端·mongodb·agent
葫芦和十三10 小时前
图解 MongoDB 05|文档模型设计:内嵌 vs 引用,反范式不是免费午餐
后端·mongodb·agent
不能放弃治疗14 小时前
单 Agent 实现模式
后端
IT_陈寒16 小时前
Redis内存爆了,原来我漏掉了这个致命配置
前端·人工智能·后端
fliter16 小时前
最后一块拼图:用 bitvec 构造 IPv4 包,真正做出自己的 Ping
后端
fliter17 小时前
用 Rust 解析并生成 ICMP 包:checksum、nom 与 cookie-factory
后端
蝎子莱莱爱打怪18 小时前
XZLL-IM干货系列 03|消息 ID 设计:一个 UUID 搞不定的事,我用两个 ID 解决了
后端·面试·开源