在自回归模型(如Transformer解码器)中,生成文本时是逐个token进行的。每次生成新token时,注意力机制需要计算当前token与之前所有token之间的关系,这涉及大量的矩阵运算。由于计算复杂度随序列长度增长而急剧上升(例如,预测第1001个token时需处理1000×1000的QK矩阵),效率会显著下降。
为提升推理速度,引入了KV Cache(键值缓存)技术:
- Q 是当前时间步的输入,每次生成新token时都会变化,因此无需缓存。
- K(Key)和V(Value) 来自之前的token,在后续步骤中可以复用,因此将其缓存起来,避免重复计算。
- 由于decoder具有因果性(causal attention),每个token只关注前面的token,所以只需重新计算新token的Q,以及它对已有K、V的注意力权重。
- 通过缓存已计算过的K和V,可以大幅减少重复计算,从而加速推理过程。
简而言之:KV Cache是一种"用内存换取速度"的优化技巧,通过存储历史K和V的计算结果,避免在每一步都重新计算整个注意力矩阵,显著提升了自回归模型的推理效率。
核心要点:
只缓存K和V,不缓存Q → 因为Q是动态输入,K/V可复用 → 提升效率 → KV Cache = 内存换速度。
实现要点(PyTorch 伪代码)
class MultiHeadAttentionWithCache(nn.Module):
def __init__(self):
self.W_q = ...
self.W_k = ...
self.W_v = ...
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
def forward(self, x, use_cache=False):
Q = self.W_q(x) # (B, 1, D) ------ 生成阶段通常 seq_len=1
K_new = self.W_k(x) # (B, 1, D)
V_new = self.W_v(x)
if use_cache:
if self.cache_k is None:
self.cache_k = K_new
self.cache_v = V_new
else:
self.cache_k = torch.cat([self.cache_k, K_new], dim=1)
self.cache_v = torch.cat([self.cache_v, V_new], dim=1)
K, V = self.cache_k, self.cache_v
else:
K, V = K_new, V_new
attn = Q @ K.transpose(-2, -1) / sqrt(d)
attn = mask_and_softmax(attn) # causal mask
output = attn @ V
return output
在 Hugging Face Transformers 中,KV Cache(也称为 past key values)是默认启用的 ,只要使用的是支持它的生成方法(如 .generate()),并且模型是自回归解码器(如 GPT-2、LLaMA、Qwen、Bloom 等)。
下面将从 代码示例 + 原理解释 两个层面,详细说明如何显式启用和查看 KV Cache。
1. 基础用法:自动启用 KV Cache(推荐)
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "Qwen/Qwen2-0.5B" # 或 gpt2, meta-llama/Llama-3-8b 等
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
prompt = "今天天气很"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# 使用 generate() ------ 自动启用 KV Cache
outputs = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
use_cache=True # 默认就是 True!
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
关键点:
use_cache=True是.generate()的默认参数。- 模型内部会自动管理
past_key_values,你无需手动处理。- 这就是为什么生成速度比逐 token 调用快很多!
2. 手动控制 KV Cache(用于研究或自定义生成)
如果想一步步生成 并显式操作 KV Cache,可以这样做:
import torch
# 初始化
input_ids = tokenizer("今天天气很", return_tensors="pt").input_ids.to(model.device)
past_key_values = None # 初始为空
generated = input_ids.clone()
for step in range(5): # 生成5个新token
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True
)
# 获取 logits 并采样下一个 token
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
# 更新
generated = torch.cat([generated, next_token], dim=1)
input_ids = next_token # 下一步只输入新 token!
past_key_values = outputs.past_key_values # 关键:传递 KV Cache
print(tokenizer.decode(generated[0], skip_special_tokens=True))
重点解释:
- 第一次调用 :
input_ids = [今, 天, 天, 气, 很],past_key_values=None- 模型计算全部 K/V,并返回
past_key_values(包含5个 token 的 K/V)
- 模型计算全部 K/V,并返回
- 第二次调用 :
input_ids = [好](仅新 token),past_key_values=上一步结果- 模型只计算
[好]的 Q,K/V 从 cache 读取历史 + 自己 - 返回新的
past_key_values(长度6)
- 模型只计算
- 如此循环,每步只处理一个 token,但利用缓存避免重复计算。
3. 查看 KV Cache 的结构
可以打印 past_key_values 来理解其格式:
print(type(outputs.past_key_values)) # tuple
print(len(outputs.past_key_values)) # = num_layers (e.g., 24 for LLaMA-7B)
# 每一层是一个 (key, value) 元组
layer_0_key, layer_0_value = outputs.past_key_values[0]
print(layer_0_key.shape) # (batch_size, num_heads, seq_len, head_dim)
print(layer_0_value.shape) # same
例如,对于 Qwen2-0.5B(层数 24):
past_key_values是长度为 24 的 tuple- 每个元素是
(key_tensor, value_tensor) - shape:
(1, 14, 5, 64)→ batch=1, heads=14, seq_len=5, head_dim=64
注意:不同模型的 KV 形状可能不同(比如是否转置、是否合并层等),但逻辑一致。
4. 禁用 KV Cache(对比实验)
# 不使用 cache:每次都要重新 encode 整个序列!
outputs_slow = model.generate(
**inputs,
max_new_tokens=10,
use_cache=False # 强制禁用
)
会发现:
- 速度明显变慢
- 显存波动更大(因为每步都重新计算整个 attention
5. 注意事项
| 项目 | 说明 |
|---|---|
| 仅适用于 decoder-only 模型 | 如 GPT、LLaMA、Qwen。Encoder-decoder(如 T5)也有类似机制,但更复杂。 |
必须开启 use_cache=True |
在 model.generate() 或 model.forward() 中 |
| 输入必须是增量的 | 在手动循环中,后续 input_ids 只能是新 token,不能重复传整个序列! |
| 位置编码要正确 | 模型内部会根据 past_key_values 的长度自动计算当前位置(无需处理) |
总结
| 场景 | 是否需要手动管 KV Cache? |
|---|---|
| 正常生成文本 | 不需要,.generate() 自动处理 |
| 自定义解码策略(如 beam search 修改版) | 需要显式传递 past_key_values |
| 研究 attention 行为 | 可打印 past_key_values 分析 |