《昇思25天学习打卡营第3天 | mindspore DataSet 数据集的常见用法》

1. 背景:

使用 mindspore 学习神经网络,打卡第三天;

2. 训练的内容:

使用 mindspore 的常见的数据集 DataSet 的使用方法;

3. 常见的用法小节:

  • 数据集加载

    train_dataset = MnistDataset("MNIST_Data/train", shuffle=False)

  • 数据集迭代(create_tuple_iterator或create_dict_iterator 实现)

    def visualize(dataset):
    figure = plt.figure(figsize=(4,4))
    cols, rows = 3, 3

    复制代码
      plt.subplots_adjust(wspace=0.5, hspace=0.5)
    
      for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
          figure.add_subplot(rows, cols, idx +1)
          plt.title(int(label))
          plt.axis('off')
          plt.imshow(image.asnumpy().squeeze(), cmap='gray')
          if idx == cols * rows - 1:
              break;
         
      plt.show()

    visualize(train_dataset)

  • 数据集常用操作(shuffer, map, batch):

    shuffer - 随机打乱数据顺序

    train_dataset = train_dataset.shuffle(buffer_size=64)
    visualize(train_dataset)

    image, label = next(train_dataset.create_tuple_iterator())
    print(image.shape, image.dtype)

    map - 对数据进行

    将图像统一除以255,数据类型由uint8转为了float32

    train_dataset = train_dataset.map(vision.Rescale(1.0/255.0, 0), input_columns='image')

    #batch: 有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量
    train_dataset = train_dataset.batch(batch_size=32)

    batch后的数据增加一维,大小为batch_size。

    image, label = next(train_dataset.create_tuple_iterator())
    print(image.shape, image.dtype)

  • 自定义数据集(可随机访问数据集/可迭代数据集/生成器类型)

    自定义数据加载类,来生成数据集,通过 GeneratorDataset 接口实现数据加载

    实现 getitem, len 方法,进行 索引键直接访问

    class RandomAccessDataset:
    def init(self):
    self._data = np.ones((5, 2))
    self._label = np.zeros((5, 1))

    复制代码
      def __getitem__(self, index):
          return self._data[index], self._label[index]
      
      def __len__(self):
          return len(self._data)

    loader = RandomAccessDataset()
    dataset = GeneratorDataset(source=loader, column_names=['data', 'label'])

    for data in dataset:
    print(data)

    可跌代数据集,实现 iter, next 方法

    应用场景:iter(dataset),读取数据库,远程访问返回的数据流

    class IterableDataset():
    def init(self, start, end):
    self.start = start
    self.end = end

    复制代码
      def __next__(self):
          return next(self.data)
      
      def __iter__(self):
          self.data = iter(range(self.start, self.end))
          return self

    生成器:可迭代数据集类型,依赖 python 的 generator 返回数据

    def my_generator(start, end):
    for i in range(start, end):
    yield i

    dataset = GeneratorDataset(source=lambda: my_generator(3, 6), column_names=['data'])

    for d in dataset:
    print(d)

活动参与链接:

https://xihe.mindspore.cn/events/mindspore-training-camp

相关推荐
市场部需要一个软件开发岗位19 小时前
一个无人机平台+算法监督平台的离线部署指南
java·python·算法·bash·无人机·持续部署
喵手19 小时前
Python爬虫实战:房产数据采集实战 - 链家二手房&安居客租房多页爬虫完整方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·房产数据采集·链家二手房/安居客房源采集·采集结果sqlite导出
不懒不懒19 小时前
【机器学习:下采样 VS 过采样——逻辑回归在信用卡欺诈检测中的实践】
python·numpy·scikit-learn·matplotlib·pip·futurewarning
Leinwin19 小时前
Moltbot 部署至 Azure Web App 完整指南:从本地到云端的安全高效跃迁
后端·python·flask
叫我辉哥e119 小时前
新手进阶Python:办公看板集成AI智能助手+语音交互+自动化问答
python
咚咚王者19 小时前
人工智能之核心技术 深度学习 第九章 框架实操(PyTorch / TensorFlow)
人工智能·pytorch·深度学习
AI人工智能+19 小时前
联机手写签名识别技术通过采集书写时的压力、速度、轨迹等动态特征,构建独特的“行为指纹“
深度学习·联机手写签名识别·手写签名识别
大模型最新论文速读19 小时前
NCoTS:搜索最优推理路径,改进大模型推理效果
人工智能·深度学习·机器学习·语言模型·自然语言处理
真智AI19 小时前
用 FAISS 搭个轻量 RAG 问答(Python)
开发语言·python·faiss
2401_8576835419 小时前
使用Kivy开发跨平台的移动应用
jvm·数据库·python