from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
3.使用dataset 和dataloader加载数据
python复制代码
class ImageFolder(data.Dataset):
def __init__(self, root, image_size=224, mode='train', augmentation_prob=0.4):
"""Initializes image paths and preprocessing module."""
self.root = root
# GT : Ground Truth
self.GT_paths = root[:-1] + '_GT/'
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
self.image_size = image_size
self.mode = mode
self.RotationDegree = [0, 90, 180, 270]
self.augmentation_prob = augmentation_prob
print("image count in {} path :{}".format(self.mode, len(self.image_paths)))
def __getitem__(self, index):
"""Reads an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
# [:-len(".jpg")]是列表的索引,image_path.split('_')[-1]是选取路径中的最后一段字符,
# [: -4]指的是截取第0个到第-4个元素,不包括第4个元素
filename = image_path.split('_')[-1][:-len(".jpg")]
GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png'
image = Image.open(image_path)
GT = Image.open(GT_path)
# 计算图像的比例
dada_transform = T.Compose([
T.Resize((256,256)),
T.ToTensor()
])
image= dada_transform(image)
GT = dada_transform(GT)
return image,GT
def __len__(self):
"""Returns the total number of font files."""
return len(self.image_paths)
def get_loader(image_path, image_size, batch_size, num_workers=0, mode='train', augmentation_prob=0.4):
"""Builds and returns Dataloader."""
dataset = ImageFolder(root=image_path, image_size=image_size, mode=mode, augmentation_prob=augmentation_prob)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader
4 开始画图
python复制代码
if __name__ == '__main__':
train_path = './dataset/train/'
image_size = 224
batch_size =4
num_workers = 0
augmentation_prob = 0.4
train_loader = get_loader(image_path=train_path,
image_size=image_size,
batch_size=batch_size,
num_workers=num_workers,
mode='train',
augmentation_prob=augmentation_prob)
#从train_loader中获取一个batch的图像和GT
for step, (img,GT) in enumerate(train_loader):
if step>0:
break
print(f"img:{img.shape}")
print(f"GT:{GT.shape}")
for ii in np.arange(4):
plt.subplot(2, 4, ii + 1)
image = img[ii, :, :, :].numpy().transpose(1, 2, 0)
plt.imshow(image)
plt.axis("off")
plt.subplot(2, 4, ii + 5)
GT_i = GT[ii, :, :, :].numpy().transpose(1, 2, 0)
plt.imshow(GT_i)
plt.axis("off")
plt.show()
plt.subplots_adjust(hspace=0.3)