一、数据介绍
CIFAR 是机器学习和计算机视觉领域中广泛使用的图像分类基准数据集,由加拿大高级研究学院(Canadian Institute for Advanced Research,CIFAR)的研究团队发布,主要用于小尺寸图像的分类任务,是入门和验证图像分类模型性能的经典数据集。
1、数据集的核心版本
CIFAR 数据集主要分为两个核心版本,二者在类别复杂度和样本划分上有明显区别:
-
CIFAR-10
- 类别数量:10 个互斥的图像类别,包括飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。
- 样本规模:共 60000 张 32×32 的彩色 RGB 图像,每个类别包含 6000 张图像;其中 50000 张为训练集(每个类别 5000 张),10000 张为测试集(每个类别 1000 张)。
- 特点:类别间区分度相对清晰,适合入门级图像分类模型的训练和验证。
-
CIFAR-100
- 类别层级:包含 100 个细分类别,同时这些类别又归属于 20 个粗分类别(如 "水生哺乳动物" 包含海狮、海豹等细分类别)。
- 样本规模:同样是 60000 张 32×32 的彩色 RGB 图像,每个细分类别包含 600 张图像;训练集 50000 张(每个细分类别 500 张),测试集 10000 张(每个细分类别 100 张)。
- 特点:类别数量更多且部分类别相似度高,任务难度显著高于 CIFAR-10,常用于验证模型的细粒度分类能力。
2、数据集特点
- 图像尺寸小:32×32 的分辨率远低于真实场景的图像,模型学习到的特征相对有限,容易出现过拟合。
- 数据多样性:图像涵盖了自然和人造物体,且包含不同角度、光照和背景的样本,能一定程度上模拟真实世界的图像分布。
- 无标注噪声:数据集标注质量高,无明显标注错误,适合作为模型性能的客观基准。
二、实例化
1. 数据预处理
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作
# 1. 数据预处理:先归一化到[0,1],再标准化(适配CIFAR-10的3通道参数)
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量,将HWC格式的PIL图转为CHW格式,同时归一化到[0,1]
# CIFAR-10经典的均值和标准差(3通道,对应RGB)
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
std=(0.2470, 0.2435, 0.2616))
])
# 2. 加载CIFAR-10数据集(修改root为压缩包所在文件夹)
train_dataset = datasets.CIFAR10(
root=r"D:\PythonStudy", # 压缩包所在的文件夹路径(关键!不是压缩包本身)
train=True,
download=True, # 检测到文件夹内有压缩包时,仅解压不下载
transform=transform
)
test_dataset = datasets.CIFAR10(
root=r"D:\PythonStudy", # 同训练集的root路径
train=False,
transform=transform
)
# 可选:验证数据集基本信息
print(f"CIFAR-10训练集数量:{len(train_dataset)}")
print(f"CIFAR-10测试集数量:{len(test_dataset)}")
print(f"单张图片形状:{train_dataset[0][0].shape}") # 输出 torch.Size([3, 32, 32])
2. Dataset类
现在我们想要取出来一个图片,看看长啥样,因为datasets.CIFAR10(若使用 CIFAR-100 则为datasets.CIFAR100)本质上继承了torch.utils.data.Dataset,所以自然需要有对应的方法。
python
import matplotlib.pyplot as plt
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:
len():返回数据集的样本总数。
getitem(idx):根据索引idx返回对应样本的数据和标签。
PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。
在 Python 中,getitem__和__len 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:
1.__getitem__方法
__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。
通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。
2.__len__方法__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。
python
from torch.utils.data import Dataset
# CIFAR-10数据集的简化版本
class CIFAR10(Dataset):
def __init__(self, root, train=True, transform=None):
# 初始化:加载图片路径和标签
self.data, self.targets = fetch_cifar10_data(root, train) 的函数
self.transform = transform # 预处理操作
def __len__(self):
return len(self.data) # 返回样本总数
def __getitem__(self, idx): # 获取指定索引的样本
# 获取指定索引的图像和标签
img, target = self.data[idx], self.targets[idx]
# 应用图像预处理(如ToTensor、Normalize)
if self.transform is not None: # 如果有预处理操作
img = self.transform(img) # 转换图像格式
# 这里假设 img 是一个 PIL 图像对象(CIFAR-10为RGB格式),transform 会将其转换为张量并进行归一化/标准化
return img, target # 返回处理后的图像和标签
1. 模块导入:
from torch.utils.data import Dataset
- 作用 :
Dataset是 PyTorch 中所有自定义数据集的抽象基类 ,自定义数据集必须继承它,且必须实现__len__和__getitem__两个核心方法,否则会报错。- 为什么必须继承 :PyTorch 的
DataLoader(数据加载器)只能处理继承自Dataset的类,这是 PyTorch 数据加载流水线的核心规范。2. 类定义:
class CIFAR10(Dataset)
- 作用 :定义专用于 CIFAR-10 的数据集类,命名贴合数据集类型(若适配 CIFAR-100 则改为
CIFAR100),继承Dataset后具备 PyTorch 数据集的标准特性。3.
__init__方法(初始化)
代码行 核心作用 细节说明 self.data, self.targets = fetch_cifar10_data(root, train)加载数据和标签 ① fetch_cifar10_data是自定义数据加载函数,作用是从root路径读取 CIFAR-10 的图像(RGB PIL 格式)和标签;②train参数区分训练集 / 测试集,逻辑与 MNIST 版一致;③ 最终self.data存储所有图像,self.targets存储所有标签。self.transform = transform接收预处理逻辑 ① 保存外部传入的预处理操作(如 ToTensor()、Normalize());② 预处理逻辑与 MNIST 版通用,仅 CIFAR-10 需适配 3 通道的归一化参数;③ 允许外部灵活调整预处理(如增广、归一化等),不修改数据集类本身。4.
__len__方法(返回样本总数)
- 代码 :
return len(self.data)- 作用 :① 告诉
DataLoader数据集的总样本数,是批量采样、打乱、多进程加载的基础;② 调用len(train_dataset)时会触发该方法,比如验证数据集大小(如 CIFAR-10 训练集返回 50000,测试集返回 10000);③ 逻辑与 MNIST 版完全一致,仅self.data的长度对应 CIFAR-10 的样本数。5.
__getitem__方法(按索引取样本,核心)这是数据集类的核心方法 ,
DataLoader迭代时会反复调用该方法获取单个样本,拆解作用如下:
代码行 核心作用 细节说明 img, target = self.data[idx], self.targets[idx]读取原始样本 ① idx是传入的样本索引(0~len (data)-1);②img是原始 RGB PIL 图像(32×32×3),target是 0-9 的整数标签(对应 CIFAR-10 的类别);③ 逻辑与 MNIST 版一致,仅img从单通道灰度图变为 3 通道 RGB 图。if self.transform is not None: img = self.transform(img)应用预处理 ① 对原始 PIL 图像执行预处理(如 ToTensor()转为 C×H×W 张量、Normalize()标准化);② 预处理是可选的(若transform=None则返回原始 PIL 图像);③ 逻辑与 MNIST 版完全一致,仅预处理参数适配 CIFAR-10 的 3 通道。return img, target返回样本 ① 标准返回格式:(图像张量,标签),是 PyTorch 模型训练 / 测试的输入格式;② 图像张量形状为 [3, 32, 32](CIFAR-10),MNIST 版为[1, 28, 28],仅形状差异,返回逻辑不变。
python
# 3. 取出单个样本
image, label = train_dataset[5]
# 4. 修复后的可视化函数(核心修改)
def imshow(img):
# 关键:将均值/标准差转为PyTorch张量,并调整维度为[3,1,1]适配广播
std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
# 反标准化(张量间运算,支持广播)
img = img * std + mean
# 限制像素值在[0,1](避免数值溢出导致显示异常)
img = torch.clamp(img, 0, 1)
# 转为numpy并调整维度(C×H×W → H×W×C)
npimg = img.numpy()
plt.imshow(npimg.transpose((1, 2, 0)))
plt.axis('off') # 隐藏坐标轴
plt.show()
# 5. 调用可视化(逻辑不变)
print(f"Label: {label}")
imshow(image)

3. Dataloader类
DataLoader是 PyTorch 封装的批量数据加载器 ,核心作用是将Dataset(数据集,定义了 "如何取单个样本")封装为 "批量迭代器",适配模型训练 / 测试时的批量输入需求,同时支持打乱、多进程加载等优化。
python
# 3. 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
shuffle=True # 随机打乱数据
)
test_loader = DataLoader(
test_dataset,
batch_size=1000 # 每个批次1000张图片
# shuffle=False # 测试时不需要打乱数据
)
训练集加载器(train_loader)
基于训练数据集
train_dataset创建批量加载器:
batch_size=64:每次加载 64 张图片为一个批次(选 2 的幂次方适配 GPU 计算效率);shuffle=True:训练前随机打乱数据顺序,避免模型学 "数据顺序",提升泛化能力。测试集加载器(test_loader)
基于测试数据集
test_dataset创建批量加载器:
batch_size=1000:测试集批次更大(无训练开销,提速评估);- 默认
shuffle=False:测试不打乱数据,保证结果可复现,也节省计算开销。核心:
DataLoader把数据集封装成 "批量迭代器",训练侧重打乱 + 小批次适配 GPU,测试侧重大批次提速 + 不打乱保复现。
三、总结

核心结论
- Dataset 类:定义数据的内容和格式(即 "如何获取单个样本"),包括:
- 数据存储路径 / 来源(如文件路径、数据库查询)。
- 原始数据的读取方式(如图像解码为 PIL 对象、文本读取为字符串)。
- 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过 transform 参数实现)。
- 返回值格式(如 (image_tensor, label))。
- DataLoader 类:定义数据的加载方式和批量处理逻辑(即 "如何高效批量获取数据"),包括:
- 批量大小(batch_size)。
- 是否打乱数据顺序(shuffle)。
勇闯python的第39天@浙大疏锦行