Pytorch学习之Dataset类

一、前言

学习之前,先了解两个pytorch的助手函数,分别是dir()和help()。

dir()就是directory的缩写,也就是目录的意思。所以,dir()函数就是打开pytorch的目录,看看其下有什么功能、组件等。

打开看见一个 package 里面有什么功能,如:dir(pytorch),输出1,2,3。还可以逐级往下用 dir() 打开,如: dir(pytorch.3) ,输出a, b, c。

help()就是介绍这个组件的说明书。

python 复制代码
import torch
dir(torch) # 这句的输出里面含有'cuda'
python 复制代码
['AVG',
 'AggregationType',
 'AliasDb',
 'AnyType',
	...
 '_C',
 '_StorageBase',
 '_VF',
 '__all__',
 '__annotations__',
 '__builtins__',
	...
 'ctypes',
 'cuda',
 'cuda_path',
 'cuda_version',
 'cudnn_affine_grid_generator',
	...
 'sqrt_',
 ...] # 输出太长删减了不必要的部分
python 复制代码
dir(torch.cuda) # 这句的输出里面就可以找到 'is_available'
python 复制代码
['Any',
 'BFloat16Storage',
	...
 'is_available',
 'is_initialized',
 'list_gpu_processes',
	...
 'traceback',
 'warnings'] # 输出太长删减了不必要的部分
python 复制代码
help(torch.cuda.is_available)
python 复制代码
Help on function is_available in module torch.cuda:

is_available() -> bool
    Returns a bool indicating if CUDA is currently available.

二、Dataset类介绍

python 复制代码
from torch.utils.data import Dataset
import os
 
class MyDataSet(Dataset):
    def __init__(self, root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)
 
    def __getitem__(self, idx):
        image_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img, label
 
    def __len__(self):
        return len(self.img_path)

这段代码是借鉴的b站UP主------小土堆的,大家可以关注学习一下。P6. Dataset类代码实战_哔哩哔哩_bilibili

大致介绍一下这个类,里面三个函数,initgetitemlen

__init__可以理解成为整个类提供一个全局变量,方便其他函数使用这些变量;

__getitem__可以理解为获取训练所需要的数据和标签;

__len__理解为获取数据量的长度。

这里只写了一个简单方便理解的,后续会在这几个函数里面,进行不断的变换,但最终目的差不多。

三、调用

python 复制代码
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label)
bees_dataset=MyData(root_dir,bees_label)

img,label=ants_dataset[0]

len(ants_dataset)

train_dataset=ants_dataset+bees_dataset

直接调用定义的dataset类即可,也可以将两个数据集进行叠加生成一个新的训练集。

相关推荐
lingggggaaaa28 分钟前
免杀对抗——C2远控篇&PowerShell&有无文件落地&C#参数调用&绕AMSI&ETW&去混淆特征
c语言·开发语言·笔记·学习·安全·microsoft·c#
wdfk_prog3 小时前
[Linux]学习笔记系列 -- [kernel]workqueue
linux·笔记·学习
wdfk_prog3 小时前
[Linux]学习笔记系列 -- [kernel]usermode_helper
linux·笔记·学习
冬夜戏雪3 小时前
【学习日记】【刷题回溯、贪心、动规】
学习
一只爱做笔记的码农3 小时前
【BootstrapBlazor】移植BootstrapBlazor VS工程到Vscode工程,报error blazor106的问题
笔记·学习·c#
xixixi777774 小时前
“C2隐藏”——命令与控制服务器的隐藏技术
网络·学习·安全·代理·隐藏·合法服务·c2隐藏
名字不相符4 小时前
攻防世界WEB难度一(个人记录)
学习·php·web·萌新
陈天伟教授4 小时前
基于学习的人工智能(4)机器学习基本框架
人工智能·学习·机器学习
7***37455 小时前
DeepSeek在文本分类中的多标签学习
学习·分类·数据挖掘
jiushun_suanli5 小时前
量子纠缠:颠覆认知的宇宙密码
经验分享·学习·量子计算