无论是搞大模型应用、RAG,还是深入到基础设施层的 vLLM 推理优化,Self-Attention 的矩阵乘法维度变化 和 KV Cache 的底层原理都是最核心的硬核知识。
我们可以把整个计算流程拆成两部分:第一部分是没有 Cache 时的标准 Self-Attention 维度推导 (通常在训练或 Prefill 阶段),第二部分是加入 KV Cache 后的推理加速流。
一、 标准 Self-Attention 核心计算流(无 Cache)
假设我们输入一个句子的 Token 序列,通过 Embedding 层后得到了输入矩阵 X X X。为了让维度变化最清晰,我们先隐去 Batch Size ( B B B) 和多头注意力(Multi-Head)的干扰 ,只看单头在序列和特征维度上的变化。

1. 基础符号定义
- N N N: 序列长度 (Sequence Length,即 Token 数量)
- d d d: 隐藏层维度 (Embedding Dimension / Hidden Size)
此时,输入矩阵 X X X 的维度是:( N , d ) (N, d) (N,d)。
2. 生成 Q , K , V Q, K, V Q,K,V 矩阵
输入 X X X 分别乘以三个可学习的权重矩阵 W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV(它们的维度都是 ( d , d ) (d, d) (d,d)):
Q = X ⋅ W Q → ( N , d ) × ( d , d ) = ( N , d ) Q = X \cdot W_Q \quad \rightarrow \quad (N, d) \times (d, d) = \mathbf{(N, d)} Q=X⋅WQ→(N,d)×(d,d)=(N,d)
K = X ⋅ W K → ( N , d ) × ( d , d ) = ( N , d ) K = X \cdot W_K \quad \rightarrow \quad (N, d) \times (d, d) = \mathbf{(N, d)} K=X⋅WK→(N,d)×(d,d)=(N,d)
V = X ⋅ W V → ( N , d ) × ( d , d ) = ( N , d ) V = X \cdot W_V \quad \rightarrow \quad (N, d) \times (d, d) = \mathbf{(N, d)} V=X⋅WV→(N,d)×(d,d)=(N,d)
3. 核心双矩阵乘法(Attention 经典公式)
Attention 的公式大家耳熟能详:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
我们拆解这两次关键的矩阵乘法:
第一次矩阵乘法: Q × K T Q \times K^T Q×KT(计算注意力分数)
- Q Q Q 的维度: ( N , d ) (N, d) (N,d)
- K T K^T KT 的维度: ( d , N ) (d, N) (d,N)
- 计算: ( N , d ) × ( d , N ) = ( N , N ) (N, d) \times (d, N) = \mathbf{(N, N)} (N,d)×(d,N)=(N,N)
💡 本质理解 :这个 ( N , N ) (N, N) (N,N) 的矩阵就是注意力权重矩阵(Attention Maps) 。矩阵中的第 i i i 行第 j j j 列,代表第 i i i 个 Token 对第 j j j 个 Token 的关注度。经过 Softmax 归一化后,它变成了一个每一行求和都为 1 的权重概率矩阵。
第二次矩阵乘法: Score × V \text{Score} \times V Score×V(加权语义融合)
- Score \text{Score} Score (经过 Softmax) 的维度: ( N , N ) (N, N) (N,N)
- V V V 的维度: ( N , d ) (N, d) (N,d)
- 计算: ( N , N ) × ( N , d ) = ( N , d ) (N, N) \times (N, d) = \mathbf{(N, d)} (N,N)×(N,d)=(N,d)
💡 本质理解 :利用刚才算出来的 ( N , N ) (N, N) (N,N) 相互关系权重,对包含了实际语义信息的 V V V 矩阵进行加权求和。最终输出的维度依然是 ( N , d ) (N, d) (N,d) ,完美和输入 X X X 保持一致,这就使得 Transformer 可以堆叠很多层。
二、 为什么需要 KV Cache?(推理痛点)
大模型在生成(Generation/Decoding)阶段 ,是自回归(Autoregressive)的------也就是说,每次只生成一个新 Token。
传统的"傻瓜式"计算流(无 Cache)
假设当前已经生成了 4 个 Token,准备生成第 5 个 Token:
- 把前 4 个 Token 和第 5 个 Token 拼在一起,组成长度为 5 的序列输入 X n e w X_{new} Xnew,维度 ( 5 , d ) (5, d) (5,d)。
- 重新计算 Q , K , V Q, K, V Q,K,V(维度都是 ( 5 , d ) (5, d) (5,d))。
- 重新做 Q × K T → ( 5 , 5 ) Q \times K^T \rightarrow (5, 5) Q×KT→(5,5)。
- 重新乘以 V → ( 5 , d ) V \rightarrow (5, d) V→(5,d)。
痛点在哪里?
当你生成第 1000 个 Token 时,前 999 个 Token 已经生成过了,它们的语义并没有变。但为了算出第 1000 个 Token,你不得不把前 999 个 Token 重新转成 K K K 和 V V V 参与整场矩阵乘法。
这带来了巨大的 O ( N 2 ) O(N^2) O(N2) 计算浪费 ,并且由于频繁搬运历史数据,推理会极快地卡在 Memory Bound(内存带宽瓶颈)。
三、 加入 KV Cache 后的矩阵乘法流(逐 Token 解密)
KV Cache 的核心逻辑: 既然前面的 Token 不变,那它们映射出来的 K K K 和 V V V 矩阵也绝对不会变。我们把它们缓存(Cache)在显存里,每次新进来一个 Token,只算这一个 Token 的 Q , K , V Q, K, V Q,K,V。
我们来看看此时奇妙的维度变化:
1. 状态定义
- 假设历史已经处理了 t t t 个 Token(它们对应的 K c a c h e K_{cache} Kcache 和 V c a c h e V_{cache} Vcache 已经存在显存里)。
- 当前步(Current Step)只输入 1 个新 Token 。此时的输入 X c u r r e n t X_{current} Xcurrent 维度是 ( 1 , d ) (1, d) (1,d)。
2. 生成当前步的 Q , K , V Q, K, V Q,K,V
- Q c u r r e n t = X c u r r e n t ⋅ W Q → ( 1 , d ) × ( d , d ) = ( 1 , d ) Q_{current} = X_{current} \cdot W_Q \rightarrow (1, d) \times (d, d) = \mathbf{(1, d)} Qcurrent=Xcurrent⋅WQ→(1,d)×(d,d)=(1,d)
- K c u r r e n t = X c u r r e n t ⋅ W K → ( 1 , d ) × ( d , d ) = ( 1 , d ) K_{current} = X_{current} \cdot W_K \rightarrow (1, d) \times (d, d) = \mathbf{(1, d)} Kcurrent=Xcurrent⋅WK→(1,d)×(d,d)=(1,d)
- V c u r r e n t = X c u r r e n t ⋅ W V → ( 1 , d ) × ( d , d ) = ( 1 , d ) V_{current} = X_{current} \cdot W_V \rightarrow (1, d) \times (d, d) = \mathbf{(1, d)} Vcurrent=Xcurrent⋅WV→(1,d)×(d,d)=(1,d)
3. 更新 Cache(拼接)
把新算出来的这 1 个 Token 的 K , V K, V K,V 拼接到历史缓存中:
- K c a c h e d _ n e w = Concat ( K c a c h e , K c u r r e n t ) → ( t + 1 , d ) K_{cached\new} = \text{Concat}(K{cache}, K_{current}) \quad \rightarrow \quad \mathbf{(t+1, d)} Kcached_new=Concat(Kcache,Kcurrent)→(t+1,d)
- V c a c h e d _ n e w = Concat ( V c a c h e , V c u r r e n t ) → ( t + 1 , d ) V_{cached\new} = \text{Concat}(V{cache}, V_{current}) \quad \rightarrow \quad \mathbf{(t+1, d)} Vcached_new=Concat(Vcache,Vcurrent)→(t+1,d)
4. 关键:带 Cache 的矩阵乘法
现在,我们用只有 1 行 的 Q c u r r e n t Q_{current} Qcurrent,去和更新后的 K , V K, V K,V Cache 做运算。
第一次矩阵乘法: Q c u r r e n t × K c a c h e d _ n e w T Q_{current} \times K_{cached\_new}^T Qcurrent×Kcached_newT
- Q c u r r e n t Q_{current} Qcurrent 维度: ( 1 , d ) (1, d) (1,d)
- K c a c h e d _ n e w T K_{cached\_new}^T Kcached_newT 维度: ( d , t + 1 ) (d, t+1) (d,t+1)
- 计算: ( 1 , d ) × ( d , t + 1 ) = ( 1 , t + 1 ) (1, d) \times (d, t+1) = \mathbf{(1, t+1)} (1,d)×(d,t+1)=(1,t+1)
💡 维度减负 :原先不加 Cache 时,这里算出来的是 ( t + 1 , t + 1 ) (t+1, t+1) (t+1,t+1) 的庞大方阵。现在由于 Q Q Q 只有 1 行,算出来的 Attention Score 变成了 ( 1 , t + 1 ) (1, t+1) (1,t+1) 的一行向量 !它代表这 1 个新 Token 对过去所有 t + 1 t+1 t+1 个 Token 的注意力权重。
第二次矩阵乘法: Score × V c a c h e d _ n e w \text{Score} \times V_{cached\_new} Score×Vcached_new
- Score \text{Score} Score (经过 Softmax) 维度: ( 1 , t + 1 ) (1, t+1) (1,t+1)
- V c a c h e d _ n e w V_{cached\_new} Vcached_new 维度: ( t + 1 , d ) (t+1, d) (t+1,d)
- 计算: ( 1 , t + 1 ) × ( t + 1 , d ) = ( 1 , d ) (1, t+1) \times (t+1, d) = \mathbf{(1, d)} (1,t+1)×(t+1,d)=(1,d)
最终输出维度是 ( 1 , d ) (1, d) (1,d),恰好就是当前步预测下一个 Token 所需要的全部特征向量。
总结:KV Cache 为什么能加速?
| 指标 | 传统无 Cache 模式 (每个 Token 生成时) | 加入 KV Cache 模式 (Decoding 阶段) |
|---|---|---|
| 输入 Q Q Q 的行数 | t + 1 t+1 t+1 (把历史所有人拉过来一起算) | 1 1 1 (只关心当前最新的这个 Token) |
| Attention 矩阵维度 | ( t + 1 , t + 1 ) (t+1, t+1) (t+1,t+1) (大方阵,计算量随序列线性暴涨) | ( 1 , t + 1 ) (1, t+1) (1,t+1) (扁平的行向量,省下极多计算) |
| 核心计算开销 | O ( ( t + 1 ) 2 ⋅ d ) O((t+1)^2 \cdot d) O((t+1)2⋅d) | O ( ( t + 1 ) ⋅ d ) O((t+1) \cdot d) O((t+1)⋅d) |
最后一句话总结 KV Cache 的本质:
KV Cache 将 Decoding 阶段的计算复杂度从 平方阶 O ( N 2 ) O(N^2) O(N2) 降到了 线性阶 O ( N ) O(N) O(N) 。它通过空间换时间 的策略,让大模型在每次生成新 Token 时,不再做重复的矩阵历史映射,只需要让"当前步的一行 Q Q Q"去和"历史所有的 K , V K, V K,V 缓存"做快速内积,从而实现了数十倍的推理速度提升。