自定义MyDataSet获取数据及对应label
实例化数据集需要用到 DataSet 类,我们可以自定义来实现对数据集的处理
MyDataSet类代码如下:
python
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform
def __len__(self):
return len(self.images_path)
# 获取item对象图像和类别,只对img进行预处理,label不处理
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
# zip(*batch):处理一个batch内的图片,图片为一组,标签为一组
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0) # 增加batch维度
labels = torch.as_tensor(labels) #将labels转化为tensor,images在__getitem__方法的transform已经转化为tensor
return images, labels
定义好MyDataSet后,就可以在train类中引用了,具体代码如下:
python
from my_dataset import MyDataSet
# 省略其他代码......
# 这里定义了train和val两种预处理方法
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
# MyDataSet实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 省略其他代码......
从而实现了在train类中获取一个batch的数据,且该数据为图像一组和label一组,同时经过预处理的数据