深度学习Python基础(2)

二 数据处理

一般来说PyTorch中深度学习训练的流程是这样的:

  1. 创建Dateset

  2. Dataset传递给DataLoader

  3. DataLoader迭代产生训练数据提供给模型

对应的一般都会有这三部分代码

创建Dateset(可以自定义)

dataset = face_dataset # Dataset部分自定义过的face_dataset

Dataset传递给DataLoader

dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle=False,num_workers=8)

DataLoader迭代产生训练数据提供给模型

for i in range(epoch):

for index,(img,label) in enumerate(dataloader):

pass

到这里应该就PyTorch的数据集和数据传递机制应该就比较清晰明了了。Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset的getitem方法。其中,还会涉及数据的变化形式。

1 . 数据收集

找数据集,注意数据集格式.

Dataset是DataLoader实例化的一个参数。

CIFAR10是CV训练中经常使用到的一个数据集,在PyTorch中CIFAR10是一个写好的Dataset,我们使用时只需以下代码:

data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)

datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

用自己在一个文件夹中的数据作为数据集时可以使用ImageFolder这个方便的API。

FaceDataset = datasets.ImageFolder('./data', transform=img_transform)

如何自定义一个数据集

torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。

所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。

Pytorch提供两种数据集: Map式数据集Iterable式数据集

Map式数据集

一个Map式的数据集必须要重写getitem (self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map).

这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。

自定义类大致是这样的:

class CustomDataset(data.Dataset):#需要继承data.Dataset

def init(self):

TODO

1. Initialize file path or list of file names.

pass

def getitem(self, index):

TODO

1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).

2. Preprocess the data (e.g. torchvision.Transform).

3. Return a data pair (e.g. image and label).

#这里需要注意的是,第一步:read one data,是一个data

pass

def len(self):

You should change 0 to the total size of your dataset.

return 0

例子-1: 自己实验中写的一个例子:这里我们的图片文件储存在"./data/faces/"文件夹下,图片的名字并不是从1开始,而是从final_train_tag_dict.txt这个文件保存的字典中读取,label信息也是用这个文件中读取。大家可以照着上面的注释阅读这段代码。

from torch.utils import data

import numpy as np

from PIL import Image

class face_dataset(data.Dataset):

def init(self):

self.file_path = './data/faces/'

f=open("final_train_tag_dict.txt","r")

self.label_dict=eval(f.read()) # eval除了 计算 ,还 可以 将str转为dict

f.close()

def getitem(self,index):

label = list(self.label_dict.values())[index-1]

img_id = list(self.label_dict.keys())[index-1]

img_path = self.file_path+str(img_id)+".jpg"

img = np.array(Image.open(img_path))

return img,label

def len(self):

return len(self.label_dict)

Iterable式数据集

一个Iterable(迭代)式数据集是抽象类data.IterableDataset 的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。

2 . 数据划分

数据划分主要是路径的处理。

def makedir(new_dir):

if not os.path.exists(new_dir):

os.makedirs(new_dir)

检测路径是否存在,若不存在,则创建此路径。

dataset_dir = os.path.join("..", "..", "data", "RMB_data")

设置路径,将它们组合在一起。相对于Python文件所在位置的相对路径。

for root, dirs, files in os.walk(dataset_dir):

for sub_dir in dirs:

imgs = os.listdir(os.path.join(root, sub_dir))

imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))

random.shuffle(imgs)

img_count = len(imgs)

train_point = int(img_count * train_pct)

valid_point = int(img_count * (train_pct + valid_pct))

for i in range(img_count):

if i < train_point:

out_dir = os.path.join(train_dir, sub_dir)

elif i < valid_point:

out_dir = os.path.join(valid_dir, sub_dir)

else:

out_dir = os.path.join(test_dir, sub_dir)

makedir(out_dir)

target_path = os.path.join(out_dir, imgs[i])

src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

shutil.copy(src_path, target_path)

os.walk

每一层遍历:

root保存的就是当前遍历的文件夹的绝对路径;

dirs保存当前文件夹下的所有子文件夹的名称(仅一层,孙子文件夹不包括)

files保存当前文件夹下的所有文件的名称

其次,发现它的遍历文件方式,在图的遍历方式中,那可不就是深度遍历嘛!!

  1. os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。

******shutil.copy()******Python中的方法用于将源文件的内容复制到目标文件或目录。它还会保留文件的权限模式,但不会保留文件的其他元数据(例如文件的创建和修改时间)。源必须代表文件,但目标可以是文件或目录。如果目标是目录,则文件将使用源中的基本文件名复制到目标中。另外,目的地必须是可写的。如果目标是文件并且已经存在,则将其替换为源文件,否则将创建一个新文件。

3 . 图像预处理-transforms

3.1 图像标准化

transforms.Normalize(mean,std,inplace)

逐通道的标准化,每个通道先求出平均值和标准差,然后标准化。Inplace表示是否原地操作。

3 .2 图像裁剪

train_transform = transforms.Compose([

transforms.Resize((32, 32)),

transforms.RandomCrop(32, padding=4),

transforms.ToTensor(),

transforms.Normalize(norm_mean, norm_std),

])

(1)transforms.CenterCrop(size)

从图片中心截取size大小的图片。

(2)transforms.RandomCrop(size,padding,padding_mode)

随机裁剪区域。

(3)transforms.RandomResizedCrop(size,scale,ratio)

随机大小,随机长宽比的裁剪。

3 .3 图像旋转

(1)transforms.RandomHorizationalFlip(p)

依据概率p水平翻转。

(2)transforms.RandomVerticalFlip(p)

依据概率p垂直翻转。

(3)transforms.RandomRotation(degrees,resample,expand)

transforms方法

Transforms Methods

一、裁剪

1. transforms.CenterCrop

2. transforms.RandomCrop

3. transforms.RandomResizedCrop

4. transforms.FiveCrop

5. transforms.TenCrop

二、翻转和旋转

1. transforms.RandomHorizontalFlip

2. transforms.RandomVerticalFlip

3. transforms.RandomRotation

三、图像变换

• 1. transforms.Pad

• 2. transforms.ColorJitter

• 3. transforms.Grayscale

• 4. transforms.RandomGrayscale

• 5. transforms.RandomAffine

• 6. transforms.LinearTransformation

• 7. transforms.RandomErasing

• 8. transforms.Lambda

• 9. transforms.Resize

• 10. transforms.Totensor

• 11. transforms.Normalize

四、transforms的操作

• 1. transforms.RandomChoice

• 2. transforms.RandomApply

• 3. transforms.RandomOrder

train_transform = transforms.Compose([

transforms.Resize((224, 224)),

1 CenterCrop

transforms.CenterCrop(512), # 512

2 RandomCrop

transforms.RandomCrop(224, padding=16),

transforms.RandomCrop(224, padding=(16, 64)),

transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),

transforms.RandomCrop(512, pad_if_needed=True), # pad_if_needed=True

transforms.RandomCrop(224, padding=64, padding_mode='edge'),

transforms.RandomCrop(224, padding=64, padding_mode='reflect'),

transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),

3 RandomResizedCrop

transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),

4 FiveCrop

transforms.FiveCrop(112),

transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

5 TenCrop

transforms.TenCrop(112, vertical_flip=False),

transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),

1 Horizontal Flip

transforms.RandomHorizontalFlip(p=1),

2 Vertical Flip

transforms.RandomVerticalFlip(p=0.5),

3 RandomRotation

transforms.RandomRotation(90),

transforms.RandomRotation((90), expand=True),

transforms.RandomRotation(30, center=(0, 0)),

transforms.RandomRotation(30, center=(0, 0), expand=True), # expand only for center rotation

transforms.ToTensor(),

transforms.Normalize(norm_mean, norm_std),

])

4 . 数据加载- D ata Loader

相关推荐
Leweslyh1 小时前
物理信息神经网络(PINN)八课时教案
人工智能·深度学习·神经网络·物理信息神经网络
love you joyfully1 小时前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
该醒醒了~2 小时前
PaddlePaddle推理模型利用Paddle2ONNX转换成onnx模型
人工智能·paddlepaddle
小树苗1932 小时前
DePIN潜力项目Spheron解读:激活闲置硬件,赋能Web3与AI
人工智能·web3
凡人的AI工具箱2 小时前
每天40分玩转Django:Django测试
数据库·人工智能·后端·python·django·sqlite
qyq12 小时前
Django框架与ORM框架
后端·python·django
大多_C2 小时前
BERT outputs
人工智能·深度学习·bert
企业软文推广2 小时前
企业如何选择媒体发稿平台及相关事项?媒介盒子分享
python
Debroon2 小时前
乳腺癌多模态诊断解释框架:CNN + 可解释 AI 可视化
人工智能·神经网络·cnn
反方向的钟儿2 小时前
非结构化数据分析与应用(Unstructured data analysis and applications)(pt3)图像数据分析1
人工智能·计算机视觉·数据分析