问题描述
数据集为h5文件,文件较大无法全部读入内存,故使用自定义Dataset和DataLoader从硬盘中读取,再使用Pytorch训练模型。在多张GPU同时训练时,希望通过设置 DataLoader(..., num_workers=8, ...)
使用多进程读取数据,加速训练,但遇到报错 h5py objects cannot be pickled,后查询得知基本版h5py不支持多进程操作。
设置num_workers=0
可以解决报错,但无法加速训练
有博主(https://blog.csdn.net/qq_36468195/article/details/114922648)建议设置 DataLoader(..., num_workers=0, ...)
,通过主线程读取数据,这样可以见解决报错,但无法提升训练速度,主进程在读取数据时会阻塞训练,使得训练时间延长。
解决方法
安装 h5pickle
模块
pip install h5pickle
在实例化Dataset中打开h5文件时,使用h5pickle
模块替代h5py
模块打开文件.
import torch
import h5py
import h5pickle
class MyDataset(torch.utils.data.Dataset):
def __init__(self, h5_file, ...):
...
...
# self.h5_file_handle = h5py.File(h5_file, "r")
self.h5_file_handle = h5pickle.File(h5_file, "r")
...
...