Pytorch学习之Dataset类

一、前言

学习之前,先了解两个pytorch的助手函数,分别是dir()和help()。

dir()就是directory的缩写,也就是目录的意思。所以,dir()函数就是打开pytorch的目录,看看其下有什么功能、组件等。

打开看见一个 package 里面有什么功能,如:dir(pytorch),输出1,2,3。还可以逐级往下用 dir() 打开,如: dir(pytorch.3) ,输出a, b, c。

help()就是介绍这个组件的说明书。

python 复制代码
import torch
dir(torch) # 这句的输出里面含有'cuda'
python 复制代码
['AVG',
 'AggregationType',
 'AliasDb',
 'AnyType',
	...
 '_C',
 '_StorageBase',
 '_VF',
 '__all__',
 '__annotations__',
 '__builtins__',
	...
 'ctypes',
 'cuda',
 'cuda_path',
 'cuda_version',
 'cudnn_affine_grid_generator',
	...
 'sqrt_',
 ...] # 输出太长删减了不必要的部分
python 复制代码
dir(torch.cuda) # 这句的输出里面就可以找到 'is_available'
python 复制代码
['Any',
 'BFloat16Storage',
	...
 'is_available',
 'is_initialized',
 'list_gpu_processes',
	...
 'traceback',
 'warnings'] # 输出太长删减了不必要的部分
python 复制代码
help(torch.cuda.is_available)
python 复制代码
Help on function is_available in module torch.cuda:

is_available() -> bool
    Returns a bool indicating if CUDA is currently available.

二、Dataset类介绍

python 复制代码
from torch.utils.data import Dataset
import os
 
class MyDataSet(Dataset):
    def __init__(self, root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)
 
    def __getitem__(self, idx):
        image_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img, label
 
    def __len__(self):
        return len(self.img_path)

这段代码是借鉴的b站UP主------小土堆的,大家可以关注学习一下。P6. Dataset类代码实战_哔哩哔哩_bilibili

大致介绍一下这个类,里面三个函数,initgetitemlen

__init__可以理解成为整个类提供一个全局变量,方便其他函数使用这些变量;

__getitem__可以理解为获取训练所需要的数据和标签;

__len__理解为获取数据量的长度。

这里只写了一个简单方便理解的,后续会在这几个函数里面,进行不断的变换,但最终目的差不多。

三、调用

python 复制代码
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label)
bees_dataset=MyData(root_dir,bees_label)

img,label=ants_dataset[0]

len(ants_dataset)

train_dataset=ants_dataset+bees_dataset

直接调用定义的dataset类即可,也可以将两个数据集进行叠加生成一个新的训练集。

相关推荐
西岸行者5 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意5 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码5 天前
嵌入式学习路线
学习
毛小茛5 天前
计算机系统概论——校验码
学习
babe小鑫5 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms5 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下5 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。5 天前
2026.2.25监控学习
学习
im_AMBER5 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J5 天前
从“Hello World“ 开始 C++
c语言·c++·学习