KV Cache 详解

KV Cache: 大模型自回归生成的关键优化技术

KV Cache(键值缓存)是 Transformer 解码器在自回归推理阶段普遍采用的一种工程优化。它通过缓存每一层注意力计算中历史 token 的 KeyValue 矩阵,避免每一步都重复计算,显著降低推理延迟,是 LLM 高性能服务的基础组件。

1. 背景:自回归生成中的重复计算问题

大型语言模型(LLM)在生成文本时采用自回归方式:逐个生成 token,并将新 token 拼接到输入序列末尾,作为下一步的输入。

假设已经生成了 t − 1 t-1 t−1 个 token,当前输入序列长度为 t − 1 t-1 t−1,下一步需要生成第 t t t 个 token。在未使用 KV Cache 的情况下,每一步都要重新计算整个序列 [ x 1 , x 2 , . . . , x t − 1 , x t ] [x_1, x_2, ..., x_{t-1}, x_t] [x1,x2,...,xt−1,xt] 的所有注意力 K 和 V。而实际上,前 t − 1 t-1 t−1 个 token 的 K 和 V 在前几步已经计算过,完全可以直接复用。

不加缓存时,第 t t t 步的计算复杂度为 O ( t 2 ) \mathcal{O}(t^2) O(t2)(因为需要计算 t × t t \times t t×t 的注意力矩阵)。随着序列长度 t t t 增大,计算开销呈平方增长,使推理变得极其缓慢。

2. KV Cache 的核心思想

只计算当前新 token 的 Query(Q)、Key(K)、Value(V),并复用之前所有 token 已缓存的 K 和 V。

  • 缓存内容:每一层 Transformer 中,所有历史 token 的 Key 矩阵和 Value 矩阵。
  • 更新方式:每生成一个新 token,计算其 K、V 并追加到缓存中。
  • 注意力计算:使用当前 token 的 Q 与缓存中所有历史 K 计算注意力分数,再用该分数加权缓存中所有历史 V。

这样,第 t t t 步的计算复杂度从 O ( t 2 ) \mathcal{O}(t^2) O(t2) 降为 O ( t ) \mathcal{O}(t) O(t),因为只需要对新 token 计算一次 Q,并与长度为 t t t 的 K 序列做点积(即 Q K T Q K^T QKT 的其中一行)。

3. 工作机制与流程

以单层自注意力为例。假设已经生成了 t − 1 t-1 t−1 个 token,缓存了:

  • K cache = [ k 1 , k 2 , ... , k t − 1 ] K_{\text{cache}} = [k_1, k_2, \dots, k_{t-1}] Kcache=[k1,k2,...,kt−1]
  • V cache = [ v 1 , v 2 , ... , v t − 1 ] V_{\text{cache}} = [v_1, v_2, \dots, v_{t-1}] Vcache=[v1,v2,...,vt−1]

第 t t t 步生成新 token x t x_t xt 的过程:

  1. 将 x t x_t xt 通过 embedding 层和位置编码得到向量表示。
  2. 通过线性变换计算该 token 的查询 q t q_t qt、键 k t k_t kt、值 v t v_t vt:
    q t = x t W Q , k t = x t W K , v t = x t W V q_t = x_t W_Q,\quad k_t = x_t W_K,\quad v_t = x_t W_V qt=xtWQ,kt=xtWK,vt=xtWV
  3. 更新缓存:
    K cache ← concat ( K cache , k t ) K_{\text{cache}} \leftarrow \text{concat}(K_{\text{cache}}, k_t) Kcache←concat(Kcache,kt)
    V cache ← concat ( V cache , v t ) V_{\text{cache}} \leftarrow \text{concat}(V_{\text{cache}}, v_t) Vcache←concat(Vcache,vt)
  4. 计算当前 token 的注意力输出:
    • 注意力分数 s t = q t ⋅ [ k 1 , k 2 , ... , k t ] T s_t = q_t \cdot [k_1, k_2, \dots, k_t]^T st=qt⋅[k1,k2,...,kt]T(长度为 t t t 的行向量)。
    • 权重 a t = softmax ( s t / d k ) a_t = \text{softmax}(s_t / \sqrt{d_k}) at=softmax(st/dk ),其中 d k d_k dk 是 head 维度。
    • 输出 o t = a t ⋅ V cache o_t = a_t \cdot V_{\text{cache}} ot=at⋅Vcache。
  5. 经过 FFN 等后续层,得到输出 logits,采样得到下一个 token x t + 1 x_{t+1} xt+1。

每一步只计算一个 q t q_t qt,因此注意力计算的复杂度为 O ( t ) \mathcal{O}(t) O(t)。注意:第一个 token(prefill 阶段)仍需计算完整序列,这一步无法加速。

4. 为什么可以安全地缓存 K 和 V?

在 Transformer Decoder 中,通过因果注意力掩码 (causal attention mask)确保每个 token 只能看到它之前(包括自己)的 token,不能看到未来 token。因此,历史 token 的 K 和 V 不会依赖任何未来的信息,它们是静态的。无论后续生成多少 token,已有 token 的 K、V 都保持不变,这为缓存提供了理论依据。

5. 内存开销分析

KV Cache 消耗的显存与 batch_sizeseq_length层数注意力头数等因素成正比。具体公式如下:

Memory = batch_size × seq_len × num_layers × num_heads × head_dim × 2 × bytes_per_elem \text{Memory} = \text{batch\_size} \times \text{seq\_len} \times \text{num\_layers} \times \text{num\_heads} \times \text{head\_dim} \times 2 \times \text{bytes\_per\_elem} Memory=batch_size×seq_len×num_layers×num_heads×head_dim×2×bytes_per_elem

其中 2 2 2 表示同时缓存 K 和 V。

LLaMA 7B 为例:

  • num_layers = 32num_heads = 32head_dim = 128
  • 每 token 每层的 K+V 参数量 = 2 × 32 × 128 = 8192 2 \times 32 \times 128 = 8192 2×32×128=8192
  • 使用 FP16(2 bytes/param)时,每 token 每层占用 8192 × 2 = 16 KB 8192 \times 2 = 16\ \text{KB} 8192×2=16 KB
  • 32 层共 32 × 16 KB = 512 KB 32 \times 16\ \text{KB} = 512\ \text{KB} 32×16 KB=512 KB 每 token

batch_size = 64seq_len = 2048 时:
Memory = 64 × 2048 × 512 KB = 64 GB \text{Memory} = 64 \times 2048 \times 512\ \text{KB} = 64\ \text{GB} Memory=64×2048×512 KB=64 GB

这远超模型参数本身的内存(7B FP16 约 14 GB)。因此,KV Cache 是 LLM 推理时的显存瓶颈,尤其在长上下文和大 batch 场景下。

6. 实现细节与代码示例

在工程实现中,KV Cache 通常以 past_key_values 的形式在各个解码层之间传递。以下是一个极简的 PyTorch 风格伪代码:

python 复制代码
class DecoderLayer:
    def __init__(self, d_model, n_head):
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        # ... 其他组件

    def forward(self, x, past_kv=None):
        q = self.Wq(x)   # [seq_len, d_model]
        k = self.Wk(x)
        v = self.Wv(x)

        if past_kv is not None:
            past_k, past_v = past_kv
            k = torch.cat([past_k, k], dim=0)   # 沿序列长度拼接
            v = torch.cat([past_v, v], dim=0)

        # 注意力计算(带因果掩码)
        attn_out = scaled_dot_product_attention(q, k, v, mask=causal_mask)
        return attn_out, (k, v)

# 推理循环
past_kvs = [None] * num_layers   # 每层单独缓存
for _ in range(max_new_tokens):
    for i, layer in enumerate(layers):
        x, past_kvs[i] = layer(x, past_kv=past_kvs[i])
    # 采样下一个 token
    next_token = sample(x[-1])
    x = next_token.unsqueeze(0)   # 下一步仅输入新 token
在实际框架(如 Hugging Face Transformers)中,past_key_values 是一个长度为 num_layers 的 tuple,每个元素是 (K, V) 两个张量,形状为 [batch_size, num_heads, past_seq_len, head_dim]。

7. 优化变体:降低 KV Cache 内存

为了缓解显存压力,工业界和学术界提出了多种优化方案。

7.1 Multi-Query Attention (MQA)

所有 Q 头共享同一组 K 和 V 。因此 KV Cache 大小降为原来的 1 num_heads \frac{1}{\text{num\_heads}} num_heads1。例如 PaLM、Falcon 等模型采用 MQA,推理速度提升明显,但可能带来轻微的质量损失。

7.2 Grouped-Query Attention (GQA)

将 Q 头分组,每组共享一个 K、V 头。这是 MQA 和标准多头注意力的折中方案。例如 LLaMA 2/3 中,7B 模型使用 32 个 Q 头、8 个 KV 头,缓存大小降至 1 4 \frac{1}{4} 41。Gemma、DeepSeek 等模型也广泛采用 GQA。

7.3 PagedAttention

将 KV Cache 分割成固定大小的"页",以非连续方式存储。这可以消除显存碎片,并支持变长序列的高效批处理。由 vLLM 项目提出,目前是生产环境推理的标准技术之一。

7.4 量化缓存

对 KV Cache 使用 INT8 甚至 INT4 量化。例如 KIVI、KVQuant 等工作表明,适度量化(每 channel 或 per-token)可以在保持生成质量的同时减少 50%~75% 的缓存占用。

7.5 滑动窗口与动态缓存

只保留最近 w w w 个 token 的 K、V,丢弃更早的信息。适用于局部依赖较强的任务(如代码补全、流式生成)。Longformer、BigBird 等模型原生支持窗口注意力,但会丧失长距离建模能力。

8. 局限性

  • 显存消耗:长序列(如 100k tokens)即使单 batch 推理,KV Cache 也可能填满多张 80GB A100。
  • 批处理不友好:不同请求的序列长度差异导致缓存形状不规则,需借助 PagedAttention 等技术。
  • 首 token 延迟:第一个 token(prefill)仍需计算完整序列的 K、V,无法加速。

9. 总结

方面 说明
作用 消除自回归生成中历史 token 的重复 K/V 计算,将单步复杂度从 O ( t 2 ) \mathcal{O}(t^2) O(t2) 降至 O ( t ) \mathcal{O}(t) O(t)
代价 显存占用随序列长度线性增长,成为长上下文推理的主要瓶颈
实现形式 Hugging Face past_key_values,或手工维护每层 (K, V) 张量
优化方向 MQA/GQA、量化、PagedAttention、滑动窗口
适用范围 所有基于 Transformer Decoder 的自回归模型(GPT、LLaMA、ChatGLM、Qwen 等)

理解 KV Cache 是分析 LLM 推理速度、显存占用,以及设计高效部署方案的基础。结合 FlashAttention、Continuous Batching 等技术,可以进一步释放硬件潜力。

参考文献与延伸阅读

相关推荐
AI精钢23 天前
DeepSeek KV Cache 入门解读:98% 命中率背后的工程逻辑
大模型·llm推理·kv cache·deepseek·ai工程
Luchang-Li1 个月前
不同架构模型KV Cache大小计算
kv cache
阿杰学AI1 个月前
AI核心知识123—大语言模型之 KV Cache
人工智能·ai·语言模型·自然语言处理·aigc·kv cache·键值缓存
handsomestWei2 个月前
KV Cache与vLLM、SGLang推理框架
vllm·推理框架·kv cache·sglang
lin_dec+2 个月前
KV Cache:大模型推理加速的关键技术
nlp·transformer·vllm·大模型推理·kv cache
一顿能吃五大海碗啊啊啊2 个月前
大模型推理加速 KV cache
mha·gqa·mqa·kv cache
dawdo2224 个月前
自己动手从头开始编写LLM推理引擎(9)-KV缓存实现和优化
缓存·llm·transformer·qwen·kv cache
被制作时长两年半的个人练习生4 个月前
KV Cache
kv cache
enjoy编程5 个月前
Spring AI 大模型工程核心:效率的极限博弈
注意力机制·flashattention·kv cache·pd分离·pagedattention·epd分离·radixattention