DAY 38 Dataset 和 Dataloader 类
知识点回顾:
- Dataset 类的__getitem__和__len__方法(本质是 python 的特殊方法)
- Dataloader 类
- minist 手写数据集的了解
作业:了解下 cifar 数据集,尝试获取其中一张图片
这份笔记非常清晰地梳理了 PyTorch 数据处理的核心双雄:Dataset 和 DataLoader。通过"厨师与服务员"的比喻,直观地解释了个体样本处理 与批量物流分发的关系。
以下是针对 DAY 38 的排版整理以及为您准备的 CIFAR-10 探究作业指南。
📘 DAY 38:Dataset 与 DataLoader 类
1. 核心概念对比
在处理大规模数据(如图像、长文本)时,内存无法一次性承载,必须采用"分而治之"的策略。
| 维度 | Dataset (厨师) |
DataLoader (服务员) |
|---|---|---|
| 核心职责 | 定义"数据是什么"及"如何获取单菜品" | 定义"如何打包"及"上菜策略" |
| 魔术方法 | __len__: 告知总数 |
__getitem__: 按索引取样 | (内部调用 Dataset 的方法) |
| 关键参数 | root, transform (预处理) | batch_size, shuffle, num_workers |
| 并行计算 | 无 | 支持多进程加载,提高 GPU 利用率 |
2. MNIST 数据集详解
- 规模:60,000 张训练图,10,000 张测试图。
- 规格: 像素,单通道灰度图。
- 类别:0-9 共 10 个数字。
- 特性:由于维度适中,它是验证 MLP(多层感知机)和 CNN(卷积神经网络)的"标准试金石"。
3. 预处理的"反逻辑"
在 PyTorch 中,预处理通常在加载阶段 通过 transform 管道完成。
python
# 典型组合:转为张量 -> 归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
- 注意 :
Dataset并不预先处理好所有图片存储在内存,而是在DataLoader迭代时,通过__getitem__实时进行变换,这极大地节省了内存空间。
📝 今日作业:探索 CIFAR-10 数据集
任务描述 :
了解 CIFAR-10 数据集并尝试获取、显示其中一张图片。
1. 什么是 CIFAR-10?
与纯黑白的 MNIST 不同,CIFAR-10 是更接近现实场景的数据集:
- 内容:包含 10 个类别的彩色汽车、飞机、猫、狗等。
- 规格 : 像素,RGB 三通道彩色图。
- 难度:由于背景复杂、物体姿态各异,识别难度显著高于 MNIST。
2. 作业代码实现
python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 1. 定义基本的转换(彩色图像归一化通常使用 0.5)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 2. 下载并加载 CIFAR-10 训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
# 3. 定义类别名称(CIFAR-10 的标签是数字 0-9,对应以下名称)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# 4. 获取一张随机图片
idx = torch.randint(0, len(trainset), (1,)).item()
image, label = trainset[idx]
# 5. 可视化
def imshow_cifar(img):
img = img / 2 + 0.5 # 反归一化
npimg = img.numpy()
# 注意:PyTorch 张量是 [C, H, W],但 Matplotlib 需要 [H, W, C]
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(f"Label: {classes[label]}")
plt.show()
imshow_cifar(image)