【Pytorch】加载数据

数据集获取:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq

本文基于P5. PyTorch加载数据初认识_哔哩哔哩_bilibili

dataset:提供一种方式去获取数据及其label值,解释:Pytorch中的dataset类------创建适应任意模型的数据集接口_datasetpath-CSDN博客

dataloader:为网络提供不同的数据形式

首先新建一个python文件:read_data

把数据集文件与代码文件放在同一目录下

找到图片,复制路径。

read_data文件代码:

python 复制代码
from torch.utils.data import Dataset
# 读取图片
from PIL import Image
import os


# Dataset 是 PyTorch 的数据集基类。
# Image 用于打开和处理图片。
# os 用于处理文件路径。

# MyData 类继承自 PyTorch 的 Dataset 类,需要实现三个方法:__init__()、__getitem__() 和 __len__()。
class MyData(Dataset):
    # 初始化s
    def __init__(self, root_dir, label_dir):
        # self.root_dir和self.label_dir分别保存图像数据的根目录和标签目录。
        # self.path是root_dir 和 label_dir的连接路径。
        # self.img_path是指定目录下所有文件的列表,即图像文件的名称。
        # 路径
        self.root_dir = root_dir
        # 标签名
        self.label_dir = label_dir
        # 拼接成路径名
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 获取所有图片的编号
        self.img_path = os.listdir(self.path)

    # 传编号
    def __getitem__(self, idx):
        # idx是数据集中的索引。
        # img_name是根据索引获取的图像文件名称。
        # img_item_path是图像的完整路径。
        # Image.open(img_item_path)用于打开图像文件。
        # label是图像的标签(在这个例子中,标签是目录名)。
        # return img, label返回图像和标签的元组。

        # 当前图片的名字
        img_name = self.img_path[idx]
        # 当前图片的地址
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)

        # 打开图片
        # Image.open()返回值是PIL类型格式,可以直接图片展示
        img = Image.open(img_item_path)
        label = self.label_dir

        # 返回样本对{x:y}
        return img, label

    def __len__(self):
        # 返回数据集中图像的数量,即img_path列表的长度。
        # 返回长度
        return len(self.img_path)

# root_dir 是数据的根目录。
# ants_label_dir 和 bees_label_dir 是两个标签目录,分别代表蚂蚁和蜜蜂的图像数据。
# ants_dataset 和 bees_dataset 分别是两个 MyData 实例,表示蚂蚁和蜜蜂的图像数据集。
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset

进阶版:

cpp 复制代码
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import ConcatDataset
import numpy as np
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

# Dataset 和 DataLoader 用于创建和加载数据集。
# ConcatDataset 用于合并多个数据集。
# Image 用于打开和处理图像。
# os 用于处理文件路径。
# transforms 用于图像预处理。
# SummaryWriter 用于 TensorBoard 日志记录。
# make_grid 用于将多个图像合并成一个网格图像。

writer = SummaryWriter("logs")

class MyData(Dataset):

    def __init__(self, root_dir, image_dir, label_dir, transform):
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir

        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.path.join(self.root_dir, self.image_dir)

        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)

        # 应用于图像的转换操作(如调整大小和转换为 Tensor)
        self.transform = transform
        # 因为label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
        self.image_list.sort()
        self.label_list.sort()

    def __getitem__(self, idx):
        # 根据索引idx获取图像和标签。
        # img_item_path和label_item_path是图像和标签的完整路径。
        # Image.open(img_item_path)
        # 打开图像文件。
        img_name = self.image_list[idx]
        label_name = self.label_list[idx]

        img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
        label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)

        #获取图片文件
        img = Image.open(img_item_path)

        # 读取标签文件的内容。
        with open(label_item_path, 'r') as f:
            label = f.readline()

        # 应用转换操作self.transform。
        img = self.transform(img)

        # 返回一个字典,包含图像和标签。
        sample = {'img': img, 'label': label}
        return sample

    def __len__(self):
        # 确保图像和标签的数量相同。
        # 返回数据集中图像的数量。
        assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)

if __name__ == '__main__':
    # transform定义了图像预处理操作。
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

    root_dir = "dataset/train"
    image_ants = "ants_image"
    label_ants = "ants_label"
    ants_dataset = MyData(root_dir, image_ants, label_ants, transform)

    image_bees = "bees_image"
    label_bees = "bees_label"
    bees_dataset = MyData(root_dir, image_bees, label_bees, transform)

    train_dataset = ants_dataset + bees_dataset

    # 使用DataLoader创建一个数据加载器,batch_size = 1和num_workers = 2。
    dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)

    # 使用SummaryWriter将索引为119的图像写入TensorBoard。
    writer.add_image('error', train_dataset[119]['img'])
    writer.close()
相关推荐
Stara051137 分钟前
基于多头自注意力机制(MHSA)增强的YOLOv11主干网络—面向高精度目标检测的结构创新与性能优化
人工智能·python·深度学习·神经网络·目标检测·计算机视觉·yolov11
YuSun_WK40 分钟前
目标跟踪相关综述文章
人工智能·计算机视觉·目标跟踪
一切皆有可能!!44 分钟前
RAG数据处理:PDF/HTML
人工智能·语言模型
kyle~1 小时前
深度学习---知识蒸馏(Knowledge Distillation, KD)
人工智能·深度学习
那雨倾城2 小时前
使用 OpenCV 将图像中标记特定颜色区域
人工智能·python·opencv·计算机视觉·视觉检测
whoarethenext2 小时前
c/c++的opencv的图像预处理讲解
人工智能·opencv·计算机视觉·预处理
金融小师妹3 小时前
应用BERT-GCN跨模态情绪分析:贸易缓和与金价波动的AI归因
大数据·人工智能·算法
武子康3 小时前
大语言模型 10 - 从0开始训练GPT 0.25B参数量 补充知识之模型架构 MoE、ReLU、FFN、MixFFN
大数据·人工智能·gpt·ai·语言模型·自然语言处理
广州智造3 小时前
OptiStruct实例:3D实体转子分析
数据库·人工智能·算法·机器学习·数学建模·3d·性能优化
LuckyTHP4 小时前
java 使用zxing生成条形码(可自定义文字位置、边框样式)
java·开发语言·python