数据集(Dataset)和数据加载器(DataLoader)-pytroch学习3

pytorch网站学习

处理数据样本的代码往往会变得很乱、难以维护;理想情况下,我们希望把数据部分的代码和模型训练部分分开写,这样更容易阅读、也更好维护。

简单说:数据和模型最好"分工明确",不要写在一起。

PyTorch 提供了两个数据处理的"基本工具":

  • torch.utils.data.Dataset

  • torch.utils.data.DataLoader

    它们可以用来处理官方内置的数据集 ,也可以用来加载你自己的数据。

    Dataset 存储样本及其对应的标签,而 DataLoader 则在 Dataset 周围封装了一个迭代器,以便轻松访问这些样本。

  • Dataset:用于存储样本和对应的标签,类似一个"数据库",它记录了所有数据。

  • DataLoader:基于 Dataset 封装了一个可迭代对象,方便你在训练过程中一次取出一个批次(batch)的数据。

  • Dataset = 数据仓库,负责"存"数据

  • DataLoader = 快递员,负责"送"数据,一批一批送给模型训练用

PyTorch 提供了 Dataset(负责存数据)和 DataLoader(负责送数据)两个工具,可以方便地管理、加载各种数据

PyTorch 的领域库提供了许多预加载的数据集(例如 FashionMNIST),这些数据集都是 torch.utils.data.Dataset 的子类,,例如,FashionMNIST 数据集就是一个专门用于服装图像识别的预加载数据集,它已经按照 Dataset 接口组织好了数据,你可以直接用来训练和测试模型

参数解释:

root:这是用来存放训练/测试数据的文件夹路径。

train:指定是加载训练集(train=True)还是测试集(train=False)。

download=True:如果你指定的 root 路径下没有数据,它会自动联网下载。

transformtarget_transform

  • transform 是对图像特征做的变换(比如转为张量、归一化等)

  • target_transform 是对标签做的变换(比如 one-hot 编码)

    from torchvision import datasets, transforms

    定义图像的预处理操作:把图片转成张量

    transform = transforms.ToTensor()

    加载训练集

    train_data = datasets.FashionMNIST(
    root="data", # 数据保存目录
    train=True, # 加载训练集
    download=True, # 如果没有就下载
    transform=transform # 图像预处理
    )

    加载测试集

    test_data = datasets.FashionMNIST(
    root="data",
    train=False, # 加载测试集
    download=True,
    transform=transform
    )

如何手动取出数据集里的样本,并把它们可视化显示出来

遍历和可视化数据集

我们可以像访问列表那样,用下标手动访问数据集:training_data[index]

我们使用 matplotlib 来把训练数据中的一些样本画出来进行可视化。

复制代码
什么是 training_data[index]?
在 PyTorch 中,像 training_data 这种数据集对象,其实可以像列表(list)一样使用:


image, label = training_data[0]  # 取出第一个样本(包括图像和标签)
image 是一张 28×28 的图(张量)

label 是它的标签(比如 "T-shirt/top")

# 标签编号和对应的文字(类别)之间的映射关系
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

figure = plt.figure(figsize=(8, 8))  # 创建一个图形窗口,大小为 8x8 英寸
cols, rows = 3, 3                    # 准备画一个 3 行 3 列 的图像网格(共 9 张图)

for i in range(1, cols * rows + 1):  # 循环9次(从1到9)
    sample_idx = torch.randint(len(training_data), size=(1,)).item()  # 随机选一个样本索引
    img, label = training_data[sample_idx]  # 从训练集中取出图像和标签
​
    figure.add_subplot(rows, cols, i)  # 添加一个子图(3x3 的第 i 个格子)
    plt.title(labels_map[label])       # 设置图像标题为标签名称(比如 "Sneaker")
    plt.axis("off")                    # 不显示坐标轴
    plt.imshow(img.squeeze(), cmap="gray")  # 显示图像(压缩维度 + 灰度图)
plt.show()  # 显示整张图(9张图一起展示)


​

如何自己创建一个自定义的数据集(Custom Dataset),让 PyTorch 能读取自己的图片和标签,比如本地的一些图片文件和 CSV 表格。

为你自己的文件创建一个自定义数据集

自定义 Dataset 类时,必须实现三个函数:__init__(初始化)、__len__(返回样本总数) 和 __getitem__(获取指定样本)

如果你不是用官方的数据集(比如 FashionMNIST),而是用你自己文件夹里的图片 + CSV 表里的标签,那就需要自己写一个"自定义数据集类":

  • __init__():定义数据集在哪里、怎么加载图片和标签

  • __len__():告诉 PyTorch 你一共有多少张图(样本数量)
    __len__ 函数

    这个函数的作用是:返回数据集中样本(图片)的数量。

  • __getitem__():定义怎么通过索引取出一张图和它的标签(比如 dataset[0]

    import os # 用于路径拼接
    import pandas as pd # 用于读取 CSV 文件
    from torchvision.io import read_image # 用于读取图像(转为张量)
    from torch.utils.data import Dataset # 自定义数据集要继承这个类

    自定义图片数据集类,继承自 PyTorch 的 Dataset 基类

    class CustomImageDataset(Dataset):
    # 初始化函数:加载CSV标签表、图片文件夹路径、图像和标签的预处理方法
    def init(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file) # 读取CSV文件,包含图片文件名和对应标签
    self.img_dir = img_dir # 图片所在的文件夹路径
    self.transform = transform # 图像的预处理方法(例如缩放、归一化)
    self.target_transform = target_transform # 标签的预处理方法(例如转one-hot)

    复制代码
      # 返回数据集中样本的总数量
      def __len__(self):
          return len(self.img_labels)  # 返回 CSV 中的行数(也就是图片数量)
    
      # 按照索引返回一张图片和它的标签
      def __getitem__(self, idx):
          # 根据索引从CSV中获取图片文件名,并拼接成完整路径
          img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    
          # 使用 torchvision.io.read_image 读取图片(返回的是Tensor格式)
          image = read_image(img_path)
    
          # 获取对应的标签(CSV第二列)
          label = self.img_labels.iloc[idx, 1]
    
          # 如果定义了图像预处理,就应用它
          if self.transform:
              image = self.transform(image)
    
          # 如果定义了标签预处理,就应用它
          if self.target_transform:
              label = self.target_transform(label)
    
          # 返回一对数据:(图像,标签)
          return image, label

__init__ 函数

当我们创建 Dataset 数据集对象时,这个 __init__ 函数会被运行一次。

在这个函数中,我们设置好图像所在的文件夹路径、标签文件(CSV),以及两种预处理方法(transform)

这个时候 Python 就会自动去运行你写的 __init__ 函数,完成以下事情:

做什么 举例
读入标签文件 从 CSV 读出每张图对应的标签
记住图片路径 比如你的图片都在 "images/" 文件夹里
保存预处理方法 如果你要对图像做缩放、归一化等处理,也在这里传进来

你可以把 __getitem__() 想象成这样一个问题:

你对 PyTorch 说:"嘿,帮我从数据集中拿出第 5 张图像,还有它的标签。"

PyTorch 就会执行你写的 __getitem__(5),然后:

  1. 去 CSV 表里看第5行,拿到图像文件名,比如 img5.png

  2. 拼成路径,比如 images/img5.png

  3. read_image() 把它读成模型能用的格式(张量)

  4. 拿到它的标签,比如 label=2(代表"Pullover")

  5. 如果你有设置 transform,就先处理一下

  6. 返回 (图像张量, 标签) 给你

使用 DataLoader 为训练准备数据

Dataset(数据集)每次只能取出一条数据(特征和标签)。

而在训练模型时,我们通常希望将样本按小批量(minibatch)送入模型,
并且在每一轮训练(epoch)中
打乱数据的顺序
,以减少模型过拟合,

同时利用 Python 的多进程功能来加快数据的读取速度。

DataLoader 是一个可迭代对象,它通过一个简单的 API 帮我们封装了以上所有复杂操作。

这里的 API 就是"别人已经写好的功能接口",你只要用很简单的方式去"调用它",就可以完成很复杂的事情。

就像你开车,不用知道发动机怎么工作,你只需要踩油门,这个"油门"就是给你用的 API。

没有 DataLoader 时的问题 DataLoader 自动帮你做了什么
一次只能读一张图 ✅ 自动按 batch_size 读多张图
每次都按固定顺序读 ✅ 每轮训练前自动打乱数据
读取慢(尤其是大数据) ✅ 用多进程后台加速加载数据
写代码复杂 ✅ 封装好,只要一行就能搞定

minibatch (中文叫"小批量")指的是:**每次训练时不把所有数据一次性喂给模型,而是一次取出一小部分来训练。**举个例子:

你有 10,000 张训练图像,不可能一次性都送给模型(太慢/太耗显存)。

你可以这样设置:

batch_size = 64

就是:每次训练用 64 张图,学完一批,再取下一批。

这种方式叫:小批量训练(mini-batch training)

什么是 shuffle(打乱数据)?

定义:shuffle 指的是:在每轮训练开始前,把训练数据的顺序随机打乱。

为什么要打乱?

假如你的数据是按类别排好顺序的(比如先全是猫,后全是狗):

模型可能先学猫学很久,突然一下全是狗,这样容易 过拟合某一类,泛化能力差

所以我们会在每个 epoch 前加个参数:

DataLoader(..., shuffle=True)

表示:每一轮训练前,重新随机排序数据。

什么是多进程加载(num_workers)?

定义:PyTorch 可以使用多个"后台工作进程(线程)"同时从磁盘里读取图片,加快加载速度。

举个例子:

你用 DataLoader 加载数据时可以设置:

DataLoader(dataset, batch_size=64, num_workers=4)

意思是:开 4 个后台进程来同时读数据!

就像你点外卖,找了 4 个骑手一起送菜,当然比 1 个骑手送得快。

复制代码
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)什么意思啊

这段代码是用 PyTorch 的 DataLoader,将训练数据和测试数据按小批量分组,并在每轮开始时随机打乱顺序,方便高效地进行模型训练和测试。

遍历 DataLoader

我们已经把数据集加载进了 DataLoader,现在可以根据需要对数据集进行迭代(逐批处理)。

下面的每次迭代都会返回一批 train_features(训练特征)和 train_labels(标签),每批包含 64 个样本和对应的标签(即 batch_size=64)。

因为我们设置了 shuffle=True,所以在我们把所有批次迭代完之后,数据会被自动打乱顺序。

(如果你想更精细地控制数据加载的顺序,可以了解一下 PyTorch 的 Sampler 机制。)

Samplers 是 PyTorch 中 更灵活地控制数据加载顺序 的工具。

如果你想自己控制"数据加载顺序"、"打乱方式"、"分组策略"等,就可以用 Sampler 来代替 shuffle=True

Sampler 是一个类,用来控制 DataLoader 在每一轮训练中应该以什么顺序取数据的索引

常见的 Sampler 类型

Sampler 类别 作用
SequentialSampler 按顺序取数据(默认用于 shuffle=False
RandomSampler 随机打乱数据(默认用于 shuffle=True
SubsetRandomSampler 只随机抽样部分数据(适合做验证集)
WeightedRandomSampler 按权重随机抽样(处理数据不平衡)
复制代码
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

从训练集里拿出一批数据,并显示其中一张图片和它的标签

相关推荐
一个真正のman.4 小时前
c加加学习之day01
学习
蔗理苦5 小时前
2025-04-03 Latex学习1——本地配置Latex + VScode环境
ide·vscode·学习·latex
charlie1145141917 小时前
从0开始的构建的天气预报小时钟(基于STM32F407ZGT6,ESP8266 + SSD1309)——第2章——构建简单的ESP8266驱动
stm32·单片机·物联网·学习·c·esp8266
南宫生7 小时前
Java迭代器【设计模式之迭代器模式】
java·学习·设计模式·kotlin·迭代器模式
虾球xz7 小时前
游戏引擎学习第203天
学习·游戏引擎
WDeLiang8 小时前
Flask学习笔记 - 模板渲染
笔记·学习·flask
明月清了个风8 小时前
数据结构与算法学习笔记----贪心区间问题
笔记·学习·算法·贪心算法
因为奋斗超太帅啦8 小时前
MySQL学习笔记(一)——MySQL下载安装配置
笔记·学习·mysql
aoxiang_ywj9 小时前
【Linux】内核驱动学习笔记(二)
linux·笔记·学习
WhyNot?10 小时前
深度学习入门(三):神经网络的学习
深度学习·神经网络·学习