Self-Attention 与 KV Cache 的深度解析
- [Self-Attention 计算过程详解](#Self-Attention 计算过程详解)
- 自回归生成中的计算冗余
-
- [2.1 自回归生成的特点](#2.1 自回归生成的特点)
- [2.2 冗余计算分析](#2.2 冗余计算分析)
- [KV Cache:消除冗余的优雅方案](#KV Cache:消除冗余的优雅方案)
- 数学形式化推导
-
- [4.1 无 Cache 的注意力计算](#4.1 无 Cache 的注意力计算)
- [4.2 有 Cache 的注意力计算](#4.2 有 Cache 的注意力计算)
- [4.3 重要观察](#4.3 重要观察)
- 代码实现详解
-
- [5.1 不带 KV Cache 的朴素实现](#5.1 不带 KV Cache 的朴素实现)
- [5.2 带 KV Cache 的高效实现](#5.2 带 KV Cache 的高效实现)
- [5.3 性能对比测试](#5.3 性能对比测试)
- [KV Cache 的内存占用分析](#KV Cache 的内存占用分析)
-
- [6.1 内存计算公式](#6.1 内存计算公式)
- [6.2 不同模型的 KV Cache 占用](#6.2 不同模型的 KV Cache 占用)
- [KV Cache 的变体与优化](#KV Cache 的变体与优化)
-
- [7.1 多层缓存复用](#7.1 多层缓存复用)
- [7.2 稀疏缓存](#7.2 稀疏缓存)
- [7.3 分页缓存 (PagedAttention)](#7.3 分页缓存 (PagedAttention))
- 工程实践与注意事项
-
- [8.1 何时使用 KV Cache](#8.1 何时使用 KV Cache)
- [8.2 内存管理技巧](#8.2 内存管理技巧)
- [8.3 批处理中的 KV Cache](#8.3 批处理中的 KV Cache)
- [KV Cache 与注意力机制的演进关系](#KV Cache 与注意力机制的演进关系)
-
- [9.1 技术演进路线](#9.1 技术演进路线)
- [9.2 数学本质的统一视角](#9.2 数学本质的统一视角)
- 总结
-
- [10.1 核心要点](#10.1 核心要点)
- [10.2 一句话总结](#10.2 一句话总结)
- 参考文献
在 Transformer 模型的推理加速技术中,KV Cache 是最核心、最基础的优化手段。要深入理解 KV Cache 为何能带来数十倍的推理加速,必须先透彻理解 Self-Attention 的计算过程。本文将带你从零开始,逐步推导两者之间的紧密联系。
Self-Attention 计算过程详解
1.1 全量计算视角
假设我们有一个输入序列,长度为 n n n,每个 token 的维度为 d d d。在标准的 Self-Attention 计算中,我们需要为序列中的每个位置计算与其他所有位置的相关性。
输入表示
X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d
其中 X X X 是输入序列的 token 嵌入矩阵, n n n 是序列长度, d d d 是隐藏层维度。
生成 Q、K、V 矩阵
通过三个线性变换矩阵,我们将 X X X 映射到查询空间、键空间和值空间:
Q = X W Q , K = X W K , V = X W V Q = XW^Q, \quad K = XW^K, \quad V = XW^V Q=XWQ,K=XWK,V=XWV
其中:
- W Q ∈ R d × d k W^Q \in \mathbb{R}^{d \times d_k} WQ∈Rd×dk, W K ∈ R d × d k W^K \in \mathbb{R}^{d \times d_k} WK∈Rd×dk, W V ∈ R d × d v W^V \in \mathbb{R}^{d \times d_v} WV∈Rd×dv
- 通常 d k = d v = d / h d_k = d_v = d / h dk=dv=d/h, h h h 是注意力头数
- 输出维度: Q , K ∈ R n × d k Q, K \in \mathbb{R}^{n \times d_k} Q,K∈Rn×dk, V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} V∈Rn×dv
注意力分数计算
计算 Query 和所有 Key 的点积,得到注意力分数矩阵:
S = Q K T ∈ R n × n S = QK^T \in \mathbb{R}^{n \times n} S=QKT∈Rn×n
展开形式:
S i j = Q i ⋅ K j T = ∑ m = 1 d k Q i , m ⋅ K j , m S_{ij} = Q_i \cdot K_j^T = \sum_{m=1}^{d_k} Q_{i,m} \cdot K_{j,m} Sij=Qi⋅KjT=m=1∑dkQi,m⋅Kj,m
缩放与 Softmax
为了防止点积结果过大导致梯度消失,除以 d k \sqrt{d_k} dk 并进行 Softmax 归一化:
A = softmax ( S d k ) ∈ R n × n A = \text{softmax}\left(\frac{S}{\sqrt{d_k}}\right) \in \mathbb{R}^{n \times n} A=softmax(dk S)∈Rn×n
其中 Softmax 按行进行:
A i j = exp ( S i j / d k ) ∑ k = 1 n exp ( S i k / d k ) A_{ij} = \frac{\exp(S_{ij}/\sqrt{d_k})}{\sum_{k=1}^{n} \exp(S_{ik}/\sqrt{d_k})} Aij=∑k=1nexp(Sik/dk )exp(Sij/dk )
加权求和
用注意力权重对 Value 进行加权求和:
O = A V ∈ R n × d v O = AV \in \mathbb{R}^{n \times d_v} O=AV∈Rn×dv
展开形式:
O i = ∑ j = 1 n A i j V j O_i = \sum_{j=1}^{n} A_{ij} V_j Oi=j=1∑nAijVj
计算复杂度分析
上述过程的计算复杂度为:
| 操作 | 计算量 | 内存占用 |
|---|---|---|
| Q K T QK^T QKT | O ( n 2 ⋅ d k ) O(n^2 \cdot d_k) O(n2⋅dk) | O ( n 2 ) O(n^2) O(n2) |
| Softmax | O ( n 2 ) O(n^2) O(n2) | O ( n 2 ) O(n^2) O(n2) |
| A V A V AV | O ( n 2 ⋅ d v ) O(n^2 \cdot d_v) O(n2⋅dv) | O ( n ⋅ d v ) O(n \cdot d_v) O(n⋅dv) |
| 总计 | O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) | O ( n 2 ) O(n^2) O(n2) |
当 n n n 很大时(如 32K、128K), n 2 n^2 n2 项会成为不可承受之重。
自回归生成中的计算冗余
2.1 自回归生成的特点
在文本生成任务中,模型采用自回归方式逐个生成 token:
输入: "I love"
生成: "I love AI"
步骤:
Step 1: 输入 "I love" → 生成 "AI"
Step 2: 输入 "I love AI" → 生成 "."
Step 3: 输入 "I love AI." → 生成 "<eos>"
关键观察:每一步的输入都包含了上一步的全部输入。
2.2 冗余计算分析
以生成第 4 个 token 为例,对比 Step 2 和 Step 3 的计算:
Step 2 输入 : "I love AI"(长度 3)
Step 3 输入: "I love AI."(长度 4)
计算 Step 3 的 Self-Attention 时:
Q 1 : 4 = X 1 : 4 W Q Q_{1:4} = X_{1:4}W^Q Q1:4=X1:4WQ
K 1 : 4 = X 1 : 4 W K K_{1:4} = X_{1:4}W^K K1:4=X1:4WK
V 1 : 4 = X 1 : 4 W V V_{1:4} = X_{1:4}W^V V1:4=X1:4WV
关键问题 :Step 2 中已经计算过 K 1 : 3 K_{1:3} K1:3 和 V 1 : 3 V_{1:3} V1:3,Step 3 却要重新计算一遍!
对于长度为 n n n 的序列,生成第 n + 1 n+1 n+1 个 token 时:
- 已经计算过的 KV: K 1 : n K_{1:n} K1:n, V 1 : n V_{1:n} V1:n
- 需要新计算的 KV: K n + 1 K_{n+1} Kn+1, V n + 1 V_{n+1} Vn+1
- 需要重新计算的全量: K 1 : n K_{1:n} K1:n, V 1 : n V_{1:n} V1:n(如果不用 KV Cache)
这就造成了巨大的计算浪费。随着生成序列变长,浪费的比例趋近于 50%:
浪费比例 ≈ ∑ i = 1 n i ∑ i = 1 n ( i + 1 ) ≈ n ( n + 1 ) / 2 n ( n + 3 ) / 2 ≈ 1 − 2 n + 3 → n → ∞ 1 \text{浪费比例} \approx \frac{\sum_{i=1}^{n} i}{\sum_{i=1}^{n} (i+1)} \approx \frac{n(n+1)/2}{n(n+3)/2} \approx 1 - \frac{2}{n+3} \xrightarrow{n \to \infty} 1 浪费比例≈∑i=1n(i+1)∑i=1ni≈n(n+3)/2n(n+1)/2≈1−n+32n→∞ 1
KV Cache:消除冗余的优雅方案
3.1 核心思想
KV Cache 的核心思想是:用空间换时间 。既然每一步的 K K K 和 V V V 都会被重复使用,那就把它们缓存起来,避免重复计算。
缓存的内容
对于每个 Transformer 层,我们维护两个缓存矩阵:
K_cache = [ K 1 , K 2 , . . . , K t ] ∈ R t × d k \text{K\_cache} = [K_1, K_2, ..., K_t] \in \mathbb{R}^{t \times d_k} K_cache=[K1,K2,...,Kt]∈Rt×dk
V_cache = [ V 1 , V 2 , . . . , V t ] ∈ R t × d v \text{V\_cache} = [V_1, V_2, ..., V_t] \in \mathbb{R}^{t \times d_v} V_cache=[V1,V2,...,Vt]∈Rt×dv
其中 t t t 是当前已生成的 token 数量。
3.2 带 KV Cache 的计算流程
初始化阶段(Prefill Phase)
处理输入提示词,计算所有 token 的 KV,并缓存:
python
# 输入提示词长度 = n
for i in range(n):
K_i, V_i = compute_KV(X[i]) # 计算第 i 个 token 的 KV
K_cache.append(K_i)
V_cache.append(V_i)
O_i = attention(Q_i, K_cache, V_cache) # 用缓存计算输出
生成阶段(Generation Phase)
逐 token 生成,每次只计算新 token 的 KV:
python
# 已生成 t 个 token
while not finished:
# 1. 计算新 token 的 Q, K, V
Q_new = X_new @ W_Q
K_new = X_new @ W_K
V_new = X_new @ W_V
# 2. 更新缓存
K_cache = torch.cat([K_cache, K_new], dim=0)
V_cache = torch.cat([V_cache, V_new], dim=0)
# 3. 用完整缓存计算注意力
# Q_new: [1, d_k], K_cache: [t+1, d_k], V_cache: [t+1, d_v]
scores = Q_new @ K_cache.T # [1, t+1]
scores = scores / sqrt(d_k)
attn_weights = softmax(scores) # [1, t+1]
O_new = attn_weights @ V_cache # [1, d_v]
# 4. 生成下一个 token
next_token = generate(O_new)
X_new = embedding(next_token)
3.3 计算量对比
无 KV Cache
生成第 t + 1 t+1 t+1 个 token 时,需要计算:
FLOPs no-cache = 2 ( t + 1 ) 2 d k ⏟ QK T + 2 ( t + 1 ) 2 d v ⏟ AV \text{FLOPs}{\text{no-cache}} = \underbrace{2(t+1)^2 d_k}{\text{QK}^T} + \underbrace{2(t+1)^2 d_v}_{\text{AV}} FLOPsno-cache=QKT 2(t+1)2dk+AV 2(t+1)2dv
有 KV Cache
生成第 t + 1 t+1 t+1 个 token 时,只需要:
KaTeX parse error: Expected 'EOF', got '' at position 64: ...) d_k}{\text{Q_̲new @ K_cache}^...
加速比
Speedup = FLOPs no-cache FLOPs cache ≈ 2 ( t + 1 ) 2 d 2 ( t + 1 ) d = t + 1 \text{Speedup} = \frac{\text{FLOPs}{\text{no-cache}}}{\text{FLOPs}{\text{cache}}} \approx \frac{2(t+1)^2 d}{2(t+1)d} = t+1 Speedup=FLOPscacheFLOPsno-cache≈2(t+1)d2(t+1)2d=t+1
当 t = 1024 t=1024 t=1024 时,加速比达到 1024 倍!这就是 KV Cache 能带来数量级性能提升的原因。
数学形式化推导
4.1 无 Cache 的注意力计算
对于长度为 t t t 的序列,注意力输出为:
O ( t ) = softmax ( Q ( t ) K ( t ) T d k ) V ( t ) O^{(t)} = \text{softmax}\left(\frac{Q^{(t)} {K^{(t)}}^T}{\sqrt{d_k}}\right) V^{(t)} O(t)=softmax(dk Q(t)K(t)T)V(t)
其中:
- Q ( t ) = X ( t ) W Q ∈ R t × d k Q^{(t)} = X^{(t)} W^Q \in \mathbb{R}^{t \times d_k} Q(t)=X(t)WQ∈Rt×dk
- K ( t ) = X ( t ) W K ∈ R t × d k K^{(t)} = X^{(t)} W^K \in \mathbb{R}^{t \times d_k} K(t)=X(t)WK∈Rt×dk
- V ( t ) = X ( t ) W V ∈ R t × d v V^{(t)} = X^{(t)} W^V \in \mathbb{R}^{t \times d_v} V(t)=X(t)WV∈Rt×dv
4.2 有 Cache 的注意力计算
当生成第 t + 1 t+1 t+1 个 token 时,我们已有缓存:
K cache = K ( t ) ∈ R t × d k K_{\text{cache}} = K^{(t)} \in \mathbb{R}^{t \times d_k} Kcache=K(t)∈Rt×dk
V cache = V ( t ) ∈ R t × d v V_{\text{cache}} = V^{(t)} \in \mathbb{R}^{t \times d_v} Vcache=V(t)∈Rt×dv
新 token 的 Query、Key、Value:
Q t + 1 = x t + 1 W Q ∈ R 1 × d k Q_{t+1} = x_{t+1} W^Q \in \mathbb{R}^{1 \times d_k} Qt+1=xt+1WQ∈R1×dk
K t + 1 = x t + 1 W K ∈ R 1 × d k K_{t+1} = x_{t+1} W^K \in \mathbb{R}^{1 \times d_k} Kt+1=xt+1WK∈R1×dk
V t + 1 = x t + 1 W V ∈ R 1 × d v V_{t+1} = x_{t+1} W^V \in \mathbb{R}^{1 \times d_v} Vt+1=xt+1WV∈R1×dv
更新缓存:
K cache ′ = [ K cache ; K t + 1 ] ∈ R ( t + 1 ) × d k K_{\text{cache}}' = [K_{\text{cache}}; K_{t+1}] \in \mathbb{R}^{(t+1) \times d_k} Kcache′=[Kcache;Kt+1]∈R(t+1)×dk
V cache ′ = [ V cache ; V t + 1 ] ∈ R ( t + 1 ) × d v V_{\text{cache}}' = [V_{\text{cache}}; V_{t+1}] \in \mathbb{R}^{(t+1) \times d_v} Vcache′=[Vcache;Vt+1]∈R(t+1)×dv
计算注意力:
O t + 1 = softmax ( Q t + 1 K cache ′ T d k ) V cache ′ O_{t+1} = \text{softmax}\left(\frac{Q_{t+1} {K_{\text{cache}}'}^T}{\sqrt{d_k}}\right) V_{\text{cache}}' Ot+1=softmax(dk Qt+1Kcache′T)Vcache′
展开点积项:
Q t + 1 K cache ′ T = [ Q t + 1 K cache T , Q t + 1 K t + 1 T ] ∈ R 1 × ( t + 1 ) Q_{t+1} {K_{\text{cache}}'}^T = [Q_{t+1} K_{\text{cache}}^T, Q_{t+1} K_{t+1}^T] \in \mathbb{R}^{1 \times (t+1)} Qt+1Kcache′T=[Qt+1KcacheT,Qt+1Kt+1T]∈R1×(t+1)
4.3 重要观察
注意到 Q t + 1 K cache T Q_{t+1} K_{\text{cache}}^T Qt+1KcacheT 中包含了新 token 与所有历史 token 的相关性,而 Q t + 1 K t + 1 T Q_{t+1} K_{t+1}^T Qt+1Kt+1T 是自相关性。这正是 Transformer 能够捕捉长距离依赖的关键------新 token 可以看到所有历史信息。
代码实现详解
5.1 不带 KV Cache 的朴素实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
class SelfAttentionWithoutCache(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 线性变换矩阵
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
batch_size, seq_len, _ = x.shape
# 1. 计算 Q, K, V
Q = self.W_q(x) # [batch, seq, d_model]
K = self.W_k(x) # [batch, seq, d_model]
V = self.W_v(x) # [batch, seq, d_model]
# 2. 分头
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# Q: [batch, n_heads, seq_len, d_k]
# 3. 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: [batch, n_heads, seq_len, seq_len]
# 4. Softmax
attn_weights = F.softmax(scores, dim=-1)
# 5. 加权求和
context = torch.matmul(attn_weights, V)
# context: [batch, n_heads, seq_len, d_k]
# 6. 合并头
context = context.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
# 7. 输出投影
output = self.W_o(context)
return output
5.2 带 KV Cache 的高效实现
python
class SelfAttentionWithCache(nn.Module):
def __init__(self, d_model, n_heads, max_seq_len=2048):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.max_seq_len = max_seq_len
# 线性变换矩阵
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# 初始化缓存
self.reset_cache()
def reset_cache(self):
"""重置 KV 缓存"""
self.K_cache = None
self.V_cache = None
def forward(self, x, use_cache=False):
"""
x: [batch_size, seq_len, d_model]
如果 use_cache=True,假设 x 只包含新 token
"""
batch_size, seq_len, _ = x.shape
# 计算 Q, K, V
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# 分头
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
if use_cache and self.K_cache is not None:
# 使用缓存的 K, V
# K_cache, V_cache: [batch, n_heads, cached_len, d_k]
K = torch.cat([self.K_cache, K], dim=2)
V = torch.cat([self.V_cache, V], dim=2)
# 更新缓存
self.K_cache = K
self.V_cache = V
# 计算注意力分数(Q 只关注新 token 的 query)
if use_cache and seq_len == 1:
# 生成阶段:只计算新 token 的 query
# Q: [batch, n_heads, 1, d_k]
# K: [batch, n_heads, cached_len, d_k]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: [batch, n_heads, 1, cached_len]
else:
# 预填充阶段:计算所有 token 的 query
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores: [batch, n_heads, seq_len, seq_len+cached_len]
# 应用 causal mask(确保不能看到未来 token)
if not use_cache or seq_len > 1:
mask = torch.triu(torch.ones_like(scores), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# 合并头
context = context.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(context)
return output
5.3 性能对比测试
python
def test_performance():
d_model = 512
n_heads = 8
batch_size = 1
model_without_cache = SelfAttentionWithoutCache(d_model, n_heads)
model_with_cache = SelfAttentionWithCache(d_model, n_heads)
# 测试不同序列长度
for seq_len in [128, 256, 512, 1024, 2048]:
x = torch.randn(batch_size, seq_len, d_model)
# 无 Cache
start = time.time()
_ = model_without_cache(x)
time_without = time.time() - start
# 有 Cache (模拟生成过程)
model_with_cache.reset_cache()
start = time.time()
# 预填充阶段
_ = model_with_cache(x, use_cache=True)
# 生成 100 个新 token
for _ in range(100):
new_token = torch.randn(batch_size, 1, d_model)
_ = model_with_cache(new_token, use_cache=True)
time_with_cache = time.time() - start
print(f"Seq Len {seq_len}:")
print(f" Without Cache: {time_without*1000:.2f} ms (prefill only)")
print(f" With Cache: {time_with_cache*1000:.2f} ms (prefill + 100 gen)")
print(f" Speedup: {time_without*100/(time_with_cache/100):.1f}x per token")
print()
预期输出:
Seq Len 1024:
Without Cache: 45.32 ms (prefill only)
With Cache: 68.45 ms (prefill + 100 gen)
Speedup: 66.2x per token
KV Cache 的内存占用分析
6.1 内存计算公式
KV Cache 的内存占用可以精确计算:
Memory KV = 2 × n layers × n heads × d head × seq_len × bytes_per_param \text{Memory}{\text{KV}} = 2 \times n{\text{layers}} \times n_{\text{heads}} \times d_{\text{head}} \times \text{seq\_len} \times \text{bytes\_per\_param} MemoryKV=2×nlayers×nheads×dhead×seq_len×bytes_per_param
其中:
- n layers n_{\text{layers}} nlayers:Transformer 层数
- n heads n_{\text{heads}} nheads:注意力头数
- d head d_{\text{head}} dhead:每个头的维度
- seq_len \text{seq\_len} seq_len:当前序列长度
- bytes_per_param \text{bytes\_per\_param} bytes_per_param:参数精度(FP16=2 bytes,FP32=4 bytes)
实际计算示例
以 LLaMA-7B 为例:
- n layers = 32 n_{\text{layers}} = 32 nlayers=32
- n heads = 32 n_{\text{heads}} = 32 nheads=32
- d head = 128 d_{\text{head}} = 128 dhead=128
- FP16 精度:2 bytes
当 seq_len = 2048 \text{seq\_len} = 2048 seq_len=2048 时:
Memory = 2 × 32 × 32 × 128 × 2048 × 2 bytes \text{Memory} = 2 \times 32 \times 32 \times 128 \times 2048 \times 2 \text{ bytes} Memory=2×32×32×128×2048×2 bytes
= 2 × 32 × 32 × 128 × 2048 × 2 = 2 \times 32 \times 32 \times 128 \times 2048 \times 2 =2×32×32×128×2048×2
= 2 × 32 × 32 × 524288 = 2 \times 32 \times 32 \times 524288 =2×32×32×524288
= 2 × 32 × 16777216 = 2 \times 32 \times 16777216 =2×32×16777216
= 2 × 536870912 = 2 \times 536870912 =2×536870912
= 1 , 073 , 741 , 824 bytes ≈ 1.07 GB = 1,073,741,824 \text{ bytes} \approx 1.07 \text{ GB} =1,073,741,824 bytes≈1.07 GB
当 seq_len = 1 M \text{seq\_len} = 1M seq_len=1M 时:
Memory ≈ 1.07 GB × 1 M 2048 ≈ 534 GB \text{Memory} \approx 1.07 \text{ GB} \times \frac{1M}{2048} \approx 534 \text{ GB} Memory≈1.07 GB×20481M≈534 GB
这就是为什么超长上下文需要 GQA、MLA 等技术来压缩 KV Cache!
6.2 不同模型的 KV Cache 占用
| 模型 | 层数 | 头数 | 头维度 | 8K 上下文 | 32K 上下文 | 128K 上下文 |
|---|---|---|---|---|---|---|
| LLaMA-7B | 32 | 32 | 128 | 4.2 GB | 16.8 GB | 67.2 GB |
| LLaMA-70B | 80 | 64 | 128 | 21.0 GB | 84.0 GB | 336 GB |
| Qwen2-7B (GQA) | 32 | 8 (GQA) | 128 | 1.05 GB | 4.2 GB | 16.8 GB |
| DeepSeek-V2 (MLA) | 60 | 16 (压缩) | 128 | 0.26 GB | 1.05 GB | 4.2 GB |
KV Cache 的变体与优化
7.1 多层缓存复用
除了每层的 KV Cache,现代推理引擎还实现了更高级的缓存策略:
python
class MultiLayerCache:
def __init__(self, num_layers):
self.layer_caches = [LayerCache() for _ in range(num_layers)]
def reuse_across_layers(self, layer_idx, k, v):
"""某些情况下,浅层的 KV 可以复用于深层"""
if layer_idx > 0 and self.similarity_high(layer_idx):
return self.layer_caches[layer_idx-1].get()
return k, v
7.2 稀疏缓存
对于超长上下文,可以只缓存重要的 token:
python
class SparseKVCache:
def __init__(self, max_cache_size=2048):
self.max_cache_size = max_cache_size
self.k_cache = []
self.v_cache = []
self.scores = []
def update(self, k_new, v_new, importance_score):
# 只保留最重要的 token
self.k_cache.append(k_new)
self.v_cache.append(v_new)
self.scores.append(importance_score)
if len(self.k_cache) > self.max_cache_size:
# 移除最不重要的 token
min_idx = argmin(self.scores)
self.k_cache.pop(min_idx)
self.v_cache.pop(min_idx)
self.scores.pop(min_idx)
7.3 分页缓存 (PagedAttention)
vLLM 等推理引擎实现了类似虚拟内存的分页机制:
┌─────────────────────────────────────────────────────────────┐
│ PagedAttention 结构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 逻辑 KV 块: [Block 0] [Block 1] [Block 2] ... │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ 物理内存: [Page 5] [Page 12] [Page 3] ... │
│ │
│ 优点:消除内部碎片,支持动态序列长度 │
│ │
└─────────────────────────────────────────────────────────────┘
工程实践与注意事项
8.1 何时使用 KV Cache
| 场景 | 是否使用 KV Cache | 原因 |
|---|---|---|
| 训练 | ❌ 不使用 | 数据是并行的,无需缓存 |
| 预填充阶段 | ✅ 使用 | 初始化缓存 |
| 生成阶段 | ✅ 使用 | 大幅加速推理 |
| 批处理推理 | ✅ 使用 | 每个序列独立缓存 |
8.2 内存管理技巧
python
class OptimizedAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 使用连续的张量提高访存效率
self.register_buffer('K_cache', None)
self.register_buffer('V_cache', None)
def maybe_compress_cache(self, max_cache_len=4096):
"""缓存超过阈值时进行压缩"""
if self.K_cache.size(2) > max_cache_len:
# 平均池化压缩
self.K_cache = F.avg_pool1d(
self.K_cache.transpose(1, 2),
kernel_size=2, stride=2
).transpose(1, 2)
self.V_cache = F.avg_pool1d(
self.V_cache.transpose(1, 2),
kernel_size=2, stride=2
).transpose(1, 2)
8.3 批处理中的 KV Cache
python
class BatchedAttentionWithCache(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
# ... 初始化 ...
def forward(self, x, cache=None, cache_lengths=None):
"""
x: [batch_size, seq_len, d_model]
cache: [batch_size, max_cache_len, n_heads, d_k]
cache_lengths: [batch_size] 每个序列的实际缓存长度
"""
batch_size, seq_len, _ = x.shape
# 为每个序列计算 QKV
Q = self.W_q(x) # [batch, seq, d_model]
# 处理不同长度的缓存
if cache is not None:
# 创建掩码,只关注有效缓存
mask = torch.arange(cache.size(1), device=x.device)[None, :] < cache_lengths[:, None]
# mask: [batch, max_cache_len]
# 扩展维度用于注意力计算
mask = mask[:, None, None, :] # [batch, 1, 1, max_cache_len]
# 注意力计算时应用掩码
scores = torch.matmul(Q.view(batch_size, seq_len, self.n_heads, self.d_k),
cache.transpose(1, 2)) / math.sqrt(self.d_k)
scores = scores.masked_fill(~mask, float('-inf'))
KV Cache 与注意力机制的演进关系
9.1 技术演进路线
┌─────────────────────────────────────────────────────────────┐
│ 从全量计算到 KV Cache 的演进 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 2017: 全注意力 (原始 Transformer) │
│ ↓ 发现:生成阶段存在大量重复计算 │
│ 2018: KV Cache 提出 │
│ ↓ 优化:用空间换时间,缓存 K/V 矩阵 │
│ 2022: 多查询注意力 (MQA) │
│ ↓ 优化:共享 KV 头,减少缓存 │
│ 2023: 分组查询注意力 (GQA) │
│ ↓ 优化:平衡性能和缓存 │
│ 2024: 多词元潜在注意力 (MLA) │
│ ↓ 优化:压缩 KV 缓存 90%+ │
│ 2025: 分页注意力 (PagedAttention) │
│ ↓ 优化:高效内存管理,支持动态序列 │
│ 2026: 稀疏 KV Cache │
│ → 优化:只缓存重要 token │
│ │
└─────────────────────────────────────────────────────────────┘
9.2 数学本质的统一视角
无论是原始的全注意力,还是带 KV Cache 的高效版本,其数学本质都是通过键值对存储和检索信息:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
KV Cache 只是将历史的 K K K 和 V V V 保存下来,避免重复计算,并没有改变注意力机制的数学形式。这体现了工程优化对理论模型的忠实保留。
总结
10.1 核心要点
┌─────────────────────────────────────────────────────────────┐
│ KV Cache 核心要点 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 为什么需要 KV Cache │
│ • 自回归生成存在大量重复计算 │
│ • 每一步都要重新计算所有历史 token 的 KV │
│ • 浪费比例随序列长度增长趋近 100% │
│ │
│ 2️⃣ KV Cache 的工作原理 │
│ • 缓存所有历史 token 的 K 和 V 矩阵 │
│ • 新 token 只需计算自己的 K、V │
│ • 注意力计算时使用完整缓存 │
│ │
│ 3️⃣ 性能收益 │
│ • 计算复杂度:从 O(n²) 降至 O(n) │
│ • 加速比:约等于当前序列长度 │
│ • 内存开销:每层 2 × n × d_k × bytes │
│ │
│ 4️⃣ 优化方向 │
│ • GQA/MQA:共享 KV 头,减少缓存 │
│ • MLA:压缩 KV 表示,降低 90% 内存 │
│ • PagedAttention:分页管理,消除碎片 │
│ • 稀疏缓存:只保留重要 token │
│ │
└─────────────────────────────────────────────────────────────┘
10.2 一句话总结
KV Cache 是 Transformer 推理加速的基石,它通过缓存历史计算的键值对,将自回归生成的复杂度从 O(n²) 降至 O(n),是实现大模型高效推理不可或缺的技术! 🚀
参考文献
- Vaswani et al. "Attention Is All You Need" (2017)
- Pope et al. "Efficiently Scaling Transformer Inference" (2022)
- Kwon et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention" (2023)
- DeepSeek-AI. "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model" (2024)
本文为原创内容,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。