【infra】kv cache, flash attn

KV Cache 的数学建模与原理解析

简单来说,它的本质是"空间换时间"。在 Transformer 模型逐字生成文本(Autoregressive Decoding)时,模型为了生成当前 Token,需要计算当前 Token 与之前所有已生成 Token 的注意力权重。如果没有 KV Cache,模型在每生成一个新词时,都要把前面所有的词重新计算一遍,这会造成极大的算力浪费。

为了避免这种重复计算,系统会将之前计算过的所有 Token 的 Key (K)Value (V) 向量缓存下来。当生成新 Token 时,只需要计算当前新 Token 的 Q、K、V,然后提取缓存中历史 Token 的 K 和 V 进行注意力计算。

将时间复杂度从 O(t2)O(t^2)O(t2) 降低至 O(t)O(t)O(t),空间复杂度是从 O(1)O(1)O(1) 变成了 O(t)O(t)O(t)。空间是指显存的占用,但是一直占用还是比直接算的临时变量占用要小的。

LLM生成的过程中,因为是自回归的,所以Q的seq_len永远是1,KV充当memory的作用,seq_len是当前生成的序列长度

1. 自回归解码的数学表示与计算冗余

在标准的 Transformer 架构中,注意力机制(Scaled Dot-Product Attention)的通用公式定义为:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

对于自回归生成任务,假设当前模型已经生成了长度为 t−1t-1t−1 的序列,其特征矩阵表示为 X1:t−1∈R(t−1)×dX_{1:t-1} \in \mathbb{R}^{(t-1) \times d}X1:t−1∈R(t−1)×d(其中 ddd 为隐层维度)。

当预测第 ttt 个 Token 时,如果不加优化,模型会将当前所有已知的 ttt 个 Token 一起输入进行线性映射:

Q1:t=X1:tWQ,K1:t=X1:tWK,V1:t=X1:tWV Q_{1:t} = X_{1:t}W_Q, \quad K_{1:t} = X_{1:t}W_K, \quad V_{1:t} = X_{1:t}W_V Q1:t=X1:tWQ,K1:t=X1:tWK,V1:t=X1:tWV

此时,K1:tK_{1:t}K1:t 和 V1:tV_{1:t}V1:t 的维度均为 t×dkt \times d_kt×dk。由于因果掩码(Causal Mask)的存在,前 t−1t-1t−1 个 Token 的状态计算包含了大量的重复映射。为了消除这种冗余,我们引入了 KV Cache 机制。

2. KV Cache 的状态更新方程(维度拆解)

在 KV Cache 机制下,我们将推理过程解耦为增量计算 。在预测第 ttt 个 Token 时,模型实际上只需要处理当前最新生成的那一个 Token ,即 xt∈R1×dx_t \in \mathbb{R}^{1 \times d}xt∈R1×d。

2.1 查询向量(Query)的瞬时性建模

由于当前的目标仅仅是基于 xtx_txt 提取历史信息以预测下一个词,因此我们需要且仅需要计算当前时间步的 Query 向量:

qt=xtWQ q_t = x_t W_Q qt=xtWQ

此时,qtq_tqt 的维度严格为 1×dk1 \times d_k1×dk。这在数学上表征了当前单步的"瞬时查询意图"。

2.2 键值对(Key-Value)的增量拼接

针对当前的输入 xtx_txt,我们同样计算其对应的单步 Key 和 Value:

kt=xtWK∈R1×dk k_t = x_t W_K \in \mathbb{R}^{1 \times d_k} kt=xtWK∈R1×dk

vt=xtWV∈R1×dv v_t = x_t W_V \in \mathbb{R}^{1 \times d_v} vt=xtWV∈R1×dv

此时,显存中已经维护了前 t−1t-1t−1 步的缓存矩阵 K≤t−1K_{\le t-1}K≤t−1 和 V≤t−1V_{\le t-1}V≤t−1。状态更新方程表现为矩阵在序列长度维度(Sequence Length Dimension)上的拼接(Concatenation):

K≤t=Concat(K≤t−1,kt)∈Rt×dk K_{\le t} = \text{Concat}(K_{\le t-1}, k_t) \in \mathbb{R}^{t \times d_k} K≤t=Concat(K≤t−1,kt)∈Rt×dk

V≤t=Concat(V≤t−1,vt)∈Rt×dv V_{\le t} = \text{Concat}(V_{\le t-1}, v_t) \in \mathbb{R}^{t \times d_v} V≤t=Concat(V≤t−1,vt)∈Rt×dv

2.3 单步 Attention 结果计算

基于更新后的矩阵,第 ttt 步的注意力输出计算如下:

ht=softmax(qtK≤tTdk)V≤t h_t = \text{softmax}\left(\frac{q_t K_{\le t}^T}{\sqrt{d_k}}\right) V_{\le t} ht=softmax(dk qtK≤tT)V≤t

维度验算:

  1. 投影计算:qt (1×dk)×K≤tT (dk×t)=Scorest (1×t)q_t \ (1 \times d_k) \times K_{\le t}^T \ (d_k \times t) = \text{Scores}_t \ (1 \times t)qt (1×dk)×K≤tT (dk×t)=Scorest (1×t)
  2. 加权求和:Scorest (1×t)×V≤t (t×dv)=ht (1×dv)\text{Scores}t \ (1 \times t) \times V{\le t} \ (t \times d_v) = h_t \ (1 \times d_v)Scorest (1×t)×V≤t (t×dv)=ht (1×dv)

最终输出的 hth_tht 恰好为一个 Token 的隐特征表示,直接用于后续 FFN 层的计算。

3. 为什么 KV 可以复用?(不变性证明)

KV Cache 成立的理论基础在于自回归机制下的状态局部不变性

对于序列中的任意历史位置 iii(i<ti < ti<t),其输入特征为 xix_ixi。在因果掩码的约束下,xix_ixi 的生成仅依赖于 X1:iX_{1:i}X1:i,与未来的输入 X>iX_{>i}X>i 严格解耦。

同时,投影矩阵 WKW_KWK 和 WVW_VWV 在推理阶段是全局共享且冻结的常数矩阵。

因此,映射函数 fK(xi)=xiWKf_K(x_i) = x_i W_KfK(xi)=xiWK 和 fV(xi)=xiWVf_V(x_i) = x_i W_VfV(xi)=xiWV 满足:

∀t>i,∂ki∂xt=0,∂vi∂xt=0 \forall t > i, \quad \frac{\partial k_i}{\partial x_t} = 0, \quad \frac{\partial v_i}{\partial x_t} = 0 ∀t>i,∂xt∂ki=0,∂xt∂vi=0

这意味着,无论序列如何向后延伸,历史位置 iii 所计算出的 kik_iki 和 viv_ivi 恒定不变 。这种数学上的绝对静态性,使得我们可以将它们直接缓存在 GPU 的显存中,避免 O(t2)O(t^2)O(t2) 的重复矩阵乘法计算。

4. 为什么没有 Q Cache?(查询的映射解耦)

理解不缓存 QQQ 的核心在于明确 Attention 机制的数学本质:Attention 本质上是当前状态与历史状态之间的相似度度量(内积空间投影)。

在时刻 ttt,注意力得分矩阵的第 ttt 行元素构成为:

Scoret,j=qt⋅kjTdk(j≤t) \text{Score}_{t, j} = \frac{q_t \cdot k_j^T}{\sqrt{d_k}} \quad (j \le t) Scoret,j=dk qt⋅kjT(j≤t)

我们可以观察到:

  1. qtq_tqt 的生命周期仅限于当前时刻 ttt 。它是用来衡量 xtx_txt 与历史 K≤tK_{\le t}K≤t 之间相关性的算子。
  2. 当时间步推进到 t+1t+1t+1 时,模型需要度量的是新输入 xt+1x_{t+1}xt+1 与历史信息的相似度,即计算 qt+1⋅K≤t+1Tq_{t+1} \cdot K_{\le t+1}^Tqt+1⋅K≤t+1T。
  3. 历史查询向量集合 Q<t=[q1,q2,...,qt−1]Q_{<t} = [q_1, q_2, \dots, q_{t-1}]Q<t=[q1,q2,...,qt−1] 参与的是过去时刻的输出计算,它们不参与 第 ttt 时刻的任何代数运算。

结论:
KKK 和 VVV 构成了全局累加的上下文特征基(Contextual Basis) ,需要被持久化保存;而 QQQ 则是游离于特征基之外的瞬时算子(Instantaneous Operator) 。用完即弃的数学属性决定了无需且不能为 QQQ 分配显存缓存。


关于 KV Cache 的必须了解的核心知识

为了应对工业级的大模型部署,尤其是需要处理超长上下文或多模态任务时,以下几个维度是必须掌握的:

1. 推理的两个阶段:Prefill 与 Decode

理解 KV Cache,首先要明白 LLM 推理的两个截然不同的阶段:

  • Prefill(预填充阶段): 处理用户的输入 Prompt。此时所有 Token 都是已知的,模型会并行计算所有输入 Token 的 K 和 V,并存入 KV Cache。这个阶段是计算密集型(Compute-Bound),GPU 算力拉满。
  • Decode(解码阶段): 模型开始逐字生成回复。此时模型需要频繁从显存(HBM)中把过去所有 Token 的 KV Cache 搬运到计算单元(SRAM)中参与计算。由于每次只生成一个 Token,计算量很小,但数据搬运量极大。因此这个阶段是访存密集型(Memory-Bound)

prefill就是用户提问计算嵌入的过程,decode就是llm开始生成answer的过程。

2. KV Cache 的显存占用计算

KV Cache 是 LLM 爆显存的罪魁祸首之一。随着序列长度(Sequence Length)的增加,它会线性膨胀。一个 Token 的 KV Cache 占用字节数计算公式如下:

Bytes per Token=2×L×Hnum×Hdim×B×Pbytes\text{Bytes per Token} = 2 \times L \times H_{\text{num}} \times H_{\text{dim}} \times B \times P_{\text{bytes}}Bytes per Token=2×L×Hnum×Hdim×B×Pbytes

  • 222:代表 K 和 V 两个矩阵。
  • LLL:Transformer 的层数(Layers)。
  • HnumH_{\text{num}}Hnum:注意力头数(Heads)。
  • HdimH_{\text{dim}}Hdim:每个头的维度(Head Dimension)。
  • BBB:Batch Size。
  • PbytesP_{\text{bytes}}Pbytes:精度占用的字节数(例如 FP16 为 2 bytes)。

举例: 在标准 LLaMA-2-7B(FP16)中,如果 Batch Size 为 1,每个 Token 的 KV Cache 大约占用 0.5 MB。如果上下文长度达到 128k,仅单条请求的 KV Cache 就需要 64 GB 的显存。

3. 工业界应对 KV Cache 压力的三大优化方向

在实际的工程部署中,直接使用朴素的 KV Cache 是不可行的,必须结合以下优化技术:

A. 架构层面:MQA 与 GQA

为了从模型结构上减少 KV Cache 的大小,现在的主流模型(如 LLaMA-3, Qwen)普遍放弃了标准的多头注意力(MHA),转而使用:

  • MQA (Multi-Query Attention): 所有 Query 头共享同一组 K 和 V 头。KV Cache 极小,但可能损失一定的模型表现力。
  • GQA (Grouped-Query Attention): MHA 和 MQA 的折中方案。将 Query 头分组,每组共享一组 K 和 V。这是目前工业界最主流的做法,能在保证性能的同时大幅缩减 KV Cache。

B. 系统工程层面:PagedAttention (vLLM)

在传统的 KV Cache 管理中,系统通常会为每个请求预先分配一块连续的显存,这会导致严重的显存碎片化(浪费高达 60%)。

  • PagedAttention 引入了类似 Linux 操作系统中虚拟内存分页(Paging)的思想。它将 KV Cache 划分为固定大小的块(Blocks),在物理显存中非连续存储,并通过一张映射表来管理。这使得显存利用率接近 100%,极大地提升了系统并发吞吐量。

C. 算法与量化层面:

  • KV Cache Quantization: 将原本 FP16/BF16 的缓存量化为 INT8、INT4 甚至更低比特,直接将显存占用砍半或更多。
  • 上下文截断与滑动窗口(如 StreamingLLM / Attention Sinks): 发现只有最前面的几个 Token(System Prompt/注意力基底)和最近生成的 Token 最重要。因此可以只缓存这些关键 Token,丢弃中间的 Token,从而实现理论上的"无限长度"生成而不爆显存。

如果说 KV Cache 解决的是"生成新词时,如何避免重复计算"的问题,那么 FlashAttention 解决的就是"处理超长文本时,如何打破 GPU 内存读写的物理瓶颈"的问题。特别是你在向 VLM/VLA 或长文本大模型方向转型的过程中,FlashAttention 及其变体是必问的核心考点。

1. 痛点:标准 Attention 为什么"慢"且"吃显存"?

在标准 Transformer 中,计算 Attention 的公式是:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

假设输入序列长度(上下文长度)为 NNN。

  • Q,K,VQ, K, VQ,K,V 矩阵的大小是 N×dN \times dN×d。
  • QQQ 和 KTK^TKT 相乘,会生成一个 N×NN \times NN×N 的注意力分数矩阵(Attention Matrix)
  • 接着对这个 N×NN \times NN×N 矩阵做 Softmax。
  • 最后再和 VVV 相乘。

物理硬件的噩梦(Memory Wall):

GPU 的计算单元(ALU)计算速度极快,但它的主显存(HBM,High Bandwidth Memory,比如 80GB 的显存)读写速度相对较慢。

在标准 Attention 中,GPU 需要把那个庞大的 N×NN \times NN×N 矩阵从计算单元写回到 HBM,然后再读出来做 Softmax,再写回去,再读出来和 VVV 相乘。

当上下文 NNN 变大时(比如 32k, 128k),这个 N×NN \times NN×N 矩阵会呈平方级爆炸。模型 90% 的时间没有在做数学计算,而是在"等显存搬运数据" 。这就叫访存密集型瓶颈(Memory-Bound)


2. FlashAttention 的破局点:硬件感知(Hardware-Aware)

FlashAttention 的作者(Tri Dao 等人)敏锐地发现:GPU 除了有一块很大但慢的 HBM(主显存),还有一块很小但极快的片上缓存(SRAM)。

FlashAttention 的核心思想就是:想尽一切办法,把计算过程留在 SRAM 里,绝不把中间那个 N×NN \times NN×N 的庞大矩阵写回主显存 HBM。

为了做到这一点,它引入了两个硬核创新:

核心机制一:分块计算 (Tiling)

既然 SRAM 太小,装不下整个 N×NN \times NN×N 的矩阵,那就把输入矩阵 Q,K,VQ, K, VQ,K,V 切成小块(Blocks)

每次只从主显存(HBM)中搬一小块 QQQ 和一小块 K,VK, VK,V 到 SRAM 中,在高速缓存里直接计算它们的分块 Attention 结果,然后动态更新最终结果。通过这种"蚂蚁搬家"的方式,直接绕过了生成全局 N×NN \times NN×N 矩阵的需求。

核心机制二:在线 Softmax (Online Softmax)

分块计算遇到了一个巨大的数学障碍:Softmax 怎么分块?

标准的 Softmax 需要知道一整行的所有元素,先找出最大值,再求指数和,作为分母。如果你只拿到了一小块数据,你是算不出真正的 Softmax 的。

FlashAttention 使用了巧妙的代数技巧(基于 Safe Softmax 的缩放原理),在每次处理一个新块时,动态保存并更新两个局部变量(当前的最大值和当前的指数和)。这样,即便不看全貌,也能一步步通过局部更新,算出和全局 Softmax 一模一样的精确结果。


3. 为什么它叫"Flash"?

因为引入了 Tiling 和 Online Softmax,显存的读写次数(Memory Accesses)发生了质的改变:

  • 标准 Attention 读写复杂度:O(N2)O(N^2)O(N2)
  • FlashAttention 读写复杂度:O(N)O(N)O(N)

虽然在数学上,它执行的乘加计算量(FLOPs)并没有减少(甚至因为重计算稍微多了一点点),但因为它大大减少了慢速显存的读写等待时间,实际运行速度(Wall-clock time)反而快了 2 到 4 倍

关键考点:精确无损

过去很多试图解决 O(N2)O(N^2)O(N2) 问题的算法(如 Sparse Attention, Linformer 等)都是"近似算法",会损失模型精度。而 FlashAttention 是精确注意力(Exact Attention),它的计算结果与标准 Attention 没有任何区别(在浮点误差允许范围内),这使得它能无缝替换所有大模型底层的注意力代码。


4. 训练与推理(结合 KV Cache)视角的差异

你需要将 FlashAttention 和前文讲的 KV Cache 结合起来看:

维度 FlashAttention KV Cache (PagedAttention)
主要解决什么问题 解决 N×NN \times NN×N 注意力矩阵带来的显存读写瓶颈 (Memory Wall)。 解决解码时历史 Token 重复计算的问题。
显存复杂度变化 将 Attention 过程的峰值显存从 O(N2)O(N^2)O(N2) 降到 O(N)O(N)O(N)。 将每个生成步骤的计算量从 O(N)O(N)O(N) 降到 O(1)O(1)O(1)(但占据庞大显存)。
推理阶段的作用 Prefill(预填充) 阶段绝对关键!因为需要一次性处理极长的 Prompt,必须用 FA 加速。 Decode(解码) 阶段绝对关键!因为逐字生成,必须缓存历史。

注:在 Decode 阶段,因为每次只生成 1 个 Token,Q 的长度是 1,此时计算的不再是矩阵乘矩阵(GEMM),而是矩阵乘向量(GEMV),此时也有专门针对这种场景优化的 FlashDecoding 技术。

5. 面试/工业界常考演进:FA2 与 FA3

如果你在面试算法工程岗,知道最初的 FA 是不够的:

  • FlashAttention-2 (FA2): 进一步减少了非矩阵乘法(non-matmul)操作的比例,优化了线程块(Thread block)在 GPU 上的工作分配(Work Partitioning),特别是在多头注意力之间做了更好的负载均衡,使其更逼近硬件的理论极限(可达 GPU 峰值性能的 70%+)。
  • FlashAttention-3 (FA3): 为最新的 Hopper 架构(如 H100 显卡)量身定制,引入了异步计算(计算与显存加载完全重叠)并原生支持了 FP8 低精度计算格式。

总结来说: KV Cache 让大模型"能够一边记一边聊",而 FlashAttention 让大模型"能够一眼看完整本财报而 GPU 不爆炸"。这两个技术的结合,才是当下长文本大模型(如 Kimi、Claude 3 200k)得以落地的物理基石。

可参考

https://zhuanlan.zhihu.com/p/1965867186716407094

https://zhuanlan.zhihu.com/p/2015196808893192187

相关推荐
石榴树下的七彩鱼1 小时前
AI抠图效果实测:基于Python的3种背景移除模型对比
开发语言·人工智能·python·ai抠图·石榴智能·背景移除·rmbg
中杯可乐多加冰1 小时前
Graphiti:让AI拥有“记忆“这件事,终于有人做对了
人工智能
碳基硅坊1 小时前
LoRA微调Qwen3-VL-8B-Instruct做产品质量检查
人工智能·qwen3-vl-8b
shchojj1 小时前
Generative AI applications -- Writing
人工智能
AirDroid_cn1 小时前
macOS Sequoia 通知摘要:如何启用AI生成的通知摘要,并排除特定应用?
人工智能·macos
霍夫曼vx_helloworld73521 小时前
经典图像检测技术概述
图像处理·人工智能·计算机视觉
AI人工智能+1 小时前
营业执照识别技术通过计算机视觉与人工智能技术,实现企业证照信息的自动化采集
人工智能·深度学习·ocr·营业执照识别
目黑live +wacyltd1 小时前
算法备案的实操指南(含截图示例)
人工智能·算法·llm·大模型备案·算法备案
guslegend1 小时前
第2节:工程初始化
人工智能·大模型