Pytorch使用Dataset加载数据

1、前言:

在阅读之前,需要配置好对应pytorch版本。

对于一般学习,使用cpu版本的即可。参考教程点我

导入pytorch包,使用如下命令即可。

python 复制代码
import torch   # 注意虽然叫pytorch,但是在引用时是引用torch

2、神经网络获取数据

神经网络获取数据主要用到Dataset和Dataloader 两个方法
Dataset 主要用于获取数据以及对应的真实label
Dataloader 主要为后面的网络提供不同的数据形式

在torch.utils.data包内提供了DataSet类,可在Pytorch官网看到对应的描述

python 复制代码
class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`. Subclasses could also
    optionally implement :meth:`__getitems__`, for speedup batched samples
    loading. This method accepts list of indices of samples of batch and returns
    list of samples.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs an index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    # def __getitems__(self, indices: List) -> List[T_co]:
    # Not implemented to prevent false-positives in fetcher check in
    # torch.utils.data._utils.fetch._MapDatasetFetcher

    def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

根据上述描述可知,Dataset是一个抽象类,用于表示数据集。你可以通过继承这个类并实现以下方法来自定义数据集:

python 复制代码
__len__(self): 返回数据集的大小,即数据集中有多少个样本。
__getitem__(self, idx): 根据索引 idx 返回数据集中的一个样本和对应的标签。

3、案例

使用Dataset读取文件夹E:\Python_learning\Deep_learning\dataset\hymenoptera_data\train\ants下所有图片。并获取对应的label,该数据集的文件夹的名字为对应的标签,而文件夹内为对应的训练集的图片

python 复制代码
import os
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms


class MyDataset(Dataset):
    def __init__(self, root_path, label):
        self.root_path = root_path
        self.label = label
        self.img_path = os.path.join(root_path, label)  # 拼接路径
        print(f"图片路径: {self.img_path}")  # 打印路径以进行调试
        try:
            self.img_path_list = os.listdir(self.img_path)  # 列出文件夹中的文件
            print(f"图片列表: {self.img_path_list}")  # 打印图片列表以进行调试
        except PermissionError as e:
            print(f"权限错误: {e}")
        except FileNotFoundError as e:
            print(f"文件未找到错误: {e}")

    def __getitem__(self, index):
        img_index = self.img_path_list[index]
        img_path = os.path.join(self.img_path, img_index)
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"读取图片时出错: {e}, 图片路径: {img_path}")
            raise e
        label = self.label
        return img, label

    def __len__(self):
        return len(self.img_path_list)


# 实例化这个类
my_data = MyDataset(root_path=r'E:\Python_learning\Deep_learning\dataset\hymenoptera_data\train', label='ants')
writer = SummaryWriter('logs')
for i in range(my_data.__len__()):
    img, label = my_data[i]  # 依次获取对应的图片
    # 此处img为PIL Image, 使用transforms中的ToTensor方法转化为tensor格式
    writer.add_image(tag=label, img_tensor=transforms.ToTensor()(img), global_step=i)
writer.close()
print(f"当前文件夹下{i + 1}张图片已读取完毕,请在Tensorboard中查看")

在控制台输入tensorboard --logdir='E:\Python_learning\Deep_learning\note\logs'打开tensorboard查看

相关推荐
咚咚王者4 分钟前
人工智能之数据分析 Pandas:第六章 数据清洗
人工智能·数据分析·pandas
geneculture7 分钟前
融合全部讨论精华的融智学认知与实践总览图:掌握在复杂世界中锚定自我、有效行动、并参与塑造近未来的元能力
大数据·人工智能·数据挖掘·信息科学·融智学的重要应用·信智序位·全球软件定位系统
闲人编程7 分钟前
GraphQL与REST API对比与实践
后端·python·api·graphql·rest·codecapsule
永霖光电_UVLED11 分钟前
安森美与英诺赛科将合作推进氮化镓(GaN)功率器件的量产应用
人工智能·神经网络·生成对抗网络
Dev7z13 分钟前
基于深度学习的脑肿瘤自动诊断和分析系统的研究与实现(Web界面+数据集+训练代码)
人工智能·深度学习
珠海西格电力16 分钟前
零碳园区数字感知基础架构规划:IoT 设备布点与传输管网衔接设计
大数据·运维·人工智能·物联网·智慧城市·能源
AI即插即用19 分钟前
即插即用系列 | WACV 2024 D-LKA:超越 Transformer?D-LKA Net 如何用可变形大核卷积刷新医学图像分割
图像处理·人工智能·深度学习·目标检测·计算机视觉·视觉检测·transformer
草莓熊Lotso23 分钟前
《算法闯关指南:动态规划算法--斐波拉契数列模型》--04.解码方法
c++·人工智能·算法·动态规划
winfredzhang27 分钟前
深入剖析 wxPython 配置文件编辑器
python·编辑器·wxpython·ini配置
狂放不羁霸27 分钟前
电子科技大学2025年机器学习期末考试回忆
人工智能·机器学习