深度学习进阶(二十九)现代 LLM 的核心架构设计其四:GQA

上一篇我们介绍了 KV Cache:它把每一步重复的 K、V 计算存进缓存,让自回归推理的计算量骤降。

但这个加速不是没有代价的。KV Cache 的大小正比于多项参数,因此又反过来推动了注意力结构本身的改进。

这便是本篇内容:分组查询注意力(Grouped-Query Attention,GQA)

1.多头注意力 MHA

我们在前面展开过:标准 Transformer 使用多头注意力机制,\(H\) 个注意力头各自拥有独立的 Q、K、V 投影矩阵:

\\\text{head}_i = \\text{Attention}(\\mathbf{x} \\mathbf{W}_i\^Q, \\mathbf{x} \\mathbf{W}_i\^K, \\mathbf{x} \\mathbf{W}_i\^V) \\

而其中每个头独立学习不同的注意力模式。最终的输出是 \(H\) 个头的拼接:

\\\text{MHA}(\\mathbf{x}) = \\text{Concat}(\\text{head}_1, \\dots, \\text{head}_H) \\mathbf{W}\^O \\

这本身是为了增加表达能力的合理设置,但 KV Cache 出现后,KV Cache 需要为每个头单独存储一份 K 和 V。这一结构设计带来了较大的内存压力。

2. 多查询注意力 MQA

19 年,Shazeer(就是 SwiGLU 那位)在 Fast Transformer Decoding: One Write-Head is All You Need 提出了一个激进方案,即多查询注意力(Multi-Query Attention,MQA)。

在 MQA 中,\(H\) 个 Query Head 共享同一组 K 和 V,只有一个 K 头和一个 V 头:

\\\text{head}_i = \\text{Attention}(\\mathbf{Q}_i, \\mathbf{K}_{\\text{shared}}, \\mathbf{V}_{\\text{shared}}) \\

意思是无论有多少个 Query Head,它们查的都是同一份 K 和 V。

这样,KV Cache 的大小瞬间降到 MHA 的 \(1/H\)。对于 64 头的模型,直接省了 98.4% 的 KV Cache 内存。

但代价也很明显:不同 Query Head 已经被证明会关注不同模式把它们绑定到同一份 K、V 上,必然损失表达能力。

实验结果也印证了这一点:MQA 的训练更不稳定,在质量敏感的任务上效果有明显下降。

3. 分组查询注意力 GQA

目前的主流方案来自 23 年的论文 GQA: Training Generalized Multi-Query Transformer for Multi-Head Attention ,它其实更像是前两个方案的折中:

把 \(H\) 个 Query Head 分成 \(G\) 组,每组共享一个 K 头和一个 V 头。\(G\) 是一个可调参数。

这其实是把质量与效率的权衡变成了一个连续可调的超参数:你想省多少显存,就设置多少组。

举个例子,假设 \(H=8, G=4\),那么其对应关系即如下:

Query Head 使用的 KV Head
Q₀ KV₀
Q₁ KV₀
Q₂ KV₁
Q₃ KV₁
Q₄ KV₂
Q₅ KV₂
Q₆ KV₃
Q₇ KV₃

于是注意力实际上是这样的:

\\\text{head}_0 = \\text{Attention}(Q_0,K_0,V_0) \\

\\\text{head}_1 = \\text{Attention}(Q_1,K_0,V_0) \\

\\\text{head}_2 = \\text{Attention}(Q_2,K_1,V_1) \\

然后,所有组的输出再拼接到一起:

\\\text{GQA}(\\mathbf{x}) = \\text{Concat}(\\text{head}_{0}, \\text{head}_{1}, \\dots, \\text{head}_{H-1}) \\mathbf{W}\^O \\

看得出来,GQA 的改动非常小,它只改变了 K 和 V 的投影矩阵列数,简单对比如下:

  1. MHA:\(\mathbf{W}^Q \in \mathbb{R}^{d \times H d_h}\),\(\mathbf{W}^K \in \mathbb{R}^{d \times H d_h}\),\(\mathbf{W}^V \in \mathbb{R}^{d \times H d_h}\)
  2. GQA:\(\mathbf{W}^Q \in \mathbb{R}^{d \times H d_h}\),\(\mathbf{W}^K \in \mathbb{R}^{d \times G d_h}\),\(\mathbf{W}^V \in \mathbb{R}^{d \times G d_h}\)

GQA 的 K、V 列数从 \(H \times d_h\) 缩小到 \(G \times d_h\),Q 保持不动。这意味参数量节省了 \(2 \times (H - G) \times d \times d_h\),同时 KV Cache 也相应缩小。

而从实现角度看,现代框架通常不会真的复制 K、V。而是在进入注意力计算前,先针对头索引 \(h\) 构造一个映射:

\\\text{group_id}(h)=\\left\\lfloor\\frac{h}{H/G}\\right\\rfloor \\

然后计算时直接索引:

\K'_h = K_{\\text{group_id}(h)} \\

\V'_h = V_{\\text{group_id}(h)} \\

这样来实现只共享内存,不会真的复制数据。

4. 大模型中的实际配置

GQA 在提出之后迅速成为主流方案。如今绝大多数开源大模型都已经放弃传统 MHA,转而采用 GQA 来控制 KV Cache 的规模。

一些代表性开源模型如下:

模型 KV 头数 Query 头数 分组比例
LLaMA 2 70B 8 64 \(G=8\)
LLaMA 3 8B 8 32 \(G=8\)
LLaMA 3 70B 8 64 \(G=8\)
LLaMA 3 405B 8 128 \(G=8\)
Mistral 7B 8 32 \(G=8\)
Mixtral 8x7B 8 32 \(G=8\)
Qwen 2.5 7B 8 28 \(G=8\)
Qwen 2.5 72B 8 64 \(G=8\)
Gemma 2 9B 8 16 \(G=8\)

值得一提的是: 8 个 KV Head 几乎成为行业默认值。

这是因为对于常见的 \(32 \sim 128\) 个 Query Head 而言,8 个 KV Head 已经能够保留足够丰富的注意力模式,同时又能让 KV Cache 缩小到原来的 \(\frac18\sim\frac1{16}\) 左右,这是大量实践下的优解。

不过事情并没有在这里结束,但随着模型规模继续增大,研究者们又有了新发现:

KV Cache 的瓶颈不仅来自 Head 数量,还来自每个 Head 内部庞大的特征维度。

于是 DeepSeek 在 V2 中提出了多头潜在注意力(Multi-head Latent Attention, MLA)。

其思路是直接将 K 和 V 压缩到一个低维潜在空间中进行存储,简单来说就是压缩 KV 本身的表示维度。

但目前这种技术还是只集中在 DeepSeek 本身的生态中,GQA 仍然是当前开源模型的主流选择。

相关推荐
林间码客1 小时前
《人工智能概论》实验4 知识点复习提纲
人工智能·深度学习·机器学习
CJH(本人账号)2 小时前
AI Agent 安全危机:当你的“智能助手“变成攻击者的“远程武器“
网络·人工智能·安全·ai·开源·github
人工智能培训2 小时前
从GPT到开源大模型
人工智能·gpt·深度学习·机器学习·容器·知识图谱
要开心吖ZSH2 小时前
AI医疗分诊与健康咨询助手agent开发——(1)从零搭建SpringBoot与AI对话系统:后端骨架 + 前端对话页 + SSE流式输出
java·ai·agent·健康医疗
笨蛋©2 小时前
2026制造业实战:数字化检测计划(Inspection Plan)编制流程与质量管理标准化
ai·数字化·cad·质量管理·制造业
红宝村村长2 小时前
OPD Reverse KL
机器学习
AI原来如此2 小时前
Claude与ChatGPT激战正酣,国内AI中转站却突破2000家
人工智能·ai·chatgpt·大模型·编程
keykey6.3 小时前
LSTM 文本情感分析:从词嵌入到分类实战
开发语言·人工智能·深度学习·机器学习
数智工坊3 小时前
周志华《Machine Learning》学习笔记--第十三章--半监督学习
笔记·学习·机器学习