DeepSeek技术解读-从MHA到MLA的完整解读(适合有点基础的同学)

一、传统的多头注意力机制(MHA,Multi-Head Attention):

在标准的Transformer中,多头注意力机制(MHA)通过并行计算多个注意力头来捕捉输入序列中的不同特征。每个注意力头都有自己的查询(Query, Q)、键(Key, K)和值(Value, V)矩阵,他们各自的主要作用如下:

  • 查询矩阵 Q:查询矩阵是你想要寻找某个信息的"问题"。在Transformer中,查询矩阵是输入的一个投影,表示当前token对其他token的"需求"。它帮助你确定自己在序列中的位置需要关注什么内容
  • 键矩阵 K:键矩阵是每个token提供的"信息"或"标识符"。每个token都有一个与之关联的键,用于与查询进行对比,以确定它与查询的相关性。你可以把键想象成词语的"标签"。
  • 值矩阵 V:值是实际的信息,提供了词向量的内容。根据Q与K的匹配程度,V最终用来生成输出向量。

假定:d 是隐向量维度, <math xmlns="http://www.w3.org/1998/Math/MathML"> n h n_h </math>nh是注意力头的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d h d_h </math>dh是每个注意力头的维度, <math xmlns="http://www.w3.org/1998/Math/MathML"> h t h_t </math>ht是attention层地t个token的输入隐向量。

  1. 标准的MHA首先使用三个权重矩阵(训练参数) <math xmlns="http://www.w3.org/1998/Math/MathML"> W q , W k , W v ∈ R d h ∗ n h ∗ d W_q,W_k,W_v \in{\mathbb{R}^{d_h*n_h*d}} </math>Wq,Wk,Wv∈Rdh∗nh∗d计算得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k t , v t q_t,k_t,v_t </math>qt,kt,vt向量。然后 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k t , v t q_t,k_t,v_t </math>qt,kt,vt向量拆分成 <math xmlns="http://www.w3.org/1998/Math/MathML"> n h n_h </math>nh份(每个注意力头分一份):

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> [ q 𝑡 , 1 ; q 𝑡 , 2 ; . . . ; q 𝑡 , 𝑛 h ] = q 𝑡 [ k 𝑡 , 1 ; k 𝑡 , 2 ; . . . ; k 𝑡 , 𝑛 h ] = k 𝑡 [ v 𝑡 , 1 ; v 𝑡 , 2 ; . . . ; v 𝑡 , 𝑛 h ] = v 𝑡 [q_{𝑡,1};q_{𝑡,2}; ...; q_{𝑡,𝑛_ℎ}]= q_𝑡 \\ [k_{𝑡,1};k_{𝑡,2}; ...; k_{𝑡,𝑛_ℎ}]= k_𝑡 \\ [v_{𝑡,1};v_{𝑡,2}; ...; v_{𝑡,𝑛_ℎ}]= v_𝑡 </math>[qt,1;qt,2;...;qt,nh]=qt[kt,1;kt,2;...;kt,nh]=kt[vt,1;vt,2;...;vt,nh]=vt

  1. 使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , k t q_t,k_t </math>qt,kt计算注意力得分,并使用注意力权重对 <math xmlns="http://www.w3.org/1998/Math/MathML"> v t v_t </math>vt进行加权求和,得到每个注意力头的结果:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> o 𝑡 , 𝑖 = ∑ j = 1 t ︁ S o f t m a x 𝑗 ( q 𝑡 , 𝑖 𝑇 k 𝑗 , 𝑖 d h ) v j , i o_{𝑡,𝑖} =\sum^{t}{j=1}{︁Softmax_𝑗 (\frac{q^𝑇 {𝑡,𝑖} k{𝑗,𝑖}}{\sqrt{d_h}})} v{j,i} </math>ot,i=j=1∑t︁Softmaxj(dh qt,iTkj,i)vj,i

  1. 最后把所有注意力头结果向量拼接起来,通过一层限行映射回原始维度:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> u 𝑡 = 𝑊 𝑂 [ o 𝑡 , 1 ; o 𝑡 , 2 ; . . . ; o 𝑡 , 𝑛 h ] u_𝑡 = 𝑊^𝑂[o_{𝑡,1}; o_{𝑡,2}; ...; o_{𝑡,𝑛_ℎ}] </math>ut=WO[ot,1;ot,2;...;ot,nh]


二、多头潜在注意力机制(MLA,Multi-Head Latent Attention) :

MLA的核心是对value和key进行低秩联合压缩 来减少推理时的键值缓存(KV cache),MLA设计中所有的K和V都需要缓存,MLA只需要缓存一个压缩的向量,并且此向量纬度远远小于 <math xmlns="http://www.w3.org/1998/Math/MathML"> d h n h d_hn_h </math>dhnh,只需要在推理计算时再向上投影生成所有的K和V。具体计算如下:

2.1 对value和key进行低秩联合压缩:

具体的:

  • 生成压缩潜在隐向量(latent vector),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> 𝑊 𝐷 𝐾 𝑉 ∈ R 𝑑 𝑐 × 𝑑 𝑊^{𝐷𝐾𝑉} ∈ \mathbb{R}^{𝑑_𝑐×𝑑} </math>WDKV∈Rdc×d是下投影矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> c 𝑡 𝐾 𝑉 = 𝑊 𝐷 𝐾 𝑉 h 𝑡 c^{𝐾𝑉}_𝑡 = 𝑊^{𝐷𝐾𝑉}h_𝑡 </math>ctKV=WDKVht。

  • 通过上投影矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> 𝑊 U K , 𝑊 U V ∈ R d h n h ∗ 𝑑 𝑐 𝑊^{UK}, 𝑊^{UV} ∈ \mathbb{R}^{d_hn_h*𝑑_𝑐} </math>WUK,WUV∈Rdhnh∗dc将潜在隐向量分别重建键K矩阵和值V矩阵,注意可以认为是映射成隐向量维度 h ,而不是每个注意力头的维度 : <math xmlns="http://www.w3.org/1998/Math/MathML"> k t 𝐶 = 𝑊 U K c 𝑡 𝐾 𝑉 k^𝐶_t = 𝑊^{UK}c^{𝐾𝑉}_𝑡 </math>ktC=WUKctKV, <math xmlns="http://www.w3.org/1998/Math/MathML"> v t 𝐶 = 𝑊 U V c 𝑡 𝐾 𝑉 v^𝐶_t = 𝑊^{UV}c^{𝐾𝑉}_𝑡 </math>vtC=WUVctKV

  • 应用旋转位置编码(RoPE),引入位置信息。因为传统的MHA中,每个token都对应着自己的K向量,天然包含了位置信息,现在通过一个共用的潜在隐向量映射得到的K是不包含位置信息的。 <math xmlns="http://www.w3.org/1998/Math/MathML"> k t R = R o P E ( W K R h t ) k^R_t = RoPE(W^{KR}h_t) </math>ktR=RoPE(WKRht)。其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> 𝑊 K R ∈ R 𝑑 h R ∗ d 𝑊^{KR} ∈ \mathbb{R}^{𝑑^R_h*d} </math>WKR∈RdhR∗d是用于生成解耦键的矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> d h R d^R_h </math>dhR是解耦键的维度。

  • 将位置矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> k t R k^R_t </math>ktR和上投影得到的矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> k t C k^C_t </math>ktC拼接得到最终的地t个位置token的K矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> k t = [ k t V ; k t R ] k_t = [k^V_t;k^R_t] </math>kt=[ktV;ktR], <math xmlns="http://www.w3.org/1998/Math/MathML"> v t = v t C v_t=v^C_t </math>vt=vtC。

    因此在推理过程中,为了加速推理,需要将K、V缓存。当采用MLA:只有 <math xmlns="http://www.w3.org/1998/Math/MathML"> k t K V k^{KV}_t </math>ktKV <math xmlns="http://www.w3.org/1998/Math/MathML"> k t R k^R_t </math>ktR需要缓存,只需要缓存 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( d c + d h R ) ∗ l (d_c + d^R_h) * l </math>(dc+dhR)∗l个参数。如果是MLA,所有keys和values向量都需要缓存,则需要缓存 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 n h d h l 2n_h d_h l </math>2nhdhl 个参数。

2.2 处理query向量

同样的,为了降低训练过程中的内存激活量,对Q也进行类似的处理:

2.3 计算attention输出

最后使用query ( <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , i q_{t,i} </math>qt,i),keys ( <math xmlns="http://www.w3.org/1998/Math/MathML"> k j , i k_{j,i} </math>kj,i)和values ( <math xmlns="http://www.w3.org/1998/Math/MathML"> v j , i C v^C_{j,i} </math>vj,iC)计算attention结果,这里 <math xmlns="http://www.w3.org/1998/Math/MathML"> q t , i q_{t,i} </math>qt,i <math xmlns="http://www.w3.org/1998/Math/MathML"> k j , i k_{j,i} </math>kj,i都拼接了RoPE位置向量,所以纬度是一样的 ,其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> 𝑊 O ∈ R 𝑑 ∗ d h n h 𝑊^O ∈ \mathbb{R}^{𝑑*d_hn_h} </math>WO∈Rd∗dhnh表示输出映射层矩阵 最终得到纬度为d的输出隐向量:

相关推荐
Buling_0几秒前
算法-哈希表篇08-四数之和
数据结构·算法·散列表
AllowM2 分钟前
【LeetCode Hot100】除自身以外数组的乘积|左右乘积列表,Java实现!图解+代码,小白也能秒懂!
java·算法·leetcode
RAN_PAND28 分钟前
STL介绍1:vector、pair、string、queue、map
开发语言·c++·算法
AnnyYoung1 小时前
华为云deepseek大模型平台:deepseek满血版
人工智能·ai·华为云
INDEMIND2 小时前
INDEMIND:AI视觉赋能服务机器人,“零”碰撞避障技术实现全天候安全
人工智能·视觉导航·服务机器人·商用机器人
慕容木木2 小时前
【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体的替代品,可本地部署+知识库,注册即可有750w的token使用
人工智能·火山引擎·deepseek·deepseek r1
南 阳2 小时前
百度搜索全面接入DeepSeek-R1满血版:AI与搜索的全新融合
人工智能·chatgpt
企鹅侠客2 小时前
开源免费文档翻译工具 可支持pdf、word、excel、ppt
人工智能·pdf·word·excel·自动翻译
fai厅的秃头姐!3 小时前
C语言03
c语言·数据结构·算法
冰淇淋百宝箱3 小时前
AI 安全时代:SDL与大模型结合的“王炸组合”——技术落地与实战指南
人工智能·安全