读取一个batch的图像并且显示出来

1读取一个batch用于训练

我们在训练模型的时候,除了观察图像的标签和尺寸,最好能读取一个batch的图像显示出来,观察原始图像和grountruth是否对应,如果正确才能正式开始后续的训练。

下面以一个皮肤病分割的数据集加以演示。

2.导入所需要的包

python 复制代码
from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

3.使用dataset 和dataloader加载数据

python 复制代码
class ImageFolder(data.Dataset):
    def __init__(self, root, image_size=224, mode='train', augmentation_prob=0.4):
        """Initializes image paths and preprocessing module."""
        self.root = root
        # GT : Ground Truth
        self.GT_paths = root[:-1] + '_GT/'
        self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
        self.image_size = image_size
        self.mode = mode
        self.RotationDegree = [0, 90, 180, 270]
        self.augmentation_prob = augmentation_prob
        print("image count in {} path :{}".format(self.mode, len(self.image_paths)))

    def __getitem__(self, index):
        """Reads an image from a file and preprocesses it and returns."""
        image_path = self.image_paths[index]
        # [:-len(".jpg")]是列表的索引,image_path.split('_')[-1]是选取路径中的最后一段字符,
        # [: -4]指的是截取第0个到第-4个元素,不包括第4个元素
        filename = image_path.split('_')[-1][:-len(".jpg")]
        GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png'

        image = Image.open(image_path)
        GT = Image.open(GT_path)
        # 计算图像的比例
        dada_transform = T.Compose([
            T.Resize((256,256)),
            T.ToTensor()
        ])

        image= dada_transform(image)
        GT = dada_transform(GT)
        return image,GT

    def __len__(self):
        """Returns the total number of font files."""
        return len(self.image_paths)


def get_loader(image_path, image_size, batch_size, num_workers=0, mode='train', augmentation_prob=0.4):
    """Builds and returns Dataloader."""

    dataset = ImageFolder(root=image_path, image_size=image_size, mode=mode, augmentation_prob=augmentation_prob)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    return data_loader

4 开始画图

python 复制代码
if __name__ == '__main__':
    train_path = './dataset/train/'
    image_size = 224
    batch_size =4
    num_workers = 0
    augmentation_prob = 0.4
    train_loader = get_loader(image_path=train_path,
                              image_size=image_size,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              mode='train',
                              augmentation_prob=augmentation_prob)

#从train_loader中获取一个batch的图像和GT
    for step, (img,GT) in enumerate(train_loader):
        if step>0:
            break
    print(f"img:{img.shape}")
    print(f"GT:{GT.shape}")

    for ii in np.arange(4):
        plt.subplot(2, 4, ii + 1)
        image = img[ii, :, :, :].numpy().transpose(1, 2, 0)
        plt.imshow(image)
        plt.axis("off")
        plt.subplot(2, 4, ii + 5)
        GT_i = GT[ii, :, :, :].numpy().transpose(1, 2, 0)
        plt.imshow(GT_i)
        plt.axis("off")
    plt.show()
    plt.subplots_adjust(hspace=0.3)

程序运行的结果如图所示:

相关推荐
AI量化投资实验室18 分钟前
deap系统重构,再新增一个新的因子,年化39.1%,卡玛提升至2.76(附python代码)
大数据·人工智能·重构
张登杰踩26 分钟前
如何快速下载Huggingface上的超大模型,不用梯子,以Deepseek-R1为例子
人工智能
AIGC大时代26 分钟前
分享14分数据分析相关ChatGPT提示词
人工智能·chatgpt·数据分析
TMT星球1 小时前
生数科技携手央视新闻《文博日历》,推动AI视频技术的创新应用
大数据·人工智能·科技
AI视觉网奇1 小时前
图生3d算法学习笔记
人工智能
小锋学长生活大爆炸1 小时前
【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式
人工智能·pytorch·深度学习·图神经网络·gnn·dgl
机械心2 小时前
pytorch深度学习模型推理和部署、pytorch&ONNX&tensorRT模型转换以及python和C++版本部署
pytorch·python·深度学习
佛州小李哥2 小时前
在亚马逊云科技上用AI提示词优化功能写出漂亮提示词(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
鸭鸭鸭进京赶烤2 小时前
计算机工程:解锁未来科技之门!
人工智能·科技·opencv·ai·机器人·硬件工程·软件工程
ModelWhale2 小时前
十年筑梦,再创鲸彩!庆祝和鲸科技十周年
人工智能·科技