Llama中模块参数大小

LLama2中,流程中数据大小的变换如下

Transformer模块

第一次输入,进行prefill,输入x维度为[1, 8, 4096]

  1. 构建wq,wk,wv,wo,尺寸均为[4096,4096], 与x点乘,得到xq, xk, xv

  2. 构建KV cache, 尺寸为 [batch size, max_seq_len, local_kv_heads, head_dim],对应 [1, 8, 32, 128]

3.基于kv cache构造 keys, alues,对应的尺寸还是[1,8,32,128]

  1. 在最后两个维度对于xq和key进行点乘,得到scores,维度变成【1, 32, 8, 8】

  2. 将mask与scores相加

  3. 对于scores进行softmax

  4. 将scores [1, 32, 8, 8]与values [1, 32, 8, 128]进行乘法

  5. 得到output [1, 8, 4096]

  6. 将output再与wo进行乘法[1, 8, 4096]

  7. 接下来对于输出进行 ffn_norm的操作

Feedforward模块

11.然后进行feed_forward.得到当前transformer模块的输出 [1, 8, 4096]

feed_forward的操作如下,虽然代码很小,但是计算量却很大。

复制代码
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

其中,w1的维度为[11008, 4096], w2的维度为[4096, 11008], w3的维度为[11008, 4096]

kv cache的表达如下

python 复制代码
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )

关于kv cache的细节讨论

llama2设定 local_kv_heads为32,head_dim为128。所以,kv cache的尺寸为 [1, 512,32, 128] * 2

对于一个batch的数据来说哦,因为llama2 7B 包含32个transformer,所以,当使用FP32表达时, 对应一个batch的kv cache的大小为128 * 32 * 128 *2 * 32 * 4byte= 0.5GB.

这里,也可以看到几个变量:

* 当batch变大时,kv cache线性增长

* 当batch 的最大长度增大时, Kv cache线性增长。

参考链接:

https://arxiv.org/pdf/1911.02150

相关推荐
狂炫冰美式23 分钟前
3天,1人,从0到付费产品:AI时代个人开发者的生存指南
前端·人工智能·后端
LCG元1 小时前
垂直Agent才是未来:详解让大模型"专业对口"的三大核心技术
人工智能
我不是QI1 小时前
周志华《机器学习—西瓜书》二
人工智能·安全·机器学习
操练起来1 小时前
【昇腾CANN训练营·第八期】Ascend C生态兼容:基于PyTorch Adapter的自定义算子注册与自动微分实现
人工智能·pytorch·acl·昇腾·cann
KG_LLM图谱增强大模型2 小时前
[500页电子书]构建自主AI Agent系统的蓝图:谷歌重磅发布智能体设计模式指南
人工智能·大模型·知识图谱·智能体·知识图谱增强大模型·agenticai
声网2 小时前
活动推荐丨「实时互动 × 对话式 AI」主题有奖征文
大数据·人工智能·实时互动
caiyueloveclamp2 小时前
【功能介绍03】ChatPPT好不好用?如何用?用户操作手册来啦!——【AI溯源篇】
人工智能·信息可视化·powerpoint·ai生成ppt·aippt
q***48412 小时前
Vanna AI:告别代码,用自然语言轻松查询数据库,领先的RAG2SQL技术让结果更智能、更精准!
人工智能·microsoft
LCG元2 小时前
告别空谈!手把手教你用LangChain构建"能干活"的垂直领域AI Agent
人工智能
想你依然心痛3 小时前
视界无界:基于Rokid眼镜的AI商务同传系统开发与实践
人工智能·智能硬件·rokid·ai眼镜·ar技术