文章目录
摘要
本周学习了Vision Transformer (ViT) 的基本原理及其实现,并完成了基于PyTorch的模型训练、验证和预测任务。深入理解了ViT如何将图像分割成patch作为输入序列,并结合Transformer Encoder处理。通过迁移学习在花类数据集上训练模型,并验证了模型在预测任务中的优越性能。
Abstract
This week, I studied the fundamental principles and implementation of Vision Transformer (ViT) and completed model training, validation, and prediction tasks using PyTorch. I gained a deep understanding of how ViT splits an image into patches as input sequences and processes them using the Transformer Encoder. By leveraging transfer learning, I trained the model on a flower dataset and validated its superior performance in prediction tasks.
Vision Transformer
1 原理
- 数据处理
我认为ViT的关键在于理解怎么将图片当作一个序列输入进模型之中。我们先看看ViT整体结构图,如下图所示
论文中提到将 224x224x3 的图像作为输入,将图像分为 16x16x3 大小的patch,也就是说将输入图像分为了 224 × 224 × 3 16 × 16 × 3 = 196 \frac{224×224×3}{16×16×3}=196 16×16×3224×224×3=196 个patch。其中每个patch拉直之后的维度为 16×16×3=768维,也就是Linear Projection of Flattened Patches层下面分割的小图像。
在具体实现中,使用卷积核大小为 16x16x3 、步距为16、卷积核个数为768的卷积层,就能将3维图像转换为Transformer所需要的输入token[组数,维度]。
-
全连接层
上述[196,768]的token将传入Linear Projection of Flattened Patches层,该层是 768x768 的全连接层,该层输出认为 196x768 。
-
位置编码
将经过全连接层后的输出进行位置编码,其位置编码和Transformer中的时序编码有异曲同工之妙,前者可以通过位置编码表示出token之间关于原输入图像的一些位置信息,后者可以表示输入先后的时序信息。
该模型位置编码通过类似于坐标的形式表达,直接于输入相加,不改变维度大小。如下图所示:
进行位置编码后,还需要加上一个特殊字符(最左输入0*),输入总组数从之前的196变为197,传入Transformer Encoder的token为[197,768]。
- Transformer Encoder
ViT采用的是Transformer中编码器进行叠加,但其中的参数数量有所不同。
经过位置编码和加入特殊字符的token[197,768]传入编码器,首先经过层归一化,再经过多头自注意力。这里的多头自注意力是采用12个头,也就是将768维分为12份,每份(Q、K、V)64维度,计算之后再进行合并为768维。
ViT中的编码器仍是采用残差连接,再经过一次层归一化后,就进入单个Transformer Encoder的最后一层MLP(多层感知机)。MLP将经过多头自注意力的输出维度升高4倍,即从768变为3072,最后再将维度降至768维
ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。
- 输出
ViT中的编码器输入和输出都是768维,也就是在硬件运行的情况下一直叠加,论文中也是将该模块叠加了L块。
最后,通过全连接层和softmax进行概率输出即可
2 代码
在理解完ViT的原理之后,我们来看看PyTorch代码如何实现。这里以ViT-base模型,输入图像 224x224x3,patch大小 16x16x3 为例
花类数据集:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzy
训练模型代码如下,需要自行更改数据集路径和权重路径。
python
import os
import math
import argparse
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluate
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
if os.path.exists("../weights") is False:
os.makedirs("../weights")
tb_writer = SummaryWriter()
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
model = create_model(num_classes=args.num_classes, has_logits=False).to(device)
if args.weights != "":
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
weights_dict = torch.load(args.weights, map_location=device)
# 删除不需要的权重
# del_keys = ['head.weight', 'head.bias'] if model.has_logits \
# else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
del_keys = ['head.weight', 'head.bias']
for k in del_keys:
del weights_dict[k]
print(model.load_state_dict(weights_dict, strict=False))
if args.freeze_layers:
for name, para in model.named_parameters():
# 除head, pre_logits外,其他权重全部冻结
if "head" not in name and "pre_logits" not in name:
para.requires_grad_(False)
else:
print("training {}".format(name))
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):
# train
train_loss, train_acc = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
# validate
val_loss, val_acc = evaluate(model=model,
data_loader=val_loader,
device=device,
epoch=epoch)
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
tb_writer.add_scalar(tags[0], train_loss, epoch)
tb_writer.add_scalar(tags[1], train_acc, epoch)
tb_writer.add_scalar(tags[2], val_loss, epoch)
tb_writer.add_scalar(tags[3], val_acc, epoch)
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
torch.save(model.state_dict(), "../weights/model-{}.pth".format(epoch))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lrf', type=float, default=0.01)
# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data_path', type=str, default='../data/flower_photos', help='path to dataset')
parser.add_argument('--model-name', default='', help='create model name')
# 预训练权重路径,如果不想载入就设置为空字符
parser.add_argument('--weights', type=str, default='../weights/vit_base_patch16_224.pth', help='path to initial weights')
# 是否冻结权重
parser.add_argument('--freeze-layers', type=bool, default=True)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
opt = parser.parse_args()
main(opt)
训练结果如下:
因为是迁移学习的原因,只需要进行微调即可,所以9epoch之后准确率就达到97.9%了。
每训练一个epoch,就会将训练模型保存至weights文件夹,如下图所示
通过上述代码的训练之后,我们可以将保存的模型model-9.pth引入预测代码进行预测啦!需自行更改权重路径,以及需要测试的图片路径。
python
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from vit_model import vit_base_patch16_224_in21k as create_model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(254),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# load image
img_path = "../data/Image/flower.png"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
img2 = img
plt.imshow(img)
plt.show()
img = img.convert('RGB')
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0) # [1, 3, 224, 224]
# read class_indict
json_path = 'class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=5, has_logits=False).to(device) # num_classes=5:表示模型将被训练来识别5个不同的类别;has_logits=False:模型不直接输出logits,在实际应用中,这通常意味着模型的输出层之后可能会跟随一个softmax激活函数
# load model weights
model_weight_path = "../weights/model-9.pth" # 采用第10轮训练的参数
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
plt.imshow(img2)
plt.show()
if __name__ == '__main__':
main()
模型预测结果如下所示:
模型预测结果几乎100%为sunflowers