1读取一个batch用于训练
我们在训练模型的时候,除了观察图像的标签和尺寸,最好能读取一个batch的图像显示出来,观察原始图像和grountruth是否对应,如果正确才能正式开始后续的训练。
下面以一个皮肤病分割的数据集加以演示。
2.导入所需要的包
python
from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
3.使用dataset 和dataloader加载数据
pythonclass ImageFolder(data.Dataset): def __init__(self, root, image_size=224, mode='train', augmentation_prob=0.4): """Initializes image paths and preprocessing module.""" self.root = root # GT : Ground Truth self.GT_paths = root[:-1] + '_GT/' self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) self.image_size = image_size self.mode = mode self.RotationDegree = [0, 90, 180, 270] self.augmentation_prob = augmentation_prob print("image count in {} path :{}".format(self.mode, len(self.image_paths))) def __getitem__(self, index): """Reads an image from a file and preprocesses it and returns.""" image_path = self.image_paths[index] # [:-len(".jpg")]是列表的索引,image_path.split('_')[-1]是选取路径中的最后一段字符, # [: -4]指的是截取第0个到第-4个元素,不包括第4个元素 filename = image_path.split('_')[-1][:-len(".jpg")] GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png' image = Image.open(image_path) GT = Image.open(GT_path) # 计算图像的比例 dada_transform = T.Compose([ T.Resize((256,256)), T.ToTensor() ]) image= dada_transform(image) GT = dada_transform(GT) return image,GT def __len__(self): """Returns the total number of font files.""" return len(self.image_paths) def get_loader(image_path, image_size, batch_size, num_workers=0, mode='train', augmentation_prob=0.4): """Builds and returns Dataloader.""" dataset = ImageFolder(root=image_path, image_size=image_size, mode=mode, augmentation_prob=augmentation_prob) data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) return data_loader
4 开始画图
pythonif __name__ == '__main__': train_path = './dataset/train/' image_size = 224 batch_size =4 num_workers = 0 augmentation_prob = 0.4 train_loader = get_loader(image_path=train_path, image_size=image_size, batch_size=batch_size, num_workers=num_workers, mode='train', augmentation_prob=augmentation_prob) #从train_loader中获取一个batch的图像和GT for step, (img,GT) in enumerate(train_loader): if step>0: break print(f"img:{img.shape}") print(f"GT:{GT.shape}") for ii in np.arange(4): plt.subplot(2, 4, ii + 1) image = img[ii, :, :, :].numpy().transpose(1, 2, 0) plt.imshow(image) plt.axis("off") plt.subplot(2, 4, ii + 5) GT_i = GT[ii, :, :, :].numpy().transpose(1, 2, 0) plt.imshow(GT_i) plt.axis("off") plt.show() plt.subplots_adjust(hspace=0.3)
程序运行的结果如图所示: