大模型算法面试笔记——多头潜在注意力(MLA)

注意力机制解决的问题:传统序列处理模型如RNN和LSTM,捕捉长距离依赖关系的难题。注意力机制允许模型在序列的不同位置之间建立直接联系,有效捕捉远距离依赖关系。

为了减少推理过程中KV Cache占用的显存,GQA和MQA通过head之间共享KV实现,这是一种牺牲性能对存储空间妥协的方案,而MLA通过对KV对做低秩联合压缩来减少推理中的KV缓存,目标是减少kv cache存储量的同时,保存模型的效果。

具体做法是,对于每个token,先通过一个低秩矩阵将KV联合压缩到一个低维向量cKVc^{KV}cKV中,然后通过两个升维矩阵WUKW^{UK}WUK,WUVW^{UV}WUV解压缩回高维,后续进行普通的多头注意力计算,这样每次只需要存这个低维向量。

这样做有个问题,就是压缩和解压操作使计算量增加了,而实际计算中,通过"矩阵吸收"操作,也就是矩阵运算过程中的结合律使多个矩阵合并,从而减少计算量。

具体计算过程如下(对qqq做相同压缩操作,于是也有了cQc^QcQ和WUQW^{UQ}WUQ):

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax⁡(cQWUQ(cKVWUK)Td)cKVWUVWO=\operatorname{softmax}(\frac{c^QW^{UQ}(c^{KV}W^{UK})^T}{\sqrt{d}})c^{KV}W^{UV}W^O=softmax(d cQWUQ(cKVWUK)T)cKVWUVWO
=softmax⁡(cQ(WUQ(WUK)T)(cKV)Td)cKV(WUVWO)=\operatorname{softmax}(\frac{c^Q(W^{UQ}(W^{UK})^T)(c^{KV})^T}{\sqrt{d}})c^{KV}(W^{UV}W^O)=softmax(d cQ(WUQ(WUK)T)(cKV)T)cKV(WUVWO)
=softmax⁡(cQWUQUK(cKV)Td)cKVWUVO=\operatorname{softmax}(\frac{c^QW^{UQUK}(c^{KV})^T}{\sqrt{d}})c^{KV}W^{UVO}=softmax(d cQWUQUK(cKV)T)cKVWUVO

如上所示,计算过程中,由于矩阵乘法结合律,WUQW^{UQ}WUQ,WUKW^{UK}WUK合并成一个矩阵WUQUKW^{UQUK}WUQUK,同理,WUVW^{UV}WUV,WOW^OWO合并成WUVOW^{UVO}WUVO。

对比普通MHA计算公式:

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax⁡(htWQKTd)vWO=\operatorname{softmax}(\frac{h_tW^QK^T}{\sqrt{d}})vW^O=softmax(d htWQKT)vWO

可知,两种注意力机制计算量相同,没有引入额外计算量,而缓存从两个高维KKK,VVV变成了一个低维cKVc^{KV}cKV

接下来是位置编码RoPE的处理,MHA中,RoPE可以通过对qqq,kkk向量乘以一个位置相关的变换矩阵RiR_iRi(iii为当前token所处的位置)。然而,在MLA中,如果做相同的处理将会如下所示:

qiRi(kjRj)T=cQWUQRi(cjKVWUKRj)Tq_iR_i(k_jR_j)^T=c^QW^{UQ}R_i(c^{KV}_jW^{UK}R_j)^TqiRi(kjRj)T=cQWUQRi(cjKVWUKRj)T
=cQWUQRiRjT(WUK)T(cjKV)T=c^QW^{UQ}R_iR^T_j(W^{UK})^T(c_j^{KV})^T=cQWUQRiRjT(WUK)T(cjKV)T

由于RiR_iRi不是一个固定的矩阵,无法实现矩阵吸收来减少计算量。对于这个问题,deepseek的做法是将参与注意力计算的qqq,vvv分成两部分,一部分进行矩阵吸收操作,不带位置信息,一部分进行位置信息计算。

对于qqq,基于潜在向量cQc^QcQ通过矩阵WQRW^{QR}WQR变换为低维向量后进行RoPE变换得到qRq^RqR;对于kkk,直接将输入hth_tht也通过一个矩阵WKRW^{KR}WKR变换后做RoPE变换得到kRk^RkR,其中kRk^RkR按照MQA的处理方式,各个head之间共享,既减少了显存调用又保证了位置编码的全局一致。然后将qRq_RqR,kRk^RkR拼接到前面计算得到的qqq,kkk向量后面,得到最终用于计算注意力的q=[qC;qR]q=[q^C;q^R]q=[qC;qR],k=[kC;kR]k=[k^C;k^R]k=[kC;kR]。

这样计算点积时如下,其中ttt,jjj表示token,iii表示head:

qt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktRq_{t,i}k_{j,i}^T=[q_{t,i}^C;q_{t,i}^R]\times[k_{j,i}^C;k_t^R]=q^C_{t,i}k_{j,i}^C+q_{t,i}^Rk_t^Rqt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktR

这样不包含位置编码的部分qt,iCkj,iCq^C_{t,i}k_{j,i}^Cqt,iCkj,iC就可以进行矩阵吸收的处理,每个head缓存一个ct,iKVc_{t,i}^{KV}ct,iKV;后一项按MQA的方式计算,所有head只需缓存一个共享的ktRk^R_tktR。

学习参考资料:zhihu-冷面爸

相关推荐
smj2302_796826521 小时前
解决leetcode第3768题.固定长度子数组中的最小逆序对数目
python·算法·leetcode
F_D_Z1 小时前
简明 | Yolo-v3结构理解摘要
深度学习·神经网络·yolo·计算机视觉·resnet
cynicme1 小时前
力扣3531——统计被覆盖的建筑
算法·leetcode
hd51cc1 小时前
MFC控件 学习笔记二
笔记·学习·mfc
core5122 小时前
深度解析DeepSeek-R1中GRPO强化学习算法
人工智能·算法·机器学习·deepseek·grpo
mit6.8242 小时前
计数if|
算法
a伊雪2 小时前
c++ 引用参数
c++·算法
java1234_小锋2 小时前
Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 自注意力机制(Self-Attention)原理介绍
深度学习·语言模型·transformer
ney187819024742 小时前
分类网络LeNet + FashionMNIST 准确率92.9%
python·深度学习·分类