DAY 39 Dataset和Dataloader

一、数据介绍

CIFAR 是机器学习和计算机视觉领域中广泛使用的图像分类基准数据集,由加拿大高级研究学院(Canadian Institute for Advanced Research,CIFAR)的研究团队发布,主要用于小尺寸图像的分类任务,是入门和验证图像分类模型性能的经典数据集。

1、数据集的核心版本

CIFAR 数据集主要分为两个核心版本,二者在类别复杂度和样本划分上有明显区别:

  1. CIFAR-10

    • 类别数量:10 个互斥的图像类别,包括飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。
    • 样本规模:共 60000 张 32×32 的彩色 RGB 图像,每个类别包含 6000 张图像;其中 50000 张为训练集(每个类别 5000 张),10000 张为测试集(每个类别 1000 张)。
    • 特点:类别间区分度相对清晰,适合入门级图像分类模型的训练和验证。
  2. CIFAR-100

    • 类别层级:包含 100 个细分类别,同时这些类别又归属于 20 个粗分类别(如 "水生哺乳动物" 包含海狮、海豹等细分类别)。
    • 样本规模:同样是 60000 张 32×32 的彩色 RGB 图像,每个细分类别包含 600 张图像;训练集 50000 张(每个细分类别 500 张),测试集 10000 张(每个细分类别 100 张)。
    • 特点:类别数量更多且部分类别相似度高,任务难度显著高于 CIFAR-10,常用于验证模型的细粒度分类能力。

2、数据集特点

  1. 图像尺寸小:32×32 的分辨率远低于真实场景的图像,模型学习到的特征相对有限,容易出现过拟合。
  2. 数据多样性:图像涵盖了自然和人造物体,且包含不同角度、光照和背景的样本,能一定程度上模拟真实世界的图像分布。
  3. 无标注噪声:数据集标注质量高,无明显标注错误,适合作为模型性能的客观基准。

二、实例化

1. 数据预处理

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)

# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作

# 1. 数据预处理:先归一化到[0,1],再标准化(适配CIFAR-10的3通道参数)
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量,将HWC格式的PIL图转为CHW格式,同时归一化到[0,1]
    # CIFAR-10经典的均值和标准差(3通道,对应RGB)
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                         std=(0.2470, 0.2435, 0.2616))  
])

# 2. 加载CIFAR-10数据集(修改root为压缩包所在文件夹)
train_dataset = datasets.CIFAR10(
    root=r"D:\PythonStudy",  # 压缩包所在的文件夹路径(关键!不是压缩包本身)
    train=True,             
    download=True,           # 检测到文件夹内有压缩包时,仅解压不下载
    transform=transform     
)

test_dataset = datasets.CIFAR10(
    root=r"D:\PythonStudy",  # 同训练集的root路径
    train=False,            
    transform=transform
)

# 可选:验证数据集基本信息
print(f"CIFAR-10训练集数量:{len(train_dataset)}")
print(f"CIFAR-10测试集数量:{len(test_dataset)}")
print(f"单张图片形状:{train_dataset[0][0].shape}")  # 输出 torch.Size([3, 32, 32])

2. Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.CIFAR10(若使用 CIFAR-100 则为datasets.CIFAR100)本质上继承了torch.utils.data.Dataset,所以自然需要有对应的方法。

python 复制代码
import matplotlib.pyplot as plt

# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

  • len():返回数据集的样本总数。

  • getitem(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

在 Python 中,getitem__和__len 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:

1.__getitem__方法

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。
2.__len__方法

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

python 复制代码
from torch.utils.data import Dataset

# CIFAR-10数据集的简化版本
class CIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None):
        # 初始化:加载图片路径和标签
        self.data, self.targets = fetch_cifar10_data(root, train) 的函数
        self.transform = transform  # 预处理操作
        
    def __len__(self): 
        return len(self.data)  # 返回样本总数
    
    def __getitem__(self, idx):  # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None:  # 如果有预处理操作
            img = self.transform(img)  # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象(CIFAR-10为RGB格式),transform 会将其转换为张量并进行归一化/标准化
            
        return img, target  # 返回处理后的图像和标签

1. 模块导入:from torch.utils.data import Dataset

  • 作用Dataset 是 PyTorch 中所有自定义数据集的抽象基类 ,自定义数据集必须继承它,且必须实现 __len____getitem__ 两个核心方法,否则会报错。
  • 为什么必须继承 :PyTorch 的 DataLoader(数据加载器)只能处理继承自 Dataset 的类,这是 PyTorch 数据加载流水线的核心规范。

2. 类定义:class CIFAR10(Dataset)

  • 作用 :定义专用于 CIFAR-10 的数据集类,命名贴合数据集类型(若适配 CIFAR-100 则改为CIFAR100),继承Dataset后具备 PyTorch 数据集的标准特性。

3. __init__ 方法(初始化)

代码行 核心作用 细节说明
self.data, self.targets = fetch_cifar10_data(root, train) 加载数据和标签 fetch_cifar10_data 是自定义数据加载函数,作用是从root路径读取 CIFAR-10 的图像(RGB PIL 格式)和标签;② train参数区分训练集 / 测试集,逻辑与 MNIST 版一致;③ 最终self.data存储所有图像,self.targets存储所有标签。
self.transform = transform 接收预处理逻辑 ① 保存外部传入的预处理操作(如ToTensor()Normalize());② 预处理逻辑与 MNIST 版通用,仅 CIFAR-10 需适配 3 通道的归一化参数;③ 允许外部灵活调整预处理(如增广、归一化等),不修改数据集类本身。

4. __len__ 方法(返回样本总数)

  • 代码return len(self.data)
  • 作用 :① 告诉DataLoader数据集的总样本数,是批量采样、打乱、多进程加载的基础;② 调用len(train_dataset)时会触发该方法,比如验证数据集大小(如 CIFAR-10 训练集返回 50000,测试集返回 10000);③ 逻辑与 MNIST 版完全一致,仅self.data的长度对应 CIFAR-10 的样本数。

5. __getitem__ 方法(按索引取样本,核心)

这是数据集类的核心方法DataLoader 迭代时会反复调用该方法获取单个样本,拆解作用如下:

代码行 核心作用 细节说明
img, target = self.data[idx], self.targets[idx] 读取原始样本 idx是传入的样本索引(0~len (data)-1);② img是原始 RGB PIL 图像(32×32×3),target是 0-9 的整数标签(对应 CIFAR-10 的类别);③ 逻辑与 MNIST 版一致,仅img从单通道灰度图变为 3 通道 RGB 图。
if self.transform is not None: img = self.transform(img) 应用预处理 ① 对原始 PIL 图像执行预处理(如ToTensor()转为 C×H×W 张量、Normalize()标准化);② 预处理是可选的(若transform=None则返回原始 PIL 图像);③ 逻辑与 MNIST 版完全一致,仅预处理参数适配 CIFAR-10 的 3 通道。
return img, target 返回样本 ① 标准返回格式:(图像张量,标签),是 PyTorch 模型训练 / 测试的输入格式;② 图像张量形状为[3, 32, 32](CIFAR-10),MNIST 版为[1, 28, 28],仅形状差异,返回逻辑不变。
python 复制代码
# 3. 取出单个样本
image, label = train_dataset[5]

# 4. 修复后的可视化函数(核心修改)
def imshow(img):
    # 关键:将均值/标准差转为PyTorch张量,并调整维度为[3,1,1]适配广播
    std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    
    # 反标准化(张量间运算,支持广播)
    img = img * std + mean  
    # 限制像素值在[0,1](避免数值溢出导致显示异常)
    img = torch.clamp(img, 0, 1)
    
    # 转为numpy并调整维度(C×H×W → H×W×C)
    npimg = img.numpy()
    plt.imshow(npimg.transpose((1, 2, 0)))
    plt.axis('off')  # 隐藏坐标轴
    plt.show()

# 5. 调用可视化(逻辑不变)
print(f"Label: {label}")
imshow(image)

3. Dataloader类

DataLoader是 PyTorch 封装的批量数据加载器 ,核心作用是将Dataset(数据集,定义了 "如何取单个样本")封装为 "批量迭代器",适配模型训练 / 测试时的批量输入需求,同时支持打乱、多进程加载等优化。

python 复制代码
# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)

训练集加载器(train_loader)

基于训练数据集train_dataset创建批量加载器:

  • batch_size=64:每次加载 64 张图片为一个批次(选 2 的幂次方适配 GPU 计算效率);
  • shuffle=True:训练前随机打乱数据顺序,避免模型学 "数据顺序",提升泛化能力。

测试集加载器(test_loader)

基于测试数据集test_dataset创建批量加载器:

  • batch_size=1000:测试集批次更大(无训练开销,提速评估);
  • 默认shuffle=False:测试不打乱数据,保证结果可复现,也节省计算开销。

核心:DataLoader把数据集封装成 "批量迭代器",训练侧重打乱 + 小批次适配 GPU,测试侧重大批次提速 + 不打乱保复现。

三、总结

核心结论

  • Dataset 类:定义数据的内容和格式(即 "如何获取单个样本"),包括:
    • 数据存储路径 / 来源(如文件路径、数据库查询)。
    • 原始数据的读取方式(如图像解码为 PIL 对象、文本读取为字符串)。
    • 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过 transform 参数实现)。
    • 返回值格式(如 (image_tensor, label))。
  • DataLoader 类:定义数据的加载方式和批量处理逻辑(即 "如何高效批量获取数据"),包括:
    • 批量大小(batch_size)。
    • 是否打乱数据顺序(shuffle)。

勇闯python的第39天@浙大疏锦行

相关推荐
冬奇Lab21 分钟前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab22 分钟前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP4 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年4 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼4 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS5 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区6 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈6 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang6 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk18 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能