PyTorch图像数据加载实战指南

本文核心讲解如何基于PyTorch的DatasetDataLoader构建高效、可无缝集成到深度学习模型的图像数据加载流水线,同时覆盖自定义数据集处理与内置数据集快速使用两大场景。

一、教程核心目标与示例数据集

1. 核心目标

构建完整的图像数据处理流水线,实现:

  • 数据集的结构化划分(训练集+验证集)
  • 图像数据的加载与自动化预处理/增强
  • 按批次高效读取数据并输入模型

2. 示例数据集

使用5类花卉数据集,包含:郁金香(Tulips)、雏菊(Daisy)、蒲公英(Dandelion)、玫瑰(Roses)、向日葵(Sunflowers),原始数据按"类别文件夹+图像文件"的结构组织。

二、开发环境配置

1. 必需依赖包

通过pip一键安装:

复制代码
pip install torch torchvision matplotlib opencv-contrib-python imutils

2. 备选方案

若本地环境配置困难,可通过PyImageSearch University获取预配置的Google Colab笔记本,支持Windows/macOS/Linux全平台直接运行,无需本地安装依赖。

三、项目结构与文件说明

下载并解压教程资源后,项目目录结构如下:

复制代码
├── build_dataset.py          # 数据集划分脚本
├── builtin_dataset.py        # PyTorch内置数据集加载脚本
├── flower_photos/            # 原始花卉数据集
│   ├── daisy/
│   ├── dandelion/
│   ├── roses/
│   ├── sunflowers/
│   └── tulips/
├── load_and_visualize.py     # 数据加载与批次可视化脚本
└── pyimagesearch/
    ├── config.py             # 全局配置文件
    └── __init__.py           # 包初始化文件

四、全局配置文件(config.py

集中管理所有可配置参数,避免硬编码:

复制代码
# specify path to the flowers and mnist dataset
FLOWERS_DATASET_PATH = "flower_photos"
MNIST_DATASET_PATH = "mnist"
# specify the paths to our training and validation set 
TRAIN = "train"
VAL = "val"
# set the input height and width
INPUT_HEIGHT = 128
INPUT_WIDTH = 128
# set the batch size and validation data split
BATCH_SIZE = 8
VAL_SPLIT = 0.1

|----------------------|-----------------------|
| 参数名 | 取值/含义 |
| FLOWERS_DATASET_PATH | 原始花卉数据集根目录 |
| MNIST_DATASET_PATH | MNIST内置数据集保存目录 |
| TRAIN/VAL | 训练集/验证集输出目录名 |
| INPUT_HEIGHT/WIDTH | 模型输入图像尺寸(128×128) |
| BATCH_SIZE | 数据批次大小(8) |
| VAL_SPLIT | 验证集占总数据集的比例(0.1,即10%) |

五、数据集划分(build_dataset.py)

将原始数据集按比例划分为训练集和验证集,保持类别分布均匀。

复制代码
# USAGE  
# python build_dataset.py
# import necessary packages
from pyimagesearch import config  
from imutils import paths  
import numpy as np  
import shutil  
import os

1. 核心函数:copy_images

  • 功能:接收图像路径列表和目标目录,自动按类别创建子文件夹并复制图像

  • 关键逻辑:从图像路径中提取类别名(路径倒数第二层),在目标目录下创建对应类别文件夹,再将图像复制到对应位置

    def copy_images(imagePaths, folder):
    # check if the destination folder exists and if not create it
    if not os.path.exists(folder):
    os.makedirs(folder)
    # loop over the image paths
    for path in imagePaths:
    # grab image name and its label from the path and create
    # a placeholder corresponding to the separate label folder
    imageName = path.split(os.path.sep)[-1]
    label = path.split(os.path.sep)[-2]
    labelFolder = os.path.join(folder, label)
    # check to see if the label folder exists and if not create it
    if not os.path.exists(labelFolder):
    os.makedirs(labelFolder)
    # construct the destination image path and copy the current
    # image to it
    destination = os.path.join(labelFolder, imageName)
    shutil.copy(path, destination)

2. 划分流程

  1. 加载所有图像路径并通过np.random.shuffle()随机打乱,保证训练集和验证集的类别分布一致

  2. VAL_SPLIT计算验证集和训练集的图像数量

  3. 分别将训练集、验证集图像复制到train/val/目录

    load all the image paths and randomly shuffle them

    print("[INFO] loading image paths...")
    imagePaths = list(paths.list_images(config.FLOWERS_DATASET_PATH))
    np.random.shuffle(imagePaths)

    generate training and validation paths

    valPathsLen = int(len(imagePaths) * config.VAL_SPLIT)
    trainPathsLen = len(imagePaths) - valPathsLen
    trainPaths = imagePaths[:trainPathsLen]
    valPaths = imagePaths[trainPathsLen:]

    copy the training and validation images to their respective

    directories

    print("[INFO] copying training and validation images...")
    copy_images(trainPaths, config.TRAIN)
    copy_images(valPaths, config.VAL)

3. 最终目录结构

划分后生成独立的训练集和验证集目录,均保持"根目录→类别文件夹→图像文件"的结构:

复制代码
├── train/
│   ├── daisy/
│   ├── dandelion/
│   ├── roses/
│   ├── sunflowers/
│   └── tulips/
└── val/
    ├── daisy/
    ├── dandelion/
    ├── roses/
    ├── sunflowers/
    └── tulips/

六、PyTorch Dataset与DataLoader核心实现(load_and_visualize.py)

这是教程的核心部分,讲解如何加载数据、应用增强并构建可迭代的DataLoader。

1. 关键导入

  • ImageFolder:PyTorch内置的图像数据集类,用于加载按类别分文件夹的数据集

  • DataLoader:将Dataset包装为可迭代对象,实现按批次加载

  • transforms:提供图像预处理和数据增强的内置函数

    USAGE

    python load_and_visualize.

    import necessary packages

    from pyimagesearch import config
    from torchvision.datasets import ImageFolder
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt
    import torch

2. 批次可视化函数:visualize_batch

  • 输入:数据批次、类别列表、数据集类型(train/val)
  • 核心处理:
    1. 将PyTorch张量(通道优先格式:C×H×W)转换为numpy数组(通道最后格式:H×W×C)
    2. 将归一化到[0,1]的像素值还原为[0,255]的整数格式
    3. 绘制批次中所有图像,并标注对应的类别名称

    def visualize_batch(batch, classes, dataset_type):
    # initialize a figure
    fig = plt.figure("{} batch".format(dataset_type),
    figsize=(config.BATCH_SIZE, config.BATCH_SIZE))
    # loop over the batch size
    for i in range(0, config.BATCH_SIZE):
    # create a subplot
    ax = plt.subplot(2, 4, i + 1)
    # grab the image, convert it from channels first ordering to
    # channels last ordering, and scale the raw pixel intensities
    # to the range [0, 255]
    image = batch[0][i].cpu().numpy()
    image = image.transpose((1, 2, 0))
    image = (image * 255.0).astype("uint8")
    # grab the label id and get the label from the classes list
    idx = batch[1][i]
    label = classes[idx]
    # show the image along with the label
    plt.imshow(image)
    plt.title(label)
    plt.axis("off")
    # show the plot
    plt.tight_layout()
    plt.show()

3. 数据预处理与增强

针对训练集和验证集设计不同的变换流水线(验证集不使用数据增强,仅做必要的格式转换):

复制代码
# 训练集变换:包含数据增强
trainTransforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.25),  # 25%概率水平翻转
    transforms.RandomVerticalFlip(p=0.25),    # 25%概率垂直翻转
    transforms.RandomRotation(degrees=15),    # 随机旋转±15度
    transforms.ToTensor()                     # 转换为张量并归一化到[0,1]
])

# 验证集变换:仅调整尺寸和格式转换
valTransforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
  • 关键说明:ToTensor()不仅完成数据类型转换,还会自动将PIL图像或numpy数组的像素值从[0,255]归一化到[0,1]区间。

4. ImageFolder数据集创建

  • 输入要求:数据集必须遵循"根目录/类别名/图像文件"的结构

  • 自动功能:识别所有唯一类别并映射为整数标签(0~4对应5类花卉)

  • 代码实现:

    trainDataset = ImageFolder(root=config.TRAIN, transform=trainTransforms)
    valDataset = ImageFolder(root=config.VAL, transform=valTransforms)

  • 核心方法:

    • __len__():返回数据集的总样本数
    • __getitem__(index):通过索引获取单个样本,返回格式为(图像张量, 整数标签)

5. DataLoader配置

将Dataset包装为可迭代对象,支持按批次加载和并行处理:

复制代码
# 训练集DataLoader:开启shuffle打乱样本,优化梯度下降收敛
trainDataLoader = DataLoader(trainDataset, batch_size=8, shuffle=True)
# 验证集DataLoader:无需打乱
valDataLoader = DataLoader(valDataset, batch_size=8)
  • 核心作用:将数据集划分为固定大小的批次,支持模型批量处理;训练集开启shuffle=True可避免模型学习到数据的顺序特征。

6. 批次获取与可视化

通过iter()将DataLoader转换为迭代器,再通过next()获取单个批次,调用visualize_batch函数展示批次中的图像和标签。

七、PyTorch内置数据集使用(builtin_dataset.py)

PyTorch的torchvision.datasets模块提供了大量常用计算机视觉数据集的一键下载和加载功能,包括MNIST、CIFAR-10、CIFAR-100、CelebA等。

1. MNIST数据集加载示例

复制代码
# 加载训练集(自动下载)
trainDataset = MNIST(root=config.MNIST_DATASET_PATH, train=True, download=True, transform=transforms.ToTensor())
# 加载测试集
valDataset = MNIST(root=config.MNIST_DATASET_PATH, train=False, download=True, transform=transforms.ToTensor())
  • 关键参数:
    • train=True:加载训练集;train=False:加载测试集
    • download=True:自动将数据集下载到指定的root目录

2. 适配调整

  • 可视化函数修改:MNIST为单通道灰度图,绘制时需指定cmap="gray"
  • DataLoader配置与自定义数据集完全一致

八、教程总结

本文完整实现了PyTorch图像数据加载的全流程:

  1. 完成了自定义数据集的结构化划分,保证训练集和验证集的类别分布均匀
  2. 掌握了transforms模块的使用,实现了数据预处理和训练集数据增强
  3. 理解了ImageFolderDataLoader的工作原理,构建了高效的批次数据加载流水线
  4. 学会了快速加载PyTorch内置数据集,简化常用数据集的使用流程

最终构建的数据加载流水线可直接无缝集成到任意PyTorch深度学习模型中,用于模型训练和验证。

参考文章:Image Data Loaders in PyTorch - PyImageSearch

相关推荐
博.闻广见11 小时前
AI_Python基础-4.标准库与IO
开发语言·python
程序猿编码11 小时前
大模型的“文字障眼法“:FlipAttack 文本反转越狱技术全解析
linux·python·ai·大模型
Yunzenn11 小时前
深度分析字节最新研究cola-DLM第 01 章:语言生成的三次范式之争 —— 从 RNN 到 AR 到扩散
linux·人工智能·rnn·深度学习·机器学习·架构·transformer
m0_6346667311 小时前
Stability Audio 3.0 把 AI 音乐推过了一个门槛:从“音频片段”走向“完整歌曲”
人工智能·音视频
名不经传的养虾人11 小时前
从0到1:企业级AI项目迭代日记 Vol.30|看不见的地基:从“能用”到“可信”的30天
人工智能·ai编程·企业ai
晚烛11 小时前
CANN 数据流与内存优化:L1/L2 缓存机制与计算重叠深度解析
人工智能·python·缓存
xiao5kou4chang6kai411 小时前
如何用Python处理气象海洋数据?台风数据爬取、SST的EOF分析、WRF剖面图绘制
python·气象·台风·wrf·海洋
Reload.11 小时前
CZ航司,shopping JS逆向 acw_sc__v2
开发语言·javascript·python·网络爬虫·ecmascript
薛定猫AI11 小时前
【深度解析】从 Antigravity 2.0 看 AI Agent 的产品化演进:动态子代理、项目工作区与多模型编排实战
人工智能