Pytorch初上手——Dataset自定义数据集与Dataloader数据加载器

引言

很显然,这篇文章的主要内容如标题所示,使用pytorch创建自定义的数据集并进行简单的查看。不废话进入正题,相关资源链接放在文章末尾部分。

Dataset的创建

Dataset是什么

自定义数据集继承自 torch.utils.data.Dataset ,用于定义自己所需要的数据集,为此需要实现三个该类的三个函数:initlengetitem

python 复制代码
from torch.utils.data import Dataset
python 复制代码
class SelfDataset(Dataset):
    def __init__(self):
        
    def __len__(self):
    
    def __getitem__(self):

init

init 函数在实例化 Dataset 对象时运行一次,实现该数据类对象的初始化。我们一般在该函数中定义数据集涉及的文件和所需的变换即transform。

len

len 函数返回我们数据集中样本的数量。

getitem

getitem_ 函数加载并返回给定索引 idx 处的数据集中的样本。如果有transform的话也会在这部分进行。

自定义Dataset数据集

接下来进行示例代码编写,自定义数据集为图像与掩膜的数据集,包含图像和对应掩膜文件。先看一下文件的构成:

python 复制代码
import os
base_dir = "Dataset"
ls = os.listdir(base_dir)
for i in ls:
    print(base_dir+"/"+i)
    for j in os.listdir(base_dir+"/"+i):
        print("  "+base_dir+"/"+i+"/"+j)
python 复制代码
Dataset/Test
  Dataset/Test/Image
Dataset/Train
  Dataset/Train/Image
  Dataset/Train/Mask

明确好文件路径后就开始编写自定义数据集的三个函数了。

声明自定义数据集类ImageMaskDataset

python 复制代码
class ImageMaskDataset(Dataset):

编写__init__部分,参数为图像文件夹路径,掩膜文件夹路径和transform预处理操作

python 复制代码
    def __init__(self,image_dir,mask_dir,transform=None):
        """
        Args:
            image_dir: 图像文件夹路径 (如 'data/train/image')
            mask_dir:  掩膜文件夹路径 (如 'data/train/mask')
            transform: 可选的图像预处理操作
        """
        self.image_dir = image_dir
        self.mask_dir=mask_dir
        self.transform = transform
        self.image_filenames = sorted(os.listdir(image_dir)) #所有图像文件名

编写__len__部分

python 复制代码
    def __len__(self):
        return len(self.image_filenames)

编写__getitem__部分,参数为索引idx,返回索引对应的图像和掩膜。如果定义了transform,则返回变换后的图像和掩膜

python 复制代码
    def __getitem__(self,idx):
        # 获取文件名
        img_name_with_ext = self.image_filenames[idx]
        # 去掉拓展名
        img_basename = os.path.splitext(img_name_with_ext)[0]
        # 掩膜扩展名
        mask_extensions = ['.png']
        mask_path = None

        # 找掩膜
        for ext in mask_extensions:
            potential_path = os.path.join(self.mask_dir, img_basename + ext)
            if os.path.exists(potential_path):
                mask_path = potential_path
                break

        if mask_path is None:
            raise FileNotFoundError(f"找不到掩膜文件: {img_basename} 在 {self.mask_dir} 中")

        # 加载图像和掩膜
        img_path = os.path.join(self.image_dir, img_name_with_ext)
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')

        # transform
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image,mask  

自定义transform

transform的构建用到torchvision.transforms.v2

python 复制代码
import torchvision.transforms.v2 as transforms
# 定义同时作用于图像和掩膜的变换
transform = transforms.Compose([
    transforms.Resize((512, 512)),# resize
    transforms.RandomHorizontalFlip(p=0.5),# 随机水平翻转
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),# 随机仿射变换
    transforms.ToImage(),# 变为tensor
    transforms.ToDtype(torch.float32, scale=True) 
])

transform做好了之后就能在数据集实例化时使用了。

DataLoader

DataLoder是什么

在训练模型时,我们通常希望以"小批量"传递样本,Dataloader就是干这个的。它和Dataset类同样位于 torch.utils.data 下,所以这样导入

python 复制代码
from torch.utils.data import Dataset, DataLoader

DataLoader实例化

现在可以进行数据集和数据加载器实例化了。

python 复制代码
# 实例化 Dataset 和 DataLoader
dataset = ImageMaskDataset('Dataset/Train/Image',"Dataset/Train/Mask",transform=transform)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
print("已载入数据")
python 复制代码
已载入数据

我们来看看效果,写个循环看看dataloader里面是什么

python 复制代码
i = 0
for images, masks in dataloader:
    print(f"Batch images shape: {images.shape}")  # [batch, 3, H, W]
    print(f"Batch masks shape: {masks.shape}")    # [batch, H, W]
    print(f"第{i}轮")
    i+=1
python 复制代码
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第0轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第1轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第2轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第3轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])

很好,因为定义了batch_size为10,所以输出的一个批次包含10张图片,image为3通道512*512,mask为单通道512*512。

结尾及相关链接

这就是本章的全部内容,实现了自定义Dataset和Dataloader的实例化,简单来说,就是让你具备了数据集加载并查看的能力。下期再聊。

相关链接如下

示例所用数据集:

Huỳnh Trịnh Ngọc. A02025-Medical-Image-Segmentation. https://kaggle.com/competitions/a0-2025-medical-image-segmentation, 2025. Kaggle.

Pytorch链接:

https://pytorch.org/

相关推荐
如竟没有火炬1 小时前
接雨水22
数据结构·python·算法·leetcode·散列表
HackTwoHub1 小时前
AI提示词注入绕过工具:一键绕过Codex/Claude安全限制,CTF夺旗与渗透测试必备神器
网络·人工智能·安全·web安全·系统安全·网络攻击模型·安全架构
诺未科技_NovaTech1 小时前
Microsoft 365 E7 ,“AI+安全+身份”三位一体,打造 AI 时代的一站式操作系统
人工智能·安全·microsoft
小白学大数据1 小时前
均线选股策略研究:基于 Python 数据分析实现
人工智能·python·数据分析
三无推导1 小时前
OpenHuman 开源项目详解:个人 AI 助手架构与核心技术拆解
人工智能·性能优化·架构·开源·ai助手
C137的本贾尼1 小时前
从零认识 Spring AI:Java 开发者的 AI 第一课
python·langchain
源码之家1 小时前
计算机毕业设计:Pyhon健康数据分析系统 Django框架 数据分析 可视化 身体数据分析 大数据(建议收藏)✅
大数据·python·数据挖掘·数据分析·django·lstm·课程设计
薛定猫AI1 小时前
【深度解析】Hermes Agent 与 Hermes Desktop:长期记忆、技能沉淀与多端网关的开源 AI Agent 实战
人工智能·开源
xwz小王子2 小时前
给机器人装上脊髓反射:AT-VLA 如何把触觉塞进 VLA,并把闭环响应压到 40 毫秒
人工智能·机器人