PyTorch深度学习——数据输入和预处理

pytorch数据载入

数据载入

在使用pytorch构建和训练模型的过程中,需要经常把原始数据(比如图片、音频)转化为张量的格式,为了方便地批量处理图片数据,pytorch引入了一系列工具来对这个过程进行包装

torch.utils.data.DataLoader

pytorch提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作

Pytorch torch.utils.data.DataLoader 用法详细介绍-CSDN博客

python 复制代码
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

参数说明

  • dataset :要从中加载数据的数据集(一个**torch.utils.data.DataLoader的实例**)
  • batch_size:每批次要装载多少样品(迷你批次的大小)
  • shuffle :设置为True以使数据在每个时期都重新洗牌
  • sampler :定义从数据集中抽取样本的策略
  • batch_sampler:类似于采样器sampler,但一次返回一个迷你批次的索引,sampler只返回一个下标索引, 与batch_size,shuffle,sampler和drop_last互斥
  • num_workers :多少个子流程用于数据加载。 0表示将在主进程中加载数据 (默认值:0)
  • collate_fn :把一批 dataset 的实例转化为包含迷你批次数据的张量
  • pin_memory :如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中。 如果您的数据元素是自定义类型,或者您的collate_fn返回的是一个自定义类型的批处理
  • drop_last :决定是否将最后一个迷你批次的数据丢掉
  • timeout:如果为正,则为从工作人员收集批次的超时值。 应始终为非负数。 (默认值:0)
  • worker_init_fn:如果非None,这个函数将在每个工作子进程上被调用,并接收工作进程ID(一个在[0, num_workers - 1]范围内的整数)作为输入,它在设置随机种子之后、但在数据加载之前被调用。(默认:None)
  • prefetch_factor:每个子流程预先加载的样本数。 2表示将在所有子流程中预取总共2 * num_workers个样本。 (默认值:2)
  • persistent_workers :如果为True,则一次使用数据集后,数据加载器将不会关闭工作进程。 这样可以使Worker Dataset实例保持活动状态。 (默认值:False)

映射类型的数据集

为了能够使用 DataLoader 类,首先需要构造关于单个数据的 torch.ulits.data.Dataset 类,这个类有两种:一种是映射类型(Map-Style),对于这个类型,每个数据都有一个对应的索引,通过输入具体的索引,就能得到对应的数据

python 复制代码
import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

一般来说,对于这个类,主要需要重写两个方法:一个是 getitem ,该方法是python内置的操作符方法,对应的操作符是索引操作符 [],通过输入整数数据索引,其大小在0至N-1之间(N微数据的总数目),返回具体的某一条数据记录,这就是该方法需要完成的任务,而具体的逻辑需要根据数据集的类型来决定,另一个方法是 len ,该方法返回数据的总数

在python,如果一个Dataset类重写了该方法,可以通过使用 len 内置函数来获取数据的数目

torchvision工具包的使用

PyTorch:Torchvision的简单介绍与使用-CSDN博客

可迭代类型的数据集

python 复制代码
from torch.utils.data import IterableDataset


class MyIterableDataset(IterableDataset):

    def __init__(self, file_path):
        self.file_path = file_path

    def __iter__(self):
        with open(self.file_path, 'r') as file_obj:
            for line in file_obj:
                line_data = line.strip('\n').split(',')
                yield line_data


if __name__ == '__main__':
    dataset = MyIterableDataset('test_csv.csv')
    for data in dataset:
        print(data)

pytorch模型的保存和加载

序列化和反序列化

由于pytorch的模块和张量的本质是 torch.nn.Module 和 torch.tensor 类的实例,而pytorch自带了一系列的方法,可以将这些类的实例转化为字符串,所以这些势力可以通过python序列化方法进行序列化(serialization)和反序列化(unserialization)

pytorch的实现里集成了python自带的pickle包对模块和张量进行序列化,张量序列化的本质是把张量的信息,包括数据类型和存储位置,以及携带的数据等转化为字符串,而这些字符串时候可以通过使用python自带的文件IO函数进行存储,这个过程是可逆的,即可以通过文件IO函数来读取存储的字符串,然后将字符串逆向解析成pytorch的模块和张量

python 复制代码
torch.save(obj,f,pickle_module=pickle,pickle_protocol=2)
torch.load(f,map_location=None,pickle_module=pickle,**pickle_load_args)

torch.save 函数传入的第一个参数是pytorch中可以被序列化的对象,包括模型和张量等,第二个参数是存储文件的路径,序列化的结果将会被保留在这个路径里面,第三个参数是默认的,传入的是序列化的库,可以使用pytorch默认的序列化库pickle,第四个参数是pickle协议,即如何把对象转化为字符串的规范,上述使用的协议版本是2

与 torch.save 函数对应的是 torch.load 函数,该函数在给定序列化后的文件路径之后,就能输出 pytorch 的对象,第一个参数是文件路径之后,第二个参数是张量存储位置的映射,如果存储时的模型在CPU上,可以直接使用默认参数,但当存储的模型在GPU上,torch.load 的默认行为是先把模型载入CPU中,然后转移到保存时的GPU上,加入载入模型的时候是在另外一台计算机上,而计算机没有GPU或GPU的型号对不上就会报错

此时可以使用 map_loactin 函数,设置 map_loactin = 'CPU',这样就会把模型保留在CPU里面,不再移动到GPU中,pickle_module 参数和 torch.save 里的同名参数的作用一致

在pytorch中,模型的保存方法有两种,第一种是直接保存模型的实例(因为模型本身可以被序列化),第二种是保存模型的状态字典(State Dict),一个模型的状态字典包含模型所有参数的名字以及名字对应的张量,通过调用 state_dict 方法,就可获取当前模型的状态字典

状态字典的保存和载入

由于pytorch模块的实现依赖具体的pytorch版本,所以会存在一种情况:使用某一个版本保存的序列化文件无法被另一个版本的pytorch载入,相比之下,pytorch的张量变动较小,二状态字典只含有张量参数的名字和张量参数的具体信息,预模块的实现关联较小,因此更加推荐使用 state_dict 方法来获取状态字典,然后保存该张量字典来保存模型,这样可以实现最大限度地减小代码对pytorch版本的依赖性

另外在训练的时候,不仅要保存模型的相关信息,还要保存优化器的相关信息,因为可能需要从存储的检查点出发,继续进行训练,pytorch中参数:当前的学习率,当前梯度的指数移动平均等,通过调用优化器的 state_dict 方法和 load_state_dict 方法,可以让优化器输出和载入相关的状态信息

python 复制代码
save_info = { # 保存的信息
    "iter_num":iter_num,  # 迭代步数
    "optimizer":optimizer.state_dict,  # 优化器的状态字典
    "model":model.state_dict(),  # 模型的状态字典
}
# 保存信息
torch.save(save_info,save_path)
# 载入信息
save_info = torch.load(save_path)
optimizer.load_stste_dict(save_info["optimizer"])
model.load_stste_dict(save_info["model"])

pytorch数据可视化

tensorboard是一个数据可视化工具,能直观的显示深度学习中张量的变化,从这个变幻的过程中很容易的可以了解到模型在训练中的行为,包括但不限于损失函数的下降趋势是否合理,张量分量的分布是否在训练过程中发生变化

Pytorch:Tensorboard的安装及常用类的使用【图表+图片方法的使用】-CSDN博客

pytorch进阶 可视化工具TensorBoard的使用_pip install future tensorboard-CSDN博客

PyCharm中TensorBoard的安装和使用_phyton怎么安装 tensorborad-CSDN博客

Tensorboard的使用 ---- SummaryWriter类(pytorch版)-CSDN博客

pytorch模型的并行化

多GPU训练:PyTorch中的数据并行与模型并行-CSDN博客

相关推荐
新智元16 分钟前
Ilya震撼发声!OpenAI前主管亲证:AGI已觉醒,人类还在装睡
人工智能·openai
朱昆鹏25 分钟前
如何通过sessionKey 登录 Claude
前端·javascript·人工智能
汉堡go32 分钟前
1、机器学习与深度学习
人工智能·深度学习·机器学习
只是懒得想了1 小时前
使用 Gensim 进行主题建模(LDA)与词向量训练(Word2Vec)的完整指南
人工智能·自然语言处理·nlp·word2vec·gensim
johnny2331 小时前
OpenAI系列模型介绍、API使用
人工智能
KKKlucifer1 小时前
生成式 AI 冲击下,网络安全如何破局?
网络·人工智能·web安全
LiJieNiub1 小时前
基于 PyTorch 实现 MNIST 手写数字识别
pytorch·深度学习·学习
ARM+FPGA+AI工业主板定制专家2 小时前
基于JETSON ORIN/RK3588+AI相机:机器人-多路视觉边缘计算方案
人工智能·数码相机·机器人
文火冰糖的硅基工坊2 小时前
[创业之路-691]:历史与现实的镜鉴:从三国纷争到华为铁三角的系统性启示
人工智能·科技·华为·重构·架构·创业
chxin140162 小时前
Transformer注意力机制——动手学深度学习10
pytorch·rnn·深度学习·transformer