二 数据处理
一般来说PyTorch中深度学习训练的流程是这样的:
-
创建Dateset
-
Dataset传递给DataLoader
-
DataLoader迭代产生训练数据提供给模型
对应的一般都会有这三部分代码
创建Dateset(可以自定义)
dataset = face_dataset # Dataset部分自定义过的face_dataset
Dataset传递给DataLoader
dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle=False,num_workers=8)
DataLoader迭代产生训练数据提供给模型
for i in range(epoch):
for index,(img,label) in enumerate(dataloader):
pass
到这里应该就PyTorch的数据集和数据传递机制应该就比较清晰明了了。Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset的getitem方法。其中,还会涉及数据的变化形式。
1 . 数据收集
找数据集,注意数据集格式.
Dataset是DataLoader实例化的一个参数。
CIFAR10是CV训练中经常使用到的一个数据集,在PyTorch中CIFAR10是一个写好的Dataset,我们使用时只需以下代码:
data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。
用自己在一个文件夹中的数据作为数据集时可以使用ImageFolder这个方便的API。
FaceDataset = datasets.ImageFolder('./data', transform=img_transform)
如何自定义一个数据集
torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。
所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。
Pytorch提供两种数据集: Map式数据集Iterable式数据集
Map式数据集
一个Map式的数据集必须要重写getitem (self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map).
这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。
自定义类大致是这样的:
class CustomDataset(data.Dataset):#需要继承data.Dataset
def init(self):
TODO
1. Initialize file path or list of file names.
pass
def getitem(self, index):
TODO
1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
2. Preprocess the data (e.g. torchvision.Transform).
3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def len(self):
You should change 0 to the total size of your dataset.
return 0
例子-1: 自己实验中写的一个例子:这里我们的图片文件储存在"./data/faces/"文件夹下,图片的名字并不是从1开始,而是从final_train_tag_dict.txt这个文件保存的字典中读取,label信息也是用这个文件中读取。大家可以照着上面的注释阅读这段代码。
from torch.utils import data
import numpy as np
from PIL import Image
class face_dataset(data.Dataset):
def init(self):
self.file_path = './data/faces/'
f=open("final_train_tag_dict.txt","r")
self.label_dict=eval(f.read()) # eval除了 计算 ,还 可以 将str转为dict
f.close()
def getitem(self,index):
label = list(self.label_dict.values())[index-1]
img_id = list(self.label_dict.keys())[index-1]
img_path = self.file_path+str(img_id)+".jpg"
img = np.array(Image.open(img_path))
return img,label
def len(self):
return len(self.label_dict)
Iterable式数据集
一个Iterable(迭代)式数据集是抽象类data.IterableDataset 的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。
2 . 数据划分
数据划分主要是路径的处理。
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
检测路径是否存在,若不存在,则创建此路径。
dataset_dir = os.path.join("..", "..", "data", "RMB_data")
设置路径,将它们组合在一起。相对于Python文件所在位置的相对路径。
for root, dirs, files in os.walk(dataset_dir):
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
os.walk
每一层遍历:
root保存的就是当前遍历的文件夹的绝对路径;
dirs保存当前文件夹下的所有子文件夹的名称(仅一层,孙子文件夹不包括)
files保存当前文件夹下的所有文件的名称
其次,发现它的遍历文件方式,在图的遍历方式中,那可不就是深度遍历嘛!!
- os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
******shutil.copy()******Python中的方法用于将源文件的内容复制到目标文件或目录。它还会保留文件的权限模式,但不会保留文件的其他元数据(例如文件的创建和修改时间)。源必须代表文件,但目标可以是文件或目录。如果目标是目录,则文件将使用源中的基本文件名复制到目标中。另外,目的地必须是可写的。如果目标是文件并且已经存在,则将其替换为源文件,否则将创建一个新文件。
3 . 图像预处理-transforms
3.1 图像标准化
transforms.Normalize(mean,std,inplace)
逐通道的标准化,每个通道先求出平均值和标准差,然后标准化。Inplace表示是否原地操作。
3 .2 图像裁剪
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
(1)transforms.CenterCrop(size)
从图片中心截取size大小的图片。
(2)transforms.RandomCrop(size,padding,padding_mode)
随机裁剪区域。
(3)transforms.RandomResizedCrop(size,scale,ratio)
随机大小,随机长宽比的裁剪。
3 .3 图像旋转
(1)transforms.RandomHorizationalFlip(p)
依据概率p水平翻转。
(2)transforms.RandomVerticalFlip(p)
依据概率p垂直翻转。
(3)transforms.RandomRotation(degrees,resample,expand)
transforms方法
Transforms Methods
一、裁剪
1. transforms.CenterCrop
2. transforms.RandomCrop
3. transforms.RandomResizedCrop
4. transforms.FiveCrop
5. transforms.TenCrop
二、翻转和旋转
1. transforms.RandomHorizontalFlip
2. transforms.RandomVerticalFlip
3. transforms.RandomRotation
三、图像变换
• 1. transforms.Pad
• 2. transforms.ColorJitter
• 3. transforms.Grayscale
• 4. transforms.RandomGrayscale
• 5. transforms.RandomAffine
• 6. transforms.LinearTransformation
• 7. transforms.RandomErasing
• 8. transforms.Lambda
• 9. transforms.Resize
• 10. transforms.Totensor
• 11. transforms.Normalize
四、transforms的操作
• 1. transforms.RandomChoice
• 2. transforms.RandomApply
• 3. transforms.RandomOrder
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
1 CenterCrop
transforms.CenterCrop(512), # 512
2 RandomCrop
transforms.RandomCrop(224, padding=16),
transforms.RandomCrop(224, padding=(16, 64)),
transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),
transforms.RandomCrop(512, pad_if_needed=True), # pad_if_needed=True
transforms.RandomCrop(224, padding=64, padding_mode='edge'),
transforms.RandomCrop(224, padding=64, padding_mode='reflect'),
transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),
3 RandomResizedCrop
transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),
4 FiveCrop
transforms.FiveCrop(112),
transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
5 TenCrop
transforms.TenCrop(112, vertical_flip=False),
transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
1 Horizontal Flip
transforms.RandomHorizontalFlip(p=1),
2 Vertical Flip
transforms.RandomVerticalFlip(p=0.5),
3 RandomRotation
transforms.RandomRotation(90),
transforms.RandomRotation((90), expand=True),
transforms.RandomRotation(30, center=(0, 0)),
transforms.RandomRotation(30, center=(0, 0), expand=True), # expand only for center rotation
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])