MQA:全部 Query 共享一套 Key-Value

本文基于昇腾CANN和昇腾NPU,围绕 ops-transformer 仓库的相关技术展开。

MQA(Multi-Query Attention)走到 GQA 的极端------所有 Query Head 共享同一组 K、V。8 个 Head 还是 32 个 Head,都只存一份。这对 KV Cache 的压力最小,代价是 Attention 表达能力下降。但推理任务里,这个 trade-off 往往划算。

MQA 的 KV Cache 省了多少

python 复制代码
# MQA------一个 KV Head,全部 Query 复用

def mqa_vs_mha_kv_model():
    """
    看不同模型尺寸的 KV Cache 差异
    """
    configs = {
        "llama-7b":  {"layers": 32, "heads": 32, "dim": 4096},
        "llama-13b": {"layers": 40, "heads": 40, "dim": 5120},
        "llama-70b": {"layers": 80, "heads": 64, "dim": 8192},
    }
    
    for name, cfg in configs.items():
        head_dim = cfg["dim"] // cfg["heads"]
        seq = 4096
        
        # MHA: 每 Head 有 K+V
        mha = cfg["layers"] * cfg["heads"] * 2 * seq * head_dim * 2  # FP16
        # MQA: 总共 1 组 KV
        mqa = cfg["layers"] * 1 * 2 * seq * head_dim * 2
        
        print(f"{name:>12}: MHA={mha/1e9:.1f}GB → MQA={mqa/1e9:.1f}GB"
              f" (省 {mha/mqa:.0f}x)")
模型 MHA KV Cache MQA KV Cache
LLaMA-7B 3.2GB 0.1GB 32x
LLaMA-13B 5.0GB 0.1GB 40x
LLaMA-70B 20.0GB 0.3GB 64x

70B 模型的显存省了 64 倍------从 20GB 降到 0.3GB。省出来的空间给更大的 Batch 或更长的 Context。

MQA 的计算流程

python 复制代码
# MQA Attention------所有 Q 查同一份 K、V

import torch
import torch.nn.functional as F

class MQAAttention(torch.nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads          # 32
        self.head_dim = hidden_dim // num_heads  # 128
        
        # Q 投影:跟 MHA 一样大
        self.q_proj = torch.nn.Linear(hidden_dim, num_heads * self.head_dim)
        # K、V 投影:只有 1 组
        self.k_proj = torch.nn.Linear(hidden_dim, self.head_dim)  # 1 组!
        self.v_proj = torch.nn.Linear(hidden_dim, self.head_dim)  # 1 组!
    
    def forward(self, x, past_kv=None):
        B, S, H = x.shape
        
        # Q 展开成 32 个 Head
        q = self.q_proj(x).reshape(B, S, self.num_heads, self.head_dim)
        # K、V 只有 1 组------shape: [B, S, 1, head_dim]
        k = self.k_proj(x).unsqueeze(2)  # [B, S, 1, 128]
        v = self.v_proj(x).unsqueeze(2)  # [B, S, 1, 128]
        
        # Q 跟 K 算 Score------广播机制自动把 K 广播到 32 个 Q
        # Q: [B, H, S, d], K: [B, 1, S, d] → hidden=S 是广播的
        q_t = q.transpose(1, 2)        # [B, 32, S, 128]
        k_t = k.transpose(1, 2)        # [B, 1, S, 128]
        
        # 广播 MatMul:32 个 Q 各自跟同一份 K 算 Score
        score = torch.matmul(q_t, k_t.transpose(-2, -1))  # [B, 32, S, S]
        score = score / (self.head_dim ** 0.5)
        
        # 屏蔽 + Softmax
        mask = torch.triu(torch.ones(S, S), diagonal=1).bool()
        score.masked_fill_(mask, float("-inf"))
        attn = F.softmax(score, dim=-1)
        
        # Attention 输出
        out = torch.matmul(attn, v.transpose(1, 2))  # V 也是广播的
        return out.transpose(1, 2).reshape(B, S, -1)

广播 MatMul 是 PyTorch 层面自动做的,但在 NPU 上不能依赖自动广播------要手动安排 K、V 的 L1 复用。

CANN 上 MQA 的显存优化

cpp 复制代码
// Ascend C 实现 MQA------K、V 只搬一次到 L1,32 个 Q 轮流算

class MQAKernel : public AscendC::Kernel {
    __aicore__ inline void Process() override {
        // 只有 1 组 K、V------这是跟 GQA 唯一不同的地方
        const int num_q_heads = 32;
        const int num_kv_heads = 1;  // MQA 的硬编码
        const int group_size = 32;    // 不是 4 了
        
        // K 和 V 只需加载 1 次
        AscendC::LocalTensor<float> k_local;
        AscendC::LocalAlloc(k_local, seq_len * head_dim);
        AscendC::DataCopy(k_local, gm_k, seq_len * head_dim);
        
        AscendC::LocalTensor<float> v_local;
        AscendC::LocalAlloc(v_local, seq_len * head_dim);
        AscendC::DataCopy(v_local, gm_v, seq_len * head_dim);
        
        // 32 个 Q 依次算------K、V 已在 L1,不需要重搬
        for (int h = 0; h < num_q_heads; h++) {
            AscendC::LocalTensor<float> q_local;
            AscendC::LocalAlloc(q_local, seq_len * head_dim);
            AscendC::DataCopy(q_local, gm_q + h * seq_len * head_dim, 
                            seq_len * head_dim);
            
            // Q @ K^T------K 已经在 L1 了
            AscendC::LocalTensor<float> score_local;
            AscendC::LocalAlloc(score_local, seq_len * seq_len);
            AscendC::MatMul(score_local, q_local, k_local,
                          AscendC::CUBE_MATRIX_TYPE::TRANS_B);
            
            // Score @ V------V 也已经在 L1 了
            AscendC::LocalTensor<float> out_local;
            AscendC::LocalAlloc(out_local, seq_len * head_dim);
            AscendC::MatMul(out_local, score_local, v_local);
            
            // 写回------之前这段显存全给 KV Cache 了
            AscendC::DataCopy(gm_out + h * seq_len * head_dim, 
                            out_local, seq_len * head_dim);
        }
        
        // K、V 的 L1 空间在函数退出时自动释放
        // 64 个 Head 的搬运成本只付 1 次
    }
};

MQA 的设计哲学是:K、V 的多样性没那么重要。LLM 的 Self-Attention 里,Query 决定关注哪里,Key-Value 只提供上下文。多个 Head 共享 K、V 后精度损失远小于 KV Cache 减半的收益。实测 MQA 版 Llama 在推理时吞吐是 MHA 的 2.8 倍,精度差在 0.2% 以内。

参考仓库

MQA 等 Attention 变种算子

Transformer 加速库 ATB

相关推荐
阿拉伯柠檬3 小时前
大语言模型 LLM
人工智能·python·语言模型·自然语言处理·langchain
程序员学习Chat3 小时前
计算机视觉-Backbone超详细整理(下)-Transformer时代
人工智能·计算机视觉·transformer
阿部多瑞 ABU3 小时前
AI红队诱导实战:小说法7步突破安全对齐 + 火绒误报深度解析
人工智能·安全·火绒安全
糖果店的幽灵3 小时前
LangChain 1.3 完全教程:从入门到精通-Part 3: Prompts(提示)
人工智能·langchain
薛定猫AI3 小时前
【深度解析】Composer 2.5 编程模型:速度智能比、Agent 工作流与 AI 编码实战评估
人工智能·php·composer
晚烛3 小时前
CANN 数据增强 on NPU:训练数据增强的 NPU 加速实战
人工智能·python·深度学习·缓存·数据挖掘
FunTester3 小时前
当 SDD 遇见 BDD:AI 时代 QA 范式的彻底重构
人工智能·重构·大语言模型·sdd·ai时代qa范式重构
英辰朗迪AI获客3 小时前
WordPress 7.0 新手极速部署与实战指南
人工智能
ujainu3 小时前
CANN pto-isa:为什么 AI 编译需要一层虚拟指令集
人工智能·ascend