深度学习框架PyTorch笔记(三)数据集类(Data Set)与数据加载器(Data Loader)

深度学习框架PyTorch笔记(三)数据集类(Data Set)与数据加载器(Data Loader)

​ 在PyTorch中,数据集(Data Set)和数据加载器(Data Loader)是实现深度学习模型和测试的基本组件。下面将首先介绍数据集(Data Set)和数据加载器(Data Loader)的概念,然后介绍如何创建和使用PyTorch中的数据加载器的一些步骤和示例。

​ **数据集类(Data Set)**是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。

​ **数据加载器(Data Loader)**是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

​ PyTorch中的 torch.utils.data.Datasettorch.utils.data.DataLoader 是数据加载和处理的核心组件。它们将数据读取与模型训练解耦,提供高效、灵活的数据迭代方式。下面从基础概念、自定义加载器参数、多进程机制等方面进行详细介绍。

1.数据集(Data Set)

1.1 自定义数据集定义实现

Data Set 是一个抽象类,表示一个数据集。任何自定义数据集都必须继承它,自定义DataSet类,必须实现它构造函数和两个方法:

  • __init__: 在 实例化DataSet 对象运行一次。我们初始化包含图像的目录、注释文件和transform与 target_transform.

  • __len__:返回数据集的总样本数。len(dataset)会调用它。

  • __getitem__(self, idx):根据整数索引idx会返回一个样本(通常为特征和标签)。dataset[idx] 会调用它。

其作用就是实现通过索引访问对应的数据以及标签

python 复制代码
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

使用自定义数据集时,可以用将其与torch.utils.data.DataLoader 结合使用,以便进行数据的批量加载和处理和训练。

1.2 两种自定义数据集风格

​ 在PyTorch中,自定义数据集有两个核心设计模式:映射式(Map-Style)可迭代式(Iterable-style) 。它们的差异不仅是实现接口不同,更反映了"随机访问"与"流式读取"两种数据消费范式的根本区别。下面从设计理念、实现细节、多进程交互、适用场景等方面深入解析。

  • Map-style datasets(映射式) :就是上述需要实现 __getitem____len__ 的数据集,它通过索引映射到数据样本。适用于所有数据能一次性放入索引结构(如列表、文件路径列表)的场景。
  • Iterable-style datasets(可迭代式) :当数据集太大无法一次性加载,或数据是流式读取时(如实时日志、数据库流),可以继承 IterableDataset,实现 __iter__ 方法返回一个迭代器。这种数据集不能使用 len(),也无法使用随机采样(shuffle)的 loader,需使用 Sampler 的特定变体。

在后续笔记我们将详细介绍。

1.3 内置数据集

​ PyTorch提供了一些常用数据集类,主要在torchvision.datasetstorchtext.datasetstorchaudio.datasets中。例如:

  • torchvision.datasets.MNISTCIFAR10ImageFolder(从文件夹结构加载图片,子文件夹为类别)
  • torchtext.datasets.IMDB
  • torchaudio.datasets.LIBRISPEECH

这些内置类都继承自 Dataset,使用时可自动下载数据,并提供标准化访问方式。

​ 现在我们来展示一个如何从TorchVision加载了Fahion-MINIST由60000个训练样本和10000个测试样本组成。每个样本包含一个\(28\times{28}\) 灰度图像和一个来自10个类别之一的关联标签。下面使用以下参数加载FashionMINIST数据集:

  • root:是存储路径、测试数据的路径。
  • train:指定训练集或测试数据集。
  • download=True:如果root路径下没有数据,则从网上下载数据。
  • transformtarget_transform是指定特征和标签转换。
python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="./data",
    train=False,
    download=True,
    transform=ToTensor()
)

我们可以用索引来访问数据集中的样本,用 matplotlib 可视化图形样本。

python 复制代码
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

其运行结果如下:

2. 数据加载器(Data Loader)

数据加载器(Data Loader)DataSet 封装为可迭代对象,负责批量加载、打乱数据、多进程并行加载等功能。其功能如下:

  • 批量加载数据:DataLoader可以从数据集中按照指定的批量大小加载数据。每个批次的数据可以作为一个张量或列表返回,便于进行后续的处理和训练。
  • 数据随机洗牌:通过设置shuffle=True,DataLoader可以在每个迭代周期中对数据进行随机洗牌,以减少模型对数据顺序的依赖性,提高训练效果。
  • 多线程数据加载:DataLoader支持使用多个线程来并行加载数据,加快数据加载的速度,提高训练效率。
  • 数据批次采样:除了按照批量大小加载数据外,DataLoader还支持自定义的数据批次采样方式。可以通过设置batch_sampler参数来指定自定义的批次采样器,例如按照指定的样本顺序或权重进行采样。

数据加载器的API形式核心参数

python 复制代码
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, multiprocessing_context=None,
           generator=None, prefetch_factor=2, persistent_workers=False)
  • dataset :要加载的 Dataset 对象(映射式或可迭代式)。
  • batch_size:每个批次的样本数,默认为 1。
  • shuffle :是否在每个 epoch 开始时打乱数据顺序(仅对映射式有效)。打乱基于 RandomSampler
  • sampler :自定义采样器,继承自 torch.utils.data.Sampler。定义数据索引的抽取策略。如果指定,shuffle 必须为 False
  • batch_sampler :类似 sampler,但每次返回一批索引,与 batch_sizeshufflesampler 互斥。
  • num_workers:用于数据加载的子进程数。0 表示在主进程中加载,通常设置大于 0 可以加速数据预处理,利用多核。
  • collate_fn :函数,定义如何将多个样本列表合并为一个批次。默认 collate_fn 会将所有样本沿第0维堆叠成张量,通常对于同型数据有效。如果样本结构不一致(如不同长度序列),需要自定义。
  • pin_memory :若为 True,数据加载器在返回张量前将其复制到 CUDA 固定内存,加速数据传输到 GPU。仅适用于 CUDA。
  • drop_last :若为 True,丢弃最后一个不完整批次(当总样本数不能被 batch_size 整除时)。在训练时如果要求严格固定批次大小(如 BatchNorm)应设为 True
  • timeout:从 worker 进程获取一个 batch 的超时时间(秒)。如果超时会抛异常。
  • worker_init_fn:每个 worker 进程的初始化函数,参数为 worker id,可用于设置随机种子等。
  • generator:用于生成随机采样的伪随机数生成器,保证可复现性。
  • prefetch_factor:每个 worker 预先加载的 batch 数(默认 2),增加可以让 GPU 更少等待。
  • persistent_workers :若为 True,在数据集被消费一次后不会关闭 worker 进程,可保持 worker 存活以加速后续 epoch。

数据调用案例Demo

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader


# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


# 自定义数据加载器类
class MyDataLoader(DataLoader):
    def __init__(self, dataset, batch_size=1, shuffle=False, num_workers=0):
        super().__init__(dataset, batch_size, shuffle, num_workers=num_workers)

    def collate_fn(self, batch):
        # 自定义的数据预处理、合并等操作
        # 这里只是简单地将样本转换为Tensor,并进行堆叠
        return torch.stack(batch)


# 自定义数据集类
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 创建数据加载器实例
dataloader = MyDataLoader(dataset, batch_size=2, shuffle=True)

# 遍历数据加载器
for batch in dataloader:
    # batch是一个包含多个样本的张量(或列表)
    # 这里可以对批次数据进行处理
    print(batch)

3.实战案例

python 复制代码
import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
 
# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):
    x = torch.tensor(load_iris().data)
    y = torch.tensor(load_iris().target)
 
    # 数据归一化
    x_min = torch.min(x, dim=0).values
    x_max = torch.max(x, dim=0).values
    x = (x - x_min) / (x_max - x_min)
 
    if shuffle:
        idx = torch.randperm(x.shape[0])
        x = x[idx]
        y = y[idx]
    return x, y
 
# 自定义鸢尾花数据类
class IrisDataset(Dataset):
    def __init__(self, mode='train', num_train=120, num_dev=15):
        super(IrisDataset, self).__init__()
        x, y = load_data(shuffle=True)
        if mode == 'train':
            self.x, self.y = x[:num_train], y[:num_train]
        elif mode == 'dev':
            self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
        else:
            self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]
 
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
 
    def __len__(self):
        return len(self.x)
 
batch_size = 16
 
# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')
 
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

4.总 结

  • ataset 定义数据源及其访问方式,映射式最常用,流式数据用 IterableDataset
  • DataLoader 封装采样、批处理、多进程加载和内存固定等功能,参数丰富。
  • 通过自定义 samplercollate_fn 可以灵活处理各种数据形式和不平衡问题。
  • 多进程加载是加速训练的关键,需注意内存复制和系统兼容性。

掌握 DatasetDataLoader 的用法与内部机制,能够让你根据实际需求搭建高效的数据管道,将 I/O 瓶颈降到最低,从而充分释放 GPU 计算能力。

5.参考资料

https://cloud.tencent.com/developer/article/2055224?policyId=1003

https://cloud.tencent.com/developer/article/2440506

https://cloud.tencent.com/developer/article/1010379?policyId=1003