【ViT】对图片进行分类(论文复现)
本文所涉及所有资源均在传知代码平台可获取
文章目录
概述
Transformer架构虽然已经成为自然语言处理任务的标准,但是它在计算机视觉的应用仍然有限,先前的视觉任务中,注意力大多与卷积结合使用。ViT模型的出现,证明了对CNN的依赖是不必要的,直接应用于图像补丁序列的纯Transformer架构可以在图像分类任务中表现良好
模型结构
模型总体框架
上述是ViT模型的基本框架,可以大致分为三个主要部分
- Patch_embed(将图片分成一系列的patches)
- Transformer Encoder(建模不同序列之间的相关性)
- MLP Head(用于最终的分类结构)
Patch_embed
在标准的Transformer模块中,输入的格式为二维矩阵 [num_token,token_dim] ,但对于图像数据而言,其输入数据的格式为[H,W,C] 的三维矩阵,明显不是Transformer架构需要的。所以需要Patch_embed结构将其转换为Transformer架构的输入。
针对于ViT-B/16而言,将输入图片(224x224)按照大小为(16x16) 的Patch进行划分,生成196个Patch。此时通过线性映射将每个Patch映射到一个长度为768 (16x16x3) 一维向量中。这一步可以通过卷积核大小为16x16,步距为 16 的卷积来实现。最后将长宽进行展平,则得到Transformer需要的输入格式。具体的维度变换如下所示:
[224,224,3] -> [14,14,768] -> [196,768]
在输入到Transformer Encoder之前还需要加上 [class]token(z00=xclass*z*00=*x**c**l**a**s**s*) ,它在Transformer 编码器 zL0z**L 0 输出处的状态用作图像表示 yy , 在预训练和微调过程中,zL0z**L0 处都具有一个分类头。
同时需要将Position Embeddin[197,768]叠加(add)到上述的token上
如上图所示,第一行第一列的位置编码上与其自身的余弦相似度最高,其次是与第一行和第一列的余弦相似度更高,这符合常理
Transformer Encoder
Transformer Encoder 本身是堆叠Encoder Block L 次,ViT-B/16是12次。主要有以下几部分组成:
- Layer Norm: 针对NLP领域提出,因为在RNN这类时序网络中,时序的长度并不一定是一个定值,Layer Norm在每个样本的每个特征维度上进行归一化,使得每个特征的均值为0,方差为1,从而有助于提高模型的训练效果和泛化能力。
- Multi-head Attention: 使用多头注意力机制能够联合来自不同head部分学习到的信息。
- MLP Block:由全连接+GELU激活函数+Dropout组成,在ViT-B/16的模型结构中,第一个全连接层将输入节点的个数翻4倍,第二个全连接层键还原节点的个数
MLP Head
通过Transfomer Encoder后输入的shape和输出的shape保持不变,由于我们只需要分类信息,因此只需要提取[class]token 的结果 zL0z**L0 ,之后通过MLP Head得到最后的分类结果。
模型的公式如下,其中E表示token的个数
演示效果
可视化输入图片的形式
可视化模型运行结果
核心逻辑
对输入图片进行分块处理
bash
class PatchEmbed(nn.Module):
def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
super(PatchEmbed,self).__init__()
img_size = (img_size,img_size)
patch_size = (patch_size,patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size[0]//patch_size[0])*(img_size[1]//patch_size[1])
self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,
stride=patch_size)
self.norm = nn.LayerNorm(embed_dim) if norm_layer else nn.Identity()
def forward(self,x):
# 首先需要判断输入图片的大小符合我们的预期
B,C,H,W=x.shape
assert H==self.img_size[0] and W==self.img_size[1],\
f"input image{H}x{W} does not model {self.img_size[0]}x{self.img_size[1]}"
# [N,in_c,H,W]->[N,embed_dim,H//16,W//16]->[N,embed_dim,H//16*W//16]
x = self.proj(x).flatten(2).transpose(1,2)
x = self.norm(x)
return x
多头注意力机制
bash
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0,
proj_drop_ratio=0):
super(Attention,self).__init__()
self.head_dim = dim//num_heads
self.num_heads = num_heads
self.dim = dim
self.scale = qk_scale or self.head_dim**(0.5)
self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.proj = nn.Linear(dim,dim)
self.proj_drop = nn.Dropout(proj_drop_ratio)
def forward(self,x):
# [batch_size,num_patches+class_token,channel:HxW]
B,N,C = x.shape
# 将其进行投影,也就是多头自注意力机制所说的矩阵相乘
# reshape [B,N,C]->[B,N,3,heads,head_dim]->[3,B,heads,N,head_dim]
qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
# [B,heads,N,head_dim]
q,k,v=qkv[0],qkv[1],qkv[2]
# [B,heads,N,N]
attn = (q@k.transpose(-2,-1))*self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# [B,heads,N,head_dim]->[B,N,heads,head_dim]->[B,N,heads*head_dim]
# x = attn@v.permute(0,2,1,3).flatten(2)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj_drop(self.proj(x))
return x
MLP 模块
bash
class MLP(nn.Module):
def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
super(MLP,self).__init__()
hidden_features = hidden_features or in_features
out_features = out_features or in_features
self.fc1 = nn.Linear(in_features,hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features,out_features)
self.drop = nn.Dropout(drop)
# 根据流程图确定其中的结构,注意是先激活函数之后才是dropout操作
def forward(self,x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
Block 结构
bash
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
attn_drop_ratio=0.,
drop_path_ratio=0.,
drop_ratio=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super(Block,self).__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim=dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
# Linear都需要是int的类型数据
self.mlp = MLP(dim,int(dim*mlp_ratio),dim,act_layer,drop_ratio)
self.norm2 = norm_layer(dim)
def forward(self,x):
x = x+self.drop_path(self.attn(self.norm1(x)))
x = x+self.drop_path(self.mlp(self.norm2(x)))
return x
ViT 模块
bash
class VisionTransformer(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
in_c=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
representation_size=None,
drop_ratio=0.,
attn_drop_ratio=0.,
drop_path_ratio=0.,
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
):
super(VisionTransformer,self).__init__()
# 首先需要进行初始化操作,还可以对权重进行初始化操作
self.num_classes = num_classes
self.embed_dim = self.num_features = embed_dim
self.num_tokens = 1
act_layer = act_layer or nn.GELU
norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
# cls_token是针对于每个embed_dim确定一个class
# pos_embed除了channel 还要针对于每一个patch确定结果
self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
self.pos_drop = nn.Dropout(p=drop_ratio)
dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
self.blocks = nn.Sequential(*[
Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,
drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,drop_path_ratio=dpr[i],
norm_layer=norm_layer,act_layer=act_layer)
for i in range(depth)
])
self.norm = norm_layer(embed_dim)
# Pre_logits layer 相当于多添加了一个全连接层
if representation_size:
self.has_logits=True
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc',nn.Linear(embed_dim,representation_size))
('out',nn.Tanh())]
))
else:
self.has_logits = False
self.pre_logits = nn.Identity()
self.head = nn.Linear(self.num_features,self.num_classes) if num_classes>0 else nn.Identity()
# 开始对所有的权重进行初始化操作
nn.init.trunc_normal_(self.pos_embed,std=0.02)
nn.init.trunc_normal_(self.cls_token,std=0.02)
self.apply(_init_vit_weights)
def forward(self,x):
B,C,H,W = x.shape
#[B,C,H,W]->[B,N,H*W]
x = self.patch_embed(x)
# 每次都需要进行操作,所以不能对其本身进行expand操作
cls_token = self.cls_token.expand(B,-1,-1)
# 注意到后续一个是cat操作一个是add操作,且位置的先后关系
x = torch.cat((cls_token,x),dim=1)
# self.pos_embed中针对于一个batch值共享
x = self.pos_drop(x+self.pos_embed)
x = self.blocks(x)
x = self.norm(x)
x = self.pre_logits(x[:,0])
x = self.head(x)
return x
文章代码资源点击附件获取