pytorch Dataset类代码学习

python 复制代码
from torch.utils.data import  Dataset
from PIL import Image
import os


class my_data(Dataset):
    def __init__(self, root_dir, label_dir): # 初始化类,根据这一个类,来创建特例实例需要调用的一个函数
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)



    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = my_data(root_dir, ants_label_dir)
bees_dataset = my_data(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset

在控制台中将上述代码粘贴:查看数据集等操作:

python 复制代码
  ...: from PIL import Image
  ...: import os
........................
  ...:     def __len__(self):
  ...:         return len(self.img_path)

创建数据集,包括路径与标签。还有蚂蚁的数据集。

python 复制代码
root_dir = "dataset\train"
ants_label_dir = "ants"
ants_dataset = my_data(root_dir, ants_label_dir)

然而,出现如下的一些报错:

OSError: [WinError 123] 文件名、目录名或卷标语法不正确。: 'dataset\train\\ants'

原因是:

python 复制代码
root_dir = "dataset/train"

斜画线反了,不能直接用复制粘贴里面来的。

完整读取数据集里的图片代码:

python 复制代码
root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = my_data(root_dir, ants_label_dir)
img, label = ants_dataset[1]
img.show()

如果读取出来的图片反复都是一张,则是因为:读取的是上一次成功读取的图片。

错误原因是在这句代码中:

python 复制代码
img, label = ants_dataset[1]

这句中的连接是逗号,并不是.

通过上述的语句,即可实现数据集图片的读取。

两个数据集的相加:

python 复制代码
train_dataset = ants_dataset + bees_dataset

在控制台中,使用同样的方法读取:

python 复制代码
len(ants_dataset)
输出:Out[23]: 124
len(bees_dataset)
输出:Out[24]: 121
img,label = train_dataset[123]
img.show()
img,label = train_dataset[124]
img.show()
相关推荐
@小博的博客2 小时前
【Linux探索学习】第二篇Linux的基本指令(2)——开启Linux学习第二篇
linux·运维·学习
格林威3 小时前
常规线扫描镜头有哪些类型?能做什么?
人工智能·深度学习·数码相机·算法·计算机视觉·视觉检测·工业镜头
lyx33136967594 小时前
#深度学习基础:神经网络基础与PyTorch
pytorch·深度学习·神经网络·参数初始化
007php0074 小时前
某大厂跳动面试:计算机网络相关问题解析与总结
java·开发语言·学习·计算机网络·mysql·面试·职场和发展
知识分享小能手4 小时前
微信小程序入门学习教程,从入门到精通,微信小程序核心 API 详解与案例(13)
前端·javascript·学习·react.js·微信小程序·小程序·vue
递归不收敛4 小时前
吴恩达机器学习课程(PyTorch 适配)学习笔记:3.3 推荐系统全面解析
pytorch·学习·机器学习
B站计算机毕业设计之家5 小时前
智慧交通项目:Python+YOLOv8 实时交通标志系统 深度学习实战(TT100K+PySide6 源码+文档)✅
人工智能·python·深度学习·yolo·计算机视觉·智慧交通·交通标志
又是忙碌的一天7 小时前
前端学习 JavaScript(2)
前端·javascript·学习
蒙奇D索大7 小时前
【数据结构】考研数据结构核心考点:二叉排序树(BST)全方位详解与代码实现
数据结构·笔记·学习·考研·算法·改行学it
玲娜贝儿--努力学习买大鸡腿版7 小时前
推荐算法学习笔记(十九)阿里SIM 模型
笔记·学习·推荐算法