如何计算kv cache的缓存大小

符号定义

首先,定义一些符号:

( B ):批大小(Batch Size)

( L ):序列长度(Sequence Length),在您的问题中,( L = 1 )

( N ):Transformer 层数(Number of Transformer Layers)

( H ):注意力头数(Number of Attention Heads)

( D ):每个注意力头的维度(Dimension per Head),即 ( D = Hidden Size / H D = \text{Hidden Size} / H D=Hidden Size/H)

( S ):数据类型大小(Size of Data Type),以字节为单位。例如:

FP32(32位浮点数):( S = 4 ) 字节

FP16(16位浮点数):( S = 2 ) 字节

KV 缓存的内存计算

对于每一层的多头注意力机制,我们需要存储 键(Key)值(Value) 的缓存。对于每一层,键和值的缓存大小计算如下:

键缓存(Key Cache)大小:

Size Key = B × H × L × D × S \text{Size}_{\text{Key}} = B \times H \times L \times D \times S SizeKey=B×H×L×D×S

值缓存(Value Cache)大小:

Size Value = B × H × L × D × S \text{Size}_{\text{Value}} = B \times H \times L \times D \times S SizeValue=B×H×L×D×S

因此,每一层的 KV 缓存总大小为:

SizeKV per layer = SizeKey + Size Value = 2 × B × H × L × D × S \text{Size}{\text{KV per layer}} = \text{Size}{\text{Key}} + \text{Size}_{\text{Value}} = 2 \times B \times H \times L \times D \times S SizeKV per layer=SizeKey+SizeValue=2×B×H×L×D×S

由于模型有 ( N ) 层,因此 总的 KV 缓存大小为:

Total SizeKV = N × SizeKV per layer = 2 × B × H × L × D × N × S \text{Total Size}{\text{KV}} = N \times \text{Size}{\text{KV per layer}} = 2 \times B \times H \times L \times D \times N \times S Total SizeKV=N×SizeKV per layer=2×B×H×L×D×N×S

具体示例计算

假设以下参数:

批大小:( B = 1 )

序列长度:( L = 1 ) (即 token 数为 1)

层数:( N = 12 ) (例如,一个小型的 Transformer)

隐藏层尺寸:( Hidden Size = 768 \text{Hidden Size} = 768 Hidden Size=768 )

注意力头数:( H = 12 )

每个头的维度:

D = Hidden Size H = 768 12 = 64 D = \frac{\text{Hidden Size}}{H} = \frac{768}{12} = 64 D=HHidden Size=12768=64

数据类型:FP32,( S = 4 ) 字节

现在,我们计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 12 × 64 × 4 = 2 × 12 × 64 × 4 = 6144 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 6144\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×12×1×64×4 =2×1×12×1×64×4 =2×12×64×4 =2×12×64×4 =6144 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer = 12 × 6144 = 73728 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 12 \times 6144 \ &= 73728\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =12×6144 =73728 字节

即大约 72 KB。

虽然这个数字看起来不大,但在大型模型中,参数会显著增大。例如,考虑一个具有以下参数的大型模型:

层数:( N = 96 )

隐藏层尺寸:( Hidden Size = 12288 \text{Hidden Size} = 12288 Hidden Size=12288)

注意力头数:( H = 96 )

每个头的维度:

D = 12288 96 = 128 D = \frac{12288}{96} = 128 D=9612288=128

数据类型:FP16,( S = 2 ) 字节

计算每一层的 KV 缓存大小:

Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 96 × 1 × 128 × 2 = 2 × 96 × 128 × 2 = 2 × 96 × 128 × 2 = 49 , 152 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 96 \times 1 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 49,152\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×96×1×128×2 =2×96×128×2 =2×96×128×2 =49,152 字节

总的 KV 缓存大小:

Total SizeKV = N × SizeKV per layer = 96 × 49 , 152 = 4 , 719 , 616 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 96 \times 49,152 \ &= 4,719,616\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =96×49,152 =4,719,616 字节

即大约 4.5 MB。

注意事项

模型规模的影响: 可以看到,随着 层数 ( N )、隐藏层尺寸 和 注意力头数 ( H ) 的增加,KV 缓存的内存需求会显著增长。

序列长度的影响: 虽然在 ( L = 1 ) 时,序列长度对内存影响较小,但在生成长序列时,( L ) 会增加,导致 KV 缓存内存占用线性增长。

数据类型的影响: 使用 FP16 可以将内存占用减少一半,但对于大型模型,内存需求仍然很高。

总结

即使 token 数为 1,由于模型的层数、注意力头数、每个头的维度等参数较大,KV 缓存仍然需要消耗较大的内存。

通过以上公式,可以直观地看到各个参数对 KV 缓存内存占用的影响,从而理解为什么在处理单个 token 时仍需要大的内存。

优化建议

减少模型规模: 降低 ( N )、( H ) 或 ( D ) 的值。
使用半精度: 采用 FP16 或更低精度的数据类型。
批量大小优化: 确保 ( B ) 仅为需要的最小值。
序列长度控制: 在可能的情况下,限制生成序列的最大长度 ( L )。

相关推荐
confident343 分钟前
hibernate 配置 二级 缓存
java·缓存·hibernate
weixin_425878231 小时前
Nginx 缓存那些事儿:原理、配置和最佳实践
运维·nginx·缓存
我们的五年2 小时前
【Linux课程学习】:第20弹---信号入门专题(基础部分)
linux·服务器·后端·学习·缓存
weisian1519 小时前
Redis篇-4--原理篇3--Redis发布/订阅(Pub/Sub)
数据库·redis·缓存
Swift社区12 小时前
巧用缓存:高效实现基于 read4 的文件读取方法
缓存
宁静@星空18 小时前
004-Redis 持久化
数据库·redis·缓存
weisian15119 小时前
Redis篇-1--入门介绍
java·数据库·redis·缓存·lua
wu@5555521 小时前
pika:适用于大数据量持久化的类redis组件|jedis集成pika(二)
数据库·redis·缓存
码农老起1 天前
中间件的分类与实践:从消息到缓存
缓存·中间件