《------往期经典推荐------》
二、机器学习实战专栏【链接】 ,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】,持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~
《------正文------》
目录
- 摘要
- 方法
-
- [1. 代理注意力机制](#1. 代理注意力机制)
- [2. 代理注意力模块](#2. 代理注意力模块)
- 创新点
- 实验结果
- 总结
- [Agent Attention源码与注释](#Agent Attention源码与注释)

摘要

本文提出了一种新的注意力机制------Agent Attention ,旨在在计算效率和表示能力之间取得平衡。传统的Softmax注意力机制虽然具有强大的表达能力,但其计算复杂度较高,限制了其在多种场景中的应用。Agent Attention通过引入一组代理令牌(Agent Tokens),减少了查询令牌(Query Tokens)与键值对(Key-Value Pairs)之间的直接交互,从而显著降低了计算复杂度。代理令牌首先作为查询令牌的"代理"从键值对中聚合信息,然后将这些信息广播回查询令牌。由于代理令牌的数量可以设计得远小于查询令牌的数量,Agent Attention在保持全局上下文建模能力的同时,显著提高了计算效率。此外,本文还证明了Agent Attention是线性注意力的一种广义形式,从而实现了Softmax注意力和线性注意力的无缝集成。实验结果表明,Agent Attention在多种视觉任务中表现出色,尤其是在高分辨率场景下。例如,在Stable Diffusion中应用Agent Attention,不仅加速了图像生成过程,还显著提高了生成图像的质量,且无需额外训练。
方法
1. 代理注意力机制

Agent Attention的核心思想是引入一组代理令牌A,作为查询令牌Q的"代理"。代理令牌首先从键值对(K, V)中聚合信息,然后将这些信息广播回查询令牌。具体来说,Agent Attention由两个Softmax注意力操作组成:
- 代理聚合(Agent Aggregation):代理令牌A作为查询,从键值对(K, V)中聚合信息,生成代理特征VA。
- 代理广播(Agent Broadcast):代理令牌A作为键,将代理特征VA广播给每个查询令牌Q,形成最终输出。
由于代理令牌的数量可以设计得远小于查询令牌的数量,Agent Attention的计算复杂度从Softmax注意力的O(N²)降低到O(Nn),其中n是代理令牌的数量,N是查询令牌的数量。
2. 代理注意力模块
为了进一步提升Agent Attention的性能,本文还引入了两个改进:
- 代理偏置(Agent Bias):为了更好利用位置信息,本文设计了一种代理偏置,帮助不同的代理令牌关注不同的区域。
- 多样性恢复模块(Diversity Restoration Module):为了保持特征多样性,本文采用了深度卷积(DWC)模块。
创新点

- 代理令牌的引入:通过引入代理令牌,Agent Attention减少了查询令牌与键值对之间的直接交互,显著降低了计算复杂度。
- Softmax与线性注意力的集成:本文证明了Agent Attention是线性注意力的一种广义形式,从而实现了Softmax注意力和线性注意力的无缝集成。
- 高效的计算与强大的表达能力:Agent Attention在保持全局上下文建模能力的同时,显著提高了计算效率,尤其适用于高分辨率场景。
实验结果

本文在多个视觉任务上验证了Agent Attention的有效性,包括图像分类、目标检测、语义分割和图像生成。实验结果表明:
- 图像分类:在ImageNet-1K数据集上,Agent Attention在多个模型上均取得了显著的性能提升。例如,Agent-PVT-S在参数和计算量仅为PVT-L的30%和40%的情况下,性能超过了PVT-L。
- 目标检测:在COCO数据集上,Agent Attention在RetinaNet、Mask R-CNN和Cascade Mask R-CNN框架上均表现出色,显著提高了检测精度。
- 语义分割:在ADE20K数据集上,Agent Attention在SemanticFPN和UperNet模型上均取得了显著的性能提升。
- 图像生成:在Stable Diffusion中应用Agent Attention,不仅加速了图像生成过程,还显著提高了生成图像的质量,且无需额外训练。
总结
本文提出的Agent Attention是一种新颖的注意力机制,通过引入代理令牌,显著降低了计算复杂度,同时保持了强大的表达能力。Agent Attention不仅适用于多种视觉任务,还在高分辨率场景下表现出色。此外,Agent Attention还可以直接应用于预训练的大型扩散模型,如Stable Diffusion,显著加速图像生成过程并提高生成质量。由于其线性复杂度和强大的表示能力,Agent Attention为视频建模和多模态基础模型等具有超长令牌序列的挑战性任务提供了新的可能性。
Agent Attention源码与注释
python
# 论文:Agent Attention: On the Integration of Softmax and Linear Attention
# 论文地址:https://arxiv.org/pdf/2312.08874
# 代码地址: https://github.com/LeapLabTHU/Agent-Attention
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
class AgentAttention(nn.Module):
def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
sr_ratio=1, agent_num=49, **kwargs):
super().__init__()
# 确保维度dim可以被头数num_heads整除
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim # 输入特征维度
self.num_patches = num_patches # 图像分割成的patch数量
window_size = (int(num_patches ** 0.5), int(num_patches ** 0.5)) # 假设patch是正方形,计算窗口大小
self.window_size = window_size
self.num_heads = num_heads # 注意力头数
head_dim = dim // num_heads # 每个头的维度
self.scale = head_dim ** -0.5 # 缩放因子
# 定义Q、KV的线性变换层
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
# 定义注意力分数的dropout层
self.attn_drop = nn.Dropout(attn_drop)
# 定义输出的线性变换层
self.proj = nn.Linear(dim, dim)
# 定义输出的dropout层
self.proj_drop = nn.Dropout(proj_drop)
# 如果空间降采样比例大于1,则定义空间降采样层
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.agent_num = agent_num # 代理token的数量
# 深度可分离卷积
self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
# 定义各种位置偏置参数
self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))
self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))
self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
# 初始化位置偏置参数
trunc_normal_(self.an_bias, std=.02)
trunc_normal_(self.na_bias, std=.02)
trunc_normal_(self.ah_bias, std=.02)
trunc_normal_(self.aw_bias, std=.02)
trunc_normal_(self.ha_bias, std=.02)
trunc_normal_(self.wa_bias, std=.02)
pool_size = int(agent_num ** 0.5) # 计算池化层的输出大小
self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)) # 自适应平均池化层
self.softmax = nn.Softmax(dim=-1) # softmax层用于计算注意力权重
def forward(self, x, H, W):
b, n, c = x.shape # 获取输入特征的batch size、patch数量和特征维度
num_heads = self.num_heads # 获取注意力头数
head_dim = c // num_heads # 计算每个头的维度
q = self.q(x) # 计算Q
# 如果空间降采样比例大于1,则对特征进行降采样
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(b, c, H, W) # 调整特征形状以适应卷积层
x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1) # 降采样并调整形状
x_ = self.norm(x_) # 归一化
kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3) # 计算KV并调整形状
else:
kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3) # 计算KV并调整形状
k, v = kv[0], kv[1] # 分离K和V
# 计算代理token
agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3) # 调整Q的形状
k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3) # 调整K的形状
v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3) # 调整V的形状
agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3) # 调整代理token的形状
kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio) # 计算KV的空间大小
# 计算位置偏置
position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')
position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
position_bias = position_bias1 + position_bias2
# 计算代理注意力分数并应用softmax
agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
agent_attn = self.attn_drop(agent_attn) # 应用dropout
agent_v = agent_attn @ v # 计算代理注意力加权和
# 计算代理token到Q的注意力分数
agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
agent_bias = agent_bias1 + agent_bias2
q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
q_attn = self.attn_drop(q_attn) # 应用dropout
x = q_attn @ agent_v # 计算注意力加权和
# 调整形状
x = x.transpose(1, 2).reshape(b, n, c)
v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)
if self.sr_ratio > 1:
v = nn.functional.interpolate(v, size=(H, W), mode='bilinear') # 上采样
x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c) # 加上深度可分离卷积的结果
x = self.proj(x) # 线性变换
x = self.proj_drop(x) # 应用dropout
return x # 返回输出特征
if __name__ == '__main__':
dim = 64
num_patches = 49
block = AgentAttention(dim=dim, num_patches=num_patches)
H, W = 7, 7
x = torch.rand(1, num_patches, dim)
# 前向传播
output = block(x, H, W)
print(f"Input size: {x.size()}")
print(f"Output size: {output.size()}")

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!