引言
很显然,这篇文章的主要内容如标题所示,使用pytorch创建自定义的数据集并进行简单的查看。不废话进入正题,相关资源链接放在文章末尾部分。
Dataset的创建
Dataset是什么
自定义数据集继承自 torch.utils.data.Dataset ,用于定义自己所需要的数据集,为此需要实现三个该类的三个函数:init、len 和 getitem。
python
from torch.utils.data import Dataset
python
class SelfDataset(Dataset):
def __init__(self):
def __len__(self):
def __getitem__(self):
init
init 函数在实例化 Dataset 对象时运行一次,实现该数据类对象的初始化。我们一般在该函数中定义数据集涉及的文件和所需的变换即transform。
len
len 函数返回我们数据集中样本的数量。
getitem
getitem_ 函数加载并返回给定索引 idx 处的数据集中的样本。如果有transform的话也会在这部分进行。
自定义Dataset数据集
接下来进行示例代码编写,自定义数据集为图像与掩膜的数据集,包含图像和对应掩膜文件。先看一下文件的构成:
python
import os
base_dir = "Dataset"
ls = os.listdir(base_dir)
for i in ls:
print(base_dir+"/"+i)
for j in os.listdir(base_dir+"/"+i):
print(" "+base_dir+"/"+i+"/"+j)
python
Dataset/Test
Dataset/Test/Image
Dataset/Train
Dataset/Train/Image
Dataset/Train/Mask
明确好文件路径后就开始编写自定义数据集的三个函数了。
声明自定义数据集类ImageMaskDataset
python
class ImageMaskDataset(Dataset):
编写__init__部分,参数为图像文件夹路径,掩膜文件夹路径和transform预处理操作
python
def __init__(self,image_dir,mask_dir,transform=None):
"""
Args:
image_dir: 图像文件夹路径 (如 'data/train/image')
mask_dir: 掩膜文件夹路径 (如 'data/train/mask')
transform: 可选的图像预处理操作
"""
self.image_dir = image_dir
self.mask_dir=mask_dir
self.transform = transform
self.image_filenames = sorted(os.listdir(image_dir)) #所有图像文件名
编写__len__部分
python
def __len__(self):
return len(self.image_filenames)
编写__getitem__部分,参数为索引idx,返回索引对应的图像和掩膜。如果定义了transform,则返回变换后的图像和掩膜
python
def __getitem__(self,idx):
# 获取文件名
img_name_with_ext = self.image_filenames[idx]
# 去掉拓展名
img_basename = os.path.splitext(img_name_with_ext)[0]
# 掩膜扩展名
mask_extensions = ['.png']
mask_path = None
# 找掩膜
for ext in mask_extensions:
potential_path = os.path.join(self.mask_dir, img_basename + ext)
if os.path.exists(potential_path):
mask_path = potential_path
break
if mask_path is None:
raise FileNotFoundError(f"找不到掩膜文件: {img_basename} 在 {self.mask_dir} 中")
# 加载图像和掩膜
img_path = os.path.join(self.image_dir, img_name_with_ext)
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L')
# transform
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image,mask
自定义transform
transform的构建用到torchvision.transforms.v2
python
import torchvision.transforms.v2 as transforms
# 定义同时作用于图像和掩膜的变换
transform = transforms.Compose([
transforms.Resize((512, 512)),# resize
transforms.RandomHorizontalFlip(p=0.5),# 随机水平翻转
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),# 随机仿射变换
transforms.ToImage(),# 变为tensor
transforms.ToDtype(torch.float32, scale=True)
])
transform做好了之后就能在数据集实例化时使用了。
DataLoader
DataLoder是什么
在训练模型时,我们通常希望以"小批量"传递样本,Dataloader就是干这个的。它和Dataset类同样位于 torch.utils.data 下,所以这样导入
python
from torch.utils.data import Dataset, DataLoader
DataLoader实例化
现在可以进行数据集和数据加载器实例化了。
python
# 实例化 Dataset 和 DataLoader
dataset = ImageMaskDataset('Dataset/Train/Image',"Dataset/Train/Mask",transform=transform)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
print("已载入数据")
python
已载入数据
我们来看看效果,写个循环看看dataloader里面是什么
python
i = 0
for images, masks in dataloader:
print(f"Batch images shape: {images.shape}") # [batch, 3, H, W]
print(f"Batch masks shape: {masks.shape}") # [batch, H, W]
print(f"第{i}轮")
i+=1
python
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第0轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第1轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第2轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
第3轮
Batch images shape: torch.Size([10, 3, 512, 512])
Batch masks shape: torch.Size([10, 1, 512, 512])
很好,因为定义了batch_size为10,所以输出的一个批次包含10张图片,image为3通道512*512,mask为单通道512*512。
结尾及相关链接
这就是本章的全部内容,实现了自定义Dataset和Dataloader的实例化,简单来说,就是让你具备了数据集加载并查看的能力。下期再聊。
相关链接如下:
示例所用数据集:
Pytorch链接: