手撕大模型 | MQA 和 GQA 原理解析

一、前言

大模型(参数规模通常数十亿至万亿级)在处理复杂任务时面临三大核心问题:

  1. 显式关联的局限性:传统 Multi-head Attention 依赖输入数据的显式特征(如文本中的词向量、图像中的像素特征)计算注意力,难以捕捉深层语义(如 "同义词替换""上下文隐喻")或抽象结构(如 "逻辑推理链")。
  2. 数据效率与泛化瓶颈:大模型训练需海量数据,但在低资源语言、专业领域(如医学、法律)中,显式关联数据稀缺,导致模型泛化能力骤降。
  3. 多模态融合难点:跨模态任务(如图文生成、视频理解)中,不同模态的特征空间差异大(如文本的离散符号 vs 图像的连续像素),显式关联(如 "图像中的猫" 与文本 "猫")之外的隐式关联(如 "图像风格" 与 "文本情感")难以建模。

在前面的文章中,笔者已经讲解了 LLM 推理的关键技术-KV Cache(【手撕大模型】KVCache 原理及代码解析),但是随着大模型功能的不断强化,其容量也在增加,当前的 KVCache 技术已经不能满足发展需要了,所以,各种针对于 KVCache 优化的技术应时而生。

二、优化 KV cache 的方法

参考 zhuanlan.zhihu.com/p/167300361...

当前,业界针对 KV Cache 的优化方法可以总结为有四类:

  1. 共享 KV:多个 Head 共享使用 1 组 KV,将原来每个 Head 一个 KV,变成 1 组 Head 一个 KV,来压缩 KV 的存储。代表方法:GQA,MQA 等。
  2. 窗口 KV:针对长序列控制一个计算 KV 的窗口,KV cache 只保存窗口内的结果(窗口长度远小于序列长度),超出窗口的 KV 会被丢弃,通过这种方法能减少 KV 的存储,当然也会损失一定的长文推理效果。代表方法:Longformer 等。
  3. 量化压缩:基于量化的方法,通过更低的 Bit 位来保存 KV,将单 KV 结果进一步压缩,代表方法:INT8/INT4 等。
  4. 计算优化:通过优化计算过程,减少访存换入换出的次数,让更多计算在片上存储 SRAM 进行,以提升推理性能,代表方法:flashAttention 等。

共享 KV 主要有两种方法,MQA 和 GQA 都是 Google 提出的,详见: MQA(2019)GQA(2023)

三、MQA &

MQA(多查询注意力)和 GQA(分组查询注意力)作为自注意力机制的优化版本,主要作用是加快推理进程、减少内存占用,同时努力维持模型原有的性能表现。

以 Llama 7B 模型为例,其隐藏层维度为 4096,这意味着每个 K、V 向量都包含 4096 个数据。若采用半精度浮点(float16)格式存储,单个 Transformer 模块中,单序列的 K、V 缓存空间就达到 4096×2×2=16KB。由于 Llama 2 包含 32 个 Transformer 模块,单个序列在整个模型中的缓存需求便为 16KB×32=512KB。

那么多序列的情况呢?倘若句子长度为 1024,缓存空间就会增至 512MB。目前英伟达性能顶尖的 H100 显卡,其 SRAM 缓存约为 50MB,A100 则为 40MB,显然难以满足需求。尽管可将数据存于 GPU 显存(DRAM),但会对性能产生影响。7B 规模的模型已是如此,175B 规模的模型面临的问题更严峻。

解决这一问题的思路可从硬件与软件两方面展开:

  • 硬件层面,可借助 HBM(高带宽内存)提高数据读取速度;或摆脱冯・诺依曼架构的束缚,改变计算单元从内存读取数据的方式,转而以存储为核心,构建计算与存储一体化的 "存内计算" 模式,例如采用 "忆阻器" 技术。
  • 软件层面则通过算法优化来解决,Llama 2 所采用的 GQA(分组查询注意力)便是其中一种方案。

下面将通过图示来展示 MQA、GQA 与传统 MHA(多头注意力)的差异:

多头注意力机制(MHA)就是多个头各自拥有自己的 Q,K,V 来算各自的 Self-Attention,而 MQA(Multi Query Attention)就是 Q 依然保持多头,但是 K,V 只有一个,所有多头的 Q 共享一个 K,V ,这样做虽然能最大程度减少 KV Cache 所需的缓存空间,但是可想而知参数的减少意味着精度的下降,所以为了在精度和计算之间做一个 trade-off,GQA (Group Query Attention)孕育而生,即 Q 依然是多头,但是分组共享 K,V,即减少了 K,V 缓存所需的缓存空间,也暴露了大部分参数不至于精度损失严重。

四、MQA

MQA 的思路比较简单,详见上图,每一层的所有 Head,共享同一个 KV 来计算 Attention。相对于 MHA 的单个 Token 需要保存的 KV 数减少了 n_h 倍(head 数量),即每一层共享使用一个 Q 向量和一个 V 向量。

使用 MQA 的模型包括PaLMStarCoderGemini等。很明显,MQA 直接将 KV Cache 减少到了原来的 1/n_h,这是非常可观的,单从节省显存角度看已经是天花板了。

效果方面,目前看来大部分任务的损失都比较有限,且 MQA 的支持者相信这部分损失可以通过进一步训练来弥补回。此外,注意到 MQA 由于共享了 K、V,将会导致 Attention 的参数量减少了将近一半,而为了模型总参数量的不变,通常会相应地增大 FFN/GLU 的规模,这也能弥补一部分效果损失。

五、GQA

GQA 是平衡了 MQA 和 MHA 的一种折中的方法,不是每个 Head 一个 KV,也不是所有 Head 共享一个 KV,而是对所有 Head 分组,比如分组数为 g ,那么每组: n_h/g 个 Head 共享一个 KV。当 g=1 时,GQA 就等价于 MQA,当 g=n_h 时, GQA 就等价于 MHA。

为了方便更清晰的理解 GQA 和 MQA ,使用一个 Token 计算 KV 过程来进行演示:

总结下单 token 计算下,几种方法 KV Cache 的存储量(模型层数:l,每层 Head 数量:n_h )

六、参考链接

zhuanlan.zhihu.com/p/167300361...

54376)]

六、参考链接

zhuanlan.zhihu.com/p/167300361...

spaces.ac.cn/archives/10...

相关推荐
地平线开发者4 小时前
征程 6 | BPU trace 简介与实操
算法·自动驾驶
Wnq100724 小时前
如何在移动 的巡检机器人上,实现管道跑冒滴漏的视觉识别
数码相机·opencv·机器学习·计算机视觉·目标跟踪·自动驾驶
韩曙亮6 小时前
【自动驾驶】自动驾驶概述 ⑨ ( 自动驾驶软件系统概述 | 预测系统 | 决策规划 | 控制系统 )
人工智能·机器学习·自动驾驶·激光雷达·决策规划·控制系统·预测系统
IT古董1 天前
【第五章:计算机视觉-计算机视觉在工业制造领域中的应用】1.工业缺陷分割-(1)工业品缺陷风格基础知识:割任务定义、数据集介绍
计算机视觉·3d·自动驾驶
Mr.Winter`2 天前
深度强化学习 | 基于SAC算法的动态避障(ROS C++仿真)
人工智能·深度学习·神经网络·机器人·自动驾驶·ros·具身智能
酌量2 天前
路径平滑优化详解(二次规划): 数学建模与目标函数推导
经验分享·笔记·学习·机器人·自动驾驶
青云交2 天前
Java 大视界 -- Java 大数据在智慧交通自动驾驶仿真与测试数据处理中的应用
java·大数据·自动驾驶·数据存储·算法优化·智慧交通·测试数据处理
箫乾2 天前
第78篇:AI+交通:自动驾驶、智能交通管理与物流优化
人工智能·机器学习·自动驾驶
ARM+FPGA+AI工业主板定制专家2 天前
基于JETSON/RK3588机器人高动态双目视觉系统方案
人工智能·机器学习·fpga开发·机器人·自动驾驶