13-KV Cache与位置编码表:大模型推理加速的核心技术

从自回归生成说起

在前面的章节中,我们学习了大模型的核心原理:给定前面的Token序列,预测下一个Token。但是,当我们实际使用大模型进行文本生成时,会遇到一个严重的性能问题

自回归生成的过程

假设我们要让模型生成一句话:"今天天气真好"(5个Token)

第1步:输入提示词"今天"

  • 输入序列:["今天"](1个Token)
  • 模型计算注意力,输出:["天气"]

第2步:继续生成

  • 输入序列:["今天", "天气"](2个Token)
  • 模型重新计算这2个Token的注意力,输出:["真"]

第3步:继续生成

  • 输入序列:["今天", "天气", "真"](3个Token)
  • 模型重新计算这3个Token的注意力,输出:["好"]

注意到问题了吗?每次生成新Token时,模型都要重新计算前面所有Token的注意力!

重复计算的代价

让我们用数学来量化这个问题。

注意力计算回顾

在注意力机制中,对于每个Token,我们需要计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q = X ⋅ W Q (Query) K = X ⋅ W K (Key) V = X ⋅ W V (Value) Output = softmax ( Q ⋅ K T d k ) ⋅ V \begin{aligned} Q &= X \cdot W_Q \quad \text{(Query)} \\ K &= X \cdot W_K \quad \text{(Key)} \\ V &= X \cdot W_V \quad \text{(Value)} \\ \text{Output} &= \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V \end{aligned} </math>QKVOutput=X⋅WQ(Query)=X⋅WK(Key)=X⋅WV(Value)=softmax(dk Q⋅KT)⋅V

重复计算示例

假设我们要生成长度为100的文本,每个生成步骤的计算量:

步骤 序列长度 需要计算的Token数 累计计算量
1 1 1 1
2 2 2 1+2=3
3 3 3 1+2+3=6
... ... ... ...
100 100 100 1+2+...+100=5050

总计算量 : <math xmlns="http://www.w3.org/1998/Math/MathML"> ∑ i = 1 100 i = 100 × 101 2 = 5050 \sum_{i=1}^{100} i = \frac{100 \times 101}{2} = 5050 </math>∑i=1100i=2100×101=5050 次Token的注意力计算

但实际上,真正需要的计算量只有100次!因为:

  • 第1个Token的K、V计算一次就够了
  • 第2个Token的K、V计算一次就够了
  • ...
  • 第100个Token的K、V计算一次就够了

问题的根源 :前面Token的K和V在每一步都被重新计算,但它们的值根本不会改变

KV Cache:缓存已计算的K和V

核心思想

KV Cache的思想非常简单:

既然每个Token的K和V只需要计算一次,那就把它们缓存起来,下次直接使用!

具体来说:

  1. 第1步:生成第1个Token

    • 计算: <math xmlns="http://www.w3.org/1998/Math/MathML"> K 1 , V 1 K_1, V_1 </math>K1,V1
    • 缓存:保存 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 1 , V 1 K_1, V_1 </math>K1,V1
    • 输出:新Token
  2. 第2步:生成第2个Token

    • 计算:只计算新Token的 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 2 , V 2 K_2, V_2 </math>K2,V2
    • 缓存:保存 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 2 , V 2 K_2, V_2 </math>K2,V2,现在缓存中有 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ K 1 , K 2 ] , [ V 1 , V 2 ] [K_1, K_2], [V_1, V_2] </math>[K1,K2],[V1,V2]
    • 使用缓存:直接读取 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 1 , V 1 K_1, V_1 </math>K1,V1,无需重新计算
    • 输出:新Token
  3. 第3步:生成第3个Token

    • 计算:只计算新Token的 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 3 , V 3 K_3, V_3 </math>K3,V3
    • 缓存:保存 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 3 , V 3 K_3, V_3 </math>K3,V3,现在缓存中有 <math xmlns="http://www.w3.org/1998/Math/MathML"> [ K 1 , K 2 , K 3 ] , [ V 1 , V 2 , V 3 ] [K_1, K_2, K_3], [V_1, V_2, V_3] </math>[K1,K2,K3],[V1,V2,V3]
    • 使用缓存:直接读取 <math xmlns="http://www.w3.org/1998/Math/MathML"> K 1 , K 2 , V 1 , V 2 K_1, K_2, V_1, V_2 </math>K1,K2,V1,V2,无需重新计算
    • 输出:新Token

性能提升

使用KV Cache后,生成100个Token的计算量:

步骤 需要新计算的KV 从缓存读取的KV 总计算量
1 1 0 1
2 1 1 2
3 1 2 3
... ... ... ...
100 1 99 100

总计算量 :100次(从5050次降到100次,加速50倍!)

KV Cache的数据结构

缓存的形状

对于一个多头注意力层:

  • 输入序列长度: <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n(已生成的Token数)
  • 模型维度: <math xmlns="http://www.w3.org/1998/Math/MathML"> d model d_{\text{model}} </math>dmodel(例如:4096)
  • 注意力头数: <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h(例如:32)
  • 每个头的维度: <math xmlns="http://www.w3.org/1998/Math/MathML"> d k = d model / h d_k = d_{\text{model}} / h </math>dk=dmodel/h(例如:128)
  • 层数: <math xmlns="http://www.w3.org/1998/Math/MathML"> L L </math>L(例如:32层)

每一层的KV Cache形状
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> K cache ∈ R n × h × d k V cache ∈ R n × h × d k \begin{aligned} K_{\text{cache}} &\in \mathbb{R}^{n \times h \times d_k} \\ V_{\text{cache}} &\in \mathbb{R}^{n \times h \times d_k} \end{aligned} </math>KcacheVcache∈Rn×h×dk∈Rn×h×dk

全模型的KV Cache形状(所有层):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Total K Cache ∈ R L × n × h × d k Total V Cache ∈ R L × n × h × d k \begin{aligned} \text{Total K Cache} &\in \mathbb{R}^{L \times n \times h \times d_k} \\ \text{Total V Cache} &\in \mathbb{R}^{L \times n \times h \times d_k} \end{aligned} </math>Total K CacheTotal V Cache∈RL×n×h×dk∈RL×n×h×dk

内存占用计算

假设使用FP16精度(每个数2字节),模型参数:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> L = 32 L = 32 </math>L=32 层
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> d model = 4096 d_{\text{model}} = 4096 </math>dmodel=4096
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> h = 32 h = 32 </math>h=32 头
  • 序列长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> n = 2048 n = 2048 </math>n=2048

单个样本的KV Cache内存
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Memory = 2 × L × n × h × d k × 2 bytes = 2 × 32 × 2048 × 32 × 128 × 2 = 1 , 073 , 741 , 824 bytes = 1 GB \begin{aligned} \text{Memory} &= 2 \times L \times n \times h \times d_k \times 2 \text{ bytes} \\ &= 2 \times 32 \times 2048 \times 32 \times 128 \times 2 \\ &= 1,073,741,824 \text{ bytes} \\ &= 1 \text{ GB} \end{aligned} </math>Memory=2×L×n×h×dk×2 bytes=2×32×2048×32×128×2=1,073,741,824 bytes=1 GB

Batch推理的内存(batch_size=32)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Total Memory = 1 GB × 32 = 32 GB \text{Total Memory} = 1 \text{ GB} \times 32 = 32 \text{ GB} </math>Total Memory=1 GB×32=32 GB

这就是为什么大模型推理需要大显存的原因之一!

实际例子:不同模型的KV Cache

模型 层数 d_model 头数 序列长度 单样本KV Cache Batch=32
GPT-2 Small 12 768 12 1024 36 MB 1.1 GB
LLaMA-7B 32 4096 32 2048 1 GB 32 GB
LLaMA-13B 40 5120 40 2048 1.6 GB 51 GB
LLaMA-65B 80 8192 64 2048 5.1 GB 163 GB
GPT-3 175B 96 12288 96 2048 9.2 GB 294 GB

可以看到,对于超大模型,KV Cache可能比模型权重本身还要占用更多显存

KV Cache的实现细节

伪代码实现

python 复制代码
class MultiHeadAttentionWithKVCache:
    def __init__(self, d_model, num_heads):
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 权重矩阵
        self.W_Q = Parameter(torch.randn(d_model, d_model))
        self.W_K = Parameter(torch.randn(d_model, d_model))
        self.W_V = Parameter(torch.randn(d_model, d_model))
        self.W_O = Parameter(torch.randn(d_model, d_model))

        # KV Cache(初始为空)
        self.k_cache = []  # List of cached K tensors
        self.v_cache = []  # List of cached V tensors

    def forward(self, x, use_cache=True):
        """
        x: 输入Token的embedding,形状 (batch_size, 1, d_model)
           注意:推理时每次只输入1个新Token
        """
        batch_size = x.shape[0]

        # 计算新Token的Q、K、V
        Q_new = x @ self.W_Q  # (batch_size, 1, d_model)
        K_new = x @ self.W_K  # (batch_size, 1, d_model)
        V_new = x @ self.W_V  # (batch_size, 1, d_model)

        # 重塑为多头形状
        Q_new = Q_new.view(batch_size, 1, self.num_heads, self.d_k)
        K_new = K_new.view(batch_size, 1, self.num_heads, self.d_k)
        V_new = V_new.view(batch_size, 1, self.num_heads, self.d_k)

        if use_cache:
            # 将新的K、V添加到缓存
            self.k_cache.append(K_new)
            self.v_cache.append(V_new)

            # 拼接所有历史K、V
            K = torch.cat(self.k_cache, dim=1)  # (batch, seq_len, heads, d_k)
            V = torch.cat(self.v_cache, dim=1)
        else:
            K = K_new
            V = V_new

        # 计算注意力
        # Q: (batch, 1, heads, d_k) - 只有新Token的Query
        # K: (batch, seq_len, heads, d_k) - 所有Token的Key(包括历史)
        scores = torch.einsum('bqhd,bkhd->bhqk', Q_new, K) / math.sqrt(self.d_k)
        # scores: (batch, heads, 1, seq_len)

        attn_weights = F.softmax(scores, dim=-1)

        # 加权求和
        output = torch.einsum('bhqk,bkhd->bqhd', attn_weights, V)
        # output: (batch, 1, heads, d_k)

        # 重塑并投影
        output = output.reshape(batch_size, 1, self.d_model)
        output = output @ self.W_O

        return output

    def clear_cache(self):
        """清空KV Cache,开始新的生成任务"""
        self.k_cache = []
        self.v_cache = []

关键点解析

  1. 只计算新Token的K和V

    python 复制代码
    K_new = x @ self.W_K  # x的形状是(batch, 1, d_model),只有1个Token
  2. 从缓存读取历史K、V

    python 复制代码
    K = torch.cat(self.k_cache, dim=1)  # 拼接所有历史Token的K
  3. 注意力计算使用完整的K、V

    python 复制代码
    # Q只有1个Token(新Token)
    # K、V有n个Token(所有历史Token + 新Token)
    scores = torch.einsum('bqhd,bkhd->bhqk', Q_new, K)

KV Cache与位置编码的关系

这里有一个非常重要的问题:当我们使用KV Cache时,位置编码怎么办?

绝对位置编码的问题

回顾一下绝对位置编码(Sinusoidal或Learned):
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> X with_pos [ i ] = X [ i ] + PE [ i ] X_{\text{with\_pos}}[i] = X[i] + \text{PE}[i] </math>Xwith_pos[i]=X[i]+PE[i]

其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> PE [ i ] \text{PE}[i] </math>PE[i] 是第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 个位置的位置编码。

问题 :当使用KV Cache时,每次只输入1个新Token,但这个Token的绝对位置在不断变化!

举例:

  • 第1步:输入Token的位置是0,PE[0]
  • 第2步:输入Token的位置是1,PE[1]
  • 第3步:输入Token的位置是2,PE[2]
  • ...

看起来没问题?但实际上有个隐藏的问题:

缓存的K、V已经包含了位置编码信息

  • 第1个Token的K、V计算时使用了PE[0]
  • 第2个Token的K、V计算时使用了PE[1]
  • ...

所以,绝对位置编码在KV Cache场景下是兼容的,但需要注意:

  1. 必须传入正确的位置索引(当前Token是第几个)
  2. 位置编码表必须足够长(支持最大序列长度)

位置编码表(Position Embedding Table)

在实际实现中,位置编码通常预先计算并存储在一个位置编码表中:

python 复制代码
class PositionalEncoding:
    def __init__(self, d_model, max_seq_len=5000):
        # 预先计算所有位置的编码
        self.pe_table = torch.zeros(max_seq_len, d_model)

        position = torch.arange(0, max_seq_len).unsqueeze(1)  # (max_seq_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )

        # 偶数维度用sin
        self.pe_table[:, 0::2] = torch.sin(position * div_term)
        # 奇数维度用cos
        self.pe_table[:, 1::2] = torch.cos(position * div_term)

    def get_position_encoding(self, position):
        """
        position: 当前Token的位置索引(标量)
        返回: 该位置的位置编码向量 (d_model,)
        """
        return self.pe_table[position]

使用KV Cache时的流程

python 复制代码
# 第1步:生成第1个Token(位置0)
x_0 = token_embedding(token_0) + pe_table[0]  # 加上位置0的编码
output_0 = attention(x_0)

# 第2步:生成第2个Token(位置1)
x_1 = token_embedding(token_1) + pe_table[1]  # 加上位置1的编码
output_1 = attention(x_1)  # 使用KV Cache,读取位置0的K、V

# 第3步:生成第3个Token(位置2)
x_2 = token_embedding(token_2) + pe_table[2]  # 加上位置2的编码
output_2 = attention(x_2)  # 使用KV Cache,读取位置0、1的K、V

RoPE与KV Cache

RoPE(Rotary Position Embedding)是一种更现代的位置编码方式,它在计算注意力时动态地将位置信息旋转到Q和K中。

RoPE的优势

  1. 相对位置敏感:注意力分数只依赖于Token之间的相对距离
  2. 无需位置编码表:位置信息通过旋转矩阵动态计算
  3. 与KV Cache完美兼容:缓存的K已经包含了正确的位置信息

RoPE在KV Cache中的应用

python 复制代码
def apply_rotary_pos_emb(q, k, position):
    """
    应用旋转位置编码
    q, k: (batch, seq_len, heads, d_k)
    position: 当前Token的绝对位置
    """
    # 计算旋转角度
    theta = position / (10000 ** (torch.arange(0, d_k, 2) / d_k))

    # 构造旋转矩阵
    cos = torch.cos(theta)
    sin = torch.sin(theta)

    # 旋转Q和K
    q_rot = apply_rotation(q, cos, sin)
    k_rot = apply_rotation(k, cos, sin)

    return q_rot, k_rot

# 使用KV Cache时
Q_new = apply_rotary_pos_emb(Q_new, position=current_position)
K_new = apply_rotary_pos_emb(K_new, position=current_position)

# 缓存已旋转的K
k_cache.append(K_new)

关键点

  • 每个Token的K在计算时就已经包含了其位置信息(通过旋转)
  • 缓存的K不需要再次旋转
  • 新Token的Q旋转时使用其当前位置
  • 注意力计算时,Q和K的相对位置关系自动体现在旋转角度的差值中

位置编码与KV Cache的总结

位置编码类型 与KV Cache的兼容性 注意事项
绝对位置编码(Sinusoidal) ✅ 兼容 需要预先计算位置编码表,传入正确的位置索引
绝对位置编码(Learned) ✅ 兼容 同上,位置编码表是可学习参数
RoPE ✅ 完美兼容 缓存的K已包含位置信息,无需额外处理
ALiBi ✅ 完美兼容 位置偏置在计算注意力时动态添加

KV Cache的变体与优化

1. Multi-Query Attention (MQA)

问题:标准多头注意力中,每个头都有自己的K和V,导致KV Cache很大。

解决方案 :所有头共享一组K和V。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q i = X ⋅ W i Q (每个头有独立的Q) K = X ⋅ W K (所有头共享K) V = X ⋅ W V (所有头共享V) \begin{aligned} Q_i &= X \cdot W_i^Q \quad \text{(每个头有独立的Q)} \\ K &= X \cdot W^K \quad \text{(所有头共享K)} \\ V &= X \cdot W^V \quad \text{(所有头共享V)} \end{aligned} </math>QiKV=X⋅WiQ(每个头有独立的Q)=X⋅WK(所有头共享K)=X⋅WV(所有头共享V)

优势

  • KV Cache大小减少 <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h 倍(头数)
  • 例如:32头变成1组K、V,内存占用减少32倍

劣势

  • 表达能力下降(所有头看到相同的K、V)

2. Grouped-Query Attention (GQA)

折中方案:将头分成若干组,每组共享K和V。

例如:32个头分成4组,每组8个头共享一组K、V。
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Group 1: heads 0-7 共享 K 1 , V 1 Group 2: heads 8-15 共享 K 2 , V 2 Group 3: heads 16-23 共享 K 3 , V 3 Group 4: heads 24-31 共享 K 4 , V 4 \begin{aligned} &\text{Group 1: heads 0-7 共享 } K_1, V_1 \\ &\text{Group 2: heads 8-15 共享 } K_2, V_2 \\ &\text{Group 3: heads 16-23 共享 } K_3, V_3 \\ &\text{Group 4: heads 24-31 共享 } K_4, V_4 \end{aligned} </math>Group 1: heads 0-7 共享 K1,V1Group 2: heads 8-15 共享 K2,V2Group 3: heads 16-23 共享 K3,V3Group 4: heads 24-31 共享 K4,V4

优势

  • KV Cache减少 <math xmlns="http://www.w3.org/1998/Math/MathML"> h / g h / g </math>h/g 倍( <math xmlns="http://www.w3.org/1998/Math/MathML"> g g </math>g是组数)
  • 保留了一定的多头表达能力

实际应用

  • LLaMA-2:使用GQA(8组)
  • Mistral:使用GQA(8组)

3. Paged Attention

问题:KV Cache是连续内存块,当序列很长时,可能无法分配足够大的连续内存。

解决方案:将KV Cache分成固定大小的"页",类似操作系统的虚拟内存。

python 复制代码
# 传统KV Cache:连续内存
k_cache = torch.zeros(batch, seq_len, heads, d_k)  # 需要连续的seq_len空间

# Paged Attention:分页存储
page_size = 16  # 每页存储16个Token的K/V
num_pages = seq_len // page_size
k_cache_pages = [
    torch.zeros(batch, page_size, heads, d_k) for _ in range(num_pages)
]

优势

  • 内存碎片友好
  • 支持动态序列长度
  • 减少内存浪费(不需要预先分配最大长度)

实际应用

  • vLLM:使用Paged Attention实现高效的批量推理

实际生成示例

让我们通过一个完整的例子来理解KV Cache的工作流程。

任务:生成文本 "今天天气真好"

初始状态

  • Prompt(用户输入):无(从头开始生成)
  • KV Cache:空

Step 1:生成"今天"

arduino 复制代码
输入:<BOS>(开始标记)
位置编码:PE[0]
计算:Q_0, K_0, V_0
KV Cache:K_0, V_0
输出:"今天"

Step 2:生成"天气"

css 复制代码
输入:"今天"
位置编码:PE[1]
计算:Q_1, K_1, V_1(只计算新Token)
KV Cache:[K_0, K_1], [V_0, V_1](添加新的K、V)
注意力:Q_1 attend to [K_0, K_1]
输出:"天气"

Step 3:生成"真"

css 复制代码
输入:"天气"
位置编码:PE[2]
计算:Q_2, K_2, V_2
KV Cache:[K_0, K_1, K_2], [V_0, V_1, V_2]
注意力:Q_2 attend to [K_0, K_1, K_2]
输出:"真"

Step 4:生成"好"

css 复制代码
输入:"真"
位置编码:PE[3]
计算:Q_3, K_3, V_3
KV Cache:[K_0, K_1, K_2, K_3], [V_0, V_1, V_2, V_3]
注意力:Q_3 attend to [K_0, K_1, K_2, K_3]
输出:"好"

性能对比

不使用KV Cache(每步重新计算)

  • Step 1:计算1个Token的KV → 1次
  • Step 2:计算2个Token的KV → 2次
  • Step 3:计算3个Token的KV → 3次
  • Step 4:计算4个Token的KV → 4次
  • 总计:1+2+3+4 = 10次KV计算

使用KV Cache

  • Step 1:计算K_0, V_0 → 1次
  • Step 2:计算K_1, V_1 → 1次
  • Step 3:计算K_2, V_2 → 1次
  • Step 4:计算K_3, V_3 → 1次
  • 总计 :4次KV计算(加速2.5倍

对于更长的序列(例如2048个Token),加速比接近1024倍!

KV Cache的管理策略

1. 固定长度截断

当序列超过最大长度时,丢弃最早的Token:

python 复制代码
max_cache_len = 2048

if len(k_cache) >= max_cache_len:
    # 移除最早的Token
    k_cache.pop(0)
    v_cache.pop(0)

# 添加新Token
k_cache.append(K_new)
v_cache.append(V_new)

优势 :简单,内存可控 劣势:可能丢失重要的历史信息

2. 滑动窗口

只保留最近的N个Token:

python 复制代码
window_size = 512

if len(k_cache) >= window_size:
    k_cache = k_cache[-window_size:]
    v_cache = v_cache[-window_size:]

优势 :专注于局部上下文 劣势:无法建模长距离依赖

3. 重要性采样

根据注意力权重,保留重要的Token:

python 复制代码
def prune_cache_by_attention(k_cache, v_cache, attention_weights, keep_ratio=0.5):
    # 计算每个Token的平均注意力分数
    importance = attention_weights.mean(dim=(0, 1))  # (seq_len,)

    # 选择重要性最高的Token
    num_keep = int(len(k_cache) * keep_ratio)
    keep_indices = torch.topk(importance, num_keep).indices

    # 只保留重要的Token
    k_cache = [k_cache[i] for i in keep_indices]
    v_cache = [v_cache[i] for i in keep_indices]

    return k_cache, v_cache

优势 :保留关键信息 劣势:计算复杂度高

4. H2O(Heavy-Hitter Oracle)

最新的研究表明,大多数注意力权重集中在少数"重要"Token上:

  • Heavy Hitters:注意力权重最高的Token(例如标点符号、关键词)
  • Recent Tokens:最近生成的Token

策略:只缓存Heavy Hitters + Recent Tokens

python 复制代码
def h2o_cache_management(k_cache, v_cache, attention_weights,
                         heavy_ratio=0.1, recent_ratio=0.1):
    seq_len = len(k_cache)

    # 计算累积注意力分数
    cumulative_attention = attention_weights.sum(dim=(0, 1, 2))  # (seq_len,)

    # 选择Heavy Hitters
    num_heavy = int(seq_len * heavy_ratio)
    heavy_indices = torch.topk(cumulative_attention, num_heavy).indices

    # 选择Recent Tokens
    num_recent = int(seq_len * recent_ratio)
    recent_indices = torch.arange(seq_len - num_recent, seq_len)

    # 合并索引
    keep_indices = torch.cat([heavy_indices, recent_indices]).unique()

    # 只保留选中的Token
    k_cache = [k_cache[i] for i in keep_indices]
    v_cache = [v_cache[i] for i in keep_indices]

    return k_cache, v_cache

优势

  • 大幅减少缓存大小(可减少90%)
  • 几乎不损失性能

小结

KV Cache的核心思想

  1. 问题:自回归生成时,每步都重新计算前面所有Token的K和V,导致大量重复计算
  2. 解决方案:缓存已计算的K和V,每步只计算新Token的K和V
  3. 性能提升 :对于长度N的序列,从 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N 2 ) O(N^2) </math>O(N2) 降到 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( N ) O(N) </math>O(N),加速可达N/2倍

位置编码与KV Cache

  1. 绝对位置编码:需要预先计算位置编码表,确保每个Token使用正确的位置索引
  2. RoPE:通过旋转矩阵动态编码位置,与KV Cache完美兼容
  3. ALiBi:通过注意力偏置编码位置,与KV Cache完美兼容

内存优化技术

  1. Multi-Query Attention (MQA) :所有头共享K和V,减少KV Cache <math xmlns="http://www.w3.org/1998/Math/MathML"> h h </math>h倍
  2. Grouped-Query Attention (GQA):头分组共享K和V,平衡性能和内存
  3. Paged Attention:分页存储KV Cache,减少内存碎片
  4. H2O:只缓存重要Token,减少90%缓存大小

实际应用

  • OpenAI GPT:使用KV Cache + 绝对位置编码
  • Meta LLaMA:使用KV Cache + RoPE + GQA
  • vLLM:使用KV Cache + Paged Attention,实现高效批量推理
  • DeepSeek:使用KV Cache + MLA(Multi-head Latent Attention),进一步压缩KV Cache

KV Cache是大模型推理加速的基石技术,几乎所有现代推理系统都依赖它来实现实时交互。理解KV Cache的原理,对于优化大模型部署和推理性能至关重要。

相关推荐
孟陬1 小时前
国外技术周刊 #1:Paul Graham 重新分享最受欢迎的文章《创作者的品味》、本周被划线最多 YouTube《如何在 19 分钟内学会 AI》、为何我不
java·前端·后端
想用offer打牌1 小时前
一站式了解四种限流算法
java·后端·go
嘻哈baby1 小时前
用 C++ 写线程池是怎样一种体验?
后端
嘻哈baby1 小时前
SQL Server 和 Oracle 以及 MySQL 有哪些区别?
后端
绝无仅有1 小时前
Redis过期删除与内存淘汰策略详解
后端·面试·架构
武子康2 小时前
大数据-237 离线数仓 - Hive 广告业务实战:ODS→DWD 事件解析、广告明细与转化分析落地
大数据·后端·apache hive
绝无仅有2 小时前
Redis大Key问题排查与解决方案全解析
后端·面试·架构
舒一笑2 小时前
Ubuntu系统安装CodeX出现问题
linux·后端
golang学习记2 小时前
GitLens 十大神技:彻底改变你在 VS Code 中的 Git 工作流
前端·后端·visual studio code