【深度学习】DataLoader自定义数据集制作

第一步 导包

python 复制代码
import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

第二步 自定义数据集

python 复制代码
data_dir = "./flower_data/"
train_dir = data_dir + "/train_filelist"
valid_dir = data_dir + "/val_filelist"
python 复制代码
from torch.utils.data import Dataset,DataLoader
class FlowerDataset(Dataset):
    def __init__(self,root_dir,ann_file,transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform
        
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self,idx):
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image,label
    
    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(" ") for x in f.readlines()]
            for filename,gt_label in samples:
                data_infos[filename] = np.array(gt_label,dtype=np.int64)
        return data_infos

注:ann_file内容格式如下

第三步 自定义transform

python 复制代码
data_transforms = {
    "train":
        transforms.Compose([
            transforms.Resize(64),
            transforms.RandomRotation(45),
            transforms.CenterCrop(64),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]),
    "valid":
        transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
}

第四步 根据自定义Dataset实例化DataLoader

①实例化Dataset

python 复制代码
train_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/train.txt",transform=data_transforms["train"])
valid_dataset = FlowerDataset(root_dir=train_dir,ann_file="./flower_data/val.txt",transform=data_transforms["valid"])

②实例化DataLoader

python 复制代码
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader = DataLoader(valid_dataset,batch_size=64,shuffle=True)

③验证图片是否加载正确

python 复制代码
image, label = iter(train_loader).next()
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))


第五步 训练

①前置准备

python 复制代码
dataloaders = {"train":train_loader,"valid":val_loader}

model_name = "resnet"
feature_extract = True

# 是否用GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 使用模型
model_ft = models.resnet18()
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 102))

# 优化器设置
optimizer_ft = optim.Adam(model_ft.parameters(),lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
criterion = nn.CrossEntropyLoss()

②自定义模型

python 复制代码
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename="best.pth"):
    since = time.time()
    best_acc = 0
    model.to(device)
    
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]["lr"]]
    
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch,num_epochs-1))
        print("-"*10)
        
        # 训练和验证
        for phase in ["train","valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            # 遍历所有数据
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # 清零
                optimizer.zero_grad()
                # 只有训练的时候计算和更新梯度
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    loss = criterion(outputs,labels)
                    _,preds = torch.max(outputs,1)
                    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        
                # 计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds==labels.data)
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            
            # 得到最好的那次模型
            if phase=="valid" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    "state_dict":model.state_dict(),
                    "best_acc":best_acc,
                    "optimizer":optimizer.state_dict()
                }
                torch.save(state,filename)
                
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                valid_losses.append(epoch_loss)
                scheduler.step(epoch_loss)#学习率衰减
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
                
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    

    # 训练完后用最好的一次当做模型最终的结果,等着一会测试
    model.load_state_dict(best_model_wts)
    return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

③训练模型

python 复制代码
model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20, filename='best.pth')
相关推荐
Hero_HL5 分钟前
Towards Open World Object Detection概述(论文)
人工智能·目标检测·计算机视觉
远方16099 分钟前
10-Oracle 23 ai Vector Search 概述和参数
人工智能·oracle
fydw_71515 分钟前
Celery 核心概念详解及示例
人工智能·机器学习
audyxiao00126 分钟前
计算机视觉顶刊《International Journal of Computer Vision》2025年5月前沿热点可视化分析
图像处理·人工智能·opencv·目标检测·计算机视觉·大模型·视觉检测
蹦蹦跳跳真可爱58930 分钟前
Python----目标检测(训练YOLOV8网络)
人工智能·python·yolo·目标检测
张较瘦_32 分钟前
[论文阅读] 人工智能+项目管理 | 当 PMBOK 遇见 AI:传统项目管理框架的破局之路
论文阅读·人工智能
Q同学35 分钟前
Qwen3开源最新Embedding模型
深度学习·神经网络·llm
Leinwin36 分钟前
行业案例 | ASOS 借助 Azure AI Foundry(国际版)为年轻时尚爱好者打造惊喜体验
人工智能·microsoft·azure
用户849137175471643 分钟前
🚀 为什么猫和狗更像?用“向量思维”教会 AI 懂语义!
人工智能·llm
AI大模型知识1 小时前
Qwen3+Ollama本地部署MCP初体验
人工智能·llm