Day38 Dataset和Dataloader类

目录

一、Dataset类的__getitem__和__len__方法(本质是python的特殊方法)

1、MNIST数据处理

2、Dataset类

__getitem__方法

__len__方法

二、Dataloader类

三、minist手写数据集的了解

[1、MNIST 数据集基本概况](#1、MNIST 数据集基本概况)

[2、MNIST 数据集的特点与应用场景](#2、MNIST 数据集的特点与应用场景)

[3、MNIST 数据的存储与读取](#3、MNIST 数据的存储与读取)

[4、MNIST 数据集的性能基准](#4、MNIST 数据集的性能基准)

[5、MNIST 的扩展与变种](#5、MNIST 的扩展与变种)

四、了解下cifar数据集,尝试获取其中一张图片

[1、CIFAR 数据集概述](#1、CIFAR 数据集概述)

[2、CIFAR-10 数据集详解](#2、CIFAR-10 数据集详解)

[a. 数据规模与结构](#a. 数据规模与结构)

[b. 数据特点](#b. 数据特点)

[3、CIFAR-100 数据集详解](#3、CIFAR-100 数据集详解)

[a. 数据规模与结构](#a. 数据规模与结构)

[b. 与 CIFAR-10 的核心区别](#b. 与 CIFAR-10 的核心区别)

[4、数据预处理与加载(PyTorch 示例)](#4、数据预处理与加载(PyTorch 示例))

标准化参数

[5、CIFAR 数据集的应用与挑战](#5、CIFAR 数据集的应用与挑战)

[a. 典型应用场景](#a. 典型应用场景)

[b. 主要挑战](#b. 主要挑战)

[6、与 MNIST 数据集的对比](#6、与 MNIST 数据集的对比)

7、CIFAR-10图片获取

8、CIFAR-100图片获取


一、Dataset类的__getitem__和__len__方法(本质是python的特殊方法)

1、MNIST数据处理

在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式。

  • DataLoader类:决定数据如何加载
  • Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理。

为了引入这些概念,现在接触一个新的而且非常经典的数据集:MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。

torchvision

├── datasets # 视觉数据集(如 MNIST、CIFAR)

├── transforms # 视觉数据预处理(如裁剪、翻转、归一化)

├── models # 预训练模型(如 ResNet、YOLO)

├── utils # 视觉工具函数(如目标检测后处理)

└── io # 图像/视频 IO 操作

复制代码
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作

# 先归一化,再标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])

# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

这里稍微有点反逻辑,正常思路应该是先有数据集,后续再处理。但是在pytorch的思路是,数据在加载阶段就处理结束。

2、Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.MNIST本质上集成了torch.utils.data.Dataset,所以自然需要有对应的方法。

复制代码
import matplotlib.pyplot as plt

# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

这里很难理解,为什么train_dataset[sample_idx]可以获取到图片和标签,是因为 datasets.MNIST这个类继承了torch.utils.data.Dataset类,这个类中有一个方法__getitem__,这个方法会返回一个tuple,tuple中第一个元素是图片,第二个元素是标签。

详细介绍下torch.utils.data.Dataset类

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

  • len():返回数据集的样本总数。
  • getitem(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

在 Python 中,getitem__和__len 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:

__getitem__方法

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

复制代码
# 示例代码
class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __getitem__(self, idx):
        return self.data[idx]

# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。

__len__方法

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

复制代码
class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __len__(self):
        return len(self.data)

# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

这里定义的__len__方法,使得MyList类的实例可以像普通列表一样被len()函数调用获取长度。

复制代码
# minist数据集的简化版本
class MNIST(Dataset):
    def __init__(self, root, train=True, transform=None):
        # 初始化:加载图片路径和标签
        self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签
        self.transform = transform # 预处理操作
        
    def __len__(self): 
        return len(self.data)  # 返回样本总数
    
    def __getitem__(self, idx): # 获取指定索引的样本
        # 获取指定索引的图像和标签
        img, target = self.data[idx], self.targets[idx]
        
        # 应用图像预处理(如ToTensor、Normalize)
        if self.transform is not None: # 如果有预处理操作
            img = self.transform(img) # 转换图像格式
        # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化
            
        return img, target  # 返回处理后的图像和标签
  • Dataset = 厨师(准备单个菜品)
  • DataLoader = 服务员(将菜品按订单组合并上桌)

预处理(如切菜、调味)属于厨师的工作,而非服务员。所以在dataset就需要添加预处理步骤。

复制代码
# 可视化原始图像(需要反归一化)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray') # 显示灰度图像
    plt.show()

print(f"Label: {label}")
imshow(image)

二、Dataloader类

复制代码
# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)

总结:

维度 Dataset DataLoader
核心职责 定义"数据是什么"和"如何获取单个样本" 定义"如何批量加载数据"和"加载策略"
核心方法 __getitem__(获取单个样本)、__len__(样本总数) 无自定义方法,通过参数控制加载逻辑
预处理位置 __getitem__中通过transform执行预处理 无预处理逻辑,依赖Dataset返回的预处理后数据
并行处理 无(仅单样本处理) 支持多进程加载(num_workers>0
典型参数 root(数据路径)、transform(预处理) batch_sizeshufflenum_workers

核心结论

  • Dataset :定义数据的内容和格式(即"如何获取单个样本"),包括:

    • 数据存储路径/来源(如文件路径、数据库查询)。
    • 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。
    • 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过transform参数实现)。
    • 返回值格式(如(image_tensor, label))。
  • DataLoader :定义数据的加载方式和批量处理逻辑(即"如何高效批量获取数据"),包括:

    • 批量大小(batch_size)。
    • 是否打乱数据顺序(shuffle)。

三、minist手写数据集的了解

1、MNIST 数据集基本概况

MNIST(Modified National Institute of Standards and Technology database)是机器学习领域最经典的入门级数据集之一,由纽约大学的 Yann LeCun 等人整理,被广泛用于图像分类算法的开发与测试。

  • 数据来源

    • 训练集:来自 250 个不同人手写的数字,其中 50% 是高中学生,50% 是人口普查局员工。
    • 测试集:与训练集独立的另一组手写数字样本。
  • 数据规模

    • 训练集:60,000 张灰度图像
    • 测试集:10,000 张灰度图像
    • 每张图像尺寸:28×28 像素
  • 数据格式

    • 图像:单通道灰度图,像素值范围为 0-255(归一化后通常为 0-1)
    • 标签:0-9 的数字分类,共 10 个类别

2、MNIST 数据集的特点与应用场景

  • 特点

    • 数据标准化:所有图像已统一尺寸并中心化处理
    • 噪声较少:经过筛选和预处理,适合入门级算法验证
    • 类别平衡:每个数字的样本数量相对均匀
  • 典型应用场景

    • 机器学习入门教程的标准案例
    • 新算法的基准测试(如 CNN、神经网络等)
    • 迁移学习的源领域数据
    • 模型轻量化与压缩技术的测试平台

3、MNIST 数据的存储与读取

MNIST 数据集通常以 IDX 格式存储,这是一种用于存储多维度数组的简单文件格式:

  • 文件结构

    • train-images-idx3-ubyte:训练集图像(60,000×28×28)
    • train-labels-idx1-ubyte:训练集标签(60,000 个)
    • t10k-images-idx3-ubyte:测试集图像(10,000×28×28)
    • t10k-labels-idx1-ubyte:测试集标签(10,000 个)

4、MNIST 数据集的性能基准

由于数据规模适中且难度较低,MNIST 已成为衡量算法性能的标准之一:

  • 传统机器学习算法

    • 支持向量机(SVM):准确率约 97%-98%
    • 随机森林:准确率约 95%-97%
  • 深度学习算法

    • 简单全连接神经网络:准确率约 98%-99%
    • 卷积神经网络(CNN):准确率可达 99.5% 以上
    • 最新技术(如胶囊网络):准确率接近 99.7%

5、MNIST 的扩展与变种

  • EMNIST:扩展版 MNIST,包含更多字符类别(大写字母、小写字母、数字)
  • Fashion-MNIST:由 Zalando 提供的服装图像数据集,结构与 MNIST 完全一致,更具挑战性
  • KMNIST:日本手写假名数据集,格式与 MNIST 兼容
  • MNIST-C:添加了各种噪声和干扰的 MNIST 变种,用于测试模型鲁棒性

四、了解下cifar数据集,尝试获取其中一张图片

1、CIFAR 数据集概述

CIFAR(Canadian Institute For Advanced Research)数据集是计算机视觉领域的经典基准数据集,由加拿大高级研究所的 Alex Krizhevsky、Vinod Nair 和 Geoffrey Hinton 创建,主要用于图像分类任务的模型训练与评估。该数据集分为 CIFAR-10 和 CIFAR-100 两个版本,两者在数据结构和应用场景上既有相似性又有明显区别。

2、CIFAR-10 数据集详解

a. 数据规模与结构

  • 图像数量:共 60,000 张 32×32 的彩色图像,其中 50,000 张训练集、10,000 张测试集。

  • 类别分布 :10 个大类,每个类别包含 6,000 张图像,类别如下:

    类别 示例图像
    airplane 飞机
    automobile 汽车
    bird 鸟类
    cat
    deer 鹿
    dog
    frog 青蛙
    horse
    ship
    truck 卡车

b. 数据特点

  • 色彩与尺寸:RGB 三通道彩色图像,固定尺寸 32×32,像素值范围 [0, 255]。
  • 难度挑战:图像分辨率低、目标物体占据像素少,且存在背景干扰、姿态变化等问题,对模型识别能力要求较高。

3、CIFAR-100 数据集详解

a. 数据规模与结构

  • 图像数量:同 CIFAR-10,共 60,000 张 32×32 彩色图像(50,000 训练 + 10,000 测试)。
  • 类别分布 :100 个细分类别,每个类别包含 600 张图像,类别组织为 20 个超类(每个超类包含 5 个子类),例如:
    • 超类 "动物":包含 bear(熊)、tiger(老虎)、lion(狮子)等子类;
    • 超类 "交通工具":包含 car(汽车)、train(火车)、truck(卡车)等子类。

b. 与 CIFAR-10 的核心区别

维度 CIFAR-10 CIFAR-100
类别数量 10 个大类 100 个细分类别(20 超类)
分类难度 较低(类别差异明显) 较高(细分类别易混淆)
典型应用 基础模型验证 细粒度分类研究

4、数据预处理与加载(PyTorch 示例)

标准化参数

  • CIFAR-10
    均值 (0.4914, 0.4822, 0.4465),标准差 (0.2023, 0.1994, 0.2010)
  • CIFAR-100
    均值 (0.5071, 0.4867, 0.4408),标准差 (0.2675, 0.2565, 0.2761)

5、CIFAR 数据集的应用与挑战

a. 典型应用场景

  • 模型性能评估:如 ResNet、DenseNet 等经典网络常以 CIFAR 为基准测试分类准确率;
  • 数据增强研究:由于数据量有限,常用于验证数据增强技术(如 Cutout、Mixup)的效果;
  • 半监督学习与迁移学习:小样本场景下的算法验证(如 FixMatch、SimCLR)。

b. 主要挑战

  • 小尺寸图像:32×32 像素难以捕捉复杂细节,需模型具备强特征提取能力;
  • 类别相似度:CIFAR-100 中同类超类的子类(如不同品种的狗)外观高度相似,分类难度大;
  • 过拟合问题:训练集规模有限(5 万张),深度模型易出现过拟合,需结合正则化策略(如 Dropout、L2 正则)。

6、与 MNIST 数据集的对比

维度 MNIST CIFAR-10 CIFAR-100
图像类型 灰度(1 通道) 彩色(3 通道) 彩色(3 通道)
图像尺寸 28×28 32×32 32×32
类别数量 10(手写数字) 10(物体大类) 100(物体细分类)
任务难度 低(入门级) 中(基础研究) 高(进阶研究)
典型准确率 CNN 可达 99%+ 主流模型约 95% 主流模型约 85%

7、CIFAR-10图片获取

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10数据集的均值和标准差
])

train_dataset = datasets.CIFAR10(
    root='./data', 
    train=True, 
    download=True,
    transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', 
    train=False, 
    download=False,
    transform=transform
)
复制代码
# 随机选择一张图片,可以重复运行,每次都会随机选择
cifar_sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
cifar_image, cifar_label = train_dataset[cifar_sample_idx] # 获取图片和标签

import numpy as np
import matplotlib.pyplot as plt

# 定义CIFAR-10的类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# 可视化CIFAR-10图像(需要反归一化)
def imshow_cifar(img, label=None):
    # 使用CIFAR-10的均值和标准差进行反归一化
    img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(3, 1, 1) + np.array([0.4914, 0.4822, 0.4465]).reshape(3, 1, 1)
    npimg = img.numpy()
    # 调整通道顺序:[C,H,W] → [H,W,C]
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    # 显示标签(如果提供)
    if label is not None:
        plt.title(f"Label: {label} ({classes[label]})")
    plt.axis('off')
    plt.show()

# 使用示例
print(f"Label: {cifar_label}")
imshow_cifar(cifar_image, cifar_label)

8、CIFAR-100图片获取

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

torch.manual_seed(42)

# CIFAR-100的归一化参数
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # CIFAR-100数据集的均值和标准差
])

# 加载CIFAR-100数据集
train_dataset = datasets.CIFAR100(
    root='./data', 
    train=True, 
    download=True,
    transform=transform
)
test_dataset = datasets.CIFAR100(
    root='./data', 
    train=False, 
    download=False,
    transform=transform
)

# 随机选择一张图片
cifar100_sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
cifar100_image, cifar100_label = train_dataset[cifar100_sample_idx]

# 定义CIFAR-100的类别名称(完整列表)
cifar100_classes = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 
    'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 
    'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 
    'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 
    'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 
    'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 
    'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 
    'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]

# 可视化CIFAR-100图像(需要反归一化)
def imshow_cifar100(img, label=None):
    # 使用CIFAR-100的均值和标准差进行反归一化
    img = img * np.array([0.2675, 0.2565, 0.2761]).reshape(3, 1, 1) + np.array([0.5071, 0.4867, 0.4408]).reshape(3, 1, 1)
    npimg = img.numpy()
    # 调整通道顺序:[C,H,W] → [H,W,C]
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    # 显示标签
    if label is not None:
        plt.title(f"Label: {label} ({cifar100_classes[label]})")
    plt.axis('off')
    plt.show()

# 使用示例
print(f"Label: {cifar100_label}")
imshow_cifar100(cifar100_image, cifar100_label)

@浙大疏锦行

相关推荐
AndrewHZ2 分钟前
【Python与生活】如何实现一个条形码检测算法?
人工智能·pytorch·python·深度学习·算法·生活
白嫖不白嫖6 分钟前
从番茄炒蛋到神经网络:解密AI模型的本质
人工智能·深度学习·神经网络
struggle20259 分钟前
PINA开源程序用于高级建模的 Physics-Informed 神经网络
人工智能·深度学习·神经网络
198912 分钟前
【Dify精讲】第14章:部署架构与DevOps实践
运维·人工智能·python·ai·架构·flask·devops
微信公众号:AI创造财富15 分钟前
文生视频(Text-to-Video)
开发语言·人工智能·python·深度学习·aigc·virtualenv
struggle202515 分钟前
Z-Ant开源程序是简化了微处理器上神经网络的部署和优化
人工智能·深度学习·神经网络
NetX行者23 分钟前
Wordvice AI:Wordvice 推出的免费,基于先进的 AI 技术帮助用户提升英文写作质量
人工智能·ai工具
倔强青铜三34 分钟前
🚀LlamaIndex中文教程(1)----对接Qwen3大模型
人工智能·后端·python
Q_Q51100828535 分钟前
python的校园兼职系统
开发语言·spring boot·python·django·flask·node.js·php
程序员寒山36 分钟前
Ai工具之DeepSiteV2(1):「边聊边建」的智能辅助建站神器
人工智能