Pytorch加载数据的Dateset类和DataLoader类

Pytorch提供了Dataset类和DataLoader类专门用于处理数据,他们既可以加载Pytorch预置的数据集,也可以加载自定义数据集。其中数据集类Dataset负责存储样本以及他们对应的标签;数据加载类DataLoader负责迭代访问数据集中的样本。

Dataset

  • 映射型数据集

    继承自Dataset类,表示一个从索引到样本的映射(索引可以不是整数),这样我们就可以方便地通过dataset[idx] 来访问指定得到索引的样本。这也是目前最常见的数据集类型。映射型数据集必须实现 getitem ()函数,其负责根据指定的key返回对应的样本。一般还会实现 len() 用于返回数据集的大小。

  • 迭代型数据集

    继承自 Dataset,表示可迭代的数据集,它可以通过 iter(dataset)以数据流(steam)的形式访问,适用于访问大数据集或者远程服务器上产生的数据,迭代数据集必须实现__iter__()函数,用于返回一个样本迭代器(iteror).
    多线程访问数据集

python 复制代码
from torch.utils.data import IterableDataset, DataLoader

class MyIterableDataset(IterableDataset):

    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start
        self.start = start
        self.end = end

    def __iter__(self):
        return iter(range(self.start, self.end))

ds = MyIterableDataset(start=3, end=7) # [3, 4, 5, 6]
# Single-process loading
print(list(DataLoader(ds, num_workers=0)))
# Directly doing multi-process loading
print(list(DataLoader(ds, num_workers=2)))

# output
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

[tensor([3]), tensor([3]), tensor([4]), tensor([4]), tensor([5]), tensor([5]), tensor([6]), tensor([6])]

创建一个自定义的映射型数据集

python 复制代码
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

可以看到,我们实现了 init ()、len () 和 getitem() 三个函数,其中:

  • init() 初始化数据集参数,这里设置了图像的存储目录、标签(通过读取标签 csv 文件)以及样本和标签的数据转换函数;
  • len() 返回数据集中样本的个数;
  • getitem() 映射型数据集的核心,根据给定的索引 idx 返回样本。这里会根据索引从目录下通过 read_image 读取图片和从 csv 文件中读取图片标签,并且返回处理后的图像和标签。

DataLoader

前面的数据集Dataset类提供了一种按照索引访问样本的方式,不过在实际训练模型时,我们都需要先将数据集切分为很多的mini-batches,然后按照批次将样本送入模型,并且循环这一过程,每一个完整遍历所有样本成为一个epoch。

python 复制代码
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

img = train_features[0].squeeze()
label = train_labels[0]
print(img.shape)
print(f"Label: {label}")
相关推荐
Ankie Wan4 分钟前
notepad++技巧:查找和替换:扩展 or 正则表达式
python·正则表达式·notepad++
带娃的IT创业者4 分钟前
《AI大模型趣味实战》智能Agent和MCP协议的应用实例:搭建一个能阅读DOC文件并实时显示润色改写过程的Python Flask应用
人工智能·python·flask
一只韩非子9 分钟前
什么是MCP?为什么引入MCP?(通俗易懂版)
人工智能·aigc·mcp
JavaEdge在掘金12 分钟前
启动nginx报错,80 failed (97: Address family not supported by protocol)
python
新智元12 分钟前
毛骨悚然!o3 精准破译照片位置,只靠几行 Python 代码?人类在 AI 面前已裸奔
人工智能·openai
纪元A梦19 分钟前
华为OD机试真题——绘图机器(2025A卷:100分)Java/python/JavaScript/C++/C/GO最佳实现
java·javascript·c++·python·华为od·go·华为od机试题
程序员小远31 分钟前
接口测试和单元测试详解
自动化测试·软件测试·python·测试工具·单元测试·测试用例·接口测试
Tech Synapse40 分钟前
电商商品推荐系统实战:基于TensorFlow Recommenders构建智能推荐引擎
人工智能·python·tensorflow
帅帅的Python40 分钟前
2015-2023 各省 GDP 数据,用QuickBI 进行数据可视化——堆叠图!
大数据·人工智能
聿小翼43 分钟前
selenium-wire 与 googletrans 的爱恨情仇
python