一文搞懂KV-Cache

前几天面试的时候,面试官问我知道什么是KV-Cache吗?我愣在了原地,所以回来赶紧搞懂,把我所理解的和大家一起学习一下。也作为Transformer系列的第五篇。

Transformer系列文章:

一览Transformer整体架构

Transformer------Attention怎么实现集中注意力

Transformer------FeedForward模块在干什么?

从0开始实现Transformer

所有相关源码示例、流程图、模型配置与知识库构建技巧,我也将持续更新在Github:LLMHub,欢迎关注收藏!

希望大家带着下面的问题来学习,我会在文末给出答案。

  1. KV Cache节省了Self-Attention层中哪部分的计算?
  2. KV Cache对MLP层的计算量有影响吗?
  3. KV Cache对block间的数据传输量有影响吗?

在推理阶段,KV-Cache是Transformer加速推理的常用策略。

我们都知道,Transformer的解码器是自回归 的架构,所谓自回归,就是前面的输出会加到现在的输入里面进行不断生成,所以理解了Attention的同学就会意识到这个里面有很多重复性的计算,如果不了解Attention,可以去看一下我之前的文章Transformer------Attention怎么实现集中注意力

那么为什么会有重复性计算呢,我们来看一下

可以看到当前的Attention计算只与几个数据有关:

  1. 当前的query加入的新token,也就是来自模型前一轮的输出,图中的第一轮的"你",第二轮的"是",第三轮的"谁"。
  2. 历史K矩阵:每个Q向量都会依次和K矩阵中的每一行进行计算
  3. 历史V矩阵:Q*K得到的矩阵每一行要与V矩阵进行计算

传统Transformer在进行计算时是在每一轮中将Q,K,V乘以对应的W权重,进行计算Attention的过程,但其实这个计算过程中每一轮新增的向量只是Q中最后一行向量,K中最后一列向量,V中最后一行向量,可以把之前K,V计算的结果进行缓存,当前一轮只利用新加入的Q向量和新的K向量和V向量进行计算,最后将K向量和V向量与原始的向量进行拼接,来大大减少冗余的计算量。

当然KV-Cache会增加内存的使用,是典型的空间换时间操作,所以当序列特别长的时候,KV-Cache的显存开销甚至会超过模型本身,很容易爆显存,比如batch_size=32, head=32, layer=32, dim_size=4096, seq_length=2048, float32类型,则需要占用的显存为 2 * 32 * 4096 * 2048 * 32 * 4 / 1024/1024/1024 /1024 = 64G。

最后,我们回答一下文章开头提出的问题。

  1. KV Cache节省了Self-Attention层中哪部分的计算?

节省的是历史 token 的 Key 和 Value 的重新计算 ,把 历史 token 的 Key/Value 缓存在缓存中。每次只需计算当前 token 的 Q,历史的 K/V 可直接复用。无需重新前向计算 K,V 的线性变换和位置编码,从而节省了大量计算。

  1. KV Cache对MLP层的计算量有影响吗?

没有影响,MLP 层(即 FFN)是每个 token 独立计算的,不依赖历史上下文。所以每个生成的 token 无论如何都要进行一次完整的 MLP 前向传播,KV Cache 只作用于 Self-Attention 层的 Key 和 Value,不涉及 MLP 层。

  1. KV Cache对block间的数据传输量有影响吗?

有影响,通常会减少 block 间传输量(尤其在多卡/分布式环境中)。如果每一步都重新计算历史 Key/Value,就要不断在 block 之间传输所有 token 的 KV 表征。使用 KV Cache 后,每步只需传输当前 token 的 Q(给当前层使用)以及缓存的 KV(已经存储,不重复传)

关于深度学习和大模型相关的知识和前沿技术更新,请关注公众号算法coting

上内容部分参考了

动图看懂什么是KV Cache

LLM(20):漫谈 KV Cache 优化方法,深度理解 StreamingLLM

大模型推理加速:看图学KV Cache

非常感谢,如有侵权请联系删除!

相关推荐
AI视觉网奇2 小时前
rknn yolo11 推理
前端·人工智能·python
AI数据皮皮侠3 小时前
中国各省森林覆盖率等数据(2000-2023年)
大数据·人工智能·python·深度学习·机器学习
西柚小萌新4 小时前
【深入浅出PyTorch】--3.1.PyTorch组成模块1
人工智能·pytorch·python
2401_841495645 小时前
【数据结构】红黑树的基本操作
java·数据结构·c++·python·算法·红黑树·二叉搜索树
西猫雷婶5 小时前
random.shuffle()函数随机打乱数据
开发语言·pytorch·python·学习·算法·线性回归·numpy
鑫宝的学习笔记5 小时前
Vmware虚拟机联网问题,显示:线缆已拔出!!!
人工智能·ubuntu
小李独爱秋5 小时前
机器学习中的聚类理论与K-means算法详解
人工智能·算法·机器学习·支持向量机·kmeans·聚类
comli_cn6 小时前
GSPO论文阅读
论文阅读·人工智能
大有数据可视化6 小时前
数字孪生背后的大数据技术:时序数据库为何是关键?
大数据·数据库·人工智能
Bioinfo Guy6 小时前
Genome Med|RAG-HPO做表型注释:学习一下大语言模型怎么作为发文思路
人工智能·大语言模型·多组学