Softmax注意力与线性注意力的优雅融合,Agent Attention推动注意力新升级

来自清华大学的研究者提出了一种新的注意力范式------代理注意力 (Agent Attention)。

近年来,视觉 Transformer 模型得到了极大的发展,相关工作在分类、分割、检测等视觉任务上都取得了很好的效果。然而,将 Transformer 模型应用于视觉领域并不是一件简单的事情。与自然语言不同,视觉图片中的特征数量更多。由于 Softmax 注意力是平方复杂度,直接进行全局自注意力的计算往往会带来过高的计算量。针对这一问题,先前的工作通常通过减少参与自注意力计算的特征数量的方法来降低计算量。例如,设计稀疏注意力机制(如 PVT)或将注意力的计算限制在局部窗口中(如 Swin Transformer)。尽管有效,这样的自注意力方法很容易受到计算模式的影响,同时也不可避免地牺牲了自注意力的全局建模能力。

与 Softmax 注意力不同,线性注意力将 Softmax 解耦为两个独立的函数,从而能够将注意力的计算顺序从 (query・key)・value 调整为 query・(key・value),使得总体的计算复杂度降低为线性。然而,目前的线性注意力方法效果明显逊于 Softmax 注意力,难以实际应用。

注意力模块是 Transformers 的关键组件。全局注意力机制具良好的模型表达能力,但过高的计算成本限制了其在各种场景中的应用。本文提出了一种新的注意力范式,代理注意力 (Agent Attention),同时具有高效性和很强的模型表达能力。

具体来说,代理注意力在传统的注意力三元组 (Q,K,V) 中引入了一组额外的代理向量 A,定义了一种新的四元注意力机制 (Q, A, K, V)。其中,代理向量 A 首先作为查询向量 Q 的代理,从 K 和 V 中聚合信息,然后将信息广播回 Q。由于代理向量的数量可以设计得比查询向量的数量小得多,代理注意力能够以很低的计算成本实现全局信息的建模。

此外,本文证明代理注意力等价于一种线性注意力范式,实现了高性能 Softmax 注意力和高效线性注意力的自然融合。该方法在 ImageNet 上使 DeiT、PVT、Swin Transformer、CSwin Transformer 等模型架构取得了显著的性能提升 ,能够将模型在 CPU 端加速约 2.0 倍、在 GPU 端加速约 1.6 倍 。应用于 Stable Diffusion 时,代理注意力能够将模型生成速度提升约 1.8 倍 ,并显著提高图像生成质量 ,且无需任何额外训练

方法

在本文中,我们创新性地向注意力三元组 (Q,K,V) 引入了一组额外的代理向量 A,定义了一种四元的代理注意力范式 (Q, A, K, V)。如图 1 (c) 所示,在代理注意力中,我们不会直接计算 Q 和 K 之间两两的相似度,而是使用少量的代理向量 A 来收集 K 和 V 中的信息,进而呈递给 Q,以很低的计算成本实现全局信息的建模。从整体结构上看,代理注意力由两个常规 Softmax 注意力操作组成,并且等效为一种广义的线性注意力,实现了高性能 Softmax 注意力和高效线性注意力的自然融合,因而同时具有二者的优点,即:计算复杂度低且模型表达能力强。

图 1:Softmax 注意力、线性注意力与代理注意力机制对比

1. 代理注意力

图 2:代理注意力示意图

上图即为代理注意力的示意图,下面给出具体数学形式。为了书写方便,我们将 Softmax 注意力和线性注意力分别缩写为:

其中,Q,K,V 分别为 Query、Key、Value 矩阵,表示 Softmax 函数,为线性注意力中的映射函数。则代理注意力可以表示为:

另一个等效的表示为:

其中 A 为新定义的代理矩阵。

如公式 (3) 和示意图第一行所示,代理注意力由两个 Softmax 注意力操作组成,分别为代理特征聚合和广播。具体来说,我们首先将 A 作为 Query,在 A、K 和 V 之间进行注意力计算,从所有特征中汇聚信息,得到代理特征。随后,我们将 A 作为 Key,作为 Value,和 Q 进行第二次注意力计算,将代理特征中的全局信息广播回每一个特征,并获得最终输出 O。这样一来,我们避免了 Q 和 K 之间相似度的计算,而是通过代理向量实现了每个 query-key 之间的信息交换。可以看到,在这一计算范式中,少量的代理特征 A 充当了 Q 的 "代理人"------ 从 K 和 V 中收集信息并呈递给 Q,因而本文将这种注意力机制命名为代理注意力 。实际应用中,我们将 A 的数量设置为一个小的超参数 n,从而以线性计算复杂度实现了全局建模。

值得指出的是,如公式 (4) 和示意图第二行所示,代理注意力实际上将高性能的 Softmax 注意力和高效的线性注意力融合在了一起,通过使用两次 Softmax 注意力操作实现了广义线性注意力范式,其中等效映射函数定义为

实际应用中,代理向量可以通过不同的方法获得,例如设置为一组可学习参数,或通过池化等方式从输入特征中得到。我们也可以使用更加优越的方法来获得代理向量,例如 Deformable Points、Token Merging 等。本文中,我们采用简单的池化来获取代理向量。

2. 代理注意力模块

为了更好地发挥代理注意力的潜力,本文进一步做出了两方面的改进。一方面,我们定义了 Agent Bias 以促进不同的代理向量聚焦于图片中不同的位置,从而更好地利用位置信息。另一方面,作为一种广义的线性注意力,代理注意力也面临特征多样性不足的问题,因此我们采用一个轻量化的 DWC 作为多样性恢复模块。

在以上设计的基础上,本文提出了一种新的代理注意力模块,其结构如下图:

图 3:代理注意力模块

结合了 Softmax 注意力和线性注意力的优势,代理注意力模块具有以下特点:

(1) **计算复杂度低且模型表达能力强。**之前的研究通常将 Softmax 注意力和线性注意力视为两种不同的注意力范式,试图解决各自的问题和局限。代理注意力优雅地融合了这两种注意力形式,从而自然地继承了它们的优点,同时享受低计算复杂性和高模型表达能力。

(2) **能够采用更大的感受野。**得益于线性计算复杂度,代理注意力可以自然地采用更大的感受野,而不会增加模型计算量。例如,可以将 Swin Transformer 的 window size 由 7^2 扩大为 56^2,即直接采用全局自注意力,而完全不引入额外计算量。

实验结果

1. 分类任务

代理注意力是一个通用的注意力模块,本文基于 DeiT、PVT、Swin Transformer、CSwin Transformer 等模型架构进行了实验。如下图所示,在 ImageNet 分类任务中,基于代理注意力构建的模型能够取得显著的性能提升。例如,Agent-Swin-S 可以取得超越 Swin-B 的性能,而其参数量和计算量不到后者的 60%。

图 4:ImageNet 图片分类结果

在实际推理速度方面,代理注意力也具有显著的优势。如下图所示,在 CPU/GPU 端,代理注意力模型能够取得 2.0 倍 / 1.6 倍左右的加速,同时取得更好的性能。

图 5:实际测速结果

2. 检测和分割

在检测和分割任务中,相较于基础模型,Agent Transformer 也能够取得十分显著的性能提升,这在一定程度上得益于代理注意力的全局感受野。

图 6:COCO 物体检测与分割结果

图 7:ADE20K 语义分割结果

3.Agent Stable Diffusion

特别值得指出的是,代理注意力可以直接应用于 Stable Diffusion 模型,无需训练 ,即可加速生成并显著提升图片生成质量。如下图所示,将代理注意力应用于 Stable Diffusion 模型,能够将图片生成速度提升约 1.8 倍,同时提升图片的生成质量。

图 8:Stable Diffusion, ToMeSD 和 AgentSD 的定量化结果

下图中给出了生成图片的样例。可以看到,代理注意力能够显著降低 Stable Diffusion 模型生成图片的歧义和错误,同时提升生成速度和生成质量。

图 9:生成图片的样例

4. 高分辨率与大感受野

本文还探究了分辨率和感受野对模型性能的影响。如下图所示,我们基于 Agent-Swin-T 将窗口大小由 7^2 逐步扩大到 56^2。可以看到,随着感受野的扩大,模型性能稳步提升。这说明尽管 Swin 的窗口划分是有效的,但它依然不可避免地损害了模型的全局建模能力。

图 10:感受野大小的影响

下图中,我们将图片分辨率由 256^2 逐步扩大到 384^2。可以看到,在高分辨率的场景下,代理注意力模型持续展现出显著的优势。

图 11:高分辨率场景

总结

本文的贡献主要在三个方面:

(1) 提出了一种新颖、自然、有效且高效 的注意力范式 ------ 代理注意力,它自然地融合了高性能的 Softmax 注意力和高效的线性注意力,以线性计算量实现有效的全局信息建模。

(2) 在分类、检测、分割等诸多任务中充分验证了代理注意力的优越性,特别是在高分辨率、长序列 的场景下,这或为开发大尺度、细粒度、面向实际应用场景的视觉、语言大模型提供了新的方法。

(3) 创新性地以一种无需训练 的方式将代理注意力应用于 Stable Diffusion 模型,显著提升生成速度并提高图片质量,为扩散模型的加速和优化 提供了有效的新研究思路

相关推荐
FL16238631294 分钟前
荔枝成熟度分割数据集labelme格式2263张3类别
人工智能·深度学习
一点.点7 分钟前
DRIVEGPT4: 通过大语言模型实现可解释的端到端自动驾驶
人工智能·语言模型·自然语言处理·自动驾驶
天涯海风1 小时前
介绍一下什么是 AI、 AGI、 ASI
人工智能·agi
zzc9211 小时前
Tensorflow 2.X Debug中的Tensor.numpy问题 @tf.function
人工智能·tensorflow·numpy
我是你们的星光1 小时前
基于深度学习的高效图像失真校正框架总结
人工智能·深度学习·计算机视觉·3d
追逐☞2 小时前
机器学习(11)——xgboost
人工智能·机器学习
智驱力人工智能2 小时前
AI移动监测:仓储环境安全的“全天候守护者”
人工智能·算法·安全·边缘计算·行为识别·移动监测·动物检测
斯普信专业组3 小时前
Apidog MCP服务器,连接API规范和AI编码助手的桥梁
运维·服务器·人工智能
小技工丨3 小时前
LLaMA-Factory:了解webUI参数
人工智能·llm·llama·llama-factory
whaosoft-1433 小时前
w~自动驾驶~合集3
人工智能