背景
进行生成器的构建,还有数据增强。并且封装在data.py函数里。
声明:整个数据和代码来自于b站,链接:使用pytorch框架手把手教你利用VGG16网络编写猫狗分类程序_哔哩哔哩_bilibili
我做了复现,并且记录了自己在做这个项目分类时候,一些所思所得。
构建生成器+数据增强
这段代码定义了一个自定义的数据生成器类`DataGenerator`,用于处理图像数据,特别适用于深度学习中的图像分类或物体检测任务。下面是这个脚本的主要功能和流程总结:
- **预处理函数定义**:
-
`preprocess_input(x)`: 将图像像素值归一化到[-1, 1]区间,这是许多深度学习模型的标准输入格式。
-
`cvtColor(image)`: 确保图像为RGB格式,如果输入是灰度或其他格式,则转换为RGB。
2. **`DataGenerator`类**:
-
**初始化 (`init`)**: 接受图像标注信息的列表、图像输入尺寸和是否进行随机数据增强的标志。
-
**数据长度 (`len`)**: 返回数据集的总样本数。
-
**获取样本 (`getitem`)**:
-
读取图像和标签;
-
应用数据增强(如果`random=True`),包括缩放、裁剪、翻转、旋转和色域扭曲;
-
对图像进行预处理(归一化并调整通道顺序);
- 返回处理后的图像数据和标签。 -
**辅助函数**:
-
`rand(a, b)`: 生成一个在[a, b]范围内的随机数。
-
`get_random_data(image, inpt_shape, jitter, hue, sat, val, random)`: 实现数据增强逻辑,包括调整图像大小、添加灰边、随机翻转、旋转以及HSV空间的颜色调整。
总结来说,这个`DataGenerator`类主要用于读取图片文件,并根据给定的参数执行一系列图像预处理和数据增强操作,以便于后续的深度学习模型训练。它能够生成经过标准化处理和增强的数据,提高模型对图像变化的鲁棒性,适合于训练图像识别、分类或检测模型。
【为什么获取样本放在 (__getitem__
)】
def __getitem__(self, index):
annotation_path=self.annotation_lines[index].split(';')[1].split()[0]
image=Image.open(annotation_path)
image=self.get_random_data(image,self.input_shape,random=self.random)
image=np.transpose(preprocess_input(np.array(image).astype(np.float32)),[2,0,1])
y=int(self.annotation_lines[index].split(';')[0])
return image,y
在Python的类定义中,`getitem`方法是一个特殊方法,当实例化的对象使用方括号`[]`操作符访问元素时,该方法会被自动调用。对于数据处理或者数据集类(如上述的`DataGenerator`)而言,实现`getitem`方法具有重要意义,原因如下:
-
**索引访问**: 在深度学习中,我们经常需要按索引从数据集中获取单个样本,用于训练或验证模型。通过重载`getitem`,可以像操作列表一样简单地从数据集中提取数据,例如`data = dataset[i]`。
-
**迭代兼容**: 实现了`getitem`的数据集类可以被用在for循环中,因为Python的迭代协议会尝试使用`getitem`方法来遍历对象。这样使得数据集能够方便地与很多现有的库和框架集成,比如PyTorch的训练循环中直接迭代数据加载器。
-
**兼容PyTorch等框架**: 深度学习框架(如PyTorch)的数据加载器(DataLoader)要求数据集类实现`getitem`和`len`方法,以便能够高效地按批次加载和处理数据。这允许框架根据需要动态地请求和处理数据,特别是在进行批量训练时,提高了内存效率和训练速度。
-
**灵活的数据处理**: 在`getitem`内部,可以根据索引执行复杂的逻辑来处理和返回特定的样本。这包括但不限于加载图像、应用数据增强、预处理等操作,为每个样本提供了定制化的处理流程。
因此,在上述代码中,将获取样本的逻辑放在`getitem`方法中,是为了遵循这些最佳实践和框架要求,确保数据集能够高效、灵活地服务于深度学习模型的训练和验证过程。
完整代码
python
import cv2
import numpy as np
import torch.utils.data as data
from PIL import Image
def preprocess_input(x):
x/=127.5
x-=1.
return x
def cvtColor(image):
if len(np.shape(image))==3 and np.shape(image)[-2]==3:
return image
else:
image=image.convert('RGB')
return image
class DataGenerator(data.Dataset):
def __init__(self,annotation_lines,inpt_shape,random=True):
self.annotation_lines=annotation_lines
self.input_shape=inpt_shape
self.random=random
def __len__(self):
return len(self.annotation_lines)
def __getitem__(self, index):
annotation_path=self.annotation_lines[index].split(';')[1].split()[0]
image=Image.open(annotation_path)
image=self.get_random_data(image,self.input_shape,random=self.random)
image=np.transpose(preprocess_input(np.array(image).astype(np.float32)),[2,0,1])
y=int(self.annotation_lines[index].split(';')[0])
return image,y
def rand(self,a=0,b=1):
return np.random.rand()*(b-a)+a
def get_random_data(self,image,inpt_shape,jitter=.3,hue=.1,sat=1.5,val=1.5,random=True):
image=cvtColor(image)
iw,ih=image.size
h,w=inpt_shape
if not random:
scale=min(w/iw,h/ih)
nw=int(iw*scale)
nh=int(ih*scale)
dx=(w-nw)//2
dy=(h-nh)//2
image=image.resize((nw,nh),Image.BICUBIC)
new_image=Image.new('RGB',(w,h),(128,128,128))
new_image.paste(image,(dx,dy))
image_data=np.array(new_image,np.float32)
return image_data
new_ar=w/h*self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
scale=self.rand(.75,1.25)
if new_ar<1:
nh=int(scale*h)
nw=int(nh*new_ar)
else:
nw=int(scale*w)
nh=int(nw/new_ar)
image=image.resize((nw,nh),Image.BICUBIC)
#将图像多余的部分加上灰条
dx=int(self.rand(0,w-nw))
dy=int(self.rand(0,h-nh))
new_image=Image.new('RGB',(w,h),(128,128,128))
new_image.paste(image,(dx,dy))
image=new_image
#翻转图像
flip=self.rand()<.5
if flip: image=image.transpose(Image.FLIP_LEFT_RIGHT)
rotate=self.rand()<.5
if rotate:
angle=np.random.randint(-15,15)
a,b=w/2,h/2
M=cv2.getRotationMatrix2D((a,b),angle,1)
image=cv2.warpAffine(np.array(image),M,(w,h),borderValue=[128,128,128])
#色域扭曲
hue=self.rand(-hue,hue)
sat=self.rand(1,sat) if self.rand()<.5 else 1/self.rand(1,sat)
val=self.rand(1,val) if self.rand()<.5 else 1/self.rand(1,val)
x=cv2.cvtColor(np.array(image,np.float32)/255,cv2.COLOR_RGB2HSV)#颜色空间转换
x[...,1]*=sat
x[...,2]*=val
x[x[:,:,0]>360,0]=360
x[:,:,1:][x[:,:,1:]>1]=1
x[x<0]=0
image_data=cv2.cvtColor(x,cv2.COLOR_HSV2RGB)*255
return image_data