【Pytorch】数据集的加载和处理(二)

【Pytorch】数据集的加载和处理(一)
Pytorch中张量可以是一维、二维、三维或者更高维度的数据结构。一维张量类似于向量,二维张量类似于矩阵,三维张量类似一系列矩阵的堆叠。

目录

将张量包装为数据集

创建数据加载器

数据转换(图像转换)


将张量包装为数据集

导入MNIST训练数据集并提取数据和标签

复制代码
import torch
import torchvision
from torchvision import datasets
train_data=datasets.MNIST("./data",train=True,download=True)
x_train, y_train=train_data.data,train_data.targets

导入MNIST验证数据集并提取数据和标签

复制代码
val_data=datasets.MNIST("./data", train=False, download=True)
x_val,y_val=val_data.data, val_data.targets

使用 TensorDataset类将张量包装为数据集

复制代码
from torch.utils.data import TensorDataset
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)

for x,y in train_ds:
    print(x.shape,y.item())
    break

创建数据加载器

通过DataLoader从数据集创建数据加载器

复制代码
from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=100)
val_dl = DataLoader(val_ds, batch_size=100)

for xb,yb in train_dl:
    print(xb.shape)
    print(yb.shape)
    break

数据转换(图像转换)

通过 transform 类进行简单的图像转换

导入库和训练数据集

复制代码
import torchvision
import matplotlib.pyplot as plt
from torchvision import datasets
from torchvision import transforms
train_data=datasets.MNIST("./data", train=True, download=True)

借助transform类定义旋转

复制代码
data_transform = transforms.Compose
([
    transforms.RandomHorizontalFlip(p=1),
    transforms.RandomVerticalFlip(p=1),
    transforms.ToTensor(),
])

对训练数据集中图像进行旋转并打印对比

复制代码
img = train_data[5][0]
img_tr=data_transform(img)
img_tr_np=img_tr.numpy()

plt.subplot(1,2,1)
plt.imshow(img,cmap="gray")
plt.title("original")
plt.subplot(1,2,2)
plt.imshow(img_tr_np[0],cmap="gray");
plt.title("transformed 180")
相关推荐
abc123456sdggfd几秒前
HTML5中Vuex持久化插件中WebStorage的底层配置
jvm·数据库·python
小龙Guo2 分钟前
Yolo 多任务推理,摄像头+视频实时推理,实现关键点、分割、检测等模型推理部署
python·yolo·关键点检测·模型推理
pele3 分钟前
Go语言如何发GET请求_Go语言HTTP GET请求教程【总结】
jvm·数据库·python
云烟成雨TD3 分钟前
Spring AI Alibaba 1.x 系列【33】Human-in-the-Loop(人在回路)演示
java·人工智能·spring
weixin_580614004 分钟前
Go 语言中 go install 命令的正确用法与常见误区详解
jvm·数据库·python
qq_654366984 分钟前
Bootstrap 5移除jQuery依赖 Bootstrap 5如何不使用jQuery
jvm·数据库·python
今天你TLE了吗5 分钟前
LLM到Agent&RAG——AI概念概述 第五章:Skill
人工智能·笔记·后端·学习
网安情报局6 分钟前
弹性云服务器跟游戏行业有什么关系?
人工智能
m0_676544387 分钟前
CSS如何实现元素悬浮在页面底部_利用fixed定位与底部间距
jvm·数据库·python
weixin_568996067 分钟前
Redis怎样监控当前发生了多少次内存驱逐
jvm·数据库·python