- torch.utils.data.DataLoader是PyTorch数据加载工具的核心;
- 表示一个Python可迭代数据集;
DataLoader支持的数据集类型
- map-style 和 iterable-style 的数据集;
- 可定制的数据加载顺序;
- 自动批量数据集;
- 单进程和多进程数据加载;
- 自动内存固定;
DataLoader构造函数
python
DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
generator=None,
*
prefetch_factor=2,
persistent_workers=False,
pin_memory_device='')
Dataset 类型
- DataLoader构造函数中的最重要的参数是dataset;
- dataset指示了加载数据的数据集对象;
- PyTorch支持两种不同类型数据集;
- map-style数据集;
- 可迭代类型数据集;
Map-Style数据集
- 实现了__getitem__()函数和__len__()函数;
- 表示一个从指数/键值到数据样本的映射;
- dataset[idx],可以从磁盘的文件夹中读取第idx序列的图像和相应的标签;
可迭代类型数据集
- 是IterableDataset子类的一个实例;
- 实现了__iter__()函数原型;
- 表示了一个可迭代的数据样本集;
- 该类型的数据集特别适合代价较高的随机读取或者不可随机读取的场合;
- 批量数据的大小依赖于获取的数据;
- 调用iter(dataset),可以返回一个数据流、远程的服务器或者实时生成的日志;
数据加载顺序和采样器
- 对于可迭代数据集,数据加载的顺序取决于用户的定义;
- 上述特性允许简单的实现块读取和动态批量大小数据读取;
- 对于map-style类型的数据集:
- torch.utils.data.Sampler类用于指定数据加载中的指数/键值序列;
- 他们表示数据集的指数可迭代对象;
- 在随机梯度下降(SGD)的情况下:
- Sampler可以随机的排序指数列表且可即时生成一个指数列表;
- 或者可以生成一个mini-batch SGD的小量值的指数序列;
- 基于shuffle参数可以自动的构建一个序列或者被洗牌的采样器给到DataLoader;
- 相反的,用户可用采样器参数指定一个定制的采样器对象,生成下一个要获取的指标/键值;
- 定制的采样器可生成批量指数的列表并以batch_sampler参数传递给DataLoader;
- 通过batch_size和drop_last参数可激活自动批量模式;
加载批量的和非批量的数据
- 通过参数batch_size,drop_last,batch_sampler,和collate_fn;
- DataLoader支持自动整理获取的数据样本到批量集合中;
自动批量(默认)
- 最常见的情况;
- 对应于获取一个小量数据,且整理他们到一个批次样本中;
- 也就是包含张量的一个维度作为批量的维度(通常是第一维);
- 当batch_size(默认1)非空时,数据加载器生成批量的样本;
- batch_size和drop_last参数被用于指定数据加载器如何获得批量的数据集键值;
- 对于映射数据集,用户可以指定batch_sampler,一次生成一个键值的列表;
失能自动批量
- 某些情况下,用户可能想要手动处理批量数据集,或简单的加载几个样本;
- 参数batch_size和batch_sampler都为None时,自动批量失能;
- 每一个从数据集获取的样本被传递给collate_fn作为参数的函数所处理;
- 自动批量失能时,默认collate_fn简单转换Numpy数组为PyTorch张量;
- 保持一切不受影响;
单一进程和多进程数据加载
- DataLoader默认使用单一进行加载数据;
- 在一个python进程中,GIL(Global Interpreter Lock)阻止跨线程真全并行python代码运行;
- 为了避免数据加载的计算代码;
- PyTorch提供通过设置num_worker参数为正整数的简单设置切换执行多线程数据加载;
单一进程数据加载(默认)
- 该模式下,在DataLoader初始化的同一进程中获取数据;
- 因此数据加载可能会阻塞计算;
- 在跨进程共享数据的资源(共享的内存和文件描述符)有限时,该模式被优先考虑;
- 或者当整个数据集较小且可以被整体加载进内存时,该模式被优先考虑;
- 另外,单进程加载通常显式更多的可读性错误追踪信息,更有利与调试;
多进程数据加载
- 设置num_worker为正整数将会开启多进程数据加载;
- 多进程的数量为num_worker的数量;
- 一些次数的迭代后,加载器工作进程将会同父进程消耗相同的CPU内存;
- 这在数据集含大量数据(比如在数据集构建时加载大量的文件名列表)时可能有问题;
- 或者用户使用了多个进程(总的内存消耗=number_of_workers*size_of_parent_process)时可能会有问题;
- 最简单的应变方式为使用非参考计数的表示(比如,Pandas,Numpy或者PyArrow对象)替换python对象;
- 该模式下,在一个DataLoader迭代器被创建时,num_workers数量的进程被创建;
- dataset,collate_fn和worker_init_fn被传递到每一个进程;
- 上述三者被用于进程初始化和数据的获取;
- 这意味着在工作进程的运行中内部IO,转换操作同数据获取被一同处理;
- torch.utils.data.get_worker_info()在一个工作进程中返回多种有用信息(包括:进程id,数据集副本,初始化速度等);
- 且在主进程中返回None;
- 用户可以在数据集代码中使用torch.utils.data.get_worker_info()函数;
- worker_init_fn独立配置每一个数据集副本以确定是否在工作进程中运行代码;
- 对映射类型的数据集,主进程使用采样器生成索引并传送到工作进程中;
- 任何随机洗牌在主程序中执行,主程序通过分配索引确定数据加载顺序;
- 对可迭代类型数据集,每一个工作进程获取一个数据集对象的副本;
- 原始的多进程加载将导致数据的复制;
- 使用torch.utils.data.get_worker_info()和worker_init_fn,用户可以独立配置每一个副本;
- 多线程加载中,drop_last参数去掉每一个进程中的可迭代数据集副本的不完整批量数据;
- 当迭代的最后一位被达到时进程被关闭;
- 基于多进程中使用CUDA和共享CUDA张量的细节原因在多进程加载中不推荐返回CUDA张量;
- 推荐使用自动内存固定(设置pin_memory=True),能够更快传递数据到CUDA使能的GPU中;
基于平台的行为
- 由于工作进程依赖于Python multiprocessing,进程启动行为windows和Unix是有区别的;
- UNix上,fork() 是默认的multiprocessing启动方法;
- 使用fork(),直接通过克隆的地址空间,子工作进程可获取dataset和Python参数;
- windows或者MacOS上,spawn()是默认的multiprocessing启动方法;
- 使用spawn(),另一个解释器被启动,运行用户的主要脚本;
- 以及通过pickle序列化,接收数据集的内部工作进程函数、collate_fn和其他参数;
- 以上单独的序列化意味着应该采取两个步骤以确保在使用多进程数据加载时与windows兼容;
- 打包大部分主要的脚本代码在if name=='main':程序块中;
- 确保当每一个工作进程被启动时,if name=='main':不再次启动;
- 你可以在if name=='main':程序块中放置数据集和DataLoader·实例创建逻辑,因为在工作进程中其不需要被再次执行;
- 确保collate_fn,worker_init_fn或者dataset代码声明在顶级的__main__检查之外的定义中;
- 这就保证了上述代码声明在工作进程中是可用的;
多进程数据加载的随机性
- 默认情况下,每一个工作进程将具有自己的PyTorch种子,设置为base_seed+worker_id;
- 这里base_seed是一个由主进程使用他的RNG或者一个指定的生成器生成的长周期数据;
- 然而,用于其他库的种子可以通过初始化工作进程被复制;
- 导致每一个工作进程返回一致的随机数;
- 在worker_init_fn中,你可以获取PyTroch种子集用于每一个工作进程,使用
- torch.utils.data.get_worker_info().seed或者torch.initial_seed();
- 也可以使用上述两个种子为其他库在数据加载之前设置种子;
内存锁定
- 主机到GPU的拷贝更快,当数据来自锁定内存时;
- 对数据加载,DataLoader的pin_memory=True时,自动将获取的数据张量放到锁定内存中;
- 默认的内存锁定逻辑仅识别张量和映射以及包含张量的可迭代对象;
- 默认情况下,锁定逻辑观察到一个批量定制数据类型(当有一个collate_fn返回一个定制批量类型时);
- 或者批量数据中的每一个单元都是定制类型时;
- 锁定逻辑不能识别他们,将返回不在锁定内存中的批量数据(或者单元);
- 为了使能内存锁定用于定制批量数据或者数据类型,定义一个pin_memory()方法在你的定制类型中;
内存锁定实例
python
import torch
from torch.utils.data import DataLoader
class SimpleCustomBatch:
def __init__(self,data):
transposed_data=list(zip(*data))
self.inp=torch.stack(transposed_data[0],0)
self.tgt=torch.stack(transposed_data[1],0)
def pin_memory(self):
self.inp=self.inp.pin_memory()
self.tgt=self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps=torch.arange(10*5,dtype=torch.float32).view(10,5)
tgts=torch.arange(10*5,dtype=torch.float32).view(10,5)
dataset=TensorDataset(inps,tgts)
loader=DataLoader(dataset,
batch_size=2,
collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx,sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
DataLoader参数解析
- DataLoader合并一个数据和一个采样器,提供一个可迭代的采样器;
- DataLoader在单线程或多线程模式下,支持映射类型数据集和可迭代类型数据集;
- DataLoader支持定制的加载顺序和可优化的自动批量整理和内存锁定;
- dataset(Dataset)---加载数据的数据集;
- batch_size(Int,optional)---每一批次加载多少数据样本(默认为1);
- shuffle(bool,optional)---设置为True,每一代都进行数据洗牌(默认False);
- sampler(Sampler or iterable,optional)---定义从数据集抽取样本的策略,可以是实现__len__功能的任意Iterable对象,如果指定的话,shuffle必须不指定;
- batch_sampler(Sampler or iterable,optional)---类似与sampler,但是一次返回一个索引批次,同batch_size,shuffle,sampler和drop_last相互排斥;
- num_workers(int,optional)---多少子进程用于数据加载,0意味着将在主进程加载数据,默认0;
- collate_fn(Callable,optional)---合并一个样本列表为一个张量mini-batch的型式,当从映射数据集使用批次加载时使用;
- pin_memory(bool,optional)---如果设置为True,在返回数据之前数据加载器将拷贝张量到设备/CUDA的锁定内存区.如果你的数据单元是一个定制类型,或者你的collate_fn返回一个批次定制类型,参考文档中的实例;
- drop_last(bool,optional)---设置为true丢弃最后不完整的批次,如果数据集的大小不能被批量大小整除的话。如果设置为False,且数据集的大小不被批次大小整除,最后的批次将很小(默认False);
- timeout(numeric,optional)---如果为正数,超时值用于收集一个工作进程的批次,应当总是非正数(默认0);
- worker_init_fn(Callable,optional)---如果非空,该函数将会被在每一个工作子进程被调用,以工作进程id(一个正数,在【0,num_workers-1】范围内)作为输入;
- multiprocessing_context(str or multiprocessing.context.BaseContext,optional)---如果为空,操作系统的默认多进程上下文会被使用(默认为空);
- generator(torch.Generator,optional)---如果非空,随机采样器将使用RNG生成随机指标并多进程申城base_seed用于工作进程(默认为空);
- prefetch_factor(int,optional,keywork-only arg)---每一个工作进程提前加载的批次数量,2意味着对于所有的工作进程将有总数为2*num_worker批次的预取数据(默认值依赖于参数num_worker的值,如果num_worker=0,默认值为空,否则,num_worker>0,默认值为2);
- persistent_workers(bool,optional)---如果为True,数据加载器在一个数据集被消耗一次之后将不关闭工作进程,这允许保持工作进程数据集实例为活动状态(默认为False);
- pin_memory_device(str,optional)---如果为True,设备锁定内存运行.