基于PyTorch的深度学习6——数据处理工具箱1

PyTorch涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如图所示

torch.utils.data工具包,它包括以下4个类。

1)Dataset:是一个抽象类,其他数据集需要继承这个类,并且覆写其中的两个方法(getitem_、len)。

2)DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能。

3)random_split:把数据集随机拆分为给定长度的非重叠的新数据集。

4)*sampler:多种采样函数。

PyTorch可视化处理工具(Torchvision),其是PyTorch的一个视觉处理工具包,独立于PyTorch,需要另外安装。

pip install torchvision #或conda install torchvision

它包括4个类,各类的主要功能如下。

1)datasets:提供常用的数据集加载,设计上都是继承自torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet和COCO等。

2)models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择pretrained=True)​,包括AlexNet、VGG系列、ResNet系列、Inception系列等。

3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。

4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是save_img,它能将Tensor保存成图片。

------------------------------------------------------utils.data

utils.data包括Dataset和DataLoader。

torch.utils.data.Dataset为抽象类。自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供数据的大小(size),后者通过给定索引获取数据和标签。

__getitem__一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。

首先我们来定义一个简单的数据集,然后通过具体使用Dataset及DataLoader,给读者一个直观的认识

1)导入需要的模块

复制代码
import torch
from torch.utils import data
import numpy as np

2)定义获取数据集的类

该类继承基类Dataset,自定义一个数据集及对应标签

复制代码
import torch
from torch.utils.data import Dataset
import numpy as np

class TestDataset(Dataset):
    def __init__(self):
        super(TestDataset, self).__init__()  # 显式调用父类初始化
        self.Data = np.asarray([[1, 2], [3, 4], [2, 1], [3, 4], [4, 5]], dtype=np.float32)
        self.Label = np.asarray([0, 1, 0, 1, 2], dtype=np.int64)

    def __getitem__(self, index):
        # 将numpy数据转换为Tensor
        txt = torch.from_numpy(self.Data[index])  # 默认float32
        label = torch.tensor(self.Label[index], dtype=torch.long)  # 显式指定类型
        return txt, label

    def __len__(self):
        return len(self.Data)

3)获取数据集中的数据

复制代码
Test=TestDataset()
print(Test[2]) #相当于调用__getitem__(2)
print(Test.__len__())

以上数据以tuple返回,每次只返回一个样本。实际上,Dateset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和并行加速等操作,可选择DataLoader。

DataLoader的格式为:

复制代码
data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=<function default_collate at 0x7f108ee01620>,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
)

主要参数说明:

• dataset:加载的数据集。

• batch_size:批大小。

• shuffle:是否将数据打乱。

• sampler:样本抽样。

• num_workers:使用多进程加载的进程数,0代表不使用多进程。

• collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。

• pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些。

• drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。

复制代码
test_loader=data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
    print('i:',i)
    Data,Label=traindata
    print('data:',Data)
    print('Label:',Label)

运行结果:

复制代码
i: 0
data: tensor([[1, 2],
        [3, 4]])
Label: tensor([0, 1])
i: 1
data: tensor([[2, 1],
        [3, 4]])
Label: tensor([0, 1])
i: 2
data: tensor([[4, 5]])
Label: tensor([2])

从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,比如对它进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器。

复制代码
dataiter=iter(test_loader)
imgs,labels=next(dataiter)

一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,因为不同的目录代表不同类别(这种情况比较普遍)​,使用data.Dataset来处理就很不方便。不过,使用PyTorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数

相关推荐
爱学习的程序媛3 分钟前
“数字孪生”详解与前端技术栈
前端·人工智能·计算机视觉·智慧城市·信息与通信
数业智能心大陆3 分钟前
科技赋能心育服务,心大陆 AI 减压舱守护校园心灵健康
人工智能·心理健康
程序员Sunday11 分钟前
Claude Code 生态爆发:5个必知的新工具
前端·人工智能·后端
智算菩萨31 分钟前
【How Far Are We From AGI】6 AGI的进化论——从胚胎到终极的三级跃迁与发展路线图
论文阅读·人工智能·深度学习·ai·agi
夏同学Xavi40 分钟前
skls-mgr:统一管理 Agent Skills 的 CLI 工具
人工智能·程序员·命令行
梦醒过后说珍重43 分钟前
【超分实战】拒绝灾难性遗忘!记一次原生4K医疗影像(SurgiSR4K)的模型微调踩坑实录
深度学习
天青色等烟雨0944 分钟前
Skill的终局:不是被生成,而是能进化
人工智能·agent
FPGA-ADDA1 小时前
第四篇:嵌入式系统常用通信接口详解(I2C、SPI、UART、RS232/485、CAN、USB)
人工智能·单片机·嵌入式硬件·fpga开发·信息与通信
智算菩萨1 小时前
【How Far Are We From AGI】7 AGI的七重奏——从实验室到现实世界的应用图景与文明展望
论文阅读·人工智能·ai·agi·感知
一招定胜负1 小时前
从 TXT 到 CSV 再到 Flask 部署:语音转写 AI 总结全流程实战
人工智能