1. Dataset 类:"存数据的容器"
你可以把它理解成一个数据盒子,里面装着你的数据集(比如图片、标签)。要让这个 "盒子" 能用,得给它加两个 "功能按钮"(Python 的特殊方法):
__getitem__(self, idx):按索引idx取数据(比如dataset[0]就能拿到第 1 个样本)。__len__(self):返回数据集的总样本数(比如len(dataset)知道一共有多少数据)。
2. DataLoader 类:"给模型端菜的服务员"
Dataset 是 "装菜的盘子",DataLoader 就是 "把菜分成小份、端给模型吃" 的人。它的作用是:
- 把
Dataset里的数据分成批次(比如一次给模型喂 32 个样本,而不是全塞进去)。 - 支持打乱数据(避免模型学 "顺序" 而不是 "规律")。
- 支持多线程加载(加快数据读取速度)。
3. MNIST 手写数据集
这是一个 "手写数字图片集",里面是 0-9 的手写数字(每张图是 28x28 的黑白色),是深度学习入门常用的 "练习数据集"。
作业:获取 CIFAR 数据集的一张图片
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# 1. 加载CIFAR数据集(自动下载到本地)
transform = transforms.ToTensor() # 把图片转成PyTorch能处理的格式
cifar_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# 2. 取第1张图和它的标签
img, label = cifar_dataset[0] # 用Dataset的__getitem__取数据
# 3. 显示图片
plt.imshow(img.permute(1, 2, 0)) # 调整格式(PyTorch是[通道,高,宽],plt需要[高,宽,通道])
plt.title(f"Label: {label}") # 显示标签(CIFAR10的标签是0-9,对应不同类别)
plt.show()