Pytorch初上手——Dataset自定义数据集与Dataloader数据加载器

引言

很显然,这篇文章的主要内容如标题所示,使用pytorch创建自定义的数据集并进行简单的查看。不废话进入正题,相关资源链接放在文章末尾部分。

Dataset的创建

Dataset是什么

自定义数据集继承自 torch.utils.data.Dataset ,用于定义自己所需要的数据集,为此需要实现三个该类的三个函数:initlengetitem

python 复制代码
from torch.utils.data import Dataset
python 复制代码
class SelfDataset(Dataset):
    def __init__(self):
        
    def __len__(self):
    
    def __getitem__(self):

init

init 函数在实例化 Dataset 对象时运行一次,实现该数据类对象的初始化。我们一般在该函数中定义数据集涉及的文件和所需的变换即transform。

len

len 函数返回我们数据集中样本的数量。

getitem

getitem_ 函数加载并返回给定索引 idx 处的数据集中的样本。如果有transform的话也会在这部分进行。

自定义Dataset数据集

接下来进行示例代码编写,自定义数据集为图像与掩膜的数据集,包含图像和对应掩膜文件。先看一下文件的构成:

python 复制代码
import os
base_dir = "Dataset"
ls = os.listdir(base_dir)
for i in ls:
    print(base_dir+"/"+i)
    for j in os.listdir(base_dir+"/"+i):
        print("  "+base_dir+"/"+i+"/"+j)
python 复制代码
Dataset/Test
  Dataset/Test/Image
Dataset/Train
  Dataset/Train/Image
  Dataset/Train/Mask

明确好文件路径后就开始编写自定义数据集的三个函数了。

声明自定义数据集类ImageMaskDataset

python 复制代码
class ImageMaskDataset(Dataset):

编写__init__部分,参数为图像文件夹路径,掩膜文件夹路径和transform预处理操作

python 复制代码
    def __init__(self,image_dir,mask_dir,transform=None):
        """
        Args:
            image_dir: 图像文件夹路径 (如 'data/train/image')
            mask_dir:  掩膜文件夹路径 (如 'data/train/mask')
            transform: 可选的图像预处理操作
        """
        self.image_dir = image_dir
        self.mask_dir=mask_dir
        self.transform = transform
        self.image_filenames = sorted(os.listdir(image_dir)) #所有图像文件名

编写__len__部分

python 复制代码
    def __len__(self):
        return len(self.image_filenames)

编写__getitem__部分,参数为索引idx,返回索引对应的图像和掩膜。如果定义了transform,则返回变换后的图像和掩膜

python 复制代码
    def __getitem__(self,idx):
        # 获取文件名
        img_name_with_ext = self.image_filenames[idx]
        # 去掉拓展名
        img_basename = os.path.splitext(img_name_with_ext)[0]
        # 掩膜扩展名
        mask_extensions = ['.png']
        mask_path = None

        # 找掩膜
        for ext in mask_extensions:
            potential_path = os.path.join(self.mask_dir, img_basename + ext)
            if os.path.exists(potential_path):
                mask_path = potential_path
                break

        if mask_path is None:
            raise FileNotFoundError(f"找不到掩膜文件: {img_basename} 在 {self.mask_dir} 中")

        # 加载图像和掩膜
        img_path = os.path.join(self.image_dir, img_name_with_ext)
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        # transform
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image,mask  

自定义transform

transform的构建用到torchvision.transforms.v2

python 复制代码
import torchvision.transforms.v2 as transforms
# 定义同时作用于图像和掩膜的变换
transform = transforms.Compose([
    transforms.Resize((512, 512)),# resize
    transforms.RandomHorizontalFlip(p=0.5),# 随机水平翻转
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),# 随机仿射变换
    transforms.ToImage(),# 变为tensor
    transforms.ToDtype(torch.float32, scale=True) 
])

transform做好了之后就能在数据集实例化时使用了。

DataLoader

DataLoder是什么

在训练模型时,我们通常希望以"小批量"传递样本,Dataloader就是干这个的。它和Dataset类同样位于 torch.utils.data 下,所以这样导入

python 复制代码
from torch.utils.data import Dataset, DataLoader

DataLoader实例化

现在可以进行数据集和数据加载器实例化了。

python 复制代码
# 实例化 Dataset 和 DataLoader
dataset = ImageMaskDataset('Dataset/Train/Image',"Dataset/Train/Mask",transform=transform)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
print("已载入数据")
python 复制代码
已载入数据

我们来看看效果,写个循环看看dataloader里面是什么

python 复制代码
i = 0
for images, masks in dataloader:
    print(f"Batch images shape: {images.shape}")  # [batch, 3, H, W]
    print(f"Batch masks shape: {masks.shape}")    # [batch, H, W]
    print(f"第{i}轮")
    i+=1
python 复制代码
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第0轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第1轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第2轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第3轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])

很好,因为定义了batch_size为10,所以输出的一个批次包含10张图片,image为3通道512*512,mask为单通道512*512。

结尾及相关链接

这就是本章的全部内容,实现了自定义Dataset和Dataloader的实例化,简单来说,就是让你具备了数据集加载并查看的能力。下期再聊。

相关链接如下

示例所用数据集:

Huỳnh Trịnh Ngọc. A02025-Medical-Image-Segmentation. https://kaggle.com/competitions/a0-2025-medical-image-segmentation, 2025. Kaggle.

Pytorch链接:

https://pytorch.org/

相关推荐
青岛前景互联信息技术有限公司14 小时前
AI驱动的消防通信指挥系统:实现风险预警与智能接处警的秒级响应
大数据·人工智能·物联网
美团技术团队14 小时前
报名|ACL'26 美团中稿精选:从能力评测到推理优化,构建生成新范式
人工智能
Legend NO2414 小时前
非结构化数据治理全解:从合规痛点、中台架构到 AI 智能化分类落地
大数据·人工智能·架构
闻道参看14 小时前
智能搜索生态驱动的流量卡位实操:中小微入局者的 GEO 优化 服务选型全维度实证分析
大数据·人工智能
Bruce_Liuxiaowei14 小时前
当Windows成为Agent的监狱-操作系统级Agent安全架构深度解读
人工智能·windows·安全·安全架构·智能体
王_teacher14 小时前
ResNet-18网络模型+原理解析+Pytorch实现+手写模型
人工智能·cnn·卷积神经网络
树谷-胡老师14 小时前
2024年中国大型数据中心空间分布及环境属性数据集
人工智能·机器学习
老码观察14 小时前
设计模式实战解读(十一):外观模式——给复杂系统套一层壳
python·设计模式·外观模式
小马哥crazymxm14 小时前
自动驾驶“跨化身”!Sensor2Sensor用4D高斯泼溅+扩散模型,把网络行车记录仪变成高精度LiDAR真数据
人工智能·机器学习·自动驾驶
ss27314 小时前
【Python实战】基于FastAPI的绿植养护管理系统 - 完整项目
python·fastapi