【大模型面试八股 1】Transformer注意力机制:MHA、MQA、GQA、MLA原理解析与对比

《Transformer注意力机制:MHA、MQA、GQA、MLA原理解析与对比》


深入解析:Transformer中的注意力机制演进---MHA、MQA、GQA与MLA深度分析

在自然语言处理(NLP)、计算机视觉(CV)等领域的突破性进展中,Transformer架构作为核心模型之一,已经成为众多任务的标准框架。然而,随着模型规模的持续增长,传统的多头注意力机制(MHA)在计算效率和内存消耗方面逐渐暴露出瓶颈。为了解决这些问题,研究者提出了新型的注意力机制,如多查询注意力(MQA)分组查询注意力(GQA)多头潜在注意力(MLA),它们通过不同的优化策略提升了Transformer的效率。

本文将详细探讨这些注意力机制的原理、公式推导、局限性,并进行对比分析,帮助大家理解它们的特点及适用场景。


1. 多头注意力机制(MHA)

可参考我的以往文章:

【Transformer原理详解1】:https://blog.csdn.net/qq_52920290/article/details/157328797?spm=1001.2014.3001.5501

【Transformer原理详解2】:https://blog.csdn.net/qq_52920290/article/details/157440788?spm=1001.2014.3001.5501

原理与公式推导

多头注意力(MHA) 是Transformer的核心模块,其目的是通过多个并行的注意力头来捕捉不同子空间的信息,从而增强模型的表达能力。

给定输入序列 X = [ x 1 , x 2 , . . . , x n ] ∈ R n × d X = [x_1, x_2, ..., x_n] \in \mathbb{R}^{n \times d} X=[x1,x2,...,xn]∈Rn×d,其中 n n n 是序列长度, d d d 是隐藏层维度。

1) 线性投影生成 Q, K, V

对于每个头 i i i,我们使用线性变换生成Query (Q)Key (K)Value (V)

Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV

2) 缩放点积注意力

对于每个头的注意力计算,使用缩放点积注意力

Attention i = softmax ( Q i K i ⊤ d h ) V i \text{Attention}_i = \text{softmax}\left(\frac{Q_i K_i^\top}{\sqrt{d_h}}\right)V_i Attentioni=softmax(dh QiKi⊤)Vi

其中 d h = d h d_h = \frac{d}{h} dh=hd 是每个头的维度, h h h 是头的数量。

3) 多头拼接与输出映射

最终,将多个头的输出拼接在一起,通过线性映射得到最终的输出:

MHA ( X ) = Concat ( Attention 1 , ... , Attention h ) W O \text{MHA}(X) = \text{Concat}(\text{Attention}_1, \dots, \text{Attention}_h) W^O MHA(X)=Concat(Attention1,...,Attentionh)WO

局限性

  • 计算和内存开销高 :随着序列长度和头数的增加,计算复杂度为 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d),导致内存消耗也随之增加。
  • KV缓存问题 :推理时需要缓存每个头的KeyValue,这会显著增加显存消耗。

2. 多查询注意力(MQA)

原理与公式推导

多查询注意力(MQA)通过将多个注意力头的查询共享同一组KeyValue,从而减少计算和内存开销。

1) 查询保持独立

每个注意力头独立生成Query ,但共享相同的KeyValue

Q i = X W i Q Q_i = X W_i^Q Qi=XWiQ

2) 共享Key和Value

所有的头共享同一组KeyValue

K = X W K , V = X W V K = X W^K, \quad V = X W^V K=XWK,V=XWV

3) 注意力计算

每个头的注意力计算如下:

Attention i = softmax ( Q i K ⊤ d h ) V \text{Attention}_i = \text{softmax}\left(\frac{Q_i K^\top}{\sqrt{d_h}}\right)V Attentioni=softmax(dh QiK⊤)V

局限性

  • 表达能力下降 :由于所有头共享同一组KeyValue,不同头之间的表达差异性降低,可能导致性能下降。
  • 对长序列任务不适应:在需要捕捉长距离依赖的任务中,MQA的效果不如MHA。

3. 分组查询注意力(GQA)

原理与公式推导

分组查询注意力(GQA)MHAMQA 的折衷方法,通过将头分组,每个组共享同一组KeyValue,但组内的头仍然独立计算。

1) 分组计算

假设有 h h h 个头,分为 g g g 组,每组包含 h / g h/g h/g 个头:

K ( j ) = X W j K , V ( j ) = X W j V K^{(j)} = X W_j^K, \quad V^{(j)} = X W_j^V K(j)=XWjK,V(j)=XWjV

组内每个头的注意力计算:

Attention i = softmax ( Q i K ( j ) ⊤ d h ) V ( j ) \text{Attention}_i = \text{softmax}\left(\frac{Q_i K^{(j)\top}}{\sqrt{d_h}}\right)V^{(j)} Attentioni=softmax(dh QiK(j)⊤)V(j)

局限性

  • 超参数依赖 :分组数 g g g 的选择对性能有较大影响,需要调参。
  • 复杂度增加:分组数的选择增加了模型的调参复杂度。

4. 多头潜在注意力(MLA)

为了更好地理解MLA(Multi-Head Latent Attention) ,我们将详细讲解其公式推导原理,并结合相关图示对其进行全面的说明。

【参考原文】:DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model

arxiv: https://arxiv.org/abs/2405.04434

4.1 MLA原理概述

MLA 是一种通过低秩压缩(Low-Rank Compression)来优化传统多头注意力(MHA )机制的方案。其主要目标是减少 Key 和Value 在注意力计算中的内存开销,同时保持模型的表达能力。它通过引入潜在空间(Latent Space)来对输入的Query (Q)Key (K)和Value (V) 进行低秩压缩,从而减小计算和内存负担。

关键步骤:
  1. 低秩压缩 :将KeyValue 映射到低秩潜在空间,对Query也做了同样处理。
  2. 位置编码(RoPE) :使用RoPE进行位置编码,以便更好地捕捉输入序列的顺序信息。
  3. 多头并行计算:通过多个并行的注意力头来增强模型的表达能力。
  4. 缓存机制:在推理阶段缓存计算结果,以减少计算复杂度。

4.2 公式推导

4.2.1 低秩压缩

假设输入为一个隐藏状态 h t ∈ R n × d h_t \in \mathbb{R}^{n \times d} ht∈Rn×d,其中 n n n 是序列长度, d d d 是每个词向量的维度。

MHA 中,多个注意力头使用不同的QueryKeyValue进行计算:

Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV

其中, W i Q W_i^Q WiQ, W i K W_i^K WiK, W i V W_i^V WiV 是每个注意力头的参数矩阵。

MLA 中,为了节省内存和计算资源,我们使用低秩压缩 来减少KeyValue的维度。

1) 低秩映射

将原始的KeyValue映射到低秩潜在空间:

Q t = X W Q , K t = X W K , V t = X W V Q_t = X W^Q, \quad K_t = X W^K, \quad V_t = X W^V Qt=XWQ,Kt=XWK,Vt=XWV

通过低秩分解,我们将原始的KeyValue映射到潜在空间:

Q t latent = f ( Q t ) , K t latent = f ( K t ) , V t latent = f ( V t ) Q_t^{\text{latent}} = f(Q_t), \quad K_t^{\text{latent}} = f(K_t), \quad V_t^{\text{latent}} = f(V_t) Qtlatent=f(Qt),Ktlatent=f(Kt),Vtlatent=f(Vt)

其中, f ( ⋅ ) f(\cdot) f(⋅) 表示低秩压缩操作,通常通过矩阵分解因式分解方法进行。

这样,通过压缩后的 Q t Q_t Qt, K t K_t Kt, 和 V t V_t Vt,我们可以在一个较低维度的潜在空间内进行计算,从而降低了计算和内存的开销。

2) 注意力计算

计算每个头的注意力时,使用低秩压缩后的 Q t Q_t Qt、 K t K_t Kt 和 V t V_t Vt 来计算注意力权重。注意力的计算公式为:

Attention i = softmax ( Q t latent K t latent ⊤ d h ) V t latent \text{Attention}_i = \text{softmax}\left(\frac{Q_t^{\text{latent}} K_t^{\text{latent}\top}}{\sqrt{d_h}}\right)V_t^{\text{latent}} Attentioni=softmax(dh QtlatentKtlatent⊤)Vtlatent

其中, d h d_h dh 是每个头的维度,通常为 d h = d h d_h = \frac{d}{h} dh=hd, h h h 是头数。

最终的输出通过多个头进行拼接(Concatenation)并通过线性变换得到:

MLA ( X ) = Concat ( Attention 1 , ... , Attention h ) W O \text{MLA}(X) = \text{Concat}(\text{Attention}_1, \dots, \text{Attention}_h) W^O MLA(X)=Concat(Attention1,...,Attentionh)WO

4.2.2 位置编码(RoPE)

MLA中,我们使用**RoPE(Rotary Positional Encoding)**来处理输入的顺序信息。RoPE方法通过旋转的方式嵌入位置信息,使得模型能够更好地理解序列中的相对位置。

RoPE应用

在图示中,RoPE 被应用于压缩后的QueryKey/Value,以保持位置编码的同时避免内存开销过大。RoPE的应用公式为:

Q t latent = RoPE ( Q t latent ) , K t latent = RoPE ( K t latent ) Q_t^{\text{latent}} = \text{RoPE}(Q_t^{\text{latent}}), \quad K_t^{\text{latent}} = \text{RoPE}(K_t^{\text{latent}}) Qtlatent=RoPE(Qtlatent),Ktlatent=RoPE(Ktlatent)

这使得 Q t Q_t Qt 和 K t K_t Kt 在潜在空间中不仅包含语义信息,还能有效编码顺序信息。

4.3 MLA的工作流程

  1. 输入(Hidden State) :输入的隐藏状态 h t h_t ht 通过线性映射得到Query (Q)Key (K)Value (V)
  2. 低秩压缩 :对QKV 进行低秩压缩,得到Latent ( Q t Q_t Qt ) Latent ( K t V t K_tV_t KtVt )
  3. RoPE :将旋转位置编码(RoPE)应用于压缩后的 Q t Q_t Qt 和 K t K_t Kt。
  4. 注意力计算 :使用低秩压缩后的 Q t Q_t Qt 和 K t K_t Kt 进行注意力计算。
  5. 多头拼接:将多个注意力头的结果拼接起来,并通过一个线性变换输出最终结果。

4.4 MLA与MHA的对比

特性 MHA(多头注意力) MLA(多头潜在注意力)
Key/Value 每个头有独立的Key/Value 所有头共享低秩压缩后的Key/Value
内存消耗 低(通过低秩压缩)
计算复杂度 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)(但低秩计算更高效)
表达能力 接近MHA,但略有压缩损失
适用场景 适用于计算资源丰富的任务 适用于长序列任务和显存有限的任务

4.5 局限性

  • 压缩效果有限:过度压缩可能丧失一些表达能力,影响性能。
  • 实现复杂:相较于MHA,MLA需要更复杂的实现和低秩压缩过程。

5. 对比表

机制 意图 Key/Value特性 参数&内存开销 表达能力 计算复杂度 局限性
MHA 多子空间并行建模 每头独立 高(与头数正比) 最强 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) 内存&计算高
MQA 极致内存压缩 所有头共享K/V 极低 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) 信息表达不足
GQA 性能/效率折中 每组共享K/V 中等 较好 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) 分组超参敏感
MLA 低秩压缩&长序列优化 压缩潜变量空间 低(潜维度控制) 接近MHA O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d)(低秩) 结构复杂

6. 总结

在实际应用中,选择合适的注意力机制取决于任务的计算资源、内存需求以及表达能力要求。MHA 适用于需要强表达能力和资源充足的任务,MQA 则适用于内存受限的场景,GQA 提供了计算与内存的平衡,而MLA则在长序列任务和显存紧张的环境中提供了最优的性能。

MLA 在长序列任务中提供了显著的优势,尤其在内存受限的情况下,通过低秩压缩潜在空间计算 显著减少了显存的使用,并提高了计算效率。如果你正在开发一个需要高效处理长序列的模型,或者在显存有限的情况下进行推理,MLA将是一个值得考虑的优化方案。


版权声明:本文为CSDN博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。

相关推荐
互联科技报2 小时前
2026Ai短视频工具市场报告:行业规模、占有率及内容特工队AI排名
人工智能
小马爱打代码2 小时前
Spring AI 进阶:RAG 技术原理拆解与本地知识库检索落地
人工智能·深度学习·spring
freepopo2 小时前
天津商业空间设计:本地团队的美学落地方案 [特殊字符]
大数据·人工智能·python
想你依然心痛2 小时前
AI 视频生成新时代:Wan2.2-T2V-A5B 文本转视频实战
人工智能·音视频·ai视频
咚咚王者2 小时前
人工智能之核心技术 深度学习 第五章 Transformer模型
人工智能·深度学习·transformer
zuozewei2 小时前
7D-AI系列:AI大模型应用性能核心指标:TTFT 与 TPOT 详解
人工智能
WZGL12302 小时前
银发经济新浪潮:适老化设计如何催生“青春化”市场
人工智能
逄逄不是胖胖2 小时前
《动手学深度学习》-57长短期记忆网络LSTM
人工智能·深度学习·lstm
LASDAaaa12312 小时前
基于DETR的花卉种类识别与分类系统详解
人工智能·数据挖掘
数琨创享TQMS质量数智化2 小时前
国有大型交通运输设备制造集团QMS质量管理平台案例
大数据·人工智能·物联网