一、前言
学习之前,先了解两个pytorch的助手函数,分别是dir()和help()。
dir()就是directory的缩写,也就是目录的意思。所以,dir()函数就是打开pytorch的目录,看看其下有什么功能、组件等。
打开看见一个 package
里面有什么功能,如:dir(pytorch)
,输出1,2,3
。还可以逐级往下用 dir()
打开,如: dir(pytorch.3)
,输出a, b, c。
help()就是介绍这个组件的说明书。
python
import torch
dir(torch) # 这句的输出里面含有'cuda'
python
['AVG',
'AggregationType',
'AliasDb',
'AnyType',
...
'_C',
'_StorageBase',
'_VF',
'__all__',
'__annotations__',
'__builtins__',
...
'ctypes',
'cuda',
'cuda_path',
'cuda_version',
'cudnn_affine_grid_generator',
...
'sqrt_',
...] # 输出太长删减了不必要的部分
python
dir(torch.cuda) # 这句的输出里面就可以找到 'is_available'
python
['Any',
'BFloat16Storage',
...
'is_available',
'is_initialized',
'list_gpu_processes',
...
'traceback',
'warnings'] # 输出太长删减了不必要的部分
python
help(torch.cuda.is_available)
python
Help on function is_available in module torch.cuda:
is_available() -> bool
Returns a bool indicating if CUDA is currently available.
二、Dataset类介绍
python
from torch.utils.data import Dataset
import os
class MyDataSet(Dataset):
def __init__(self, root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
image_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
这段代码是借鉴的b站UP主------小土堆的,大家可以关注学习一下。P6. Dataset类代码实战_哔哩哔哩_bilibili
大致介绍一下这个类,里面三个函数,init、getitem、len。
__init__可以理解成为整个类提供一个全局变量,方便其他函数使用这些变量;
__getitem__可以理解为获取训练所需要的数据和标签;
__len__理解为获取数据量的长度。
这里只写了一个简单方便理解的,后续会在这几个函数里面,进行不断的变换,但最终目的差不多。
三、调用
python
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label)
bees_dataset=MyData(root_dir,bees_label)
img,label=ants_dataset[0]
len(ants_dataset)
train_dataset=ants_dataset+bees_dataset
直接调用定义的dataset类即可,也可以将两个数据集进行叠加生成一个新的训练集。