什么是Transforms
在PyTorch中,transforms是用于对数据进行预处理、增强和变换的操作集合。transforms通常用于数据载入和训练过程中,可以包括数据的归一化、裁剪、翻转、旋转、缩放等操作,以及将数据转换成PyTorch可以处理的Tensor格式。
Transforms的使用
首先导入包
py
from torchvision import transforms
实际上是导入了一个文件,文件名为transforms
,我们目前主要查看一下其中的ToTensor
类
我们可以查看一下其中的源码
我们可以看到这个类的主要作用是 将PIL Image
或者是numpy.ndarray
类型转化为tensor
类型
其中的__call__
方法,类似于c++重载()运算符
使用PIL image格式
具体的使用就可以通过
py
from PIL import Image
from torchvision import transforms
img_path = "dataset/hymenoptera_data/train/ants/0013035.jpg"
PIL_img = Image.open(img_path)
tensor_tans = transforms.ToTensor() # 通过transforms中的ToTensor类创建一个对象
img = tensor_tans(PIL_img) # __call__方法类似于c++中重载了()运算符,我们只需要传入PIL_img格式的图像就可以输出tensor格式的图像
print(img)
我们就可以成功地转化为Tensor
格式了
使用numpy.ndarray格式
首先我们要先通过pip
安装opencv-python
这个库,在终端
输入
pip install opencv-python
安装成功后导入包
py
import cv2
然后使用
py
cv_img = cv2.imread(img_path)
创建出来的图片格式就是numpy.ndarray
格式
为什么要使用Tensor数据类型
Tensor
数据类型包括了我们训练神经网络的一系列参数,以及训练神经网络所需要的格式,这是刚刚我们的格式所不具备的