【即插即用涨点模块】Agent Attention代理注意力:适用于高分辨率场景和多种视觉任务【附源码+注释】

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称 项目名称
1.【人脸识别与管理系统开发 2.【车牌识别与自动收费管理系统开发
3.【手势识别系统开发 4.【人脸面部活体检测系统开发
5.【图片风格快速迁移软件开发 6.【人脸表表情识别系统
7.【YOLOv8多目标识别与自动标注软件开发 8.【基于深度学习的行人跌倒检测系统
9.【基于深度学习的PCB板缺陷检测系统 10.【基于深度学习的生活垃圾分类目标检测系统
11.【基于深度学习的安全帽目标检测系统 12.【基于深度学习的120种犬类检测与识别系统
13.【基于深度学习的路面坑洞检测系统 14.【基于深度学习的火焰烟雾检测系统
15.【基于深度学习的钢材表面缺陷检测系统 16.【基于深度学习的舰船目标分类检测系统
17.【基于深度学习的西红柿成熟度检测系统 18.【基于深度学习的血细胞检测与计数系统
19.【基于深度学习的吸烟/抽烟行为检测系统 20.【基于深度学习的水稻害虫检测与识别系统
21.【基于深度学习的高精度车辆行人检测与计数系统 22.【基于深度学习的路面标志线检测与识别系统
23.【基于深度学习的智能小麦害虫检测识别系统 24.【基于深度学习的智能玉米害虫检测识别系统
25.【基于深度学习的200种鸟类智能检测与识别系统 26.【基于深度学习的45种交通标志智能检测与识别系统
27.【基于深度学习的人脸面部表情识别系统 28.【基于深度学习的苹果叶片病害智能诊断系统
29.【基于深度学习的智能肺炎诊断系统 30.【基于深度学习的葡萄簇目标检测系统
31.【基于深度学习的100种中草药智能识别系统 32.【基于深度学习的102种花卉智能识别系统
33.【基于深度学习的100种蝴蝶智能识别系统 34.【基于深度学习的水稻叶片病害智能诊断系统
35.【基于与ByteTrack的车辆行人多目标检测与追踪系统 36.【基于深度学习的智能草莓病害检测与分割系统
37.【基于深度学习的复杂场景下船舶目标检测系统 38.【基于深度学习的农作物幼苗与杂草检测系统
39.【基于深度学习的智能道路裂缝检测与分析系统 40.【基于深度学习的葡萄病害智能诊断与防治系统
41.【基于深度学习的遥感地理空间物体检测系统 42.【基于深度学习的无人机视角地面物体检测系统
43.【基于深度学习的木薯病害智能诊断与防治系统 44.【基于深度学习的野外火焰烟雾检测系统
45.【基于深度学习的脑肿瘤智能检测系统 46.【基于深度学习的玉米叶片病害智能诊断与防治系统
47.【基于深度学习的橙子病害智能诊断与防治系统 48.【基于深度学习的车辆检测追踪与流量计数系统
49.【基于深度学习的行人检测追踪与双向流量计数系统 50.【基于深度学习的反光衣检测与预警系统
51.【基于深度学习的危险区域人员闯入检测与报警系统 52.【基于深度学习的高密度人脸智能检测与统计系统
53.【基于深度学习的CT扫描图像肾结石智能检测系统 54.【基于深度学习的水果智能检测系统
55.【基于深度学习的水果质量好坏智能检测系统 56.【基于深度学习的蔬菜目标检测与识别系统
57.【基于深度学习的非机动车驾驶员头盔检测系统 58.【太基于深度学习的阳能电池板检测与分析系统
59.【基于深度学习的工业螺栓螺母检测 60.【基于深度学习的金属焊缝缺陷检测系统
61.【基于深度学习的链条缺陷检测与识别系统 62.【基于深度学习的交通信号灯检测识别
63.【基于深度学习的草莓成熟度检测与识别系统 64.【基于深度学习的水下海生物检测识别系统
65.【基于深度学习的道路交通事故检测识别系统 66.【基于深度学习的安检X光危险品检测与识别系统
67.【基于深度学习的农作物类别检测与识别系统 68.【基于深度学习的危险驾驶行为检测识别系统
69.【基于深度学习的维修工具检测识别系统 70.【基于深度学习的维修工具检测识别系统
71.【基于深度学习的建筑墙面损伤检测系统 72.【基于深度学习的煤矿传送带异物检测系统
73.【基于深度学习的老鼠智能检测系统 74.【基于深度学习的水面垃圾智能检测识别系统
75.【基于深度学习的遥感视角船只智能检测系统 76.【基于深度学习的胃肠道息肉智能检测分割与诊断系统
77.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统 78.【基于深度学习的心脏超声图像间隔壁检测分割与分析系统
79.【基于深度学习的果园苹果检测与计数系统 80.【基于深度学习的半导体芯片缺陷检测系统
81.【基于深度学习的糖尿病视网膜病变检测与诊断系统 82.【基于深度学习的运动鞋品牌检测与识别系统
83.【基于深度学习的苹果叶片病害检测识别系统 84.【基于深度学习的医学X光骨折检测与语音提示系统
85.【基于深度学习的遥感视角农田检测与分割系统 86.【基于深度学习的运动品牌LOGO检测与识别系统
87.【基于深度学习的电瓶车进电梯检测与语音提示系统 88.【基于深度学习的遥感视角地面房屋建筑检测分割与分析系统
89.【基于深度学习的医学CT图像肺结节智能检测与语音提示系统 90.【基于深度学习的舌苔舌象检测识别与诊断系统
91.【基于深度学习的蛀牙智能检测与语音提示系统 92.【基于深度学习的皮肤癌智能检测与语音提示系统

二、机器学习实战专栏【链接】 ,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

  • 摘要
  • 方法
    • [1. 代理注意力机制](#1. 代理注意力机制)
    • [2. 代理注意力模块](#2. 代理注意力模块)
  • 创新点
  • 实验结果
  • 总结
  • [Agent Attention源码与注释](#Agent Attention源码与注释)

论文地址:https://arxiv.org/pdf/2312.08874

代码地址: https://github.com/LeapLabTHU/Agent-Attention

摘要

本文提出了一种新的注意力机制------Agent Attention ,旨在在计算效率和表示能力之间取得平衡。传统的Softmax注意力机制虽然具有强大的表达能力,但其计算复杂度较高,限制了其在多种场景中的应用。Agent Attention通过引入一组代理令牌(Agent Tokens)​,减少了查询令牌(Query Tokens)与键值对(Key-Value Pairs)之间的直接交互,从而显著降低了计算复杂度。代理令牌首先作为查询令牌的"代理"从键值对中聚合信息,然后将这些信息广播回查询令牌。由于代理令牌的数量可以设计得远小于查询令牌的数量,Agent Attention在保持全局上下文建模能力的同时,显著提高了计算效率。此外,本文还证明了Agent Attention是线性注意力的一种广义形式,从而实现了Softmax注意力和线性注意力的无缝集成。实验结果表明,Agent Attention在多种视觉任务中表现出色,尤其是在高分辨率场景下。例如,在Stable Diffusion中应用Agent Attention,不仅加速了图像生成过程,还显著提高了生成图像的质量,且无需额外训练。

方法

1. 代理注意力机制

Agent Attention的核心思想是引入一组代理令牌A,作为查询令牌Q的"代理"。代理令牌首先从键值对(K, V)中聚合信息,然后将这些信息广播回查询令牌。具体来说,Agent Attention由两个Softmax注意力操作组成:

  1. 代理聚合(Agent Aggregation)​:代理令牌A作为查询,从键值对(K, V)中聚合信息,生成代理特征VA。
  2. 代理广播(Agent Broadcast)​:代理令牌A作为键,将代理特征VA广播给每个查询令牌Q,形成最终输出。

由于代理令牌的数量可以设计得远小于查询令牌的数量,Agent Attention的计算复杂度从Softmax注意力的O(N²)降低到O(Nn),其中n是代理令牌的数量,N是查询令牌的数量。

2. 代理注意力模块

为了进一步提升Agent Attention的性能,本文还引入了两个改进:

  1. 代理偏置(Agent Bias)​:为了更好利用位置信息,本文设计了一种代理偏置,帮助不同的代理令牌关注不同的区域。
  2. 多样性恢复模块(Diversity Restoration Module)​:为了保持特征多样性,本文采用了深度卷积(DWC)模块。

创新点

  1. 代理令牌的引入:通过引入代理令牌,Agent Attention减少了查询令牌与键值对之间的直接交互,显著降低了计算复杂度。
  2. Softmax与线性注意力的集成:本文证明了Agent Attention是线性注意力的一种广义形式,从而实现了Softmax注意力和线性注意力的无缝集成。
  3. 高效的计算与强大的表达能力:Agent Attention在保持全局上下文建模能力的同时,显著提高了计算效率,尤其适用于高分辨率场景。

实验结果

本文在多个视觉任务上验证了Agent Attention的有效性,包括图像分类、目标检测、语义分割和图像生成。实验结果表明:

  1. 图像分类:在ImageNet-1K数据集上,Agent Attention在多个模型上均取得了显著的性能提升。例如,Agent-PVT-S在参数和计算量仅为PVT-L的30%和40%的情况下,性能超过了PVT-L。
  2. 目标检测:在COCO数据集上,Agent Attention在RetinaNet、Mask R-CNN和Cascade Mask R-CNN框架上均表现出色,显著提高了检测精度。
  3. 语义分割:在ADE20K数据集上,Agent Attention在SemanticFPN和UperNet模型上均取得了显著的性能提升。
  4. 图像生成:在Stable Diffusion中应用Agent Attention,不仅加速了图像生成过程,还显著提高了生成图像的质量,且无需额外训练。

总结

本文提出的Agent Attention是一种新颖的注意力机制,通过引入代理令牌,显著降低了计算复杂度,同时保持了强大的表达能力。Agent Attention不仅适用于多种视觉任务,还在高分辨率场景下表现出色。此外,Agent Attention还可以直接应用于预训练的大型扩散模型,如Stable Diffusion,显著加速图像生成过程并提高生成质量。由于其线性复杂度和强大的表示能力,Agent Attention为视频建模和多模态基础模型等具有超长令牌序列的挑战性任务提供了新的可能性。

Agent Attention源码与注释

python 复制代码
# 论文:Agent Attention: On the Integration of Softmax and Linear Attention
# 论文地址:https://arxiv.org/pdf/2312.08874
# 代码地址: https://github.com/LeapLabTHU/Agent-Attention

import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_


class AgentAttention(nn.Module):
    def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
                 sr_ratio=1, agent_num=49, **kwargs):
        super().__init__()
        # 确保维度dim可以被头数num_heads整除
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim  # 输入特征维度
        self.num_patches = num_patches  # 图像分割成的patch数量
        window_size = (int(num_patches ** 0.5), int(num_patches ** 0.5))  # 假设patch是正方形,计算窗口大小
        self.window_size = window_size
        self.num_heads = num_heads  # 注意力头数
        head_dim = dim // num_heads  # 每个头的维度
        self.scale = head_dim ** -0.5  # 缩放因子

        # 定义Q、KV的线性变换层
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        # 定义注意力分数的dropout层
        self.attn_drop = nn.Dropout(attn_drop)
        # 定义输出的线性变换层
        self.proj = nn.Linear(dim, dim)
        # 定义输出的dropout层
        self.proj_drop = nn.Dropout(proj_drop)

        # 如果空间降采样比例大于1,则定义空间降采样层
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        self.agent_num = agent_num  # 代理token的数量
        # 深度可分离卷积
        self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
        # 定义各种位置偏置参数
        self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
        self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
        self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))
        self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))
        self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
        self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
        # 初始化位置偏置参数
        trunc_normal_(self.an_bias, std=.02)
        trunc_normal_(self.na_bias, std=.02)
        trunc_normal_(self.ah_bias, std=.02)
        trunc_normal_(self.aw_bias, std=.02)
        trunc_normal_(self.ha_bias, std=.02)
        trunc_normal_(self.wa_bias, std=.02)
        pool_size = int(agent_num ** 0.5)  # 计算池化层的输出大小
        self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))  # 自适应平均池化层
        self.softmax = nn.Softmax(dim=-1)  # softmax层用于计算注意力权重

    def forward(self, x, H, W):
        b, n, c = x.shape  # 获取输入特征的batch size、patch数量和特征维度
        num_heads = self.num_heads  # 获取注意力头数
        head_dim = c // num_heads  # 计算每个头的维度
        q = self.q(x)  # 计算Q

        # 如果空间降采样比例大于1,则对特征进行降采样
        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(b, c, H, W)  # 调整特征形状以适应卷积层
            x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1)  # 降采样并调整形状
            x_ = self.norm(x_)  # 归一化
            kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3)  # 计算KV并调整形状
        else:
            kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3)  # 计算KV并调整形状
        k, v = kv[0], kv[1]  # 分离K和V

        # 计算代理token
        agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)  # 调整Q的形状
        k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)  # 调整K的形状
        v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)  # 调整V的形状
        agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)  # 调整代理token的形状

        kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio)  # 计算KV的空间大小
        # 计算位置偏置
        position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')
        position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
        position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
        position_bias = position_bias1 + position_bias2
        # 计算代理注意力分数并应用softmax
        agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
        agent_attn = self.attn_drop(agent_attn)  # 应用dropout
        agent_v = agent_attn @ v  # 计算代理注意力加权和

        # 计算代理token到Q的注意力分数
        agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
        agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
        agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
        agent_bias = agent_bias1 + agent_bias2
        q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
        q_attn = self.attn_drop(q_attn)  # 应用dropout
        x = q_attn @ agent_v  # 计算注意力加权和

        # 调整形状
        x = x.transpose(1, 2).reshape(b, n, c)
        v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)
        if self.sr_ratio > 1:
            v = nn.functional.interpolate(v, size=(H, W), mode='bilinear')  # 上采样
        x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)  # 加上深度可分离卷积的结果

        x = self.proj(x)  # 线性变换
        x = self.proj_drop(x)  # 应用dropout
        return x  # 返回输出特征


if __name__ == '__main__':
    dim = 64
    num_patches = 49

    block = AgentAttention(dim=dim, num_patches=num_patches)

    H, W = 7, 7
    x = torch.rand(1, num_patches, dim)

    # 前向传播
    output = block(x, H, W)
    print(f"Input size: {x.size()}")
    print(f"Output size: {output.size()}")

好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

相关推荐
豆芽8195 小时前
深度学习核心算法
人工智能·python·深度学习·神经网络·机器学习·计算机视觉·卷积神经网络
油泼辣子多加9 小时前
【计算机视觉】数据增强
人工智能·yolo·计算机视觉
即安莉12 小时前
OPENCV数字识别(非手写数字/采用模板匹配)
人工智能·opencv·计算机视觉
数字扫地僧13 小时前
实时图像处理:让你的应用更智能
图像处理·opencv·计算机视觉
楼台的春风19 小时前
【Harris角点检测器详解】
图像处理·人工智能·深度学习·opencv·算法·计算机视觉·嵌入式
jndingxin19 小时前
OpenCV旋转估计(3)图像拼接类cv::detail::MultiBandBlender
人工智能·opencv·计算机视觉
蹦蹦跳跳真可爱58919 小时前
Python----计算机视觉处理(Opencv:边缘填充方式)
人工智能·python·opencv·计算机视觉
FL162386312921 小时前
医学图像分割数据集肺分割数据labelme格式6299张2类别
人工智能·深度学习·计算机视觉
狮歌~资深攻城狮1 天前
OpenCV的基本用法全解析
人工智能·opencv·计算机视觉
巷9551 天前
OpenCV平滑处理:图像去噪与模糊技术详解
人工智能·opencv·计算机视觉