一、dataset定义
在PyTorch中,Dataset是数据加载的核心抽象类,其使用流程主要分为自定义数据集和数据加载器配置两部分。
方法的说明:
meth:__getitem__根据索引返回单个样本及其标签。
meth:__len__返回数据集样本总数。
python
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
二、dataset类读取数据
python
from torch.utils.data import Dataset
from PIL import Image
import os
class MyDataset(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
ants_dataset = MyDataset(root_dir=r'dataset/train', label_dir='ants')
bees_dataset = MyDataset(root_dir=r'dataset/train', label_dir='bees')
train_data = ants_dataset + bees_dataset
print(len(train_data))
print(train_data[0])
输出
