Pytorch:复写Dataset函数详解,以及Dataloader如何调用

在 PyTorch 中,DatasetDataLoader 是数据加载和处理的重要组件。下面详细介绍 Dataset 类的作用及其 __len__()__getitem__() 方法,以及它们如何与 DataLoader 协作,包括数据打乱(shuffle)和批处理(batching)等功能。

Dataset

Dataset 是一个抽象基类,用于表示一个数据集。你需要继承这个基类并实现以下两个方法:

1. __len__()

  • 作用: 返回数据集中样本的总数量。

  • 返回值: 一个整数,表示数据集中样本的数量。

  • 用例 : 当你需要知道数据集的大小时,例如在创建 DataLoader 对象时,DataLoader 需要知道数据集中有多少样本才能正确地进行批处理和打乱操作。

这一步确定你取数据的范围,如果你想一次取两个数据,需要在__len__()里面控制索引的长度

python 复制代码
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

2. __getitem__(index)

  • 作用: 根据给定的索引返回数据集中的一个样本。

  • 参数 : index,一个整数,表示要获取的样本的索引。

  • 返回值: 返回一个样本的数据(和可能的标签),这可以是任何类型,例如一个图像和其标签、一个文本片段等等。

这里其实可操作性很大,比如你想每次dataloader得到的batch里面包含图片的路径,那么在这里return。

python 复制代码
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

DataLoader

DataLoader 是用于批量加载数据的工具,它接受一个 Dataset 对象并提供了以下功能:

1. 批处理(Batching)

  • 作用: 将数据集分成小批量,每次从数据集中取出一个批次的数据进行训练或评估。

  • 实现 : DataLoader 根据 batch_size 参数将数据分批,每个批次包含 batch_size 个样本。

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

2. 数据打乱(Shuffling)

  • 作用: 在每个 epoch 开始时打乱数据的顺序,有助于模型更好地泛化。

  • 实现 : 如果 shuffle=TrueDataLoader 会在每个 epoch 开始时创建一个打乱的索引列表,然后按这些索引顺序提取样本。这里的打乱索引范围就是从_len_函数获取的

过程:

  1. 创建一个索引列表 [0, 1, 2, ..., len(dataset) - 1]
  2. 如果 shuffle=True,打乱这个索引列表。
  3. 使用打乱后的索引列表来从数据集中提取样本。
python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3. 并行加载(Multi-threaded Loading)

  • 作用: 使用多个子进程并行加载数据,减少数据预处理和加载的时间。

  • 参数 : num_workers 指定用于数据加载的子进程数量。

  • 实现 : DataLoader 会启动 num_workers 个子进程来调用 Dataset__getitem__() 方法并加载数据。

python 复制代码
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

总结

  1. Dataset:

    • __len__(): 返回数据集的总样本数。DataLoader 使用这个方法来确定数据的总量,以便进行正确的批处理和打乱。
    • __getitem__(index): 根据索引返回单个数据样本。DataLoader 会调用这个方法来获取每个批次的数据样本。
  2. DataLoader:

    • 负责将数据分批、打乱数据、并行加载等任务。
    • 使用 Dataset__len__() 来了解数据集的大小,并使用 __getitem__() 来获取每个样本。

通过这种方式,Dataset 提供了数据访问的接口,而 DataLoader 管理数据的加载、打乱和批处理等高级功能。

相关推荐
仙人掌_lz16 分钟前
AI与机器学习ML:利用Python 从零实现神经网络
人工智能·python·机器学习
我感觉。25 分钟前
【医疗电子技术-7.2】血糖监测技术
人工智能·医疗电子
DeepSeek忠实粉丝31 分钟前
微调篇--超长文本微调训练
人工智能·程序员·llm
XiaoQiong.Zhang33 分钟前
简历模板3——数据挖掘工程师5年经验
大数据·人工智能·机器学习·数据挖掘
逸雪飞扬44 分钟前
Gradio 非侵入式修改的离线使用方案
python·html
Akamai中国1 小时前
为何AI推理正推动云计算从集中式向分布式转型
人工智能·云原生·云计算·边缘计算
oil欧哟1 小时前
🧐 如何让 AI 接入自己的 API?开发了一个将 OpenAPI 文档转为 MCP 服务的工具
前端·人工智能·mcp
whoarethenext1 小时前
C++/OpenCV地砖识别系统结合 Libevent 实现网络化 AI 接入
c++·人工智能·opencv
来自外太空的鱼-张小张1 小时前
java将pdf文件转换为图片工具类
java·python·pdf
endNone1 小时前
【机器学习】SAE(Sparse Autoencoders)稀疏自编码器
人工智能·python·深度学习·sae·autoencoder·稀疏自编码器