Day 38 - Dataset 和 DataLoader

在深度学习任务中,数据处理是至关重要的一环。面对大规模数据集,显存往往无法一次性存储所有数据,因此需要采用分批训练(Batch Training)的策略。PyTorch 提供了两个核心工具类来解决数据加载和预处理的问题:DatasetDataLoader

本文将深入探讨这两个类的原理、用法以及它们之间的关系,并以经典的 MNIST 手写数字数据集为例进行演示。

一、 PyTorch 数据处理核心架构

在 PyTorch 中,数据处理流程被解耦为两个独立的部分:

  1. Dataset (数据集):负责定义"数据是什么",即如何获取单个样本及其对应的标签,以及如何进行预处理。
  2. DataLoader (数据加载器):负责定义"如何加载数据",即如何将 Dataset 中的样本组装成批次(Batch),并提供多线程加载、随机打乱等功能。

形象比喻

  • Dataset 就像是厨师,他的工作是负责把每一个菜品(样本)切好、洗好、调好味(预处理)。
  • DataLoader 就像是服务员,他的工作是把厨师做好的菜品,按照订单的要求(Batch Size),打包好端给客人(模型)。

二、 Dataset 类详解

torch.utils.data.Dataset 是一个抽象基类,所有自定义的数据集都必须继承它,并实现其核心接口。

1. 核心魔术方法

PyTorch 要求 Dataset 子类必须实现以下两个魔术方法(Magic Methods):

  • len(self)
    • 作用:返回数据集的样本总数。
    • 调用方式 :当使用 len(dataset) 时自动调用。
    • 意义:DataLoader 需要知道数据集的大小,以便计算一个 Epoch 需要多少个 Batch。
  • getitem(self, idx)
    • 作用 :根据索引 idx 获取单个样本的数据和标签。
    • 调用方式 :当使用 dataset[idx] 时自动调用。
    • 意义:这是数据读取和预处理发生的地方。

2. Python 魔术方法原理解析

为了更好地理解 __len____getitem__,我们来看一个简单的 Python 自定义类示例:

复制代码
class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    # 实现索引访问功能
    def __getitem__(self, idx):
        return self.data[idx]

    # 实现长度获取功能
    def __len__(self):
        return len(self.data)

# 实例化对象
my_list_obj = MyList()

# 1. 测试 __getitem__
# 对象可以直接使用 [] 索引访问,像内置列表一样
print(f"索引为2的元素: {my_list_obj[2]}")  # 输出: 30

# 2. 测试 __len__
# 对象可以直接使用 len() 函数
print(f"列表长度: {len(my_list_obj)}")      # 输出: 5

3. 自定义 Dataset 示例

基于上述原理,一个典型的自定义 Dataset 结构如下:

复制代码
from torch.utils.data import Dataset

class MNIST(Dataset):
    def __init__(self, root, train=True, transform=None):
        """
        初始化:加载文件路径、标签文件等
        """
        # 假设 fetch_mnist_data 是一个自定义函数,用于读取数据
        self.data, self.targets = fetch_mnist_data(root, train) 
        self.transform = transform # 预处理操作流水线
        
    def __len__(self): 
        """
        返回数据集大小
        """
        return len(self.data)
    
    def __getitem__(self, idx): 
        """
        获取指定索引 idx 的样本
        """
        # 1. 根据索引获取原始数据和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 2. 应用预处理(如转 Tensor、归一化等)
        if self.transform is not None:
            img = self.transform(img)
            
        return img, target

三、 实战:使用 torchvision 加载 MNIST

torchvision 是 PyTorch 官方的计算机视觉库,其中 torchvision.datasets 模块内置了许多常用数据集(如 MNIST, CIFAR10, ImageNet 等),它们都已经实现了 Dataset 的接口。

1. 数据预处理 (Transforms)

在加载图像数据时,通常需要进行一系列预处理,如转为张量(Tensor)、归一化(Normalize)等。

复制代码
from torchvision import transforms

# 定义预处理流水线
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为 PyTorch 张量,并将像素值归一化到 [0, 1]
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化:(x - mean) / std。参数为 MNIST 数据集的全局均值和标准差
])

2. 加载数据集

复制代码
from torchvision import datasets

# 加载训练集
train_dataset = datasets.MNIST(
    root='./data',       # 数据存储路径
    train=True,          # True 表示加载训练集
    download=True,       # 如果路径下不存在数据,是否自动下载
    transform=transform  # 应用上面定义的预处理
)

# 加载测试集
test_dataset = datasets.MNIST(
    root='./data',
    train=False,         # False 表示加载测试集
    transform=transform
)

注意 :在 PyTorch 的设计哲学中,数据预处理通常是在加载阶段(即 __getitem__ 被调用时)动态进行的,而不是先处理好再保存。这样做可以节省磁盘空间,并支持动态的数据增强。

3. 查看单个样本

由于 train_dataset 本质上是一个 Dataset 子类,我们可以直接通过索引访问:

复制代码
import matplotlib.pyplot as plt
import torch

# 随机获取一个索引
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()

# 获取样本(自动调用 __getitem__)
image, label = train_dataset[sample_idx]

# 可视化(需要反归一化以便人眼观察)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray')
    plt.show()

print(f"Label: {label}")
imshow(image)

四、 DataLoader 类详解

torch.utils.data.DataLoader 是 PyTorch 中用于加载数据的核心工具。它接收一个 Dataset 对象,并根据配置参数生成一个可迭代对象。

1. 核心功能

DataLoader 的主要职责包括:

  • Batching:将多个样本打包成一个批次。
  • Shuffling:在每个 Epoch 开始时打乱数据顺序,防止模型记忆数据的顺序特征。
  • Multiprocessing:使用多进程并行加载数据,加速数据准备过程(避免 CPU 成为瓶颈)。

2. 创建 DataLoader

复制代码
from torch.utils.data import DataLoader

# 训练集加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64,   # 每个批次包含 64 个样本
    shuffle=True     # 训练时通常需要打乱数据
)

# 测试集加载器
test_loader = DataLoader(
    test_dataset,
    batch_size=1000, # 测试时显存压力较小,可以使用更大的 batch_size
    shuffle=False    # 测试时不需要打乱顺序,以便结果对比
)

关于 Batch Size 的选择

通常选择 2 的幂次方(如 32, 64, 128),这有利于 GPU 的并行计算效率。

五、 总结:Dataset 与 DataLoader 的对比

为了清晰地区分这两个概念,我们可以从以下几个维度进行对比:

|-----------|--------------------------------------|----------------------------------------|
| 维度 | Dataset | DataLoader |
| 核心职责 | 定义"数据内容"和"单个样本获取方式" | 定义"批量加载策略"和"迭代方式" |
| 核心方法 | __getitem__ (获取单个), __len__ (总数) | 内部实现迭代器协议 (__iter__) |
| 预处理位置 | 在 __getitem__ 中定义具体的转换逻辑 | 不负责预处理,直接使用 Dataset 返回的结果 |
| 并行处理 | 无(仅处理单样本逻辑) | 支持多进程加载 (num_workers) |
| 关键参数 | root (路径), transform (变换) | batch_size, shuffle, num_workers |

一句话总结

Dataset 负责把数据从磁盘读出来并处理成模型能看懂的格式(Tensor),而 DataLoader 负责把这些 Tensor 批量、高效、随机地喂给模型进行训练。

相关推荐
算法如诗5 小时前
Python实现基于GA -FCM遗传算法(GA)优化FCM模糊C均值聚类进行多变量时间序列预测
python·均值算法
测试人社区-小明5 小时前
洞察金融科技测试面试:核心能力与趋势解析
人工智能·科技·面试·金融·机器人·自动化·github
LO嘉嘉VE5 小时前
学习笔记二十九:贝叶斯决策论
人工智能·笔记·学习
识途老码5 小时前
python装饰器
开发语言·python
fresh hacker5 小时前
【Python数据分析】速通NumPy
开发语言·python·数据挖掘·数据分析·numpy
猫天意5 小时前
【即插即用模块】AAAI2026 | MHCB+DPA:特征提取+双池化注意力,涨点必备,SCI保二争一!彻底疯狂!!!
网络·人工智能·深度学习·算法·yolo
_codemonster5 小时前
AI大模型入门到实战系列(三)词元(token)和嵌入(embedding)
人工智能·机器学习·embedding
IT_陈寒5 小时前
Java 21新特性实战:这5个改进让我的代码效率提升40%
前端·人工智能·后端
爱笑的眼睛116 小时前
端到端语音识别系统的前沿实践与深度剖析:从RNN-T到Conformer
java·人工智能·python·ai