一.PyTorch环境配置及安装
1.1 工具安装
1.1.1 Anaconda下载
清华大学镜像站下载,版本为Anaconda3-5.2.0-Windows-x86_64(对应python3.6.5)
Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror
1.1.2 Pytorch安装
进入官网。选择合适版本下载Start Locally | PyTorch(不推荐),因为我下载了好几次失败了
推荐使用下面的方法:
-
torch下载:
pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
-
opencv下载:
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
-
torchvision
pip install torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple
二、DataSet
2.1 DataSet的作用
①获取每个数据及其label
②获取数据总数
2.2 认识DataSet
由上图可知dataset是一个抽象类,可以用来创造数据集,而抽象类不能实例化,需要构造抽象类的子类来创建数据集,所有的datasets继承这个类,并重写两个方法:(1)get_item 方法获取数据和label(2)len:获取数据总数
2.3 重写dataset
PIL中的Image
img=Image.open(image_path) 读取图像路径作为一个变量
img.show() 打开图片
os:
os.path.join(dir1,dir2):将两个路径合并在一起
os.listdir(dir):将目标路径dir中的所有文件路径生成一个列表
python
from torch.utils.data import Dataset, ConcatDataset
import os
from PIL import Image
class MyDataset(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __len__(self):
return len(self.img_path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir # 这里可能需要根据实际情况调整
return img, label
# 定义根目录和标签目录
root_dir = "hymenoptera_data\\train"
ants_label_dir = "ants_img"
bees_label_dir = "bees_img"
# 创建数据集实例
ant_datasets = MyDataset(root_dir, ants_label_dir)
bee_datasets = MyDataset(root_dir, bees_label_dir)
# 使用 ConcatDataset 合并数据集
train_datasets = ConcatDataset([ant_datasets, bee_datasets])
# 获取第一个样本
img, label = train_datasets[0]
print(label)
2.4 将数据集格式转化为txt存放label格式
python
import os
# 设置根目录和图像目录
root_dir = r"D:\桌面\Python\1024\hymenoptera_data\train"
ants_dir = "ants_img"
ant_image_path = os.path.join(root_dir, ants_dir)
ant_path_list = os.listdir(ant_image_path)
# 获取标签
ants_label = ants_dir.split("_")[0]
# 输出目录
out_dir = "ants_label"
# 遍历图像目录中的所有文件
for i in ant_path_list:
# 创建输出文件路径
file_name=i.split('.jpg')[0]
print(file_name)
output_file_path = os.path.join(root_dir, out_dir, "{}.txt".format(file_name))
print(output_file_path)
# 写入标签到文本文件
with open(output_file_path, 'w') as f:
f.write(ants_label)
三、TensorBoard
3.1 安装TensorBoard
pip install tensorboard
3.2 SummaryWriter
由上图可知:此类是将事件写进log_dir文件夹中,被TensorBoard解析,用于观察损失函数的变化
from torch.utils.tensorboard import SummaryWriter writer=SummaryWriter("logs")
3.2.1 add_scalar()
def add_scalar(self,tag,scalar_value,global_step=None,walltime=None,new_style=False, double_precision=False,): """Add scalar data to summary. Args: tag (string): Data identifier 标题 scalar_value (float or string/blobname): Value to save 添加的值,y轴 global_step (int): Global step value to record 训练步数,x轴 walltime (float): Optional override default walltime (time.time()) with seconds after epoch of event new_style (boolean): Whether to use new style (tensor field) or old style (simple_value field). New style could lead to faster data loading. Examples::
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() x = range(100) for i in x: writer.add_scalar('y=2x', i * 2, i) writer.close() """
在终端输入:tensorboard --logdir=logs
可通过参数 --port=xx 指定端口,以防止端口冲突(多用户使用同一服务器进行训练时):tensorboard --logdir=logs --port=6007
注:
-
这个命令会打开 logdir文件夹下的所有事件文件。
-
图表以 tag 作为区分,可向一个 tag多次写入数据,会自动进行拟合。
举个栗子
绘制y=x^2图像,并将图像存至log文件中
python
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter("log")
for i in range(100):
writer.add_scalar("y=x^2",i*i,i)
writer.close()
3.2.2 add_image()
def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
"""Add image data to summary.
Note that this requires the
pillow
package.Args:
tag (string): Data identifier 标题
img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data 要求图像的数据类型为torch.Tensor,numpy.array,string/blobname
global_step (int): Global step value to record 训练步数
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
Shape:
img_tensor: Default is :math:
(3, H, W)
. You can usetorchvision.utils.make_grid()
toconvert a batch of tensor into 3xHxW format or call
add_images
and let us do the job.Tensor with :math:
(1, H, W)
, :math:(H, W)
, :math:(H, W, 3)
is also suitable as long ascorresponding
dataformats
argument is passed, e.g.CHW
,HWC
,HW
.dataformats = 'HWC':图片的格式High高度,Weight宽度,channel通道数;默认通道数在前"""
举个栗子
python
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer=SummaryWriter("log")
img_path=r"hymenoptera_data/train/ants_img/0013035.jpg"
img =Image.open(img_path)
print(type(img))
#<class 'PIL.JpegImagePlugin.JpegImageFile'>
img_array=np.array(img)
print(type(img_array))
#<class 'numpy.ndarray'>
writer.add_image("test", img_array, 1, dataformats = 'HWC')
writer.close()
从PIL
到numpy
,需要在add_image()
中指定shape
中每一个数字/维表示的含义
3.2.3 close()
把缓存中保存的数据写到目标events
文件中,一旦训练中断没有close
,则你的保存目录中不会有相应的数据。
四、transforms
torchvision中的transforms主要是对图片进行一些变换。 tranforms对应 tranforms.py 文件,里面定义了很多类,输入一个图片对象,返回经过处理的图片对象。
transforms.py就像一个工具箱,里面定义的各种类就像各种工具,图片就是输入对象,经过工具处理,输出期望的图片结果。
4.1 transforms.ToTensor
4.1.1 如何用
ToTensor功能是将PIL Image类型或者numpy ndarry类型的图片对象转换为tensor类型。
python
from torchvision import transforms
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
img_path="hymenoptera_data/train/ants_img/0013035.jpg"
img=Image.open(img_path)
trans_tensor=transforms.ToTensor()
img_tensor=trans_tensor(img)
writer=SummaryWriter("logs")
writer.add_image("test3",img_tensor)
writer.close()
使用transforms的方法就是 先实例化选中的类,然后用实例化的对象去处理图片就行。
4.1.2 为什么需要Tensor类型
tensor 数据类型可以理解为包装了反向神经网络一些理论基础参数。在神经网络中,要将数据先转换为Tensor类型,再进行训练。
4.2 常用的Transforms API函数
4.2.1 常用的输入图片数据类型
-
PIL:Image.open()
-
tensor:transforms.ToTensor()
-
ndarrays:cv.Imcread()
4.2.2 常用的Transform
-
ToTensor():将图片对象类型转换为Tensor
-
Normalize():对图像像素进行归一化计算
-
Resize():重新设置PIL Image的大小,返回的也是PIL Image格式
-
Compose():输入为transforms类型参数列表
python
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import os
root_path = "hymenoptera_data/train/ants"
img_name = "7759525_1363d24e88.jpg"
img_path = os.path.join(root_path,img_name)
img = Image.open(img_path)
writer = SummaryWriter("logs")
# ToTensor
trans_totensor = transforms.ToTensor() # instantiation
img_tensor = trans_totensor(img)
writer.add_image("Tensor", img_tensor)
# Normalize
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm)
#Resize
print(img.size)
trans_resize = transforms.Resize((512,512))
img_resize = trans_resize(img) # return type still is PIL image
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize)
# Compose - resize -2
trans_resize_2 = transforms.Resize(512)
tran_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize2 = tran_compose(img)
writer.add_image("Compose", img_resize2)
writer.close()