3. pytorch中数据集加载和处理

文章目录

    • 前言
    • [1. 自定义数据集](#1. 自定义数据集)
    • [2. 使用 DataLoader 加载数据](#2. 使用 DataLoader 加载数据)
    • [3. 预处理和数据增强](#3. 预处理和数据增强)
      • [3.1 为什么需要预处理和数据增强?](#3.1 为什么需要预处理和数据增强?)
      • [3.2 核心工具:torchvision.transforms](#3.2 核心工具:torchvision.transforms)
      • [3.3 常见预处理操作详细讲解](#3.3 常见预处理操作详细讲解)
      • [3.4 图像数据增强](#3.4 图像数据增强)
      • [3.5 训练 vs 验证/测试 的 transforms 区别(超级重要!)](#3.5 训练 vs 验证/测试 的 transforms 区别(超级重要!))
      • [3.6 如何在 Dataset 中使用 transforms?](#3.6 如何在 Dataset 中使用 transforms?)
    • [4. 加载图像数据集](#4. 加载图像数据集)
    • 总结

前言

在深度学习训练过程中,数据处理和加载是第一步。它涉及如何读取数据、清洗数据、批量提供给模型训练。如果数据加载慢或不高效,会拖慢整个训练过程。

为了高效地处理数据,PyTorch 提供了强大的工具,包括 torch.utils.data.Datasettorch.utils.data.DataLoader,帮助我们管理数据集、批量加载和数据增强等任务。

pyTorch 数据处理与加载的介绍:

  • 自定义 Dataset :通过继承 torch.utils.data.Dataset 来加载自己的数据集。
  • DataLoaderDataLoader 按批次加载数据,支持多线程加载并进行数据打乱。
  • 数据预处理与增强 :使用 torchvision.transforms 进行常见的图像预处理和增强操作,提高模型的泛化能力。
  • 加载标准数据集torchvision.datasets 提供了许多常见的数据集,简化了数据加载过程。
  • 多个数据源 :通过组合多个 Dataset 实例来处理来自不同来源的数据。

1. 自定义数据集

torch.utils.data.Dataset 是一个抽象类(就像一个模板),PyTorch 要求我们自己"填空"来创建真正能用的数据集。

自定义 Dataset 的核心目的:告诉 PyTorch "我的数据长什么样、有多少个样本、怎么取出一个样本"。

必须实现两个方法:

  • __len__(self):告诉 PyTorch 数据集总共有多少个样本(就像问"列表有多长")。
  • __getitem__(self, idx):告诉 PyTorch "第 idx 个样本是什么"(idx 从 0 开始)。

假设我们有一个简单的 CSV 文件或一些列表数据,我们可以通过继承 Dataset 类来创建自己的数据集。

下面代码的例子就是模拟一个简单的二分类任务(判断一个东西属于类别 0 还是类别 1)。

  • 每个样本(一条数据)都有2个特征(比如可以想象成一个人的"身高"和"体重",或者任何两个数字指标)。

  • 每个样本都有一个

    标签(label),告诉我们这个样本属于哪个类别:

    • 1 → 正类(比如"是猫")
    • 0 → 负类(比如"是狗")

代码如下:

python 复制代码
import torch
from torch.utils.data import Dataset

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, X_data, Y_data):
        """
        初始化时,把外部传进来的数据保存到类里面
        """
        self.X_data = X_data   # 保存所有输入特征
        self.Y_data = Y_data   # 保存所有标签

    def __len__(self):
        """返回数据集大小 = 样本数量"""
        return len(self.X_data)  # 这里是 4

    def __getitem__(self, idx):
        """根据索引 idx 返回第 idx 个样本"""
        x = torch.tensor(self.X_data[idx], dtype=torch.float32)  # 把特征转成 Tensor
        y = torch.tensor(self.Y_data[idx], dtype=torch.float32)  # 把标签转成 Tensor
        return x, y  # 返回一个元组:(特征, 标签)
    
# 示例数据
X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]  # 输入特征
Y_data = [1, 0, 1, 0]  # 目标标签

# 创建数据集实例
dataset = MyDataset(X_data, Y_data)

# 测试
print(len(dataset))          # 输出: 4
print(dataset[0])            # 输出: (tensor([1., 2.]), tensor(1.))
print(dataset[1])            # 输出: (tensor([3., 4.]), tensor(0.))
print(dataset[2])            # 输出: (tensor([5., 6.]), tensor(1.))

X_data 和 Y_data 如何一一对应:

python 复制代码
X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]  # 输入特征(4个样本,每个样本2个特征)
Y_data = [1, 0, 1, 0]                      # 目标标签(4个标签)

把它们按索引对齐来看:

索引 (idx) X_data 中的样本(输入特征) Y_data 中的标签(目标类别) 含义解释(举例)
0 [1, 2] 1 第1个样本:特征是1和2,属于类别1(正类)
1 [3, 4] 0 第2个样本:特征是3和4,属于类别0(负类)
2 [5, 6] 1 第3个样本:特征是5和6,属于类别1(正类)
3 [7, 8] 0 第4个样本:特征是7和8,属于类别0(负类)

注意:一般矩阵第一层都可以尽量当作里欸向量来看。

python 复制代码
X_data = [
    [1, 2],  # <--这是 idx=0 的样本,它是一个包含两个数字的整体
    [3, 4],  # <--这是 idx=1 的样本
    ...
]

[1, 2] 看起来是两个数,但对于数据集来说,它是一个样本(Sample)

2. 使用 DataLoader 加载数据

DataLoader 是 PyTorch 提供的一个重要工具,用于从 Dataset 中按批次(batch)加载数据。

DataLoader 允许我们批量读取数据并进行多线程加载,从而提高训练效率。

为什么重要

  • 深度学习模型通常不能一次处理全部数据(内存不够),需要分成小批次(mini-batch)训练。
  • DataLoader 自动帮你:
    • 分批取数据
    • 打乱顺序(防止模型学到数据顺序的偏见)
    • 多线程并行加载(加速,尤其处理图像时)
  • 小批次取出数据并且加载之后,就可以放入神经网络训练了,只不过每一个小批次训练都是训练取出来的几个样本,直到一个epoch训练完。

创建DataLoader:

python 复制代码
from torch.utils.data import DataLoader

# 假设你已经创建好了 dataset(上节课的 MyDataset,里面有4个样本)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

关键参数解释(重点记这些):

参数 含义 常用设置 备注
dataset 要加载的 Dataset 对象(必须传) 你的自定义 dataset
batch_size 每个批次包含多少个样本 常用 16、32、64、128 越大越快,但占内存多
shuffle 是否在每个 epoch 开始前打乱数据顺序 训练时 True,验证/测试 False 非常重要!防止过拟合
drop_last 如果最后一个 batch 样本不够 batch_size 个,是否丢弃它 训练时常 True,测试时 False 避免 batch 太小影响梯度
num_workers 多线程加载数据用的线程数(并行加速) 0(单线程)~ 4/8(看CPU核数) Windows 有时设>0会出错,先用0
pin_memory 是否将数据固定在内存中(加速 CPU→GPU 传输) 用 GPU 时 True

具体代码:

python 复制代码
import torch
from torch.utils.data import Dataset, DataLoader

# 上节课的 MyDataset
class MyDataset(Dataset):
    def __init__(self, X_data, Y_data):
        self.X_data = X_data
        self.Y_data = Y_data
    def __len__(self):
        return len(self.X_data)
    def __getitem__(self, idx):
        x = torch.tensor(self.X_data[idx], dtype=torch.float32)
        y = torch.tensor(self.Y_data[idx], dtype=torch.float32)
        return x, y

# 数据
X_data = [[1, 2], [3, 4], [5, 6], [7, 8]]
Y_data = [1, 0, 1, 0]

dataset = MyDataset(X_data, Y_data)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历打印
for epoch in range(1):
    print(f"Epoch {epoch + 1}")
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f'  Batch {batch_idx + 1}:')
        print(f'    Inputs shape: {inputs.shape} → {inputs}')
        print(f'    Labels: {labels}')

讲解:

  • for epoch in range(1):一个 epoch 表示"把整个数据集看一遍"。

  • enumerate(dataloader):遍历 DataLoader,每次取出一个 batch,并给出批次编号(batch_idx 从 0 开始)。

  • (inputs, labels):解包每个 batch 返回的内容。

    • inputs:一个 Tensor,形状是 [batch_size, 特征维度](这里是 [2, 2])

    • labels:一个 Tensor,形状是 [batch_size](这里是 [2])

代码输出结果:

text 复制代码
Epoch 1
  Batch 1:
    Inputs shape: torch.Size([2, 2]) → tensor([[5., 6.],
        [7., 8.]])
    Labels: tensor([1., 0.])
  Batch 2:
    Inputs shape: torch.Size([2, 2]) → tensor([[1., 2.],
        [3., 4.]])
    Labels: tensor([1., 0.])

每次循环中,DataLoader 会返回一个批次的数据,包括输入特征(inputs)和目标标签(labels)。

相当于对自定义dataset中的__getitem__的调用。

3. 预处理和数据增强

3.1 为什么需要预处理和数据增强?

  • 预处理(Preprocessing):
    • 让所有图像格式统一(大小、数据类型、数值范围)。
    • 使模型更容易收敛(尤其是使用预训练模型时,必须匹配它的预处理方式)。
  • 数据增强(Data Augmentation):
    • 通过随机变换"人工制造"更多训练样本。
    • 让模型看到同一张图片的不同版本(如翻转、旋转),学会关注本质特征,而不是记住具体样子。
    • 显著提高模型在真实场景(新数据)上的表现。

记住一句话: 训练时多用增强(随机变换),验证/测试时只用固定预处理(不要随机)。

3.2 核心工具:torchvision.transforms

  • 所有操作都定义在 torchvision.transforms 中。
  • 输入:通常是 PIL.Image(用 Pillow 打开的图像)或 Tensor。
  • 输出:变换后的图像(PIL 或 Tensor)。
  • 最常用方式 :用 transforms.Compose([]) 把多个操作按顺序组合成一个"流水线"。

3.3 常见预处理操作详细讲解

python 复制代码
import torchvision.transforms as transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((128, 128)),      # 调整大小
    transforms.ToTensor(),              # 转成 Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # 标准化
])

image = Image.open('image.jpg')         # 用 Pillow 打开图像(RGB模式)
image_tensor = transform(image)         # 应用整个流水线
print(image_tensor.shape)               # 输出: torch.Size([3, 128, 128])

逐个操作解释

操作 作用 输入 → 输出 关键点
Resize((128, 128)) 把图像统一调整到 128×128 像素(可指定单个整数如 128,表示短边) PIL.Image → PIL.Image 防止不同大小图像无法 batch
ToTensor() 1. 把 PIL.Image 或 numpy 数组转为 PyTorch Tensor 2. 通道顺序从 HWC → CHW 3. 像素值从 [0,255] → [0.0, 1.0] PIL.Image → Tensor([C, H, W]) 必须步骤!模型只能吃 Tensor
Normalize(mean, std) 对每个通道进行:(x - mean) / std 使数据均值≈0,标准差≈1 Tensor → Tensor(值可能负数) 使用预训练模型(如 ResNet)时必须匹配 ImageNet 的 mean/std

Normalize 的 mean/std 值来源

  • [0.485, 0.456, 0.406][0.229, 0.224, 0.225] 是 ImageNet 数据集上计算的统计值。
  • 几乎所有 torchvision 预训练模型都要求用这个值!

形状变化过程

  • 原始图像:H × W × 3 (PIL)
  • ToTensor 后:3 × H × W (Tensor, 值 [0,1])
  • Resize 后:3 × 128 × 128

3.4 图像数据增强

python 复制代码
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),     # 随机水平翻转
    transforms.RandomRotation(30),         # 随机旋转 ±30 度
    transforms.RandomResizedCrop(128),     # 随机位置裁剪 + 调整大小
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

常用增强操作解释(这些都是随机的,只在训练时用):

操作 作用 参数说明 推荐场景
RandomHorizontalFlip(p=0.5) 以 p 概率水平翻转图像(左右镜像) p 默认 0.5 几乎所有任务都用
RandomVerticalFlip(p=0.5) 垂直翻转(上下镜像) 适合非方向性任务(如细胞图像)
RandomRotation(degrees) 随机旋转 ±degrees 度 如 30 表示 -30~+30 旋转不变任务慎用
RandomResizedCrop(size) 随机比例、随机位置裁剪,然后 resize 到指定大小 size=128 很强大!模拟不同尺度
ColorJitter(brightness, contrast, saturation, hue) 随机调整亮度、对比度、饱和度、色相 每个参数可设范围,如 brightness=0.2 光照变化大的场景
RandomGrayscale(p=0.1) 以 p 概率转灰度图 分类任务可加

关键特点

  • 每次调用 transform(image) 时,随机操作会产生不同结果
  • 一张图在不同 epoch 或同一 epoch 不同 batch 中看起来都不一样 → 相当于数据量变多了

3.5 训练 vs 验证/测试 的 transforms 区别(超级重要!)

python 复制代码
# 训练集 transforms(带增强)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(128),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 验证/测试集 transforms(只预处理,不增强)
val_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.CenterCrop(128),    # 中心裁剪,更稳定
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

3.6 如何在 Dataset 中使用 transforms?

在上次的自定义 Dataset 中,直接在 __init__ 或外部传入:

python 复制代码
dataset = MyCustomDataset(root_dir='data/', transform=train_transform)

DataLoader 会自动在每次 __getitem__ 时应用 transform。

4. 加载图像数据集

对于图像数据集,torchvision.datasets 提供了许多常见数据集(如 CIFAR-10、ImageNet、MNIST 等)以及用于加载图像数据的工具。

用 PyTorch 快速加载标准图像数据集**。torchvision 内置了很多经典数据集,你几乎不需要自己准备数据,就能开始训练模型。**

示例:加载 MNIST 数据集(手写数字识别)

python 复制代码
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 1. 定义预处理(transforms)
transform = transforms.Compose([
    transforms.ToTensor(),                          # 转成 Tensor,并缩放到 [0.0, 1.0]
    transforms.Normalize((0.5,), (0.5,))            # 标准化:(x - 0.5)/0.5 → [-1, 1]
])

# 2. 加载数据集
train_dataset = datasets.MNIST(
    root='./data',          # 数据保存路径(会自动创建)
    train=True,             # True=训练集(60,000 张),False=测试集(10,000 张)
    download=True,          # 第一次运行时自动下载
    transform=transform     # 应用上面定义的预处理
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,          # 测试集也会下载(如果还没下)
    transform=transform
)

# 3. 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False)

# 4. 迭代查看一个 batch
for inputs, labels in train_loader:
    print(inputs.shape)   # 示例输出: torch.Size([64, 1, 28, 28])
    print(labels.shape)   # 示例输出: torch.Size([64])
    break                 # 只看第一个 batch 就行

详细解释

(1)transform 的选择(为什么这样写?)

MNIST单通道灰度图像(28×28 像素),不是彩色(RGB)。

  • ToTensor()
    • PIL.Image(或 numpy)转为 Tensor
    • 通道顺序:HWC → CHW
    • 像素值:0~255 → 0.0~1.0
    • 输出形状:[1, 28, 28](1 个通道)
  • Normalize((0.5,), (0.5,))
    • 参数是元组,逗号不能少。因为灰度图只有 1 个通道。
    • 计算公式:(x - mean) / std
    • 这里把 [0,1] → [-1,1],很多简单模型喜欢这种对称分布。
    • 如果用预训练模型(MNIST 通常不用),才需要 ImageNet 的 mean/std。

(2)datasets.MNIST 的参数详解

参数 含义 常用值
root 数据存放根目录 './data' 或任何路径
train True=训练集(60000 张),False=测试集(10000 张)
download 是否自动下载 第一次 True,以后 False
transform 对每张图片应用的变换 你的 transform 对象

(3)DataLoader 设置区别

  • 训练 loader:shuffle=True(打乱顺序,防止模型记住顺序)
  • 测试 loader:shuffle=False(保持固定顺序,便于复现和评估)

(4)输出形状解释

  • inputs.shape: [64, 1, 28, 28]
    • 64:batch_size,表示64个样本
    • 1:通道数(灰度图)
    • 28, 28:高度和宽度
  • labels.shape: [64]
    • 每个样本一个数字标签(0~9)

运行上面代码之后,当前文件夹下面会出现data文件夹,如图所示:

当运行 datasets.MNIST(..., download=True) 时,PyTorch 会自动从官网下载并解压这些文件到 ./data/MNIST/raw 目录下。你看到的就是这些原始的二进制文件(不是图片jpg,而是特殊格式的压缩包和数据文件)。

  • 训练集图像(60,000 张手写数字图片):
    • train-images-idx3-ubyte.gz:压缩包(.gz)
    • train-images-idx3-ubyte:解压后的原始数据文件(包含所有训练图像的像素值)
  • 训练集标签(60,000 个数字 0~9):
    • train-labels-idx1-ubyte.gz:压缩包
    • train-labels-idx1-ubyte:解压后的标签文件

这些文件的内容是什么?

  • 不是普通的 jpg/png,而是一种叫 IDX 的旧格式(MNIST 从1990年代就用这个)。
  • 图像文件:里面打包了所有 28×28 灰度像素(每个像素 0~255,表示黑到白)。
  • 标签文件:每个样本一个字节(0~9)。

PyTorch 的 datasets.MNIST 会自动读取这些文件,转成我们能用的 Dataset 对象。

总结

PyTorch 数据处理的核心流程:Dataset 定义数据 → transforms 预处理/增强 → DataLoader 批量加载

  • 自定义 Dataset :继承 torch.utils.data.Dataset,实现 lengetitem,灵活加载任意数据。
  • DataLoader:从 Dataset 中按 batch_size 批量加载,支持 shuffle(训练打乱)和 num_workers(多线程加速)。
  • transforms:使用 torchvision.transforms.Compose 组合操作。预处理统一格式(Resize + ToTensor + Normalize);数据增强(RandomFlip/Rotation/Crop 等)只用于训练集,提高泛化能力。
  • 标准数据集:torchvision.datasets(如 MNIST、CIFAR10)一键下载 + 加载,直接配合 transforms 和 DataLoader 使用。

记住关键区别:训练时用随机增强 + shuffle验证/测试时用固定预处理 + 不打乱

相关推荐
Robot侠2 小时前
ROS1从入门到精通 10:URDF机器人建模(从零构建机器人模型)
人工智能·机器人·ros·机器人操作系统·urdf机器人建模
haiyu_y2 小时前
Day 46 TensorBoard 使用介绍
人工智能·深度学习·神经网络
阿里云大数据AI技术2 小时前
DataWorks 又又又升级了,这次我们通过 Arrow 列存格式让数据同步速度提升10倍!
大数据·人工智能
做科研的周师兄2 小时前
中国土壤有机质数据集
人工智能·算法·机器学习·分类·数据挖掘
IT一氪2 小时前
一款 AI 驱动的 Word 文档翻译工具
人工智能·word
Data_agent2 小时前
京东图片搜索商品API,json数据返回
数据库·python·json
lovingsoft2 小时前
Vibe coding 氛围编程
人工智能
深盾科技2 小时前
融合C++与Python:兼顾开发效率与运行性能
java·c++·python
百***07452 小时前
GPT-Image-1.5 极速接入全流程及关键要点
人工智能·gpt·计算机视觉