PyTorch中的 Dataset、DataLoader 和 enumerate()

PyTorch:关于Dataset,DataLoader 和 enumerate()

本博文主要参考了 Pytorch中DataLoader的使用方法详解pytorch:关于enumerate,Dataset和Dataloader 两篇文章进行总结和归纳。

DataLoader 隶属 PyTorch 中 torch.utils.data 下的一个类,任何继承 torch.utils.data.Data 类的子类均需要重载__getitem__()及__len__()两个函数,且子类在__init__()函数产生的数据路径,将作为 DataLoader 参数 DataSets 的实参。该类将自定义的 Dataset 根据 batch size 大小、是否 shuffle 等封装成一个 Batch Size 大小的 Tensor,用于后面的训练。

Dataset 类构建

在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。这里的 Dateset 可以指整个数据集,也可以是训练集,测试集等。

javascript 复制代码
class Dataset:
    def __init__(self,...):
        ...
    def __len__(self,...):
        return n
    def __getitem__(self,item):
        return data[item]

正常情况下,该数据集是要继承 Pytorch 中 Dataset 类的,但实际操作中,即使不继承,数据集类构建后仍可以用 Dataloader() 加载的。

在dataset类中,len (self)返回数据集中数据的总个数,getitem (self,item)表示每次返回第 item 条(个)数据。

①__init__:传入数据,或者像下面一样直接在函数里加载数据

②__len__:返回这个数据集一共有多少个 item

③__getitem__:返回一条(个)训练样本的数据,并将其转换成 tensor

在 dataset 实例化时一般要传入数据集的路径,一般在__init__() 函数中指定数据集路径等相关信息(可以通过相关路径读取包含图像名称、标签等相关信息的 json 或者 csv 等类型的文件);通过__getitem__(self,item) 得到对应的图像并将进行 transform 转换(缩放、裁剪、转换成 tensor 等操作),最终以 tensor 的形式返回。

DataLoader 使用

在构建 Dataset 类后,即可使用 DataLoader 加载。DataLoader 中常用参数如下:

  1. dataset:需要载入的数据集,如前面构造的 dataset 类。
  2. batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个 batch 进行训练。
  3. shuffle:是否在打乱数据集样本顺序。True 为打乱,False 反之。
  4. num_workers:这个参数决定了有几个进程来处理 data loading。0 意味着所有的数据都会被 load 进主进程。(默认num_workers=0,在 Windows 系统下需要设置为 0
  5. drop_last:是否舍去最后一个batch的数据(很多情况下数据总数 N 与 batch size 不整除,导致最后一个 batch 不为 batch size)。True 为舍去,False 反之。

注意:使用 DataLoader 读取数据时,为了加快效率,所以使用了多个线程,即 num_workers 不为0,在 windows 系统下报如下的错误。

RuntimeError: Couldn't open shared file mapping: <torch_16716_3565374679>, error code: <1455>

javascript 复制代码
DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 

参照 DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support() 教程中提到,在 https://github.com/pytorch/pytorch/pull/5585 中给出了一些官方解释,应该是 Windows下的一些线程文件读写的问题。

在 Windows 上,FileMapping 对象应必须在所有相关进程都关闭后,才能释放。启用多线程处理时,子进程将创建 FileMapping,然后主进程将打开它。 之后当子进程将尝试释放它的时候,因为父进程还在引用,所以它的引用计数不为零,无法释放。 但是当前代码没有提供在可能的情况下再次关闭它的机会。这个版本官方说 num_workers=1 是可以用的,更多的线程还在解决,不过现在即便是用 2 个子进程也已经可以了。

加载数据的过程

pytorch 中加载数据的顺序是:

  1. 创建一个 dataset 对象
  2. 创建一个 dataloader 对象
  3. 循环 dataloader 对象,将 data, label 拿到模型中去训练

enumerate() 函数

在对 Dataloader 进行读取时,通常使用 enumerate() 函数,enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标。调用 enumerate(dataloader) 时每次都会读出一个 batch_size 大小的数据。例如,数据集中总共包含 245 张图像,train_loader = dataloader(dataset, batch_size=32, drop_last=True) 被实例化时,经过以下代码后输出的 count 为 224(正好等于32*7),而多出来的 245-224=21 张图像不够一个 batch 因此被 drop 掉了。下面展示了如何从 dataloader 中通过 enumerate() 返回一个batch_size的数据。

javascript 复制代码
for k, images, target in enumerate(dataloader):

其中,k代表下标值,images, target 代表可遍历的数据对象。因为 enumerate(dataloader) 一次会返回一个 batch 的数据,所以返回的 images 为 batch_size 长度的list,target 也为 batch_size 长度的 list。

通常,dataloader 里包含很多个数据对象,那么我们应该怎么保证 batch 就是我们所需要的数据呢?通过 Dataset 的定义可以实现我们需要的数据。Dataset 是用来定义数据从哪里读取,以及如何读取的问题,通过重写 Dataset 抽象类的__getitem__()函数。enumerate(dataloader) 得到的数据就是 getitem() 函数返回的数据,只不过 enumerate(dataloader) 一次会得到 batch_size 个不同 item 的数据组成的 list。

javascript 复制代码
def __getitem__(self, item):
	images = self.data[item]
	target = self.label[item]
	return images, target

返回 item 对应的数据,就是 enumerate(dataloader) 得到的数据的一部分。

javascript 复制代码
def __len__(self):
	return len(self.data)

返回 dataset 中总的数据个数,用于控制返回多少个 batch 的数据,enumerate(dataloader) 一次会返回 batch_size 大小的 list。

Reference

Pytorch中DataLoader的使用方法详解
pytorch:关于enumerate,Dataset和Dataloader
DataLoader windows平台下 多线程读数据报错 | BrokenPipeError: [Errno 32] Broken pipe | freeze_support()

相关推荐
m0_748254098 分钟前
100天精通Python(爬虫篇)——第113天:爬虫基础模块之urllib详细教程大全
开发语言·爬虫·python
cnbestec9 分钟前
Kinova在开源家庭服务机器人TidyBot++研究里大展身手
人工智能·科技·机器人
小爬虫程序猿15 分钟前
深入理解Jsoup与Selenium:Java爬虫的双剑合璧
爬虫·python·selenium
随便写写17 分钟前
Pyside6 基础框架以及三种基础控件
python
deflag21 分钟前
第T4周:TensorFlow实现猴痘识别(Tensorboard的使用)
人工智能·tensorflow·neo4j
夏娃同学30 分钟前
基于Flask后端框架的均值填充
python·flask
HackKong36 分钟前
Python与黑客技术
网络·python·web安全·网络安全·php
四口鲸鱼爱吃盐41 分钟前
Pytorch | 利用GNP针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·神经网络·计算机视觉
进击的小小学生41 分钟前
多因子模型连载
大数据·python·数据分析·区块链
小码贾42 分钟前
OpenCV-Python实战(6)——图相运算
人工智能·python·opencv