【论文阅读及代码实现】BiFormer: 具有双水平路由注意的视觉变压器

【论文阅读及代码实现】BiFormer: 具有双水平路由注意的视觉变压器

文章目录

BiFormer: Vision Transformer with Bi-Level Routing Attention

视觉转换器的核心组成部分,注意力是捕捉长期依赖关系的有力工具

计算跨所有空间位置的成对token交互时,计算负担和沉重的内存占用

提出了一种新的动态稀疏注意,通过双层路由实现更灵活的内容感知计算分配

过程:

  • 首先在粗区域级别过滤掉不相关的键值对
  • 然后在剩余候选区域(即路由区域)的联合中应用细粒度的Token到Token
  • 利用稀疏性来节省计算和内存,同时只涉及GPU-friendly的密集矩阵乘法

提出了一种新的通用视觉变压器,称为BiF变压器

一、总体介绍

Transformer有许多适合于构建强大的数据驱动模型的属性

捕获数据中的远程依赖关系

卷积本质上是一个局部算子,与之相反,注意力的一个关键属性是全局接受场,它使视觉转换器能够捕获远程依赖

稀疏关注引入到视觉转换,可以减少相应的计算量

不同语义区域的查询实际上关注的键值对是完全不同的。因此,强制所有查询处理同一组令牌可能不是最优的

需要评估所有查询和键之间的配对亲和力,因此具有相同的vanilla attention复杂性。另一种可能性是基于每个查询的本地上下文来预测注意力偏移量

高效地定位有价值的键值

提出了一种区域到区域路由,核心思想是在粗粒度的区域级别过滤掉最不相关的键值

不是直接在细粒度的令牌级别

应用Token到令Token的注意,这是非常重要的,因为现在假定键值(Q,K,V)对在空间上是分散的

使用BRA作为核心构建块,我们提出了BiFormer,这是一个通用的视觉变压器骨干

BRA使BiFormer能够以内容感知的方式为每个查询处理最相关的键/值Token的一小部分,因此我们的模型实现了更好的计算性能权衡

具体作用:

  • 引入了一种新的双层路由机制,自适应查询的方式实现内容感知的稀疏模式
  • 双级路由关注作为基本构建块
  • 更好的性能和更低的计算量

二、联系工作

Vision transformers

采用基于通道的MLP块进行错位嵌入(通道混合),并采用注意力块进行交叉位置关系建,transformers使用注意力作为卷积的替代方案来实现全局上下文建模

vanilla attention在所有空间位置上两两计算特征亲和性,它会带来很高的计算负担 和沉重的内存占用

Efficient attention mechanisms

稀疏连接模式[6],低秩近似[43]或循环操作[11]来减少vanilla attention的计算和内存复杂性瓶颈,Swin变压器中,将注意力限制在不重叠的局部窗口上,并引入移位窗口操作来实现相邻窗口之间的窗口间通信

手工制作的稀疏模式:

  • 膨胀窗口[41,46]
  • 十字形窗口[14]

不同查询的关注区域可能会有显著差异

双层路由注意的目标是定位几个最相关的键值对,而四叉树注意构建了一个到ken金字塔,并组装来自不同粒度的所有级别的消息

三、具体模型

3.1 注意力

注意力的具体表示:

Q∈RNq×C,键K∈RNkv×C,值V∈RNkv×C作为输入

避免权值集中和梯度消失,引入标量因子√C

基础的构建块是多头自关注(MHSA)

3.2 双级路由注意(BRA)

为了缓解MHSA的可扩展性问题,一些研究[14,29,41,46,48]提出了不同的稀疏关注机制,其中每个查询只关注少量的键值对

探索了一种动态的、查询感知的稀疏注意机制。

整体结构图:

具体操作思想:

  • 在粗区域级别过滤掉大多数不相关的键值对
  • 只保留一小部分路由区域
  • 路由区域的联合中应用细粒度的令牌到令牌关

Region partition and input projection.

特征图X∈RH×W×C

分为S×S个不重叠的区域,使得每个区域包含H×W×S2特征向量

将其转化为

同时将导出查询,键,值张量,Q, K, V∈R s2xHW/S2×C,具有线性投影

Region-to-region routing with directed graph

构造一个有向图来找到参与关系,每个给定区域应该参与的区域

对Q和K应用每个区域的平均值来推导区域级查询和键Qr, Kr∈RS2×C

Qr与转置的Kr之间的矩阵乘法推导出区域到区域亲和图的邻接矩阵

Ar中的条目度量两个区域在语义上的关联程度

步骤是通过仅为每个区域保留top-k连接来修剪关联图

Ir的第i行包含第i区最相关区域的k个指标

区域到区域路由索引矩阵Ir,我们就可以应用细粒度的Token到令Token的注意关注。对于区域i中的每个查询令牌

收集键和值张量

函数LCE(·)使用深度卷积参数化,我们将内核大小设置为5

BRA的计算包括三个部分:

  • 线性投影

  • 区域到区域路由

  • token到token注意

3.4. BiFormer的结构设计

BRA为基本构建块,提出了一种新的通用视觉变压器BiFormer

具体结构:

  1. 第一阶段使用重叠的patch嵌入
  2. 第二到第四阶段使用patch合并模块
  3. 使用Ni连续的BiFormer块来变换特征

将每个注意头设置为32个通道,MLP扩展比e=3。对于BRA,由于输入分辨率不同,我们对4个阶段使用topk = 1,4,16, S2

分类/语义分割/目标检测任务,区域划分因子S = 7/8/16

四、论文实验结果

同样只看在ADE20K,语义分割上的实验效果,与其他的效果来进行对比

基于MMSegmentation[8]在ADE20K[55]数据集上进行了语义分割实验。

采用框架对比:

  • 语义FPN
  • UperNet

主干都使用ImageNet-1K预训练的权重进行初始化,而其他层则使用随机初始化,使用AdamW优化器对模型进行优化,批量大小设置为32

Swin Transformer相同的设置

五、代码理解

从官方代码中给出的代码中我们选取biformer_base来对相应的

通过相应参数,我们可以得知,在构建模型中的数据

由于我下游任务是语义分割,topks的最后一项参数是S=8,s2是64

这里是具体的BRA模块的构成参数导入,由4个阶段的不同来分配不同的参数,因为s=-1改为了s=64,在4个阶段的Attention都为BiLevelRoutingAttention

在代码中的具体使用

原官方代码中有很多if,else的判断选择,但是最后执行的代码为这一段

Biformer的具体函数在

可以看到具体的函数操作

六、遥感实验结果

2023.5.21 resnet50 Vaihingen 256*256(叠切) 3225 100 0.01 SGD OA=83.47% Miou=67.75% F1=80.53% resnet50+BiFormer*4+IRFFN(depth=[3,4,6,3] num_heads=[2, 4, 8, 16],dilation=[1, 2])
2023.5.21 resnet50 Vaihingen 256*256(叠切) 3225 100 0.01 SGD 82.93% 67.61% 80.35% resnet50+BiFormer2+MSDA2+IRFFN(depth=[3,4,6,3] num_heads=[2, 4, 8, 16],dilation=[1, 2])
2023.5.21 resnet50 Vaihingen 256*256(叠切) 3225 100 0.01 SGD 83.24% 67.74% 80.44% resnet50+BiFormer+MSDA+BiFormer+MSDA+IRFFN(depth=[3,4,6,3] num_heads=[2, 4, 8, 16],dilation=[1, 2])

主干网络:resnet50

解码器:Unet的融合解码

初步结论:具有一定提高的效果,但作为轻量级的网络,在实际的使用上效果一般

相关推荐
NAGNIP8 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab9 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab9 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP13 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年13 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼13 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS13 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区14 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈15 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang15 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx