第一步 导包
python
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image
第二步 自定义数据集
python
data_dir = "./flower_data/"
train_dir = data_dir + "/train_filelist"
valid_dir = data_dir + "/val_filelist"
python
from torch.utils.data import Dataset,DataLoader
class FlowerDataset(Dataset):
def __init__(self,root_dir,ann_file,transform=None):
self.ann_file = ann_file
self.root_dir = root_dir
self.img_label = self.load_annotations()
self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
self.label = [label for label in list(self.img_label.values())]
self.transform = transform
def __len__(self):
return len(self.img)
def __getitem__(self,idx):
image = Image.open(self.img[idx])
label = self.label[idx]
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array(label))
return image,label
def load_annotations(self):
data_infos = {}
with open(self.ann_file) as f:
samples = [x.strip().split(" ") for x in f.readlines()]
for filename,gt_label in samples:
data_infos[filename] = np.array(gt_label,dtype=np.int64)
return data_infos
注:ann_file内容格式如下
![](https://i-blog.csdnimg.cn/direct/b8d35b9515b94504bd10f684f0c45624.png)
第三步 自定义transform
python
data_transforms = {
"train":
transforms.Compose([
transforms.Resize(64),
transforms.RandomRotation(45),
transforms.CenterCrop(64),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
]),
"valid":
transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
}
第四步 根据自定义Dataset实例化DataLoader
①实例化Dataset
python
train_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/train.txt",transform=data_transforms["train"])
valid_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/val.txt",transform=data_transforms["valid"])
②实例化DataLoader
python
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader = DataLoader(valid_dataset,batch_size=64,shuffle=True)
③验证图片是否加载正确
python
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
第五步 训练
①前置准备
python
dataloaders = {"train":train_loader,"valid":val_loader}
model_name = "resnet"
feature_extract = True
# 是否用GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 使用模型
model_ft = models.resnet18()
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(),lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
criterion = nn.CrossEntropyLoss()
②自定义模型
python
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename="best.pth"):
since = time.time()
best_acc = 0
model.to(device)
val_acc_history = []
train_acc_history = []
train_losses = []
valid_losses = []
LRs = [optimizer.param_groups[0]["lr"]]
best_model_wts = copy.deepcopy(model.state_dict())
for epoch in range(num_epochs):
print("Epoch {}/{}".format(epoch,num_epochs-1))
print("-"*10)
# 训练和验证
for phase in ["train","valid"]:
if phase == "train":
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
# 遍历所有数据
for inputs,labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# 清零
optimizer.zero_grad()
# 只有训练的时候计算和更新梯度
with torch.set_grad_enabled(phase == "train"):
outputs = model(inputs)
loss = criterion(outputs,labels)
_,preds = torch.max(outputs,1)
if phase == "train":
loss.backward()
optimizer.step()
# 计算损失
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds==labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
time_elapsed = time.time() - since
print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# 得到最好的那次模型
if phase=="valid" and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
state = {
"state_dict":model.state_dict(),
"best_acc":best_acc,
"optimizer":optimizer.state_dict()
}
torch.save(state,filename)
if phase == 'valid':
val_acc_history.append(epoch_acc)
valid_losses.append(epoch_loss)
scheduler.step(epoch_loss)#学习率衰减
if phase == 'train':
train_acc_history.append(epoch_acc)
train_losses.append(epoch_loss)
print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
LRs.append(optimizer.param_groups[0]['lr'])
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# 训练完后用最好的一次当做模型最终的结果,等着一会测试
model.load_state_dict(best_model_wts)
return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs
③训练模型
python
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')
![](https://i-blog.csdnimg.cn/direct/46ca95fb1502418b9a49a67a73693e22.png)