在 PyTorch 中,Dataset
和 DataLoader
是数据加载和处理的重要组件。下面详细介绍 Dataset
类的作用及其 __len__()
和 __getitem__()
方法,以及它们如何与 DataLoader
协作,包括数据打乱(shuffle)和批处理(batching)等功能。
Dataset
类
Dataset
是一个抽象基类,用于表示一个数据集。你需要继承这个基类并实现以下两个方法:
1. __len__()
-
作用: 返回数据集中样本的总数量。
-
返回值: 一个整数,表示数据集中样本的数量。
-
用例 : 当你需要知道数据集的大小时,例如在创建
DataLoader
对象时,DataLoader
需要知道数据集中有多少样本才能正确地进行批处理和打乱操作。
这一步确定你取数据的范围,如果你想一次取两个数据,需要在__len__()里面控制索引的长度
python
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
2. __getitem__(index)
-
作用: 根据给定的索引返回数据集中的一个样本。
-
参数 :
index
,一个整数,表示要获取的样本的索引。 -
返回值: 返回一个样本的数据(和可能的标签),这可以是任何类型,例如一个图像和其标签、一个文本片段等等。
这里其实可操作性很大,比如你想每次dataloader得到的batch里面包含图片的路径,那么在这里return。
python
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
DataLoader
类
DataLoader
是用于批量加载数据的工具,它接受一个 Dataset
对象并提供了以下功能:
1. 批处理(Batching)
-
作用: 将数据集分成小批量,每次从数据集中取出一个批次的数据进行训练或评估。
-
实现 :
DataLoader
根据batch_size
参数将数据分批,每个批次包含batch_size
个样本。
python
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
2. 数据打乱(Shuffling)
-
作用: 在每个 epoch 开始时打乱数据的顺序,有助于模型更好地泛化。
-
实现 : 如果
shuffle=True
,DataLoader
会在每个 epoch 开始时创建一个打乱的索引列表,然后按这些索引顺序提取样本。这里的打乱索引范围就是从_len_函数获取的
过程:
- 创建一个索引列表
[0, 1, 2, ..., len(dataset) - 1]
。 - 如果
shuffle=True
,打乱这个索引列表。 - 使用打乱后的索引列表来从数据集中提取样本。
python
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
3. 并行加载(Multi-threaded Loading)
-
作用: 使用多个子进程并行加载数据,减少数据预处理和加载的时间。
-
参数 :
num_workers
指定用于数据加载的子进程数量。 -
实现 :
DataLoader
会启动num_workers
个子进程来调用Dataset
的__getitem__()
方法并加载数据。
python
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
总结
-
Dataset
:__len__()
: 返回数据集的总样本数。DataLoader
使用这个方法来确定数据的总量,以便进行正确的批处理和打乱。__getitem__(index)
: 根据索引返回单个数据样本。DataLoader
会调用这个方法来获取每个批次的数据样本。
-
DataLoader
:- 负责将数据分批、打乱数据、并行加载等任务。
- 使用
Dataset
的__len__()
来了解数据集的大小,并使用__getitem__()
来获取每个样本。
通过这种方式,Dataset
提供了数据访问的接口,而 DataLoader
管理数据的加载、打乱和批处理等高级功能。