【PyTorch攻略(2/7)】 加载数据集

一、说明

PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset ,允许您使用预加载的数据集以及您自己的数据。数据集 存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象,以便轻松访问样本。

PyTorch域库提供了许多示例预加载数据集,例如FashionMNIST ,它子类torch.utils.data.Dataset并实现特定于特定数据的函数。可以在此处找到它们并用作原型设计和基准测试模型的示例:

  • 图像数据集
  • 文本数据集
  • 音频数据集

二、加载数据集

我们将从TorchVision加载FashionMNIST数据集。FashionMNIST是Zalando文章图像的数据集,由60,000个训练示例和10,000个测试示例组成。每个示例包含一个 28x28 灰度图像和一个来自 10 个类之一的相关标签。

  • 每张图片高 28 像素,宽 28 像素,共 784 像素。
  • 这 10 个类告诉它是什么类型的图像。例如,T型短裤/上衣,裤子,套头衫,连衣裙,包,踝靴等。
  • 灰度是介于 0 到 255 之间的值,用于测量黑白图像的强度。强度值从白色增加到黑色。例如,白色为 0,黑色为 255。

我们使用以下参数加载FashionMNIST 数据集:

  • 是存储训练/测试数据的路径。
  • 训练指定训练或测试数据集。
  • download = 如果数据在根目录中不可用,则 True 从互联网下载数据。
  • 转换指定特征和标注转换
ajz 复制代码
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

三、迭代和可视化数据集

我们可以像列表一样手动索引数据集:*training_data[index]。*我们使用 matplotlib 来可视化训练数据中的一些样本。

ajz 复制代码
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx] # Iterate training data
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

四、准备数据以使用数据加载程序进行训练

数据集检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望以"小批量"方式传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的多处理来加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。要素是输入,标注是输出。我们训练特征并训练模型来预测标签。

DataLoader 是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。要使用数据加载器,我们需要设置以下参数:

  • 数据是将用于训练模型的训练数据;以及用于评估模型的测试数据。
  • 批大小是每个批中要处理的记录数。
  • 随机播放是按索引随机抽取的数据。
ajz 复制代码
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

五、遍历数据加载器

我们已将数据集加载到数据加载器 中,并可以根据需要循环访问数据集。下面的每次迭代都会返回一批train_featurestrain_labels (分别包含 batch_size = 64 个要素和标注)。由于我们指定了 shuffle = True,因此在我们遍历所有批次后,数据将被洗牌。

ajz 复制代码
# Display image and label
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

NOrmalization是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素都有一个介于 0 到 255 之间的值,这些值是特征。如果一个像素值为 17,另一个像素值为 197。像素重要性的分布将不均匀,因为较高的像素体积会偏离学习。归一化会更改数据的范围,而不会扭曲其在我们的功能之间的区别。进行此预处理是为了避免:

  • 预测精度降低
  • 模型学习的难度
  • 特征数据范围的不利分布

六、变换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用 transform 对数据执行一些操作,并使其适合训练。

所有 TorchVision 数据集都有两个参数:用于修改特征的转换和用于修改接受包含转换 逻辑的可调用对象的标签target_transformtorchvision.transform模块提供了几种开箱即用的常用变换。

FashionMNIST 功能采用 PIL 图像格式,标签为整数。对于训练,我们需要特征作为规范化张量,标签作为独热编码张量。为了进行这些转换,我们使用ToTensorLamda

ajz 复制代码
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

七、ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素强度值缩放在 [0., 1.] 范围内。

八、Lambda()

Lambda 应用任何用户定义的 lambda 函数。在这里,我们定义了一个函数来将整数转换为独热编码张量。它首先创建一个大小为 10(我们数据集中的标签数量)的张量并调用 scatter ,它在索引上分配一个 value=1 ,如标签 y 给出的那样。您也可以将torch.nn.functional.one_hot用作其他选项。

ajz 复制代码
target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

下一>> PyTorch 简介 (3/7)

相关推荐
网易独家音乐人Mike Zhou2 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot
安静读书2 小时前
Python解析视频FPS(帧率)、分辨率信息
python·opencv·音视频
小陈phd2 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao3 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
小二·4 小时前
java基础面试题笔记(基础篇)
java·笔记·python
小喵要摸鱼5 小时前
Python 神经网络项目常用语法
python
一念之坤7 小时前
零基础学Python之数据结构 -- 01篇
数据结构·python
wxl7812277 小时前
如何使用本地大模型做数据分析
python·数据挖掘·数据分析·代码解释器
NoneCoder7 小时前
Python入门(12)--数据处理
开发语言·python
ZHOU_WUYI7 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt