Attention MHA GQA | 注意力机制与键值缓存 (MHA / GQA / MQA)
Step 1: 核心思想与痛点
在大语言模型中,注意力机制 (Attention) 决定了模型如何"回顾"并提取历史上下文的信息。随着模型层数加深和序列变长,Attention 模块在推理阶段面临极大的性能挑战。
什么是 KV Cache?为什么它是性能瓶颈?
在自回归生成中,每次生成第 NNN 个 Token 时,我们需要计算它与前面 N−1N-1N−1 个 Token 的相关性。为了避免重复计算前 N−1N-1N−1 个 Token 的特征,我们将其投影后的 Key 和 Value 张量**缓存(Cache)**在显存中,当前步直接拼接读取。
然而,读取巨量的 KV Cache 会面临严重的显存容量瓶颈 和内存带宽瓶颈 (Memory-bound),导致推理极慢。
从 MHA 到 GQA:大模型架构的进化
- MHA (Multi-Head Attention): 标准的多头注意力。每个 Query 头都有自己专属的 Key 和 Value 头。其巨大的 KV Cache 让推理寸步难行。
- MQA (Multi-Query Attention) : 所有的 Query 头共享同一个 Key 和 Value 头。极大地减少了 KV Cache 的占用,但由于表达能力锐减,模型效果往往打折。
- GQA (Grouped-Query Attention) : LLaMA-2/3 采用的折中方案。将 Query 头分组,每组共享一个 Key 和 Value 头。这在模型效果和显存占用之间取得了良好的工程平衡。
MHA不共享头

MQA

GQA:一组 Query 头共享一个 K/V

为什么可以共享K和V


不共用记忆系统可以提升什么表达能力?
1. 不同头可以学习不同的"匹配标准"
Attention 打分是:
scorei(m,n)=Qi(m)⊤Ki(n)\text{score}_{i}(m,n) = Q_i(m)^\top K_i(n)scorei(m,n)=Qi(m)⊤Ki(n)
这里的 QiQ_iQi 和 KiK_iKi 决定了第 iii 个头怎么看"当前 token 和历史 token 是否相关"。
如果不共享 K/V,每个头有自己的:
WQ(i),WK(i)W_Q^{(i)},\quad W_K^{(i)}WQ(i),WK(i)
所以每个头可以学习不同的匹配方式。
例如:
- head1\text{head}_1head1 可能专门判断:当前词和前一个名词是否相关
- head2\text{head}_2head2 可能专门判断:当前括号和前面的左括号是否匹配
- head3\text{head}_3head3 可能专门判断:当前代词和前文实体是否指代一致
也就是说,不共享 K/V 让每个头拥有不同的"检索标准"。
如果共享 K/V,多个头虽然 QiQ_iQi 不同,但它们面对的是同一套 KKK:
Q1K⊤,Q2K⊤,Q3K⊤Q_1K^\top,\quad Q_2K^\top,\quad Q_3K^\topQ1K⊤,Q2K⊤,Q3K⊤
它们的提问方式不同,但"历史 token 被怎样编码成 Key"是一样的。
所以多样性会下降。
2. 不同头可以保存不同类型的历史信息
Value 决定了被取回来的内容:
headi=softmax(QiKi⊤d)Vi\text{head}_i = \text{softmax}\left(\frac{Q_iK_i^\top}{\sqrt d}\right)V_iheadi=softmax(d QiKi⊤)Vi
如果不共享 ViV_iVi,每个头可以把历史 token 编码成不同的信息。
例如同一个历史 token:
xnx_nxn
在不同头里可以被投影成:
V1(n)=xnWV(1)V_1(n)=x_nW_V^{(1)}V1(n)=xnWV(1)
V2(n)=xnWV(2)V_2(n)=x_nW_V^{(2)}V2(n)=xnWV(2)
V3(n)=xnWV(3)V_3(n)=x_nW_V^{(3)}V3(n)=xnWV(3)
它们可以分别表示不同内容:
- V1V_1V1: 语法信息
- V2V_2V2: 语义信息
- V3V_3V3: 位置/结构信息
- V4V_4V4: 代码缩进、括号、变量依赖
所以不共享 Value 的好处是:
同一个历史 token 可以被多个头保存成不同形态的信息。\boxed{\text{同一个历史 token 可以被多个头保存成不同形态的信息。}}同一个历史 token 可以被多个头保存成不同形态的信息。
如果共享 VVV,那么多个 Query 头最后取回来的都是同一套 Value 表示:
V1=V2=V3V_1 = V_2 = V_3V1=V2=V3
这会形成信息瓶颈。
3. 不同头可以形成更强的"功能分工"
MHA 中,每个头都有完整的:
Qi, Ki, ViQ_i,\ K_i,\ V_iQi, Ki, Vi
所以每个头都像一个独立的小检索器:
| 组成 | 说明 |
|---|---|
| 自己的问题 | QiQ_iQi |
| 自己的索引 | KiK_iKi |
| 自己的资料库 | ViV_iVi |
因此不同头可以分工:
| 头 | 可能学到的功能 |
|---|---|
| head 1 | 局部邻近 token |
| head 2 | 长距离依赖 |
| head 3 | 主谓关系 |
| head 4 | 指代关系 |
| head 5 | 标点/括号结构 |
| head 6 | 代码变量引用 |
| head 7 | 段落主题 |
| head 8 | 任务指令相关内容 |
这就是所谓的表达能力强:
模型可以同时建模多种不同关系,而不是所有头挤在同一套 K/V 表示里。\boxed{\text{模型可以同时建模多种不同关系,而不是所有头挤在同一套 K/V 表示里。}}模型可以同时建模多种不同关系,而不是所有头挤在同一套 K/V 表示里。
4. 共享 K/V 会损失什么?
以 GQA 为例,假设:
Q1,Q2,Q3,Q4Q_1,Q_2,Q_3,Q_4Q1,Q2,Q3,Q4
共享同一个:
K1,V1K_1,V_1K1,V1
那么它们是:
head1=softmax(Q1K1⊤)V1\text{head}_1 = \text{softmax}(Q_1K_1^\top)V_1head1=softmax(Q1K1⊤)V1
head2=softmax(Q2K1⊤)V1\text{head}_2 = \text{softmax}(Q_2K_1^\top)V_1head2=softmax(Q2K1⊤)V1
head3=softmax(Q3K1⊤)V1\text{head}_3 = \text{softmax}(Q_3K_1^\top)V_1head3=softmax(Q3K1⊤)V1
head4=softmax(Q4K1⊤)V1\text{head}_4 = \text{softmax}(Q_4K_1^\top)V_1head4=softmax(Q4K1⊤)V1
虽然 Q1,Q2,Q3,Q4Q_1,Q_2,Q_3,Q_4Q1,Q2,Q3,Q4 不同,所以注意力权重可以不同。
但是它们共享:
K1,V1K_1,V_1K1,V1
所以它们无法分别拥有完全不同的历史索引方式和历史内容表示。
这就像:
四个人可以问不同问题,但只能查同一本索引、同一本资料。\boxed{\text{四个人可以问不同问题,但只能查同一本索引、同一本资料。}}四个人可以问不同问题,但只能查同一本索引、同一本资料。
而 MHA 是:
四个人不仅问题不同,索引系统不同,资料整理方式也不同。\boxed{\text{四个人不仅问题不同,索引系统不同,资料整理方式也不同。}}四个人不仅问题不同,索引系统不同,资料整理方式也不同。
5. 不共享 K/V 的核心优势
一句话:
不共享 K/V 可以让不同注意力头从不同角度理解历史上下文。\boxed{\text{不共享 K/V 可以让不同注意力头从不同角度理解历史上下文。}}不共享 K/V 可以让不同注意力头从不同角度理解历史上下文。
具体提升的是:
- 关系建模能力
- 上下文检索能力
- 多头功能分工能力
- 复杂依赖捕捉能力
- 不同语义/语法/结构特征的并行提取能力
6. 为什么实际大模型还要用 GQA?
因为 MHA 虽然表达能力强,但 KV Cache 太大。
所以工程上会牺牲一部分表达能力,换取推理速度和显存:
| 方案 | 特点 |
|---|---|
| MHA | 表达能力强,但贵 |
| MQA | 便宜,但压缩太狠 |
| GQA | 折中 |
所以 GQA 的本质是:
承认不共享 K/V 更强,但为了减少 KV Cache,让几个 Query 头共用一套 K/V。\boxed{\text{承认不共享 K/V 更强,但为了减少 KV Cache,让几个 Query 头共用一套 K/V。}}承认不共享 K/V 更强,但为了减少 KV Cache,让几个 Query 头共用一套 K/V。
最关键的一句话是:
不共享 K/V 的表达能力,主要体现在每个头都可以学习不同的"历史信息组织方式"和"上下文检索方式"。\boxed{\text{不共享 K/V 的表达能力,主要体现在每个头都可以学习不同的"历史信息组织方式"和"上下文检索方式"。}}不共享 K/V 的表达能力,主要体现在每个头都可以学习不同的"历史信息组织方式"和"上下文检索方式"。
Step 2: 核心公式与张量维度
注意力计算公式:
Attention(Q,K,V)=Softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dk QKT)V
张量维度追踪 (Shape Tracking) - 算法工程师的灵魂:
假设 Batch=B, Seq_len=S, Num_Heads=H, Head_Dim=D
- 线性投影后:
Q形状为[B, S, H * D] - 切分多头后:转置为
[B, H, S, D] - 注意力分数计算:
Q @ K^T->[B, H, S, D] @ [B, H, D, S]->[B, H, S, S] - 乘以 Value:
Scores @ V->[B, H, S, S] @ [B, H, S, D]->[B, H, S, D] - 最后合并多头:转置回
[B, S, H, D]并view成[B, S, H * D]。
Step 3: 工业界源码映射
在真实的工业界代码中,这段逻辑在哪里?
- HuggingFace LLaMA :
transformers/models/llama/modeling_llama.py中的LlamaAttention类。 - vLLM (推理框架): 核心关注它的 PagedAttention 实现,用来解决这里 KV Cache 的显存碎片化问题。
Step 4: 动手实战
要求 :请补全下方 GroupedQueryAttention 的 forward 函数中的 TODO 部分,实现:
- 张量的多头切分与重塑 (Reshape)
- KV 缓存的拼接逻辑
- 注意力分数的计算
python
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
将 KV 头复制 n_rep 次,以匹配 Query 头的数量 (GQA/MQA 需要)
"""
batch, num_kv_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int = None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_dim = hidden_dim // num_heads
# 定义投影矩阵
self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=False)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor = None,
kv_cache: tuple[torch.Tensor, torch.Tensor] = None
):
batch_size, seq_len, _ = x.shape
# 1. 线性投影
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# ==========================================
# TODO 1: Reshape xq, xk, xv 以适配多头注意力计算
# ==========================================
# xq = ???
# xk = ???
# xv = ???
xq = xq.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2).contiguous()
xk = xk.view(batch_size,seq_len,self.num_kv_heads,self.head_dim).transpose(1,2).contiguous()
xv = xv.view(batch_size,seq_len,self.num_kv_heads,self.head_dim).transpose(1,2).contiguous()
# ==========================================
# TODO 2: 处理 KV Cache
# ==========================================
if kv_cache is not None:
k_cache, v_cache = kv_cache
# 自回归生成时,正确顺序应该是: # 历史 token 在前,当前 token 在后
xk = torch.cat([k_cache,xk], dim=2) # 将缓存的 KV 与当前的 KV 拼接
# dim = 2 是因为 KV 的 shape 是 [B, H, S, D],我们希望在 S 维度上拼接, S 维度是序列长度
xv = torch.cat([v_cache,xv], dim=2) # 将缓存的 KV 与当前的 KV 拼接
# xk = ???
# xv = ???
new_kv_cache = (xk, xv)
# 通过 repeat_kv 把 GQA 的 KV 头数扩充到和 Query 数量一致
xk = repeat_kv(xk, self.num_queries_per_kv)
xv = repeat_kv(xv, self.num_queries_per_kv)
# ==========================================
# TODO 3: 计算注意力分数 (Scaled Dot-Product)
# ==========================================
# scores = ???
scores = torch.matmul(xq,xk.transpose(2,3))/math.sqrt(self.head_dim)
if attention_mask is not None:
scores = scores + attention_mask
probs = torch.nn.functional.softmax(scores, dim=-1)
# dim = -1 是因为我们希望在最后一个维度上进行 softmax,最后一个维度是序列长度 S,softmax 会在这个维度上计算注意力分数的归一化概率
output = torch.matmul(probs, xv)
# probs = ???
# output = ???
# ==========================================
# TODO 4: 恢复形状并输出
# [B, H, S, D] -> [B, S, H*D]
# ==========================================
# output = ???
# return self.o_proj(output), new_kv_cache
output = output.transpose(1,2).contiguous().view(batch_size,seq_len,self.num_heads*self.head_dim)
return self.o_proj(output), new_kv_cache
解析
1. TODO 1 (多头切分与维度转置)
- 切分多头: 使用
view(batch_size, seq_len, num_heads, head_dim)将线性投影后的张量从[B, S, H*D]重塑为[B, S, H, D],其中H是头数,D是每个头的维度。 - 维度转置: 通过
.transpose(1, 2)将形状从[B, S, H, D]转为[B, H, S, D],这是注意力计算的标准格式,方便后续的矩阵乘法。 - GQA 的 KV 头数: 注意
xk和xv使用num_kv_heads而不是num_heads,这是 GQA 的核心区别。例如 LLaMA-2 70B 使用 64 个 Query 头但只有 8 个 KV 头。 - 工程细节: 为什么要 transpose?因为注意力分数计算
Q @ K^T需要在[S, D]和[D, S]维度上进行矩阵乘法,将 heads 维度放在第二个位置可以让 batch 和 heads 维度自动广播。
2. TODO 2 (KV Cache 拼接)
- 自回归生成场景: 在推理时,每次只生成一个新 token,但需要用到之前所有 token 的 Key 和 Value。如果每次都重新计算,时间复杂度是 O(N2)O(N^2)O(N2)。
- Cache 机制: 将历史的
k_cache和v_cache与当前步的xk、xv在seq_len维度(dim=2)拼接,形状从[B, H, old_len, D]变为[B, H, old_len+1, D]。 - 显存优化: GQA 的 KV Cache 只需存储
num_kv_heads个头,而不是num_heads个。例如 LLaMA-2 70B 的 KV Cache 显存占用是 MHA 的 1/8。 - 工程陷阱: 必须在
repeat_kv之前进行拼接,否则会重复缓存已扩展的 KV,导致显存浪费。
3. TODO 3 (Scaled Dot-Product Attention)
- 注意力分数计算:
scores = Q @ K^T / sqrt(d_k),其中xk.transpose(2, 3)将[B, H, S, D]转为[B, H, D, S],与xq的[B, H, S, D]相乘得到[B, H, S, S]的注意力矩阵。 - 缩放因子: 除以
sqrt(head_dim)是为了防止点积结果过大导致 softmax 梯度消失。这是 Transformer 原论文的核心设计。 - Mask 机制:
attention_mask通常是一个下三角矩阵(Causal Mask),用-inf填充上三角部分,确保当前 token 只能看到之前的 token。 - Softmax 归一化: 在最后一个维度(
dim=-1)上进行 softmax,将注意力分数转为概率分布。 - 加权求和:
output = probs @ V将注意力权重与 Value 相乘,得到加权后的特征表示。
4. TODO 4 (多头合并与输出投影)
- 维度转置:
.transpose(1, 2)将[B, H, S, D]转回[B, S, H, D]。 - 内存连续性:
.contiguous()确保张量在内存中是连续存储的,这是view操作的前提。如果不调用contiguous(),view可能会报错。 - 合并多头:
.view(batch_size, seq_len, -1)将[B, S, H, D]展平为[B, S, H*D],其中-1自动推断为num_heads * head_dim。 - 输出投影: 通过
o_proj线性层将多头特征映射回hidden_dim,这是标准 Transformer 的最后一步。
进阶思考:GQA 的延迟扩充 (Lazy Expansion)
- 为什么不直接缓存扩充后的 KV? 如果在缓存时就用
repeat_kv扩充,显存占用会和 MHA 一样大,失去了 GQA 的优势。 - 正确做法: 只缓存原始的
num_kv_heads个头,在每次前向传播时临时扩充。虽然增加了计算量,但由于注意力计算是 Memory-bound(受限于显存带宽而非计算速度),这个开销可以忽略。 - 工业实践: vLLM、TensorRT-LLM 等推理框架都采用这种延迟扩充策略,在 70B 模型上可以节省数十 GB 的 KV Cache 显存。