Pytorch学习之Dataset类

一、前言

学习之前,先了解两个pytorch的助手函数,分别是dir()和help()。

dir()就是directory的缩写,也就是目录的意思。所以,dir()函数就是打开pytorch的目录,看看其下有什么功能、组件等。

打开看见一个 package 里面有什么功能,如:dir(pytorch),输出1,2,3。还可以逐级往下用 dir() 打开,如: dir(pytorch.3) ,输出a, b, c。

help()就是介绍这个组件的说明书。

python 复制代码
import torch
dir(torch) # 这句的输出里面含有'cuda'
python 复制代码
['AVG',
 'AggregationType',
 'AliasDb',
 'AnyType',
	...
 '_C',
 '_StorageBase',
 '_VF',
 '__all__',
 '__annotations__',
 '__builtins__',
	...
 'ctypes',
 'cuda',
 'cuda_path',
 'cuda_version',
 'cudnn_affine_grid_generator',
	...
 'sqrt_',
 ...] # 输出太长删减了不必要的部分
python 复制代码
dir(torch.cuda) # 这句的输出里面就可以找到 'is_available'
python 复制代码
['Any',
 'BFloat16Storage',
	...
 'is_available',
 'is_initialized',
 'list_gpu_processes',
	...
 'traceback',
 'warnings'] # 输出太长删减了不必要的部分
python 复制代码
help(torch.cuda.is_available)
python 复制代码
Help on function is_available in module torch.cuda:

is_available() -> bool
    Returns a bool indicating if CUDA is currently available.

二、Dataset类介绍

python 复制代码
from torch.utils.data import Dataset
import os
 
class MyDataSet(Dataset):
    def __init__(self, root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)
 
    def __getitem__(self, idx):
        image_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img, label
 
    def __len__(self):
        return len(self.img_path)

这段代码是借鉴的b站UP主------小土堆的,大家可以关注学习一下。P6. Dataset类代码实战_哔哩哔哩_bilibili

大致介绍一下这个类,里面三个函数,initgetitemlen

__init__可以理解成为整个类提供一个全局变量,方便其他函数使用这些变量;

__getitem__可以理解为获取训练所需要的数据和标签;

__len__理解为获取数据量的长度。

这里只写了一个简单方便理解的,后续会在这几个函数里面,进行不断的变换,但最终目的差不多。

三、调用

python 复制代码
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label)
bees_dataset=MyData(root_dir,bees_label)

img,label=ants_dataset[0]

len(ants_dataset)

train_dataset=ants_dataset+bees_dataset

直接调用定义的dataset类即可,也可以将两个数据集进行叠加生成一个新的训练集。

相关推荐
丝斯20111 小时前
AI学习笔记整理(67)——大模型的Benchmark(基准测试)
人工智能·笔记·学习
whale fall1 小时前
2026 年 1-3 月雅思口语完整话题清单(1-4 月通用最终版)
笔记·学习
xian_wwq1 小时前
【学习笔记】对网络安全“三化六防挂图作战”的理解与思考
笔记·学习·三化六防
AI视觉网奇2 小时前
metahuman 购买安装记录
笔记·学习·ue5
winfreedoms3 小时前
java-网络编程——黑马程序员学习笔记
java·网络·学习
五VV3 小时前
【ESP32】SP3手柄与ESP32连接不上问题解决
经验分享·学习
墨黎芜4 小时前
SQL Server从入门到精通——C#与数据库
数据库·学习·信息可视化
wdfk_prog4 小时前
[Linux]学习笔记系列 -- [drivers][dma]stm32-dma
linux·笔记·学习
暖阳之下4 小时前
学习周报三十三
学习
写点什么呢4 小时前
Ltspice_安装与使用
学习·测试工具