导入Dataset类
py
from torch.utils.data import Dataset
# 注意是Dataset(大写)的才是类
通过jupyter
我们可以阅读一下Dataset
类的具体使用方法
py
help(Dataset)
# 或者直接
Dataset??
我们可以看到具体对Dataset
类的解释
从蓝色字体我们可以得出
- 所有的代表map的数据集应该继承这个类
- 所有继承的子类都重写
__getitem__
这个方法,这个方法支持获取数据样本中的指定键 - 同时子类也要重写
__len__
这个方法返回数据集大小 - 子类可以重写
__getitem__
,来加速样本生成
也就是说我们要重写__getitem__
方法与__len__
方法
其他导入包
py
from PIL import Image # 主要用于图像的操作
import os # 文件操作
Image
用于将目标路径的文件转化为可以打开的图片变量
os
用于文件操作
listdir
对目标文件夹中的文件名称列成列表os.path.join
用于将两个地址进行拼接
MyData类的定义
py
class MyData(Dataset): # 创建一个MyData类,同时继承Dataset类
def __init__(self, root_dir, label_dir): # 类似于c++的构造函数
# root_dir 一般设置为训练集文件夹的地址(train)
# label_dir 一般设置为分类文件夹的地址(ants)
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(root_dir, label_dir) # 这个函数的作用是将root_dir的地址与label_dir的地址拼接起来
self.img_path = os.listdir(self.path) # 将特定文件夹地址(path)中的所有文件列成一个list
def __getitem__(self, index): # 重写父类的方法
img_name = self.img_path[index] # 获取对应下标的图片名
img_item_path = os.path.join(self.path, img_name) # 获取图片路径
img = Image.open(img_item_path) # 根据图片路径打开图片
# img.show() 展示图片
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
类的实例化
py
# root_dir 一般设置为训练集文件夹的地址(train)
# label_dir 一般设置为分类文件夹的地址(ants)
root_dir = "hymenoptera_data/train"
ant_label_dir = "ants"
bee_label_dir = "bees"
# 生成对应训练集的图片、标签列表
ants_dataset = MyData(root_dir, ant_label_dir)
bees_dataset = MyData(root_dir, bee_label_dir)
# 列表相加,前提是必须重载__len__方法
train_dataset = ants_dataset + bees_dataset
源码链接
参考资料