符号定义
首先,定义一些符号:
( B ):批大小(Batch Size)
( L ):序列长度(Sequence Length),在您的问题中,( L = 1 )
( N ):Transformer 层数(Number of Transformer Layers)
( H ):注意力头数(Number of Attention Heads)
( D ):每个注意力头的维度(Dimension per Head),即 ( D = Hidden Size / H D = \text{Hidden Size} / H D=Hidden Size/H)
( S ):数据类型大小(Size of Data Type),以字节为单位。例如:
FP32(32位浮点数):( S = 4 ) 字节
FP16(16位浮点数):( S = 2 ) 字节
KV 缓存的内存计算
对于每一层的多头注意力机制,我们需要存储 键(Key) 和 值(Value) 的缓存。对于每一层,键和值的缓存大小计算如下:
键缓存(Key Cache)大小:
Size Key = B × H × L × D × S \text{Size}_{\text{Key}} = B \times H \times L \times D \times S SizeKey=B×H×L×D×S
值缓存(Value Cache)大小:
Size Value = B × H × L × D × S \text{Size}_{\text{Value}} = B \times H \times L \times D \times S SizeValue=B×H×L×D×S
因此,每一层的 KV 缓存总大小为:
SizeKV per layer = SizeKey + Size Value = 2 × B × H × L × D × S \text{Size}{\text{KV per layer}} = \text{Size}{\text{Key}} + \text{Size}_{\text{Value}} = 2 \times B \times H \times L \times D \times S SizeKV per layer=SizeKey+SizeValue=2×B×H×L×D×S
由于模型有 ( N ) 层,因此 总的 KV 缓存大小为:
Total SizeKV = N × SizeKV per layer = 2 × B × H × L × D × N × S \text{Total Size}{\text{KV}} = N \times \text{Size}{\text{KV per layer}} = 2 \times B \times H \times L \times D \times N \times S Total SizeKV=N×SizeKV per layer=2×B×H×L×D×N×S
具体示例计算
假设以下参数:
批大小:( B = 1 )
序列长度:( L = 1 ) (即 token 数为 1)
层数:( N = 12 ) (例如,一个小型的 Transformer)
隐藏层尺寸:( Hidden Size = 768 \text{Hidden Size} = 768 Hidden Size=768 )
注意力头数:( H = 12 )
每个头的维度:
D = Hidden Size H = 768 12 = 64 D = \frac{\text{Hidden Size}}{H} = \frac{768}{12} = 64 D=HHidden Size=12768=64
数据类型:FP32,( S = 4 ) 字节
现在,我们计算每一层的 KV 缓存大小:
Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 1 × 12 × 1 × 64 × 4 = 2 × 12 × 64 × 4 = 2 × 12 × 64 × 4 = 6144 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 1 \times 12 \times 1 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 2 \times 12 \times 64 \times 4 \ &= 6144\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×12×1×64×4 =2×1×12×1×64×4 =2×12×64×4 =2×12×64×4 =6144 字节
总的 KV 缓存大小:
Total SizeKV = N × SizeKV per layer = 12 × 6144 = 73728 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 12 \times 6144 \ &= 73728\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =12×6144 =73728 字节
即大约 72 KB。
虽然这个数字看起来不大,但在大型模型中,参数会显著增大。例如,考虑一个具有以下参数的大型模型:
层数:( N = 96 )
隐藏层尺寸:( Hidden Size = 12288 \text{Hidden Size} = 12288 Hidden Size=12288)
注意力头数:( H = 96 )
每个头的维度:
D = 12288 96 = 128 D = \frac{12288}{96} = 128 D=9612288=128
数据类型:FP16,( S = 2 ) 字节
计算每一层的 KV 缓存大小:
Size KV per layer = 2 × B × H × L × D × S = 2 × 1 × 96 × 1 × 128 × 2 = 2 × 96 × 128 × 2 = 2 × 96 × 128 × 2 = 49 , 152 字节 \begin{align*} \text{Size}_{\text{KV per layer}} &= 2 \times B \times H \times L \times D \times S \ &= 2 \times 1 \times 96 \times 1 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 2 \times 96 \times 128 \times 2 \ &= 49,152\ \text{字节} \end{align*} SizeKV per layer=2×B×H×L×D×S =2×1×96×1×128×2 =2×96×128×2 =2×96×128×2 =49,152 字节
总的 KV 缓存大小:
Total SizeKV = N × SizeKV per layer = 96 × 49 , 152 = 4 , 719 , 616 字节 \begin{align*} \text{Total Size}{\text{KV}} &= N \times \text{Size}{\text{KV per layer}} \ &= 96 \times 49,152 \ &= 4,719,616\ \text{字节} \end{align*} Total SizeKV=N×SizeKV per layer =96×49,152 =4,719,616 字节
即大约 4.5 MB。
注意事项
模型规模的影响: 可以看到,随着 层数 ( N )、隐藏层尺寸 和 注意力头数 ( H ) 的增加,KV 缓存的内存需求会显著增长。
序列长度的影响: 虽然在 ( L = 1 ) 时,序列长度对内存影响较小,但在生成长序列时,( L ) 会增加,导致 KV 缓存内存占用线性增长。
数据类型的影响: 使用 FP16 可以将内存占用减少一半,但对于大型模型,内存需求仍然很高。
总结
即使 token 数为 1,由于模型的层数、注意力头数、每个头的维度等参数较大,KV 缓存仍然需要消耗较大的内存。
通过以上公式,可以直观地看到各个参数对 KV 缓存内存占用的影响,从而理解为什么在处理单个 token 时仍需要大的内存。
优化建议
减少模型规模: 降低 ( N )、( H ) 或 ( D ) 的值。
使用半精度: 采用 FP16 或更低精度的数据类型。
批量大小优化: 确保 ( B ) 仅为需要的最小值。
序列长度控制: 在可能的情况下,限制生成序列的最大长度 ( L )。