pytorch分批加载大数据集

pytorch分批加载大数据集

本文处理的数据特点:

(1)数据量大,无法一次读取到内存中

(2)数据是图片或者存储在csv中(每一行是一个sample,包括feature和label)

加载数据集需要继承torch.utils.data 的 Dataset类,并实现 __len__和__getitem__方法。其中

len:返回数据集总数,

getitem:返回指定的数的矩阵和标签。

加载图片

这段代码是一个使用 PyTorch 数据加载和处理机制的例子,主要用于从指定目录加载图片数据,并通过 DataLoader 进行批量处理。

python 复制代码
from torch.utils.data import Dataset, DataLoader
import torch
import glob
import os
from PIL import Image

class PictureLoad(Dataset):
    def __init__(self, paths, size=(10, 10)):
        self.paths = glob.glob(paths)
        self.size = size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, item):
        try:
            img = Image.open(self.paths[item]).resize(self.size)
            img_tensor = torch.from_numpy(np.asarray(img)).float() / 255.0  # 转为Tensor并归一化
            label = os.path.basename(self.paths[item]).split('.')[0]  # 更健壮的文件名提取方式
            return img_tensor, label
        except IOError:
            print(f"Error opening file: {self.paths[item]}")  # 处理文件打开错误
            return None, None

if __name__ == '__main__':
    root_path = os.path.join(os.path.dirname(os.getcwd()), "cap")
    pic_paths = os.path.join(root_path, '*.jpg')

    picture = Pictureload(pic_paths)
    dataloader = DataLoader(picture, batch_size=32, num_workers=2, timeout=2)

    for a, b in dataloader:
        print(b, a.shape)  # 输出标签和图片数据的尺寸,而不是原始数据

表格数据

确保数据以分批方式从文件中加载,且不会一次性将所有数据加载到内存中,适合处理大规模数据文件。

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class DataLoad(Dataset):
    def __init__(self, file_path, batch_size=3):
        '''
        初始化函数,设置文件路径和每批读取的数据大小。
        '''
        self.file_path = file_path
        self.batch_size = batch_size
        self.total_data = self._get_total_len()

    def _get_total_len(self):
        '''
        辅助函数用于计算文件中的数据行数。
        '''
        with open(self.file_path, 'r') as file:
            return sum(1 for line in file)

    def __len__(self):
        '''
        返回数据集的总长度。
        '''
        return self.total_data

    def __getitem__(self, idx):
        '''
        根据索引获取数据,每次从文件中动态加载数据。
        '''
        if idx * self.batch_size >= self.total_data:
            raise IndexError("Index out of range")

        skip_rows = idx * self.batch_size if idx > 0 else 0
        df = pd.read_csv(self.file_path, skiprows=skip_rows, nrows=self.batch_size, header=None)
        data_tensor = torch.tensor(df.values)
        return data_tensor

if __name__ == "__main__":
    dataset = DataLoad('path_to_your_data.csv', batch_size=32)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)

    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        for data in dataloader:
            print("Data Batch:")
            print(data)

对于两个batch_size的解释:假设 PretrainData 类每次通过其__getitem__ 方法返回一批数据,即32行数据(根据它的 batch_size=32 设定)。当您使用 DataLoader 并设置其 batch_size 为1时,意味着每次从 DataLoader 迭代得到的数据批将包含从 PretrainData 返回的1个独立批次。因此,每个从 DataLoader 返回的数据批将包含1*32=32条数据。

相关推荐
IT古董6 分钟前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师1 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)2 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
小王子10242 小时前
设计模式Python版 组合模式
python·设计模式·组合模式
kakaZhui2 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20253 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥3 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin3 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客4 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
云空4 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析