Dataset入门
Pytorch Dataset code:torch/utils/data/dataset.py#L17
Pytorch Dataset tutorial: tutorials/beginner/basics/data_tutorial.html
理论:
PyTorch中的Dataset
是一个抽象类,用来表示数据集的接口,所有其他数据集都需要继承这个类,并且覆写以下三个方法:
-
init:初始化数据集的一些配置,例如加载所有的数据标签。
-
len:以便
len(dataset)
可以返回数据集的大小,例如n。如果n小于数据集长度,则只会取前n个的数据。 -
getitem:输入是数据的索引,以便可以使用
dataset[i]
来获取第i个样本,数据增强一般会在这里做。
代码:
下面是一个自定义的Dataset样例(不可执行):
import cv2
import json
import torch.utils.Dataset as Dataset
class CustomDataset(Dataset):
def __init__(self, imgs_path, labels_path, img_transform=None, label_transform=None):
self.imgs_path = imgs_path # 输入图像的路径,list
self.labels_path = labels_path # 输入图像对应的标签路径,list
self.img_transform = img_transform # 图像的数据增强
self.label_transform = label_transform # 标签的数据增强
def __len__(self):
return len(self.imgs_path) # 返回数据集的长度
def __getitem__(self, idx):
img_path = self.imgs_path[idx]
label_path = self.labels_path[idx]
img = cv2.imread(img_path) # 读取图像
label = json.load(open(label_path)) # 读取标签
if self.img_transform: # 图像的数据增强
img = self.img_transform(img)
if self.label_transform: # 标签的数据增强
label = self.label_transform(label)
return img, label # 返回图像和标签,用于训练
总结:
值得注意的是,Dataset
只负责数据的加载和预处理,对于如何训练数据(例如:是否进行shuffle,是否进行并行加速等)这部分的逻辑是由DataLoader
实现的。通常情况下,我们会将Dataset
和DataLoader
一起使用。
另外,PyTorch还提供了一些常用的数据集,如:ImageFolder
,CIFAR10
,MNIST
等,这些数据集都是继承Dataset
类,同时在init
方法中进行数据的下载,以及在getitem
方法中进行数据的加载和预处理。
Dataset是单线程读取数据,每次只能读取一个样本,不能一次性读取一个mini-batch的数据。
Dataset的主要特性包含:
-
抽象接口:PyTorch通过定义一个抽象
Dataset
类,让用户可以使用统一的方式来加载各种不同的数据,提供了很好的扩展性。 -
懒加载:实际的数据载入并不发生在构造数据集实例时,而是发生在用到这些数据时,这样可以提高内存利用率,并且可以实现对大规模数据的处理。
-
预处理:
Dataset
的一个重要应用就是数据预处理,你可以在getitem
函数中进行任何你的数据预处理过程。
嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~