PyTorch入门之【dataset】

参考:https://www.bilibili.com/video/BV1DV4y1y7KG/?spm_id_from=333.999.0.0\&vd_source=98d31d5c9db8c0021988f2c2c25a9620

目录

使用Pytorch自带的dataset

在 PyTorch 中,torchvision.datasets 包中提供了许多经典数据集的实现,你可以使用它们来训练和测试模型。

当然这些数据集是在服务器上的它在使用的时候是联网下载的。首次运行会下载,再次运行就不用下载了。

这里以经典的MNIST 数据集为例。
总代码如下:

python 复制代码
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms

# define a transform
transform = transforms.Compose([
    transforms.Resize(24),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# download training & testing dataset
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transform
)

# create label to idx dictionary
labels = {i: training_data.classes[i] for i in range(len(training_data.classes))}

# display images in MNIST
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

# create dataloader
train_data_loader = DataLoader(training_data, batch_size=16, shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=16, shuffle=True)
print(next(iter(train_data_loader))[0].shape)

下面挨个看各个模块的作用:

python 复制代码
# define a transform
transform = transforms.Compose([
    transforms.Resize(24),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

这段代码定义了一个数据转换管道,它将一系列的图像处理操作串联起来,以便对图像进行预处理。

  • transforms.Grayscale():将彩色图像转换为灰度图像。
  • transforms.Resize(24):调整图像的大小为 24x24 像素。
  • transforms.RandomRotation(10):随机旋转图像最多 10 度,增加数据的多样性和鲁棒性。
  • transforms.ToTensor():将图像转换为张量形式,以便进行后续的数据处理和模型训练。

通过将上述操作按照顺序组合在一起,你可以定义一个 transform 对象,用于对图像数据集中的每个图像进行预处理。该 transform 对象被用于加载 MNIST 数据集,并且在 DataLoader 中配合使用。这样的数据预处理流程在深度学习中非常常见,它能够帮助提高模型训练的效果和泛化能力。你可以根据自己的需求,定制不同的转换操作,以适应不同的任务和数据集特点。

python 复制代码
# download training & testing dataset
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transform
)

上述代码就是下载training_data和test_data数据。
download=True 参数用于指定是否下载数据集。当该参数设置为 True 时,如果数据集尚未下载,则会自动下载数据集。如果数据集已经存在,将不会再次下载。在加载数据集时 datasets.MNIST() 会检查文件是否下载过。

python 复制代码
# create label to idx dictionary
labels = {i: training_data.classes[i] for i in range(len(training_data.classes))}

这段代码的作用是将 MNIST 训练集的类别标签映射为整数索引,并将其存储在 labels 字典中。

这个MNIST 训练集是用来区分0-9的数据集,故这里就可以将0映射到0,1映射到1以此类推。

python 复制代码
# display images in MNIST
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

上述代码就是将MNIST数据集中随机的生成9个图片打印出来,为了验证一下我们的MNIST数据集是否成功的加载

python 复制代码
# create dataloader
train_data_loader = DataLoader(training_data, batch_size=16, shuffle=True)
test_data_loader = DataLoader(test_data, batch_size=16, shuffle=True)
print(next(iter(train_data_loader))[0].shape)

上述代码用于创建数据加载器 (DataLoader),设置批次以及是否shuffle。

用户自定义的dataset

python 复制代码
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder


# define a transform
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(24),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

# create dataset
my_mnist = ImageFolder(root='./my-mnist', transform=transform)

# create label to idx dictionary
labels = {i: my_mnist.classes[i] for i in range(len(my_mnist.classes))}

# display images in MNIST
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(my_mnist), size=(1,)).item()
    img, label = my_mnist[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

# create dataloader
train_data_loader = DataLoader(my_mnist, batch_size=16, shuffle=True)
print(next(iter(train_data_loader))[0].shape)

总的代码几乎差不多,唯一有区别的就是数据是从自己定义的路径下加载的。

使用 ImageFolder 类创建数据集 my_mnist

相关推荐
寻星探路25 分钟前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
聆风吟º3 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子3 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder3 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能3 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5774 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
猫头虎4 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h4 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切4 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话4 小时前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python