Pytorch:复写Dataset函数详解,以及Dataloader如何调用

在 PyTorch 中,DatasetDataLoader 是数据加载和处理的重要组件。下面详细介绍 Dataset 类的作用及其 __len__()__getitem__() 方法,以及它们如何与 DataLoader 协作,包括数据打乱(shuffle)和批处理(batching)等功能。

Dataset

Dataset 是一个抽象基类,用于表示一个数据集。你需要继承这个基类并实现以下两个方法:

1. __len__()

  • 作用: 返回数据集中样本的总数量。

  • 返回值: 一个整数,表示数据集中样本的数量。

  • 用例 : 当你需要知道数据集的大小时,例如在创建 DataLoader 对象时,DataLoader 需要知道数据集中有多少样本才能正确地进行批处理和打乱操作。

这一步确定你取数据的范围,如果你想一次取两个数据,需要在__len__()里面控制索引的长度

python 复制代码
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

2. __getitem__(index)

  • 作用: 根据给定的索引返回数据集中的一个样本。

  • 参数 : index,一个整数,表示要获取的样本的索引。

  • 返回值: 返回一个样本的数据(和可能的标签),这可以是任何类型,例如一个图像和其标签、一个文本片段等等。

这里其实可操作性很大,比如你想每次dataloader得到的batch里面包含图片的路径,那么在这里return。

python 复制代码
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

DataLoader

DataLoader 是用于批量加载数据的工具,它接受一个 Dataset 对象并提供了以下功能:

1. 批处理(Batching)

  • 作用: 将数据集分成小批量,每次从数据集中取出一个批次的数据进行训练或评估。

  • 实现 : DataLoader 根据 batch_size 参数将数据分批,每个批次包含 batch_size 个样本。

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

2. 数据打乱(Shuffling)

  • 作用: 在每个 epoch 开始时打乱数据的顺序,有助于模型更好地泛化。

  • 实现 : 如果 shuffle=TrueDataLoader 会在每个 epoch 开始时创建一个打乱的索引列表,然后按这些索引顺序提取样本。这里的打乱索引范围就是从_len_函数获取的

过程:

  1. 创建一个索引列表 [0, 1, 2, ..., len(dataset) - 1]
  2. 如果 shuffle=True,打乱这个索引列表。
  3. 使用打乱后的索引列表来从数据集中提取样本。
python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3. 并行加载(Multi-threaded Loading)

  • 作用: 使用多个子进程并行加载数据,减少数据预处理和加载的时间。

  • 参数 : num_workers 指定用于数据加载的子进程数量。

  • 实现 : DataLoader 会启动 num_workers 个子进程来调用 Dataset__getitem__() 方法并加载数据。

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

总结

  1. Dataset:

    • __len__(): 返回数据集的总样本数。DataLoader 使用这个方法来确定数据的总量,以便进行正确的批处理和打乱。
    • __getitem__(index): 根据索引返回单个数据样本。DataLoader 会调用这个方法来获取每个批次的数据样本。
  2. DataLoader:

    • 负责将数据分批、打乱数据、并行加载等任务。
    • 使用 Dataset__len__() 来了解数据集的大小,并使用 __getitem__() 来获取每个样本。

通过这种方式,Dataset 提供了数据访问的接口,而 DataLoader 管理数据的加载、打乱和批处理等高级功能。

相关推荐
2502_927161283 分钟前
DAY 42 Grad-CAM与Hook函数
人工智能
Hello123网站20 分钟前
Flowith-节点式GPT-4 驱动的AI生产力工具
人工智能·ai工具
王者鳜錸33 分钟前
PYTHON让繁琐的工作自动化-猜数字游戏
python·游戏·自动化
yzx99101342 分钟前
Yolov模型的演变
人工智能·算法·yolo
若天明1 小时前
深度学习-计算机视觉-微调 Fine-tune
人工智能·python·深度学习·机器学习·计算机视觉·ai·cnn
爱喝奶茶的企鹅1 小时前
Ethan独立开发新品速递 | 2025-08-19
人工智能
J_bean2 小时前
Spring AI Alibaba 项目接入兼容 OpenAI API 的大模型
人工智能·spring·大模型·openai·spring ai·ai alibaba
SelectDB2 小时前
Apache Doris 4.0 AI 能力揭秘(一):AI 函数之 LLM 函数介绍
数据库·人工智能·数据分析
倔强青铜三2 小时前
苦练Python第39天:海象操作符 := 的入门、实战与避坑指南
人工智能·python·面试
飞哥数智坊2 小时前
GPT-5 初战:我用 Windsurf,体验了“结对编程”式的AI开发
人工智能·windsurf