1.前言
在PyTorch中,Dataset
和DataLoader
是两个重要的工具,用于构建输入数据的管道。
(1)Dataset
是一个抽象类,表示数据集,需要实现__len__
和__getitem__
方法。
(2)DataLoader
是一个可迭代的数据加载器,它封装了数据集的加载、批处理、打乱和并行加载等功能。
2.分类任务创建Dataset
和DataLoader
(1)对于分类任务,Dataset
需要返回图像和对应的标签
python
from torch.utils.data import Dataset
from PIL import Image
import os
import torch
class ClassificationDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.labels = [...] # 这里应该是与图像对应的标签列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
(2)DataLoader
加载数据
python
from torch.utils.data import DataLoader
transform = ... # 这里定义你的数据预处理流程
dataset = ClassificationDataset(root_dir='path_to_your_data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
3.检测任务创建Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的边界框信息
python
class DetectionDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.annotations = [...] # 这里应该是与图像对应的边界框信息列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert('RGB')
boxes = self.annotations[idx] # 这些是边界框信息
if self.transform:
image, boxes = self.transform(image, boxes)
return image, boxes
(2)DataLoader
加载数据
python
dataloader = DataLoader(DetectionDataset(root_dir='path_to_your_data', transform=transform), batch_size=2, shuffle=True)
4.分割任务创建Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的分割掩码
python
class SegmentationDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.transform = transform
self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
self.masks = [...] # 这里应该是与图像对应的分割掩码列表
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
mask_path = self.masks[idx]
image = Image.open(img_path).convert('RGB')
mask = Image.open(mask_path).convert('L') # 假设掩码是灰度图
if self.transform:
image, mask = self.transform(image, mask)
return image, mask
(2)DataLoader
加载数据
python
dataloader = DataLoader(SegmentationDataset(root_dir='path_to_your_data', transform=transform), batch_size=4, shuffle=True)
在PyTorch的
Dataset
和DataLoader
框架中,idx
(或称为索引)是通过迭代DataLoader
时自动生成的。当你创建一个DataLoader
实例,并在训练循环中迭代它时,DataLoader
会内部调用Dataset
的__getitem__
方法,并自动为你提供索引idx
。