LLama2中,流程中数据大小的变换如下
Transformer模块
第一次输入,进行prefill,输入x维度为1, 8, 4096
-
构建wq,wk,wv,wo,尺寸均为4096,4096, 与x点乘,得到xq, xk, xv
-
构建KV cache, 尺寸为 batch size, max_seq_len, local_kv_heads, head_dim,对应 1, 8, 32, 128
3.基于kv cache构造 keys, alues,对应的尺寸还是1,8,32,128
-
在最后两个维度对于xq和key进行点乘,得到scores,维度变成【1, 32, 8, 8】
-
将mask与scores相加
-
对于scores进行softmax
-
将scores 1, 32, 8, 8与values 1, 32, 8, 128进行乘法
-
得到output 1, 8, 4096
-
将output再与wo进行乘法1, 8, 4096
-
接下来对于输出进行 ffn_norm的操作
Feedforward模块
11.然后进行feed_forward.得到当前transformer模块的输出 1, 8, 4096
feed_forward的操作如下,虽然代码很小,但是计算量却很大。
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
其中,w1的维度为11008, 4096, w2的维度为4096, 11008, w3的维度为11008, 4096
kv cache的表达如下
python
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
关于kv cache的细节讨论
llama2设定 local_kv_heads为32,head_dim为128。所以,kv cache的尺寸为 1, 512,32, 128 * 2
对于一个batch的数据来说哦,因为llama2 7B 包含32个transformer,所以,当使用FP32表达时, 对应一个batch的kv cache的大小为128 * 32 * 128 *2 * 32 * 4byte= 0.5GB.
这里,也可以看到几个变量:
* 当batch变大时,kv cache线性增长
* 当batch 的最大长度增大时, Kv cache线性增长。
参考链接: