如何计算kv cache的缓存大小

符号定义

首先,定义一些符号:

( B ):批大小(Batch Size)

( L ):序列长度(Sequence Length),在您的问题中,( L = 1 )

( N ):Transformer 层数(Number of Transformer Layers)

( H ):注意力头数(Number of Attention Heads)

( D ):每个注意力头的维度(Dimension per Head),即 ( D = Hidden Size / H D = \text{Hidden Size} / H D=Hidden Size/H)

( S ):数据类型大小(Size of Data Type),以字节为单位。例如:

FP32(32位浮点数):( S = 4 ) 字节

FP16(16位浮点数):( S = 2 ) 字节

KV 缓存的内存计算

对于每一层的多头注意力机制,我们需要存储 键(Key)值(Value) 的缓存。对于每一层,键和值的缓存大小计算如下:

键缓存(Key Cache)大小:

Size Key = B × H × L × D × S \text{Size}_{\text{Key}} = B \times H \times L \times D \times S SizeKey=B×H×L×D×S

值缓存(Value Cache)大小:

Size Value = B × H × L × D × S \text{Size}_{\text{Value}} = B \times H \times L \times D \times S SizeValue=B×H×L×D×S

因此,每一层的 KV 缓存总大小为:

SizeKV per layer = SizeKey + Size Value = 2 × B × H × L × D × S \text{Size}{\text{KV per layer}} = \text{Size}{\text{Key}} + \text{Size}_{\text{Value}} = 2 \times B \times H \times L \times D \times S SizeKV per layer=SizeKey+SizeValue=2×B×H×L×D×S

由于模型有 ( N ) 层,因此 总的 KV 缓存大小为:

Total SizeKV = N × SizeKV per layer = 2 × B × H × L × D × N × S \text{Total Size}{\text{KV}} = N \times \text{Size}{\text{KV per layer}} = 2 \times B \times H \times L \times D \times N \times S Total SizeKV=N×SizeKV per layer=2×B×H×L×D×N×S

具体示例计算

假设以下参数:

批大小:( B = 1 )

序列长度:( L = 1 ) (即 token 数为 1)

层数:( N = 12 ) (例如,一个小型的 Transformer)

隐藏层尺寸:( Hidden Size = 768 \text{Hidden Size} = 768 Hidden Size=768 )

注意力头数:( H = 12 )

每个头的维度:

D = Hidden Size H = 768 12 = 64 D = \frac{\text{Hidden Size}}{H} = \frac{768}{12} = 64 D=HHidden Size=12768=64

数据类型:FP32,( S = 4 ) 字节

现在,我们计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 12 × 64 × 4 = 2 × 12 × 64 × 4 = 6144 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 6144\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×12×1×64×4 =2×1×12×1×64×4 =2×12×64×4 =2×12×64×4 =6144 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer = 12 × 6144 = 73728 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 12 \times 6144 \ &= 73728\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =12×6144 =73728 字节

即大约 72 KB。

虽然这个数字看起来不大,但在大型模型中,参数会显著增大。例如,考虑一个具有以下参数的大型模型:

层数:( N = 96 )

隐藏层尺寸:( Hidden Size = 12288 \text{Hidden Size} = 12288 Hidden Size=12288)

注意力头数:( H = 96 )

每个头的维度:

D = 12288 96 = 128 D = \frac{12288}{96} = 128 D=9612288=128

数据类型:FP16,( S = 2 ) 字节

计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 96 × 1 × 128 × 2 = 2 × 96 × 128 × 2 = 2 × 96 × 128 × 2 = 49 , 152 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 96 \times 1 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 49,152\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×96×1×128×2 =2×96×128×2 =2×96×128×2 =49,152 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer = 96 × 49 , 152 = 4 , 719 , 616 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 96 \times 49,152 \ &= 4,719,616\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =96×49,152 =4,719,616 字节

即大约 4.5 MB。

注意事项

模型规模的影响: 可以看到,随着 层数 ( N )、隐藏层尺寸 和 注意力头数 ( H ) 的增加,KV 缓存的内存需求会显著增长。

序列长度的影响: 虽然在 ( L = 1 ) 时,序列长度对内存影响较小,但在生成长序列时,( L ) 会增加,导致 KV 缓存内存占用线性增长。

数据类型的影响: 使用 FP16 可以将内存占用减少一半,但对于大型模型,内存需求仍然很高。

总结

即使 token 数为 1,由于模型的层数、注意力头数、每个头的维度等参数较大,KV 缓存仍然需要消耗较大的内存。

通过以上公式,可以直观地看到各个参数对 KV 缓存内存占用的影响,从而理解为什么在处理单个 token 时仍需要大的内存。

优化建议

减少模型规模: 降低 ( N )、( H ) 或 ( D ) 的值。
使用半精度: 采用 FP16 或更低精度的数据类型。
批量大小优化: 确保 ( B ) 仅为需要的最小值。
序列长度控制: 在可能的情况下,限制生成序列的最大长度 ( L )。

相关推荐
先睡5 小时前
Redis的缓存击穿和缓存雪崩
redis·spring·缓存
CodeWithMe19 小时前
【Note】《深入理解Linux内核》 Chapter 15 :深入理解 Linux 页缓存
linux·spring·缓存
大春儿的试验田21 小时前
高并发收藏功能设计:Redis异步同步与定时补偿机制详解
java·数据库·redis·学习·缓存
likeGhee21 小时前
python缓存装饰器实现方案
开发语言·python·缓存
C182981825751 天前
OOM电商系统订单缓存泄漏,这是泄漏还是溢出
java·spring·缓存
西岭千秋雪_1 天前
Redis性能优化
数据库·redis·笔记·学习·缓存·性能优化
en-route1 天前
HTTP 缓存
网络协议·http·缓存
苦夏木禾2 天前
js请求避免缓存的三种方式
开发语言·javascript·缓存
重庆小透明2 天前
力扣刷题记录【1】146.LRU缓存
java·后端·学习·算法·leetcode·缓存
Java初学者小白2 天前
秋招Day14 - Redis - 应用
java·数据库·redis·缓存