LLM 前沿架构笔记:多头注意力机制(Attention)的进阶优化策略
在现代 LLM(比如 Qwen、LLaMA 家族等)中,Transformer 的标准多头注意力(MHA)结构正面临表征能力 与显存消耗的双重挑战。主流解决思路主要有以下两类:
核心痛点
-
表征维度受限
- 标准 MHA 将隐藏层特征均分到各头,单头维度(
d_{head})小于主干宽度(hidden_size),易造成信息压缩损失。
- 标准 MHA 将隐藏层特征均分到各头,单头维度(
-
显存墙(Memory Wall)
- 推理阶段(自回归生成时),KV Cache 需存储长序列全部的 Key/Value,特别是高上下文窗口(128K+)下,显存压力骤增。
优化策略一:特征维度解耦(Decoupled Head Dimension)
核心思想
打破「单头维度 = hidden_size / num_heads」的传统,将每个 Attention Head 的特征维度设为独立可控的更大值,提升表征能力。
具体示例(hidden_size = 2048, num_heads = 16)
- 传统做法(强绑定):
dhead, old=hidden_sizenum_heads=204816=128d_{head,\,old} = \frac{hidden\_size}{num\_heads} = \frac{2048}{16} = 128dhead,old=num_headshidden_size=162048=128
- 解耦做法(升维举例:head_dim=256):
dhead, new=256 (显式指定)d_{head,\,new} = 256\ \text{(显式指定)}dhead,new=256 (显式指定)
总内部特征提升为:
Internal Attention Dim=num_heads×dhead, new=16×256=4096\text{Internal Attention Dim} = num\heads \times d{head,\,new} = 16 \times 256 = 4096Internal Attention Dim=num_heads×dhead,new=16×256=4096
计算过程
输入特征矩阵 X(形状 L × 2048),经投影后:
Q=X⋅WQQ = X \cdot W_QQ=X⋅WQ
其中,
WQ∈R2048×4096W_Q \in \mathbb{R}^{2048 \times 4096}WQ∈R2048×4096
变换后,
Q∈RL×4096Q \in \mathbb{R}^{L \times 4096}Q∈RL×4096
Reshape 为多个头:
Qreshaped∈RL×16×256Q_{reshaped} \in \mathbb{R}^{L \times 16 \times 256}Qreshaped∈RL×16×256
优势:
- 提升了注意力内部的表达容量
- 不增加主干(残差流)参数量和算力压力
优化策略二:分组查询注意力(Grouped Query Attention, GQA)
核心思想
让多个 Query 头共用较少数量的 Key/Value 头,是 MHA(每个 Q 头配一个 KV 头)与 MQA(所有 Q 头共用 1 组 KV)之间的"折中":
机制细节(以 16 个 Attention 头,2 个 KV 头举例)
-
输入 X(L × 2048),投影生成 K、V:
K=X⋅WK K = X \cdot W_KK=X⋅WK
V=X⋅WVV = X \cdot W_VV=X⋅WV
其中
WK,WV∈R2048×(2×256)W_K, W_V \in \mathbb{R}^{2048 \times (2 \times 256)}WK,WV∈R2048×(2×256)
输出
K,V∈RL×2×256 K, V \in \mathbb{R}^{L \times 2 \times 256}K,V∈RL×2×256
-
Query 头分组后,按如下方式共享 KV(假设 16 Q 头/2 组,每组 8 个):
- 第 0~7 个 Q 头,对应第 0 号 K/V 头
- 第 8~15 个 Q 头,对应第 1 号 K/V 头
-
组内做标准缩放点积注意力:
Attention(Q,K,V)=softmax(QKTdk)V Attention(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk QKT)V
其中,dk=256d_k = 256dk=256
优势
- 极致显存优化:KV Cache 从 16 头降为 2 头,显存为原来的 1/8
- 性能无损:大模型实验表明,GQA 下推理/生成能力几乎无损
总结
| 策略类型 | 优化点 | 主要公式 (代码块) | 效果 |
|---|---|---|---|
| 维度解耦 | 表征能力提升 | dhead⟶更大可控d_{head} \longrightarrow \text{更大可控}dhead⟶更大可控 | 特征自由度大幅提升 |
| 分组查询(GQA) | KV Cache 显存极致缩减 | 原KV头数分组合并数\frac{\text{原KV头数}}{\text{分组合并数}}分组合并数原KV头数 | 推理显存消耗降低 N 倍 |
| 注意力计算(通用) | - | Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk QKT)V | - |
公式用法注意:如需插入 LateX 代码块,形式为:
Q=X⋅WQQ = X \cdot W_QQ=X⋅WQ