【ViT】对图片进行分类

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:人工智能话题分享

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

概述

模型结构

模型总体框架

Patch_embed

[Transformer Encoder](#Transformer Encoder)

[MLP Head](#MLP Head)

演示效果

核心逻辑

部署方式

参考文献


本文所有资源均可在该地址处获取。

概述

Transformer架构虽然已经成为自然语言处理任务的标准,但是它在计算机视觉的应用仍然有限,先前的视觉任务中,注意力大多与卷积结合使用。ViT模型的出现,证明了对CNN的依赖是不必要的,直接应用于图像补丁序列的纯Transformer架构可以在图像分类任务中表现良好。


模型结构

模型总体框架

上述是ViT模型的基本框架,可以大致分为三个主要部分

  • Patch_embed(将图片分成一系列的patches)
  • Transformer Encoder(建模不同序列之间的相关性)
  • MLP Head(用于最终的分类结构)

Patch_embed

在标准的Transformer模块中,输入的格式为二维矩阵 [num_token,token_dim] ,但对于图像数据而言,其输入数据的格式为[H,W,C] 的三维矩阵,明显不是Transformer架构需要的。所以需要Patch_embed结构将其转换为Transformer架构的输入。

针对于ViT-B/16而言,将输入图片(224x224)按照大小为(16x16) 的Patch进行划分,生成196个Patch。此时通过线性映射将每个Patch映射到一个长度为768 (16x16x3) 一维向量中。这一步可以通过卷积核大小为16x16,步距为 16 的卷积来实现。最后将长宽进行展平,则得到Transformer需要的输入格式。具体的维度变换如下所示:

224,224,3\] -\> \[14,14,768\] -\> \[196,768

在输入到Transformer Encoder之前还需要加上 [class]token(z00=xclassz00​=xclass​),它在Transformer 编码器 zL0zL0​ 输出处的状态用作图像表示 yy , 在预训练和微调过程中,zL0zL0​ 处都具有一个分类头。

同时需要将Position Embeddin[197,768]叠加(add)到上述的token上.

如上图所示,第一行第一列的位置编码上与其自身的余弦相似度最高,其次是与第一行和第一列的余弦相似度更高,这符合常理。

Transformer Encoder

Transformer Encoder 本身是堆叠Encoder Block L 次,ViT-B/16是12次。主要有以下几部分组成:

  • Layer Norm: 针对NLP领域提出,因为在RNN这类时序网络中,时序的长度并不一定是一个定值,Layer Norm在每个样本的每个特征维度上进行归一化,使得每个特征的均值为0,方差为1,从而有助于提高模型的训练效果和泛化能力。
  • Multi-head Attention: 使用多头注意力机制能够联合来自不同head部分学习到的信息。
  • MLP Block:由全连接+GELU激活函数+Dropout组成,在ViT-B/16的模型结构中,第一个全连接层将输入节点的个数翻4倍,第二个全连接层键还原节点的个数。

MLP Head

通过Transfomer Encoder后输入的shape和输出的shape保持不变,由于我们只需要分类信息,因此只需要提取[class]token 的结果 zL0zL0​ ,之后通过MLP Head得到最后的分类结果。

模型的公式如下,其中E表示token的个数

z0=[xclass;xp1E;xp2E;⋯ ;xpNE]+EposE∈R(P2⋅C)×D,Epos∈R(N+1)×Dzℓ′=MSA(LN(zℓ−1))+zℓ−1,ℓ=1...Lzℓ=MLP(LN(z′ℓ))+z′ℓ,ℓ=1...Ly=LN(zL0)​z0​zℓ′​zℓ​y​​=[xclass​;xp1​E;xp2​E;⋯;xpN​E]+Epos​=MSA(LN(zℓ−1​))+zℓ−1​,=MLP(LN(z′ℓ​))+z′ℓ​,=LN(zL0​)​​​​E∈R(P2⋅C)×D,Epos​∈R(N+1)×Dℓ=1...Lℓ=1...L​


演示效果

可视化输入图片的形式

可视化模型运行结果

核心逻辑

对输入图片进行分块处理

复制代码
class PatchEmbed(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
        super(PatchEmbed,self).__init__()
        img_size = (img_size,img_size)
        patch_size = (patch_size,patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[0]//patch_size[0])*(img_size[1]//patch_size[1])
        self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,
                               stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim) if norm_layer else nn.Identity()
    
    def forward(self,x):
        # 首先需要判断输入图片的大小符合我们的预期
        B,C,H,W=x.shape
        assert H==self.img_size[0] and W==self.img_size[1],\
            f"input image{H}x{W} does not model {self.img_size[0]}x{self.img_size[1]}"
        # [N,in_c,H,W]->[N,embed_dim,H//16,W//16]->[N,embed_dim,H//16*W//16]
        x = self.proj(x).flatten(2).transpose(1,2)
        
        x = self.norm(x)
        
        return x

多头注意力机制

复制代码
class Attention(nn.Module):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0,
                 proj_drop_ratio=0):
        super(Attention,self).__init__()
        self.head_dim = dim//num_heads
        self.num_heads = num_heads
        self.dim = dim
        self.scale = qk_scale or self.head_dim**(0.5)
        
        
        self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim,dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)
        
    def forward(self,x):
        # [batch_size,num_patches+class_token,channel:HxW]
        B,N,C = x.shape
        
        # 将其进行投影,也就是多头自注意力机制所说的矩阵相乘
        # reshape [B,N,C]->[B,N,3,heads,head_dim]->[3,B,heads,N,head_dim]
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
        # [B,heads,N,head_dim]
        q,k,v=qkv[0],qkv[1],qkv[2]
        # [B,heads,N,N]
        attn = (q@k.transpose(-2,-1))*self.scale
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # [B,heads,N,head_dim]->[B,N,heads,head_dim]->[B,N,heads*head_dim]
        # x = attn@v.permute(0,2,1,3).flatten(2)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        x = self.proj_drop(self.proj(x))
        return x

MLP 模块

复制代码
class MLP(nn.Module):
    def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
        super(MLP,self).__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features,hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features,out_features)
        self.drop = nn.Dropout(drop)
    
    # 根据流程图确定其中的结构,注意是先激活函数之后才是dropout操作   
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Block 结构

复制代码
class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 drop_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm
                 ):
        super(Block,self).__init__()
        
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        # Linear都需要是int的类型数据
        self.mlp = MLP(dim,int(dim*mlp_ratio),dim,act_layer,drop_ratio)
        
        self.norm2 = norm_layer(dim)
        
    def forward(self,x):
        x = x+self.drop_path(self.attn(self.norm1(x)))
        x = x+self.drop_path(self.mlp(self.norm2(x)))
        
        return x

ViT 模块

复制代码
class VisionTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_c=3,
                 num_classes=1000,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.0,
                 qkv_bias=True,
                 qk_scale=None,
                 representation_size=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 embed_layer=PatchEmbed,
                 norm_layer=None,
                 act_layer=None,
                 ):
        super(VisionTransformer,self).__init__()
        # 首先需要进行初始化操作,还可以对权重进行初始化操作
        self.num_classes = num_classes
        self.embed_dim = self.num_features = embed_dim
        self.num_tokens = 1 
        act_layer = act_layer or nn.GELU
        norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
        
        self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # cls_token是针对于每个embed_dim确定一个class
        # pos_embed除了channel 还要针对于每一个patch确定结果
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)
        
        dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,
                  drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,drop_path_ratio=dpr[i],
                  norm_layer=norm_layer,act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
        
        # Pre_logits layer 相当于多添加了一个全连接层
        if representation_size:
            self.has_logits=True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc',nn.Linear(embed_dim,representation_size))
                ('out',nn.Tanh())]
            ))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()
            
        self.head = nn.Linear(self.num_features,self.num_classes) if num_classes>0 else nn.Identity()
        
        # 开始对所有的权重进行初始化操作
        nn.init.trunc_normal_(self.pos_embed,std=0.02)
        nn.init.trunc_normal_(self.cls_token,std=0.02)
        self.apply(_init_vit_weights)
        
    def forward(self,x):
        B,C,H,W = x.shape
        #[B,C,H,W]->[B,N,H*W]
        x = self.patch_embed(x)
        # 每次都需要进行操作,所以不能对其本身进行expand操作
        cls_token = self.cls_token.expand(B,-1,-1)
        
        # 注意到后续一个是cat操作一个是add操作,且位置的先后关系
        x = torch.cat((cls_token,x),dim=1)
        
        # self.pos_embed中针对于一个batch值共享
        x = self.pos_drop(x+self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        
        
        x = self.pre_logits(x[:,0])
        x = self.head(x)
        return x  

部署方式

python 3.7.16

  • torch == 1.13.1
  • torchvision == 0.14.1
  • tqdm == 4.66.2
  • pillow == 9.5.0
  • matplotlib == 3.5.3

参考文献

论文下载地址
源码参考地址
参考博客地址

​​

相关推荐
萧鼎10 分钟前
深度探索 Py2neo:用 Python 玩转图数据库 Neo4j
数据库·python·neo4j
华子w90892585926 分钟前
基于 Python Django 和 Spark 的电力能耗数据分析系统设计与实现7000字论文实现
python·spark·django
风铃喵游38 分钟前
让大模型调用MCP服务变得超级简单
前端·人工智能
Rockson1 小时前
使用Ruby接入实时行情API教程
javascript·python
booooooty1 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
PyAIExplorer1 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标1 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒
ai_xiaogui2 小时前
一键部署AI工具!用AIStarter快速安装ComfyUI与Stable Diffusion
人工智能·stable diffusion·部署ai工具·ai应用市场教程·sd快速部署·comfyui一键安装
Tipriest_2 小时前
Python关键字梳理
python·关键字·keyword
聚客AI3 小时前
Embedding进化论:从Word2Vec到OpenAI三代模型技术跃迁
人工智能·llm·掘金·日新计划