Pytorch学习之Dataset类

一、前言

学习之前,先了解两个pytorch的助手函数,分别是dir()和help()。

dir()就是directory的缩写,也就是目录的意思。所以,dir()函数就是打开pytorch的目录,看看其下有什么功能、组件等。

打开看见一个 package 里面有什么功能,如:dir(pytorch),输出1,2,3。还可以逐级往下用 dir() 打开,如: dir(pytorch.3) ,输出a, b, c。

help()就是介绍这个组件的说明书。

python 复制代码
import torch
dir(torch) # 这句的输出里面含有'cuda'
python 复制代码
['AVG',
 'AggregationType',
 'AliasDb',
 'AnyType',
	...
 '_C',
 '_StorageBase',
 '_VF',
 '__all__',
 '__annotations__',
 '__builtins__',
	...
 'ctypes',
 'cuda',
 'cuda_path',
 'cuda_version',
 'cudnn_affine_grid_generator',
	...
 'sqrt_',
 ...] # 输出太长删减了不必要的部分
python 复制代码
dir(torch.cuda) # 这句的输出里面就可以找到 'is_available'
python 复制代码
['Any',
 'BFloat16Storage',
	...
 'is_available',
 'is_initialized',
 'list_gpu_processes',
	...
 'traceback',
 'warnings'] # 输出太长删减了不必要的部分
python 复制代码
help(torch.cuda.is_available)
python 复制代码
Help on function is_available in module torch.cuda:

is_available() -> bool
    Returns a bool indicating if CUDA is currently available.

二、Dataset类介绍

python 复制代码
from torch.utils.data import Dataset
import os
 
class MyDataSet(Dataset):
    def __init__(self, root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)
 
    def __getitem__(self, idx):
        image_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img, label
 
    def __len__(self):
        return len(self.img_path)

这段代码是借鉴的b站UP主------小土堆的,大家可以关注学习一下。P6. Dataset类代码实战_哔哩哔哩_bilibili

大致介绍一下这个类,里面三个函数,initgetitemlen

__init__可以理解成为整个类提供一个全局变量,方便其他函数使用这些变量;

__getitem__可以理解为获取训练所需要的数据和标签;

__len__理解为获取数据量的长度。

这里只写了一个简单方便理解的,后续会在这几个函数里面,进行不断的变换,但最终目的差不多。

三、调用

python 复制代码
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label)
bees_dataset=MyData(root_dir,bees_label)

img,label=ants_dataset[0]

len(ants_dataset)

train_dataset=ants_dataset+bees_dataset

直接调用定义的dataset类即可,也可以将两个数据集进行叠加生成一个新的训练集。

相关推荐
光影少年4 小时前
rust生态及学习路线,应用领域
开发语言·学习·rust
华清远见成都中心4 小时前
学习物联网可以做什么工作?
物联网·学习
折翼的恶魔4 小时前
前端学习之布局
前端·学习
递归不收敛4 小时前
吴恩达机器学习课程(PyTorch适配)学习笔记:1.2 优化算法实践
pytorch·学习·机器学习
代码or搬砖5 小时前
Git学习笔记(二)
笔记·git·学习
UpYoung!5 小时前
【Typora——MD编辑器】Typora最新 V1.12.1版:轻量级 Markdown 编辑器详细图文下载安装使用教程
学习·数学建模·编辑器·运维开发·个人开发
报错小能手5 小时前
linux学习笔记(26)计算机网络基础
linux·笔记·学习
做科研的周师兄6 小时前
中国逐日格点降水数据集V2(1960–2024,0.1°)
人工智能·学习·机器学习·支持向量机·聚类
拥友LikT6 小时前
计算机网络基础篇——如何学习计算机网络?
学习·计算机网络