【pytorch】dataset类的使用

一、dataset定义

在PyTorch中,Dataset是数据加载的核心抽象类,其使用流程主要分为自定义数据集和数据加载器配置两部分。

方法的说明:

meth:__getitem__根据索引返回单个样本及其标签。

meth:__len__返回数据集样本总数。

python 复制代码
class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

二、dataset类读取数据

python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import os


class MyDataset(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)


ants_dataset = MyDataset(root_dir=r'dataset/train', label_dir='ants')

bees_dataset = MyDataset(root_dir=r'dataset/train', label_dir='bees')

train_data = ants_dataset + bees_dataset
print(len(train_data))
print(train_data[0])

输出

相关推荐
hboot1 天前
AI工程师第四课 - 深度学习入门
pytorch·python·神经网络
weiwei228444 天前
神经网络模型导出及开放标准格式ONNX
pytorch·onnx
程序猿追13 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
闵孚龙13 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
bryant_meng13 天前
【VAE】From Pixels to Faces: Building a VAE from Scratch
pytorch·vae·log-sigma2·重参数
装不满的克莱因瓶13 天前
了解多标签图像分类方法——从Sigmoid输出到真实世界复杂视觉理解
人工智能·pytorch·python·深度学习·机器学习·分类·数据挖掘
冷小鱼13 天前
TensorFlow 2.21 进阶实战:从训练优化到生产部署的完整指南
人工智能·pytorch·python·tensorflow
冷小鱼14 天前
PyTorch 2.12 完全指南:从动态图到编译优化的深度学习框架演进
人工智能·pytorch·深度学习
IRevers14 天前
【大模型】Gemma4在ROCm和vLLM部署
人工智能·pytorch·深度学习·大模型·datawhale·vllm·amdev
盼小辉丶14 天前
PyTorch强化学习实战(14)——优先经验回放机制
pytorch·python·深度学习·强化学习