【深度学习】DataLoader自定义数据集制作

第一步 导包

python 复制代码
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

第二步 自定义数据集

python 复制代码
data_dir = "./flower_data/"
train_dir = data_dir + "/train_filelist"
valid_dir = data_dir + "/val_filelist"
python 复制代码
from torch.utils.data import Dataset,DataLoader
class FlowerDataset(Dataset):
    def __init__(self,root_dir,ann_file,transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform
        
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self,idx):
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image,label
    
    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(" ") for x in f.readlines()]
            for filename,gt_label in samples:
                data_infos[filename] = np.array(gt_label,dtype=np.int64)
        return data_infos

注:ann_file内容格式如下

第三步 自定义transform

python 复制代码
data_transforms = {
    "train":
        transforms.Compose([
            transforms.Resize(64),
            transforms.RandomRotation(45),
            transforms.CenterCrop(64),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]),
    "valid":
        transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
}

第四步 根据自定义Dataset实例化DataLoader

①实例化Dataset

python 复制代码
train_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/train.txt",transform=data_transforms["train"])
valid_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/val.txt",transform=data_transforms["valid"])

②实例化DataLoader

python 复制代码
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader = DataLoader(valid_dataset,batch_size=64,shuffle=True)

③验证图片是否加载正确

python 复制代码
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))


第五步 训练

①前置准备

python 复制代码
dataloaders = {"train":train_loader,"valid":val_loader}

model_name = "resnet"
feature_extract = True

# 是否用GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 使用模型
model_ft = models.resnet18()
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))

# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(),lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
criterion = nn.CrossEntropyLoss()

②自定义模型

python 复制代码
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename="best.pth"):
    since = time.time()
    best_acc = 0
    model.to(device)
    
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]["lr"]]
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch,num_epochs-1))
        print("-"*10)
        
        # 训练和验证
        for phase in ["train","valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            # 遍历所有数据
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = criterion(outputs,labels)
                    _,preds = torch.max(outputs,1)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        
                # 计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds==labels.data)
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            
            # 得到最好的那次模型
            if phase=="valid" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    "state_dict":model.state_dict(),
                    "best_acc":best_acc,
                    "optimizer":optimizer.state_dict()
                }
                torch.save(state,filename)
                
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
                
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

③训练模型

python 复制代码
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')
相关推荐
风象南8 分钟前
我把大脑开源给了AI
人工智能·后端
Johny_Zhao2 小时前
OpenClaw安装部署教程
linux·人工智能·ai·云计算·系统运维·openclaw
飞哥数智坊2 小时前
我帮你读《一人公司(OPC)发展研究》
人工智能
冬奇Lab6 小时前
OpenClaw 源码精读(3):Agent 执行引擎——AI 如何「思考」并与真实世界交互?
人工智能·aigc
没事勤琢磨8 小时前
如何让 OpenClaw 控制使用浏览器:让 AI 像真人一样操控你的浏览器
人工智能
用户5191495848458 小时前
CrushFTP 认证绕过漏洞利用工具 (CVE-2024-4040)
人工智能·aigc
牛马摆渡人5288 小时前
OpenClaw实战--Day1: 本地化
人工智能
前端小豆8 小时前
玩转 OpenClaw:打造你的私有 AI 助手网关
人工智能
BugShare8 小时前
写一个你自己的Agent Skills
人工智能·程序员
机器之心9 小时前
英伟达护城河被AI攻破,字节清华CUDA Agent,让人人能搓CUDA内核
人工智能·openai