dataset数据集
作用:
- 存储数据集的信息
- 获取数据集长度
__len__
- 获取数据集某特定条目的内容
__getitem__
dataloader 数据加载器
作用:
- 从数据集中随机加载数据, 并拼接为一个 batch
- 实现迭代器, 可以使用时, 迭代获取数据内容
代码实现:
python
import numpy as np
class ImageDataset():
def __init__(self, raw_data):
"""
数据集初始化
"""
self.raw_data = raw_data
def __len__(self):
"""
返回数据集的长度
"""
return len(self.raw_data)
def __getitem__(self, index):
"""
根据索引获取数据集中某一条数据
"""
image, label = self.raw_data[index]
return image, label
class DataLoader():
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
def __iter__(self):
self.indexes = np.arange(len(self.dataset))
self.cursor = 0
np.random.shuffle(self.indexes)
return self
def __next__(self):
# 计算起始索引和终止索引
begin = self.cursor
end = self.cursor + self.batch_size
# 若超出范围,抛出停止迭代异常
if end > len(self.dataset):
raise StopIteration
# 更新游标位置
self.cursor = end
# 根据索引获取对应的数据
batch_data = []
for index in self.indexes[begin:end]:
item = self.dataset[index]
batch_data.append(item)
return batch_data
if __name__ == "__main__":
images = [[f"image{i}", i] for i in range(10)]
dataset = ImageDataset(images)
loader = DataLoader(dataset, batch_size=5)
for index, batch_data in enumerate(loader, 1):
print(f"第{index}个批次:", batch_data)
代码中存在的问题:
当最后一个batch的样本数量不足 batch_size
时,比如总样本数不是 batch_size
的整数倍,不会返回最后一个不足的batch
改进后的 DataLoader:
python
class DataLoader():
def __init__(self,dataset, batch_size, shuffle=True):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
"""
初始化迭代器, 每个epoch开始时自动调用
"""
self.cursor = 0
self.indexes = np.arange(len(self.dataset))
if self.shuffle:
np.random.shuffle(self.indexes)
return self
def __next__(self):
"""
获取下一批次数据
"""
begin = self.cursor
end = self.cursor + self.batch_size
# 当剩余数据不足一个批次时全部返回剩余数据
if begin >= len(self.dataset):
raise StopIteration
end = min(end, len(self.dataset))
self.cursor = end
batch_data = []
for index in self.indexes[begin:end]:
item = self.dataset[index]
batch_data.append(item)
return batch_data
本文参考: