【Pytorch 库】自定义数据集相关的类

  • [torch.utils.data.Dataset 类](#torch.utils.data.Dataset 类)
  • [torch.utils.data.DataLoader 类](#torch.utils.data.DataLoader 类)
  • 自定义数据集示例
    • [1. 自定义 Dataset 类](#1. 自定义 Dataset 类)
    • [2. 在其他 .py 文件中引用和使用该自定义 Dataset](#2. 在其他 .py 文件中引用和使用该自定义 Dataset)
  • [torch_geometric.data.Dataset 类](#torch_geometric.data.Dataset 类)
    • [torch_geometric.data.Dataset VS torch.utils.data.Dataset](#torch_geometric.data.Dataset VS torch.utils.data.Dataset)

详细信息,参阅 torch.utils.data 文档页面

写得很棒的文章:PyTorch加载自己的数据集

PyTorch 数据加载工具 的核心是 torch.utils.data.DataLoader 类。它表示一个 Python 可迭代对象,用于遍历数据集,并支持以下功能:

  • 映射式和可迭代式数据集
  • 自定义数据加载顺序
  • 自动批量处理
  • 单进程和多进程数据加载
  • 自动内存固定

这些选项由 DataLoader 构造函数的 参数 配置,如下:

python 复制代码
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

DataLoader 构造函数中最重要的参数是 dataset,它表示一个数据集对象,用于加载数据。

PyTorch 支持两种不同类型的数据集:

  • 映射式数据集(map-style datasets)
  • 可迭代式数据集(iterable-style datasets)

torch.utils.data.Dataset 类

torch.utils.data.Dataset 是 PyTorch 库中的一个标准类,它是 用于自定义数据集的基类 。这个类是所有数据集的基础,适用于各种类型的数据加载,包括图像、文本、时间序列等

  • 作用Dataset 类的作用是提供一个接口,用于 加载和处理原始数据 。它是 PyTorch 的数据加载机制的一部分,通常 DataLoader 配合使用

  • 主要方法

    • __len__():返回数据集的大小(即样本的数量)。
    • __getitem__():根据索引返回单个数据样本。DataLoader 会使用该方法来迭代数据。

torch.utils.data.Dataset 类 Pytorch 官网文档

torch.utils.data.Dataset 类是一个抽象类,用于表示一个数据集。

  • 所有表示 从键到数据样本映射的数据集 都应当继承此类。
  • 所有子类都应当 重写 __getitem__() 方法 ,用于 根据给定的键获取数据样本
  • 子类还可以选择性地重写 __len__() 方法 ,许多 Sampler 实现和 DataLoader 的默认选项都期望返回数据集的大小
  • 子类还可以 选择性地实现 __getitems__() 方法 ,以加速 批量样本的加载。该方法接受一个包含批次样本索引的列表,并返回一个包含样本的列表。

注意 :默认情况下,DataLoader 会构造一个索引采样器,该采样器返回整数索引。为了使其与使用非整数索引/键的映射风格数据集兼容,必须提供自定义的采样器。

python 复制代码
# Pytorch torch.utils.data.Dataset 类的源码

class Dataset(Generic[_T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`. Subclasses could also
    optionally implement :meth:`__getitems__`, for speedup batched samples
    loading. This method accepts list of indices of samples of batch and returns
    list of samples.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs an index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> _T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    # def __getitems__(self, indices: List) -> List[_T_co]:
    # Not implemented to prevent false-positives in fetcher check in
    # torch.utils.data._utils.fetch._MapDatasetFetcher

    def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

torch.utils.data.DataLoader 类

数据加载器(DataLoader)将数据集(dataset)和采样器(sampler)结合起来,并提供对给定数据集的可迭代访问。

DataLoader 支持单进程或多进程加载的 映射式可迭代式 数据集,支持自定义加载顺序、可选的自动批量处理(拼接)以及内存固定(memory pinning)。

这些参数使得 DataLoader 能够在处理数据时非常灵活,支持不同的 数据加载策略并行处理方式内存管理 等。通过合理设置这些参数,可以在训练神经网络时实现高效的数据加载和处理。

  • dataset (Dataset) -- 从中加载数据的数据集。

  • batch_size (int, 可选) -- 每个批次加载多少个样本(默认值:1)。

  • shuffle (bool, 可选) -- 设置为 True 时,每个周期(epoch)都会重新打乱数据 (默认值:False)。

  • sampler (Sampler 或 Iterable, 可选) -- 定义从数据集中抽取样本的策略 。可以是任何实现了 __len__ 的 Iterable。如果指定了 sampler,则不能同时指定 shuffle

  • batch_sampler (Sampler 或 Iterable, 可选) -- 类似于 sampler,但一次返回一个批次的索引。与 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, 可选) -- 用于数据加载的子进程数量。设置为 0 时,数据将在主进程中加载。(默认值:0)

  • collate_fn (Callable, 可选) -- 将一个样本列表合并为一个 mini-batch 的 Tensor(s)。在使用映射式数据集进行批量加载时使用。

  • pin_memory (bool, 可选) -- 如果为 True,数据加载器将在返回数据之前将 Tensors 复制到设备/CUDA 固定内存中 。如果你的数据元素是自定义类型,或者 collate_fn 返回的是自定义类型的批次,请参见下面的示例。

  • drop_last (bool, 可选) -- 设置为 True 时,如果数据集大小不能被批次大小整除,则丢弃最后一个不完整的批次 。如果设置为 False 且数据集大小不能被批次大小整除,那么最后一个批次将会更小。(默认值:False

  • timeout (数值型, 可选) -- 如果为正,则表示从工作进程收集批次的超时值。应该始终是非负值。(默认值:0)

  • worker_init_fn (Callable, 可选) -- 如果不为 None,则会在每个工作进程中调用该函数,输入为工作进程的 ID(一个整数,在 [0, num_workers - 1] 范围内),在种子设置后、数据加载之前调用。(默认值:None

  • multiprocessing_context (str 或 multiprocessing.context.BaseContext, 可选) -- 如果为 None,则使用操作系统的默认多进程上下文。(默认值:None

  • generator (torch.Generator, 可选) -- 如果不为 None,则该随机数生成器将由 RandomSampler 用于生成随机索引,并且在多进程中用于生成工作进程的基准种子。(默认值:None

  • prefetch_factor (int, 可选,仅限关键字参数) -- 每个工作进程提前加载的批次数。设置为 2 时,所有工作进程将提前加载总计 2 * num_workers 个批次。(默认值取决于 num_workers 设置的值。如果 num_workers=0,默认值为 None;如果 num_workers > 0,默认值为 2)

  • persistent_workers (bool, 可选) -- 如果为 True,数据加载器将在数据集被消费一次后不会关闭工作进程。这样可以保持工作进程中的数据集实例处于活动状态。(默认值:False

  • pin_memory_device (str, 可选) -- 如果 pin_memoryTrue,则为内存固定的设备指定设备名称。

  • in_order (bool, 可选) -- 如果为 False,数据加载器将不强制按照先进先出(FIFO)的顺序返回批次。仅在 num_workers > 0 时生效。(默认值:True

注意

  • 如果使用了 spawn 启动方法,则 worker_init_fn 不能是不可序列化的对象,例如 Lambda 函数。有关 PyTorch 中多进程的更多细节,请参阅"多进程最佳实践"。
  • len(dataloader) 的启发式方法基于所使用的采样器的长度。当数据集是一个 IterableDataset 时,它会根据 len(dataset) / batch_size 来返回一个估计值,并根据 drop_last 设置进行适当的四舍五入,而不管多进程加载配置如何。这是 PyTorch 能做出的 最佳估算 ,因为 PyTorch 相信用户的数据集代码能够正确处理多进程加载,避免重复数据。
    然而,如果数据分片导致多个工作进程的最后一个批次不完整,那么这个估计仍然可能不准确,因为 (1) 一个原本完整的批次可能会被分成多个批次,(2) 当 drop_last 设置为 True 时,可能会丢失多个批次的样本。不幸的是,PyTorch 通常无法检测到这种情况。
    有关这两种数据集类型以及 IterableDataset 如何与多进程数据加载交互的更多细节,请参阅"数据集类型"。
  • 有关随机种子相关的问题,请参阅"可重现性","我的数据加载器工作进程返回相同的随机数"以及"多进程数据加载中的随机性"相关说明。
  • in_order 设置为 False 可能会影响可重现性,并且在数据不平衡的情况下,可能会导致传递给训练器的数据分布偏斜。

自定义数据集示例

1. 自定义 Dataset 类

首先,需要定义一个自定义的 Dataset 类,继承自 torch.utils.data.Dataset

python 复制代码
# dataset.py

import torch
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transforms

class CustomDataset(Dataset):
    def __init__(self, data_dir, mode='train', transform=None):
        """
        :param data_dir: 数据集根目录
        :param mode: 'train' 或 'test',指定加载训练集或测试集
        :param transform: 数据转换(如图像缩放,裁剪,归一化等)
        """
        self.data_dir = data_dir
        self.mode = mode
        self.transform = transform
        
        # 假设数据集结构是这样的:
        # data_dir/
        #    train/
        #        class1/
        #        class2/
        #    test/
        #        class1/
        #        class2/
        
        self.image_paths = []
        self.labels = []

        # 加载数据
        self._load_data()

    def _load_data(self):
        """根据模式加载训练集或测试集数据"""
        # 设置数据目录
        data_folder = os.path.join(self.data_dir, self.mode)
        
        for label, class_name in enumerate(os.listdir(data_folder)):
            class_folder = os.path.join(data_folder, class_name)
            
            if os.path.isdir(class_folder):
                for img_name in os.listdir(class_folder):
                    img_path = os.path.join(class_folder, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(label)
    
    def __len__(self):
        """返回数据集的大小"""
        return len(self.image_paths)

    def __getitem__(self, idx):
        """根据索引返回数据和标签"""
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # 读取图像
        image = Image.open(img_path).convert('RGB')
        
        # 应用转换
        if self.transform:
            image = self.transform(image)
        
        return image, label

在自定义数据集 CustomDataset 中,通过 mode='train'mode='test' 来决定加载训练集或测试集的数据 。这个 mode 参数可以在创建数据集时传入:

  • mode='train' 时,加载训练集的数据。
  • mode='test' 时,加载测试集的数据。

该示例 假设数据集存储在 data_dir/traindata_dir/test 文件夹下 ,并按类存放在子文件夹中。可以根据自己的数据存储结构修改 _load_data() 方法。

2. 在其他 .py 文件中引用和使用该自定义 Dataset

在其他 .py 文件中,可以引用该数据集类,并根据需要加载训练集或测试集。还可以传递数据增强或转换操作,例如使用 torchvision.transforms 来进行图像处理。

python 复制代码
# main.py

import torch
from torch.utils.data import DataLoader
from dataset import CustomDataset
from torchvision import transforms

# 定义数据增强和预处理操作
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建训练集和测试集
train_dataset = CustomDataset(data_dir='./data', mode='train', transform=transform)
test_dataset = CustomDataset(data_dir='./data', mode='test', transform=transform)

# 使用 DataLoader 加载数据集
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 使用训练集
for images, labels in train_loader:
    print(images.shape, labels.shape)  # 输出每个批次的图像和标签的大小
    # 在此进行训练

torch_geometric.data.Dataset 类

torch_geometric.data.Dataset 类官方文档

在 PyTorch 中,torch.utils.data.Datasettorch_geometric.data.Dataset 都是用来表示数据集的基类,但它们的作用、设计和用途有一些显著的区别,特别是在图神经网络(GNN)方面,

  • torch.utils.data.Dataset通用 ,适用于 各种类型的数据集
  • torch_geometric.data.Dataset专门为处理图数据(如图结构数据、边、节点特征等)而设计的,

torch_geometric.data.Dataset 是 PyTorch Geometric(一个专门为图神经网络设计的扩展库)中的数据集类,继承自 torch.utils.data.Dataset ,专门用于处理图数据,如图结构数据、节点特征 node_features、边特征 edge_index 等。

  • 主要方法

    • __len__():返回数据集中的图的数量
    • get():返回单个图的数据get() 通常是实现单个数据项的加载过程,返回一个 Data 对象 ,其中包含图的结构信息(如 edge_index)、节点特征(如 x)等。
  • 用途:专门用于图神经网络(GNN)任务,适合处理图结构数据,例如社交网络、分子结构、物理网络等。

torch_geometric.data.Dataset VS torch.utils.data.Dataset

主要区别

特性 torch.utils.data.Dataset torch_geometric.data.Dataset
设计目的 用于处理一般的数据集(如图像、文本、时间序列等)。 专门为图神经网络设计,处理图结构数据(如图、边、节点特征)。
继承关系 基本类,PyTorch 数据加载的基类。 继承自 torch.utils.data.Dataset,扩展为图数据处理。
数据存储 存储一般的数据(如图像、文本数据等)。 存储图数据结构,包括 edge_indexx(节点特征)等。
核心方法 __getitem__() 返回单个数据项,通常是一个样本。 get() 返回一个图数据对象,通常是 Data
数据类型 适用于任何类型的数据集。 适用于图结构数据集。
多进程支持 支持多进程数据加载(通过 num_workers)。 同样支持多进程数据加载,但主要针对图数据。
批处理支持 支持自动批处理(batching),使用 DataLoader 也支持批处理,并且能够自动处理图结构的数据。
图数据处理 没有内建的图数据处理支持。 内建支持图数据结构,如 edge_indexx(节点特征)、y(标签)。

关系

  • torch_geometric.data.Datasettorch.utils.data.Dataset 的一个扩展,专门为图数据设计。
  • torch_geometric.data.Dataset 在其基础上提供了对图数据的支持,包括图的边结构、节点特征等,使得 PyTorch Geometric 更适用于 图神经网络(GNN) 等图学习任务。
相关推荐
鸡鸭扣20 分钟前
Docker:3、在VSCode上安装并运行python程序或JavaScript程序
运维·vscode·python·docker·容器·js
paterWang1 小时前
基于 Python 和 OpenCV 的酒店客房入侵检测系统设计与实现
开发语言·python·opencv
东方佑1 小时前
使用Python和OpenCV实现图像像素压缩与解压
开发语言·python·opencv
神秘_博士2 小时前
自制AirTag,支持安卓/鸿蒙/PC/Home Assistant,无需拥有iPhone
arm开发·python·物联网·flutter·docker·gitee
Moutai码农3 小时前
机器学习-生命周期
人工智能·python·机器学习·数据挖掘
小白教程4 小时前
python学习笔记,python处理 Excel、Word、PPT 以及邮件自动化办公
python·python学习·python安装
武陵悭臾4 小时前
网络爬虫学习:借助DeepSeek完善爬虫软件,实现模拟鼠标右键点击,将链接另存为本地文件
python·selenium·网络爬虫·pyautogui·deepseek·鼠标右键模拟·保存链接为htm
代码猪猪傻瓜coding5 小时前
关于 形状信息提取的说明
人工智能·python·深度学习
码界筑梦坊6 小时前
基于Flask的第七次人口普查数据分析系统的设计与实现
后端·python·信息可视化·flask·毕业设计
微笑的Java6 小时前
Python - 爬虫利器 - BeautifulSoup4常用 API
开发语言·爬虫·python