简介
在实际应用场景中,由于训练数据集不足,很少有人会从头开始训练整个网络。普遍做法是使用在大型数据集上预训练的模型,然后将其作为初始化权重或固定特征提取器,用于特定任务。本章使用迁移学习方法对ImageNet数据集中狼和狗图像进行分类。
数据准备
下载数据集
使用download接口下载狗与狼分类数据集,数据集目录结构如下:
datasets-Canidae/data/
└── Canidae
├── train
│ ├── dogs
│ └── wolves
└── val
├── dogs
└── wolves
加载数据集
使用mindspore.dataset.ImageFolderDataset
接口来加载数据集,并进行图像增强操作。
python
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
batch_size = 18
image_size = 224
num_epochs = 5
lr = 0.001
momentum = 0.9
workers = 4
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"
def create_dataset_canidae(dataset_path, usage):
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=workers, shuffle=True)
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
scale = 32
if usage == "train":
trans = [
vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
trans = [
vision.Decode(),
vision.Resize(image_size + scale),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
data_set = data_set.map(operations=trans, input_columns='image', num_parallel_workers=workers)
data_set = data_set.batch(batch_size)
return data_set
dataset_train = create_dataset_canidae(data_path_train, "train")
dataset_val = create_dataset_canidae(data_path_val, "val")
数据集可视化
通过create_dict_iterator
接口创建数据迭代器,使用next
迭代访问数据集,并进行可视化展示。
python
data = next(dataset_train.create_dict_iterator())
images = data["image"]
labels = data["label"]
import matplotlib.pyplot as plt
import numpy as np
class_name = {0: "dogs", 1: "wolves"}
plt.figure(figsize=(5, 5))
for i in range(4):
data_image = images[i].asnumpy()
data_label = labels[i]
data_image = np.transpose(data_image, (1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
data_image = std * data_image + mean
data_image = np.clip(data_image, 0, 1)
plt.subplot(2, 2, i+1)
plt.imshow(data_image)
plt.title(class_name[int(labels[i].asnumpy())])
plt.axis("off")
plt.show()
构建ResNet50模型
定义残差块
定义基础的ResidualBlockBase和扩展的ResidualBlock,包含卷积层、归一化层和ReLU激活函数。
搭建ResNet50网络
包括初始的卷积层、池化层、四个残差块组、平均池化层和全连接层。
python
from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal
class ResidualBlockBase(nn.Cell):
expansion = 1
def __init__(self, in_channel, out_channel, stride=1, norm=None, down_sample=None):
super(ResidualBlockBase, self).__init__()
if not norm:
self.norm = nn.BatchNorm2d(out_channel)
else:
self.norm = norm
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, weight_init=Normal(mean=0, sigma=0.02))
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, weight_init=Normal(mean=0, sigma=0.02))
self.relu = nn.ReLU()
self.down_sample = down_sample
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.norm(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm(out)
if self.down_sample is not None:
identity = self.down_sample(x)
out += identity
out = self.relu(out)
return out
class ResidualBlock(nn.Cell):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, down_sample=None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, weight_init=Normal(mean=0, sigma=0.02))
self.norm1 = nn.BatchNorm2d(out_channel)
self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=stride, weight_init=Normal(mean=0, sigma=0.02))
self.norm2 = nn.BatchNorm2d(out_channel)
self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion, kernel_size=1, weight_init=Normal(mean=0, sigma=0.02))
self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)
self.relu = nn.ReLU()
self.down_sample = down_sample
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(x)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.down_sample is not None:
identity = self.down_sample(x)
out += identity
out = self.relu(out)
return out
def make_layer(last_out_channel, block, channel, block_nums, stride=1):
down_sample = None
if stride != 1 or last_out_channel != channel * block.expansion:
down_sample = nn.SequentialCell([
nn.Conv2d(last_out_channel, channel * block.expansion, kernel_size=1, stride=stride, weight_init=Normal(mean=0, sigma=0.02)),
nn.BatchNorm2d(channel * block.expansion, gamma_init=Normal(mean=1, sigma=0.02))
])
layers = []
layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))
in_channel = channel * block.expansion
for _ in range(1, block_nums):
layers.append(block(in_channel, channel))
return nn.SequentialCell(layers)
class ResNet(nn.Cell):
def __init__(self, block, layer_nums, num_classes, input_channel):
super(ResNet, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=Normal(mean=0, sigma=0.02))
self.norm = nn.BatchNorm2d(64)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = make_layer(64, block, 64, layer_nums[0])
self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
self.avg_pool = nn.AvgPool2d()
self.flatten = nn.Flatten()
self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)
def construct(self, x):
x = self.conv1(x)
x = self.norm(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x
def _resnet(model_url, block, layers, num_classes
, input_channel):
model = ResNet(block, layers, num_classes, input_channel)
return model
def resnet50(num_classes=2, input_channel=2048):
return _resnet("", ResidualBlock, [3, 4, 6, 3], num_classes, input_channel)
模型加载与冻结预训练权重
加载预训练的ResNet50模型权重,冻结部分网络参数以作为特征提取器。
python
from mindspore import load_checkpoint, load_param_into_net
def load_pretrained_weights(model, ckpt_path):
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(model, param_dict)
model = resnet50()
pretrained_ckpt_path = "resnet50.ckpt"
load_pretrained_weights(model, pretrained_ckpt_path)
for param in model.get_parameters():
param.requires_grad = False
微调网络
微调ResNet50的最后一个全连接层以适应新任务。
python
class ResNetFinetune(nn.Cell):
def __init__(self, model, num_classes):
super(ResNetFinetune, self).__init__()
self.base_model = model
self.fc = nn.Dense(2048, num_classes)
def construct(self, x):
x = self.base_model(x)
x = self.fc(x)
return x
finetune_model = ResNetFinetune(model, num_classes=2)
模型训练
定义训练参数和优化器,使用交叉熵损失函数。
python
import mindspore.nn as nn
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(finetune_model.trainable_params(), learning_rate=lr, momentum=momentum)
model = ms.Model(finetune_model, loss_fn, optimizer, metrics={"accuracy"})
训练和验证模型
进行模型训练并在验证集上进行评估。
python
model.train(num_epochs, dataset_train, dataset_sink_mode=False)
metric = model.eval(dataset_val, dataset_sink_mode=False)
print("Validation accuracy: ", metric['accuracy'])
总结
迁移学习方法能有效利用预训练模型在新任务上的表现,通过微调模型权重,可以在有限的数据上获得较好的性能。通过对ResNet50模型进行特定任务的微调,实现了对狼和狗图像的分类,提高了模型的泛化能力。