论文阅读笔记——Multi-Token Attention

MTA 论文

在 Transformer 中计算注意力权重时,仅依赖单个 Q 和 K 的相似度,无法有效捕捉多标记组合信息。(对于 A、B 两个词,单标记注意力需要分别计算两个词的注意力分数,再通过后处理定位共同出现的位置或通过多层隐式堆叠,增加模型深度和容量)。MTA 显示建模多标记依赖,同时不牺牲全局交互和额外参数。(通过卷积运算让他能够看到邻近的Q、K 以及其他注意力头的信息)

在 Transformer 其他部分,如 FFN 的输入/输出加卷积,主要是为了捕捉词元表示 之间的局部依赖关系,不直接改变注意力机制本身如何计算相关性
MTA 的卷积直接作用在 Q K T / A QK^T/A QKT/A ,意味着卷积直接参与了决定哪些上下文位置应该被关注的过程,在处理词元间的关系强度

提出两种方式:pre-softmax convolution 和 post-softmax convolution,MTA 默认采用 Pre-softmax Q-K Convolution 和 Post-softmax Head Mixing Convolution。二者区别在于是在 softmax 之前还是之后进行。

Q-K convolution

a i j = S o f t m a x ( ∑ i ′ = 0 c q − 1 ∑ j ′ = − ⌊ c k / 2 ⌋ ⌈ c k / 2 ⌉ − 1 1 i ≥ j − j ′ θ i ′ , j ′ q i − i ′ k j − j ′ ⊤ / d ) ( 1 ) a_{ij}=\mathrm{Softmax}\left(\sum_{i^{\prime}=0}^{c_{q}-1}\sum_{j^{\prime}=-\lfloor c_{k}/2\rfloor}^{\lceil c_{k}/2\rceil-1}\mathbf{1}{i\geq j- j^{\prime}}\theta{i^{\prime},j^{\prime}}q_{i-i^{\prime}}k_{j-j^{\prime}}^{\top}/\sqrt{d}\right) \qquad \qquad(1) aij=Softmax i′=0∑cq−1j′=−⌊ck/2⌋∑⌈ck/2⌉−11i≥j−j′θi′,j′qi−i′kj−j′⊤/d (1)

在卷积中,为防止未来信息泄露,需要做 Masking。理想的 Masking 比较复杂(见式(1)),采用一种简化形式:用 0 Mask 掉未来的 Q K T QK^T QKT 值,做卷积,再用 − ∞ -\infty −∞ Mask 掉结果中非法位置,再做 Softmax。
A = S o f t m a x ( M a s k − ∞ ( C o n v 2 d θ ( M a s k 0 ( A ^ ) ) ) ) . A=\mathrm{Softmax}\left(\mathrm{Mask}{-\infty}\left(\mathrm{Conv}2\mathrm{d}\theta\left(\mathrm{Mask}_0(\hat{A})\right)\right)\right). A=Softmax(Mask−∞(Conv2dθ(Mask0(A^)))).

Head Mixing Convolution

允许不同注意力头之间共享信息,放大重要信号。将 M 个头分成 M / c h M/c_h M/ch 个组,每组 c h c_h ch 个头。在每组的头内左 1D 卷积。同样可以在 softmax 之前或之后进行。

Group Normalization with depth scaling

改善梯度流,对抗深层网络中残差连接可能带来的主导效应(让模型更关注注意力部分输出,而不是仅仅传递上一层信息)。

在每个头的输出上独立应用组归一化,并结合一个随层数变化的缩放因子。

核心矛盾:在「增强注意力精度」和「保持计算效率」之间尚未找到完美平衡,当前更适合对计算资源不敏感的高精度场景。

实验结果

1.找字母块任务 ,验证 MTA 能够解决 [多条件匹配] 问题。

MTA 错误率接近 0% ,而 Transformer 失败率超 50%

2.LLM,在 105B 词元数据上训练 880M 参数模型

  • MTA 仅在 1/4 的层 使用 Key-Query 卷积(核大小: c q = 6 , c k = 11 c_q=6,c_k=11 cq=6,ck=11)。
  • 所有层使用 Head 卷积(核大小 c h = 2 c_h=2 ch=2)。
相关推荐
SHIPKING3931 小时前
【机器学习&深度学习】什么是下游任务模型?
人工智能·深度学习·机器学习
巴伦是只猫1 小时前
【机器学习笔记Ⅰ】11 多项式回归
笔记·机器学习·回归
DKPT4 小时前
Java桥接模式实现方式与测试方法
java·笔记·学习·设计模式·桥接模式
巴伦是只猫6 小时前
【机器学习笔记Ⅰ】13 正则化代价函数
人工智能·笔记·机器学习
伍哥的传说6 小时前
React 各颜色转换方法、颜色值换算工具HEX、RGB/RGBA、HSL/HSLA、HSV、CMYK
深度学习·神经网络·react.js
要努力啊啊啊8 小时前
YOLOv3-SPP Auto-Anchor 聚类调试指南!
人工智能·深度学习·yolo·目标检测·目标跟踪·数据挖掘
**梯度已爆炸**9 小时前
NLP文本预处理
人工智能·深度学习·nlp
Liudef069 小时前
FLUX.1-Kontext 高效训练 LoRA:释放大语言模型定制化潜能的完整指南
人工智能·语言模型·自然语言处理·ai作画·aigc
静心问道10 小时前
大型语言模型中的自动化思维链提示
人工智能·语言模型·大模型
汀沿河10 小时前
2 大模型高效参数微调;prompt tunning
人工智能·深度学习·prompt