模拟注意力:少量参数放大 Attention 表征能力

论文标题

SAS: Simulated Attention Score

论文地址

https://arxiv.org/pdf/2507.07694

代码

见论文附录

作者背景

摩根士丹利,斯坦福大学,微软研究院,新加坡国立大学,得克萨斯大学奥斯汀分校,香港大学

动机

多头注意力是 Transformer 的核心组件,它通过引入多组 QKV 投影来捕获不同的特征子空间,从而在机器翻译、问答等任务中取得巨大成功。研究表明,注意力头的数量对 Transformer 性能至关重要:在保证每个头的隐藏维度充分大的前提下,注意力头数越多可以使模型效果越好。但问题在于,直接增加头数或维度往往伴随着模型参数量和计算开销的剧增,这在训练和部署中代价高昂

目前也有一些注意力架构旨在提高计算效率,例如共享部分 K 和 V 的 MQA、GQA;使用矩阵分解的 MLA、MFA、TPA 等。但这些方法主要关注降低内存/计算成本,而非提升注意力的表达能力

于是作者希望在不显著增加参数的前提下,设计一种新的注意力架构,实现近似于使用了更多注意力头和更高每头维度的性能提升

本文方法

本文提出 SAS(Simulated Attention Score,模拟注意力分数),核心思想是在注意力计算中引入额外的映射层,将低维的头表示投射到更高维空间,以此"虚拟地"增大注意力头数和每头的隐藏维度

一、扩展注意力头

对于查询Q,其特征维度为 [B, T, H, D],分别表示 batch_size,序列长度,头数和隐藏维度。为了扩充 H,需要把其他维度拉平,得到张量 Q_0,维度为 [B * T * D, H] ;然后使用一个 H * H' 的线性变换得到 Q_1,维度为 [B * T * D, H'],其中 H' > H;Q_1 过一个 ReLU 引入非线性;最后再过一个 H' * H' 的线性层,并加上 Q_1 的残差连接

于是我们获得了更多的注意力头,其中残差连接的引入可以稳定训练;值得注意的是,原始头数 H 和扩展后的头数 H' 都远小于每头的特征维度 D,所以这个两层 MLP 的参数开销相对整模型来说可以忽略不计

除了使用 MLP 来扩展维度,作者还尝试了卷积方案。具体地,将查询 Q 的维度整理成 [B * T, H, D],类似于多通道特征图,然后使用卷积变换将 H 扩展成 H',同样地,H' > H,最后再过第二层卷积以及残差连接

类似地,在 K、V 中都应用上述扩展流程

二、扩展注意力维度

直觉上,每个注意力头内部特征维度 D 越大,其能够捕获的子空间信息越丰富。因此作者进一步在 Q 和 K 上也引入了类似的维度扩展映射。这里之所以不对 V 进行扩展,是因为 V

直接决定了注意力模块的输出张量隐藏维度,扩大 V 的每头维度到 D 会导致后续前馈层的参数量大幅增加,违背了不显著增加计算量的初衷

三、注意力聚合

在标准多头注意力中,会将所有头的输出向量拼接,再通过一个输出投影矩阵 O 映射回模型的隐藏维度。然而,由于 SAS 对注意力头数进行了扩增,若仍按传统方式拼接势必导致输出维度变大,进而导致 O 的参数量大大增加(H * hidden 变为 H' * hidden)。为此,作者提出了参数高效注意力聚合机制,旨在不增加输出层参数规模的情况下完成对多头输出的整合

实现过程非常简单:假设注意力头数扩展了 r 倍,即 r * H = H',那么便把所有头划分成 r 组,每组都按照原本的计算流程与 O 相乘,得到 r 组输出结果,最后取平均作为注意力模块的最终输出传向前馈层

实验结果

作者在多种基准任务和数据集上对SAS进行了验证,包括语言模型预训练及下游任务评估,全面展示了SAS在准确率和效率方面的优势

一、预训练效果

下图对比了SAS与标准MHA、MQA、GQA、MLA、TPA等方法在ArXiv和Books3数据集上的表现。结果表明,无论是短序列训练(长度512)还是长序列训练(长度1024),SAS均取得了最低的验证困惑度

除了取得更好的性能,SAS还加速了模型的收敛。作者报告,在 Books3 数据集、序列长度512的训练中,MHA模型在5万步时达到29.86的验证困惑度,而SAS模型在3万步时就达到了相近的30.49,即 SAS 可以节约 40% 左右的计算资源

此外,作者还在更大的训练长度、更大的模型尺寸上做了验证,结果表明相比于其他注意力机制 SAS 具备稳定的优势

二、下游任务效果

作者评测了在多个下游任务基准(ARC、HellaSwag、PIQA、ScIQ、SocialIQA、WinoGrande)上 SAS 与其他注意力模型的效果,可见在多种参数量、训练数据量的实验设置下,SAS 大部分情况下都表现出了最优性能

相关推荐
新智元16 小时前
AI 教父 Hinton 末日警告!你必须失业,AI 万亿泡沫豪赌才能「赢」
人工智能·openai
新智元16 小时前
CUDA 再见了!寒武纪亮出软件全家桶
人工智能·openai
oe101916 小时前
好文与笔记分享 A Survey of Context Engineering for Large Language Models(下)
人工智能·笔记·语言模型·agent
有为少年16 小时前
告别乱码:OpenCV 中文路径(Unicode)读写的解决方案
人工智能·opencv·计算机视觉
清风与日月16 小时前
halcon分类器使用标准流程
深度学习·目标检测·计算机视觉
渔舟渡简16 小时前
机器学习-回归分析之一元线性回归
机器学习·线性回归
西西阿西哥16 小时前
【随便聊聊】和ChatGPT聊聊潜空间
深度学习·chatgpt
B站计算机毕业设计之家16 小时前
Python招聘数据分析可视化系统 Boss直聘数据 selenium爬虫 Flask框架 数据清洗(附源码)✅
爬虫·python·selenium·机器学习·数据分析·flask
FreeCode16 小时前
LangChain1.0智能体开发:模型使用
人工智能·langchain·agent
张较瘦_17 小时前
[论文阅读] AI+ | 从 “刚性科层” 到 “智能协同”:一文读懂 AI 应对国家安全风险的核心逻辑
论文阅读·人工智能