KV Cache(键值缓存)技术

在自回归模型(如Transformer解码器)中,生成文本时是逐个token进行的。每次生成新token时,注意力机制需要计算当前token与之前所有token之间的关系,这涉及大量的矩阵运算。由于计算复杂度随序列长度增长而急剧上升(例如,预测第1001个token时需处理1000×1000的QK矩阵),效率会显著下降。

为提升推理速度,引入了KV Cache(键值缓存)技术:

  • Q 是当前时间步的输入,每次生成新token时都会变化,因此无需缓存。
  • K(Key)和V(Value) 来自之前的token,在后续步骤中可以复用,因此将其缓存起来,避免重复计算。
  • 由于decoder具有因果性(causal attention),每个token只关注前面的token,所以只需重新计算新token的Q,以及它对已有K、V的注意力权重。
  • 通过缓存已计算过的K和V,可以大幅减少重复计算,从而加速推理过程。

简而言之:KV Cache是一种"用内存换取速度"的优化技巧,通过存储历史K和V的计算结果,避免在每一步都重新计算整个注意力矩阵,显著提升了自回归模型的推理效率。


核心要点:
只缓存K和V,不缓存Q → 因为Q是动态输入,K/V可复用 → 提升效率 → KV Cache = 内存换速度。

实现要点(PyTorch 伪代码)

复制代码
class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self):
        self.W_q = ...
        self.W_k = ...
        self.W_v = ...
        self.register_buffer("cache_k", None, persistent=False)
        self.register_buffer("cache_v", None, persistent=False)

    def forward(self, x, use_cache=False):
        Q = self.W_q(x)                     # (B, 1, D) ------ 生成阶段通常 seq_len=1
        K_new = self.W_k(x)                 # (B, 1, D)
        V_new = self.W_v(x)

        if use_cache:
            if self.cache_k is None:
                self.cache_k = K_new
                self.cache_v = V_new
            else:
                self.cache_k = torch.cat([self.cache_k, K_new], dim=1)
                self.cache_v = torch.cat([self.cache_v, V_new], dim=1)
            K, V = self.cache_k, self.cache_v
        else:
            K, V = K_new, V_new

        attn = Q @ K.transpose(-2, -1) / sqrt(d)
        attn = mask_and_softmax(attn)       # causal mask
        output = attn @ V
        return output

Hugging Face Transformers 中,KV Cache(也称为 past key values)是默认启用的 ,只要使用的是支持它的生成方法(如 .generate()),并且模型是自回归解码器(如 GPT-2、LLaMA、Qwen、Bloom 等)。

下面将从 代码示例 + 原理解释 两个层面,详细说明如何显式启用和查看 KV Cache


1. 基础用法:自动启用 KV Cache(推荐)

复制代码
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2-0.5B"  # 或 gpt2, meta-llama/Llama-3-8b 等
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

prompt = "今天天气很"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# 使用 generate() ------ 自动启用 KV Cache
outputs = model.generate(
    **inputs,
    max_new_tokens=10,
    do_sample=False,
    use_cache=True  # 默认就是 True!
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

关键点

  • use_cache=True.generate() 的默认参数。
  • 模型内部会自动管理 past_key_values,你无需手动处理。
  • 这就是为什么生成速度比逐 token 调用快很多!

2. 手动控制 KV Cache(用于研究或自定义生成)

如果想一步步生成显式操作 KV Cache,可以这样做:

复制代码
import torch

# 初始化
input_ids = tokenizer("今天天气很", return_tensors="pt").input_ids.to(model.device)
past_key_values = None  # 初始为空

generated = input_ids.clone()

for step in range(5):  # 生成5个新token
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True
        )
    
    # 获取 logits 并采样下一个 token
    next_token_logits = outputs.logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
    
    # 更新
    generated = torch.cat([generated, next_token], dim=1)
    input_ids = next_token  # 下一步只输入新 token!
    past_key_values = outputs.past_key_values  # 关键:传递 KV Cache

print(tokenizer.decode(generated[0], skip_special_tokens=True))

重点解释:

  • 第一次调用input_ids = [今, 天, 天, 气, 很]past_key_values=None
    • 模型计算全部 K/V,并返回 past_key_values(包含5个 token 的 K/V)
  • 第二次调用input_ids = [好](仅新 token),past_key_values=上一步结果
    • 模型只计算 [好] 的 Q,K/V 从 cache 读取历史 + 自己
    • 返回新的 past_key_values(长度6)
  • 如此循环,每步只处理一个 token,但利用缓存避免重复计算。

3. 查看 KV Cache 的结构

可以打印 past_key_values 来理解其格式:

复制代码
print(type(outputs.past_key_values))  # tuple
print(len(outputs.past_key_values))   # = num_layers (e.g., 24 for LLaMA-7B)

# 每一层是一个 (key, value) 元组
layer_0_key, layer_0_value = outputs.past_key_values[0]
print(layer_0_key.shape)   # (batch_size, num_heads, seq_len, head_dim)
print(layer_0_value.shape) # same

例如,对于 Qwen2-0.5B(层数 24):

  • past_key_values 是长度为 24 的 tuple
  • 每个元素是 (key_tensor, value_tensor)
  • shape: (1, 14, 5, 64) → batch=1, heads=14, seq_len=5, head_dim=64

注意:不同模型的 KV 形状可能不同(比如是否转置、是否合并层等),但逻辑一致。


4. 禁用 KV Cache(对比实验)

复制代码
# 不使用 cache:每次都要重新 encode 整个序列!
outputs_slow = model.generate(
    **inputs,
    max_new_tokens=10,
    use_cache=False  # 强制禁用
)

会发现:

  • 速度明显变慢
  • 显存波动更大(因为每步都重新计算整个 attention

5. 注意事项

项目 说明
仅适用于 decoder-only 模型 如 GPT、LLaMA、Qwen。Encoder-decoder(如 T5)也有类似机制,但更复杂。
必须开启 use_cache=True model.generate()model.forward()
输入必须是增量的 在手动循环中,后续 input_ids 只能是新 token,不能重复传整个序列!
位置编码要正确 模型内部会根据 past_key_values 的长度自动计算当前位置(无需处理)

总结

场景 是否需要手动管 KV Cache?
正常生成文本 不需要,.generate() 自动处理
自定义解码策略(如 beam search 修改版) 需要显式传递 past_key_values
研究 attention 行为 可打印 past_key_values 分析
相关推荐
linmoo19864 小时前
Langchain4j 系列之十一 - 工具调用(AI Services)
人工智能·langchain·工具·langchain4j·toolcall·tool calling
victory04316 小时前
llama2 MLP 门控FFN
深度学习·transformer
laplace01237 小时前
Part3 RAG文档切分
笔记·python·中间件·langchain·rag
至此流年莫相忘7 小时前
LangGraph之条件边
langchain
paopao_wu8 小时前
LangChainV1.0[05]-记忆管理
人工智能·python·langchain·ai编程
Hcoco_me9 小时前
大模型面试题46:在训练7B LLM时,如果使用AdamW优化器,那么它需要的峰值显存是多少?
开发语言·人工智能·深度学习·transformer·word2vec
雍凉明月夜10 小时前
深度学习网络笔记Ⅴ(Transformer源码详解)
笔记·深度学习·transformer
小五Z10 小时前
LangChain框架--LLM接入方式
ai·langchain
爱吃泡芙的小白白11 小时前
Agent学习——并行化模式
学习·langchain·agent·google adk
菜鸟冲锋号11 小时前
从零搭建高可用GraphRAG系统:LangChain+Neo4j+FAISS+Qwen-7B实战指南
langchain·neo4j·faiss