KV Cache: 大模型自回归生成的关键优化技术
KV Cache(键值缓存)是 Transformer 解码器在自回归推理阶段普遍采用的一种工程优化。它通过缓存每一层注意力计算中历史 token 的 Key 和 Value 矩阵,避免每一步都重复计算,显著降低推理延迟,是 LLM 高性能服务的基础组件。
1. 背景:自回归生成中的重复计算问题
大型语言模型(LLM)在生成文本时采用自回归方式:逐个生成 token,并将新 token 拼接到输入序列末尾,作为下一步的输入。
假设已经生成了 t − 1 t-1 t−1 个 token,当前输入序列长度为 t − 1 t-1 t−1,下一步需要生成第 t t t 个 token。在未使用 KV Cache 的情况下,每一步都要重新计算整个序列 [ x 1 , x 2 , . . . , x t − 1 , x t ] [x_1, x_2, ..., x_{t-1}, x_t] [x1,x2,...,xt−1,xt] 的所有注意力 K 和 V。而实际上,前 t − 1 t-1 t−1 个 token 的 K 和 V 在前几步已经计算过,完全可以直接复用。
不加缓存时,第 t t t 步的计算复杂度为 O ( t 2 ) \mathcal{O}(t^2) O(t2)(因为需要计算 t × t t \times t t×t 的注意力矩阵)。随着序列长度 t t t 增大,计算开销呈平方增长,使推理变得极其缓慢。
2. KV Cache 的核心思想
只计算当前新 token 的 Query(Q)、Key(K)、Value(V),并复用之前所有 token 已缓存的 K 和 V。
- 缓存内容:每一层 Transformer 中,所有历史 token 的 Key 矩阵和 Value 矩阵。
- 更新方式:每生成一个新 token,计算其 K、V 并追加到缓存中。
- 注意力计算:使用当前 token 的 Q 与缓存中所有历史 K 计算注意力分数,再用该分数加权缓存中所有历史 V。
这样,第 t t t 步的计算复杂度从 O ( t 2 ) \mathcal{O}(t^2) O(t2) 降为 O ( t ) \mathcal{O}(t) O(t),因为只需要对新 token 计算一次 Q,并与长度为 t t t 的 K 序列做点积(即 Q K T Q K^T QKT 的其中一行)。
3. 工作机制与流程
以单层自注意力为例。假设已经生成了 t − 1 t-1 t−1 个 token,缓存了:
- K cache = [ k 1 , k 2 , ... , k t − 1 ] K_{\text{cache}} = [k_1, k_2, \dots, k_{t-1}] Kcache=[k1,k2,...,kt−1]
- V cache = [ v 1 , v 2 , ... , v t − 1 ] V_{\text{cache}} = [v_1, v_2, \dots, v_{t-1}] Vcache=[v1,v2,...,vt−1]
第 t t t 步生成新 token x t x_t xt 的过程:
- 将 x t x_t xt 通过 embedding 层和位置编码得到向量表示。
- 通过线性变换计算该 token 的查询 q t q_t qt、键 k t k_t kt、值 v t v_t vt:
q t = x t W Q , k t = x t W K , v t = x t W V q_t = x_t W_Q,\quad k_t = x_t W_K,\quad v_t = x_t W_V qt=xtWQ,kt=xtWK,vt=xtWV - 更新缓存:
K cache ← concat ( K cache , k t ) K_{\text{cache}} \leftarrow \text{concat}(K_{\text{cache}}, k_t) Kcache←concat(Kcache,kt)
V cache ← concat ( V cache , v t ) V_{\text{cache}} \leftarrow \text{concat}(V_{\text{cache}}, v_t) Vcache←concat(Vcache,vt) - 计算当前 token 的注意力输出:
- 注意力分数 s t = q t ⋅ [ k 1 , k 2 , ... , k t ] T s_t = q_t \cdot [k_1, k_2, \dots, k_t]^T st=qt⋅[k1,k2,...,kt]T(长度为 t t t 的行向量)。
- 权重 a t = softmax ( s t / d k ) a_t = \text{softmax}(s_t / \sqrt{d_k}) at=softmax(st/dk ),其中 d k d_k dk 是 head 维度。
- 输出 o t = a t ⋅ V cache o_t = a_t \cdot V_{\text{cache}} ot=at⋅Vcache。
- 经过 FFN 等后续层,得到输出 logits,采样得到下一个 token x t + 1 x_{t+1} xt+1。
每一步只计算一个 q t q_t qt,因此注意力计算的复杂度为 O ( t ) \mathcal{O}(t) O(t)。注意:第一个 token(prefill 阶段)仍需计算完整序列,这一步无法加速。
4. 为什么可以安全地缓存 K 和 V?
在 Transformer Decoder 中,通过因果注意力掩码 (causal attention mask)确保每个 token 只能看到它之前(包括自己)的 token,不能看到未来 token。因此,历史 token 的 K 和 V 不会依赖任何未来的信息,它们是静态的。无论后续生成多少 token,已有 token 的 K、V 都保持不变,这为缓存提供了理论依据。
5. 内存开销分析
KV Cache 消耗的显存与 batch_size 、seq_length 、层数 、注意力头数等因素成正比。具体公式如下:
Memory = batch_size × seq_len × num_layers × num_heads × head_dim × 2 × bytes_per_elem \text{Memory} = \text{batch\_size} \times \text{seq\_len} \times \text{num\_layers} \times \text{num\_heads} \times \text{head\_dim} \times 2 \times \text{bytes\_per\_elem} Memory=batch_size×seq_len×num_layers×num_heads×head_dim×2×bytes_per_elem
其中 2 2 2 表示同时缓存 K 和 V。
以 LLaMA 7B 为例:
num_layers = 32,num_heads = 32,head_dim = 128- 每 token 每层的 K+V 参数量 = 2 × 32 × 128 = 8192 2 \times 32 \times 128 = 8192 2×32×128=8192
- 使用 FP16(2 bytes/param)时,每 token 每层占用 8192 × 2 = 16 KB 8192 \times 2 = 16\ \text{KB} 8192×2=16 KB
- 32 层共 32 × 16 KB = 512 KB 32 \times 16\ \text{KB} = 512\ \text{KB} 32×16 KB=512 KB 每 token
当 batch_size = 64,seq_len = 2048 时:
Memory = 64 × 2048 × 512 KB = 64 GB \text{Memory} = 64 \times 2048 \times 512\ \text{KB} = 64\ \text{GB} Memory=64×2048×512 KB=64 GB
这远超模型参数本身的内存(7B FP16 约 14 GB)。因此,KV Cache 是 LLM 推理时的显存瓶颈,尤其在长上下文和大 batch 场景下。
6. 实现细节与代码示例
在工程实现中,KV Cache 通常以 past_key_values 的形式在各个解码层之间传递。以下是一个极简的 PyTorch 风格伪代码:
python
class DecoderLayer:
def __init__(self, d_model, n_head):
self.Wq = nn.Linear(d_model, d_model)
self.Wk = nn.Linear(d_model, d_model)
self.Wv = nn.Linear(d_model, d_model)
# ... 其他组件
def forward(self, x, past_kv=None):
q = self.Wq(x) # [seq_len, d_model]
k = self.Wk(x)
v = self.Wv(x)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=0) # 沿序列长度拼接
v = torch.cat([past_v, v], dim=0)
# 注意力计算(带因果掩码)
attn_out = scaled_dot_product_attention(q, k, v, mask=causal_mask)
return attn_out, (k, v)
# 推理循环
past_kvs = [None] * num_layers # 每层单独缓存
for _ in range(max_new_tokens):
for i, layer in enumerate(layers):
x, past_kvs[i] = layer(x, past_kv=past_kvs[i])
# 采样下一个 token
next_token = sample(x[-1])
x = next_token.unsqueeze(0) # 下一步仅输入新 token
在实际框架(如 Hugging Face Transformers)中,past_key_values 是一个长度为 num_layers 的 tuple,每个元素是 (K, V) 两个张量,形状为 [batch_size, num_heads, past_seq_len, head_dim]。
7. 优化变体:降低 KV Cache 内存
为了缓解显存压力,工业界和学术界提出了多种优化方案。
7.1 Multi-Query Attention (MQA)
所有 Q 头共享同一组 K 和 V 。因此 KV Cache 大小降为原来的 1 num_heads \frac{1}{\text{num\_heads}} num_heads1。例如 PaLM、Falcon 等模型采用 MQA,推理速度提升明显,但可能带来轻微的质量损失。
7.2 Grouped-Query Attention (GQA)
将 Q 头分组,每组共享一个 K、V 头。这是 MQA 和标准多头注意力的折中方案。例如 LLaMA 2/3 中,7B 模型使用 32 个 Q 头、8 个 KV 头,缓存大小降至 1 4 \frac{1}{4} 41。Gemma、DeepSeek 等模型也广泛采用 GQA。
7.3 PagedAttention
将 KV Cache 分割成固定大小的"页",以非连续方式存储。这可以消除显存碎片,并支持变长序列的高效批处理。由 vLLM 项目提出,目前是生产环境推理的标准技术之一。
7.4 量化缓存
对 KV Cache 使用 INT8 甚至 INT4 量化。例如 KIVI、KVQuant 等工作表明,适度量化(每 channel 或 per-token)可以在保持生成质量的同时减少 50%~75% 的缓存占用。
7.5 滑动窗口与动态缓存
只保留最近 w w w 个 token 的 K、V,丢弃更早的信息。适用于局部依赖较强的任务(如代码补全、流式生成)。Longformer、BigBird 等模型原生支持窗口注意力,但会丧失长距离建模能力。
8. 局限性
- 显存消耗:长序列(如 100k tokens)即使单 batch 推理,KV Cache 也可能填满多张 80GB A100。
- 批处理不友好:不同请求的序列长度差异导致缓存形状不规则,需借助 PagedAttention 等技术。
- 首 token 延迟:第一个 token(prefill)仍需计算完整序列的 K、V,无法加速。
9. 总结
| 方面 | 说明 |
|---|---|
| 作用 | 消除自回归生成中历史 token 的重复 K/V 计算,将单步复杂度从 O ( t 2 ) \mathcal{O}(t^2) O(t2) 降至 O ( t ) \mathcal{O}(t) O(t) |
| 代价 | 显存占用随序列长度线性增长,成为长上下文推理的主要瓶颈 |
| 实现形式 | Hugging Face past_key_values,或手工维护每层 (K, V) 张量 |
| 优化方向 | MQA/GQA、量化、PagedAttention、滑动窗口 |
| 适用范围 | 所有基于 Transformer Decoder 的自回归模型(GPT、LLaMA、ChatGLM、Qwen 等) |
理解 KV Cache 是分析 LLM 推理速度、显存占用,以及设计高效部署方案的基础。结合 FlashAttention、Continuous Batching 等技术,可以进一步释放硬件潜力。