【单点知识】基于实例详解PyTorch中的DataLoader类

文章目录

      • [0. 前言](#0. 前言)
      • [1. DataLoader的功能](#1. DataLoader的功能)
        • [1.1 可处理映射式/可迭代式数据集](#1.1 可处理映射式/可迭代式数据集)
        • [1.2 可自定义数据加载顺序](#1.2 可自定义数据加载顺序)
        • [1.3 可自动批量化打包数据](#1.3 可自动批量化打包数据)
        • [1.4 可支持多进程加载](#1.4 可支持多进程加载)
        • [1.5 可pin住内存](#1.5 可pin住内存)
      • [2. DataLoader的调用](#2. DataLoader的调用)
        • [2.1 DataLoader的调用方法](#2.1 DataLoader的调用方法)
        • [2.2 DataLoader的参数说明](#2.2 DataLoader的参数说明)
      • [3. DataLoader的使用实例](#3. DataLoader的使用实例)
      • [4. 总结](#4. 总结)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在深度学习中,数据的预处理和加载方式对模型训练的效率与效果具有重要影响。PyTorch提供了一种强大的工具------DataLoader,它能够高效地将数据集转化为适合模型训练的小批量数据,并支持多线程并行加载机制,极大地提升了数据读取速度。本文将详细介绍PyTorch中的DataLoader。

本文的说明参考了PyTorch官网文件:https://pytorch.org/docs/stable/data.html

DataLoader通常会结合ImageFolder和transforms类(即构建Dataset过程)一起使用,这两个类已经在此前文章中专题说明过:

1. DataLoader的功能

根据DataLoader的官方文档说明,将从以下5个方面说明DataLoader的功能:

1.1 可处理映射式/可迭代式数据集

PyTorch 的 DataLoader 能够处理两种形式的数据集:映射式数据集(map-style)和可迭代式数据集(iterable-style)。映射式数据集指的是那些可以通过索引直接访问其元素的数据集,它们需要实现 __getitem__ 方法和可选的 __len__ 方法。例如,torch.utils.data.Dataset 子类就是这样一种数据集,可以通过 dataset[i] 获取第 i 个样本。

可迭代式数据集则不依赖于索引访问,而是可以直接迭代出数据,这类似于 Python 中的迭代器协议。DataLoader 能够兼容这两种风格的数据集,并将它们转化为合适的形式以供模型训练使用。

1.2 可自定义数据加载顺序

DataLoader 支持通过 sampler 参数来自定义数据加载的顺序。默认情况下,如果设置了 shuffle=True,那么在每个 epoch 开始时,数据加载器会打乱数据集的顺序。但也可以传入自定义的采样器,如 RandomSampler 实现随机抽样,或 SequentialSampler 保持原有顺序,甚至可以使用 WeightedRandomSampler 来实现加权随机抽样等复杂逻辑。

1.3 可自动批量化打包数据

DataLoader 自动将数据集中的样本打包成小批量,这是通过设置 batch_size 参数来实现的。每次调用 DataLoader 的迭代器时,都会返回一个包含 batch_size 个样本的数据批次,这对于训练深度学习模型是非常关键的,因为大多数模型都需要按照批次进行前向传播和反向传播计算。

1.4 可支持多进程加载

为了加速数据加载过程,特别是对于大型数据集,DataLoader 提供了多进程支持。通过设置 num_workers 参数,可以启动多个工作进程来并发地加载数据。这意味着数据准备和模型训练可以同时进行,极大地提高了整体效率。需要注意的是,启用多进程数据加载时需要考虑数据集是否线程安全,并确保在CPU或系统资源充足的环境下运行。

注意: num_workers 参数并不一定是越大越好。以下是考虑 num_workers 设置时需要权衡的因素:

  1. 系统资源

    • CPU核心数num_workers 应设置得小于或等于可用 CPU 核心数(包括超线程)。若设置得过高,可能会导致过多的上下文切换,反而降低性能。
    • 内存限制 :增加 num_workers 可能会增加内存消耗,因为每个工作进程都会缓存一部分数据。过高的 num_workers 可能会导致内存溢出。
  2. I/O 瓶颈

    • 如果数据读取主要是由硬盘 I/O 速率决定的瓶颈,那么超过某个点后增加 num_workers 不会带来进一步的速度提升,反而可能由于争抢 I/O 资源而造成负面影响。
  3. GPU 同步

    • 当数据加载速度远大于 GPU 计算速度时,更多的 num_workers 可能不会显著提高训练效率,因为 GPU 处理速度成为瓶颈。
  4. 同步与异步行为

    • PyTorch 的 DataLoader 默认实现了一个队列系统来进行数据加载的同步操作。过多的 num_workers 可能会导致队列中积累过多的数据,这些数据在被 GPU 使用前需要等待,因此并不会提高整体吞吐量。

通常的经验值可能是把num_workers设定在 CPU 核心数的一半到全部之间 。不过最佳实践是要根据具体的硬件配置、数据集大小和读取速度以及模型训练速度等因素进行调整。在某些情况下,如遇到操作系统兼容性问题或者为了避免不必要的复杂性,也可能需要将 num_workers 设置为较小的值甚至 0。在 Windows 系统中,由于进程间通信的限制,有时必须将 num_workers 设置为 0 才能避免错误。

1.5 可pin住内存

如果是在 GPU 上进行训练,DataLoader 可以通过设置 pin_memory=True 来自动将 CPU 内存中的数据拷贝至 CUDA 可以直接访问的内存区域(即"pin"住内存),这样在数据从 CPU 到 GPU 的转移过程中可以享受到更快的速度。这是因为被pin住内存可以利用异步内存复制操作,避免同步等待,从而使得数据流水线更为顺畅。

2. DataLoader的调用

在PyTorch中,DataLoader是基于torch.utils.data接口进行工作的,DataLoader也是torch.utils.data中的核心类。

At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class.

2.1 DataLoader的调用方法

DataLoader的调用方法如下:

python 复制代码
from torch.utils.data import DataLoader

dataset = ...
loader = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)
2.2 DataLoader的参数说明
  • dataset : (Dataset) 必须是一个实现了 __getitem____len__ 方法的数据集对象,用于定义如何访问和获取样本及其对应的标签。

  • batch_size: (int) 指定每个批次加载多少个样本。这是训练神经网络时常见的参数,控制着每一步梯度更新所使用的样本数量。

  • shuffle : (bool) 若为 True,则在每个 epoch 开始时都会对数据集进行随机打乱顺序。这对于防止模型过拟合和确保模型看到所有样本组合至关重要,尤其是在训练阶段。

  • sampler : (Sampler) 可选参数,用于指定数据抽样的策略。如果提供了 sampler,则 shuffle 参数将被忽略。例如,可以使用此参数实现分布式训练时的数据分片。

  • batch_sampler : (Sampler) 可直接指定一个批量抽样器,它可以返回一批批索引而不是单个索引。如果指定了 batch_sampler,则 batch_sizeshuffle 将被忽略。

  • num_workers: (int) 上面已经详细说明,不再赘述。

  • collate_fn : (callable) 自定义函数,用于合并多个样本到一个批次。默认的 collate_fn 会堆叠具有相同形状的张量。用户可以自定义此函数以满足特殊的数据整合需求。

  • pin_memory : (bool) 若为 True,则在数据加载后将其移到 CUDA 可以直接访问的页锁定内存中,从而加快数据从 CPU 到 GPU 的传输速度。

  • drop_last : (bool) 若为 True,并且最后一个批次的样本数量小于 batch_size,则丢弃该批次。这在保证所有批次样本数量一致时很有用。

  • timeout: (float) 设置数据加载过程中阻塞的最大秒数。如果设置为0,则无限期等待。

  • worker_init_fn: (callable) 用户自定义的回调函数,用于初始化每个工作进程。可以在每个工作进程中设置不同的随机种子等。

  • prefetch_factor: (int) 预取因子,决定了工作进程在向主进程提交批次的同时,提前生成多少个额外的批次。增加此值可以减少潜在的 I/O 瓶颈,但也可能增加内存占用。

  • persistent_workers : (bool) 若为 True,则保留工作进程在多个数据加载迭代之间,这样可以避免每次重新启动工作进程带来的开销,尤其在长时间运行的任务中效果更明显。然而,这要求 num_workers > 0 并且 multiprocessing.get_start_method() 返回 'fork' 或 'spawn'。

3. DataLoader的使用实例

创建了一个 ImageFolder 数据集并应用了 transforms.ToTensor() 转换之后,你已经正确地设置了数据预处理流程,确保了 DataLoader 在批处理时接收到的是 Tensor 类型的数据。接下来要查看 DataLoader 中的特定元素,你可以通过迭代的方式来访问它们:

python 复制代码
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([transforms.Resize((200,200)), transforms.ToTensor()])

# 加载数据集
dataset = ImageFolder('D:\\DL\\pretrain\\hymenoptera\\hymenoptera_data', transform=transform)

# 创建 DataLoader
loader = DataLoader(dataset,batch_size=4)

# 通过迭代的方式访问 DataLoader 中的元素
for i, (images, labels) in enumerate(loader):
    if i == 0:  # 仅显示第一个批次的数据
        print(f"第{i}个批次的图像张量:")
        print(images.shape)  # 显示图像张量的形状
        print("对应的标签:", labels)
    break

上述代码会打印出 DataLoader 中的第一个批次(索引为0)的图像张量及其对应的类别标签。输出为:

python 复制代码
第0个批次的图像张量:
torch.Size([4, 3, 200, 200])
对应的标签: tensor([0, 0, 0, 0])

如果你想查看特定索引位置的样本(即上文所述的映射式map-style),但不是按批次而是直接查看原始数据集中的某个样本,那么可以直接从 ImageFolder 数据集中索引该样本,而非通过 DataLoader

python 复制代码
# 获取数据集中的特定样本(假设索引为10)
sample_idx = 10
image, label = dataset[sample_idx]
print(f"索引 {sample_idx} 的图像张量:")
print(image.shape)
print("对应的标签:", label)

输出为:

python 复制代码
索引 10 的图像张量:
torch.Size([3, 200, 200])
对应的标签: 0

再次强调,在实际应用中,由于 DataLoader 主要是用来进行批处理的,所以直接从其中索引单个元素并不常见。如需查看单个样本,通常是从原始 dataset 中访问。如果确实需要从 DataLoader 中一次性获取单个样本而不是一批次,需要特殊处理,例如通过 next(iter(loader)) 或者额外编写逻辑来实现。

4. 总结

最后我想再总结下 DataLoaderDataset的关系:

  • DataLoader 依赖于 Dataset 来获取原始数据,它的目的是为了更好地管理和高效地喂入数据给训练过程。
  • 使用时,首先需要基于 Dataset 构建好数据集实例,然后将这个数据集实例传给 DataLoader 构造函数,配置好加载参数后得到一个数据加载器。

总的来说,Dataset 负责定义数据源和访问逻辑,而 DataLoader 负责根据这些定义好的逻辑,按需以适合训练的形式加载和提供数据。

相关推荐
小白—人工智能6 分钟前
数据可视化 —— 多边图应用(大全)
python·信息可视化·数据可视化
noravinsc14 分钟前
使用django实现windows任务调度管理
python·django·sqlite
hvinsion15 分钟前
【Python 开源】你的 Windows 关机助手——PyQt5 版定时关机工具
windows·python·开源·定时关机
只因在人海中多看了你一眼15 分钟前
Django从零搭建卖家中心注册页面实战
python·django
亿牛云爬虫专家21 分钟前
Pyppeteer实战:基于Python的无头浏览器控制新选择
python·数据采集·爬虫代理·代理ip·无头浏览器·小红书·pyppeteer
小森776726 分钟前
(四)机器学习---逻辑回归及其Python实现
人工智能·python·算法·机器学习·逻辑回归·线性回归
生信碱移29 分钟前
入门级宏基因组数据分析教程,从实验到分析与应用
人工智能·经验分享·python·神经网络·数据挖掘·数据分析·数据可视化
码农不惑38 分钟前
Django的定制以及admin
数据库·python·django·sqlite
补三补四1 小时前
【深度学习基础】——机器的神经元:感知机
人工智能·深度学习·算法·机器学习
杂学者1 小时前
python 办公自动化------ excel文件的操作,读取、写入
python·excel