【单点知识】基于实例详解PyTorch中的DataLoader类

文章目录

      • [0. 前言](#0. 前言)
      • [1. DataLoader的功能](#1. DataLoader的功能)
        • [1.1 可处理映射式/可迭代式数据集](#1.1 可处理映射式/可迭代式数据集)
        • [1.2 可自定义数据加载顺序](#1.2 可自定义数据加载顺序)
        • [1.3 可自动批量化打包数据](#1.3 可自动批量化打包数据)
        • [1.4 可支持多进程加载](#1.4 可支持多进程加载)
        • [1.5 可pin住内存](#1.5 可pin住内存)
      • [2. DataLoader的调用](#2. DataLoader的调用)
        • [2.1 DataLoader的调用方法](#2.1 DataLoader的调用方法)
        • [2.2 DataLoader的参数说明](#2.2 DataLoader的参数说明)
      • [3. DataLoader的使用实例](#3. DataLoader的使用实例)
      • [4. 总结](#4. 总结)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

在深度学习中,数据的预处理和加载方式对模型训练的效率与效果具有重要影响。PyTorch提供了一种强大的工具------DataLoader,它能够高效地将数据集转化为适合模型训练的小批量数据,并支持多线程并行加载机制,极大地提升了数据读取速度。本文将详细介绍PyTorch中的DataLoader。

本文的说明参考了PyTorch官网文件:https://pytorch.org/docs/stable/data.html

DataLoader通常会结合ImageFolder和transforms类(即构建Dataset过程)一起使用,这两个类已经在此前文章中专题说明过:

1. DataLoader的功能

根据DataLoader的官方文档说明,将从以下5个方面说明DataLoader的功能:

1.1 可处理映射式/可迭代式数据集

PyTorch 的 DataLoader 能够处理两种形式的数据集:映射式数据集(map-style)和可迭代式数据集(iterable-style)。映射式数据集指的是那些可以通过索引直接访问其元素的数据集,它们需要实现 __getitem__ 方法和可选的 __len__ 方法。例如,torch.utils.data.Dataset 子类就是这样一种数据集,可以通过 dataset[i] 获取第 i 个样本。

可迭代式数据集则不依赖于索引访问,而是可以直接迭代出数据,这类似于 Python 中的迭代器协议。DataLoader 能够兼容这两种风格的数据集,并将它们转化为合适的形式以供模型训练使用。

1.2 可自定义数据加载顺序

DataLoader 支持通过 sampler 参数来自定义数据加载的顺序。默认情况下,如果设置了 shuffle=True,那么在每个 epoch 开始时,数据加载器会打乱数据集的顺序。但也可以传入自定义的采样器,如 RandomSampler 实现随机抽样,或 SequentialSampler 保持原有顺序,甚至可以使用 WeightedRandomSampler 来实现加权随机抽样等复杂逻辑。

1.3 可自动批量化打包数据

DataLoader 自动将数据集中的样本打包成小批量,这是通过设置 batch_size 参数来实现的。每次调用 DataLoader 的迭代器时,都会返回一个包含 batch_size 个样本的数据批次,这对于训练深度学习模型是非常关键的,因为大多数模型都需要按照批次进行前向传播和反向传播计算。

1.4 可支持多进程加载

为了加速数据加载过程,特别是对于大型数据集,DataLoader 提供了多进程支持。通过设置 num_workers 参数,可以启动多个工作进程来并发地加载数据。这意味着数据准备和模型训练可以同时进行,极大地提高了整体效率。需要注意的是,启用多进程数据加载时需要考虑数据集是否线程安全,并确保在CPU或系统资源充足的环境下运行。

注意: num_workers 参数并不一定是越大越好。以下是考虑 num_workers 设置时需要权衡的因素:

  1. 系统资源

    • CPU核心数num_workers 应设置得小于或等于可用 CPU 核心数(包括超线程)。若设置得过高,可能会导致过多的上下文切换,反而降低性能。
    • 内存限制 :增加 num_workers 可能会增加内存消耗,因为每个工作进程都会缓存一部分数据。过高的 num_workers 可能会导致内存溢出。
  2. I/O 瓶颈

    • 如果数据读取主要是由硬盘 I/O 速率决定的瓶颈,那么超过某个点后增加 num_workers 不会带来进一步的速度提升,反而可能由于争抢 I/O 资源而造成负面影响。
  3. GPU 同步

    • 当数据加载速度远大于 GPU 计算速度时,更多的 num_workers 可能不会显著提高训练效率,因为 GPU 处理速度成为瓶颈。
  4. 同步与异步行为

    • PyTorch 的 DataLoader 默认实现了一个队列系统来进行数据加载的同步操作。过多的 num_workers 可能会导致队列中积累过多的数据,这些数据在被 GPU 使用前需要等待,因此并不会提高整体吞吐量。

通常的经验值可能是把num_workers设定在 CPU 核心数的一半到全部之间 。不过最佳实践是要根据具体的硬件配置、数据集大小和读取速度以及模型训练速度等因素进行调整。在某些情况下,如遇到操作系统兼容性问题或者为了避免不必要的复杂性,也可能需要将 num_workers 设置为较小的值甚至 0。在 Windows 系统中,由于进程间通信的限制,有时必须将 num_workers 设置为 0 才能避免错误。

1.5 可pin住内存

如果是在 GPU 上进行训练,DataLoader 可以通过设置 pin_memory=True 来自动将 CPU 内存中的数据拷贝至 CUDA 可以直接访问的内存区域(即"pin"住内存),这样在数据从 CPU 到 GPU 的转移过程中可以享受到更快的速度。这是因为被pin住内存可以利用异步内存复制操作,避免同步等待,从而使得数据流水线更为顺畅。

2. DataLoader的调用

在PyTorch中,DataLoader是基于torch.utils.data接口进行工作的,DataLoader也是torch.utils.data中的核心类。

At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class.

2.1 DataLoader的调用方法

DataLoader的调用方法如下:

python 复制代码
from torch.utils.data import DataLoader

dataset = ...
loader = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, prefetch_factor=2,
           persistent_workers=False)
2.2 DataLoader的参数说明
  • dataset : (Dataset) 必须是一个实现了 __getitem____len__ 方法的数据集对象,用于定义如何访问和获取样本及其对应的标签。

  • batch_size: (int) 指定每个批次加载多少个样本。这是训练神经网络时常见的参数,控制着每一步梯度更新所使用的样本数量。

  • shuffle : (bool) 若为 True,则在每个 epoch 开始时都会对数据集进行随机打乱顺序。这对于防止模型过拟合和确保模型看到所有样本组合至关重要,尤其是在训练阶段。

  • sampler : (Sampler) 可选参数,用于指定数据抽样的策略。如果提供了 sampler,则 shuffle 参数将被忽略。例如,可以使用此参数实现分布式训练时的数据分片。

  • batch_sampler : (Sampler) 可直接指定一个批量抽样器,它可以返回一批批索引而不是单个索引。如果指定了 batch_sampler,则 batch_sizeshuffle 将被忽略。

  • num_workers: (int) 上面已经详细说明,不再赘述。

  • collate_fn : (callable) 自定义函数,用于合并多个样本到一个批次。默认的 collate_fn 会堆叠具有相同形状的张量。用户可以自定义此函数以满足特殊的数据整合需求。

  • pin_memory : (bool) 若为 True,则在数据加载后将其移到 CUDA 可以直接访问的页锁定内存中,从而加快数据从 CPU 到 GPU 的传输速度。

  • drop_last : (bool) 若为 True,并且最后一个批次的样本数量小于 batch_size,则丢弃该批次。这在保证所有批次样本数量一致时很有用。

  • timeout: (float) 设置数据加载过程中阻塞的最大秒数。如果设置为0,则无限期等待。

  • worker_init_fn: (callable) 用户自定义的回调函数,用于初始化每个工作进程。可以在每个工作进程中设置不同的随机种子等。

  • prefetch_factor: (int) 预取因子,决定了工作进程在向主进程提交批次的同时,提前生成多少个额外的批次。增加此值可以减少潜在的 I/O 瓶颈,但也可能增加内存占用。

  • persistent_workers : (bool) 若为 True,则保留工作进程在多个数据加载迭代之间,这样可以避免每次重新启动工作进程带来的开销,尤其在长时间运行的任务中效果更明显。然而,这要求 num_workers > 0 并且 multiprocessing.get_start_method() 返回 'fork' 或 'spawn'。

3. DataLoader的使用实例

创建了一个 ImageFolder 数据集并应用了 transforms.ToTensor() 转换之后,你已经正确地设置了数据预处理流程,确保了 DataLoader 在批处理时接收到的是 Tensor 类型的数据。接下来要查看 DataLoader 中的特定元素,你可以通过迭代的方式来访问它们:

python 复制代码
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([transforms.Resize((200,200)), transforms.ToTensor()])

# 加载数据集
dataset = ImageFolder('D:\\DL\\pretrain\\hymenoptera\\hymenoptera_data', transform=transform)

# 创建 DataLoader
loader = DataLoader(dataset,batch_size=4)

# 通过迭代的方式访问 DataLoader 中的元素
for i, (images, labels) in enumerate(loader):
    if i == 0:  # 仅显示第一个批次的数据
        print(f"第{i}个批次的图像张量:")
        print(images.shape)  # 显示图像张量的形状
        print("对应的标签:", labels)
    break

上述代码会打印出 DataLoader 中的第一个批次(索引为0)的图像张量及其对应的类别标签。输出为:

python 复制代码
第0个批次的图像张量:
torch.Size([4, 3, 200, 200])
对应的标签: tensor([0, 0, 0, 0])

如果你想查看特定索引位置的样本(即上文所述的映射式map-style),但不是按批次而是直接查看原始数据集中的某个样本,那么可以直接从 ImageFolder 数据集中索引该样本,而非通过 DataLoader

python 复制代码
# 获取数据集中的特定样本(假设索引为10)
sample_idx = 10
image, label = dataset[sample_idx]
print(f"索引 {sample_idx} 的图像张量:")
print(image.shape)
print("对应的标签:", label)

输出为:

python 复制代码
索引 10 的图像张量:
torch.Size([3, 200, 200])
对应的标签: 0

再次强调,在实际应用中,由于 DataLoader 主要是用来进行批处理的,所以直接从其中索引单个元素并不常见。如需查看单个样本,通常是从原始 dataset 中访问。如果确实需要从 DataLoader 中一次性获取单个样本而不是一批次,需要特殊处理,例如通过 next(iter(loader)) 或者额外编写逻辑来实现。

4. 总结

最后我想再总结下 DataLoaderDataset的关系:

  • DataLoader 依赖于 Dataset 来获取原始数据,它的目的是为了更好地管理和高效地喂入数据给训练过程。
  • 使用时,首先需要基于 Dataset 构建好数据集实例,然后将这个数据集实例传给 DataLoader 构造函数,配置好加载参数后得到一个数据加载器。

总的来说,Dataset 负责定义数据源和访问逻辑,而 DataLoader 负责根据这些定义好的逻辑,按需以适合训练的形式加载和提供数据。

相关推荐
秀儿还能再秀30 分钟前
机器学习——简单线性回归、逻辑回归
笔记·python·学习·机器学习
图片转成excel表格1 小时前
WPS Office Excel 转 PDF 后图片丢失的解决方法
人工智能·科技·深度学习
阿_旭2 小时前
如何使用OpenCV和Python进行相机校准
python·opencv·相机校准·畸变校准
幸运的星竹2 小时前
使用pytest+openpyxl做接口自动化遇到的问题
python·自动化·pytest
李歘歘2 小时前
万字长文解读深度学习——多模态模型CLIP、BLIP、ViLT
人工智能·深度学习
kali-Myon3 小时前
ctfshow-web入门-SSTI(web361-web368)上
前端·python·学习·安全·web安全·web
B站计算机毕业设计超人3 小时前
计算机毕业设计Python+大模型农产品价格预测 ARIMA自回归模型 农产品可视化 农产品爬虫 机器学习 深度学习 大数据毕业设计 Django Flask
大数据·爬虫·python·深度学习·机器学习·课程设计·数据可视化
新手小白勇闯新世界3 小时前
深度学习知识点5-马尔可夫链
人工智能·深度学习·计算机视觉