图像由作者使用 AI 生成
真正的记忆艺术,是专注的艺术。------ Samuel Johnson
大型语言模型(LLMs)展现出了惊人的能力,从架构上来看,它们基于 transformer 构建。在每一层中,transformer 都包含几十个注意力头。注意力头在信息混合中起着非常重要的作用,因此本文作者特别研究了它们在处理信息中的角色。有个很有意思的现象是:一些看起来没什么意义的 token(通常是第一个)会收到最多的注意力,因此被称作"注意力黑洞(attention sinks)"。
关于"注意力黑洞"这个现象,有几个疑问:
- 它们为什么会出现?
- 它们为什么重要,或者说,它们到底有什么用?
- 为什么注意力黑洞通常出现在上下文的第一个位置?
这篇文章就是专门来讲这些事的。
注意力黑洞是个很奇怪的现象;在没有明显原因的情况下,某个特定 token(或一组 token)会在整个注意力机制中变得特别有影响力。transformer 里的注意力机制允许模型在预测时专注于输入序列的不同部分。核心思想是通过 token 之间的相互比较来计算它们的重要性(也就是注意力值)。
注意力黑洞,其实就是某个 token 获得了不成比例的高注意力权重。叫它"黑洞",是因为它像在注意力图中形成了一个"坑",把其他 token 的注意力都吸过去了。
为什么会这样?为什么一个 token 会变成注意力黑洞?
目前没人确切知道,但有几个猜想:
- 频率高:一个频繁出现的 token 可能会受到更多关注;
- 位置优先:出现在序列开头或结尾的 token 可能有"特权";
- 上下文语义:语义上很强的 token 会引导模型关注它;
- 训练数据的偏差:训练数据本身的偏差可能让模型更关注某些 token。
那为什么这东西重要呢?
传统观点认为,注意力黑洞可能会影响模型的可解释性。如果模型大部分注意力都集中在一个 token 上,那我们研究序列中其他 token 被关注的程度就没太大意义,也就无法真正理解模型的决策过程。另外,如果某些 token 成为注意力黑洞,模型中的信息流可能会被改变,导致模型忽视了其它 token 中的重要上下文信息。这也可能意味着注意力黑洞降低了模型效率,并反映出模型存在偏差。
最近有一项新研究,对注意力黑洞为何形成、它们起什么作用提出了新看法:
本研究不仅展示了注意力黑洞是怎么出现的,还讨论了它们为何有用,尤其是在处理长上下文时。我们明确指出,这种模型学到的行为对于有效学习长上下文是必要的。------[来源]
研究作者指出,尤其是在处理长上下文时,transformer 在表示信息时会遇到一些问题。当 transformer 层数很多时,会出现所谓的秩塌陷(rank collapse)或者过度平滑(over smoothing) 。注意力机制虽然允许信息在整个序列中进行混合并提取,但如果混合次数太多,最终得到的表示就会趋于收敛、变得无信息。这个现象在处理长上下文时会更加严重。
还有一个现象叫做表示塌陷(representational collapse) ,就是说当输入序列很长时,transformer 会逐渐"破坏"掉靠后的 token 的信息。也就是说层数越多、上下文越长,transformer 的信息提取能力就越差。研究者认为这跟注意力黑洞现象有关,他们提出了如下假设:
本研究将这些问题与注意力黑洞联系起来。我们展示了 transformer 使用特定的注意力黑洞模式,是为了对抗表示塌陷,确保不同 token 的表示保持有意义的差异。------[来源]
秩塌陷是指模型在某一层失去了表示的多样性,输出变得太相似;而表示塌陷则是指从一层到下一层,token 表示变化太小,模型没学到新东西,陷入了模式化的输出。通常来说,秩塌陷出现时,也会伴随表示塌陷(但反过来不一定成立)。
研究指出这些现象是由于上下文过长或层数太多导致的"灾难性混合"效应。为了防止这种情况,transformer 自发学会了一种"防御机制"------注意力黑洞,就是其中之一。
总结一下,注意力机制让 transformer 能混合信息,但如果混合得太随意,那么一个 token 的小改动可能会影响整个序列(或者说影响很多 token),这就是"混过头"了(overmixing)。这会引发秩塌陷或表示塌陷。注意力黑洞则像一个机制,限制了模型的信息混合。因为如果注意力黑洞把其他 token 的注意力吸走,信息就没法在模型里乱传了。这样,模型的表示更稳定,对输入中其它位置的微小变化也不那么敏感。
为了验证这个想法,研究者在 Gemma 7B 模型上做了扰动分析,观察有没有注意力黑洞的情况下,扰动对模型的影响。结果显示,在没有注意力黑洞时,扰动的影响更大。
通过对注意力图的平滑处理也可以确认这个现象------当去掉注意力黑洞后,注意力图变得更模糊了。
还有一个很有趣的现象是:有的注意力头表现出一种"if-else"风格的行为。比如某个注意力头在前一个 token 含有撇号(')时才会活跃;如果没有撇号,它就几乎不关注这个序列。这时候,注意力黑洞就会吸走大部分注意力,从而避免对 token 表示的不必要修改。
作者还探索了上下文长度是否会影响注意力黑洞的形成。为此他们从头训练了一个 LLM,并改变了上下文长度。直觉上来说:上下文越长,token 之间信息混合就越多,注意力黑洞就越强烈。他们的观察确实支持这个观点:在长上下文中训练的模型中,注意力黑洞更常见;而在非常短的上下文下训练的模型中,几乎不存在注意力黑洞。
此外,随着模型规模的增长,注意力黑洞也会变得更强。在小模型中,注意力头更活跃、信息混合更频繁;而在大模型中,我们会看到更强的注意力黑洞(为了防止混过头)。
研究者还想知道 token(序列开始的 token)在 transformer 中是不是在注意力黑洞形成中起了特别作用?还是说仅仅因为它处在第一个位置?
实验结果表明是后者:如果你训练时用了 放在第一个位置,模型就会用它当黑洞;如果你推理时把它去掉,注意力黑洞就会消失。不过,即使你训练时不用 ,模型还是会在第一个位置形成一个黑洞(虽然会弱一些)。总之,预训练策略会影响注意力黑洞的形成方式,但无论如何,第一个 token 成为注意力黑洞几乎是不可避免的。
本研究提出了对注意力黑洞的新理解:它们是 transformer 架构中为应对"信息过度压缩"和"过度混合"而自然形成的应对机制。我们的分析表明,将大量注意力引导到 token 上,可以让模型对 token 扰动更不敏感。------[来源]
简单说,注意力黑洞在 transformer 模型中起着重要作用。它防止了过度混合,从而避免了秩塌陷与表示塌陷。注意力黑洞在训练过程中自发出现,像是模型的一种"自我防御机制"。通过集中注意力在一个 token 上,注意力头变得更稳定、对外部扰动更鲁棒。随着模型规模或上下文长度的增加,这种黑洞会变得更明显。
当然还有许多没解答的问题。如果注意力黑洞确实在信息流动中发挥作用,那我们就需要搞清楚:当模型学习新任务,比如微调时,它到底在起什么作用。理解注意力黑洞的角色,可以帮我们训练出更稳的模型,设计出更强健的微调策略。除此之外,更好地理解注意力黑洞还能帮我们改进文本生成,尤其是对长文本的生成效果,也可能在模型扩展方面带来实际价值。