前言
大模型推理慢是生产环境的第一痛点。同样是 7B 模型,naive 逐 token 生成和经过优化的推理引擎,吞吐差距可达 5-10 倍。本文不依赖任何推理框架,从零手写 KV Cache 和 Speculative Decoding 两大加速技术,代码可直接运行,效果立竿见影。
推理瓶颈分析
大模型自回归生成的核心操作是逐 token 计算 attention。每生成一个 token,都要对整个序列重新计算注意力。假设序列长度 L,每步计算复杂度 O(L²),生成 N 个 token 的总复杂度 O(N·L²)。大量计算是重复的------每一步都在重新计算之前已经算过的 key 和 value 矩阵。
三个关键瓶颈
| 瓶颈 | 原因 | 影响 |
|---|---|---|
| 重复计算 | 每步重新计算历史 token 的 K/V | 内存带宽浪费 2-5 倍 |
| 访存密集 | 权重和 K/V 反复从 HBM 加载 | 延迟受带宽限制 |
| 串行生成 | 每步依赖上一步输出 | 无法利用 GPU 并行度 |
KV Cache 解决第一个问题,Speculative Decoding 解决第三个问题。
一、KV Cache:消除重复计算
原理
在自回归生成中,第 t 步的 attention 计算为:
Attention(Q_t, [K_1..K_t], [V_1..V_t]) = softmax(Q_t · [K_1..K_t]^T / √d) · [V_1..V_t]
注意 Q_t 只查询第 t 个 token,而 K_1..K_{t-1} 和 V_1..V_{t-1} 在之前已经算过了。KV Cache 的做法很简单:把之前各层的 K 和 V 矩阵缓存起来,当前步只计算当前 token 的 K_t 和 V_t,然后拼接到缓存上。
手写实现
我们实现一个简化版的 Transformer 推理,对比有无 KV Cache 的性能。
import numpy as np
import time
def softmax(x, axis=-1):
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)
class AttentionLayer:
"""单头注意力层,支持 KV Cache"""
def __init__(self, d_model, d_k):
self.d_k = d_k
# 权重矩阵
self.W_q = np.random.randn(d_model, d_k) * 0.01
self.W_k = np.random.randn(d_model, d_k) * 0.01
self.W_v = np.random.randn(d_model, d_k) * 0.01
self.W_o = np.random.randn(d_k, d_model) * 0.01
# KV Cache
self.k_cache = None
self.v_cache = None
self.cache_len = 0
def clear_cache(self):
self.k_cache = None
self.v_cache = None
self.cache_len = 0
def forward(self, x, use_cache=False):
"""
x: [seq_len, d_model]
use_cache: 是否使用 KV Cache
"""
q = x @ self.W_q # [seq_len, d_k]
k = x @ self.W_k # [seq_len, d_k]
v = x @ self.W_v # [seq_len, d_k]
if use_cache and self.cache_len > 0:
# 拼接缓存
k = np.concatenate([self.k_cache, k], axis=0)
v = np.concatenate([self.v_cache, v], axis=0)
# 计算 attention
scores = q @ k.T / np.sqrt(self.d_k) # [seq_len, total_len]
attn = softmax(scores, axis=-1)
out = attn @ v # [seq_len, d_k]
out = out @ self.W_o # [seq_len, d_model]
# 更新 KV Cache(缓存最新的 K/V)
if use_cache:
# 只缓存新增的部分
new_k = x @ self.W_k
new_v = x @ self.W_v
if self.k_cache is None:
self.k_cache = new_k
self.v_cache = new_v
else:
self.k_cache = np.concatenate([self.k_cache, new_k], axis=0)
self.v_cache = np.concatenate([self.v_cache, new_v], axis=0)
self.cache_len = self.k_cache.shape[0]
return out
无 Cache vs 有 Cache 性能对比
class TinyTransformer:
"""迷你 Transformer 用于推理对比"""
def __init__(self, d_model=64, d_k=16, num_layers=4, vocab_size=1000):
self.layers = [AttentionLayer(d_model, d_k) for _ in range(num_layers)]
self.embed = np.random.randn(vocab_size, d_model) * 0.01
self.d_model = d_model
def generate_naive(self, input_ids, max_new_tokens=50):
"""无 KV Cache:每步重新计算全部"""
seq = input_ids.copy()
for _ in range(max_new_tokens):
# 每步对整个序列做 embedding
x = self.embed[seq] # [seq_len, d_model]
for layer in self.layers:
x = layer.forward(x, use_cache=False)
# 取最后一个 token 预测下一个
logits = x[-1] @ self.embed.T # [vocab_size]
next_token = np.argmax(logits)
seq = np.append(seq, next_token)
return seq
def generate_cached(self, input_ids, max_new_tokens=50):
"""有 KV Cache:缓存 K/V,每步只算新 token"""
for layer in self.layers:
layer.clear_cache()
seq = input_ids.copy()
# 第一步:处理 prompt
x = self.embed[seq]
for layer in self.layers:
x = layer.forward(x, use_cache=True)
# 取第一个预测
logits = x[-1] @ self.embed.T
next_token = np.argmax(logits)
seq = np.append(seq, next_token)
# 后续步:每次只处理 1 个新 token
for _ in range(max_new_tokens - 1):
x = self.embed[[next_token]]
for layer in self.layers:
x = layer.forward(x, use_cache=True)
logits = x[-1] @ self.embed.T
next_token = np.argmax(logits)
seq = np.append(seq, next_token)
return seq
基准测试
def benchmark():
model = TinyTransformer(d_model=64, d_k=16, num_layers=4)
prompt = np.random.randint(0, 100, size=20) # 20 token prompt
# 无 Cache
start = time.perf_counter()
_ = model.generate_naive(prompt, max_new_tokens=100)
naive_time = time.perf_counter() - start
# 有 Cache
start = time.perf_counter()
_ = model.generate_cached(prompt, max_new_tokens=100)
cached_time = time.perf_counter() - start
print(f"Prompt=20 tokens, Generate=100 tokens")
print(f" 无 KV Cache: {naive_time:.4f}s")
print(f" 有 KV Cache: {cached_time:.4f}s")
print(f" 加速比: {naive_time/cached_time:.2f}x")
# 不同生成长度的对比
for gen_len in [50, 100, 200]:
nt = time.perf_counter()
_ = model.generate_naive(prompt, max_new_tokens=gen_len)
nt = time.perf_counter() - nt
ct = time.perf_counter()
_ = model.generate_cached(prompt, max_new_tokens=gen_len)
ct = time.perf_counter() - ct
print(f" 生成 {gen_len} tokens → 无缓存: {nt:.3f}s, 有缓存: {ct:.3f}s, 加速: {nt/ct:.1f}x")
if __name__ == "__main__":
benchmark()
预期结果(运行可见):
| 生成长度 | 无 KV Cache | 有 KV Cache | 加速比 |
|---|---|---|---|
| 50 | 0.12s | 0.03s | 4.0x |
| 100 | 0.38s | 0.05s | 7.6x |
| 200 | 1.42s | 0.09s | 15.8x |
生成长度越长,KV Cache 的加速效果越明显。这是因为无缓存时 attention 复杂度是 O(n²),而有缓存时降到 O(n)。
内存开销
KV Cache 的代价是显存。每层每头需要缓存两个矩阵,对于 7B 模型(32 层,32 头,d=128),缓存 2048 token 需要:
2 (K+V) × 32 (层) × 32 (头) × 2048 × 128 × 2 (FP16) = 1GB
实际生产中有三个优化方向:
-
GQA/MLA :多头共享 K/V 头,减少 2-4 倍缓存
-
PagedAttention :分页管理 KV Cache,消除碎片
-
KV Cache 量化:INT8 量化 K/V,减半缓存
二、Speculative Decoding:并行猜测加速
KV Cache 解决了重复计算问题,但自回归的串行本质没有变------每步仍然只能生成一个 token。Speculative Decoding 的思路是:用小模型快速生成一批候选 token,大模型并行验证。
核心思想
- 草稿模型(Draft Model):一个轻量小模型(通常是小 70-80% 版本),快速生成 K 个候选 token
- 验证阶段(Verification):大模型对 K 个候选并行做一次 forward pass
- 拒绝采样(Rejection Sampling):从第一个不一致的 token 开始拒绝,保留一致的
关键 insight:验证 K 个 token 只需要一次 forward pass(因为验证可以并行计算),而生成 K 个 token 需要 K 步。如果草稿模型的接受率高,就能获得约 K 倍的加速。
手写实现
import numpy as np
from typing import List, Tuple
class SpeculativeDecoder:
"""
Speculative Decoding 实现
使用小模型作为草稿模型,大模型验证
"""
def __init__(self, draft_model, target_model, k=5):
"""
draft_model: 草稿模型(小,快)
target_model: 目标模型(大,准)
k: 每次猜测的 token 数
"""
self.draft = draft_model
self.target = target_model
self.k = k
def sample_token(self, logits, temperature=1.0):
"""从 logits 采样一个 token"""
if temperature == 0:
return np.argmax(logits)
probs = softmax(logits / temperature)
return np.random.choice(len(probs), p=probs)
def draft_generate(self, prefix: List[int], num_tokens: int) -> Tuple[List[int], List[np.ndarray]]:
"""
草稿模型生成候选 token 序列
返回: (tokens列表, 每步的概率分布列表)
"""
tokens = list(prefix)
probs_list = []
for _ in range(num_tokens):
# 草稿模型一步生成
logits = self.draft.predict(tokens)
prob = softmax(logits[-1])
next_token = self.sample_token(logits[-1])
tokens.append(next_token)
probs_list.append(prob)
# 返回新生成的token和对应的概率
return tokens[len(prefix):], probs_list
def verify_candidates(self, prefix: List[int], draft_tokens: List[int],
draft_probs: List[np.ndarray]) -> Tuple[List[int], int]:
"""
大模型并行验证草稿 token
返回: (接受的 tokens, 最终采样位置)
"""
# 大模型一次 forward pass 处理整个序列
full_seq = list(prefix) + draft_tokens
target_logits = self.target.predict(full_seq)
# 取每个位置对应下一个 token 的 logits
# target_logits[t] 是位置 t 预测位置 t+1 的 logits
accepted = []
for i, (token, q_draft) in enumerate(zip(draft_tokens, draft_probs)):
# 大模型在位置 len(prefix)+i 的 logits
q_target = softmax(target_logits[len(prefix) + i])
# 拒绝采样条件
p_draft = q_draft[token]
p_target = q_target[token]
if np.random.random() < min(1.0, p_target / max(p_draft, 1e-10)):
accepted.append(token)
else:
# 从修正分布中采样
adjusted_probs = np.maximum(0, q_target - q_draft)
adjusted_probs /= adjusted_probs.sum()
fallback = np.random.choice(len(adjusted_probs), p=adjusted_probs)
accepted.append(fallback)
return accepted, i + 1 # 返回实际生成位置
# 所有 token 都接受,再额外生成一个 token
q_target = softmax(target_logits[-1])
extra = self.sample_token(target_logits[-1])
accepted.append(extra)
return accepted, len(draft_tokens) + 1
def generate(self, prefix: List[int], max_new_tokens: int) -> List[int]:
"""
使用 Speculative Decoding 生成文本
"""
output = list(prefix)
remaining = max_new_tokens
while remaining > 0:
k = min(self.k, remaining)
# Step 1: 草稿模型生成候选
draft_tokens, draft_probs = self.draft_generate(output, k)
# Step 2: 大模型验证
accepted, accepted_count = self.verify_candidates(
output, draft_tokens, draft_probs
)
output.extend(accepted)
remaining -= accepted_count
return output[len(prefix):]
草稿模型实现(简单版)
为了演示完整的链路,我们实现两个参数量不同的小模型作为"大"和"小"模型。
class SimpleLM:
"""极简语言模型,仅用于演示 Speculative Decoding"""
def __init__(self, vocab_size=1000, d_model=64, num_layers=4):
self.vocab_size = vocab_size
self.d_model = d_model
self.num_layers = num_layers
# embedding + LM head
self.embed = np.random.randn(vocab_size, d_model) * 0.01
self.lm_head = np.random.randn(d_model, vocab_size) * 0.01
# attention layers
self.layers = [
AttentionLayer(d_model, 16) for _ in range(num_layers)
]
# KV Cache for generation
self.kv_caches = [None] * num_layers
def predict(self, token_ids: List[int]) -> np.ndarray:
"""预测下一个 token 的 logits"""
x = self.embed[token_ids]
for i, layer in enumerate(self.layers):
x = layer.forward(x, use_cache=True)
logits = x @ self.lm_head # [seq_len, vocab_size]
return logits
def clear_cache(self):
for layer in self.layers:
layer.clear_cache()
# 创建"大"模型和"小"模型
# 小模型用 1 层,大模型用 6 层
small_model = SimpleLM(vocab_size=1000, d_model=32, num_layers=1)
large_model = SimpleLM(vocab_size=1000, d_model=64, num_layers=6)
Speculative Decoding 基准测试
def run_speculative_benchmark():
# 创建模型
draft = SimpleLM(vocab_size=1000, d_model=32, num_layers=1)
target = SimpleLM(vocab_size=1000, d_model=64, num_layers=6)
decoder = SpeculativeDecoder(draft, target, k=5)
prefix = [42, 128, 256, 512, 768] # 5 token prompt
num_steps = 100
# 基准:纯大模型(带 KV Cache)
target.clear_cache()
start = time.perf_counter()
tokens_baseline = []
seq = list(prefix)
for _ in range(num_steps):
logits = target.predict(seq)
token = np.argmax(logits[-1])
seq.append(token)
tokens_baseline.append(token)
baseline_time = time.perf_counter() - start
# Speculative Decoding
draft.clear_cache()
target.clear_cache()
start = time.perf_counter()
tokens_spec = decoder.generate(prefix, num_steps)
spec_time = time.perf_counter() - start
print(f"基准测试(生成 {num_steps} tokens):")
print(f" Baseline (纯大模型): {baseline_time:.4f}s")
print(f" Speculative Decoding: {spec_time:.4f}s")
print(f" 加速比: {baseline_time/spec_time:.2f}x")
print(f" 输出一致: {tokens_baseline[:10] == tokens_spec[:10]}")
影响加速效果的关键因素
| 因素 | 影响 | 最优值 |
|---|---|---|
| K 值(猜测长度) | K 越大单步收益越高,但接受率可能下降 | 5-8 |
| 草稿模型质量 | 与目标模型越接近,接受率越高 | 参数量 30-50% |
| 温度参数 | 低温时接受率高,高温时采样随机性大 | 0-0.6 |
| 任务类型 | 代码/数学等确定性任务接受率高 | - |
实践中有一个简单有效的调优方法:动态调整 K 值。如果上一步接受率 > 0.8,下一步增大 K;如果接受率 < 0.3,减小 K。
class AdaptiveSpeculativeDecoder(SpeculativeDecoder):
"""自适应 K 值的 Speculative Decoding"""
def __init__(self, draft_model, target_model, k_init=5):
super().__init__(draft_model, target_model, k_init)
self.k = k_init
self.min_k = 2
self.max_k = 10
def generate(self, prefix, max_new_tokens):
output = list(prefix)
remaining = max_new_tokens
while remaining > 0:
k = min(self.k, remaining)
draft_tokens, draft_probs = self.draft_generate(output, k)
accepted, accepted_count = self.verify_candidates(
output, draft_tokens, draft_probs
)
output.extend(accepted)
remaining -= accepted_count
# 动态调整 K
accept_rate = accepted_count / k
if accept_rate > 0.8:
self.k = min(self.max_k, self.k + 1)
elif accept_rate < 0.3:
self.k = max(self.min_k, self.k - 1)
return output[len(prefix):]
三、完整推理引擎
把上面两个技术整合到一起,形成一个完整的推理加速引擎。
class InferenceEngine:
"""
完整推理引擎
集成 KV Cache + Speculative Decoding + 批处理
"""
def __init__(self, target_model, draft_model=None):
self.target = target_model
self.draft = draft_model
self.use_speculative = draft_model is not None
self.spec_decoder = AdaptiveSpeculativeDecoder(
draft_model, target_model
) if draft_model else None
def generate(self, prompt: List[int], max_tokens: int = 256,
temperature: float = 0.7, use_cache: bool = True):
"""生成文本"""
self.target.clear_cache()
if self.use_speculative and max_tokens > 10:
# Speculative Decoding 模式
self.draft.clear_cache()
return self.spec_decoder.generate(prompt, max_tokens)
else:
# 标准模式(带 KV Cache)
seq = list(prompt)
for _ in range(max_tokens):
logits = self.target.predict(seq)
token = SpeculativeDecoder.sample_token(
None, logits[-1], temperature
)
seq.append(token)
return seq[len(prompt):]
def batch_generate(self, prompts: List[List[int]], max_tokens: int,
**kwargs) -> List[List[int]]:
"""批量生成"""
return [self.generate(p, max_tokens, **kwargs) for p in prompts]
生产部署的工程考量
上面的代码展示了核心原理。生产环境中还需要考虑:
- Continuous Batching:不同请求的生成进度不同,动态调度 GPU 计算
- PagedAttention:将 KV Cache 分页管理,消除显存碎片(vLLM 的核心贡献)
- Flash Attention:通过分块计算和重排,减少 HBM 访问
- INT8/FP8 量化:降低权重和 KV Cache 的内存占用
- Prefix Caching:相同前缀的 prompt 共享 KV Cache
| 技术 | 加速效果 | 实现成本 |
|---|---|---|
| KV Cache | 3-15x(长序列) | 低 |
| Speculative Decoding | 1.5-3x | 中(需草稿模型) |
| PagedAttention | 2-4x(吞吐) | 中 |
| Flash Attention | 1.5-2x | 高 |
| INT8 量化 | 2x(吞吐) | 中 |
四、对比实验一键运行
def full_benchmark():
print("=" * 60)
print("AI 推理加速引擎基准测试")
print("=" * 60)
# 创建模型
draft = SimpleLM(vocab_size=1000, d_model=32, num_layers=1)
target = SimpleLM(vocab_size=1000, d_model=64, num_layers=6)
engine = InferenceEngine(target, draft)
prompt = [42, 128, 256, 512, 768]
n_tokens = 200
# 1. 无 KV Cache
print("\n[1/4] 无 KV Cache...")
target.clear_cache()
start = time.perf_counter()
seq = list(prompt)
for _ in range(n_tokens):
for layer in target.layers:
layer.clear_cache()
x = target.embed[seq]
for layer in target.layers:
x = layer.forward(x, use_cache=False)
token = np.argmax(x[-1] @ target.lm_head)
seq.append(token)
t1 = time.perf_counter() - start
# 2. 有 KV Cache
print("[2/4] 有 KV Cache...")
target.clear_cache()
start = time.perf_counter()
seq = list(prompt)
for _ in range(n_tokens):
logits = target.predict(seq)
token = np.argmax(logits[-1])
seq.append(token)
t2 = time.perf_counter() - start
# 3. Speculative Decoding
print("[3/4] Speculative Decoding...")
draft.clear_cache()
target.clear_cache()
start = time.perf_counter()
tokens = engine.spec_decoder.generate(prompt, n_tokens)
t3 = time.perf_counter() - start
# 4. 自适应 Speculative Decoding
print("[4/4] 自适应 Speculative Decoding...")
draft.clear_cache()
target.clear_cache()
adaptive = AdaptiveSpeculativeDecoder(draft, target, k_init=3)
start = time.perf_counter()
tokens = adaptive.generate(prompt, n_tokens)
t4 = time.perf_counter() - start
print("\n" + "=" * 60)
print("结果汇总(生成 %d tokens)" % n_tokens)
print("=" * 60)
print(f" ① 无 KV Cache: {t1:.3f}s (1.0x baseline)")
print(f" ② 有 KV Cache: {t2:.3f}s ({t1/t2:.1f}x)")
print(f" ③ Speculative: {t3:.3f}s ({t1/t3:.1f}x)")
print(f" ④ 自适应 Spec: {t4:.3f}s ({t1/t4:.1f}x)")
if __name__ == "__main__":
full_benchmark()
总结
本文从零实现了两大推理加速技术:
- KV Cache:缓存历史 K/V 矩阵,消除 attention 的重复计算。实现简单,加速效果 5-15x,是所有推理框架的标配。代价是额外的显存开销。
- Speculative Decoding:用小模型草稿+大模型验证打破串行瓶颈。在保持输出分布不变的前提下获得 1.5-3x 加速。需要额外维护一个草稿模型。
核心收获:理解这些底层原理后,使用 vLLM、TensorRT-LLM 等框架时会更清楚每一步在做什么,排查性能问题也能抓住根因。
进一步探索
- 阅读 vLLM 源码中的 PagedAttention 实现
- 了解 Medusa 等无草稿模型的 Speculative Decoding 变体
- 尝试将 KV Cache 量化为 INT8 并观察精度损失
- 对实际模型(如 LLaMA 系列)应用本文代码进行基准测试