DAY 39 Dataset和Dataloader

一、数据介绍

CIFAR 是机器学习和计算机视觉领域中广泛使用的图像分类基准数据集,由加拿大高级研究学院(Canadian Institute for Advanced Research,CIFAR)的研究团队发布,主要用于小尺寸图像的分类任务,是入门和验证图像分类模型性能的经典数据集。

1、数据集的核心版本

CIFAR 数据集主要分为两个核心版本,二者在类别复杂度和样本划分上有明显区别:

  1. CIFAR-10

    • 类别数量:10 个互斥的图像类别,包括飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。
    • 样本规模:共 60000 张 32×32 的彩色 RGB 图像,每个类别包含 6000 张图像;其中 50000 张为训练集(每个类别 5000 张),10000 张为测试集(每个类别 1000 张)。
    • 特点:类别间区分度相对清晰,适合入门级图像分类模型的训练和验证。
  2. CIFAR-100

    • 类别层级:包含 100 个细分类别,同时这些类别又归属于 20 个粗分类别(如 "水生哺乳动物" 包含海狮、海豹等细分类别)。
    • 样本规模:同样是 60000 张 32×32 的彩色 RGB 图像,每个细分类别包含 600 张图像;训练集 50000 张(每个细分类别 500 张),测试集 10000 张(每个细分类别 100 张)。
    • 特点:类别数量更多且部分类别相似度高,任务难度显著高于 CIFAR-10,常用于验证模型的细粒度分类能力。

2、数据集特点

  1. 图像尺寸小:32×32 的分辨率远低于真实场景的图像,模型学习到的特征相对有限,容易出现过拟合。
  2. 数据多样性:图像涵盖了自然和人造物体,且包含不同角度、光照和背景的样本,能一定程度上模拟真实世界的图像分布。
  3. 无标注噪声:数据集标注质量高,无明显标注错误,适合作为模型性能的客观基准。

二、实例化

1. 数据预处理

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)

# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作

# 1. 数据预处理:先归一化到[0,1],再标准化(适配CIFAR-10的3通道参数)
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量,将HWC格式的PIL图转为CHW格式,同时归一化到[0,1]
    # CIFAR-10经典的均值和标准差(3通道,对应RGB)
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                         std=(0.2470, 0.2435, 0.2616))  
])

# 2. 加载CIFAR-10数据集(修改root为压缩包所在文件夹)
train_dataset = datasets.CIFAR10(
    root=r"D:\PythonStudy",  # 压缩包所在的文件夹路径(关键!不是压缩包本身)
    train=True,             
    download=True,           # 检测到文件夹内有压缩包时,仅解压不下载
    transform=transform     
)

test_dataset = datasets.CIFAR10(
    root=r"D:\PythonStudy",  # 同训练集的root路径
    train=False,            
    transform=transform
)

# 可选:验证数据集基本信息
print(f"CIFAR-10训练集数量:{len(train_dataset)}")
print(f"CIFAR-10测试集数量:{len(test_dataset)}")
print(f"单张图片形状:{train_dataset[0][0].shape}")  # 输出 torch.Size([3, 32, 32])

2. Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.CIFAR10(若使用 CIFAR-100 则为datasets.CIFAR100)本质上继承了torch.utils.data.Dataset,所以自然需要有对应的方法。

python 复制代码
import matplotlib.pyplot as plt

# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

  • len():返回数据集的样本总数。

  • getitem(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

在 Python 中,getitem__和__len 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:

1.__getitem__方法

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。
2.__len__方法

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

python 复制代码
from torch.utils.data import Dataset

# CIFAR-10数据集的简化版本
class CIFAR10(Dataset):
    def __init__(self, root, train=True, transform=None):
        # 初始化:加载图片路径和标签
        self.data, self.targets = fetch_cifar10_data(root, train) 的函数
        self.transform = transform  # 预处理操作
        
    def __len__(self): 
        return len(self.data)  # 返回样本总数
    
    def __getitem__(self, idx):  # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None:  # 如果有预处理操作
            img = self.transform(img)  # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象(CIFAR-10为RGB格式),transform 会将其转换为张量并进行归一化/标准化
            
        return img, target  # 返回处理后的图像和标签

1. 模块导入:from torch.utils.data import Dataset

  • 作用Dataset 是 PyTorch 中所有自定义数据集的抽象基类 ,自定义数据集必须继承它,且必须实现 __len____getitem__ 两个核心方法,否则会报错。
  • 为什么必须继承 :PyTorch 的 DataLoader(数据加载器)只能处理继承自 Dataset 的类,这是 PyTorch 数据加载流水线的核心规范。

2. 类定义:class CIFAR10(Dataset)

  • 作用 :定义专用于 CIFAR-10 的数据集类,命名贴合数据集类型(若适配 CIFAR-100 则改为CIFAR100),继承Dataset后具备 PyTorch 数据集的标准特性。

3. __init__ 方法(初始化)

代码行 核心作用 细节说明
self.data, self.targets = fetch_cifar10_data(root, train) 加载数据和标签 fetch_cifar10_data 是自定义数据加载函数,作用是从root路径读取 CIFAR-10 的图像(RGB PIL 格式)和标签;② train参数区分训练集 / 测试集,逻辑与 MNIST 版一致;③ 最终self.data存储所有图像,self.targets存储所有标签。
self.transform = transform 接收预处理逻辑 ① 保存外部传入的预处理操作(如ToTensor()Normalize());② 预处理逻辑与 MNIST 版通用,仅 CIFAR-10 需适配 3 通道的归一化参数;③ 允许外部灵活调整预处理(如增广、归一化等),不修改数据集类本身。

4. __len__ 方法(返回样本总数)

  • 代码return len(self.data)
  • 作用 :① 告诉DataLoader数据集的总样本数,是批量采样、打乱、多进程加载的基础;② 调用len(train_dataset)时会触发该方法,比如验证数据集大小(如 CIFAR-10 训练集返回 50000,测试集返回 10000);③ 逻辑与 MNIST 版完全一致,仅self.data的长度对应 CIFAR-10 的样本数。

5. __getitem__ 方法(按索引取样本,核心)

这是数据集类的核心方法DataLoader 迭代时会反复调用该方法获取单个样本,拆解作用如下:

代码行 核心作用 细节说明
img, target = self.data[idx], self.targets[idx] 读取原始样本 idx是传入的样本索引(0~len (data)-1);② img是原始 RGB PIL 图像(32×32×3),target是 0-9 的整数标签(对应 CIFAR-10 的类别);③ 逻辑与 MNIST 版一致,仅img从单通道灰度图变为 3 通道 RGB 图。
if self.transform is not None: img = self.transform(img) 应用预处理 ① 对原始 PIL 图像执行预处理(如ToTensor()转为 C×H×W 张量、Normalize()标准化);② 预处理是可选的(若transform=None则返回原始 PIL 图像);③ 逻辑与 MNIST 版完全一致,仅预处理参数适配 CIFAR-10 的 3 通道。
return img, target 返回样本 ① 标准返回格式:(图像张量,标签),是 PyTorch 模型训练 / 测试的输入格式;② 图像张量形状为[3, 32, 32](CIFAR-10),MNIST 版为[1, 28, 28],仅形状差异,返回逻辑不变。
python 复制代码
# 3. 取出单个样本
image, label = train_dataset[5]

# 4. 修复后的可视化函数(核心修改)
def imshow(img):
    # 关键:将均值/标准差转为PyTorch张量,并调整维度为[3,1,1]适配广播
    std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
    mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    
    # 反标准化(张量间运算,支持广播)
    img = img * std + mean  
    # 限制像素值在[0,1](避免数值溢出导致显示异常)
    img = torch.clamp(img, 0, 1)
    
    # 转为numpy并调整维度(C×H×W → H×W×C)
    npimg = img.numpy()
    plt.imshow(npimg.transpose((1, 2, 0)))
    plt.axis('off')  # 隐藏坐标轴
    plt.show()

# 5. 调用可视化(逻辑不变)
print(f"Label: {label}")
imshow(image)

3. Dataloader类

DataLoader是 PyTorch 封装的批量数据加载器 ,核心作用是将Dataset(数据集,定义了 "如何取单个样本")封装为 "批量迭代器",适配模型训练 / 测试时的批量输入需求,同时支持打乱、多进程加载等优化。

python 复制代码
# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)

训练集加载器(train_loader)

基于训练数据集train_dataset创建批量加载器:

  • batch_size=64:每次加载 64 张图片为一个批次(选 2 的幂次方适配 GPU 计算效率);
  • shuffle=True:训练前随机打乱数据顺序,避免模型学 "数据顺序",提升泛化能力。

测试集加载器(test_loader)

基于测试数据集test_dataset创建批量加载器:

  • batch_size=1000:测试集批次更大(无训练开销,提速评估);
  • 默认shuffle=False:测试不打乱数据,保证结果可复现,也节省计算开销。

核心:DataLoader把数据集封装成 "批量迭代器",训练侧重打乱 + 小批次适配 GPU,测试侧重大批次提速 + 不打乱保复现。

三、总结

核心结论

  • Dataset 类:定义数据的内容和格式(即 "如何获取单个样本"),包括:
    • 数据存储路径 / 来源(如文件路径、数据库查询)。
    • 原始数据的读取方式(如图像解码为 PIL 对象、文本读取为字符串)。
    • 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过 transform 参数实现)。
    • 返回值格式(如 (image_tensor, label))。
  • DataLoader 类:定义数据的加载方式和批量处理逻辑(即 "如何高效批量获取数据"),包括:
    • 批量大小(batch_size)。
    • 是否打乱数据顺序(shuffle)。

勇闯python的第39天@浙大疏锦行

相关推荐
小雨叔2 小时前
内容管理趋势:无头CMS+AI,正在重构企业内容运营逻辑
人工智能·重构·内容运营
玦尘、2 小时前
《统计学习方法》第7章——支持向量机SVM(下)【学习笔记】
机器学习·支持向量机·学习方法
XiaoMu_0012 小时前
验证码识别系统
python·深度学习
AI_56782 小时前
从“内存溢出”到“稳定运行”——Spark OOM的终极解决方案
人工智能·spark
未来之窗软件服务2 小时前
一体化系统(九)高级表格自己编程如何选择——东方仙盟练气期
大数据·人工智能·仙盟创梦ide·东方仙盟·东方仙盟sdk·东方仙盟一体化·万象exce
学习是生活的调味剂2 小时前
实战LLaMA2-7B指令微调
人工智能·alpaca
康实训3 小时前
养老实训室建设
人工智能·机器学习·实训室·养老实训室·实训室建设
Code_流苏3 小时前
GPT-5.1深度解析:更智能更自然,日常体验依旧出色!
人工智能·gpt·ai·深度解析·gpt5.1·日常体验
风吹稻香飘3 小时前
【无标题】
人工智能·ai