大模型算法面试笔记——多头潜在注意力(MLA)

注意力机制解决的问题:传统序列处理模型如RNN和LSTM,捕捉长距离依赖关系的难题。注意力机制允许模型在序列的不同位置之间建立直接联系,有效捕捉远距离依赖关系。

为了减少推理过程中KV Cache占用的显存,GQA和MQA通过head之间共享KV实现,这是一种牺牲性能对存储空间妥协的方案,而MLA通过对KV对做低秩联合压缩来减少推理中的KV缓存,目标是减少kv cache存储量的同时,保存模型的效果。

具体做法是,对于每个token,先通过一个低秩矩阵将KV联合压缩到一个低维向量cKVc^{KV}cKV中,然后通过两个升维矩阵WUKW^{UK}WUK,WUVW^{UV}WUV解压缩回高维,后续进行普通的多头注意力计算,这样每次只需要存这个低维向量。

这样做有个问题,就是压缩和解压操作使计算量增加了,而实际计算中,通过"矩阵吸收"操作,也就是矩阵运算过程中的结合律使多个矩阵合并,从而减少计算量。

具体计算过程如下(对qqq做相同压缩操作,于是也有了cQc^QcQ和WUQW^{UQ}WUQ):

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax⁡(cQWUQ(cKVWUK)Td)cKVWUVWO=\operatorname{softmax}(\frac{c^QW^{UQ}(c^{KV}W^{UK})^T}{\sqrt{d}})c^{KV}W^{UV}W^O=softmax(d cQWUQ(cKVWUK)T)cKVWUVWO
=softmax⁡(cQ(WUQ(WUK)T)(cKV)Td)cKV(WUVWO)=\operatorname{softmax}(\frac{c^Q(W^{UQ}(W^{UK})^T)(c^{KV})^T}{\sqrt{d}})c^{KV}(W^{UV}W^O)=softmax(d cQ(WUQ(WUK)T)(cKV)T)cKV(WUVWO)
=softmax⁡(cQWUQUK(cKV)Td)cKVWUVO=\operatorname{softmax}(\frac{c^QW^{UQUK}(c^{KV})^T}{\sqrt{d}})c^{KV}W^{UVO}=softmax(d cQWUQUK(cKV)T)cKVWUVO

如上所示,计算过程中,由于矩阵乘法结合律,WUQW^{UQ}WUQ,WUKW^{UK}WUK合并成一个矩阵WUQUKW^{UQUK}WUQUK,同理,WUVW^{UV}WUV,WOW^OWO合并成WUVOW^{UVO}WUVO。

对比普通MHA计算公式:

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax⁡(htWQKTd)vWO=\operatorname{softmax}(\frac{h_tW^QK^T}{\sqrt{d}})vW^O=softmax(d htWQKT)vWO

可知,两种注意力机制计算量相同,没有引入额外计算量,而缓存从两个高维KKK,VVV变成了一个低维cKVc^{KV}cKV

接下来是位置编码RoPE的处理,MHA中,RoPE可以通过对qqq,kkk向量乘以一个位置相关的变换矩阵RiR_iRi(iii为当前token所处的位置)。然而,在MLA中,如果做相同的处理将会如下所示:

qiRi(kjRj)T=cQWUQRi(cjKVWUKRj)Tq_iR_i(k_jR_j)^T=c^QW^{UQ}R_i(c^{KV}_jW^{UK}R_j)^TqiRi(kjRj)T=cQWUQRi(cjKVWUKRj)T
=cQWUQRiRjT(WUK)T(cjKV)T=c^QW^{UQ}R_iR^T_j(W^{UK})^T(c_j^{KV})^T=cQWUQRiRjT(WUK)T(cjKV)T

由于RiR_iRi不是一个固定的矩阵,无法实现矩阵吸收来减少计算量。对于这个问题,deepseek的做法是将参与注意力计算的qqq,vvv分成两部分,一部分进行矩阵吸收操作,不带位置信息,一部分进行位置信息计算。

对于qqq,基于潜在向量cQc^QcQ通过矩阵WQRW^{QR}WQR变换为低维向量后进行RoPE变换得到qRq^RqR;对于kkk,直接将输入hth_tht也通过一个矩阵WKRW^{KR}WKR变换后做RoPE变换得到kRk^RkR,其中kRk^RkR按照MQA的处理方式,各个head之间共享,既减少了显存调用又保证了位置编码的全局一致。然后将qRq_RqR,kRk^RkR拼接到前面计算得到的qqq,kkk向量后面,得到最终用于计算注意力的q=qC;qRq=q\^C;q\^Rq=qC;qR,k=kC;kRk=k\^C;k\^Rk=kC;kR

这样计算点积时如下,其中ttt,jjj表示token,iii表示head:

qt,ikj,iT=qt,iC;qt,iR×kj,iC;ktR=qt,iCkj,iC+qt,iRktRq_{t,i}k_{j,i}^T=q_{t,i}\^C;q_{t,i}\^R\timesk_{j,i}\^C;k_t\^R=q^C_{t,i}k_{j,i}^C+q_{t,i}^Rk_t^Rqt,ikj,iT=qt,iC;qt,iR×kj,iC;ktR=qt,iCkj,iC+qt,iRktR

这样不包含位置编码的部分qt,iCkj,iCq^C_{t,i}k_{j,i}^Cqt,iCkj,iC就可以进行矩阵吸收的处理,每个head缓存一个ct,iKVc_{t,i}^{KV}ct,iKV;后一项按MQA的方式计算,所有head只需缓存一个共享的ktRk^R_tktR。

学习参考资料:zhihu-冷面爸

相关推荐
AI小老六5 小时前
SkillOpt 架构拆解:把 Skill 文本当参数,用执行轨迹训练 Agent
后端·算法·ai编程
胡萝卜术6 小时前
从“分数打架”到“排名投票”:为什么你的ChatBI必须用RRF?
算法·设计模式·面试
Asize7 小时前
初识DFS 与 BFS:递归、队列与图遍历
算法
罗西的思考20 小时前
机器人 / 强化学习】HIL-SERL:人类在环驱动的具身智能进化框架
人工智能·算法·机器学习
美团技术团队1 天前
LongCat 开源 VitaBench 2.0:长期动态智能体基准新标杆
人工智能·算法
RainCity1 天前
Java Swing 自定义组件库分享(十二)
java·笔记·后端
To_OC2 天前
LC 207 课程表:刚学图论那会儿,我连这是拓扑排序都没看出来
javascript·算法·leetcode
To_OC2 天前
LC 208 实现 Trie 前缀树:曾被名字劝退,写完发现是送分题
javascript·算法·leetcode
BadBadBad__AK2 天前
线段树维护区间 k 次方和
c++·数学·算法·stl