1. 背景:
使用 mindspore 学习神经网络,打卡第三天;
2. 训练的内容:
使用 mindspore 的常见的数据集 DataSet 的使用方法;
3. 常见的用法小节:
-
数据集加载
train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)
-
数据集迭代(create_tuple_iterator或create_dict_iterator 实现)
def visualize(dataset):
figure = plt.figure(figsize=(4,4))
cols, rows = 3, 3plt.subplots_adjust(wspace=0.5, hspace=0.5) for idx, (image, label) in enumerate(dataset.create_tuple_iterator()): figure.add_subplot(rows, cols, idx +1) plt.title(int(label)) plt.axis('off') plt.imshow(image.asnumpy().squeeze(), cmap='gray') if idx == cols * rows - 1: break; plt.show()
visualize(train_dataset)
-
数据集常用操作(shuffer, map, batch):
shuffer - 随机打乱数据顺序
train_dataset = train_dataset.shuffle(buffer_size=64)
visualize(train_dataset)image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype)map - 对数据进行
将图像统一除以255,数据类型由uint8转为了float32
train_dataset = train_dataset.map(vision.Rescale(1.0/255.0, 0), input_columns='image')
#batch: 有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量
train_dataset = train_dataset.batch(batch_size=32)batch后的数据增加一维,大小为batch_size。
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype) -
自定义数据集(可随机访问数据集/可迭代数据集/生成器类型)
自定义数据加载类,来生成数据集,通过 GeneratorDataset 接口实现数据加载
实现 getitem, len 方法,进行 索引键直接访问
class RandomAccessDataset:
def init(self):
self._data = np.ones((5, 2))
self._label = np.zeros((5, 1))def __getitem__(self, index): return self._data[index], self._label[index] def __len__(self): return len(self._data)
loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=['data', 'label'])for data in dataset:
print(data)可跌代数据集,实现 iter, next 方法
应用场景:iter(dataset),读取数据库,远程访问返回的数据流
class IterableDataset():
def init(self, start, end):
self.start = start
self.end = enddef __next__(self): return next(self.data) def __iter__(self): self.data = iter(range(self.start, self.end)) return self
生成器:可迭代数据集类型,依赖 python 的 generator 返回数据
def my_generator(start, end):
for i in range(start, end):
yield idataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=['data'])
for d in dataset:
print(d)