Pytorch(一)

一.PyTorch环境配置及安装

1.1 工具安装

1.1.1 Anaconda下载

清华大学镜像站下载,版本为Anaconda3-5.2.0-Windows-x86_64(对应python3.6.5)

Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror

1.1.2 Pytorch安装

进入官网。选择合适版本下载Start Locally | PyTorch(不推荐),因为我下载了好几次失败了

推荐使用下面的方法:

  • torch下载:

    复制代码
    pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
  • opencv下载:

    复制代码
    pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
  • torchvision

    复制代码
    pip install torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple

二、DataSet

2.1 DataSet的作用

①获取每个数据及其label

②获取数据总数

2.2 认识DataSet

由上图可知dataset是一个抽象类,可以用来创造数据集,而抽象类不能实例化,需要构造抽象类的子类来创建数据集,所有的datasets继承这个类,并重写两个方法:(1)get_item 方法获取数据和label(2)len:获取数据总数

2.3 重写dataset

PIL中的Image

  • img=Image.open(image_path) 读取图像路径作为一个变量

  • img.show() 打开图片

os:

  • os.path.join(dir1,dir2):将两个路径合并在一起

  • os.listdir(dir):将目标路径dir中的所有文件路径生成一个列表

python 复制代码
from torch.utils.data import Dataset, ConcatDataset
import os
from PIL import Image
​
class MyDataset(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
​
    def __len__(self):
        return len(self.img_path)
​
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir  # 这里可能需要根据实际情况调整
        return img, label
​
# 定义根目录和标签目录
root_dir = "hymenoptera_data\\train"
ants_label_dir = "ants_img"
bees_label_dir = "bees_img"
​
# 创建数据集实例
ant_datasets = MyDataset(root_dir, ants_label_dir)
bee_datasets = MyDataset(root_dir, bees_label_dir)
​
# 使用 ConcatDataset 合并数据集
train_datasets = ConcatDataset([ant_datasets, bee_datasets])
​
# 获取第一个样本
img, label = train_datasets[0]
print(label)

2.4 将数据集格式转化为txt存放label格式

python 复制代码
import os
# 设置根目录和图像目录
root_dir = r"D:\桌面\Python\1024\hymenoptera_data\train"
ants_dir = "ants_img"
ant_image_path = os.path.join(root_dir, ants_dir)
ant_path_list = os.listdir(ant_image_path)
# 获取标签
ants_label = ants_dir.split("_")[0]
# 输出目录
out_dir = "ants_label"
# 遍历图像目录中的所有文件
for i in ant_path_list:
    # 创建输出文件路径
    file_name=i.split('.jpg')[0]
    print(file_name)
    output_file_path = os.path.join(root_dir, out_dir, "{}.txt".format(file_name))
    print(output_file_path)
    # 写入标签到文本文件
    with open(output_file_path, 'w') as f:
        f.write(ants_label)

三、TensorBoard

3.1 安装TensorBoard

pip install tensorboard

3.2 SummaryWriter

由上图可知:此类是将事件写进log_dir文件夹中,被TensorBoard解析,用于观察损失函数的变化

复制代码
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter("logs")

3.2.1 add_scalar()

def add_scalar(self,tag,scalar_value,global_step=None,walltime=None,new_style=False, double_precision=False,): """Add scalar data to summary. Args: tag (string): Data identifier 标题 scalar_value (float or string/blobname): Value to save 添加的值,y轴 global_step (int): Global step value to record 训练步数,x轴 walltime (float): Optional override default walltime (time.time()) with seconds after epoch of event new_style (boolean): Whether to use new style (tensor field) or old style (simple_value field). New style could lead to faster data loading. Examples::

复制代码
     from torch.utils.tensorboard import SummaryWriter
     writer = SummaryWriter()
     x = range(100)
     for i in x:
         writer.add_scalar('y=2x', i * 2, i)
     writer.close()
 """

在终端输入:tensorboard --logdir=logs

可通过参数 --port=xx 指定端口,以防止端口冲突(多用户使用同一服务器进行训练时):tensorboard --logdir=logs --port=6007

注:

  • 这个命令会打开 logdir文件夹下的所有事件文件。

  • 图表以 tag 作为区分,可向一个 tag多次写入数据,会自动进行拟合。

举个栗子

绘制y=x^2图像,并将图像存至log文件中

python 复制代码
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter("log")
for i in range(100):
    writer.add_scalar("y=x^2",i*i,i)
writer.close()

3.2.2 add_image()

def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):

"""Add image data to summary.

Note that this requires the pillow package.

Args:

tag (string): Data identifier 标题

img_tensor (torch.Tensor, numpy.array, or string/blobname): Image data 要求图像的数据类型为torch.Tensor,numpy.array,string/blobname

global_step (int): Global step value to record 训练步数

walltime (float): Optional override default walltime (time.time())

seconds after epoch of event

Shape:

img_tensor: Default is :math:(3, H, W). You can use torchvision.utils.make_grid() to

convert a batch of tensor into 3xHxW format or call add_images and let us do the job.

Tensor with :math:(1, H, W), :math:(H, W), :math:(H, W, 3) is also suitable as long as

corresponding dataformats argument is passed, e.g. CHW, HWC, HW.dataformats = 'HWC':图片的格式High高度,Weight宽度,channel通道数;默认通道数在前

"""

举个栗子
python 复制代码
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
​
writer=SummaryWriter("log")
img_path=r"hymenoptera_data/train/ants_img/0013035.jpg"
img =Image.open(img_path)
print(type(img))
#<class 'PIL.JpegImagePlugin.JpegImageFile'>
img_array=np.array(img)
print(type(img_array))
#<class 'numpy.ndarray'>
writer.add_image("test", img_array, 1, dataformats = 'HWC')
writer.close()

PILnumpy,需要在add_image()中指定shape中每一个数字/维表示的含义

3.2.3 close()

把缓存中保存的数据写到目标events文件中,一旦训练中断没有close,则你的保存目录中不会有相应的数据。

四、transforms

torchvision中的transforms主要是对图片进行一些变换。 tranforms对应 tranforms.py 文件,里面定义了很多类,输入一个图片对象,返回经过处理的图片对象。

transforms.py就像一个工具箱,里面定义的各种类就像各种工具,图片就是输入对象,经过工具处理,输出期望的图片结果。

4.1 transforms.ToTensor

4.1.1 如何用

ToTensor功能是将PIL Image类型或者numpy ndarry类型的图片对象转换为tensor类型。

python 复制代码
from torchvision import transforms
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
img_path="hymenoptera_data/train/ants_img/0013035.jpg"
img=Image.open(img_path)
trans_tensor=transforms.ToTensor()
img_tensor=trans_tensor(img)
writer=SummaryWriter("logs")
writer.add_image("test3",img_tensor)
writer.close()

使用transforms的方法就是 先实例化选中的类,然后用实例化的对象去处理图片就行。

4.1.2 为什么需要Tensor类型

tensor 数据类型可以理解为包装了反向神经网络一些理论基础参数。在神经网络中,要将数据先转换为Tensor类型,再进行训练。

4.2 常用的Transforms API函数

4.2.1 常用的输入图片数据类型

  • PIL:Image.open()

  • tensor:transforms.ToTensor()

  • ndarrays:cv.Imcread()

4.2.2 常用的Transform

  • ToTensor():将图片对象类型转换为Tensor

  • Normalize():对图像像素进行归一化计算

  • Resize():重新设置PIL Image的大小,返回的也是PIL Image格式

  • Compose():输入为transforms类型参数列表

python 复制代码
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import os
​
root_path = "hymenoptera_data/train/ants"
img_name = "7759525_1363d24e88.jpg"
img_path = os.path.join(root_path,img_name)
img = Image.open(img_path)
​
writer = SummaryWriter("logs")
​
# ToTensor
trans_totensor = transforms.ToTensor() # instantiation
img_tensor = trans_totensor(img)
writer.add_image("Tensor", img_tensor)
​
# Normalize
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm)
​
#Resize
print(img.size)
trans_resize = transforms.Resize((512,512))
img_resize = trans_resize(img) # return type still is PIL image
img_resize = trans_totensor(img_resize)
writer.add_image("Resize", img_resize)
​
# Compose - resize -2
trans_resize_2 = transforms.Resize(512)
tran_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize2 = tran_compose(img)
writer.add_image("Compose", img_resize2)
​
writer.close()
复制代码
相关推荐
靴子学长3 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
海棠AI实验室4 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
苏言の狗6 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
paixiaoxin9 小时前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
weixin_515202499 小时前
第R3周:RNN-心脏病预测
人工智能·rnn·深度学习
吕小明么11 小时前
OpenAI o3 “震撼” 发布后回归技术本身的审视与进一步思考
人工智能·深度学习·算法·aigc·agi
CSBLOG12 小时前
深度学习试题及答案解析(一)
人工智能·深度学习
小陈phd12 小时前
深度学习之超分辨率算法——SRCNN
python·深度学习·tensorflow·卷积
王国强200913 小时前
动手学人工智能-深度学习计算5-文件读写操作
深度学习
威化饼的一隅14 小时前
【多模态】swift-3框架使用
人工智能·深度学习·大模型·swift·多模态