pytorch分批加载大数据集

pytorch分批加载大数据集

本文处理的数据特点:

(1)数据量大,无法一次读取到内存中

(2)数据是图片或者存储在csv中(每一行是一个sample,包括feature和label)

加载数据集需要继承torch.utils.data 的 Dataset类,并实现 __len__和__getitem__方法。其中

len:返回数据集总数,

getitem:返回指定的数的矩阵和标签。

加载图片

这段代码是一个使用 PyTorch 数据加载和处理机制的例子,主要用于从指定目录加载图片数据,并通过 DataLoader 进行批量处理。

python 复制代码
from torch.utils.data import Dataset, DataLoader
import torch
import glob
import os
from PIL import Image

class PictureLoad(Dataset):
    def __init__(self, paths, size=(10, 10)):
        self.paths = glob.glob(paths)
        self.size = size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, item):
        try:
            img = Image.open(self.paths[item]).resize(self.size)
            img_tensor = torch.from_numpy(np.asarray(img)).float() / 255.0  # 转为Tensor并归一化
            label = os.path.basename(self.paths[item]).split('.')[0]  # 更健壮的文件名提取方式
            return img_tensor, label
        except IOError:
            print(f"Error opening file: {self.paths[item]}")  # 处理文件打开错误
            return None, None

if __name__ == '__main__':
    root_path = os.path.join(os.path.dirname(os.getcwd()), "cap")
    pic_paths = os.path.join(root_path, '*.jpg')

    picture = Pictureload(pic_paths)
    dataloader = DataLoader(picture, batch_size=32, num_workers=2, timeout=2)

    for a, b in dataloader:
        print(b, a.shape)  # 输出标签和图片数据的尺寸,而不是原始数据

表格数据

确保数据以分批方式从文件中加载,且不会一次性将所有数据加载到内存中,适合处理大规模数据文件。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class DataLoad(Dataset):
    def __init__(self, file_path, batch_size=3):
        '''
        初始化函数,设置文件路径和每批读取的数据大小。
        '''
        self.file_path = file_path
        self.batch_size = batch_size
        self.total_data = self._get_total_len()

    def _get_total_len(self):
        '''
        辅助函数用于计算文件中的数据行数。
        '''
        with open(self.file_path, 'r') as file:
            return sum(1 for line in file)

    def __len__(self):
        '''
        返回数据集的总长度。
        '''
        return self.total_data

    def __getitem__(self, idx):
        '''
        根据索引获取数据,每次从文件中动态加载数据。
        '''
        if idx * self.batch_size >= self.total_data:
            raise IndexError("Index out of range")

        skip_rows = idx * self.batch_size if idx > 0 else 0
        df = pd.read_csv(self.file_path, skiprows=skip_rows, nrows=self.batch_size, header=None)
        data_tensor = torch.tensor(df.values)
        return data_tensor

if __name__ == "__main__":
    dataset = DataLoad('path_to_your_data.csv', batch_size=32)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)

    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        for data in dataloader:
            print("Data Batch:")
            print(data)

对于两个batch_size的解释:假设 PretrainData 类每次通过其__getitem__ 方法返回一批数据,即32行数据(根据它的 batch_size=32 设定)。当您使用 DataLoader 并设置其 batch_size 为1时,意味着每次从 DataLoader 迭代得到的数据批将包含从 PretrainData 返回的1个独立批次。因此,每个从 DataLoader 返回的数据批将包含1*32=32条数据。

相关推荐
山海青风1 分钟前
藏文TTS介绍:4 神经网络 TTS 的随机性与自然度
人工智能·python·神经网络·音视频
曲幽3 分钟前
FastAPI入门:从简介到实战,对比Flask帮你选对框架
python·flask·fastapi·web·route·uv·uvicorn·docs
万邦科技Lafite3 分钟前
淘宝开放API批量上架商品操作指南(2025年最新版)
开发语言·数据库·python·开放api·电商开放平台·淘宝开放平台
Deepoch3 分钟前
从“单体智能”到“群体协同”:机器狗集群的分布式智能演进之路
人工智能·科技·开发板·具身模型·deepoc·机械狗
人工智能技术咨询.3 分钟前
【无标题】基于Tensorflow库的RNN模型预测实战
人工智能
yumgpkpm4 分钟前
Cloudera CDH5|CDH6|CDP7.1.7|CDP7.3|CMP 7.3的产品优势分析(在华为鲲鹏 ARM 麒麟KylinOS、统信UOS)
大数据·人工智能·hadoop·深度学习·spark·transformer·cloudera
IT_陈寒5 分钟前
JavaScript 性能优化实战:7 个让你的应用提速 50%+ 的 V8 引擎技巧
前端·人工智能·后端
十三画者5 分钟前
【文献分享】vConTACT3机器学习能够实现可扩展且系统的病毒分类体系的构建
人工智能·算法·机器学习·数据挖掘·数据分析
雪下的新火7 分钟前
AI工具-腾讯混元3D使用简述:
人工智能·游戏引擎·aigc·blender·ai工具·笔记分享
亚里随笔8 分钟前
简约而不简单:JustRL如何用最简RL方案实现1.5B模型突破性性能
人工智能·深度学习·机器学习·语言模型·llm·rl