dataset 类
功能与作用
- 在PyTorch中,Dataset 类是
torch.utils.data
模块的一部分,它是一个抽象的基类,用于定义了数据集加载和处理的标准接口。通过继承这个类并实现其方法,可以创建自定义的数据集来适应各种机器学习任务。
基本结构介绍
- 抽象基类定义 :是一个泛型类,使用
Generic[_T_co]
来表示它可以接受一个协变类型参数_T_co
。这个类是所有数据集类的基类,它定义了数据集应该遵循的基本接口。
Dataset
类的主要组成部分:-
文档字符串(docstring) :提供了关于类的使用和实现的详细说明。它指出所有映射键到数据样本的数据集都应该继承这个类,并且应该覆盖
__getitem__
方法来支持给定键的数据样本获取。它还提到子类可以选择性地覆盖__len__
方法来返回数据集的大小,这在很多情况下是有用的,比如在Sampler
实现和DataLoader
的默认选项中。此外,子类也可以选择性地实现__getitems__
方法来加速批量样本加载。 -
__getitem__
方法 :这是一个抽象方法,子类必须实现它。这个方法应该根据给定的索引返回对应的数据样本。如果子类没有实现这个方法,尝试获取数据样本时会抛出NotImplementedError
。 -
__getitems__
方法:这个方法被注释掉了,但它是可选的,用于加速批量样本的加载。如果实现,它应该接受一个样本索引列表,并返回一个样本列表。 -
__add__
方法 :这个方法允许将两个Dataset
对象相加,返回一个新的ConcatDataset
对象,该对象将两个数据集合并为一个连续的数据集。 -
__len__
方法的注释 :这部分注释说明为什么没有为Dataset
类提供一个默认的__len__
方法。正如之前解释的,如果子类没有实现__len__
方法,那么在尝试获取数据集大小时会抛出TypeError
,这是一种强制子类提供实现的方式。
-
python
class Dataset(Generic[_T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs an index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> _T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
# def __getitems__(self, indices: List) -> List[_T_co]:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher
def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
- Dataset 类定义了以下两个核心方法,任何自定义数据集都需要实现这些方法:
__len__(self)
:返回数据集中的样本总数。__getitem__(self, idx)
:根据给定的索引idx返回一个样本。这个样本可以是一个数据点,也可以是一个数据点及其对应的标签。
- 继承自Dataset的其他常用类:
- TensorDataset:用于处理由张量构成的数据集。它将输入张量和目标张量组合在一起,形成一个数据集。
- ImageFolder:用于从文件系统中加载图像数据集。它假设每个子目录代表一个类别,并将每个图像文件作为一个样本。
- ConcatDataset:用于将多个数据集合并成一个大的数据集。
- Subset:用于从一个大数据集中选择一个子集。
- ChainDataset:用于将多个数据集串联起来,使得它们可以像一个数据集一样被迭代。
使用方法
- 自定义一个数据处理类:
python
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 假设我们有一些自定义数据和标签
custom_data = [...]
custom_labels = [...]
# 创建自定义数据集
custom_dataset = CustomDataset(custom_data, custom_labels)
# 使用DataLoader来迭代自定义数据集
custom_data_loader = DataLoader(custom_dataset, batch_size=20, shuffle=True)
for batch_idx, (data, target) in enumerate(custom_data_loader):
# 在这里处理你的数据和目标
pass
dataset 类源码
python
# mypy: allow-untyped-defs
import bisect
import itertools
import math
import warnings
from typing import (
cast,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
from typing_extensions import deprecated
# No 'default_generator' in torch/__init__.pyi
from torch import default_generator, Generator, randperm, Tensor
__all__ = [
"Dataset",
"IterableDataset",
"TensorDataset",
"StackDataset",
"ConcatDataset",
"ChainDataset",
"Subset",
"random_split",
]
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_dict = Dict[str, _T_co]
_T_tuple = Tuple[_T_co, ...]
_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict)
class Dataset(Generic[_T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs an index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> _T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
# def __getitems__(self, indices: List) -> List[_T_co]:
# Not implemented to prevent false-positives in fetcher check in
# torch.utils.data._utils.fetch._MapDatasetFetcher
def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
class IterableDataset(Dataset[_T_co], Iterable[_T_co]):
r"""An iterable Dataset.
All datasets that represent an iterable of data samples should subclass it.
Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite :meth:`__iter__`, which would return an
iterator of samples in this dataset.
When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
iterator. When :attr:`num_workers > 0`, each worker process will have a
different copy of the dataset object, so it is often desired to configure
each copy independently to avoid having duplicate data returned from the
workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
process, returns information about the worker. It can be used in either the
dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
:attr:`worker_init_fn` option to modify each copy's behavior.
Example 1: splitting workload across all workers in :meth:`__iter__`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> # xdoctest: +SKIP("Fails on MacOS12")
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]
>>> # xdoctest: +REQUIRES(POSIX)
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
>>> # With even more workers
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Example 2: splitting workload across all workers using :attr:`worker_init_fn`::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
"""
def __add__(self, other: Dataset[_T_co]):
return ChainDataset([self, other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Args:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(
tensors[0].size(0) == tensor.size(0) for tensor in tensors
), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
class StackDataset(Dataset[_T_stack]):
r"""Dataset as a stacking of multiple datasets.
This class is useful to assemble different parts of complex input data, given as datasets.
Example:
>>> # xdoctest: +SKIP
>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {'image': images[0], 'text': texts[0]}
Args:
*args (Dataset): Datasets for stacking returned as tuple.
**kwargs (Dataset): Datasets for stacking returned as dict.
"""
datasets: Union[tuple, dict]
def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None:
if args:
if kwargs:
raise ValueError(
"Supported either ``tuple``- (via ``args``) or"
"``dict``- (via ``kwargs``) like input/output, but both types are given."
)
self._length = len(args[0]) # type: ignore[arg-type]
if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type]
raise ValueError("Size mismatch between datasets")
self.datasets = args
elif kwargs:
tmp = list(kwargs.values())
self._length = len(tmp[0]) # type: ignore[arg-type]
if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type]
raise ValueError("Size mismatch between datasets")
self.datasets = kwargs
else:
raise ValueError("At least one dataset should be passed")
def __getitem__(self, index):
if isinstance(self.datasets, dict):
return {k: dataset[index] for k, dataset in self.datasets.items()}
return tuple(dataset[index] for dataset in self.datasets)
def __getitems__(self, indices: list):
# add batched sampling support when parent datasets supports it.
if isinstance(self.datasets, dict):
dict_batch: List[_T_dict] = [{} for _ in indices]
for k, dataset in self.datasets.items():
if callable(getattr(dataset, "__getitems__", None)):
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
if len(items) != len(indices):
raise ValueError(
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, d_sample in zip(items, dict_batch):
d_sample[k] = data
else:
for idx, d_sample in zip(indices, dict_batch):
d_sample[k] = dataset[idx]
return dict_batch
# tuple data
list_batch: List[list] = [[] for _ in indices]
for dataset in self.datasets:
if callable(getattr(dataset, "__getitems__", None)):
items = dataset.__getitems__(indices) # type: ignore[attr-defined]
if len(items) != len(indices):
raise ValueError(
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, t_sample in zip(items, list_batch):
t_sample.append(data)
else:
for idx, t_sample in zip(indices, list_batch):
t_sample.append(dataset[idx])
tuple_batch: List[_T_tuple] = [tuple(sample) for sample in list_batch]
return tuple_batch
def __len__(self):
return self._length
class ConcatDataset(Dataset[_T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[_T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError(
"absolute value of index should not exceed dataset length"
)
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
@deprecated(
"`cummulative_sizes` attribute is renamed to `cumulative_sizes`",
category=FutureWarning,
)
def cummulative_sizes(self):
return self.cumulative_sizes
class ChainDataset(IterableDataset):
r"""Dataset for chaining multiple :class:`IterableDataset` s.
This class is useful to assemble different existing dataset streams. The
chaining operation is done on-the-fly, so concatenating large-scale
datasets with this class will be efficient.
Args:
datasets (iterable of IterableDataset): datasets to be chained together
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super().__init__()
self.datasets = datasets
def __iter__(self):
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
yield from d
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(
d, IterableDataset
), "ChainDataset only supports IterableDataset"
total += len(d) # type: ignore[arg-type]
return total
class Subset(Dataset[_T_co]):
r"""
Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
dataset: Dataset[_T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
if isinstance(idx, list):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]
def __getitems__(self, indices: List[int]) -> List[_T_co]:
# add batched sampling support when parent dataset supports it.
# see torch.utils.data._utils.fetch._MapDatasetFetcher
if callable(getattr(self.dataset, "__getitems__", None)):
return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined]
else:
return [self.dataset[self.indices[idx]] for idx in indices]
def __len__(self):
return len(self.indices)
def random_split(
dataset: Dataset[_T],
lengths: Sequence[Union[int, float]],
generator: Optional[Generator] = default_generator,
) -> List[Subset[_T]]:
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
Example:
>>> # xdoctest: +SKIP
>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
Args:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths or fractions of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
subset_lengths: List[int] = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(
math.floor(len(dataset) * frac) # type: ignore[arg-type]
)
subset_lengths.append(n_items_in_split)
remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
# add 1 to all the lengths in round-robin fashion until the remainder is 0
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(
f"Length of split at index {i} is 0. "
f"This might result in an empty dataset."
)
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!"
)
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload]
lengths = cast(Sequence[int], lengths)
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(itertools.accumulate(lengths), lengths)
]