PyTorch深度学习——数据输入和预处理

pytorch数据载入

数据载入

在使用pytorch构建和训练模型的过程中,需要经常把原始数据(比如图片、音频)转化为张量的格式,为了方便地批量处理图片数据,pytorch引入了一系列工具来对这个过程进行包装

torch.utils.data.DataLoader

pytorch提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作

Pytorch torch.utils.data.DataLoader 用法详细介绍-CSDN博客

python 复制代码
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

参数说明

  • dataset :要从中加载数据的数据集(一个**torch.utils.data.DataLoader的实例**)
  • batch_size:每批次要装载多少样品(迷你批次的大小)
  • shuffle :设置为True以使数据在每个时期都重新洗牌
  • sampler :定义从数据集中抽取样本的策略
  • batch_sampler:类似于采样器sampler,但一次返回一个迷你批次的索引,sampler只返回一个下标索引, 与batch_size,shuffle,sampler和drop_last互斥
  • num_workers :多少个子流程用于数据加载。 0表示将在主进程中加载数据 (默认值:0)
  • collate_fn :把一批 dataset 的实例转化为包含迷你批次数据的张量
  • pin_memory :如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中。 如果您的数据元素是自定义类型,或者您的collate_fn返回的是一个自定义类型的批处理
  • drop_last :决定是否将最后一个迷你批次的数据丢掉
  • timeout:如果为正,则为从工作人员收集批次的超时值。 应始终为非负数。 (默认值:0)
  • worker_init_fn:如果非None,这个函数将在每个工作子进程上被调用,并接收工作进程ID(一个在[0, num_workers - 1]范围内的整数)作为输入,它在设置随机种子之后、但在数据加载之前被调用。(默认:None)
  • prefetch_factor:每个子流程预先加载的样本数。 2表示将在所有子流程中预取总共2 * num_workers个样本。 (默认值:2)
  • persistent_workers :如果为True,则一次使用数据集后,数据加载器将不会关闭工作进程。 这样可以使Worker Dataset实例保持活动状态。 (默认值:False)

映射类型的数据集

为了能够使用 DataLoader 类,首先需要构造关于单个数据的 torch.ulits.data.Dataset 类,这个类有两种:一种是映射类型(Map-Style),对于这个类型,每个数据都有一个对应的索引,通过输入具体的索引,就能得到对应的数据

python 复制代码
import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

一般来说,对于这个类,主要需要重写两个方法:一个是 getitem ,该方法是python内置的操作符方法,对应的操作符是索引操作符 [],通过输入整数数据索引,其大小在0至N-1之间(N微数据的总数目),返回具体的某一条数据记录,这就是该方法需要完成的任务,而具体的逻辑需要根据数据集的类型来决定,另一个方法是 len ,该方法返回数据的总数

在python,如果一个Dataset类重写了该方法,可以通过使用 len 内置函数来获取数据的数目

torchvision工具包的使用

PyTorch:Torchvision的简单介绍与使用-CSDN博客

可迭代类型的数据集

python 复制代码
from torch.utils.data import IterableDataset


class MyIterableDataset(IterableDataset):

    def __init__(self, file_path):
        self.file_path = file_path

    def __iter__(self):
        with open(self.file_path, 'r') as file_obj:
            for line in file_obj:
                line_data = line.strip('\n').split(',')
                yield line_data


if __name__ == '__main__':
    dataset = MyIterableDataset('test_csv.csv')
    for data in dataset:
        print(data)

pytorch模型的保存和加载

序列化和反序列化

由于pytorch的模块和张量的本质是 torch.nn.Module 和 torch.tensor 类的实例,而pytorch自带了一系列的方法,可以将这些类的实例转化为字符串,所以这些势力可以通过python序列化方法进行序列化(serialization)和反序列化(unserialization)

pytorch的实现里集成了python自带的pickle包对模块和张量进行序列化,张量序列化的本质是把张量的信息,包括数据类型和存储位置,以及携带的数据等转化为字符串,而这些字符串时候可以通过使用python自带的文件IO函数进行存储,这个过程是可逆的,即可以通过文件IO函数来读取存储的字符串,然后将字符串逆向解析成pytorch的模块和张量

python 复制代码
torch.save(obj,f,pickle_module=pickle,pickle_protocol=2)
torch.load(f,map_location=None,pickle_module=pickle,**pickle_load_args)

torch.save 函数传入的第一个参数是pytorch中可以被序列化的对象,包括模型和张量等,第二个参数是存储文件的路径,序列化的结果将会被保留在这个路径里面,第三个参数是默认的,传入的是序列化的库,可以使用pytorch默认的序列化库pickle,第四个参数是pickle协议,即如何把对象转化为字符串的规范,上述使用的协议版本是2

与 torch.save 函数对应的是 torch.load 函数,该函数在给定序列化后的文件路径之后,就能输出 pytorch 的对象,第一个参数是文件路径之后,第二个参数是张量存储位置的映射,如果存储时的模型在CPU上,可以直接使用默认参数,但当存储的模型在GPU上,torch.load 的默认行为是先把模型载入CPU中,然后转移到保存时的GPU上,加入载入模型的时候是在另外一台计算机上,而计算机没有GPU或GPU的型号对不上就会报错

此时可以使用 map_loactin 函数,设置 map_loactin = 'CPU',这样就会把模型保留在CPU里面,不再移动到GPU中,pickle_module 参数和 torch.save 里的同名参数的作用一致

在pytorch中,模型的保存方法有两种,第一种是直接保存模型的实例(因为模型本身可以被序列化),第二种是保存模型的状态字典(State Dict),一个模型的状态字典包含模型所有参数的名字以及名字对应的张量,通过调用 state_dict 方法,就可获取当前模型的状态字典

状态字典的保存和载入

由于pytorch模块的实现依赖具体的pytorch版本,所以会存在一种情况:使用某一个版本保存的序列化文件无法被另一个版本的pytorch载入,相比之下,pytorch的张量变动较小,二状态字典只含有张量参数的名字和张量参数的具体信息,预模块的实现关联较小,因此更加推荐使用 state_dict 方法来获取状态字典,然后保存该张量字典来保存模型,这样可以实现最大限度地减小代码对pytorch版本的依赖性

另外在训练的时候,不仅要保存模型的相关信息,还要保存优化器的相关信息,因为可能需要从存储的检查点出发,继续进行训练,pytorch中参数:当前的学习率,当前梯度的指数移动平均等,通过调用优化器的 state_dict 方法和 load_state_dict 方法,可以让优化器输出和载入相关的状态信息

python 复制代码
save_info = { # 保存的信息
    "iter_num":iter_num,  # 迭代步数
    "optimizer":optimizer.state_dict,  # 优化器的状态字典
    "model":model.state_dict(),  # 模型的状态字典
}
# 保存信息
torch.save(save_info,save_path)
# 载入信息
save_info = torch.load(save_path)
optimizer.load_stste_dict(save_info["optimizer"])
model.load_stste_dict(save_info["model"])

pytorch数据可视化

tensorboard是一个数据可视化工具,能直观的显示深度学习中张量的变化,从这个变幻的过程中很容易的可以了解到模型在训练中的行为,包括但不限于损失函数的下降趋势是否合理,张量分量的分布是否在训练过程中发生变化

Pytorch:Tensorboard的安装及常用类的使用【图表+图片方法的使用】-CSDN博客

pytorch进阶 可视化工具TensorBoard的使用_pip install future tensorboard-CSDN博客

PyCharm中TensorBoard的安装和使用_phyton怎么安装 tensorborad-CSDN博客

Tensorboard的使用 ---- SummaryWriter类(pytorch版)-CSDN博客

pytorch模型的并行化

多GPU训练:PyTorch中的数据并行与模型并行-CSDN博客

相关推荐
panpantt3213 分钟前
【参会邀请】第二届大数据与数据挖掘国际会议(BDDM 2024)邀您相聚江城!
大数据·人工智能·数据挖掘
lindsayshuo12 分钟前
jetson orin系列开发版安装cuda的gpu版本的opencv
人工智能·opencv
向阳逐梦12 分钟前
ROS机器视觉入门:从基础到人脸识别与目标检测
人工智能·目标检测·计算机视觉
陈鋆37 分钟前
智慧城市初探与解决方案
人工智能·智慧城市
qdprobot38 分钟前
ESP32桌面天气摆件加文心一言AI大模型对话Mixly图形化编程STEAM创客教育
网络·人工智能·百度·文心一言·arduino
QQ395753323738 分钟前
金融量化交易模型的突破与前景分析
人工智能·金融
QQ395753323739 分钟前
金融量化交易:技术突破与模型优化
人工智能·金融
The_Ticker1 小时前
CFD平台如何接入实时行情源
java·大数据·数据库·人工智能·算法·区块链·软件工程
Elastic 中国社区官方博客1 小时前
Elasticsearch 开放推理 API 增加了对 IBM watsonx.ai Slate 嵌入模型的支持
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
jwolf21 小时前
摸一下elasticsearch8的AI能力:语义搜索/vector向量搜索案例
人工智能·搜索引擎