pytorch数据载入
数据载入
在使用pytorch构建和训练模型的过程中,需要经常把原始数据(比如图片、音频)转化为张量的格式,为了方便地批量处理图片数据,pytorch引入了一系列工具来对这个过程进行包装
torch.utils.data.DataLoader
pytorch提供的一个用于数据加载的工具类,用于批量加载数据并为模型提供输入。它可以将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作
python
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
参数说明
- dataset :要从中加载数据的数据集(一个**
torch.utils.data.DataLoader的实例
**) - batch_size:每批次要装载多少样品(迷你批次的大小)
- shuffle :设置为True以使数据在每个时期都重新洗牌
- sampler :定义从数据集中抽取样本的策略
- batch_sampler:类似于采样器sampler,但一次返回一个迷你批次的索引,sampler只返回一个下标索引, 与batch_size,shuffle,sampler和drop_last互斥
- num_workers :多少个子流程用于数据加载。 0表示将在主进程中加载数据 (默认值:0)
- collate_fn :把一批 dataset 的实例转化为包含迷你批次数据的张量
- pin_memory :如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中。 如果您的数据元素是自定义类型,或者您的collate_fn返回的是一个自定义类型的批处理
- drop_last :决定是否将最后一个迷你批次的数据丢掉
- timeout:如果为正,则为从工作人员收集批次的超时值。 应始终为非负数。 (默认值:0)
- worker_init_fn:如果非None,这个函数将在每个工作子进程上被调用,并接收工作进程ID(一个在[0, num_workers - 1]范围内的整数)作为输入,它在设置随机种子之后、但在数据加载之前被调用。(默认:None)
- prefetch_factor:每个子流程预先加载的样本数。 2表示将在所有子流程中预取总共2 * num_workers个样本。 (默认值:2)
- persistent_workers :如果为True,则一次使用数据集后,数据加载器将不会关闭工作进程。 这样可以使Worker Dataset实例保持活动状态。 (默认值:False)
映射类型的数据集
为了能够使用 DataLoader 类,首先需要构造关于单个数据的 torch.ulits.data.Dataset 类,这个类有两种:一种是映射类型(Map-Style),对于这个类型,每个数据都有一个对应的索引,通过输入具体的索引,就能得到对应的数据
python
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, data_list):
self.data_list = data_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
return self.data_list[index]
一般来说,对于这个类,主要需要重写两个方法:一个是 getitem ,该方法是python内置的操作符方法,对应的操作符是索引操作符 [],通过输入整数数据索引,其大小在0至N-1之间(N微数据的总数目),返回具体的某一条数据记录,这就是该方法需要完成的任务,而具体的逻辑需要根据数据集的类型来决定,另一个方法是 len ,该方法返回数据的总数
在python,如果一个Dataset类重写了该方法,可以通过使用 len 内置函数来获取数据的数目
torchvision工具包的使用
PyTorch:Torchvision的简单介绍与使用-CSDN博客
可迭代类型的数据集
python
from torch.utils.data import IterableDataset
class MyIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def __iter__(self):
with open(self.file_path, 'r') as file_obj:
for line in file_obj:
line_data = line.strip('\n').split(',')
yield line_data
if __name__ == '__main__':
dataset = MyIterableDataset('test_csv.csv')
for data in dataset:
print(data)
pytorch模型的保存和加载
序列化和反序列化
由于pytorch的模块和张量的本质是 torch.nn.Module 和 torch.tensor 类的实例,而pytorch自带了一系列的方法,可以将这些类的实例转化为字符串,所以这些势力可以通过python序列化方法进行序列化(serialization)和反序列化(unserialization)
pytorch的实现里集成了python自带的pickle包对模块和张量进行序列化,张量序列化的本质是把张量的信息,包括数据类型和存储位置,以及携带的数据等转化为字符串,而这些字符串时候可以通过使用python自带的文件IO函数进行存储,这个过程是可逆的,即可以通过文件IO函数来读取存储的字符串,然后将字符串逆向解析成pytorch的模块和张量
python
torch.save(obj,f,pickle_module=pickle,pickle_protocol=2)
torch.load(f,map_location=None,pickle_module=pickle,**pickle_load_args)
torch.save 函数传入的第一个参数是pytorch中可以被序列化的对象,包括模型和张量等,第二个参数是存储文件的路径,序列化的结果将会被保留在这个路径里面,第三个参数是默认的,传入的是序列化的库,可以使用pytorch默认的序列化库pickle,第四个参数是pickle协议,即如何把对象转化为字符串的规范,上述使用的协议版本是2
与 torch.save 函数对应的是 torch.load 函数,该函数在给定序列化后的文件路径之后,就能输出 pytorch 的对象,第一个参数是文件路径之后,第二个参数是张量存储位置的映射,如果存储时的模型在CPU上,可以直接使用默认参数,但当存储的模型在GPU上,torch.load 的默认行为是先把模型载入CPU中,然后转移到保存时的GPU上,加入载入模型的时候是在另外一台计算机上,而计算机没有GPU或GPU的型号对不上就会报错
此时可以使用 map_loactin 函数,设置 map_loactin = 'CPU',这样就会把模型保留在CPU里面,不再移动到GPU中,pickle_module 参数和 torch.save 里的同名参数的作用一致
在pytorch中,模型的保存方法有两种,第一种是直接保存模型的实例(因为模型本身可以被序列化),第二种是保存模型的状态字典(State Dict),一个模型的状态字典包含模型所有参数的名字以及名字对应的张量,通过调用 state_dict 方法,就可获取当前模型的状态字典
状态字典的保存和载入
由于pytorch模块的实现依赖具体的pytorch版本,所以会存在一种情况:使用某一个版本保存的序列化文件无法被另一个版本的pytorch载入,相比之下,pytorch的张量变动较小,二状态字典只含有张量参数的名字和张量参数的具体信息,预模块的实现关联较小,因此更加推荐使用 state_dict 方法来获取状态字典,然后保存该张量字典来保存模型,这样可以实现最大限度地减小代码对pytorch版本的依赖性
另外在训练的时候,不仅要保存模型的相关信息,还要保存优化器的相关信息,因为可能需要从存储的检查点出发,继续进行训练,pytorch中参数:当前的学习率,当前梯度的指数移动平均等,通过调用优化器的 state_dict 方法和 load_state_dict 方法,可以让优化器输出和载入相关的状态信息
python
save_info = { # 保存的信息
"iter_num":iter_num, # 迭代步数
"optimizer":optimizer.state_dict, # 优化器的状态字典
"model":model.state_dict(), # 模型的状态字典
}
# 保存信息
torch.save(save_info,save_path)
# 载入信息
save_info = torch.load(save_path)
optimizer.load_stste_dict(save_info["optimizer"])
model.load_stste_dict(save_info["model"])
pytorch数据可视化
tensorboard是一个数据可视化工具,能直观的显示深度学习中张量的变化,从这个变幻的过程中很容易的可以了解到模型在训练中的行为,包括但不限于损失函数的下降趋势是否合理,张量分量的分布是否在训练过程中发生变化
Pytorch:Tensorboard的安装及常用类的使用【图表+图片方法的使用】-CSDN博客
pytorch进阶 可视化工具TensorBoard的使用_pip install future tensorboard-CSDN博客
PyCharm中TensorBoard的安装和使用_phyton怎么安装 tensorborad-CSDN博客
Tensorboard的使用 ---- SummaryWriter类(pytorch版)-CSDN博客