Transformers在计算机视觉领域中的应用【第1篇:ViT——Transformer杀入CV界之开山之作】

目录

  • [1 模型结构](#1 模型结构)
  • [2 模型的前向过程](#2 模型的前向过程)
  • [3 思考](#3 思考)
  • [4 结论](#4 结论)

论文:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

代码:https://github.com/google-research/vision_transformer

Huggingface:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py

1 模型结构

在模型的设计上,作者尽可能地按照最原始的Transformer来做的,好处是可以把NLP那边已经成功的Transformer架构直接拿过来用,就不用自己去魔改模型了。而且因为Transformer已经在NLP火了这么多年了,有一些写得非常高效的实现,ViT也可以直接拿过来使用。

下图是ViT的模型总览图。给定一张图,先把这张图打成一系列patch,然后这些patch变成一个序列,每个patch会通过一个叫线性投影层的操作特到一个特征,即论文里说的patch embedding。因为自注意力是所有元素两两之间去做交互,所以它本身不存在顺序的问题,但是对于图片来说是一个整体,每个patch是有顺序的,例如图中的1、2、3, ...。如果顺序颠倒了,就不是这张图片了。类似NLP那边,作者给patch embedding加上了位置编码position embedding,这样整体的token就包含了这个图片块原本有的图像信息,又包含了这个图像块所在的位置信息。得到这样的一个个token之后,就跟NLP那边是完全一样的了,直接输入到Transformer Encoder,编码器就会反馈给我们很多的输出。

问题来了,这么多输出,我们该拿哪个输出去做最后的分类呢?借鉴BERT,BERT模型有一个extra learnable embedding的特殊字符,即叫cls的分类字符,同样的,ViT也加了这样的一个特殊字符,用星号*代替,也是有position embedding,位置信息永远是0。

因为所有的token都跟其他所有的token做交互信息,所以cls embedding能够从别的embedding里学到有用的信息,从而我们只需要根据cls的输出做最后的判断即可。最后的MLP head,实际上就是一个通用的分类头,最后用交叉熵去进行模型的训练。

这个编码器是一个标准的Transformer Encoder,具体的结果列在图的右边。对于这些patches,先输入进来做一次Layer Norm,然后再做multi-head self-attention,然后再经过Layer Norm,再做MLP。这就是一个Transformer Block,可以把它叠加L次,就得到了Transformer Encoder。

ViT的模型结构是非常简洁的,特殊之处就在于如何把图片变成一系列的token。

ViT模型大小:

2 模型的前向过程

  • 输入图片

    • 对于输入图片,X = 224×224×3,3表示RGB通道;patch大小是16×16,那么得到的token数N = (224/16)×(224/16) = 196,即一共得到196个图像块,每个图像块的维度是16×16×3=768;所以我们把原来的图片变成了196×768。
  • 线性投射层

    • 线性投射层实际上就是一个全连接层,用E来表示,维度是768×768,右边这个768是文章中的符号D,可以变大,Transformer有多大,D就可以设置多大。但是左边的768是patch的大小计算过来的,是不变的。

    • 经过线性投射层之后,就得到了patch embedding,具体来说就是X * E = [196×768]×[768×768] = 196×768。意思是指这里得到了196个token,每个token的维度是768。这里的中括号表示矩阵,X*E表示矩阵相乘。至此,已经成功把vision的问题转化成NLP的问题了。

    • 这里还需要加上一个额外的cls token,所以最终序列的长度就是整体进入Transformer Encoder中的序列,长度是(196+1)×768 = 197×768,197是指196个图片块对应的token和1个特殊字符cls token。

  • 位置编码

    • 我们还需要加上位置编码信息,我们不可能把1、2、3,... 这些字符直接给Transformer去学,具体的做法是:设置一个表,表的每一行代表这里的1、2、3,... 这些序号,每一行都是一个向量,维度和D是一样的,都是768,这个向量也是可学习的。然后我们把这个位置信息加入到token里面,注意这里是用的Add,而不是拼接concatenation。加完位置编码之后,序列还是197×768。至此,我们就做完了整个图片的预处理了。
  • Transformer Encoder

    • 编码器的输入就是一个197×768的tensor,经过Layer Norm之后还是197×768。

    • 随后经过多头注意力,变成了三份:q、k、v,因为这里做的是多头注意力,假设用的是ViT的base版本,用了12个头,那么q、k、v的维度分别都变成了197×768/12 = 197×64,最后再把12个头的输出拼接起来,又变回197×768了,所以多头注意力的最终输出结果还是197×768。

    • 然后再经过Layer Norm,还是197×768,再过一层MLP,会把维度相对应的放大,一般是放大4倍,所以会变成197×768×4 = 197×3072,然后再缩小回去,再变成197×768输出。

    • 这就是一个Transformer Block的前向传播过程,进去是197×768,出来还是197×768,序列的长度和每个token对应的维度大小都是一样的,所以可以在一个Block上不停地往上叠加Block,想加多少加多少。最后有L层Block组成Transformer Encoder。

3 思考

论文是真的把计算机视觉当做自然语言处理去做,只使用了标准的Transformer编码器,而没有使用计算机视觉中常用的卷积神经网络。

把图片分成一系列16×16的patch的原因是:self-attention需要对所有token两两进行计算,操作复杂度是n×n,如果是把每个像素点当做token去输入到编码器中,那么对于一张224×224的图片,就有50176个token,对于检测任务的输入,图片大小通常是600×600,这样的计算复杂度是非常大的,需要很大的计算资源都难以实现,所以需要把图片分成一系列小的patch。

对于中等大小的数据集,ViT的结果反而会比ResNet的结果要弱几个点,主要的原因:Transformer跟卷积神经网络相比,缺少一些CNN所独有的归纳偏置(inductive bias)。归纳偏置指的是一种先验知识,或者一种我们提前做好的假设。常说的归纳偏置有两种:

  1. 局部性(locality):CNN是以滑动窗口这种形式一点一点地在图片上进行卷积的,它假设图片上相邻的区域会有相邻的特征,比如说桌子和椅子大概率会在相邻的位置,靠的越近的东西,相关性越强。

  2. 平移等变性(translation equivariance):公式形式是f(g(x)) = g(f(x)),无论是先计算g函数还是f函数,最终的结果都是不变的,可以把f理解成卷积,g理解成平移,也就是说,无论是先做平移还是先做卷积,结果都是不变的。因为卷积核相当于一个模板template,无论这个图片同样的物体移动到哪里,只要是同样的输入进来,遇到了同样的卷积核,那他的输出永远是一样的。

一旦CNN有了这两个归纳偏置之后,就有了很多的先验信息,所以就需要相对少的数据去学一个比较好的模型。但是对于Transformer来说,没有这些先验信息,所以它的所有这些能力、对视觉世界的感知全部都需要从这些数据里自己学。

为了验证这个假设,作者在大的数据集上训练,发现大规模的预训练比归纳偏置要好。ViT只要在足够的数据去预训练的情况下,就能在下游任务上获得很好的迁移学习效果。

4 结论

这篇论文直接拿NLP领域里标准的Transformer来解决计算机视觉问题,除了在刚开始的抽图像块的时候和位置编码用了一些图像特有的归纳偏置之外,就再也没有引入任何图像特有的归纳偏置了,把图片理解成一个序列的图像块,就像一个句子里有很多单词一样,可以直接用NLP里一个标准的Transformer来做图像分类。这个简单、扩展性很好的策略,跟大规模预训练结合起来的时候,效果出奇的好。ViT在很多图像分类的benchmark上超过了之前最好的方法。

ViT还能做检测和分割这两个最主流的视觉任务,后面出来了一系列的工作,例如:ViT-FRCNN(检测)、SETR(分割)、Swin Transformer(将多尺度的设计融合到Transformer里)等等。

ViT实现了CV和NLP大一统之后,多模态任务就可以用一个Transformer去解决了,这篇论文的影响力相当巨大。

相关推荐
catchadmin2 分钟前
Laravel AI SDK 正式发布
人工智能·php·laravel
哈__2 分钟前
CANN优化GAN生成对抗网络推理:判别器加速与生成质量平衡
人工智能·神经网络·生成对抗网络
方见华Richard3 分钟前
世毫九实验室技术优势拆解与对比分析(2026)
人工智能·交互·学习方法·原型模式·空间计算
梵得儿SHI3 分钟前
(第十篇)Spring AI 核心技术攻坚全梳理:企业级能力矩阵 + 四大技术栈攻坚 + 性能优化 Checklist + 实战项目预告
java·人工智能·spring·rag·企业级ai应用·springai技术体系·多模态和安全防护
chian-ocean5 分钟前
深入 CANN 生态:使用 `modelzoo-samples` 快速部署视觉模型
人工智能
勾股导航5 分钟前
Windows安装GPU环境
人工智能·windows·gnu
小羊不会打字9 分钟前
探索 CANN 生态:深入解析 `ops-transformer` 项目
人工智能·深度学习·transformer
哈__10 分钟前
CANN加速多模态融合推理:跨模态对齐与特征交互优化
人工智能·交互
红迅低代码平台(redxun)11 分钟前
构建企业“第二大脑“:AI低代码平台如何打造智能知识中枢?
人工智能·低代码·ai agent·ai开发平台·智能体开发平台·红迅软件
Loo国昌12 分钟前
【大模型应用开发】第六阶段:模型安全与可解释性
人工智能·深度学习·安全·transformer