LLM-leetcode TASK03

Attention MHA GQA | 注意力机制与键值缓存 (MHA / GQA / MQA)

Step 1: 核心思想与痛点

在大语言模型中,注意力机制 (Attention) 决定了模型如何"回顾"并提取历史上下文的信息。随着模型层数加深和序列变长,Attention 模块在推理阶段面临极大的性能挑战。

什么是 KV Cache?为什么它是性能瓶颈?
在自回归生成中,每次生成第 NNN 个 Token 时,我们需要计算它与前面 N−1N-1N−1 个 Token 的相关性。为了避免重复计算前 N−1N-1N−1 个 Token 的特征,我们将其投影后的 Key 和 Value 张量**缓存(Cache)**在显存中,当前步直接拼接读取。
然而,读取巨量的 KV Cache 会面临严重的显存容量瓶颈内存带宽瓶颈 (Memory-bound),导致推理极慢。
从 MHA 到 GQA:大模型架构的进化

  • MHA (Multi-Head Attention): 标准的多头注意力。每个 Query 头都有自己专属的 Key 和 Value 头。其巨大的 KV Cache 让推理寸步难行。
  • MQA (Multi-Query Attention) : 所有的 Query 头共享同一个 Key 和 Value 头。极大地减少了 KV Cache 的占用,但由于表达能力锐减,模型效果往往打折。
  • GQA (Grouped-Query Attention) : LLaMA-2/3 采用的折中方案。将 Query 头分组,每组共享一个 Key 和 Value 头。这在模型效果和显存占用之间取得了良好的工程平衡

MHA不共享头

MQA

GQA:一组 Query 头共享一个 K/V

为什么可以共享K和V


不共用记忆系统可以提升什么表达能力?

1. 不同头可以学习不同的"匹配标准"

Attention 打分是:

scorei(m,n)=Qi(m)⊤Ki(n)\text{score}_{i}(m,n) = Q_i(m)^\top K_i(n)scorei(m,n)=Qi(m)⊤Ki(n)

这里的 QiQ_iQi 和 KiK_iKi 决定了第 iii 个头怎么看"当前 token 和历史 token 是否相关"。

如果不共享 K/V,每个头有自己的:

WQ(i),WK(i)W_Q^{(i)},\quad W_K^{(i)}WQ(i),WK(i)

所以每个头可以学习不同的匹配方式。

例如:

  • head1\text{head}_1head1 可能专门判断:当前词和前一个名词是否相关
  • head2\text{head}_2head2 可能专门判断:当前括号和前面的左括号是否匹配
  • head3\text{head}_3head3 可能专门判断:当前代词和前文实体是否指代一致

也就是说,不共享 K/V 让每个头拥有不同的"检索标准"。

如果共享 K/V,多个头虽然 QiQ_iQi 不同,但它们面对的是同一套 KKK:

Q1K⊤,Q2K⊤,Q3K⊤Q_1K^\top,\quad Q_2K^\top,\quad Q_3K^\topQ1K⊤,Q2K⊤,Q3K⊤

它们的提问方式不同,但"历史 token 被怎样编码成 Key"是一样的。

所以多样性会下降。


2. 不同头可以保存不同类型的历史信息

Value 决定了被取回来的内容:

headi=softmax(QiKi⊤d)Vi\text{head}_i = \text{softmax}\left(\frac{Q_iK_i^\top}{\sqrt d}\right)V_iheadi=softmax(d QiKi⊤)Vi

如果不共享 ViV_iVi,每个头可以把历史 token 编码成不同的信息。

例如同一个历史 token:

xnx_nxn

在不同头里可以被投影成:

V1(n)=xnWV(1)V_1(n)=x_nW_V^{(1)}V1(n)=xnWV(1)

V2(n)=xnWV(2)V_2(n)=x_nW_V^{(2)}V2(n)=xnWV(2)

V3(n)=xnWV(3)V_3(n)=x_nW_V^{(3)}V3(n)=xnWV(3)

它们可以分别表示不同内容:

  • V1V_1V1: 语法信息
  • V2V_2V2: 语义信息
  • V3V_3V3: 位置/结构信息
  • V4V_4V4: 代码缩进、括号、变量依赖

所以不共享 Value 的好处是:

同一个历史 token 可以被多个头保存成不同形态的信息。\boxed{\text{同一个历史 token 可以被多个头保存成不同形态的信息。}}同一个历史 token 可以被多个头保存成不同形态的信息。

如果共享 VVV,那么多个 Query 头最后取回来的都是同一套 Value 表示:

V1=V2=V3V_1 = V_2 = V_3V1=V2=V3

这会形成信息瓶颈。


3. 不同头可以形成更强的"功能分工"

MHA 中,每个头都有完整的:

Qi, Ki, ViQ_i,\ K_i,\ V_iQi, Ki, Vi

所以每个头都像一个独立的小检索器:

组成 说明
自己的问题 QiQ_iQi
自己的索引 KiK_iKi
自己的资料库 ViV_iVi

因此不同头可以分工:

可能学到的功能
head 1 局部邻近 token
head 2 长距离依赖
head 3 主谓关系
head 4 指代关系
head 5 标点/括号结构
head 6 代码变量引用
head 7 段落主题
head 8 任务指令相关内容

这就是所谓的表达能力强

模型可以同时建模多种不同关系,而不是所有头挤在同一套 K/V 表示里。\boxed{\text{模型可以同时建模多种不同关系,而不是所有头挤在同一套 K/V 表示里。}}模型可以同时建模多种不同关系,而不是所有头挤在同一套 K/V 表示里。


4. 共享 K/V 会损失什么?

以 GQA 为例,假设:

Q1,Q2,Q3,Q4Q_1,Q_2,Q_3,Q_4Q1,Q2,Q3,Q4

共享同一个:

K1,V1K_1,V_1K1,V1

那么它们是:

head1=softmax(Q1K1⊤)V1\text{head}_1 = \text{softmax}(Q_1K_1^\top)V_1head1=softmax(Q1K1⊤)V1

head2=softmax(Q2K1⊤)V1\text{head}_2 = \text{softmax}(Q_2K_1^\top)V_1head2=softmax(Q2K1⊤)V1

head3=softmax(Q3K1⊤)V1\text{head}_3 = \text{softmax}(Q_3K_1^\top)V_1head3=softmax(Q3K1⊤)V1

head4=softmax(Q4K1⊤)V1\text{head}_4 = \text{softmax}(Q_4K_1^\top)V_1head4=softmax(Q4K1⊤)V1

虽然 Q1,Q2,Q3,Q4Q_1,Q_2,Q_3,Q_4Q1,Q2,Q3,Q4 不同,所以注意力权重可以不同。

但是它们共享:

K1,V1K_1,V_1K1,V1

所以它们无法分别拥有完全不同的历史索引方式和历史内容表示。

这就像:

四个人可以问不同问题,但只能查同一本索引、同一本资料。\boxed{\text{四个人可以问不同问题,但只能查同一本索引、同一本资料。}}四个人可以问不同问题,但只能查同一本索引、同一本资料。

而 MHA 是:

四个人不仅问题不同,索引系统不同,资料整理方式也不同。\boxed{\text{四个人不仅问题不同,索引系统不同,资料整理方式也不同。}}四个人不仅问题不同,索引系统不同,资料整理方式也不同。


5. 不共享 K/V 的核心优势

一句话:

不共享 K/V 可以让不同注意力头从不同角度理解历史上下文。\boxed{\text{不共享 K/V 可以让不同注意力头从不同角度理解历史上下文。}}不共享 K/V 可以让不同注意力头从不同角度理解历史上下文。

具体提升的是:

  • 关系建模能力
  • 上下文检索能力
  • 多头功能分工能力
  • 复杂依赖捕捉能力
  • 不同语义/语法/结构特征的并行提取能力

6. 为什么实际大模型还要用 GQA?

因为 MHA 虽然表达能力强,但 KV Cache 太大。

所以工程上会牺牲一部分表达能力,换取推理速度和显存:

方案 特点
MHA 表达能力强,但贵
MQA 便宜,但压缩太狠
GQA 折中

所以 GQA 的本质是:

承认不共享 K/V 更强,但为了减少 KV Cache,让几个 Query 头共用一套 K/V。\boxed{\text{承认不共享 K/V 更强,但为了减少 KV Cache,让几个 Query 头共用一套 K/V。}}承认不共享 K/V 更强,但为了减少 KV Cache,让几个 Query 头共用一套 K/V。

最关键的一句话是:

不共享 K/V 的表达能力,主要体现在每个头都可以学习不同的"历史信息组织方式"和"上下文检索方式"。\boxed{\text{不共享 K/V 的表达能力,主要体现在每个头都可以学习不同的"历史信息组织方式"和"上下文检索方式"。}}不共享 K/V 的表达能力,主要体现在每个头都可以学习不同的"历史信息组织方式"和"上下文检索方式"。


Step 2: 核心公式与张量维度

注意力计算公式:

Attention(Q,K,V)=Softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=Softmax(dk QKT)V

张量维度追踪 (Shape Tracking) - 算法工程师的灵魂:

假设 Batch=B, Seq_len=S, Num_Heads=H, Head_Dim=D

  1. 线性投影后:Q 形状为 [B, S, H * D]
  2. 切分多头后:转置为 [B, H, S, D]
  3. 注意力分数计算:Q @ K^T -> [B, H, S, D] @ [B, H, D, S] -> [B, H, S, S]
  4. 乘以 Value:Scores @ V -> [B, H, S, S] @ [B, H, S, D] -> [B, H, S, D]
  5. 最后合并多头:转置回 [B, S, H, D]view[B, S, H * D]

Step 3: 工业界源码映射

在真实的工业界代码中,这段逻辑在哪里?

  • HuggingFace LLaMA : transformers/models/llama/modeling_llama.py 中的 LlamaAttention 类。
  • vLLM (推理框架): 核心关注它的 PagedAttention 实现,用来解决这里 KV Cache 的显存碎片化问题。

Step 4: 动手实战

要求 :请补全下方 GroupedQueryAttentionforward 函数中的 TODO 部分,实现:

  1. 张量的多头切分与重塑 (Reshape)
  2. KV 缓存的拼接逻辑
  3. 注意力分数的计算
python 复制代码
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    将 KV 头复制 n_rep 次,以匹配 Query 头的数量 (GQA/MQA 需要)
    """
    batch, num_kv_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim)

class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int = None):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_dim = hidden_dim // num_heads
        
        # 定义投影矩阵
        self.q_proj = nn.Linear(hidden_dim, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(num_heads * self.head_dim, hidden_dim, bias=False)

    def forward(
        self, 
        x: torch.Tensor, 
        attention_mask: torch.Tensor = None, 
        kv_cache: tuple[torch.Tensor, torch.Tensor] = None
    ):
        batch_size, seq_len, _ = x.shape
        
        # 1. 线性投影
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        
        # ==========================================
        # TODO 1: Reshape xq, xk, xv 以适配多头注意力计算
        # ==========================================
        # xq = ??? 
        # xk = ???
        # xv = ???
        xq = xq.view(batch_size,seq_len,self.num_heads,self.head_dim).transpose(1,2).contiguous()
        xk = xk.view(batch_size,seq_len,self.num_kv_heads,self.head_dim).transpose(1,2).contiguous()
        xv = xv.view(batch_size,seq_len,self.num_kv_heads,self.head_dim).transpose(1,2).contiguous()



        
        

        # ==========================================
        # TODO 2: 处理 KV Cache
        # ==========================================
        if kv_cache is not None:
            k_cache, v_cache = kv_cache
            # 自回归生成时,正确顺序应该是: # 历史 token 在前,当前 token 在后
            xk = torch.cat([k_cache,xk], dim=2)  # 将缓存的 KV 与当前的 KV 拼接 
            # dim = 2 是因为 KV 的 shape 是 [B, H, S, D],我们希望在 S 维度上拼接, S 维度是序列长度
            xv = torch.cat([v_cache,xv], dim=2)  # 将缓存的 KV 与当前的 KV 拼接
            # xk = ???
            # xv = ???
            
        new_kv_cache = (xk, xv)
        
        # 通过 repeat_kv 把 GQA 的 KV 头数扩充到和 Query 数量一致
        xk = repeat_kv(xk, self.num_queries_per_kv)
        xv = repeat_kv(xv, self.num_queries_per_kv)
        
        # ==========================================
        # TODO 3: 计算注意力分数 (Scaled Dot-Product)
        # ==========================================
        # scores = ???
        scores = torch.matmul(xq,xk.transpose(2,3))/math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            scores = scores + attention_mask
        
        probs = torch.nn.functional.softmax(scores, dim=-1)
         # dim = -1 是因为我们希望在最后一个维度上进行 softmax,最后一个维度是序列长度 S,softmax 会在这个维度上计算注意力分数的归一化概率
        output = torch.matmul(probs, xv)
    
        # probs = ???
        # output = ???
        
        
        # ==========================================
        # TODO 4: 恢复形状并输出
        # [B, H, S, D] -> [B, S, H*D]
        # ==========================================
        # output = ???
        
        # return self.o_proj(output), new_kv_cache
        output = output.transpose(1,2).contiguous().view(batch_size,seq_len,self.num_heads*self.head_dim)
        return self.o_proj(output), new_kv_cache

解析

1. TODO 1 (多头切分与维度转置)

  • 切分多头: 使用 view(batch_size, seq_len, num_heads, head_dim) 将线性投影后的张量从 [B, S, H*D] 重塑为 [B, S, H, D],其中 H 是头数,D 是每个头的维度。
  • 维度转置: 通过 .transpose(1, 2) 将形状从 [B, S, H, D] 转为 [B, H, S, D],这是注意力计算的标准格式,方便后续的矩阵乘法。
  • GQA 的 KV 头数: 注意 xkxv 使用 num_kv_heads 而不是 num_heads,这是 GQA 的核心区别。例如 LLaMA-2 70B 使用 64 个 Query 头但只有 8 个 KV 头。
  • 工程细节: 为什么要 transpose?因为注意力分数计算 Q @ K^T 需要在 [S, D][D, S] 维度上进行矩阵乘法,将 heads 维度放在第二个位置可以让 batch 和 heads 维度自动广播。

2. TODO 2 (KV Cache 拼接)

  • 自回归生成场景: 在推理时,每次只生成一个新 token,但需要用到之前所有 token 的 Key 和 Value。如果每次都重新计算,时间复杂度是 O(N2)O(N^2)O(N2)。
  • Cache 机制: 将历史的 k_cachev_cache 与当前步的 xkxvseq_len 维度(dim=2)拼接,形状从 [B, H, old_len, D] 变为 [B, H, old_len+1, D]
  • 显存优化: GQA 的 KV Cache 只需存储 num_kv_heads 个头,而不是 num_heads 个。例如 LLaMA-2 70B 的 KV Cache 显存占用是 MHA 的 1/8。
  • 工程陷阱: 必须在 repeat_kv 之前进行拼接,否则会重复缓存已扩展的 KV,导致显存浪费。

3. TODO 3 (Scaled Dot-Product Attention)

  • 注意力分数计算: scores = Q @ K^T / sqrt(d_k),其中 xk.transpose(2, 3)[B, H, S, D] 转为 [B, H, D, S],与 xq[B, H, S, D] 相乘得到 [B, H, S, S] 的注意力矩阵。
  • 缩放因子: 除以 sqrt(head_dim) 是为了防止点积结果过大导致 softmax 梯度消失。这是 Transformer 原论文的核心设计。
  • Mask 机制: attention_mask 通常是一个下三角矩阵(Causal Mask),用 -inf 填充上三角部分,确保当前 token 只能看到之前的 token。
  • Softmax 归一化: 在最后一个维度(dim=-1)上进行 softmax,将注意力分数转为概率分布。
  • 加权求和: output = probs @ V 将注意力权重与 Value 相乘,得到加权后的特征表示。

4. TODO 4 (多头合并与输出投影)

  • 维度转置: .transpose(1, 2)[B, H, S, D] 转回 [B, S, H, D]
  • 内存连续性: .contiguous() 确保张量在内存中是连续存储的,这是 view 操作的前提。如果不调用 contiguous()view 可能会报错。
  • 合并多头: .view(batch_size, seq_len, -1)[B, S, H, D] 展平为 [B, S, H*D],其中 -1 自动推断为 num_heads * head_dim
  • 输出投影: 通过 o_proj 线性层将多头特征映射回 hidden_dim,这是标准 Transformer 的最后一步。

进阶思考:GQA 的延迟扩充 (Lazy Expansion)

  • 为什么不直接缓存扩充后的 KV? 如果在缓存时就用 repeat_kv 扩充,显存占用会和 MHA 一样大,失去了 GQA 的优势。
  • 正确做法: 只缓存原始的 num_kv_heads 个头,在每次前向传播时临时扩充。虽然增加了计算量,但由于注意力计算是 Memory-bound(受限于显存带宽而非计算速度),这个开销可以忽略。
  • 工业实践: vLLM、TensorRT-LLM 等推理框架都采用这种延迟扩充策略,在 70B 模型上可以节省数十 GB 的 KV Cache 显存。