本文核心讲解如何基于PyTorch的Dataset和DataLoader构建高效、可无缝集成到深度学习模型的图像数据加载流水线,同时覆盖自定义数据集处理与内置数据集快速使用两大场景。
一、教程核心目标与示例数据集
1. 核心目标
构建完整的图像数据处理流水线,实现:
- 数据集的结构化划分(训练集+验证集)
- 图像数据的加载与自动化预处理/增强
- 按批次高效读取数据并输入模型
2. 示例数据集
使用5类花卉数据集,包含:郁金香(Tulips)、雏菊(Daisy)、蒲公英(Dandelion)、玫瑰(Roses)、向日葵(Sunflowers),原始数据按"类别文件夹+图像文件"的结构组织。
二、开发环境配置
1. 必需依赖包
通过pip一键安装:
pip install torch torchvision matplotlib opencv-contrib-python imutils
2. 备选方案
若本地环境配置困难,可通过PyImageSearch University获取预配置的Google Colab笔记本,支持Windows/macOS/Linux全平台直接运行,无需本地安装依赖。
三、项目结构与文件说明
下载并解压教程资源后,项目目录结构如下:
├── build_dataset.py # 数据集划分脚本
├── builtin_dataset.py # PyTorch内置数据集加载脚本
├── flower_photos/ # 原始花卉数据集
│ ├── daisy/
│ ├── dandelion/
│ ├── roses/
│ ├── sunflowers/
│ └── tulips/
├── load_and_visualize.py # 数据加载与批次可视化脚本
└── pyimagesearch/
├── config.py # 全局配置文件
└── __init__.py # 包初始化文件
四、全局配置文件(config.py)
集中管理所有可配置参数,避免硬编码:
# specify path to the flowers and mnist dataset
FLOWERS_DATASET_PATH = "flower_photos"
MNIST_DATASET_PATH = "mnist"
# specify the paths to our training and validation set
TRAIN = "train"
VAL = "val"
# set the input height and width
INPUT_HEIGHT = 128
INPUT_WIDTH = 128
# set the batch size and validation data split
BATCH_SIZE = 8
VAL_SPLIT = 0.1
|----------------------|-----------------------|
| 参数名 | 取值/含义 |
| FLOWERS_DATASET_PATH | 原始花卉数据集根目录 |
| MNIST_DATASET_PATH | MNIST内置数据集保存目录 |
| TRAIN/VAL | 训练集/验证集输出目录名 |
| INPUT_HEIGHT/WIDTH | 模型输入图像尺寸(128×128) |
| BATCH_SIZE | 数据批次大小(8) |
| VAL_SPLIT | 验证集占总数据集的比例(0.1,即10%) |
五、数据集划分(build_dataset.py)
将原始数据集按比例划分为训练集和验证集,保持类别分布均匀。
# USAGE
# python build_dataset.py
# import necessary packages
from pyimagesearch import config
from imutils import paths
import numpy as np
import shutil
import os
1. 核心函数:copy_images
-
功能:接收图像路径列表和目标目录,自动按类别创建子文件夹并复制图像
-
关键逻辑:从图像路径中提取类别名(路径倒数第二层),在目标目录下创建对应类别文件夹,再将图像复制到对应位置
def copy_images(imagePaths, folder):
# check if the destination folder exists and if not create it
if not os.path.exists(folder):
os.makedirs(folder)
# loop over the image paths
for path in imagePaths:
# grab image name and its label from the path and create
# a placeholder corresponding to the separate label folder
imageName = path.split(os.path.sep)[-1]
label = path.split(os.path.sep)[-2]
labelFolder = os.path.join(folder, label)
# check to see if the label folder exists and if not create it
if not os.path.exists(labelFolder):
os.makedirs(labelFolder)
# construct the destination image path and copy the current
# image to it
destination = os.path.join(labelFolder, imageName)
shutil.copy(path, destination)
2. 划分流程
-
加载所有图像路径并通过
np.random.shuffle()随机打乱,保证训练集和验证集的类别分布一致 -
按
VAL_SPLIT计算验证集和训练集的图像数量 -
分别将训练集、验证集图像复制到
train/和val/目录load all the image paths and randomly shuffle them
print("[INFO] loading image paths...")
imagePaths = list(paths.list_images(config.FLOWERS_DATASET_PATH))
np.random.shuffle(imagePaths)generate training and validation paths
valPathsLen = int(len(imagePaths) * config.VAL_SPLIT)
trainPathsLen = len(imagePaths) - valPathsLen
trainPaths = imagePaths[:trainPathsLen]
valPaths = imagePaths[trainPathsLen:]copy the training and validation images to their respective
directories
print("[INFO] copying training and validation images...")
copy_images(trainPaths, config.TRAIN)
copy_images(valPaths, config.VAL)
3. 最终目录结构
划分后生成独立的训练集和验证集目录,均保持"根目录→类别文件夹→图像文件"的结构:
├── train/
│ ├── daisy/
│ ├── dandelion/
│ ├── roses/
│ ├── sunflowers/
│ └── tulips/
└── val/
├── daisy/
├── dandelion/
├── roses/
├── sunflowers/
└── tulips/
六、PyTorch Dataset与DataLoader核心实现(load_and_visualize.py)
这是教程的核心部分,讲解如何加载数据、应用增强并构建可迭代的DataLoader。
1. 关键导入
-
ImageFolder:PyTorch内置的图像数据集类,用于加载按类别分文件夹的数据集 -
DataLoader:将Dataset包装为可迭代对象,实现按批次加载 -
transforms:提供图像预处理和数据增强的内置函数USAGE
python load_and_visualize.
import necessary packages
from pyimagesearch import config
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
2. 批次可视化函数:visualize_batch
- 输入:数据批次、类别列表、数据集类型(train/val)
- 核心处理:
-
- 将PyTorch张量(通道优先格式:C×H×W)转换为numpy数组(通道最后格式:H×W×C)
- 将归一化到[0,1]的像素值还原为[0,255]的整数格式
- 绘制批次中所有图像,并标注对应的类别名称
def visualize_batch(batch, classes, dataset_type):
# initialize a figure
fig = plt.figure("{} batch".format(dataset_type),
figsize=(config.BATCH_SIZE, config.BATCH_SIZE))
# loop over the batch size
for i in range(0, config.BATCH_SIZE):
# create a subplot
ax = plt.subplot(2, 4, i + 1)
# grab the image, convert it from channels first ordering to
# channels last ordering, and scale the raw pixel intensities
# to the range [0, 255]
image = batch[0][i].cpu().numpy()
image = image.transpose((1, 2, 0))
image = (image * 255.0).astype("uint8")
# grab the label id and get the label from the classes list
idx = batch[1][i]
label = classes[idx]
# show the image along with the label
plt.imshow(image)
plt.title(label)
plt.axis("off")
# show the plot
plt.tight_layout()
plt.show()
3. 数据预处理与增强
针对训练集和验证集设计不同的变换流水线(验证集不使用数据增强,仅做必要的格式转换):
# 训练集变换:包含数据增强
trainTransforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(p=0.25), # 25%概率水平翻转
transforms.RandomVerticalFlip(p=0.25), # 25%概率垂直翻转
transforms.RandomRotation(degrees=15), # 随机旋转±15度
transforms.ToTensor() # 转换为张量并归一化到[0,1]
])
# 验证集变换:仅调整尺寸和格式转换
valTransforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
- 关键说明:
ToTensor()不仅完成数据类型转换,还会自动将PIL图像或numpy数组的像素值从[0,255]归一化到[0,1]区间。
4. ImageFolder数据集创建
-
输入要求:数据集必须遵循"根目录/类别名/图像文件"的结构
-
自动功能:识别所有唯一类别并映射为整数标签(0~4对应5类花卉)
-
代码实现:
trainDataset = ImageFolder(root=config.TRAIN, transform=trainTransforms)
valDataset = ImageFolder(root=config.VAL, transform=valTransforms) -
核心方法:
-
__len__():返回数据集的总样本数__getitem__(index):通过索引获取单个样本,返回格式为(图像张量, 整数标签)
5. DataLoader配置
将Dataset包装为可迭代对象,支持按批次加载和并行处理:
# 训练集DataLoader:开启shuffle打乱样本,优化梯度下降收敛
trainDataLoader = DataLoader(trainDataset, batch_size=8, shuffle=True)
# 验证集DataLoader:无需打乱
valDataLoader = DataLoader(valDataset, batch_size=8)
- 核心作用:将数据集划分为固定大小的批次,支持模型批量处理;训练集开启
shuffle=True可避免模型学习到数据的顺序特征。
6. 批次获取与可视化
通过iter()将DataLoader转换为迭代器,再通过next()获取单个批次,调用visualize_batch函数展示批次中的图像和标签。
七、PyTorch内置数据集使用(builtin_dataset.py)
PyTorch的torchvision.datasets模块提供了大量常用计算机视觉数据集的一键下载和加载功能,包括MNIST、CIFAR-10、CIFAR-100、CelebA等。
1. MNIST数据集加载示例
# 加载训练集(自动下载)
trainDataset = MNIST(root=config.MNIST_DATASET_PATH, train=True, download=True, transform=transforms.ToTensor())
# 加载测试集
valDataset = MNIST(root=config.MNIST_DATASET_PATH, train=False, download=True, transform=transforms.ToTensor())
- 关键参数:
-
train=True:加载训练集;train=False:加载测试集download=True:自动将数据集下载到指定的root目录
2. 适配调整
- 可视化函数修改:MNIST为单通道灰度图,绘制时需指定
cmap="gray" - DataLoader配置与自定义数据集完全一致
八、教程总结
本文完整实现了PyTorch图像数据加载的全流程:
- 完成了自定义数据集的结构化划分,保证训练集和验证集的类别分布均匀
- 掌握了
transforms模块的使用,实现了数据预处理和训练集数据增强 - 理解了
ImageFolder和DataLoader的工作原理,构建了高效的批次数据加载流水线 - 学会了快速加载PyTorch内置数据集,简化常用数据集的使用流程
最终构建的数据加载流水线可直接无缝集成到任意PyTorch深度学习模型中,用于模型训练和验证。