【PyTorch】(基础五)---- 图像数据集加载

数据集

torchvision数据集

Torchvision在torchvision.datasets模块中提供了许多内置的数据集,以及用于构建您自己的数据集的实用程序类。关于一些内置数据集目录如下,点击进去之后会有详细的数据集介绍,包括数据集大小、分类类型、以及下载方式等。

接下来我都会使用CIFAR-10数据集为例进行展示,因为其数据量不是很大,下载到本地后对存储的压力不是很大,CIFAR-10数据集由10个类别的60000张32 x32彩色图像组成,每个类别6000张图像。有50000张训练图像和10000张测试图像,用于完成图像分类任务,点进去之后看到详细页面

以下是CIFAR-10数据集的类别和部分图片展示:

torchvision提供的下载CIFAR10数据集的语法如下:

py 复制代码
torchvision.datasets.CIFAR10(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

# CIFAR10为数据集的名字
# root为存放数据的目录
# train表示是否为训练数据集
# download表示是否将数据下载到本地
# transform可以指定如何对数据集进行预处理,通常都是使用ToTensor

# 我们在选择好这些属性参数后的命令如下:
dataset = torchvision.datasets.CIFAR10(root="./dataset", train = True, download = True,transform=torchvision.transforms.ToTensor())

运行后得到的下载链接可以放到迅雷中使用,下载速度更快,下载后将压缩文件复制到当前的目录下面即可,系统再次运行时会对其进行解压使用

第一次下载后,即使继续打开download运行,系统在检测到之后就不需要继续进行下载了,所以download经常处于打开状态后续也无需进行修改

DataLoader

datalorader用于加载数据集并提供迭代器(指定每次取多少个数据,是否随机取数等),使得模型训练过程中的数据读取更加高效。其基本语法如下:

py 复制代码
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# dataset指定从哪个DataSet中进行读取数据
# batch_size指定每次读取数据的多少
# shuffle指定是否进行随机抽取
# num_workers指定的多线程数量,在Windows中设置大于0可能出问题
# drop_last表示是否保留最后余出的几个数据,比如一共有1024张图片,batch_size设置为100,最终会余出24个数据

【代码示例】使用DataLoader读取CIFAR10数据集,每次读取64张图片

py 复制代码
import torchvision

# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset/CIFAR10", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
# 观察每一次随机的结果
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        # print(imgs.shape)
        # print(targets)
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1

writer.close()

运行结果:

自定义数据集

要实现自定义的数据集类,首先,需要创建一个数据集类。这个类需要继承 torch.utils.data.Dataset 并实现两个方法:__len____getitem__

这里使用一个蚂蚁和蜜蜂的分类数据集(网盘下载),观察一下其目录结构,整个数据集分成训练集(train)和验证集(val)两部分,每部分包含ants和bees两个文件夹,每个文件夹中都是若干图片,其中的文件夹名字就是其图片的label

我们可以创建自定义的Dataset读取具体的数据。

py 复制代码
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

# 定义自定义 Dataset 类
class AntsBeesDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # 遍历根目录下的 ants 和 bees 文件夹
        for label in ['ants', 'bees']:
            label_dir = os.path.join(root_dir, label)
            for filename in os.listdir(label_dir):
                if filename.endswith('.jpg'):
                    self.image_paths.append(os.path.join(label_dir, filename))
                    self.labels.append(0 if label == 'ants' else 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# 定义数据变换
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建训练数据集
train_dataset = AntsBeesDataset(root_dir='./data_antAndBee/train', transform=transform)

# 创建验证数据集
val_dataset = AntsBeesDataset(root_dir='./data_antAndBee/val', transform=transform)

# 创建数据加载器
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,  # 使用多进程
    pin_memory=True,  # 使用 pinned 内存
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,  # 使用多进程
    pin_memory=True,  # 使用 pinned 内存
)

# 检查数据集和数据加载器
print(train_dataset.__len__())  # 输出数据集长度
print(train_dataset.__getitem__(0))  # 输出第一个样本

# TensorBoard Writer
writer = SummaryWriter('logs/log8')

# 每轮读取的结果都是随机的
for epoch in range(2):
    step = 0
    for data in train_loader:
        imgs, targets = data
        # 如果需要将图像堆叠成一个张量,可以在这里进行处理
        # 例如,使用 pad_sequence 或者自定义的填充方法
        # imgs = torch.stack([F.pad(img, (0, max_width - img.size(-1), 0, max_height - img.size(-2))) for img in imgs])
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step += 1

writer.close()

上面的方法实现起来很复杂,其实这种目录结构叫做"ImageFolder"格式,其结构如下所示:

复制代码
root/
├── class1/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
├── class2/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
└── ...

torchvision.datasets为我们提供了一种方法用于快速便捷地处理这种格式的数据集,使用方法ImageFolder()可以快速加载对应的数据集

py 复制代码
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms

if __name__ == "__main__":
    # 定义数据变换
    transform = transforms.Compose([
        # 改成统一大小并转换成tensor类型
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    # 创建训练数据集
    train_dataset = datasets.ImageFolder(root='./data_antAndBee/train',transform=transform)

    # 创建验证数据集
    val_dataset = datasets.ImageFolder(root='./data_antAndBee/val',transform=transform)

    # 创建数据加载器
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0,  # 使用多进程
        pin_memory=True  # 使用 pinned 内存
    )

    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=0,  # 使用多进程
        pin_memory=True  # 使用 pinned 内存
    )

    # 检查数据集和数据加载器
    print(train_dataset.classes)  # 输出类别名称
    print(train_dataset.class_to_idx)  # 输出类别到索引的映射

    # TensorBoard Writer
    writer = SummaryWriter('logs/log7')

    # 每轮读取的结果都是随机的
    for epoch in range(2):
        step = 0
        for data in train_loader:
            imgs, targets = data
            writer.add_images("Epoch: {}".format(epoch), imgs, step)
            step += 1

    writer.close()
相关推荐
蹦蹦跳跳真可爱5894 分钟前
Python----PaddlePaddle(深度学习框架PaddlePaddle,概述,安装,衍生工具)
开发语言·人工智能·python·paddlepaddle
机器之心11 分钟前
迈向机器人领域ImageNet,大牛Pieter Abbeel领衔国内外高校共建RoboVerse,统一仿真平台、数据集和基准
人工智能
cccccc语言我来了16 分钟前
飞浆PaddlePaddle 猫狗数据大战
人工智能·python·paddlepaddle
机器之心18 分钟前
Llama 4在测试集上训练?内部员工、官方下场澄清,LeCun转发
人工智能
QQ_77813297419 分钟前
深入解析:Python爬取Bilibili视频的技术创新与高阶实践
python
机器之心20 分钟前
首次引入强化学习!火山引擎Q-Insight让画质理解迈向深度思考
人工智能
涛涛讲AI22 分钟前
Python urllib3 全面指南:从基础到实战应用
开发语言·python·urllib3
阿白630522 分钟前
AI辅助编程_pyThon
python
说私域25 分钟前
基于开源 AI 大模型 AI 智能名片 S2B2C 商城小程序的京城首家无人智慧书店创新模式研究
人工智能·小程序·开源·零售
爱的叹息28 分钟前
关于Spring MVC中传递数组参数的详细说明,包括如何通过逗号分隔的字符串自动转换为数组,以及具体的代码示例和总结表格
python·spring·mvc