最近开始读ds的论文,为了方便巩固知识,记录一下。
总的来说V3论文主要是三个创新点:
- Multi-Head Latent Attention: 通过下采样-上采样的方式,cache一个比较短维度的潜空间向量,代替缓存维度比较长的KV潜空间相连,从而对KV cache进行优化
- DeepSeekMoE with Auxiliary-Loss-Free Load Balancing:不引入损失函数情况,仅仅通过引入偏置项,解决MoE模型负载不均衡问题
- Multi-Token Prediction:将next token prediction转化为多个token的生成任务,进而提升推理性能,帮助模型快速收敛
接下来详细详细解读一下。
Multi-Head Latent Attention
这部分内容我想沿着 MHA -> KV Cache -> MQA -> GQA -> MLA的路径来讲解,会更清晰一些。
MHA (Multi Head Attention)
多头注意力形式如下: 其公式表示如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d k ) V \mathrm{Attention}(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})=\mathrm{Softmax}\left(\boldsymbol{\frac{Q\overset{\mathrm{T}}{K}}{\sqrt{d_{k}}}}\right)\boldsymbol{V} </math>Attention(Q,K,V)=Softmax dk QKT V
其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> Q \boldsymbol{Q} </math>Q, <math xmlns="http://www.w3.org/1998/Math/MathML"> K \boldsymbol{K} </math>K 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> V \boldsymbol{V} </math>V 是输入矩阵,分别代表查询矩阵,键矩阵和值矩阵, <math xmlns="http://www.w3.org/1998/Math/MathML"> d k d_{k} </math>dk 是向量维度。
以一个长度为4的序列为例,注意力计算过程如下:
在自回归模型中(autoregressive models),会逐个生成文本的每个token,这个过程可能比较慢,因为模型一次只能生成一个token,而且每次新的预测都依赖于之前的上下文。这意味着,要预测第4个token,你需要用到前3个token的信息,这通常涉及到对这些token的表示进行一系列矩阵乘法运算。当序列长度很大时,例如要预测第1001个token,你不仅需要前999个token的信息,还要加上第1000个token的信息,这使得整个QK矩阵非常巨大(1000×1000),进而导致attention的计算量巨大。
KV Cache
KV Cache 是一种优化技术,旨在提高模型在推理阶段的效率。 它通过缓存键(K)和值(V)的值,减少重复计算,从而加速解码器中的矩阵运算。这种技术在处理大规模语言模型时尤为重要,因为它可以显著减少推理时间,尽管会带来额外的内存开销(空间换时间)。后续的 MQA, GQA, MLA 等都是基于KV Cache的改进优化。
为什么只存储K和V,而不存储Q?
- 由于Q是当前时间步的输入,每次生成新token时都会重新计算Q,不需要缓存。
- 而K和V可以复用之前时间步的计算结果,通过缓存K和V,可以避免在每个时间步重新计算它们,从而提升效率。
在实际实验时,每次问答(QA)记录都会增加KV Cache的存储需求,因为模型需要保留之前问答的上下文信息以生成连贯的响应。这种线性增长的内存占用可能会导致在处理长对话或大量问答时,内存需求显著增加,从而影响系统的性能和效率。
由于硬件资源的限制,KV Cache的大小是有限的。当缓存达到其容量上限时,旧的信息可能会被新的信息覆盖或丢弃。其表现为随着问答的进行,早期的对话内容可能会因为KV Cache的容量限制而被移除或覆盖,导致模型逐渐"遗忘"之前的上下文。由于模型无法访问完整的对话历史,其生成的回复可能会变得不够准确或连贯,尤其是在需要依赖早期信息的情况下。所以,在长对话或多轮问答中,模型的性能可能会显著下降,因为它无法有效地利用整个对话历史。
MQA (Multi-Query Attention)
MQA 通过在多个注意力头之间共享同一组K和V,同时为每个注意力头维护不同的Q,减少了计算和内存开销,且不会显著影响模型的性能。
GQA (Group Query Attention)
GQA 是对 Transformer 中使用的传统MHA机制和MQA机制的折中。在标准多头自注意力中,每个注意力头独立处理整个序列。这种方法虽然功能强大,但计算成本高昂,尤其是对于长序列。而MQA虽然通过在多个注意力头之间共享同一组键和值简化了这一过程,但其简化也不可避免的带来了一些精度的损失。GQA 通过将查询分组在一起来解决此问题,从而降低了计算复杂性,而不会显著影响性能。
MLA (Multi Head Latent Attention)
多头潜在注意力 (MLA) 将潜在特征表示纳入注意力机制,以降低计算复杂度并改善上下文表示。MLA的核心是对KV进行压缩后,再送入标准的MHA算法中,用一个更短的k,v向量来进行计算,进而减少KV Cache的大小。
接下来我们来一步一步看:
- 下采样,计算代表 KV Cache 的潜在向量
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c t K V = W D K V h t \boxed{\mathbf{c}_t^{KV}}=W^{DKV}\mathbf{h}_t </math>ctKV=WDKVht
- 上采样,计算K和V
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> [ k t , 1 C ; k t , 2 C ; . . . ; k t , n h C ] = k t C = W U K c t K V [ v t , 1 C ; v t , 2 C ; . . . ; v t , n h C ] = v t C = W U V c t K V [\mathbf{k}{t,1}^C;\mathbf{k}{t,2}^C;...;\mathbf{k}{t,n_h}^C]=\mathbf{k}t^C=W^{UK}\mathbf{c}t^{KV}\\ [\mathbf{v}{t,1}^C;\mathbf{v}{t,2}^C;...;\mathbf{v}{t,n_h}^C]=\mathbf{v}_t^C=W^{UV}\mathbf{c}_t^{KV} </math>[kt,1C;kt,2C;...;kt,nhC]=ktC=WUKctKV[vt,1C;vt,2C;...;vt,nhC]=vtC=WUVctKV
- 为K引入位置信息(RoPE)
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> k t R = R o P E ( W K R h t ) k t , i = [ k t , i C ; k t R ] \boxed{\mathbf{k}t^R}=\mathrm{RoPE}(W^{KR}\mathbf{h}t)\\ \mathbf{k}{t,i}=[\mathbf{k}{t,i}^C;\mathbf{k}_t^R] </math>ktR=RoPE(WKRht)kt,i=[kt,iC;ktR]
- 为保证对称性和一致性,相同方式计算Q
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> c t Q = W D Q h t [ q t , 1 C ; q t , 2 C ; . . . ; q t , n h C ] = q t C = W U Q c t Q [ q t , 1 R ; q t , 2 R ; . . . ; q t , n h R ] = q t R = R o P E ( W Q R c t Q ) q t , i = [ q t , i C ; q t , i R ] \mathbf{c}t^Q=W^{DQ}\mathbf{h}t\\ [\mathbf{q}{t,1}^C;\mathbf{q}{t,2}^C;...;\mathbf{q}{t,n_h}^C]=\mathbf{q}t^C=W^{UQ}\mathbf{c}t^Q\\ [\mathbf{q}{t,1}^R;\mathbf{q}{t,2}^R;...;\mathbf{q}{t,n_h}^R]=\mathbf{q}t^R=\mathrm{RoPE}(W^{QR}\mathbf{c}t^Q)\\ \mathbf{q}{t,i}=[\mathbf{q}{t,i}^C;\mathbf{q}_{t,i}^R] </math>ctQ=WDQht[qt,1C;qt,2C;...;qt,nhC]=qtC=WUQctQ[qt,1R;qt,2R;...;qt,nhR]=qtR=RoPE(WQRctQ)qt,i=[qt,iC;qt,iR]
- 注意力计算
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> o t , i = ∑ j = 1 t S o f t m a x j ( q t , i T k j , i d h + d h R ) v j , i C u t = W O [ o t , 1 ; o t , 2 ; . . . ; 0 t , n h ] \mathbf{o}{t,i}=\sum{j=1}^t\mathrm{Softmax}j(\frac{\mathbf{q}{t,i}^T\mathbf{k}{j,i}}{\sqrt{d_h+d_h^R}})\mathbf{v}{j,i}^C\\ \mathbf{u}t=W^O[\mathbf{o}{t,1};\mathbf{o}{t,2};...;\mathbf{0}{t,n_h}] </math>ot,i=j=1∑tSoftmaxj(dh+dhR qt,iTkj,i)vj,iCut=WO[ot,1;ot,2;...;0t,nh]
位置编码为什么不直接加在K上?
因为旋转位置编码RoPE与潜向量的计算不兼容,为了同时使用潜向量计算和旋转位置编码RoPE两个技术,只能多创建一个新的向量来编码位置信息,将来通过向量合并将位置信息带入键向量
为什么对于查询向量q,也要进行潜向量的计算?主要是为了特征的对齐,如果只对键 k 和值 v 进行潜在向量计算,而忽略查询 q,会导致 q 和 k、v 的特征空间不一致,影响注意力机制的效果。在注意力机制中,q、k、v是平等的输入,对它们进行相同的潜在向量计算可以保持模型的对称性和一致性。
DeepSeekMoE with Auxiliary-Loss-Free Load Balancing
首先了解两个概念:
- 稠密模型:每一层中的每个神经元都与其他层中的所有神经元相连。这种全连接的架构确保了信息流的畅通无阻,使得模型能够学习到数据中的复杂关系和模式。
- 稀疏模型:并非每个神经元都与所有其他层的神经元相连。通过减少不必要的连接,稀疏模型能够在保证性能的同时,大幅降低计算资源的需求,提高运行效率。
很明显,拥有MoE架构的DeepSeek模型属于稀疏模型。
DeepSeekMoE
先来看一下DeepSeekMoE的基础架构:
总的来说,DeepSeekMoE架构主要是对FFN层进行了改进,包括两个主要策略:
- 共享专家隔离:隔离某些专家作为始终激活的共享专家,其他专家作为路由专家。共享专家捕获和整合不同上下文中的共同知识,减轻其他路由专家中的冗余,从而提高参数效率,并确保每个路由专家通过专注于独特方面而保持专业化。具体而言:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h t ′ = u t + ∑ i = 1 N s F F N i ( s ) ( u t ) + ∑ i = 1 N r g i , t F F N i ( r ) ( u t ) \mathbf{h}_t^{\prime}=\mathbf{u}t+\sum{i=1}^{N_s}\mathrm{FFN}i^{(s)}\left(\mathbf{u}t\right)+\sum{i=1}^{N_r}\mathrm{g}{i,t}\mathrm{FFN}_i^{(r)}\left(\mathbf{u}_t\right) </math>ht′=ut+i=1∑NsFFNi(s)(ut)+i=1∑Nrgi,tFFNi(r)(ut)
-
细粒度专家细分:在保持参数数量不变的同时,通过细分FFN的中间隐藏维度,将专家分割成更细的粒度。专家细分允许将不同的知识被更精确地学习以及被更细致地分配到不同的专家,每个专家都将保持更高水平的专业化,有助于更准确和有针对性的知识获取。具体而言:
-
- 每个路由专家和输入向量之间进行相似度计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> s i , t = S i g m o i d ( u t T e i ) s_{i,t}=\mathrm{Sigmoid}\left(\mathbf{u}_t{}^T\mathbf{e}_i\right) </math>si,t=Sigmoid(utTei)
-
- 选取前top k 个
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> g i , t ′ = { s i , t , s i , t ∈ T o p k ( { s j , t ∣ 1 ⩽ j ⩽ N r } , K r ) 0 , o t h e r w i s e , \begin{aligned}&g_{i,t}^{\prime}=\begin{cases}s_{i,t},&s_{i,t}\in\mathrm{Topk}(\{s_{j,t}|1\leqslant j\leqslant N_r\},K_r)\\0,&\mathrm{otherwise},&&\end{cases}\end{aligned} </math>gi,t′={si,t,0,si,t∈Topk({sj,t∣1⩽j⩽Nr},Kr)otherwise,
-
- 对topk做归一化
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> g i , t = g i , t ′ ∑ j = 1 N r g j , t ′ g_{i,t}=\frac{g_{i,t}^{\prime}}{\sum_{j=1}^{N_r}g_{j,t}^{\prime}} </math>gi,t=∑j=1Nrgj,t′gi,t′
Auxiliary-Loss-Free Load Balancing
为了解决MoE模型所面临的负载不均衡问题,即:有些专家被过度使用,而有些专家使用的较少(和推荐算法MMoE模型的"极化"现象有点像),传统方法使用负载损失来平衡负载不均衡现象,但是复杂度较高,为了解决这个问题,DeepSeek不使用负载损失,而是为每个路由专家添加一个偏置项 <math xmlns="http://www.w3.org/1998/Math/MathML"> b i b_i </math>bi,具体而言:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> g i , t ′ = { s i , t , s i , t + b i ∈ T o p k ( { s j , t + b j ∣ 1 ⩽ j ⩽ N r } , K r ) , 0 , otherwise. g_{i,t}^{\prime}=\begin{cases}s_{i,t},&s_{i,t}+b_i\in\mathrm{Topk}(\{s_{j,t}+b_j|1\leqslant j\leqslant N_r\},K_r),\\0,&\text{otherwise.}&&\end{cases} </math>gi,t′={si,t,0,si,t+bi∈Topk({sj,t+bj∣1⩽j⩽Nr},Kr),otherwise.
注意,这里偏置项 <math xmlns="http://www.w3.org/1998/Math/MathML"> b i b_i </math>bi仅仅用在选Top k路由专家上,并不参与后续的计算,参与后续计算的还是 <math xmlns="http://www.w3.org/1998/Math/MathML"> s i , t s_{i,t} </math>si,t。
然后,定义一个超参数𝛾,称为偏置更新速度。在每一步结束时如果相应的专家负荷过重,将减少偏差项 𝛾;如果相应的专家负荷不足,将增加偏差项 𝛾。
尽管DeepSeek-V3 主要依赖无辅助损耗策略来实现负载平衡,但为了防止任何单一序列内的极端不平衡,还是采用了一个complementary sequence-wise balance loss。
Multi-Token Prediction
我们都知道,当前主流的大模型(LLMs)都是decoder-base的模型结构,也就是无论在模型训练还是在推理阶段,对于一个序列的生成过程,都是token-by-token的。每次在生成一个token的时候,都要频繁跟访存交互,加载KV Cache,再通过多层网络做完整的前向计算。对于这样的访存密集型的任务,通常会因为访存效率形成训练或推理的瓶颈。
针对token-by-token生成效率的瓶颈,业界很多方法来优化,包括减少存储的空间和减少访存次数等,进而提升训练和推理性能。这里的MTP方法,也是为了优化训练和推理效率而提出的。
MTP的核心思想是:通过解码阶段的优化,将1-token的生成,转变成multi-token的生成,从而提升训练和推理的性能。具体来说,在训练阶段,一次生成多个后续token,可以一次学习多个位置的label,进而有效提升样本的利用效率,提升训练速度;在推理阶段通过一次生成多个token,实现成倍的推理加速来提升推理性能。
接下来看看Deepseek是怎么做MTP的: 如上图所示,用 D 个顺序的模块,预测 D 个tokens。每个MTP模块的具体结构(如图红框内):
- 输入token首先接入一层共享的embedding layer
- 对于第 i 个token <math xmlns="http://www.w3.org/1998/Math/MathML"> t i t_i </math>ti和第 k 个预测深度,首先将第 k−1 层的的隐层输出做归一化处理,再对第 i+k 位置的token embedding做归一化处理,两个结果进行concat
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h i ′ k = M k [ R M S N o r m ( h i k − 1 ) ; R M S N o r m ( E m b ( t i + k ) ) ] \mathbf{h}_i^{\prime k}=M_k[\mathrm{RMSNorm}(\mathbf{h}i^{k-1});\mathrm{RMSNorm}(\mathrm{Emb}(t{i+k}))] </math>hi′k=Mk[RMSNorm(hik−1);RMSNorm(Emb(ti+k))]
- 输入到Transformer层,获得第 k 个预测深度的输出
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> h 1 : T − k k = T R M k ( h 1 : T − k ′ k ) \mathbf{h}_{1:T-k}^k=\mathrm{TRM}k(\mathbf{h}{1:T-k}^{\prime k}) </math>h1:T−kk=TRMk(h1:T−k′k)
- 将输出通过一个各 Module共享的映射矩阵OutHead 变换,再过 softmax处理,计算出词表 V 维度的输出概率
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P i + k + 1 k = O u t H e a d ( h i k ) P_{i+k+1}^k=\mathrm{OutHead}(\mathbf{h}_i^k) </math>Pi+k+1k=OutHead(hik)
- 通过CrossEntropyLoss计算每个MTP Module Head的损失
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> L M T P k = CrossEntropy ( P 2 + k : T + 1 k , t 2 + k : T + 1 ) = − 1 T ∑ i = 2 + k T + 1 log P i k [ t i ] \mathcal{L}{\mathrm{MTP}}^k=\text{CrossEntropy}(P{2+k:T+1}^k,t_{2+k:T+1})=-\frac{1}{T}\sum_{i=2+k}^{T+1}\log P_i^k[t_i] </math>LMTPk=CrossEntropy(P2+k:T+1k,t2+k:T+1)=−T1i=2+k∑T+1logPik[ti]
所以总体而言,DeepSeek的实现相对于之前的方法增加了causal chain的连接关系,同时在embedding层增加了残差链接。
更详细的细节可以参考这张图: