DAY44 Dataset和Dataloader类

目录

[1. Dataset 类:定义"数据是什么"](#1. Dataset 类:定义“数据是什么”)

[2. DataLoader 类:定义"怎么拿数据"](#2. DataLoader 类:定义“怎么拿数据”)

[3. MNIST 手写数字数据集](#3. MNIST 手写数字数据集)

附例:


1. Dataset 类:定义"数据是什么"

Dataset 类是 PyTorch 数据读取的基类,它的核心作用是定义数据的来源、索引方式和预处理逻辑

它要求用户重写两个Python 的特殊方法(Magic Methods),这使得自定义的数据集对象可以像 Python 原生列表(List)一样被操作:

  • __len__(self)

    • 作用:定义数据集的总样本数量。

    • Python 本质 :当你对数据集对象调用 len(dataset) 时,Python 解释器会自动调用这个方法。

    • 重要性DataLoader 需要通过它来计算一个 Epoch 需要遍历多少个 Batch。

  • __getitem__(self, index)

    • 作用 :定义如何获取第 index 个样本

    • Python 本质 :当你使用索引操作 dataset[3] 时,Python 解释器会自动调用这个方法,传入 index=3

    • 内部逻辑:通常包含三个步骤:

      1. 定位 :根据 index 找到对应的文件路径或数据行。

      2. 加载 :读取图片(如 PIL.Image.open)或读取文本/数值。

      3. 变换 (Transform):进行预处理(如归一化、裁剪、转 Tensor)。

    • 返回值 :通常返回一个元组,例如 (image_tensor, label)


2. DataLoader 类:定义"怎么拿数据"

如果说 Dataset 是仓库管理员(知道东西在哪),那么 DataLoader 就是物流配送员(负责打包和运输)。

DataLoader 本身不需要你重写复杂的方法,它通过参数配置来控制加载行为。它的核心职责是将 Dataset 中获取的单个样本 组合成批次(Batch)

  • 核心功能

    • Batching(批处理) :将 __getitem__ 返回的多个样本堆叠(Stack)成一个大 Tensor。例如,将 64 张 [1, 28, 28] 的图片堆叠成 [64, 1, 28, 28]

    • Shuffling(打乱) :在每个 Epoch 开始前打乱数据顺序,防止模型记忆样本顺序(通过 shuffle=True 开启)。

    • Parallelism(多进程) :使用多进程同时读取数据,加速训练(通过 num_workers 参数控制)。

  • Dataset 与 DataLoader 的关系:

    DataLoader 是一个迭代器(Iterator),我们在训练循环 for inputs, labels in dataloader: 中使用的是 DataLoader,而 DataLoader 内部会不断调用 Dataset 的 getitem 方法来凑齐数据。


3. MNIST 手写数字数据集

这是深度学习领域的 "Hello World" 级别的数据集。

  • 数据内容

    • 包含 0~9 的手写数字图片。

    • 训练集:60,000 张。

    • 测试集:10,000 张。

  • 图像特征

    • 尺寸:28 × 28 像素。

    • 通道:单通道(灰度图,Gray Scale)。

    • 数据格式 :通常在 PyTorch 中会被处理为 [1, 28, 28] 的 Tensor。

  • 适用性

    • 由于维度较小(28*28=784特征),它既适合作为结构化数据 用 MLP(多层感知机)训练,也适合作为图像数据用 CNN(卷积神经网络)训练。
  • 可视化

    • 原始数据通常是归一化过的(如均值0.1307,方差0.3081),可视化时需要进行反归一化操作才能看到清晰的黑底白字图像。

总结图解:

复制代码
[硬盘上的文件] 
      ↓ 
Dataset (__getitem__)  --> 负责:找文件 -> 读文件 -> 预处理 -> 返回单条数据 (img, label)
      ↓
DataLoader (batch_size=64) --> 负责:调用64次Dataset -> 打包成一捆 -> 乱序 -> 多进程加速
      ↓
[模型训练循环] (for x, y in loader) --> 获得形状为 [64, 1, 28, 28] 的 Tensor

附例:

cifar 数据集尝试获取其中一张图片

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# 1. 定义预处理
# 这里我们暂时只转为 Tensor,不进行标准化 (Normalization),
# 这样显示时就不需要反标准化计算,方便直观查看原始图片。
transform = transforms.Compose([
    transforms.ToTensor()
])

# 2. 下载并加载 CIFAR-10 数据集
# root: 数据存放路径
# train: True 表示训练集
# download: True 表示如果本地没有就下载
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

# CIFAR-10 的 10 个类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# 3. 获取一张图片 (调用 Dataset 的 __getitem__ 方法)
# 我们取第 100 张图片试试
index = 100
image_tensor, label_index = trainset[index]

# 打印一下数据的形状
print(f"数据索引: {index}")
print(f"Tensor 形状: {image_tensor.shape}") # 预期: torch.Size([3, 32, 32])
print(f"标签索引: {label_index} ({classes[label_index]})")

# 4. 可视化函数
def imshow(img_tensor):
    # img_tensor 目前是 [3, 32, 32] (Channel, Height, Width)
    # Matplotlib 画图需要 [32, 32, 3] (Height, Width, Channel)
    
    # 使用 .permute() 交换维度: 0->1, 1->2, 2->0
    # 或者用 numpy.transpose(1, 2, 0)
    img_np = img_tensor.permute(1, 2, 0).numpy()
    
    plt.imshow(img_np)
    plt.axis('off') # 不显示坐标轴
    plt.show()

print("\n显示图片:")
imshow(image_tensor)

@浙大疏锦行

相关推荐
ZH154558913119 小时前
Flutter for OpenHarmony Python学习助手实战:面向对象编程实战的实现
python·学习·flutter
玄同76519 小时前
SQLite + LLM:大模型应用落地的轻量级数据存储方案
jvm·数据库·人工智能·python·语言模型·sqlite·知识图谱
心疼你的一切19 小时前
模态交响:CANN驱动的跨模态AIGC统一架构
数据仓库·深度学习·架构·aigc·cann
User_芊芊君子19 小时前
CANN010:PyASC Python编程接口—简化AI算子开发的Python框架
开发语言·人工智能·python
小羊不会打字19 小时前
CANN 生态中的跨框架兼容桥梁:`onnx-adapter` 项目实现无缝模型迁移
c++·深度学习
白日做梦Q20 小时前
Anchor-free检测器全解析:CenterNet vs FCOS
python·深度学习·神经网络·目标检测·机器学习
饭饭大王66620 小时前
CANN 生态中的自动化测试利器:`test-automation` 项目保障模型部署可靠性
深度学习
island131420 小时前
CANN HIXL 通信库深度解析:单边点对点数据传输、异步模型与异构设备间显存直接访问
人工智能·深度学习·神经网络
喵手20 小时前
Python爬虫实战:公共自行车站点智能采集系统 - 从零构建生产级爬虫的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集公共自行车站点·公共自行车站点智能采集系统·采集公共自行车站点导出csv
心疼你的一切20 小时前
解锁CANN仓库核心能力:从零搭建AIGC轻量文本生成实战(附代码+流程图)
数据仓库·深度学习·aigc·流程图·cann