DataLoader自定义数据集制作

如何自定义数据集:

- 1.数据和标签的目录结构先搞定(得知道到哪读数据)

  • 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
  • 3.完成单个数据与标签读取函数(给dataloader举一个例子)

以花朵数据集为例:

- 原来数据集都是以文件夹为类别ID,现在咱们换一个套路,用txt文件指定数据路径与标签(实际情况基本都这样)

  • 这回咱们的任务就是在txt文件中获取图像路径与标签,然后把他们交给dataloader
  • 核心代码非常简单,按照对应格式传递需要的数据和标签就可以啦

任务1:读取txt文件中的路径和标签

  • 第一个小任务,从标注文件中读取数据和标签
  • 至于你准备存成什么格式,都可以的,一会能取出来东西就行

任务2:分别把数据和标签都存在list里

  • 不是我非让你存list里,因为dataloader到时候会在这里取数据

  • 按照人家要求来,不要耍个性,让整list咱就给人家整

任务3:图像数据路径得完整

  • 因为一会咱得用这个路径去读数据,所以路径得加上前缀
  • 以后大家任务不同,数据不同,怎么加你看着来就行,反正得能读到图像

任务4:把上面那几个事得写在一起

- 1.注意要使用from torch.utils.data import Dataset, DataLoader
  • 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
  • 3.def init(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
  • 4.def getitem(self, idx):根据自己任务,返回图像数据和标签数据

任务5:数据预处理(transform)

- 1.预处理的事都在上面的__getitem__中完成,需要对图像和标签咋咋地的,要整啥事,都在上面整

  • 2.返回的数据和标签就是建模时模型的输入和损失函数中标签的输入,一定整明白自己模型要啥
  • 3.预处理这个事是你定的,不同的数据需要的方法也不一样,下面给出的是比较通用的方法

任务6:根据写好的class FlowerDataset(Dataset):来实例化咱们的dataloader

- 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)

  • 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
  • 3.打印看看数据里面是不是有东西了

任务7:用之前先试试,整个数据和标签对应下

- 1.别着急往模型里传,对不对都不知道呢

  • 2.用这个方法:iter(train_loader).next()来试试,得到的数据和标签是啥
  • 3.看不出来就把图画出来,标签打印出来,确保自己整的数据集没啥问题

代码实现

复制代码
import os

from matplotlib import pyplot as plt
from torchvision import transforms, models, datasets
import numpy as np
import torch
from PIL import Image


def load_annotations(ann_file):
    data_infos = {}
    with open(ann_file) as f:
        samples = [x.strip().split(' ') for x in f.readlines()]
        for filename, gt_label in samples:
            data_infos[filename] = np.array(gt_label, dtype=np.int64)
    return data_infos

img_label =load_annotations('./flower_data/train.txt')
image_name = list(img_label.keys())
label = list(img_label.values())

data_dir = './flower_data/'
train_dir = data_dir + '/train_filelist'
valid_dir = data_dir + '/val_filelist'

image_path = [os.path.join(train_dir,img) for img in image_name]

from torch.utils.data import Dataset, DataLoader


class FlowerDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.keys())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        image = Image.open(self.img[idx])
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image, label

    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos


data_transforms = {
    'train':
        transforms.Compose([
        transforms.Resize(64),
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ]),
    'valid':
        transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

train_dataset = FlowerDataset(root_dir=train_dir, ann_file = './flower_data/train.txt', transform=data_transforms['train'])
val_dataset = FlowerDataset(root_dir=valid_dir, ann_file = './flower_data/val.txt', transform=data_transforms['valid'])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)

image, label = next(iter(train_loader))
sample = image[0].squeeze()
sample = sample.permute((1, 2, 0)).numpy()
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label[0].numpy()))
相关推荐
孤狼warrior4 小时前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
Katecat996634 小时前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python
玩大数据的龙威4 小时前
农经权二轮延包—各种地块示意图
python·arcgis
ZH15455891314 小时前
Flutter for OpenHarmony Python学习助手实战:数据库操作与管理的实现
python·学习·flutter
belldeep4 小时前
python:用 Flask 3 , mistune 2 和 mermaid.min.js 10.9 来实现 Markdown 中 mermaid 图表的渲染
javascript·python·flask
喵手4 小时前
Python爬虫实战:电商价格监控系统 - 从定时任务到历史趋势分析的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·电商价格监控系统·从定时任务到历史趋势分析·采集结果sqlite存储
喵手5 小时前
Python爬虫实战:京东/淘宝搜索多页爬虫实战 - 从反爬对抗到数据入库的完整工程化方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·京东淘宝页面数据采集·反爬对抗到数据入库·采集结果csv导出
B站_计算机毕业设计之家5 小时前
猫眼电影数据可视化与智能分析平台 | Python Flask框架 Echarts 推荐算法 爬虫 大数据 毕业设计源码
python·机器学习·信息可视化·flask·毕业设计·echarts·推荐算法
PPPPPaPeR.5 小时前
光学算法实战:深度解析镜片厚度对前后表面折射/反射的影响(纯Python实现)
开发语言·python·数码相机·算法
JaydenAI5 小时前
[拆解LangChain执行引擎] ManagedValue——一种特殊的只读虚拟通道
python·langchain