读取一个batch的图像并且显示出来

1读取一个batch用于训练

我们在训练模型的时候,除了观察图像的标签和尺寸,最好能读取一个batch的图像显示出来,观察原始图像和grountruth是否对应,如果正确才能正式开始后续的训练。

下面以一个皮肤病分割的数据集加以演示。

2.导入所需要的包

python 复制代码
from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

3.使用dataset 和dataloader加载数据

python 复制代码
class ImageFolder(data.Dataset):
    def __init__(self, root, image_size=224, mode='train', augmentation_prob=0.4):
        """Initializes image paths and preprocessing module."""
        self.root = root
        # GT : Ground Truth
        self.GT_paths = root[:-1] + '_GT/'
        self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
        self.image_size = image_size
        self.mode = mode
        self.RotationDegree = [0, 90, 180, 270]
        self.augmentation_prob = augmentation_prob
        print("image count in {} path :{}".format(self.mode, len(self.image_paths)))

    def __getitem__(self, index):
        """Reads an image from a file and preprocesses it and returns."""
        image_path = self.image_paths[index]
        # [:-len(".jpg")]是列表的索引,image_path.split('_')[-1]是选取路径中的最后一段字符,
        # [: -4]指的是截取第0个到第-4个元素,不包括第4个元素
        filename = image_path.split('_')[-1][:-len(".jpg")]
        GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png'

        image = Image.open(image_path)
        GT = Image.open(GT_path)
        # 计算图像的比例
        dada_transform = T.Compose([
            T.Resize((256,256)),
            T.ToTensor()
        ])

        image= dada_transform(image)
        GT = dada_transform(GT)
        return image,GT

    def __len__(self):
        """Returns the total number of font files."""
        return len(self.image_paths)


def get_loader(image_path, image_size, batch_size, num_workers=0, mode='train', augmentation_prob=0.4):
    """Builds and returns Dataloader."""

    dataset = ImageFolder(root=image_path, image_size=image_size, mode=mode, augmentation_prob=augmentation_prob)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    return data_loader

4 开始画图

python 复制代码
if __name__ == '__main__':
    train_path = './dataset/train/'
    image_size = 224
    batch_size =4
    num_workers = 0
    augmentation_prob = 0.4
    train_loader = get_loader(image_path=train_path,
                              image_size=image_size,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              mode='train',
                              augmentation_prob=augmentation_prob)

#从train_loader中获取一个batch的图像和GT
    for step, (img,GT) in enumerate(train_loader):
        if step>0:
            break
    print(f"img:{img.shape}")
    print(f"GT:{GT.shape}")

    for ii in np.arange(4):
        plt.subplot(2, 4, ii + 1)
        image = img[ii, :, :, :].numpy().transpose(1, 2, 0)
        plt.imshow(image)
        plt.axis("off")
        plt.subplot(2, 4, ii + 5)
        GT_i = GT[ii, :, :, :].numpy().transpose(1, 2, 0)
        plt.imshow(GT_i)
        plt.axis("off")
    plt.show()
    plt.subplots_adjust(hspace=0.3)

程序运行的结果如图所示:

相关推荐
Mr数据杨2 小时前
【Dv3Admin】插件 dv3admin_chatgpt 集成大语言模型智能模块
人工智能·语言模型·chatgpt
zm-v-159304339862 小时前
AI 赋能 Copula 建模:大语言模型驱动的相关性分析革新
人工智能·语言模型·自然语言处理
zhz52144 小时前
AI数字人融合VR全景:从技术突破到可信场景落地
人工智能·vr·ai编程·ai数字人·ai agent·智能体
数据与人工智能律师4 小时前
虚拟主播肖像权保护,数字时代的法律博弈
大数据·网络·人工智能·算法·区块链
武科大许志伟4 小时前
武汉科技大学人工智能与演化计算实验室许志伟课题组参加2025中国膜计算论坛
人工智能·科技
哲讯智能科技4 小时前
【无标题】威灏光电&哲讯科技MES项目启动会圆满举行
人工智能
__Benco4 小时前
OpenHarmony平台驱动开发(十七),UART
人工智能·驱动开发·harmonyos
小oo呆4 小时前
【自然语言处理与大模型】Windows安装RAGFlow并接入本地Ollama模型
人工智能·自然语言处理
开放知识图谱4 小时前
论文浅尝 | HOLMES:面向大语言模型多跳问答的超关系知识图谱方法(ACL2024)
人工智能·语言模型·自然语言处理·知识图谱
weixin_444579304 小时前
基于Llama3的开发应用(二):大语言模型的工业部署
人工智能·语言模型·自然语言处理