讲透Transformer(五):Self-Attention与KV Cache的深度解析——从原理到实现

Self-Attention 与 KV Cache 的深度解析

  • [Self-Attention 计算过程详解](#Self-Attention 计算过程详解)
  • 自回归生成中的计算冗余
    • [2.1 自回归生成的特点](#2.1 自回归生成的特点)
    • [2.2 冗余计算分析](#2.2 冗余计算分析)
  • [KV Cache:消除冗余的优雅方案](#KV Cache:消除冗余的优雅方案)
    • [3.1 核心思想](#3.1 核心思想)
    • [3.2 带 KV Cache 的计算流程](#3.2 带 KV Cache 的计算流程)
      • [初始化阶段(Prefill Phase)](#初始化阶段(Prefill Phase))
      • [生成阶段(Generation Phase)](#生成阶段(Generation Phase))
    • [3.3 计算量对比](#3.3 计算量对比)
      • [无 KV Cache](#无 KV Cache)
      • [有 KV Cache](#有 KV Cache)
      • 加速比
  • 数学形式化推导
    • [4.1 无 Cache 的注意力计算](#4.1 无 Cache 的注意力计算)
    • [4.2 有 Cache 的注意力计算](#4.2 有 Cache 的注意力计算)
    • [4.3 重要观察](#4.3 重要观察)
  • 代码实现详解
    • [5.1 不带 KV Cache 的朴素实现](#5.1 不带 KV Cache 的朴素实现)
    • [5.2 带 KV Cache 的高效实现](#5.2 带 KV Cache 的高效实现)
    • [5.3 性能对比测试](#5.3 性能对比测试)
  • [KV Cache 的内存占用分析](#KV Cache 的内存占用分析)
    • [6.1 内存计算公式](#6.1 内存计算公式)
    • [6.2 不同模型的 KV Cache 占用](#6.2 不同模型的 KV Cache 占用)
  • [KV Cache 的变体与优化](#KV Cache 的变体与优化)
    • [7.1 多层缓存复用](#7.1 多层缓存复用)
    • [7.2 稀疏缓存](#7.2 稀疏缓存)
    • [7.3 分页缓存 (PagedAttention)](#7.3 分页缓存 (PagedAttention))
  • 工程实践与注意事项
    • [8.1 何时使用 KV Cache](#8.1 何时使用 KV Cache)
    • [8.2 内存管理技巧](#8.2 内存管理技巧)
    • [8.3 批处理中的 KV Cache](#8.3 批处理中的 KV Cache)
  • [KV Cache 与注意力机制的演进关系](#KV Cache 与注意力机制的演进关系)
    • [9.1 技术演进路线](#9.1 技术演进路线)
    • [9.2 数学本质的统一视角](#9.2 数学本质的统一视角)
  • 总结
    • [10.1 核心要点](#10.1 核心要点)
    • [10.2 一句话总结](#10.2 一句话总结)
    • 参考文献

在 Transformer 模型的推理加速技术中,KV Cache 是最核心、最基础的优化手段。要深入理解 KV Cache 为何能带来数十倍的推理加速,必须先透彻理解 Self-Attention 的计算过程。本文将带你从零开始,逐步推导两者之间的紧密联系。


Self-Attention 计算过程详解

1.1 全量计算视角

假设我们有一个输入序列,长度为 n n n,每个 token 的维度为 d d d。在标准的 Self-Attention 计算中,我们需要为序列中的每个位置计算与其他所有位置的相关性。

输入表示

X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d

其中 X X X 是输入序列的 token 嵌入矩阵, n n n 是序列长度, d d d 是隐藏层维度。

生成 Q、K、V 矩阵

通过三个线性变换矩阵,我们将 X X X 映射到查询空间、键空间和值空间:

Q = X W Q , K = X W K , V = X W V Q = XW^Q, \quad K = XW^K, \quad V = XW^V Q=XWQ,K=XWK,V=XWV

其中:

  • W Q ∈ R d × d k W^Q \in \mathbb{R}^{d \times d_k} WQ∈Rd×dk, W K ∈ R d × d k W^K \in \mathbb{R}^{d \times d_k} WK∈Rd×dk, W V ∈ R d × d v W^V \in \mathbb{R}^{d \times d_v} WV∈Rd×dv
  • 通常 d k = d v = d / h d_k = d_v = d / h dk=dv=d/h, h h h 是注意力头数
  • 输出维度: Q , K ∈ R n × d k Q, K \in \mathbb{R}^{n \times d_k} Q,K∈Rn×dk, V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} V∈Rn×dv

注意力分数计算

计算 Query 和所有 Key 的点积,得到注意力分数矩阵:

S = Q K T ∈ R n × n S = QK^T \in \mathbb{R}^{n \times n} S=QKT∈Rn×n

展开形式:

S i j = Q i ⋅ K j T = ∑ m = 1 d k Q i , m ⋅ K j , m S_{ij} = Q_i \cdot K_j^T = \sum_{m=1}^{d_k} Q_{i,m} \cdot K_{j,m} Sij=Qi⋅KjT=m=1∑dkQi,m⋅Kj,m

缩放与 Softmax

为了防止点积结果过大导致梯度消失,除以 d k \sqrt{d_k} dk 并进行 Softmax 归一化:

A = softmax ( S d k ) ∈ R n × n A = \text{softmax}\left(\frac{S}{\sqrt{d_k}}\right) \in \mathbb{R}^{n \times n} A=softmax(dk S)∈Rn×n

其中 Softmax 按行进行:

A i j = exp ⁡ ( S i j / d k ) ∑ k = 1 n exp ⁡ ( S i k / d k ) A_{ij} = \frac{\exp(S_{ij}/\sqrt{d_k})}{\sum_{k=1}^{n} \exp(S_{ik}/\sqrt{d_k})} Aij=∑k=1nexp(Sik/dk )exp(Sij/dk )

加权求和

用注意力权重对 Value 进行加权求和:

O = A V ∈ R n × d v O = AV \in \mathbb{R}^{n \times d_v} O=AV∈Rn×dv

展开形式:

O i = ∑ j = 1 n A i j V j O_i = \sum_{j=1}^{n} A_{ij} V_j Oi=j=1∑nAijVj

计算复杂度分析

上述过程的计算复杂度为:

操作 计算量 内存占用
Q K T QK^T QKT O ( n 2 ⋅ d k ) O(n^2 \cdot d_k) O(n2⋅dk) O ( n 2 ) O(n^2) O(n2)
Softmax O ( n 2 ) O(n^2) O(n2) O ( n 2 ) O(n^2) O(n2)
A V A V AV O ( n 2 ⋅ d v ) O(n^2 \cdot d_v) O(n2⋅dv) O ( n ⋅ d v ) O(n \cdot d_v) O(n⋅dv)
总计 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) O ( n 2 ) O(n^2) O(n2)

当 n n n 很大时(如 32K、128K), n 2 n^2 n2 项会成为不可承受之重。


自回归生成中的计算冗余

2.1 自回归生成的特点

在文本生成任务中,模型采用自回归方式逐个生成 token:

复制代码
输入: "I love"
生成: "I love AI"
步骤:
Step 1: 输入 "I love" → 生成 "AI"
Step 2: 输入 "I love AI" → 生成 "."
Step 3: 输入 "I love AI." → 生成 "<eos>"

关键观察:每一步的输入都包含了上一步的全部输入

2.2 冗余计算分析

以生成第 4 个 token 为例,对比 Step 2 和 Step 3 的计算:

Step 2 输入 : "I love AI"(长度 3)
Step 3 输入: "I love AI."(长度 4)

计算 Step 3 的 Self-Attention 时:

Q 1 : 4 = X 1 : 4 W Q Q_{1:4} = X_{1:4}W^Q Q1:4=X1:4WQ
K 1 : 4 = X 1 : 4 W K K_{1:4} = X_{1:4}W^K K1:4=X1:4WK
V 1 : 4 = X 1 : 4 W V V_{1:4} = X_{1:4}W^V V1:4=X1:4WV

关键问题 :Step 2 中已经计算过 K 1 : 3 K_{1:3} K1:3 和 V 1 : 3 V_{1:3} V1:3,Step 3 却要重新计算一遍!

对于长度为 n n n 的序列,生成第 n + 1 n+1 n+1 个 token 时:

  • 已经计算过的 KV: K 1 : n K_{1:n} K1:n, V 1 : n V_{1:n} V1:n
  • 需要新计算的 KV: K n + 1 K_{n+1} Kn+1, V n + 1 V_{n+1} Vn+1
  • 需要重新计算的全量: K 1 : n K_{1:n} K1:n, V 1 : n V_{1:n} V1:n(如果不用 KV Cache)

这就造成了巨大的计算浪费。随着生成序列变长,浪费的比例趋近于 50%:

浪费比例 ≈ ∑ i = 1 n i ∑ i = 1 n ( i + 1 ) ≈ n ( n + 1 ) / 2 n ( n + 3 ) / 2 ≈ 1 − 2 n + 3 → n → ∞ 1 \text{浪费比例} \approx \frac{\sum_{i=1}^{n} i}{\sum_{i=1}^{n} (i+1)} \approx \frac{n(n+1)/2}{n(n+3)/2} \approx 1 - \frac{2}{n+3} \xrightarrow{n \to \infty} 1 浪费比例≈∑i=1n(i+1)∑i=1ni≈n(n+3)/2n(n+1)/2≈1−n+32n→∞ 1


KV Cache:消除冗余的优雅方案

3.1 核心思想

KV Cache 的核心思想是:用空间换时间 。既然每一步的 K K K 和 V V V 都会被重复使用,那就把它们缓存起来,避免重复计算。

缓存的内容

对于每个 Transformer 层,我们维护两个缓存矩阵:

K_cache = [ K 1 , K 2 , . . . , K t ] ∈ R t × d k \text{K\_cache} = [K_1, K_2, ..., K_t] \in \mathbb{R}^{t \times d_k} K_cache=[K1,K2,...,Kt]∈Rt×dk
V_cache = [ V 1 , V 2 , . . . , V t ] ∈ R t × d v \text{V\_cache} = [V_1, V_2, ..., V_t] \in \mathbb{R}^{t \times d_v} V_cache=[V1,V2,...,Vt]∈Rt×dv

其中 t t t 是当前已生成的 token 数量。

3.2 带 KV Cache 的计算流程

初始化阶段(Prefill Phase)

处理输入提示词,计算所有 token 的 KV,并缓存:

python 复制代码
# 输入提示词长度 = n
for i in range(n):
    K_i, V_i = compute_KV(X[i])  # 计算第 i 个 token 的 KV
    K_cache.append(K_i)
    V_cache.append(V_i)
    O_i = attention(Q_i, K_cache, V_cache)  # 用缓存计算输出

生成阶段(Generation Phase)

逐 token 生成,每次只计算新 token 的 KV:

python 复制代码
# 已生成 t 个 token
while not finished:
    # 1. 计算新 token 的 Q, K, V
    Q_new = X_new @ W_Q
    K_new = X_new @ W_K
    V_new = X_new @ W_V
    
    # 2. 更新缓存
    K_cache = torch.cat([K_cache, K_new], dim=0)
    V_cache = torch.cat([V_cache, V_new], dim=0)
    
    # 3. 用完整缓存计算注意力
    # Q_new: [1, d_k], K_cache: [t+1, d_k], V_cache: [t+1, d_v]
    scores = Q_new @ K_cache.T  # [1, t+1]
    scores = scores / sqrt(d_k)
    attn_weights = softmax(scores)  # [1, t+1]
    O_new = attn_weights @ V_cache  # [1, d_v]
    
    # 4. 生成下一个 token
    next_token = generate(O_new)
    X_new = embedding(next_token)

3.3 计算量对比

无 KV Cache

生成第 t + 1 t+1 t+1 个 token 时,需要计算:

FLOPs no-cache = 2 ( t + 1 ) 2 d k ⏟ QK T + 2 ( t + 1 ) 2 d v ⏟ AV \text{FLOPs}{\text{no-cache}} = \underbrace{2(t+1)^2 d_k}{\text{QK}^T} + \underbrace{2(t+1)^2 d_v}_{\text{AV}} FLOPsno-cache=QKT 2(t+1)2dk+AV 2(t+1)2dv

有 KV Cache

生成第 t + 1 t+1 t+1 个 token 时,只需要:

KaTeX parse error: Expected 'EOF', got '' at position 64: ...) d_k}{\text{Q_̲new @ K_cache}^...

加速比

Speedup = FLOPs no-cache FLOPs cache ≈ 2 ( t + 1 ) 2 d 2 ( t + 1 ) d = t + 1 \text{Speedup} = \frac{\text{FLOPs}{\text{no-cache}}}{\text{FLOPs}{\text{cache}}} \approx \frac{2(t+1)^2 d}{2(t+1)d} = t+1 Speedup=FLOPscacheFLOPsno-cache≈2(t+1)d2(t+1)2d=t+1

当 t = 1024 t=1024 t=1024 时,加速比达到 1024 倍!这就是 KV Cache 能带来数量级性能提升的原因。


数学形式化推导

4.1 无 Cache 的注意力计算

对于长度为 t t t 的序列,注意力输出为:

O ( t ) = softmax ( Q ( t ) K ( t ) T d k ) V ( t ) O^{(t)} = \text{softmax}\left(\frac{Q^{(t)} {K^{(t)}}^T}{\sqrt{d_k}}\right) V^{(t)} O(t)=softmax(dk Q(t)K(t)T)V(t)

其中:

  • Q ( t ) = X ( t ) W Q ∈ R t × d k Q^{(t)} = X^{(t)} W^Q \in \mathbb{R}^{t \times d_k} Q(t)=X(t)WQ∈Rt×dk
  • K ( t ) = X ( t ) W K ∈ R t × d k K^{(t)} = X^{(t)} W^K \in \mathbb{R}^{t \times d_k} K(t)=X(t)WK∈Rt×dk
  • V ( t ) = X ( t ) W V ∈ R t × d v V^{(t)} = X^{(t)} W^V \in \mathbb{R}^{t \times d_v} V(t)=X(t)WV∈Rt×dv

4.2 有 Cache 的注意力计算

当生成第 t + 1 t+1 t+1 个 token 时,我们已有缓存:

K cache = K ( t ) ∈ R t × d k K_{\text{cache}} = K^{(t)} \in \mathbb{R}^{t \times d_k} Kcache=K(t)∈Rt×dk
V cache = V ( t ) ∈ R t × d v V_{\text{cache}} = V^{(t)} \in \mathbb{R}^{t \times d_v} Vcache=V(t)∈Rt×dv

新 token 的 Query、Key、Value:

Q t + 1 = x t + 1 W Q ∈ R 1 × d k Q_{t+1} = x_{t+1} W^Q \in \mathbb{R}^{1 \times d_k} Qt+1=xt+1WQ∈R1×dk
K t + 1 = x t + 1 W K ∈ R 1 × d k K_{t+1} = x_{t+1} W^K \in \mathbb{R}^{1 \times d_k} Kt+1=xt+1WK∈R1×dk
V t + 1 = x t + 1 W V ∈ R 1 × d v V_{t+1} = x_{t+1} W^V \in \mathbb{R}^{1 \times d_v} Vt+1=xt+1WV∈R1×dv

更新缓存:

K cache ′ = [ K cache ; K t + 1 ] ∈ R ( t + 1 ) × d k K_{\text{cache}}' = [K_{\text{cache}}; K_{t+1}] \in \mathbb{R}^{(t+1) \times d_k} Kcache′=[Kcache;Kt+1]∈R(t+1)×dk
V cache ′ = [ V cache ; V t + 1 ] ∈ R ( t + 1 ) × d v V_{\text{cache}}' = [V_{\text{cache}}; V_{t+1}] \in \mathbb{R}^{(t+1) \times d_v} Vcache′=[Vcache;Vt+1]∈R(t+1)×dv

计算注意力:

O t + 1 = softmax ( Q t + 1 K cache ′ T d k ) V cache ′ O_{t+1} = \text{softmax}\left(\frac{Q_{t+1} {K_{\text{cache}}'}^T}{\sqrt{d_k}}\right) V_{\text{cache}}' Ot+1=softmax(dk Qt+1Kcache′T)Vcache′

展开点积项:

Q t + 1 K cache ′ T = [ Q t + 1 K cache T , Q t + 1 K t + 1 T ] ∈ R 1 × ( t + 1 ) Q_{t+1} {K_{\text{cache}}'}^T = [Q_{t+1} K_{\text{cache}}^T, Q_{t+1} K_{t+1}^T] \in \mathbb{R}^{1 \times (t+1)} Qt+1Kcache′T=[Qt+1KcacheT,Qt+1Kt+1T]∈R1×(t+1)

4.3 重要观察

注意到 Q t + 1 K cache T Q_{t+1} K_{\text{cache}}^T Qt+1KcacheT 中包含了新 token 与所有历史 token 的相关性,而 Q t + 1 K t + 1 T Q_{t+1} K_{t+1}^T Qt+1Kt+1T 是自相关性。这正是 Transformer 能够捕捉长距离依赖的关键------新 token 可以看到所有历史信息


代码实现详解

5.1 不带 KV Cache 的朴素实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time

class SelfAttentionWithoutCache(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 线性变换矩阵
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = x.shape
        
        # 1. 计算 Q, K, V
        Q = self.W_q(x)  # [batch, seq, d_model]
        K = self.W_k(x)  # [batch, seq, d_model]
        V = self.W_v(x)  # [batch, seq, d_model]
        
        # 2. 分头
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        # Q: [batch, n_heads, seq_len, d_k]
        
        # 3. 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        # scores: [batch, n_heads, seq_len, seq_len]
        
        # 4. Softmax
        attn_weights = F.softmax(scores, dim=-1)
        
        # 5. 加权求和
        context = torch.matmul(attn_weights, V)
        # context: [batch, n_heads, seq_len, d_k]
        
        # 6. 合并头
        context = context.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.d_model
        )
        
        # 7. 输出投影
        output = self.W_o(context)
        
        return output

5.2 带 KV Cache 的高效实现

python 复制代码
class SelfAttentionWithCache(nn.Module):
    def __init__(self, d_model, n_heads, max_seq_len=2048):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.max_seq_len = max_seq_len
        
        # 线性变换矩阵
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # 初始化缓存
        self.reset_cache()
        
    def reset_cache(self):
        """重置 KV 缓存"""
        self.K_cache = None
        self.V_cache = None
        
    def forward(self, x, use_cache=False):
        """
        x: [batch_size, seq_len, d_model]
        如果 use_cache=True,假设 x 只包含新 token
        """
        batch_size, seq_len, _ = x.shape
        
        # 计算 Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 分头
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        if use_cache and self.K_cache is not None:
            # 使用缓存的 K, V
            # K_cache, V_cache: [batch, n_heads, cached_len, d_k]
            K = torch.cat([self.K_cache, K], dim=2)
            V = torch.cat([self.V_cache, V], dim=2)
        
        # 更新缓存
        self.K_cache = K
        self.V_cache = V
        
        # 计算注意力分数(Q 只关注新 token 的 query)
        if use_cache and seq_len == 1:
            # 生成阶段:只计算新 token 的 query
            # Q: [batch, n_heads, 1, d_k]
            # K: [batch, n_heads, cached_len, d_k]
            scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
            # scores: [batch, n_heads, 1, cached_len]
        else:
            # 预填充阶段:计算所有 token 的 query
            scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
            # scores: [batch, n_heads, seq_len, seq_len+cached_len]
        
        # 应用 causal mask(确保不能看到未来 token)
        if not use_cache or seq_len > 1:
            mask = torch.triu(torch.ones_like(scores), diagonal=1).bool()
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)
        
        # 合并头
        context = context.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        output = self.W_o(context)
        
        return output

5.3 性能对比测试

python 复制代码
def test_performance():
    d_model = 512
    n_heads = 8
    batch_size = 1
    
    model_without_cache = SelfAttentionWithoutCache(d_model, n_heads)
    model_with_cache = SelfAttentionWithCache(d_model, n_heads)
    
    # 测试不同序列长度
    for seq_len in [128, 256, 512, 1024, 2048]:
        x = torch.randn(batch_size, seq_len, d_model)
        
        # 无 Cache
        start = time.time()
        _ = model_without_cache(x)
        time_without = time.time() - start
        
        # 有 Cache (模拟生成过程)
        model_with_cache.reset_cache()
        start = time.time()
        
        # 预填充阶段
        _ = model_with_cache(x, use_cache=True)
        
        # 生成 100 个新 token
        for _ in range(100):
            new_token = torch.randn(batch_size, 1, d_model)
            _ = model_with_cache(new_token, use_cache=True)
            
        time_with_cache = time.time() - start
        
        print(f"Seq Len {seq_len}:")
        print(f"  Without Cache: {time_without*1000:.2f} ms (prefill only)")
        print(f"  With Cache: {time_with_cache*1000:.2f} ms (prefill + 100 gen)")
        print(f"  Speedup: {time_without*100/(time_with_cache/100):.1f}x per token")
        print()

预期输出

复制代码
Seq Len 1024:
  Without Cache: 45.32 ms (prefill only)
  With Cache: 68.45 ms (prefill + 100 gen)
  Speedup: 66.2x per token

KV Cache 的内存占用分析

6.1 内存计算公式

KV Cache 的内存占用可以精确计算:

Memory KV = 2 × n layers × n heads × d head × seq_len × bytes_per_param \text{Memory}{\text{KV}} = 2 \times n{\text{layers}} \times n_{\text{heads}} \times d_{\text{head}} \times \text{seq\_len} \times \text{bytes\_per\_param} MemoryKV=2×nlayers×nheads×dhead×seq_len×bytes_per_param

其中:

  • n layers n_{\text{layers}} nlayers:Transformer 层数
  • n heads n_{\text{heads}} nheads:注意力头数
  • d head d_{\text{head}} dhead:每个头的维度
  • seq_len \text{seq\_len} seq_len:当前序列长度
  • bytes_per_param \text{bytes\_per\_param} bytes_per_param:参数精度(FP16=2 bytes,FP32=4 bytes)

实际计算示例

以 LLaMA-7B 为例:

  • n layers = 32 n_{\text{layers}} = 32 nlayers=32
  • n heads = 32 n_{\text{heads}} = 32 nheads=32
  • d head = 128 d_{\text{head}} = 128 dhead=128
  • FP16 精度:2 bytes

当 seq_len = 2048 \text{seq\_len} = 2048 seq_len=2048 时:

Memory = 2 × 32 × 32 × 128 × 2048 × 2 bytes \text{Memory} = 2 \times 32 \times 32 \times 128 \times 2048 \times 2 \text{ bytes} Memory=2×32×32×128×2048×2 bytes
= 2 × 32 × 32 × 128 × 2048 × 2 = 2 \times 32 \times 32 \times 128 \times 2048 \times 2 =2×32×32×128×2048×2
= 2 × 32 × 32 × 524288 = 2 \times 32 \times 32 \times 524288 =2×32×32×524288
= 2 × 32 × 16777216 = 2 \times 32 \times 16777216 =2×32×16777216
= 2 × 536870912 = 2 \times 536870912 =2×536870912
= 1 , 073 , 741 , 824 bytes ≈ 1.07 GB = 1,073,741,824 \text{ bytes} \approx 1.07 \text{ GB} =1,073,741,824 bytes≈1.07 GB

当 seq_len = 1 M \text{seq\_len} = 1M seq_len=1M 时:

Memory ≈ 1.07 GB × 1 M 2048 ≈ 534 GB \text{Memory} \approx 1.07 \text{ GB} \times \frac{1M}{2048} \approx 534 \text{ GB} Memory≈1.07 GB×20481M≈534 GB

这就是为什么超长上下文需要 GQA、MLA 等技术来压缩 KV Cache!

6.2 不同模型的 KV Cache 占用

模型 层数 头数 头维度 8K 上下文 32K 上下文 128K 上下文
LLaMA-7B 32 32 128 4.2 GB 16.8 GB 67.2 GB
LLaMA-70B 80 64 128 21.0 GB 84.0 GB 336 GB
Qwen2-7B (GQA) 32 8 (GQA) 128 1.05 GB 4.2 GB 16.8 GB
DeepSeek-V2 (MLA) 60 16 (压缩) 128 0.26 GB 1.05 GB 4.2 GB

KV Cache 的变体与优化

7.1 多层缓存复用

除了每层的 KV Cache,现代推理引擎还实现了更高级的缓存策略:

python 复制代码
class MultiLayerCache:
    def __init__(self, num_layers):
        self.layer_caches = [LayerCache() for _ in range(num_layers)]
        
    def reuse_across_layers(self, layer_idx, k, v):
        """某些情况下,浅层的 KV 可以复用于深层"""
        if layer_idx > 0 and self.similarity_high(layer_idx):
            return self.layer_caches[layer_idx-1].get()
        return k, v

7.2 稀疏缓存

对于超长上下文,可以只缓存重要的 token:

python 复制代码
class SparseKVCache:
    def __init__(self, max_cache_size=2048):
        self.max_cache_size = max_cache_size
        self.k_cache = []
        self.v_cache = []
        self.scores = []
        
    def update(self, k_new, v_new, importance_score):
        # 只保留最重要的 token
        self.k_cache.append(k_new)
        self.v_cache.append(v_new)
        self.scores.append(importance_score)
        
        if len(self.k_cache) > self.max_cache_size:
            # 移除最不重要的 token
            min_idx = argmin(self.scores)
            self.k_cache.pop(min_idx)
            self.v_cache.pop(min_idx)
            self.scores.pop(min_idx)

7.3 分页缓存 (PagedAttention)

vLLM 等推理引擎实现了类似虚拟内存的分页机制:

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    PagedAttention 结构                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  逻辑 KV 块:  [Block 0] [Block 1] [Block 2] ...            │
│                │         │         │                        │
│                ▼         ▼         ▼                        │
│  物理内存:  [Page 5] [Page 12] [Page 3] ...                │
│                                                             │
│  优点:消除内部碎片,支持动态序列长度                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

工程实践与注意事项

8.1 何时使用 KV Cache

场景 是否使用 KV Cache 原因
训练 ❌ 不使用 数据是并行的,无需缓存
预填充阶段 ✅ 使用 初始化缓存
生成阶段 ✅ 使用 大幅加速推理
批处理推理 ✅ 使用 每个序列独立缓存

8.2 内存管理技巧

python 复制代码
class OptimizedAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 使用连续的张量提高访存效率
        self.register_buffer('K_cache', None)
        self.register_buffer('V_cache', None)
        
    def maybe_compress_cache(self, max_cache_len=4096):
        """缓存超过阈值时进行压缩"""
        if self.K_cache.size(2) > max_cache_len:
            # 平均池化压缩
            self.K_cache = F.avg_pool1d(
                self.K_cache.transpose(1, 2), 
                kernel_size=2, stride=2
            ).transpose(1, 2)
            self.V_cache = F.avg_pool1d(
                self.V_cache.transpose(1, 2),
                kernel_size=2, stride=2
            ).transpose(1, 2)

8.3 批处理中的 KV Cache

python 复制代码
class BatchedAttentionWithCache(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        # ... 初始化 ...
        
    def forward(self, x, cache=None, cache_lengths=None):
        """
        x: [batch_size, seq_len, d_model]
        cache: [batch_size, max_cache_len, n_heads, d_k]
        cache_lengths: [batch_size] 每个序列的实际缓存长度
        """
        batch_size, seq_len, _ = x.shape
        
        # 为每个序列计算 QKV
        Q = self.W_q(x)  # [batch, seq, d_model]
        
        # 处理不同长度的缓存
        if cache is not None:
            # 创建掩码,只关注有效缓存
            mask = torch.arange(cache.size(1), device=x.device)[None, :] < cache_lengths[:, None]
            # mask: [batch, max_cache_len]
            
            # 扩展维度用于注意力计算
            mask = mask[:, None, None, :]  # [batch, 1, 1, max_cache_len]
            
            # 注意力计算时应用掩码
            scores = torch.matmul(Q.view(batch_size, seq_len, self.n_heads, self.d_k),
                                 cache.transpose(1, 2)) / math.sqrt(self.d_k)
            scores = scores.masked_fill(~mask, float('-inf'))

KV Cache 与注意力机制的演进关系

9.1 技术演进路线

复制代码
┌─────────────────────────────────────────────────────────────┐
│              从全量计算到 KV Cache 的演进                     │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  2017: 全注意力 (原始 Transformer)                          │
│        ↓  发现:生成阶段存在大量重复计算                     │
│  2018: KV Cache 提出                                        │
│        ↓  优化:用空间换时间,缓存 K/V 矩阵                  │
│  2022: 多查询注意力 (MQA)                                   │
│        ↓  优化:共享 KV 头,减少缓存                         │
│  2023: 分组查询注意力 (GQA)                                 │
│        ↓  优化:平衡性能和缓存                               │
│  2024: 多词元潜在注意力 (MLA)                               │
│        ↓  优化:压缩 KV 缓存 90%+                           │
│  2025: 分页注意力 (PagedAttention)                          │
│        ↓  优化:高效内存管理,支持动态序列                   │
│  2026: 稀疏 KV Cache                                        │
│        →  优化:只缓存重要 token                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

9.2 数学本质的统一视角

无论是原始的全注意力,还是带 KV Cache 的高效版本,其数学本质都是通过键值对存储和检索信息

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

KV Cache 只是将历史的 K K K 和 V V V 保存下来,避免重复计算,并没有改变注意力机制的数学形式。这体现了工程优化对理论模型的忠实保留。


总结

10.1 核心要点

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    KV Cache 核心要点                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1️⃣ 为什么需要 KV Cache                                     │
│      • 自回归生成存在大量重复计算                           │
│      • 每一步都要重新计算所有历史 token 的 KV               │
│      • 浪费比例随序列长度增长趋近 100%                      │
│                                                             │
│  2️⃣ KV Cache 的工作原理                                    │
│      • 缓存所有历史 token 的 K 和 V 矩阵                    │
│      • 新 token 只需计算自己的 K、V                         │
│      • 注意力计算时使用完整缓存                             │
│                                                             │
│  3️⃣ 性能收益                                                │
│      • 计算复杂度:从 O(n²) 降至 O(n)                       │
│      • 加速比:约等于当前序列长度                           │
│      • 内存开销:每层 2 × n × d_k × bytes                    │
│                                                             │
│  4️⃣ 优化方向                                                │
│      • GQA/MQA:共享 KV 头,减少缓存                         │
│      • MLA:压缩 KV 表示,降低 90% 内存                      │
│      • PagedAttention:分页管理,消除碎片                   │
│      • 稀疏缓存:只保留重要 token                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

10.2 一句话总结

KV Cache 是 Transformer 推理加速的基石,它通过缓存历史计算的键值对,将自回归生成的复杂度从 O(n²) 降至 O(n),是实现大模型高效推理不可或缺的技术! 🚀


参考文献

  1. Vaswani et al. "Attention Is All You Need" (2017)
  2. Pope et al. "Efficiently Scaling Transformer Inference" (2022)
  3. Kwon et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention" (2023)
  4. DeepSeek-AI. "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model" (2024)

本文为原创内容,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。

相关推荐
SmartBrain2 小时前
技术洞察:SpringAI与LangGraph选型对比
人工智能·spring boot·架构·langchain·aigc·fastapi
FserSuN2 小时前
OpenClaw接入模型并基于WebUI完成智能操作
人工智能
梦想画家2 小时前
WebAgent详解+实战:用开源AI智能体搞定产品与竞品市场调研
人工智能·webagent
Katecat996632 小时前
基于Mask R-CNN的肉鸡跛足检测系统:R50-SyncBN-GCB-R16-C3-C5-FPN模型训练与COCO数据集应用_2
人工智能·神经网络
小雨中_2 小时前
3.1 GPT 系列:Generative Pre-Training(从 GPT-1 到 GPT-3)
人工智能·gpt·深度学习·机器学习·自然语言处理·gpt-3
xuxianliang2 小时前
第158章 “神谕”的布局(AI)
人工智能·程序员创富
速易达网络2 小时前
AI学习路径 python到openclaw
人工智能·python·学习
量子-Alex2 小时前
【大模型智能体】MemGPT论文深度解读
人工智能