昇思25天学习打卡营第六天|应用实践/计算机视觉/Vision Transformer图像分类

心得

运行模型似乎有点靠天意?每次跑模型之前先来个焚香沐浴?总之今天是机器视觉的最后一课了,尽管课程里强调模型跑得慢,可是我的这次运行,居然很快的就看到结果了。

如果一直看我这个系列文章的小伙伴,应该可以看出,我每次课程都是成功跑完整的。每次的文章中,不仅有课程的初始代码,还有跑成功后的显示结果,一起学习的小伙伴遇到问题时可以对着看看。

"

实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。

总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种"变种词向量"然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种"空间语义",从而获得了比较好的处理效果。

"

也许有小伙伴愿意尝试一下二维呢?加油哦!

打卡截图

Vision Transformer图像分类

感谢ZOMI酱对本文的贡献。

Vision Transformer(ViT)简介

近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。

模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:

模型特点

ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:

  1. 数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
  2. 模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
  3. 模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。

注意,本教程在CPU上运行时间过长,不建议使用CPU运行。

环境准备与数据读取

开始实验之前,请确保本地已经安装了Python环境并安装了MindSpore。

首先我们需要下载本案例的数据集,可通过http://image-net.org下载完整的ImageNet数据集,本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压,请确保你的数据集路径如以下结构。

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

[1]:

复制代码
%%capture captured_output
复制代码
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
复制代码
!pip uninstall mindspore -y
复制代码
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

[2]:

复制代码
# 查看当前 mindspore 版本
复制代码
!pip show mindspore
复制代码
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

[3]:

复制代码
from download import download
复制代码
复制代码
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
复制代码
path = "./"
复制代码
复制代码
path = download(dataset_url, path, kind="zip", replace=True)
复制代码
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)

file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 220MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

[4]:

复制代码
复制代码
import os
复制代码
复制代码
import mindspore as ms
复制代码
from mindspore.dataset import ImageFolderDataset
复制代码
import mindspore.dataset.vision as transforms
复制代码
复制代码
复制代码
data_path = './dataset/'
复制代码
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
复制代码
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
复制代码
复制代码
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
复制代码
复制代码
trans_train = [
复制代码
    transforms.RandomCropDecodeResize(size=224,
复制代码
                                      scale=(0.08, 1.0),
复制代码
                                      ratio=(0.75, 1.333)),
复制代码
    transforms.RandomHorizontalFlip(prob=0.5),
复制代码
    transforms.Normalize(mean=mean, std=std),
复制代码
    transforms.HWC2CHW()
复制代码
]
复制代码
复制代码
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
复制代码
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:

其主要结构为多个Encoder和Decoder模块所组成,其中Encoder和Decoder的详细结构如下图[2]所示:

Encoder与Decoder由许多结构组成,如:多头注意力(Multi-Head Attention)层,Feed Forward层,Normaliztion层,甚至残差连接(Residual Connection,图中的"Add")。不过,其中最重要的结构是多头注意力(Multi-Head Attention)结构,该结构基于自注意力(Self-Attention)机制,是多个Self-Attention的并行组成。

所以,理解了Self-Attention就抓住了Transformer的核心。

Attention模块

以下是Self-Attention的解释,其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中:

  1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1𝑥1,𝑥2𝑥2,𝑥3𝑥3),其中的𝑥1𝑥1,𝑥2𝑥2,𝑥3𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

𝑞𝑖=𝑊𝑞⋅𝑥𝑖𝑘𝑖=𝑊𝑘⋅𝑥𝑖,𝑣𝑖=𝑊𝑣⋅𝑥𝑖𝑖=1,2,3...(1)(1){𝑞𝑖=𝑊𝑞⋅𝑥𝑖𝑘𝑖=𝑊𝑘⋅𝑥𝑖,𝑖=1,2,3...𝑣𝑖=𝑊𝑣⋅𝑥𝑖

  1. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重。

𝑎1,1=𝑞1⋅𝑘1/𝑑⎯⎯√𝑎1,2=𝑞1⋅𝑘2/𝑑⎯⎯√𝑎1,3=𝑞1⋅𝑘3/𝑑⎯⎯√(2)(2){𝑎1,1=𝑞1⋅𝑘1/𝑑𝑎1,2=𝑞1⋅𝑘2/𝑑𝑎1,3=𝑞1⋅𝑘3/𝑑

𝑆𝑜𝑓𝑡𝑚𝑎𝑥:𝑎̂ 1,𝑖=𝑒𝑥𝑝(𝑎1,𝑖)/∑𝑗𝑒𝑥𝑝(𝑎1,𝑗),𝑗=1,2,3...(3)(3)𝑆𝑜𝑓𝑡𝑚𝑎𝑥:𝑎^1,𝑖=𝑒𝑥𝑝(𝑎1,𝑖)/∑𝑗𝑒𝑥𝑝(𝑎1,𝑗),𝑗=1,2,3...

  1. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

𝑏1=∑𝑖𝑎̂ 1,𝑖𝑣𝑖,𝑖=1,2,3...(4)(4)𝑏1=∑𝑖𝑎^1,𝑖𝑣𝑖,𝑖=1,2,3...

通过下图可以整体把握Self-Attention的全部过程。

多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理,这一点也可以从代码中体现,这也是attention结构可以进行并行加速的一个方面。

总结来说,多头注意力机制在保持参数总量不变的情况下,将同样的query, key和value映射到原来的高维空间(Q,K,V)的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算,最后再合并不同子空间中的注意力信息。

所以,对于同一个输入向量,多个注意力机制可以同时对其进行处理,即利用并行计算加速处理过程,又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制,其并行能力的主要体现在下图中的𝑎1𝑎1和𝑎2𝑎2是同一个向量进行分割获得的。

以下是Multi-Head Attention代码,结合上文的解释,代码清晰的展现了这一过程。

[5]:

复制代码
复制代码
from mindspore import nn, ops
复制代码
复制代码
复制代码
class Attention(nn.Cell):
复制代码
    def __init__(self,
复制代码
                 dim: int,
复制代码
                 num_heads: int = 8,
复制代码
                 keep_prob: float = 1.0,
复制代码
                 attention_keep_prob: float = 1.0):
复制代码
        super(Attention, self).__init__()
复制代码
复制代码
        self.num_heads = num_heads
复制代码
        head_dim = dim // num_heads
复制代码
        self.scale = ms.Tensor(head_dim ** -0.5)
复制代码
复制代码
        self.qkv = nn.Dense(dim, dim * 3)
复制代码
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
复制代码
        self.out = nn.Dense(dim, dim)
复制代码
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
复制代码
        self.attn_matmul_v = ops.BatchMatMul()
复制代码
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
复制代码
        self.softmax = nn.Softmax(axis=-1)
复制代码
复制代码
    def construct(self, x):
复制代码
        """Attention construct."""
复制代码
        b, n, c = x.shape
复制代码
        qkv = self.qkv(x)
复制代码
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
复制代码
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
复制代码
        q, k, v = ops.unstack(qkv, axis=0)
复制代码
        attn = self.q_matmul_k(q, k)
复制代码
        attn = ops.mul(attn, self.scale)
复制代码
        attn = self.softmax(attn)
复制代码
        attn = self.attn_drop(attn)
复制代码
        out = self.attn_matmul_v(attn, v)
复制代码
        out = ops.transpose(out, (0, 2, 1, 3))
复制代码
        out = ops.reshape(out, (b, n, c))
复制代码
        out = self.out(out)
复制代码
        out = self.out_drop(out)
复制代码
复制代码
        return out

Transformer Encoder

在了解了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,下面代码实现了Feed Forward,Residual Connection结构。

[6]:

复制代码
复制代码
from typing import Optional, Dict
复制代码
复制代码
复制代码
class FeedForward(nn.Cell):
复制代码
    def __init__(self,
复制代码
                 in_features: int,
复制代码
                 hidden_features: Optional[int] = None,
复制代码
                 out_features: Optional[int] = None,
复制代码
                 activation: nn.Cell = nn.GELU,
复制代码
                 keep_prob: float = 1.0):
复制代码
        super(FeedForward, self).__init__()
复制代码
        out_features = out_features or in_features
复制代码
        hidden_features = hidden_features or in_features
复制代码
        self.dense1 = nn.Dense(in_features, hidden_features)
复制代码
        self.activation = activation()
复制代码
        self.dense2 = nn.Dense(hidden_features, out_features)
复制代码
        self.dropout = nn.Dropout(p=1.0-keep_prob)
复制代码
复制代码
    def construct(self, x):
复制代码
        """Feed Forward construct."""
复制代码
        x = self.dense1(x)
复制代码
        x = self.activation(x)
复制代码
        x = self.dropout(x)
复制代码
        x = self.dense2(x)
复制代码
        x = self.dropout(x)
复制代码
复制代码
        return x
复制代码
复制代码
复制代码
class ResidualCell(nn.Cell):
复制代码
    def __init__(self, cell):
复制代码
        super(ResidualCell, self).__init__()
复制代码
        self.cell = cell
复制代码
复制代码
    def construct(self, x):
复制代码
        """ResidualCell construct."""
复制代码
        return self.cell(x) + x

接下来就利用Self-Attention来构建ViT模型中的TransformerEncoder部分,类似于构建了一个Transformer的编码器部分,如下图[1]所示:

  1. ViT模型中的基础结构与标准Transformer有所不同,主要在于Normalization的位置是放在Self-Attention和Feed Forward之前,其他结构如Residual Connection,Feed Forward,Normalization都如Transformer中所设计。

  2. 从Transformer结构的图片可以发现,多个子encoder的堆叠就完成了模型编码器的构建,在ViT模型中,依然沿用这个思路,通过配置超参数num_layers,就可以确定堆叠层数。

  3. Residual Connection,Normalization的结构可以保证模型有很强的扩展性(保证信息经过深层处理不会出现退化的现象,这是Residual Connection的作用),Normalization和dropout的应用可以增强模型泛化能力。

从以下源码中就可以清晰看到Transformer的结构。将TransformerEncoder结构和一个多层感知器(MLP)结合,就构成了ViT模型的backbone部分。

[7]:

复制代码
复制代码
class TransformerEncoder(nn.Cell):
复制代码
    def __init__(self,
复制代码
                 dim: int,
复制代码
                 num_layers: int,
复制代码
                 num_heads: int,
复制代码
                 mlp_dim: int,
复制代码
                 keep_prob: float = 1.,
复制代码
                 attention_keep_prob: float = 1.0,
复制代码
                 drop_path_keep_prob: float = 1.0,
复制代码
                 activation: nn.Cell = nn.GELU,
复制代码
                 norm: nn.Cell = nn.LayerNorm):
复制代码
        super(TransformerEncoder, self).__init__()
复制代码
        layers = []
复制代码
复制代码
        for _ in range(num_layers):
复制代码
            normalization1 = norm((dim,))
复制代码
            normalization2 = norm((dim,))
复制代码
            attention = Attention(dim=dim,
复制代码
                                  num_heads=num_heads,
复制代码
                                  keep_prob=keep_prob,
复制代码
                                  attention_keep_prob=attention_keep_prob)
复制代码
复制代码
            feedforward = FeedForward(in_features=dim,
复制代码
                                      hidden_features=mlp_dim,
复制代码
                                      activation=activation,
复制代码
                                      keep_prob=keep_prob)
复制代码
复制代码
            layers.append(
复制代码
                nn.SequentialCell([
复制代码
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
复制代码
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
复制代码
                ])
复制代码
            )
复制代码
        self.layers = nn.SequentialCell(layers)
复制代码
复制代码
    def construct(self, x):
复制代码
        """Transformer construct."""
复制代码
        return self.layers(x)

ViT模型的输入

传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是一维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。

在ViT模型中:

  1. 通过将输入图像在每个channel上划分为16*16个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;*例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。*

  2. 再将每一个patch的矩阵拉伸成为一个一维向量,从而获得了近似词向量堆叠的效果。上一步得到的14 x 14的patch就转换为长度为196的向量。

这是图像输入网络经过的第一步处理。具体Patch Embedding的代码如下所示:

[8]:

复制代码
复制代码
class PatchEmbedding(nn.Cell):
复制代码
    MIN_NUM_PATCHES = 4
复制代码
复制代码
    def __init__(self,
复制代码
                 image_size: int = 224,
复制代码
                 patch_size: int = 16,
复制代码
                 embed_dim: int = 768,
复制代码
                 input_channels: int = 3):
复制代码
        super(PatchEmbedding, self).__init__()
复制代码
复制代码
        self.image_size = image_size
复制代码
        self.patch_size = patch_size
复制代码
        self.num_patches = (image_size // patch_size) ** 2
复制代码
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
复制代码
复制代码
    def construct(self, x):
复制代码
        """Path Embedding construct."""
复制代码
        x = self.conv(x)
复制代码
        b, c, h, w = x.shape
复制代码
        x = ops.reshape(x, (b, c, h * w))
复制代码
        x = ops.transpose(x, (0, 2, 1))
复制代码
复制代码
        return x

输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。

  1. class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。

  2. 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。

  3. pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。

  4. 由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。

实际上,pos_embedding总共有4种方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是一维还是二维对分类结果影响不大,所以,在我们的代码中,也是采用了一维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。

总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种"变种词向量"然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种"空间语义",从而获得了比较好的处理效果。

整体构建ViT

以下代码构建了一个完整的ViT模型。

[9]:

复制代码
复制代码
from mindspore.common.initializer import Normal
复制代码
from mindspore.common.initializer import initializer
复制代码
from mindspore import Parameter
复制代码
复制代码
复制代码
def init(init_type, shape, dtype, name, requires_grad):
复制代码
    """Init."""
复制代码
    initial = initializer(init_type, shape, dtype).init_data()
复制代码
    return Parameter(initial, name=name, requires_grad=requires_grad)
复制代码
复制代码
复制代码
class ViT(nn.Cell):
复制代码
    def __init__(self,
复制代码
                 image_size: int = 224,
复制代码
                 input_channels: int = 3,
复制代码
                 patch_size: int = 16,
复制代码
                 embed_dim: int = 768,
复制代码
                 num_layers: int = 12,
复制代码
                 num_heads: int = 12,
复制代码
                 mlp_dim: int = 3072,
复制代码
                 keep_prob: float = 1.0,
复制代码
                 attention_keep_prob: float = 1.0,
复制代码
                 drop_path_keep_prob: float = 1.0,
复制代码
                 activation: nn.Cell = nn.GELU,
复制代码
                 norm: Optional[nn.Cell] = nn.LayerNorm,
复制代码
                 pool: str = 'cls') -> None:
复制代码
        super(ViT, self).__init__()
复制代码
复制代码
        self.patch_embedding = PatchEmbedding(image_size=image_size,
复制代码
                                              patch_size=patch_size,
复制代码
                                              embed_dim=embed_dim,
复制代码
                                              input_channels=input_channels)
复制代码
        num_patches = self.patch_embedding.num_patches
复制代码
复制代码
        self.cls_token = init(init_type=Normal(sigma=1.0),
复制代码
                              shape=(1, 1, embed_dim),
复制代码
                              dtype=ms.float32,
复制代码
                              name='cls',
复制代码
                              requires_grad=True)
复制代码
复制代码
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
复制代码
                                  shape=(1, num_patches + 1, embed_dim),
复制代码
                                  dtype=ms.float32,
复制代码
                                  name='pos_embedding',
复制代码
                                  requires_grad=True)
复制代码
复制代码
        self.pool = pool
复制代码
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
复制代码
        self.norm = norm((embed_dim,))
复制代码
        self.transformer = TransformerEncoder(dim=embed_dim,
复制代码
                                              num_layers=num_layers,
复制代码
                                              num_heads=num_heads,
复制代码
                                              mlp_dim=mlp_dim,
复制代码
                                              keep_prob=keep_prob,
复制代码
                                              attention_keep_prob=attention_keep_prob,
复制代码
                                              drop_path_keep_prob=drop_path_keep_prob,
复制代码
                                              activation=activation,
复制代码
                                              norm=norm)
复制代码
        self.dropout = nn.Dropout(p=1.0-keep_prob)
复制代码
        self.dense = nn.Dense(embed_dim, num_classes)
复制代码
复制代码
    def construct(self, x):
复制代码
        """ViT construct."""
复制代码
        x = self.patch_embedding(x)
复制代码
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
复制代码
        x = ops.concat((cls_tokens, x), axis=1)
复制代码
        x += self.pos_embedding
复制代码
复制代码
        x = self.pos_dropout(x)
复制代码
        x = self.transformer(x)
复制代码
        x = self.norm(x)
复制代码
        x = x[:, 0]
复制代码
        if self.training:
复制代码
            x = self.dropout(x)
复制代码
        x = self.dense(x)
复制代码
复制代码
        return x

整体流程图如下所示:

模型训练与推理

模型训练

模型开始训练前,需要设定损失函数,优化器,回调函数等。

完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。

[10]:

复制代码
复制代码
from mindspore.nn import LossBase
复制代码
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
复制代码
from mindspore import train
复制代码
复制代码
# define super parameter
复制代码
epoch_size = 10
复制代码
momentum = 0.9
复制代码
num_classes = 1000
复制代码
resize = 224
复制代码
step_size = dataset_train.get_dataset_size()
复制代码
复制代码
# construct model
复制代码
network = ViT()
复制代码
复制代码
# load ckpt
复制代码
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
复制代码
path = "./ckpt/vit_b_16_224.ckpt"
复制代码
复制代码
vit_path = download(vit_url, path, replace=True)
复制代码
param_dict = ms.load_checkpoint(vit_path)
复制代码
ms.load_param_into_net(network, param_dict)
复制代码
复制代码
# define learning rate
复制代码
lr = nn.cosine_decay_lr(min_lr=float(0),
复制代码
                        max_lr=0.00005,
复制代码
                        total_step=epoch_size * step_size,
复制代码
                        step_per_epoch=step_size,
复制代码
                        decay_epoch=10)
复制代码
复制代码
# define optimizer
复制代码
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
复制代码
复制代码
复制代码
# define loss function
复制代码
class CrossEntropySmooth(LossBase):
复制代码
    """CrossEntropy."""
复制代码
复制代码
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
复制代码
        super(CrossEntropySmooth, self).__init__()
复制代码
        self.onehot = ops.OneHot()
复制代码
        self.sparse = sparse
复制代码
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
复制代码
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
复制代码
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
复制代码
复制代码
    def construct(self, logit, label):
复制代码
        if self.sparse:
复制代码
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
复制代码
        loss = self.ce(logit, label)
复制代码
        return loss
复制代码
复制代码
复制代码
network_loss = CrossEntropySmooth(sparse=True,
复制代码
                                  reduction="mean",
复制代码
                                  smooth_factor=0.1,
复制代码
                                  num_classes=num_classes)
复制代码
复制代码
# set checkpoint
复制代码
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
复制代码
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
复制代码
复制代码
# initialize model
复制代码
# "Ascend + mixed precision" can improve performance
复制代码
ascend_target = (ms.get_context("device_target") == "Ascend")
复制代码
if ascend_target:
复制代码
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
复制代码
else:
复制代码
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
复制代码
复制代码
# train model
复制代码
model.train(epoch_size,
复制代码
            dataset_train,
复制代码
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
复制代码
            dataset_sink_mode=False,)
复制代码
Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)

file_sizes: 100%|████████████████████████████| 346M/346M [00:18<00:00, 18.7MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 2.4269836
Train epoch time: 254275.431 ms, per step time: 2034.203 ms
epoch: 2 step: 125, loss is 1.6255778
Train epoch time: 27142.445 ms, per step time: 217.140 ms
epoch: 3 step: 125, loss is 1.3483672
Train epoch time: 25470.780 ms, per step time: 203.766 ms
epoch: 4 step: 125, loss is 1.1531123
Train epoch time: 25066.077 ms, per step time: 200.529 ms
epoch: 5 step: 125, loss is 1.5591685
Train epoch time: 24959.266 ms, per step time: 199.674 ms
epoch: 6 step: 125, loss is 1.164649
Train epoch time: 24743.208 ms, per step time: 197.946 ms
epoch: 7 step: 125, loss is 1.2841825
Train epoch time: 25585.884 ms, per step time: 204.687 ms
epoch: 8 step: 125, loss is 1.3261327
Train epoch time: 24364.958 ms, per step time: 194.920 ms
epoch: 9 step: 125, loss is 1.2463571
Train epoch time: 24698.548 ms, per step time: 197.588 ms
epoch: 10 step: 125, loss is 1.1417197
Train epoch time: 25268.490 ms, per step time: 202.148 ms

模型验证

模型验证过程主要应用了ImageFolderDataset,CrossEntropySmooth和Model等接口。

ImageFolderDataset主要用于读取数据集。

CrossEntropySmooth是损失函数实例化接口。

Model主要用于编译模型。

与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。

在本案例中,这两个指标代表了在输出的1000维向量中,以最大值或前5的输出值所代表的类别为预测结果时,模型预测的准确率。这两个指标的值越大,代表模型准确率越高。

[11]:

复制代码
复制代码
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
复制代码
复制代码
trans_val = [
复制代码
    transforms.Decode(),
复制代码
    transforms.Resize(224 + 32),
复制代码
    transforms.CenterCrop(224),
复制代码
    transforms.Normalize(mean=mean, std=std),
复制代码
    transforms.HWC2CHW()
复制代码
]
复制代码
复制代码
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
复制代码
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
复制代码
复制代码
# construct model
复制代码
network = ViT()
复制代码
复制代码
# load ckpt
复制代码
param_dict = ms.load_checkpoint(vit_path)
复制代码
ms.load_param_into_net(network, param_dict)
复制代码
复制代码
network_loss = CrossEntropySmooth(sparse=True,
复制代码
                                  reduction="mean",
复制代码
                                  smooth_factor=0.1,
复制代码
                                  num_classes=num_classes)
复制代码
复制代码
# define metric
复制代码
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
复制代码
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
复制代码
复制代码
if ascend_target:
复制代码
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
复制代码
else:
复制代码
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
复制代码
复制代码
# evaluate model
复制代码
result = model.eval(dataset_val)
复制代码
print(result)
复制代码
{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}

从结果可以看出,由于我们加载了预训练模型参数,模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平,实际项目中也可以以此准确率为标准。如果未使用预训练模型参数,则需要更多的epoch来训练。

模型推理

在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。

本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。

[12]:

复制代码
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
复制代码
复制代码
trans_infer = [
复制代码
    transforms.Decode(),
复制代码
    transforms.Resize([224, 224]),
复制代码
    transforms.Normalize(mean=mean, std=std),
复制代码
    transforms.HWC2CHW()
复制代码
]
复制代码
复制代码
dataset_infer = dataset_infer.map(operations=trans_infer,
复制代码
                                  input_columns=["image"],
复制代码
                                  num_parallel_workers=1)
复制代码
dataset_infer = dataset_infer.batch(1)

接下来,我们将调用模型的predict方法进行模型。

在推理过程中,通过index2label就可以获取对应标签,再通过自定义的show_result接口将结果写在对应图片上。

[13]:

复制代码
复制代码
import os
复制代码
import pathlib
复制代码
import cv2
复制代码
import numpy as np
复制代码
from PIL import Image
复制代码
from enum import Enum
复制代码
from scipy import io
复制代码
复制代码
复制代码
class Color(Enum):
复制代码
    """dedine enum color."""
复制代码
    red = (0, 0, 255)
复制代码
    green = (0, 255, 0)
复制代码
    blue = (255, 0, 0)
复制代码
    cyan = (255, 255, 0)
复制代码
    yellow = (0, 255, 255)
复制代码
    magenta = (255, 0, 255)
复制代码
    white = (255, 255, 255)
复制代码
    black = (0, 0, 0)
复制代码
复制代码
复制代码
def check_file_exist(file_name: str):
复制代码
    """check_file_exist."""
复制代码
    if not os.path.isfile(file_name):
复制代码
        raise FileNotFoundError(f"File `{file_name}` does not exist.")
复制代码
复制代码
复制代码
def color_val(color):
复制代码
    """color_val."""
复制代码
    if isinstance(color, str):
复制代码
        return Color[color].value
复制代码
    if isinstance(color, Color):
复制代码
        return color.value
复制代码
    if isinstance(color, tuple):
复制代码
        assert len(color) == 3
复制代码
        for channel in color:
复制代码
            assert 0 <= channel <= 255
复制代码
        return color
复制代码
    if isinstance(color, int):
复制代码
        assert 0 <= color <= 255
复制代码
        return color, color, color
复制代码
    if isinstance(color, np.ndarray):
复制代码
        assert color.ndim == 1 and color.size == 3
复制代码
        assert np.all((color >= 0) & (color <= 255))
复制代码
        color = color.astype(np.uint8)
复制代码
        return tuple(color)
复制代码
    raise TypeError(f'Invalid type for color: {type(color)}')
复制代码
复制代码
复制代码
def imread(image, mode=None):
复制代码
    """imread."""
复制代码
    if isinstance(image, pathlib.Path):
复制代码
        image = str(image)
复制代码
复制代码
    if isinstance(image, np.ndarray):
复制代码
        pass
复制代码
    elif isinstance(image, str):
复制代码
        check_file_exist(image)
复制代码
        image = Image.open(image)
复制代码
        if mode:
复制代码
            image = np.array(image.convert(mode))
复制代码
    else:
复制代码
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")
复制代码
复制代码
    return image
复制代码
复制代码
复制代码
def imwrite(image, image_path, auto_mkdir=True):
复制代码
    """imwrite."""
复制代码
    if auto_mkdir:
复制代码
        dir_name = os.path.abspath(os.path.dirname(image_path))
复制代码
        if dir_name != '':
复制代码
            dir_name = os.path.expanduser(dir_name)
复制代码
            os.makedirs(dir_name, mode=777, exist_ok=True)
复制代码
复制代码
    image = Image.fromarray(image)
复制代码
    image.save(image_path)
复制代码
复制代码
复制代码
def imshow(img, win_name='', wait_time=0):
复制代码
    """imshow"""
复制代码
    cv2.imshow(win_name, imread(img))
复制代码
    if wait_time == 0:  # prevent from hanging if windows was closed
复制代码
        while True:
复制代码
            ret = cv2.waitKey(1)
复制代码
复制代码
            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
复制代码
            # if user closed window or if some key pressed
复制代码
            if closed or ret != -1:
复制代码
                break
复制代码
    else:
复制代码
        ret = cv2.waitKey(wait_time)
复制代码
复制代码
复制代码
def show_result(img: str,
复制代码
                result: Dict[int, float],
复制代码
                text_color: str = 'green',
复制代码
                font_scale: float = 0.5,
复制代码
                row_width: int = 20,
复制代码
                show: bool = False,
复制代码
                win_name: str = '',
复制代码
                wait_time: int = 0,
复制代码
                out_file: Optional[str] = None) -> None:
复制代码
    """Mark the prediction results on the picture."""
复制代码
    img = imread(img, mode="RGB")
复制代码
    img = img.copy()
复制代码
    x, y = 0, row_width
复制代码
    text_color = color_val(text_color)
复制代码
    for k, v in result.items():
复制代码
        if isinstance(v, float):
复制代码
            v = f'{v:.2f}'
复制代码
        label_text = f'{k}: {v}'
复制代码
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
复制代码
                    font_scale, text_color)
复制代码
        y += row_width
复制代码
    if out_file:
复制代码
        show = False
复制代码
        imwrite(img, out_file)
复制代码
复制代码
    if show:
复制代码
        imshow(img, win_name, wait_time)
复制代码
复制代码
复制代码
def index2label():
复制代码
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
复制代码
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
复制代码
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
复制代码
复制代码
    nums_children = list(zip(*meta))[4]
复制代码
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
复制代码
复制代码
    _, wnids, classes = list(zip(*meta))[:3]
复制代码
    clssname = [tuple(clss.split(', ')) for clss in classes]
复制代码
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
复制代码
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
复制代码
复制代码
    mapping = {}
复制代码
    for index, (_, class_name) in enumerate(wind2class_name):
复制代码
        mapping[index] = class_name[0]
复制代码
    return mapping
复制代码
复制代码
复制代码
# Read data for inference
复制代码
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
复制代码
    image = image["image"]
复制代码
    image = ms.Tensor(image)
复制代码
    prob = model.predict(image)
复制代码
    label = np.argmax(prob.asnumpy(), axis=1)
复制代码
    mapping = index2label()
复制代码
    output = {int(label): mapping[int(label)]}
复制代码
    print(output)
复制代码
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
复制代码
                result=output,
复制代码
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
复制代码
{236: 'Doberman'}

推理过程完成后,在推理文件夹下可以找到图片的推理结果,可以看出预测结果是Doberman,与期望结果相同,验证了模型的准确性。

总结

本案例完成了一个ViT模型在ImageNet数据上进行训练,验证和推理的过程,其中,对关键的ViT模型结构和原理作了讲解。通过学习本案例,理解源码可以帮助用户掌握Multi-Head Attention,TransformerEncoder,pos_embedding等关键概念,如果要详细理解ViT的模型原理,建议基于源码更深层次的详细阅读。

[14]:

复制代码
复制代码
import time
复制代码
复制代码
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'guojun0718')
复制代码
2024-07-16 08:19:33 guojun0718

[ ]:

复制代码
相关推荐
量子-Alex8 分钟前
【多模态聚类】用于无标记视频自监督学习的多模态聚类网络
学习·音视频·聚类
吉大一菜鸡12 分钟前
FPGA学习(基于小梅哥Xilinx FPGA)学习笔记
笔记·学习·fpga开发
Eric.Lee20212 小时前
Paddle OCR 中英文检测识别 - python 实现
人工智能·opencv·计算机视觉·ocr检测
audyxiao0013 小时前
AI一周重要会议和活动概览
人工智能·计算机视觉·数据挖掘·多模态
爱吃西瓜的小菜鸡3 小时前
【C语言】判断回文
c语言·学习·算法
小A1593 小时前
STM32完全学习——SPI接口的FLASH(DMA模式)
stm32·嵌入式硬件·学习
岁岁岁平安4 小时前
spring学习(spring-DI(字符串或对象引用注入、集合注入)(XML配置))
java·学习·spring·依赖注入·集合注入·基本数据类型注入·引用数据类型注入
武昌库里写JAVA4 小时前
Java成长之路(一)--SpringBoot基础学习--SpringBoot代码测试
java·开发语言·spring boot·学习·课程设计
qq_589568104 小时前
数据可视化echarts学习笔记
学习·信息可视化·echarts
兔C5 小时前
微信小程序的轮播图学习报告
学习·微信小程序·小程序