Transformers在计算机视觉领域中的应用【第1篇:ViT——Transformer杀入CV界之开山之作】

目录

  • [1 模型结构](#1 模型结构)
  • [2 模型的前向过程](#2 模型的前向过程)
  • [3 思考](#3 思考)
  • [4 结论](#4 结论)

论文:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

代码:https://github.com/google-research/vision_transformer

Huggingface:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py

1 模型结构

在模型的设计上,作者尽可能地按照最原始的Transformer来做的,好处是可以把NLP那边已经成功的Transformer架构直接拿过来用,就不用自己去魔改模型了。而且因为Transformer已经在NLP火了这么多年了,有一些写得非常高效的实现,ViT也可以直接拿过来使用。

下图是ViT的模型总览图。给定一张图,先把这张图打成一系列patch,然后这些patch变成一个序列,每个patch会通过一个叫线性投影层的操作特到一个特征,即论文里说的patch embedding。因为自注意力是所有元素两两之间去做交互,所以它本身不存在顺序的问题,但是对于图片来说是一个整体,每个patch是有顺序的,例如图中的1、2、3, ...。如果顺序颠倒了,就不是这张图片了。类似NLP那边,作者给patch embedding加上了位置编码position embedding,这样整体的token就包含了这个图片块原本有的图像信息,又包含了这个图像块所在的位置信息。得到这样的一个个token之后,就跟NLP那边是完全一样的了,直接输入到Transformer Encoder,编码器就会反馈给我们很多的输出。

问题来了,这么多输出,我们该拿哪个输出去做最后的分类呢?借鉴BERT,BERT模型有一个extra learnable embedding的特殊字符,即叫cls的分类字符,同样的,ViT也加了这样的一个特殊字符,用星号*代替,也是有position embedding,位置信息永远是0。

因为所有的token都跟其他所有的token做交互信息,所以cls embedding能够从别的embedding里学到有用的信息,从而我们只需要根据cls的输出做最后的判断即可。最后的MLP head,实际上就是一个通用的分类头,最后用交叉熵去进行模型的训练。

这个编码器是一个标准的Transformer Encoder,具体的结果列在图的右边。对于这些patches,先输入进来做一次Layer Norm,然后再做multi-head self-attention,然后再经过Layer Norm,再做MLP。这就是一个Transformer Block,可以把它叠加L次,就得到了Transformer Encoder。

ViT的模型结构是非常简洁的,特殊之处就在于如何把图片变成一系列的token。

ViT模型大小:

2 模型的前向过程

  • 输入图片

    • 对于输入图片,X = 224×224×3,3表示RGB通道;patch大小是16×16,那么得到的token数N = (224/16)×(224/16) = 196,即一共得到196个图像块,每个图像块的维度是16×16×3=768;所以我们把原来的图片变成了196×768。
  • 线性投射层

    • 线性投射层实际上就是一个全连接层,用E来表示,维度是768×768,右边这个768是文章中的符号D,可以变大,Transformer有多大,D就可以设置多大。但是左边的768是patch的大小计算过来的,是不变的。

    • 经过线性投射层之后,就得到了patch embedding,具体来说就是X * E = [196×768]×[768×768] = 196×768。意思是指这里得到了196个token,每个token的维度是768。这里的中括号表示矩阵,X*E表示矩阵相乘。至此,已经成功把vision的问题转化成NLP的问题了。

    • 这里还需要加上一个额外的cls token,所以最终序列的长度就是整体进入Transformer Encoder中的序列,长度是(196+1)×768 = 197×768,197是指196个图片块对应的token和1个特殊字符cls token。

  • 位置编码

    • 我们还需要加上位置编码信息,我们不可能把1、2、3,... 这些字符直接给Transformer去学,具体的做法是:设置一个表,表的每一行代表这里的1、2、3,... 这些序号,每一行都是一个向量,维度和D是一样的,都是768,这个向量也是可学习的。然后我们把这个位置信息加入到token里面,注意这里是用的Add,而不是拼接concatenation。加完位置编码之后,序列还是197×768。至此,我们就做完了整个图片的预处理了。
  • Transformer Encoder

    • 编码器的输入就是一个197×768的tensor,经过Layer Norm之后还是197×768。

    • 随后经过多头注意力,变成了三份:q、k、v,因为这里做的是多头注意力,假设用的是ViT的base版本,用了12个头,那么q、k、v的维度分别都变成了197×768/12 = 197×64,最后再把12个头的输出拼接起来,又变回197×768了,所以多头注意力的最终输出结果还是197×768。

    • 然后再经过Layer Norm,还是197×768,再过一层MLP,会把维度相对应的放大,一般是放大4倍,所以会变成197×768×4 = 197×3072,然后再缩小回去,再变成197×768输出。

    • 这就是一个Transformer Block的前向传播过程,进去是197×768,出来还是197×768,序列的长度和每个token对应的维度大小都是一样的,所以可以在一个Block上不停地往上叠加Block,想加多少加多少。最后有L层Block组成Transformer Encoder。

3 思考

论文是真的把计算机视觉当做自然语言处理去做,只使用了标准的Transformer编码器,而没有使用计算机视觉中常用的卷积神经网络。

把图片分成一系列16×16的patch的原因是:self-attention需要对所有token两两进行计算,操作复杂度是n×n,如果是把每个像素点当做token去输入到编码器中,那么对于一张224×224的图片,就有50176个token,对于检测任务的输入,图片大小通常是600×600,这样的计算复杂度是非常大的,需要很大的计算资源都难以实现,所以需要把图片分成一系列小的patch。

对于中等大小的数据集,ViT的结果反而会比ResNet的结果要弱几个点,主要的原因:Transformer跟卷积神经网络相比,缺少一些CNN所独有的归纳偏置(inductive bias)。归纳偏置指的是一种先验知识,或者一种我们提前做好的假设。常说的归纳偏置有两种:

  1. 局部性(locality):CNN是以滑动窗口这种形式一点一点地在图片上进行卷积的,它假设图片上相邻的区域会有相邻的特征,比如说桌子和椅子大概率会在相邻的位置,靠的越近的东西,相关性越强。

  2. 平移等变性(translation equivariance):公式形式是f(g(x)) = g(f(x)),无论是先计算g函数还是f函数,最终的结果都是不变的,可以把f理解成卷积,g理解成平移,也就是说,无论是先做平移还是先做卷积,结果都是不变的。因为卷积核相当于一个模板template,无论这个图片同样的物体移动到哪里,只要是同样的输入进来,遇到了同样的卷积核,那他的输出永远是一样的。

一旦CNN有了这两个归纳偏置之后,就有了很多的先验信息,所以就需要相对少的数据去学一个比较好的模型。但是对于Transformer来说,没有这些先验信息,所以它的所有这些能力、对视觉世界的感知全部都需要从这些数据里自己学。

为了验证这个假设,作者在大的数据集上训练,发现大规模的预训练比归纳偏置要好。ViT只要在足够的数据去预训练的情况下,就能在下游任务上获得很好的迁移学习效果。

4 结论

这篇论文直接拿NLP领域里标准的Transformer来解决计算机视觉问题,除了在刚开始的抽图像块的时候和位置编码用了一些图像特有的归纳偏置之外,就再也没有引入任何图像特有的归纳偏置了,把图片理解成一个序列的图像块,就像一个句子里有很多单词一样,可以直接用NLP里一个标准的Transformer来做图像分类。这个简单、扩展性很好的策略,跟大规模预训练结合起来的时候,效果出奇的好。ViT在很多图像分类的benchmark上超过了之前最好的方法。

ViT还能做检测和分割这两个最主流的视觉任务,后面出来了一系列的工作,例如:ViT-FRCNN(检测)、SETR(分割)、Swin Transformer(将多尺度的设计融合到Transformer里)等等。

ViT实现了CV和NLP大一统之后,多模态任务就可以用一个Transformer去解决了,这篇论文的影响力相当巨大。

相关推荐
顾道长生'1 小时前
(NIPS-2024)PISSA:大型语言模型的主奇异值和奇异向量适配
人工智能·语言模型·自然语言处理
语音之家1 小时前
CultureLLM 与 CulturePark:增强大语言模型对多元文化的理解
人工智能·语言模型·自然语言处理
Tasfa1 小时前
【AI系列】从零开始学习大模型GPT (1)- Build a Large Language Model (From Scratch)
人工智能·gpt·学习
一个处女座的程序猿1 小时前
LLMs之o3:《Deliberative Alignment: Reasoning Enables Safer Language Models》翻译与解读
人工智能·深度学习·机器学习
静静AI学堂1 小时前
动态头部:利用注意力机制统一目标检测头部
人工智能·目标检测·计算机视觉
嵌入式小强工作室1 小时前
stm32能跑人工智能么
人工智能·stm32·嵌入式硬件
像污秽一样2 小时前
动手学深度学习-深度学习计算-1层和块
人工智能·深度学习
迪小莫学AI2 小时前
精准识别花生豆:基于EfficientNetB0的深度学习检测与分类项目
人工智能·深度学习·分类
编程迪2 小时前
自研PHP版本AI口播数字人系统源码适配支持公众号H5小程序
人工智能·数字人系统源码·口播数字人·数字人小程序·数字人开源
Anna_Tong2 小时前
人工智能的视觉天赋:一文读懂卷积神经网络
人工智能·神经网络·cnn