Pytorch:复写Dataset函数详解,以及Dataloader如何调用

在 PyTorch 中,DatasetDataLoader 是数据加载和处理的重要组件。下面详细介绍 Dataset 类的作用及其 __len__()__getitem__() 方法,以及它们如何与 DataLoader 协作,包括数据打乱(shuffle)和批处理(batching)等功能。

Dataset

Dataset 是一个抽象基类,用于表示一个数据集。你需要继承这个基类并实现以下两个方法:

1. __len__()

  • 作用: 返回数据集中样本的总数量。

  • 返回值: 一个整数,表示数据集中样本的数量。

  • 用例 : 当你需要知道数据集的大小时,例如在创建 DataLoader 对象时,DataLoader 需要知道数据集中有多少样本才能正确地进行批处理和打乱操作。

这一步确定你取数据的范围,如果你想一次取两个数据,需要在__len__()里面控制索引的长度

python 复制代码
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

2. __getitem__(index)

  • 作用: 根据给定的索引返回数据集中的一个样本。

  • 参数 : index,一个整数,表示要获取的样本的索引。

  • 返回值: 返回一个样本的数据(和可能的标签),这可以是任何类型,例如一个图像和其标签、一个文本片段等等。

这里其实可操作性很大,比如你想每次dataloader得到的batch里面包含图片的路径,那么在这里return。

python 复制代码
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

DataLoader

DataLoader 是用于批量加载数据的工具,它接受一个 Dataset 对象并提供了以下功能:

1. 批处理(Batching)

  • 作用: 将数据集分成小批量,每次从数据集中取出一个批次的数据进行训练或评估。

  • 实现 : DataLoader 根据 batch_size 参数将数据分批,每个批次包含 batch_size 个样本。

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

2. 数据打乱(Shuffling)

  • 作用: 在每个 epoch 开始时打乱数据的顺序,有助于模型更好地泛化。

  • 实现 : 如果 shuffle=TrueDataLoader 会在每个 epoch 开始时创建一个打乱的索引列表,然后按这些索引顺序提取样本。这里的打乱索引范围就是从_len_函数获取的

过程:

  1. 创建一个索引列表 [0, 1, 2, ..., len(dataset) - 1]
  2. 如果 shuffle=True,打乱这个索引列表。
  3. 使用打乱后的索引列表来从数据集中提取样本。
python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3. 并行加载(Multi-threaded Loading)

  • 作用: 使用多个子进程并行加载数据,减少数据预处理和加载的时间。

  • 参数 : num_workers 指定用于数据加载的子进程数量。

  • 实现 : DataLoader 会启动 num_workers 个子进程来调用 Dataset__getitem__() 方法并加载数据。

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

总结

  1. Dataset:

    • __len__(): 返回数据集的总样本数。DataLoader 使用这个方法来确定数据的总量,以便进行正确的批处理和打乱。
    • __getitem__(index): 根据索引返回单个数据样本。DataLoader 会调用这个方法来获取每个批次的数据样本。
  2. DataLoader:

    • 负责将数据分批、打乱数据、并行加载等任务。
    • 使用 Dataset__len__() 来了解数据集的大小,并使用 __getitem__() 来获取每个样本。

通过这种方式,Dataset 提供了数据访问的接口,而 DataLoader 管理数据的加载、打乱和批处理等高级功能。

相关推荐
逻辑君8 小时前
认知神经科学研究报告【20260008】
人工智能·深度学习·神经网络·机器学习
GIS数据转换器8 小时前
延凡智慧水务系统:引领行业变革的智能引擎
大数据·人工智能·无人机·智慧城市
行者无疆_ty8 小时前
小龙虾(OpenClaw)安装教程
人工智能·agent·openclaw·小龙虾
2601_949539458 小时前
家用新能源 SUV 核心技术科普:后排娱乐、空间工程与混动可靠性解析
大数据·网络·人工智能·算法·机器学习
北邮刘老师8 小时前
暗数据:智能体探索世界的下一步
人工智能·大模型·prompt·智能体·智能体互联网
ggabb8 小时前
世界人口血型分布及关联特点
人工智能
弘弘弘弘~8 小时前
项目实战之评论情感分析模型——基于Bert(含任务头)
人工智能·深度学习·bert
明月_清风8 小时前
从提示词到脚手架:LLM 开发的三大工程维度对比
人工智能
南湖北漠8 小时前
奇奇怪怪漫画里面的蛞蝓是带壳的那种鼻涕虫
网络·人工智能·计算机网络·其他·安全·生活
小超同学你好8 小时前
Transformer 23. Qwen 3.5 架构介绍:混合线性/全注意力、MoE 与相对 Qwen 1 / 2 / 3 的演进
人工智能·深度学习·语言模型·架构·transformer