pytorch分批加载大数据集
本文处理的数据特点:
(1)数据量大,无法一次读取到内存中
(2)数据是图片或者存储在csv中(每一行是一个sample,包括feature和label)
加载数据集需要继承torch.utils.data 的 Dataset类,并实现 __len__和__getitem__方法。其中
len:返回数据集总数,
getitem:返回指定的数的矩阵和标签。
加载图片
这段代码是一个使用 PyTorch 数据加载和处理机制的例子,主要用于从指定目录加载图片数据,并通过 DataLoader 进行批量处理。
python
from torch.utils.data import Dataset, DataLoader
import torch
import glob
import os
from PIL import Image
class PictureLoad(Dataset):
def __init__(self, paths, size=(10, 10)):
self.paths = glob.glob(paths)
self.size = size
def __len__(self):
return len(self.paths)
def __getitem__(self, item):
try:
img = Image.open(self.paths[item]).resize(self.size)
img_tensor = torch.from_numpy(np.asarray(img)).float() / 255.0 # 转为Tensor并归一化
label = os.path.basename(self.paths[item]).split('.')[0] # 更健壮的文件名提取方式
return img_tensor, label
except IOError:
print(f"Error opening file: {self.paths[item]}") # 处理文件打开错误
return None, None
if __name__ == '__main__':
root_path = os.path.join(os.path.dirname(os.getcwd()), "cap")
pic_paths = os.path.join(root_path, '*.jpg')
picture = Pictureload(pic_paths)
dataloader = DataLoader(picture, batch_size=32, num_workers=2, timeout=2)
for a, b in dataloader:
print(b, a.shape) # 输出标签和图片数据的尺寸,而不是原始数据
表格数据
确保数据以分批方式从文件中加载,且不会一次性将所有数据加载到内存中,适合处理大规模数据文件。
python
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
class DataLoad(Dataset):
def __init__(self, file_path, batch_size=3):
'''
初始化函数,设置文件路径和每批读取的数据大小。
'''
self.file_path = file_path
self.batch_size = batch_size
self.total_data = self._get_total_len()
def _get_total_len(self):
'''
辅助函数用于计算文件中的数据行数。
'''
with open(self.file_path, 'r') as file:
return sum(1 for line in file)
def __len__(self):
'''
返回数据集的总长度。
'''
return self.total_data
def __getitem__(self, idx):
'''
根据索引获取数据,每次从文件中动态加载数据。
'''
if idx * self.batch_size >= self.total_data:
raise IndexError("Index out of range")
skip_rows = idx * self.batch_size if idx > 0 else 0
df = pd.read_csv(self.file_path, skiprows=skip_rows, nrows=self.batch_size, header=None)
data_tensor = torch.tensor(df.values)
return data_tensor
if __name__ == "__main__":
dataset = DataLoad('path_to_your_data.csv', batch_size=32)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
for epoch in range(3):
print(f"Epoch {epoch + 1}")
for data in dataloader:
print("Data Batch:")
print(data)
对于两个batch_size的解释:假设 PretrainData 类每次通过其__getitem__ 方法返回一批数据,即32行数据(根据它的 batch_size=32 设定)。当您使用 DataLoader 并设置其 batch_size 为1时,意味着每次从 DataLoader 迭代得到的数据批将包含从 PretrainData 返回的1个独立批次。因此,每个从 DataLoader 返回的数据批将包含1*32=32条数据。