Pytorch学习之Dataset类

一、前言

学习之前,先了解两个pytorch的助手函数,分别是dir()和help()。

dir()就是directory的缩写,也就是目录的意思。所以,dir()函数就是打开pytorch的目录,看看其下有什么功能、组件等。

打开看见一个 package 里面有什么功能,如:dir(pytorch),输出1,2,3。还可以逐级往下用 dir() 打开,如: dir(pytorch.3) ,输出a, b, c。

help()就是介绍这个组件的说明书。

python 复制代码
import torch
dir(torch) # 这句的输出里面含有'cuda'
python 复制代码
['AVG',
 'AggregationType',
 'AliasDb',
 'AnyType',
	...
 '_C',
 '_StorageBase',
 '_VF',
 '__all__',
 '__annotations__',
 '__builtins__',
	...
 'ctypes',
 'cuda',
 'cuda_path',
 'cuda_version',
 'cudnn_affine_grid_generator',
	...
 'sqrt_',
 ...] # 输出太长删减了不必要的部分
python 复制代码
dir(torch.cuda) # 这句的输出里面就可以找到 'is_available'
python 复制代码
['Any',
 'BFloat16Storage',
	...
 'is_available',
 'is_initialized',
 'list_gpu_processes',
	...
 'traceback',
 'warnings'] # 输出太长删减了不必要的部分
python 复制代码
help(torch.cuda.is_available)
python 复制代码
Help on function is_available in module torch.cuda:

is_available() -> bool
    Returns a bool indicating if CUDA is currently available.

二、Dataset类介绍

python 复制代码
from torch.utils.data import Dataset
import os
 
class MyDataSet(Dataset):
    def __init__(self, root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)
 
    def __getitem__(self, idx):
        image_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img, label
 
    def __len__(self):
        return len(self.img_path)

这段代码是借鉴的b站UP主------小土堆的,大家可以关注学习一下。P6. Dataset类代码实战_哔哩哔哩_bilibili

大致介绍一下这个类,里面三个函数,initgetitemlen

__init__可以理解成为整个类提供一个全局变量,方便其他函数使用这些变量;

__getitem__可以理解为获取训练所需要的数据和标签;

__len__理解为获取数据量的长度。

这里只写了一个简单方便理解的,后续会在这几个函数里面,进行不断的变换,但最终目的差不多。

三、调用

python 复制代码
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label)
bees_dataset=MyData(root_dir,bees_label)

img,label=ants_dataset[0]

len(ants_dataset)

train_dataset=ants_dataset+bees_dataset

直接调用定义的dataset类即可,也可以将两个数据集进行叠加生成一个新的训练集。

相关推荐
weixin_409383123 分钟前
a星学习记录 通过父节点从目的地格子坐标回溯起点
学习·cocos·a星
搞机械的假程序猿3 分钟前
普中51单片机学习笔记-DS1302实时时钟芯片
笔记·学习·51单片机
车载测试工程师7 分钟前
CAPL学习-SOME/IP交互层-值处理类函数2
学习·tcp/ip·以太网·capl·canoe
车载测试工程师7 分钟前
CAPL学习-SOME/IP交互层-值处理类函数1
学习·tcp/ip·交互·以太网·capl·canoe
专业开发者27 分钟前
学习模块:Wi-Fi 测试与认证
学习
Mr.Jessy5 小时前
JavaScript高级:构造函数与原型
开发语言·前端·javascript·学习·ecmascript
玄斎9 小时前
MySQL 单表操作通关指南:建库 / 建表 / 插入 / 增删改查
运维·服务器·数据库·学习·程序人生·mysql·oracle
im_AMBER11 小时前
Leetcode 78 识别数组中的最大异常值 | 镜像对之间最小绝对距离
笔记·学习·算法·leetcode
其美杰布-富贵-李12 小时前
HDF5文件学习笔记
数据结构·笔记·学习
d111111111d13 小时前
在STM32函数指针是什么,怎么使用还有典型应用场景。
笔记·stm32·单片机·嵌入式硬件·学习·算法