Pytorch中的torch.utils.data.Dataset 类

1、使用方法

python 复制代码
from torch.utils.data import Dataset

2、torch.utils.data.Dataset 类的定义

python 复制代码
class Dataset(Generic[_T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`. Subclasses could also
    optionally implement :meth:`__getitems__`, for speedup batched samples
    loading. This method accepts list of indices of samples of batch and returns
    list of samples.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs an index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> _T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    # def __getitems__(self, indices: List) -> List[_T_co]:
    # Not implemented to prevent false-positives in fetcher check in
    # torch.utils.data._utils.fetch._MapDatasetFetcher

    def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

原文解释:

  1. 所有表示从键到数据样本映射的数据集都应该继承自这个类。

    这意味着,如果你有一个数据集,它通过某些键(可能是整数、字符串等)来访问数据样本,那么你应该从 Dataset 类继承来创建你的数据集类。

  2. 所有的子类都应该重写 __getitem__ 方法,支持通过给定的键获取数据样本。
    __getitem__ 是 Python 的特殊方法,用于通过 dataset[key] 这样的语法来获取数据。在你的子类中,你需要实现这个方法,确保它能够返回与给定键对应的数据样本。

  3. 子类也可以选择性地重写 __len__ 方法,该方法通常被许多 Sampler 实现和 DataLoader 的默认选项所使用,用于返回数据集的大小。
    __len__ 方法用于返回数据集中样本的总数。虽然它不是强制要求的,但如果你希望使用 PyTorch 的 SamplerDataLoader,通常需要实现这个方法。

  4. 子类还可以选择性地实现 __getitems__ 方法,以加速批量数据加载。这个方法接受一个包含批次样本索引的列表,并返回一个样本列表。
    __getitems__ 是一个可选的优化方法。如果你需要批量加载数据,实现这个方法可以提高效率。它接受一个索引列表,并返回对应的样本列表。

  5. DataLoader 默认构造一个生成整数索引的采样器(sampler)。要使它能够与具有非整数索引/键的 map-style 数据集一起工作,必须提供一个自定义的采样器。
    DataLoader 默认情况下假设你的数据集是可以通过整数索引访问的(即 dataset[0], dataset[1] 等)。如果你的数据集使用非整数键(比如字符串或其他类型),你需要提供一个自定义的采样器来生成这些键。

3、示例

示例 1:简单的整数索引数据集

假设我们有一个数据集,数据储存在一个列表中,我们可以通过整数索引来访问。

python 复制代码
from torch.utils.data import Dataset

class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# 示例数据
data = [1, 2, 3, 4, 5]

# 创建数据集
dataset = SimpleDataset(data)

# 使用
print(dataset[0])  # 输出:1
print(len(dataset))  # 输出:5

示例 2:字符串键的数据集

假设我们有一个数据集,数据以字典形式存储,键是字符串。

python 复制代码
from torch.utils.data import Dataset

class StringKeyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.keys = list(data.keys())

    def __getitem__(self, key):
        return self.data[key]

    def __len__(self):
        return len(self.keys)

# 示例数据
data = {"a": 1, "b": 2, "c": 3}

# 创建数据集
dataset = StringKeyDataset(data)

# 使用
print(dataset["a"])  # 输出:1
print(len(dataset))  # 输出:3

注意:如果需要与 DataLoader 一起使用,必须提供一个自定义的采样器,因为默认的采样器生成整数索引。

示例 3:实现 __getitems__ 方法

为了实现批量加载数据,我们可以实现 __getitems__ 方法。

python 复制代码
from torch.utils.data import Dataset

class BatchableDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __getitems__(self, indices):
        return [self.data[i] for i in indices]

    def __len__(self):
        return len(self.data)

# 示例数据
data = [10, 20, 30, 40, 50]

# 创建数据集
dataset = BatchableDataset(data)

# 使用
print(dataset[0])  # 输出:10
print(dataset.__getitems__([1, 3]))  # 输出:[20, 40]

示例 4:图像数据集

假设我们有一个图像数据集,图像路径存储在列表中。

python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.img_names = os.listdir(img_dir)

    def __getitem__(self, index):
        img_name = self.img_names[index]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)
        return image

    def __len__(self):
        return len(self.img_names)

# 创建数据集
img_dir = "path/to/images"
dataset = ImageDataset(img_dir)

# 使用
print(len(dataset))  # 输出图像数量
print(dataset[0])  # 输出第一张图像

示例 5:自定义采样器

如果你的数据集使用非整数键(如字符串),并且你想与 DataLoader 一起使用,可以定义一个自定义采样器。

python 复制代码
from torch.utils.data import Dataset, DataLoader, Sampler
import random

class StringKeyDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.keys = list(data.keys())

    def __getitem__(self, key):
        return self.data[key]

    def __len__(self):
        return len(self.keys)

class StringSampler(Sampler):
    def __init__(self, keys):
        self.keys = keys

    #每次调用时(如新的epoch开始),先打乱键的顺序,再返回迭代器。
    #实现数据加载时的随机化顺序。
    def __iter__(self):
        random.shuffle(self.keys)
        return iter(self.keys)

    def __len__(self):
        return len(self.keys)

# 示例数据
data = {"a": 1, "b": 2, "c": 3}

# 创建数据集和采样器
dataset = StringKeyDataset(data)
sampler = StringSampler(dataset.keys)

# 使用 DataLoader
dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)

for batch in dataloader:
    print(batch)  # 输出批次数据
相关推荐
心态与习惯2 小时前
深度学习中的 seq2seq 模型
人工智能·深度学习·seq2seq
2501_944526423 小时前
Flutter for OpenHarmony 万能游戏库App实战 - 蜘蛛纸牌游戏实现
android·java·python·flutter·游戏
AI即插即用3 小时前
即插即用系列 | CVPR 2025 AmbiSSL:首个注释模糊感知的半监督医学图像分割框架
图像处理·人工智能·深度学习·计算机视觉·视觉检测
飞Link4 小时前
【Django】Django的静态文件相关配置与操作
后端·python·django
Ulyanov4 小时前
从桌面到云端:构建Web三维战场指挥系统
开发语言·前端·python·tkinter·pyvista·gui开发
CCPC不拿奖不改名5 小时前
两种完整的 Git 分支协作流程
大数据·人工智能·git·python·elasticsearch·搜索引擎·自然语言处理
Coding茶水间5 小时前
基于深度学习的交通标志检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
开发语言·人工智能·深度学习·yolo·目标检测·机器学习
飞Link5 小时前
【论文笔记】《Deep Learning for Time Series Anomaly Detection: A Survey》
rnn·深度学习·神经网络·cnn·transformer
a努力。5 小时前
字节Java面试被问:TCP的BBR拥塞控制算法原理
java·开发语言·python·tcp/ip·elasticsearch·面试·职场和发展
费弗里5 小时前
一个小技巧轻松提升Dash应用debug效率
python·dash