Vision Transformer(ViT)简介
近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前规模的模型。
ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下,依然可以在图像分类任务上达到很好的效果。
模型结构
ViT模型的主体结构是基于Transformer模型的Encoder部分(部分结构顺序有调整,如:Normalization的位置与标准Transformer不同),其结构图[1]如下:
模型特点
ViT模型主要应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:
数据集的原图像被划分为多个patch(图像块)后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
模型主体的Block结构是基于Transformer的Encoder结构,但是调整了Normalization的位置,其中,最主要的结构依然是Multi-head Attention结构。
模型在Blocks堆叠后接全连接层,接受类别向量的输出作为输入并用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。
下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。
注意,本教程在CPU上运行时间过长,不建议使用CPU运行。
实验环境搭建
python
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
python
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
python
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
模型解析
下面将通过代码来细致剖析ViT模型的内部结构。
Transformer基本原理
Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示:
模型训练
模型开始训练前,需要设定损失函数,优化器,回调函数等。
完整训练ViT模型需要很长的时间,实际应用时建议根据项目需要调整epoch_size,当正常输出每个Epoch的step信息时,意味着训练正在进行,通过模型输出可以查看当前训练的loss值和时间等指标。
python
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# construct model
network = ViT()
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.00005,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=10)
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# define loss function
class CrossEntropySmooth(LossBase):
"""CrossEntropy."""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
# train model
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
dataset_sink_mode=False,)
模型推理
在进行模型推理之前,首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理,这样才能与我们训练时的输入数据匹配。
本案例采用了一张Doberman的图片作为推理图片来测试模型表现,期望模型可以给出正确的预测结果。
python
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
class Color(Enum):
"""dedine enum color."""
red = (0, 0, 255)
green = (0, 255, 0)
blue = (255, 0, 0)
cyan = (255, 255, 0)
yellow = (0, 255, 255)
magenta = (255, 0, 255)
white = (255, 255, 255)
black = (0, 0, 0)
def check_file_exist(file_name: str):
"""check_file_exist."""
if not os.path.isfile(file_name):
raise FileNotFoundError(f"File `{file_name}` does not exist.")
def color_val(color):
"""color_val."""
if isinstance(color, str):
return Color[color].value
if isinstance(color, Color):
return color.value
if isinstance(color, tuple):
assert len(color) == 3
for channel in color:
assert 0 <= channel <= 255
return color
if isinstance(color, int):
assert 0 <= color <= 255
return color, color, color
if isinstance(color, np.ndarray):
assert color.ndim == 1 and color.size == 3
assert np.all((color >= 0) & (color <= 255))
color = color.astype(np.uint8)
return tuple(color)
raise TypeError(f'Invalid type for color: {type(color)}')
def imread(image, mode=None):
"""imread."""
if isinstance(image, pathlib.Path):
image = str(image)
if isinstance(image, np.ndarray):
pass
elif isinstance(image, str):
check_file_exist(image)
image = Image.open(image)
if mode:
image = np.array(image.convert(mode))
else:
raise TypeError("Image must be a `ndarray`, `str` or Path object.")
return image
def imwrite(image, image_path, auto_mkdir=True):
"""imwrite."""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(image_path))
if dir_name != '':
dir_name = os.path.expanduser(dir_name)
os.makedirs(dir_name, mode=777, exist_ok=True)
image = Image.fromarray(image)
image.save(image_path)
def imshow(img, win_name='', wait_time=0):
"""imshow"""
cv2.imshow(win_name, imread(img))
if wait_time == 0: # prevent from hanging if windows was closed
while True:
ret = cv2.waitKey(1)
closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
# if user closed window or if some key pressed
if closed or ret != -1:
break
else:
ret = cv2.waitKey(wait_time)
def show_result(img: str,
result: Dict[int, float],
text_color: str = 'green',
font_scale: float = 0.5,
row_width: int = 20,
show: bool = False,
win_name: str = '',
wait_time: int = 0,
out_file: Optional[str] = None) -> None:
"""Mark the prediction results on the picture."""
img = imread(img, mode="RGB")
img = img.copy()
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width
if out_file:
show = False
imwrite(img, out_file)
if show:
imshow(img, win_name, wait_time)
def index2label():
"""Dictionary output for image numbers and categories of the ImageNet dataset."""
metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]
clssname = [tuple(clss.split(', ')) for clss in classes]
wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {}
for index, (_, class_name) in enumerate(wind2class_name):
mapping[index] = class_name[0]
return mapping
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = ms.Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
mapping = index2label()
output = {int(label): mapping[int(label)]}
print(output)
show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
result=output,
out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
学习心得
通过学习Vision Transformer(ViT),我认识到该模型如何将Transformer架构应用于图像分类领域,并实现了在无卷积操作的情况下依然能取得优异的性能。ViT将输入图像划分为固定大小的patch,并将其视为序列输入到Transformer中。通过多头自注意力机制,模型能够有效地捕捉图像中的全局特征。在实践中,我了解了如何使用MindSpore框架进行数据处理、模型训练和推理,包括数据增强、定义损失函数、优化器、以及模型的评估与可视化。通过对ViT的实现和调试,我不仅掌握了其理论基础,还提升了在深度学习和计算机视觉领域的实战技能。