【Pytorch】1.读取训练数据集

导入Dataset类

py 复制代码
from torch.utils.data import Dataset
# 注意是Dataset(大写)的才是类

通过jupyter我们可以阅读一下Dataset类的具体使用方法

py 复制代码
help(Dataset)
# 或者直接
Dataset??

我们可以看到具体对Dataset类的解释

从蓝色字体我们可以得出

  • 所有的代表map的数据集应该继承这个类
  • 所有继承的子类都重写__getitem__这个方法,这个方法支持获取数据样本中的指定键
  • 同时子类也要重写__len__这个方法返回数据集大小
  • 子类可以重写__getitem__,来加速样本生成
    也就是说我们要重写__getitem__方法与__len__方法

其他导入包

py 复制代码
from PIL import Image  # 主要用于图像的操作
import os  # 文件操作

Image用于将目标路径的文件转化为可以打开的图片变量
os用于文件操作

  • listdir对目标文件夹中的文件名称列成列表
  • os.path.join用于将两个地址进行拼接

MyData类的定义

py 复制代码
class MyData(Dataset):  # 创建一个MyData类,同时继承Dataset类
    def __init__(self, root_dir, label_dir):  # 类似于c++的构造函数
        # root_dir 一般设置为训练集文件夹的地址(train)
        # label_dir 一般设置为分类文件夹的地址(ants)
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(root_dir, label_dir)  # 这个函数的作用是将root_dir的地址与label_dir的地址拼接起来
        self.img_path = os.listdir(self.path)  # 将特定文件夹地址(path)中的所有文件列成一个list

    def __getitem__(self, index):  # 重写父类的方法
        img_name = self.img_path[index]  # 获取对应下标的图片名
        img_item_path = os.path.join(self.path, img_name)  # 获取图片路径
        img = Image.open(img_item_path)  # 根据图片路径打开图片
        # img.show()    展示图片
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

类的实例化

py 复制代码
# root_dir 一般设置为训练集文件夹的地址(train)
# label_dir 一般设置为分类文件夹的地址(ants)
root_dir = "hymenoptera_data/train"
ant_label_dir = "ants"
bee_label_dir = "bees"
# 生成对应训练集的图片、标签列表
ants_dataset = MyData(root_dir, ant_label_dir)
bees_dataset = MyData(root_dir, bee_label_dir)

# 列表相加,前提是必须重载__len__方法
train_dataset = ants_dataset + bees_dataset

源码链接

github

参考资料

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

相关推荐
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
leaf_leaves_leaf5 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零15 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
四口鲸鱼爱吃盐6 小时前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗6 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
四口鲸鱼爱吃盐11 小时前
Pytorch | 利用VMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
四口鲸鱼爱吃盐11 小时前
Pytorch | 利用PI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
music0ant19 小时前
Idean 处理一个项目引用另外一个项目jar 但jar版本低的问题
java·pycharm·jar
love you joyfully1 天前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
Tttian6221 天前
Pycharm访问MongoDB数据库
数据库·mongodb·pycharm