大模型算法面试笔记——多头潜在注意力(MLA)

注意力机制解决的问题:传统序列处理模型如RNN和LSTM,捕捉长距离依赖关系的难题。注意力机制允许模型在序列的不同位置之间建立直接联系,有效捕捉远距离依赖关系。

为了减少推理过程中KV Cache占用的显存,GQA和MQA通过head之间共享KV实现,这是一种牺牲性能对存储空间妥协的方案,而MLA通过对KV对做低秩联合压缩来减少推理中的KV缓存,目标是减少kv cache存储量的同时,保存模型的效果。

具体做法是,对于每个token,先通过一个低秩矩阵将KV联合压缩到一个低维向量cKVc^{KV}cKV中,然后通过两个升维矩阵WUKW^{UK}WUK,WUVW^{UV}WUV解压缩回高维,后续进行普通的多头注意力计算,这样每次只需要存这个低维向量。

这样做有个问题,就是压缩和解压操作使计算量增加了,而实际计算中,通过"矩阵吸收"操作,也就是矩阵运算过程中的结合律使多个矩阵合并,从而减少计算量。

具体计算过程如下(对qqq做相同压缩操作,于是也有了cQc^QcQ和WUQW^{UQ}WUQ):

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax⁡(cQWUQ(cKVWUK)Td)cKVWUVWO=\operatorname{softmax}(\frac{c^QW^{UQ}(c^{KV}W^{UK})^T}{\sqrt{d}})c^{KV}W^{UV}W^O=softmax(d cQWUQ(cKVWUK)T)cKVWUVWO
=softmax⁡(cQ(WUQ(WUK)T)(cKV)Td)cKV(WUVWO)=\operatorname{softmax}(\frac{c^Q(W^{UQ}(W^{UK})^T)(c^{KV})^T}{\sqrt{d}})c^{KV}(W^{UV}W^O)=softmax(d cQ(WUQ(WUK)T)(cKV)T)cKV(WUVWO)
=softmax⁡(cQWUQUK(cKV)Td)cKVWUVO=\operatorname{softmax}(\frac{c^QW^{UQUK}(c^{KV})^T}{\sqrt{d}})c^{KV}W^{UVO}=softmax(d cQWUQUK(cKV)T)cKVWUVO

如上所示,计算过程中,由于矩阵乘法结合律,WUQW^{UQ}WUQ,WUKW^{UK}WUK合并成一个矩阵WUQUKW^{UQUK}WUQUK,同理,WUVW^{UV}WUV,WOW^OWO合并成WUVOW^{UVO}WUVO。

对比普通MHA计算公式:

attention⁡=softmax⁡(qkTd)vWO\operatorname{attention}=\operatorname{softmax}(\frac{qk^T}{\sqrt{d}})vW^Oattention=softmax(d qkT)vWO
=softmax⁡(htWQKTd)vWO=\operatorname{softmax}(\frac{h_tW^QK^T}{\sqrt{d}})vW^O=softmax(d htWQKT)vWO

可知,两种注意力机制计算量相同,没有引入额外计算量,而缓存从两个高维KKK,VVV变成了一个低维cKVc^{KV}cKV

接下来是位置编码RoPE的处理,MHA中,RoPE可以通过对qqq,kkk向量乘以一个位置相关的变换矩阵RiR_iRi(iii为当前token所处的位置)。然而,在MLA中,如果做相同的处理将会如下所示:

qiRi(kjRj)T=cQWUQRi(cjKVWUKRj)Tq_iR_i(k_jR_j)^T=c^QW^{UQ}R_i(c^{KV}_jW^{UK}R_j)^TqiRi(kjRj)T=cQWUQRi(cjKVWUKRj)T
=cQWUQRiRjT(WUK)T(cjKV)T=c^QW^{UQ}R_iR^T_j(W^{UK})^T(c_j^{KV})^T=cQWUQRiRjT(WUK)T(cjKV)T

由于RiR_iRi不是一个固定的矩阵,无法实现矩阵吸收来减少计算量。对于这个问题,deepseek的做法是将参与注意力计算的qqq,vvv分成两部分,一部分进行矩阵吸收操作,不带位置信息,一部分进行位置信息计算。

对于qqq,基于潜在向量cQc^QcQ通过矩阵WQRW^{QR}WQR变换为低维向量后进行RoPE变换得到qRq^RqR;对于kkk,直接将输入hth_tht也通过一个矩阵WKRW^{KR}WKR变换后做RoPE变换得到kRk^RkR,其中kRk^RkR按照MQA的处理方式,各个head之间共享,既减少了显存调用又保证了位置编码的全局一致。然后将qRq_RqR,kRk^RkR拼接到前面计算得到的qqq,kkk向量后面,得到最终用于计算注意力的q=[qC;qR]q=[q^C;q^R]q=[qC;qR],k=[kC;kR]k=[k^C;k^R]k=[kC;kR]。

这样计算点积时如下,其中ttt,jjj表示token,iii表示head:

qt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktRq_{t,i}k_{j,i}^T=[q_{t,i}^C;q_{t,i}^R]\times[k_{j,i}^C;k_t^R]=q^C_{t,i}k_{j,i}^C+q_{t,i}^Rk_t^Rqt,ikj,iT=[qt,iC;qt,iR]×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktR

这样不包含位置编码的部分qt,iCkj,iCq^C_{t,i}k_{j,i}^Cqt,iCkj,iC就可以进行矩阵吸收的处理,每个head缓存一个ct,iKVc_{t,i}^{KV}ct,iKV;后一项按MQA的方式计算,所有head只需缓存一个共享的ktRk^R_tktR。

学习参考资料:zhihu-冷面爸

相关推荐
晨非辰10 分钟前
数据结构排序系列指南:从O(n²)到O(n),计数排序如何实现线性时间复杂度
运维·数据结构·c++·人工智能·后端·深度学习·排序算法
2301_8129148710 分钟前
简单神经网络
人工智能·深度学习·神经网络
小曹要微笑11 分钟前
STM32H7系列全面解析:嵌入式性能的巅峰之作
c语言·stm32·单片机·嵌入式硬件·算法
寻星探路12 分钟前
JavaSE重点总结后篇
java·开发语言·算法
松涛和鸣2 小时前
14、C 语言进阶:函数指针、typedef、二级指针、const 指针
c语言·开发语言·算法·排序算法·学习方法
yagamiraito_4 小时前
757. 设置交集大小至少为2 (leetcode每日一题)
算法·leetcode·go
星释4 小时前
Rust 练习册 57:阿特巴什密码与字符映射技术
服务器·算法·rust
无敌最俊朗@4 小时前
力扣hot100-141.环形链表
算法·leetcode·链表
1***Q7846 小时前
深度学习技术
人工智能·深度学习
WWZZ20257 小时前
快速上手大模型:深度学习10(卷积神经网络2、模型训练实践、批量归一化)
人工智能·深度学习·神经网络·算法·机器人·大模型·具身智能