Vision-LSTM: xLSTM 作为通用视觉主干

摘要

尽管Transformer最初是为自然语言处理引入的,但它现在已经被广泛用作计算机视觉中的通用主干结构。最近,长短期记忆(LSTM)已被扩展为一种可扩展且性能优越的架构------xLSTM,它通过指数门控和可并行化的矩阵内存结构克服了LSTM长期以来存在的限制。在本报告中,我们介绍了Vision-LSTM(ViL),它是将xLSTM构建块应用于计算机视觉的一种适配。ViL由一系列xLSTM块组成,其中奇数块从上到下处理补丁标记的序列,而偶数块则从下到上处理。实验表明,ViL有望作为计算机视觉架构的新通用主干进一步部署。

项目页面: https://nx-ai.github.io/vision-1stm/

1、引言

语言建模架构,如Transformer [39,1,33] 或最近的状态空间模型(State Space Models) [16, 17] 如Mamba [15],通常被适应到计算机视觉领域,以利用其强大的建模能力。然而,在自然语言处理中,输入句子通常通过离散词汇表编码为表示单词或常见子词的标记(tokens)。为了将图像编码为一组标记,Vision Transformer(ViT) [13] 提出将输入图像划分为非重叠的块(例如16x16像素),将它们线性投影为所谓的块标记的序列,并向这些标记添加位置信息。然后,这个序列可以被语言建模架构处理。

Extended Long Short-Term Memory(xLSTM)家族 [4] 最近被引入为一种新的语言建模架构。它展示了LSTM在大型语言模型(LLM)时代的复兴,在性能上与Transformer和状态空间模型(SSMs)等架构相媲美。与现有的Transformers或状态空间模型的视觉版本(如ViT [13] 或Vision Mamba [44])类似,这些视觉版本在各种计算机视觉任务中取得了显著成果 [31, 22, 28, 30, 3],我们引入了Vision LSTM(ViL)------一个使用xLSTM块作为其核心组件的通用计算机视觉主干。为了调整xLSTM(一个自回归模型)以适应计算机视觉(一个通常不是自回归的领域),我们采用了一个交替堆叠的mLSTM块 [4],其中奇数块从左上角到右下角逐行处理块,偶数块从右下角到左上角处理块。这种简单的交替设计允许ViL有效地处理非序列输入,如图像,而不会引入额外的计算量。

与状态空间模型(SSM)的视觉适应类似[23, 44],Vision-LSTM(ViL)在计算和内存复杂性方面相对于序列长度展现出线性关系,这使得它对于受益于高分辨率图像的任务具有吸引力,如医学影像[8,18,38,41]、分割[22,9]或物理模拟[5,27,6,2]。相比之下,由于自注意力机制,ViT的计算复杂性呈二次方增长,这使得将它们应用于高分辨率任务成本高昂。

2、方法

Vision-LSTM(ViL)是计算机视觉任务的通用主干网络,它由xLSTM块以残差方式构建而成,如图1所示。遵循ViT[13]的做法,ViL首先通过共享的线性投影将图像分割成不重叠的块,然后向每个块标记添加可学习的位置嵌入。ViL的核心是交替的mLSTM块,这些块可以完全并行化,并配备了矩阵存储和协方差更新规则。奇数mLSTM块从左上角到右下角处理块标记,而偶数块则从右下角到左上角处理。

3、实验

我们在ImageNet-1K数据集[12]上进行了实验,该数据集包含130万张训练图像和5万张验证图像,每张图像都属于1000个类别之一。我们的比较主要集中在使用序列建模主干且参数数量大致可比较的模型上。

我们在224x224分辨率下对ViL模型进行了800个周期(tiny, tiny+)或400个周期(small, small+, base)的训练,学习率设置为1e-3,并使用余弦衰减计划。详细的超参数可以在附录5中找到。

为了与Vision Mamba(Vim)[44]进行公平的比较,我们向我们的模型添加了额外的块,以匹配tiny和small变种的参数数量(分别表示为ViL-T+和ViL-S+)。请注意, V i L L \mathrm{ViL}_{\mathrm{L}} ViLL所需的计算量显著少于 V i m \mathrm{Vim} Vim,因为 V i L \mathrm{ViL} ViL以交替的方式遍历序列,而Vim在每个块中遍历序列两次。这一点在Vim使用优化的CUDA内核时仍然成立,而目前mLSTM还没有这样的优化内核,这将会使ViL的速度更快。我们在附录A.1中比较了运行时间,其中ViL的速度比Vim快高达69%。

由于ViT在视觉社区中已经被广泛接受,它们多年来经历了多次优化迭代[13,34,36,35,37,19]。由于这项工作是首次将xLSTM应用于计算机视觉,我们并不期望在所有情况下都能超越ViT多年来的超参数调优结果。尽管如此,表1中的结果显示,在较小的规模上,ViL相对于经过大量优化的ViT协议(DeiT、DeiT-II、DeiT-III)表现出更有利的结果,只有训练时间是ViL-S两倍的DeiT-III-S稍微表现得更好。在"base"规模上,ViL超越了最初的ViT[13]模型,并达到了与DeiT[34]相当的结果。请注意,由于在这个规模上训练模型的成本很高,ViL-B的超参数远非最优。作为参考,训练ViL-B大约需要600个A100 GPU小时,或在32个A100 GPU上需要19小时。

通过在"长序列微调"设置下对模型进行微调[44],可以进一步提高性能,该设置使用729的序列长度对模型进行30个周期的微调,并通过在连续的块标记之间使用 50 % 50\% 50%的重叠来实现。

尽管没有利用卷积固有的归纳偏差,ViL也展示了与基于 C N N \mathrm{CNN} CNN的模型(如ConvNeXt[24])相竞争的性能。

块设计。我们在表2中研究了设计ViL块的不同方式。简单的单向 x L S T M \mathrm{xLSTM} xLSTM块无法达到竞争性的性能,因为xLSTM的自回归性质不适合图像分类。以双向方式遍历块,即在每个块中引入第二个mLSTM层,该层以反向方式遍历序列(类似于Vim[44]),可以提高性能,但也需要更多的参数和浮点运算(FLOPS)。共享前向和反向mLSTM的参数可以使模型更加参数高效,但仍然需要更多的计算量,并且这些参数的过载会导致性能下降。使用交替块可以在保持计算和参数效率的同时提高性能。我们还探索了四向设计(类似于[23]),它指的是按行(两个方向)和按列(两个方向)遍历序列。双向遍历仅按行(两个方向)遍历序列。图2展示了不同的遍历路径。

由于双向和四向块的成本增加,这项研究是在一个严重减少的设置中进行的。我们在ImageNet-1K的一个子集上进行训练,该子集仅包含100个类别的样本,在 128 × 128 128 \times 128 128×128分辨率下训练400个周期。这特别必要,因为我们的四向实现与torch.compile(PyTorch[29]的一个通用速度优化方法)不兼容,这导致更长的运行时间,如表2最后一列所示。由于这一技术限制,我们选择交替双向块作为我们的核心设计。

3.1 分类设计

为了使用ViT执行分类,需要将令牌序列池化成一个单独的令牌,然后将其作为分类头部的输入。最常见的池化方法是(i)在序列的开始处添加一个可学习的[CLS]令牌,或者(ii)将所有块令牌平均以产生一个[AVG]令牌。使用[CLS]或[AVG]令牌通常是一个超参数,这两种变体都能实现可比较的性能。相反,自回归模型通常需要专门的分类设计。例如,Vim[44]要求[CLS]令牌位于序列的中间,如果采用其他分类设计(例如,在序列的开始和结束处使用[AVG]令牌或两个[CLS]令牌),则会遭受严重的性能损失。由于其自回归性质,我们在表3中探索了ViL的不同分类设计。[AVG]表示所有块令牌的平均值,"Middle Patch"使用中间的块令牌,"Middle [CLS]"在序列中间使用一个[CLS]令牌,"Bilateral [AVG]"使用第一个和最后一个块令牌的平均值。我们发现ViL对分类设计相对稳健,所有性能都在0.6%以内。我们选择"Bilateral [AVG]"而不是"Middle [CLS]",因为ImageNet-1K已知存在中心偏差,即物体通常位于图片的中间。通过使用"Bilateral [AVG]",我们避免了利用这种偏差,使我们的模型更加通用。

为了与以前使用单个令牌作为分类头输入的架构保持可比性,我们取第一个和最后一个块令牌的平均值。为了获得最佳性能,我们建议将这两个令牌连接起来("Bilateral Concat")而不是取平均值。这与自监督视觉变换器(如DINOv2[28])中的常见做法相似,这些变换器使用两个分别附加在[CLS]和[AVG]令牌上的目标进行训练,因此从连接[CLS]和[AVG]令牌的表示中受益。这一方向也在视觉SSM模型[40]中得到了探索,其中多个[CLS]令牌分布在序列中,然后用作分类器的输入。类似的方向也可能提高ViL的性能。

4、结论

受到xLSTM在语言建模中成功的启发,我们介绍了ViL,这是一种将xLSTM架构适应于视觉任务的架构。ViL以交替的方式处理一系列块令牌。奇数块从左上角到右下角逐行处理图像块,偶数块则从右下角到左上角处理。我们的新架构在ImageNet-1K分类任务上优于基于SSM的视觉架构,也优于经过优化的ViT模型。值得注意的是,ViL在公平比较中能够超越ViT训练流程,这些流程是多年超参数调整和变换器改进的结果。

在未来,我们看到在高分辨率图像需要最佳性能时应用ViL的潜力,如语义分割或医学成像。在这些设置中,由于自注意力的二次复杂性,变换器会受到高计算成本的困扰,而ViL由于其线性复杂性则不存在这个问题。此外,改进预训练方案(例如通过自监督学习)、探索更好的超参数设置或借鉴变换器的技术(例如LayerScale[35])是ViL的有希望的发展方向。

致谢

我们感谢欧洲高性能计算联合企业(EuroHPC Joint Undertaking)授予我们访问捷克IT4Innovations的Karolina、卢森堡LuxProvide的MeluXina、意大利CINECA的Leonardo以及芬兰CSC的LUMI的权限。

埃利斯林茨单位、LIT人工智能实验室和机器学习研究所得到了上奥地利州联邦政府的支持。我们感谢以下项目的支持:医学认知计算中心(MC3)、INCONTROL-RL(FFG-881064)、PRIMAL(FFG-873979)、S3AI(FFG-872172)、DL for GranularFlow(FFG-871302)、EPILEPSIA(FFG-892171)、AIRI FG 9-N(FWF-36284、FWF36235)、AI4GreenHeatingGrids(FFG-899943)、INTEGRATE(FFG-892418)、ELISE(H2020-ICT-2019-3 ID:951847)、Stars4Waters(HORIZON-CL6-2021-CLIMATE-01-01)。我们感谢Audi.JKU深度学习中心、TGW LOGISTICS GROUP GMBH、Silicon Austria Labs (SAL)、FILL Gesellschaft mbH、Anyline GmbH、Google、ZF Friedrichshafen AG、Robert Bosch GmbH、UCB Biopharma SRL、Merck Healthcare KGaA、Verbund AG、GLS(滑铁卢大学)、Software Competence Center Hagenberg GmbH、 Borealis AG、TÜV Austria、Frauscher Sensonic、TRUMPF和NVIDIA公司。

A. 扩展结果

A.1 ViL与Vim的运行时间比较

在表4中,我们比较了ViL和Vim[44]在ImageNet-1K数据集上训练单个epoch所需的时间。我们遵循ViTs的缩放程序,使用192(T)、384(S)、768(B)、1024(L)作为潜在维度,其中大规模模型还会将块的数量加倍。

A.2 更长训练时间的影响

我们研究了ViL-T+进行更长时间训练的效果,并按照表5中的超参数设置训练了400和800个epoch。经过400个epoch后,ViL-T+达到了77.2%的准确率,而当训练时长翻倍后,模型达到了78.1%的准确率。

我们使用表6中的超参数设置,对DeiT-III-T进行了实现,它在400个epoch后达到了75.6%的准确率,并在800个epoch后达到了76.2%的准确率。

B. 实现细节

B.1 硬件

我们在混合的自定义硬件服务器(主要是A100和A40 GPU)和配备8个A100或4个A100节点的公共研究集群上训练模型。

我们估计本项目使用的A100 GPU小时数总计为25,000小时。这一估计包括从初步探索、方法开发、分析到评估的所有工作。

B. 2 FLOPS 计算

我们使用fvcore库来计算表2中的浮点运算次数(FLOPS)^[1]^。由于fvcore并不支持ViL的所有操作,因此计算出的FLOPS并非100%准确,但它仍然是一个很好的参考点,用于比较不同块设计之间的相对计算量。

B.3 ViL 超参数

B.4、DeiT-III 重新实现的超参数

相关推荐
Elastic 中国社区官方博客1 小时前
使用 Elastic AI Assistant for Search 和 Azure OpenAI 实现从 0 到 60 的转变
大数据·人工智能·elasticsearch·microsoft·搜索引擎·ai·azure
江_小_白2 小时前
自动驾驶之激光雷达
人工智能·机器学习·自动驾驶
yusaisai大鱼4 小时前
TensorFlow如何调用GPU?
人工智能·tensorflow
珠海新立电子科技有限公司6 小时前
FPC柔性线路板与智能生活的融合
人工智能·生活·制造
IT古董6 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论
曼城周杰伦6 小时前
自然语言处理:第六十三章 阿里Qwen2 & 2.5系列
人工智能·阿里云·语言模型·自然语言处理·chatgpt·nlp·gpt-3
余炜yw7 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
莫叫石榴姐7 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
如若1238 小时前
利用 `OpenCV` 和 `Matplotlib` 库进行图像读取、颜色空间转换、掩膜创建、颜色替换
人工智能·opencv·matplotlib
YRr YRr8 小时前
深度学习:神经网络中的损失函数的使用
人工智能·深度学习·神经网络