注意力机制解决的问题:传统序列处理模型如RNN和LSTM,捕捉长距离依赖关系的难题。注意力机制允许模型在序列的不同位置之间建立直接联系,有效捕捉远距离依赖关系。
为了减少推理过程中KV Cache占用的显存,GQA和MQA通过head之间共享KV实现,这是一种牺牲性能对存储空间妥协的方案,而MLA通过对KV对做低秩联合压缩来减少推理中的KV缓存,目标是减少kv cache存储量的同时,保存模型的效果。
具体做法是,对于每个token,先通过一个低秩矩阵将KV联合压缩到一个低维向量cKVc^{KV}cKV中,然后通过两个升维矩阵WUKW^{UK}WUK,WUVW^{UV}WUV解压缩回高维,后续进行普通的多头注意力计算,这样每次只需要存这个低维向量。
这样做有个问题,就是压缩和解压操作使计算量增加了,而实际计算中,通过"矩阵吸收"操作,也就是矩阵运算过程中的结合律使多个矩阵合并,从而减少计算量。
具体计算过程如下(对qqq做相同压缩操作,于是也有了cQc^QcQ和WUQW^{UQ}WUQ):
attention=softmax(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax(cQWUQ(cKVWUK)Td)cKVWUVWO=\operatorname{softmax}(\frac{c^QW^{UQ}(c^{KV}W^{UK})^T}{\sqrt{d}})c^{KV}W^{UV}W^O=softmax(d cQWUQ(cKVWUK)T)cKVWUVWO
=softmax(cQ(WUQ(WUK)T)(cKV)Td)cKV(WUVWO)=\operatorname{softmax}(\frac{c^Q(W^{UQ}(W^{UK})^T)(c^{KV})^T}{\sqrt{d}})c^{KV}(W^{UV}W^O)=softmax(d cQ(WUQ(WUK)T)(cKV)T)cKV(WUVWO)
=softmax(cQWUQUK(cKV)Td)cKVWUVO=\operatorname{softmax}(\frac{c^QW^{UQUK}(c^{KV})^T}{\sqrt{d}})c^{KV}W^{UVO}=softmax(d cQWUQUK(cKV)T)cKVWUVO
如上所示,计算过程中,由于矩阵乘法结合律,WUQW^{UQ}WUQ,WUKW^{UK}WUK合并成一个矩阵WUQUKW^{UQUK}WUQUK,同理,WUVW^{UV}WUV,WOW^OWO合并成WUVOW^{UVO}WUVO。
对比普通MHA计算公式:
attention=softmax(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax(htWQKTd)vWO=\operatorname{softmax}(\frac{h_tW^QK^T}{\sqrt{d}})vW^O=softmax(d htWQKT)vWO
可知,两种注意力机制计算量相同,没有引入额外计算量,而缓存从两个高维KKK,VVV变成了一个低维cKVc^{KV}cKV
接下来是位置编码RoPE的处理,MHA中,RoPE可以通过对qqq,kkk向量乘以一个位置相关的变换矩阵RiR_iRi(iii为当前token所处的位置)。然而,在MLA中,如果做相同的处理将会如下所示:
qiRi(kjRj)T=cQWUQRi(cjKVWUKRj)Tq_iR_i(k_jR_j)^T=c^QW^{UQ}R_i(c^{KV}_jW^{UK}R_j)^TqiRi(kjRj)T=cQWUQRi(cjKVWUKRj)T
=cQWUQRiRjT(WUK)T(cjKV)T=c^QW^{UQ}R_iR^T_j(W^{UK})^T(c_j^{KV})^T=cQWUQRiRjT(WUK)T(cjKV)T
由于RiR_iRi不是一个固定的矩阵,无法实现矩阵吸收来减少计算量。对于这个问题,deepseek的做法是将参与注意力计算的qqq,vvv分成两部分,一部分进行矩阵吸收操作,不带位置信息,一部分进行位置信息计算。
对于qqq,基于潜在向量cQc^QcQ通过矩阵WQRW^{QR}WQR变换为低维向量后进行RoPE变换得到qRq^RqR;对于kkk,直接将输入hth_tht也通过一个矩阵WKRW^{KR}WKR变换后做RoPE变换得到kRk^RkR,其中kRk^RkR按照MQA的处理方式,各个head之间共享,既减少了显存调用又保证了位置编码的全局一致。然后将qRq_RqR,kRk^RkR拼接到前面计算得到的qqq,kkk向量后面,得到最终用于计算注意力的q=[qC;qR]q=[q^C;q^R]q=[qC;qR],k=[kC;kR]k=[k^C;k^R]k=[kC;kR]。
这样计算点积时如下,其中ttt,jjj表示token,iii表示head:
qt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktRq_{t,i}k_{j,i}^T=[q_{t,i}^C;q_{t,i}^R]\times[k_{j,i}^C;k_t^R]=q^C_{t,i}k_{j,i}^C+q_{t,i}^Rk_t^Rqt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktR
这样不包含位置编码的部分qt,iCkj,iCq^C_{t,i}k_{j,i}^Cqt,iCkj,iC就可以进行矩阵吸收的处理,每个head缓存一个ct,iKVc_{t,i}^{KV}ct,iKV;后一项按MQA的方式计算,所有head只需缓存一个共享的ktRk^R_tktR。
学习参考资料:zhihu-冷面爸