【infra之路】Transformer 核心计算流

无论是搞大模型应用、RAG,还是深入到基础设施层的 vLLM 推理优化,Self-Attention 的矩阵乘法维度变化KV Cache 的底层原理都是最核心的硬核知识。

我们可以把整个计算流程拆成两部分:第一部分是没有 Cache 时的标准 Self-Attention 维度推导 (通常在训练或 Prefill 阶段),第二部分是加入 KV Cache 后的推理加速流


一、 标准 Self-Attention 核心计算流(无 Cache)

假设我们输入一个句子的 Token 序列,通过 Embedding 层后得到了输入矩阵 X X X。为了让维度变化最清晰,我们先隐去 Batch Size ( B B B) 和多头注意力(Multi-Head)的干扰 ,只看单头在序列和特征维度上的变化。

1. 基础符号定义

  • N N N: 序列长度 (Sequence Length,即 Token 数量)
  • d d d: 隐藏层维度 (Embedding Dimension / Hidden Size)

此时,输入矩阵 X X X 的维度是:( N , d ) (N, d) (N,d)

2. 生成 Q , K , V Q, K, V Q,K,V 矩阵

输入 X X X 分别乘以三个可学习的权重矩阵 W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV(它们的维度都是 ( d , d ) (d, d) (d,d)):

Q = X ⋅ W Q → ( N , d ) × ( d , d ) = ( N , d ) Q = X \cdot W_Q \quad \rightarrow \quad (N, d) \times (d, d) = \mathbf{(N, d)} Q=X⋅WQ→(N,d)×(d,d)=(N,d)

K = X ⋅ W K → ( N , d ) × ( d , d ) = ( N , d ) K = X \cdot W_K \quad \rightarrow \quad (N, d) \times (d, d) = \mathbf{(N, d)} K=X⋅WK→(N,d)×(d,d)=(N,d)

V = X ⋅ W V → ( N , d ) × ( d , d ) = ( N , d ) V = X \cdot W_V \quad \rightarrow \quad (N, d) \times (d, d) = \mathbf{(N, d)} V=X⋅WV→(N,d)×(d,d)=(N,d)

3. 核心双矩阵乘法(Attention 经典公式)

Attention 的公式大家耳熟能详:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

我们拆解这两次关键的矩阵乘法:

第一次矩阵乘法: Q × K T Q \times K^T Q×KT(计算注意力分数)
  • Q Q Q 的维度: ( N , d ) (N, d) (N,d)
  • K T K^T KT 的维度: ( d , N ) (d, N) (d,N)
  • 计算: ( N , d ) × ( d , N ) = ( N , N ) (N, d) \times (d, N) = \mathbf{(N, N)} (N,d)×(d,N)=(N,N)

💡 本质理解 :这个 ( N , N ) (N, N) (N,N) 的矩阵就是注意力权重矩阵(Attention Maps) 。矩阵中的第 i i i 行第 j j j 列,代表第 i i i 个 Token 对第 j j j 个 Token 的关注度。经过 Softmax 归一化后,它变成了一个每一行求和都为 1 的权重概率矩阵。

第二次矩阵乘法: Score × V \text{Score} \times V Score×V(加权语义融合)
  • Score \text{Score} Score (经过 Softmax) 的维度: ( N , N ) (N, N) (N,N)
  • V V V 的维度: ( N , d ) (N, d) (N,d)
  • 计算: ( N , N ) × ( N , d ) = ( N , d ) (N, N) \times (N, d) = \mathbf{(N, d)} (N,N)×(N,d)=(N,d)

💡 本质理解 :利用刚才算出来的 ( N , N ) (N, N) (N,N) 相互关系权重,对包含了实际语义信息的 V V V 矩阵进行加权求和。最终输出的维度依然是 ( N , d ) (N, d) (N,d) ,完美和输入 X X X 保持一致,这就使得 Transformer 可以堆叠很多层。


二、 为什么需要 KV Cache?(推理痛点)

大模型在生成(Generation/Decoding)阶段 ,是自回归(Autoregressive)的------也就是说,每次只生成一个新 Token

传统的"傻瓜式"计算流(无 Cache)

假设当前已经生成了 4 个 Token,准备生成第 5 个 Token:

  1. 把前 4 个 Token 和第 5 个 Token 拼在一起,组成长度为 5 的序列输入 X n e w X_{new} Xnew,维度 ( 5 , d ) (5, d) (5,d)。
  2. 重新计算 Q , K , V Q, K, V Q,K,V(维度都是 ( 5 , d ) (5, d) (5,d))。
  3. 重新做 Q × K T → ( 5 , 5 ) Q \times K^T \rightarrow (5, 5) Q×KT→(5,5)。
  4. 重新乘以 V → ( 5 , d ) V \rightarrow (5, d) V→(5,d)。

痛点在哪里?

当你生成第 1000 个 Token 时,前 999 个 Token 已经生成过了,它们的语义并没有变。但为了算出第 1000 个 Token,你不得不把前 999 个 Token 重新转成 K K K 和 V V V 参与整场矩阵乘法。

这带来了巨大的 O ( N 2 ) O(N^2) O(N2) 计算浪费 ,并且由于频繁搬运历史数据,推理会极快地卡在 Memory Bound(内存带宽瓶颈)


三、 加入 KV Cache 后的矩阵乘法流(逐 Token 解密)

KV Cache 的核心逻辑: 既然前面的 Token 不变,那它们映射出来的 K K K 和 V V V 矩阵也绝对不会变。我们把它们缓存(Cache)在显存里,每次新进来一个 Token,只算这一个 Token 的 Q , K , V Q, K, V Q,K,V。

我们来看看此时奇妙的维度变化:

1. 状态定义

  • 假设历史已经处理了 t t t 个 Token(它们对应的 K c a c h e K_{cache} Kcache 和 V c a c h e V_{cache} Vcache 已经存在显存里)。
  • 当前步(Current Step)只输入 1 个新 Token 。此时的输入 X c u r r e n t X_{current} Xcurrent 维度是 ( 1 , d ) (1, d) (1,d)

2. 生成当前步的 Q , K , V Q, K, V Q,K,V

  • Q c u r r e n t = X c u r r e n t ⋅ W Q → ( 1 , d ) × ( d , d ) = ( 1 , d ) Q_{current} = X_{current} \cdot W_Q \rightarrow (1, d) \times (d, d) = \mathbf{(1, d)} Qcurrent=Xcurrent⋅WQ→(1,d)×(d,d)=(1,d)
  • K c u r r e n t = X c u r r e n t ⋅ W K → ( 1 , d ) × ( d , d ) = ( 1 , d ) K_{current} = X_{current} \cdot W_K \rightarrow (1, d) \times (d, d) = \mathbf{(1, d)} Kcurrent=Xcurrent⋅WK→(1,d)×(d,d)=(1,d)
  • V c u r r e n t = X c u r r e n t ⋅ W V → ( 1 , d ) × ( d , d ) = ( 1 , d ) V_{current} = X_{current} \cdot W_V \rightarrow (1, d) \times (d, d) = \mathbf{(1, d)} Vcurrent=Xcurrent⋅WV→(1,d)×(d,d)=(1,d)

3. 更新 Cache(拼接)

把新算出来的这 1 个 Token 的 K , V K, V K,V 拼接到历史缓存中:

  • K c a c h e d _ n e w = Concat ( K c a c h e , K c u r r e n t ) → ( t + 1 , d ) K_{cached\new} = \text{Concat}(K{cache}, K_{current}) \quad \rightarrow \quad \mathbf{(t+1, d)} Kcached_new=Concat(Kcache,Kcurrent)→(t+1,d)
  • V c a c h e d _ n e w = Concat ( V c a c h e , V c u r r e n t ) → ( t + 1 , d ) V_{cached\new} = \text{Concat}(V{cache}, V_{current}) \quad \rightarrow \quad \mathbf{(t+1, d)} Vcached_new=Concat(Vcache,Vcurrent)→(t+1,d)

4. 关键:带 Cache 的矩阵乘法

现在,我们用只有 1 行 的 Q c u r r e n t Q_{current} Qcurrent,去和更新后的 K , V K, V K,V Cache 做运算。

第一次矩阵乘法: Q c u r r e n t × K c a c h e d _ n e w T Q_{current} \times K_{cached\_new}^T Qcurrent×Kcached_newT
  • Q c u r r e n t Q_{current} Qcurrent 维度: ( 1 , d ) (1, d) (1,d)
  • K c a c h e d _ n e w T K_{cached\_new}^T Kcached_newT 维度: ( d , t + 1 ) (d, t+1) (d,t+1)
  • 计算: ( 1 , d ) × ( d , t + 1 ) = ( 1 , t + 1 ) (1, d) \times (d, t+1) = \mathbf{(1, t+1)} (1,d)×(d,t+1)=(1,t+1)

💡 维度减负 :原先不加 Cache 时,这里算出来的是 ( t + 1 , t + 1 ) (t+1, t+1) (t+1,t+1) 的庞大方阵。现在由于 Q Q Q 只有 1 行,算出来的 Attention Score 变成了 ( 1 , t + 1 ) (1, t+1) (1,t+1) 的一行向量 !它代表这 1 个新 Token 对过去所有 t + 1 t+1 t+1 个 Token 的注意力权重。

第二次矩阵乘法: Score × V c a c h e d _ n e w \text{Score} \times V_{cached\_new} Score×Vcached_new
  • Score \text{Score} Score (经过 Softmax) 维度: ( 1 , t + 1 ) (1, t+1) (1,t+1)
  • V c a c h e d _ n e w V_{cached\_new} Vcached_new 维度: ( t + 1 , d ) (t+1, d) (t+1,d)
  • 计算: ( 1 , t + 1 ) × ( t + 1 , d ) = ( 1 , d ) (1, t+1) \times (t+1, d) = \mathbf{(1, d)} (1,t+1)×(t+1,d)=(1,d)

最终输出维度是 ( 1 , d ) (1, d) (1,d),恰好就是当前步预测下一个 Token 所需要的全部特征向量。


总结:KV Cache 为什么能加速?

指标 传统无 Cache 模式 (每个 Token 生成时) 加入 KV Cache 模式 (Decoding 阶段)
输入 Q Q Q 的行数 t + 1 t+1 t+1 (把历史所有人拉过来一起算) 1 1 1 (只关心当前最新的这个 Token)
Attention 矩阵维度 ( t + 1 , t + 1 ) (t+1, t+1) (t+1,t+1) (大方阵,计算量随序列线性暴涨) ( 1 , t + 1 ) (1, t+1) (1,t+1) (扁平的行向量,省下极多计算)
核心计算开销 O ( ( t + 1 ) 2 ⋅ d ) O((t+1)^2 \cdot d) O((t+1)2⋅d) O ( ( t + 1 ) ⋅ d ) O((t+1) \cdot d) O((t+1)⋅d)

最后一句话总结 KV Cache 的本质:

KV Cache 将 Decoding 阶段的计算复杂度从 平方阶 O ( N 2 ) O(N^2) O(N2) 降到了 线性阶 O ( N ) O(N) O(N) 。它通过空间换时间 的策略,让大模型在每次生成新 Token 时,不再做重复的矩阵历史映射,只需要让"当前步的一行 Q Q Q"去和"历史所有的 K , V K, V K,V 缓存"做快速内积,从而实现了数十倍的推理速度提升。

相关推荐
huangdong_15 小时前
电商图片智能分类算法:主图/属性图/详情图自动识别技术
人工智能·分类·数据挖掘
电商API_1800790524715 小时前
价格波动预警|用API实时监控淘宝京东商品价格,实现自动化竞品调价与捡漏
大数据·运维·数据库·人工智能·数据挖掘·自动化
美狐美颜sdk15 小时前
直播APP开发如何实现美颜功能?低成本美颜SDK方案推荐
android·人工智能·ios·第三方美颜sdk·视频美颜sdk
码农阿强15 小时前
DeepSeek-V4 Flash/Pro 技术深度解析:成本下降与场景适配
人工智能·ai·aigc·个人开发
AI行业学习15 小时前
CC-Switch Windows + macOS 下载安装配置全流程
java·开发语言·人工智能·python
LT101579744415 小时前
2026年性能测试平台报告生成:专业可视化与合规适配指南
大数据·数据库·人工智能
kjmkq16 小时前
2026实战效果优选GEO服务商测评:效果好+服务优首选合作
大数据·人工智能
明志数科16 小时前
机器人数据采集方案设计:从场景到落地的完整指南
人工智能·数据挖掘
neocheng_52216 小时前
周末独处充电,深耕AI技能打造长期竞争力
人工智能