Python PyTorch 获取 MNIST 数据
- [1 PyTorch 获取 MNIST 数据](#1 PyTorch 获取 MNIST 数据)
- [2 PyTorch 保存 MNIST 数据](#2 PyTorch 保存 MNIST 数据)
- [3 PyTorch 显示 MNIST 数据](#3 PyTorch 显示 MNIST 数据)
1 PyTorch 获取 MNIST 数据
python
复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms
def mnist_get():
print(torch.__version__)
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化图像数据
])
# 获取数据
train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 训练数据
train_image = train_data.data.numpy()
train_label = train_data.targets.numpy()
# 测试数据
test_image = test_data.data.numpy()
test_label = test_data.targets.numpy()
2 PyTorch 保存 MNIST 数据
python
复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms
def mnist_save(mnist_path):
print(torch.__version__)
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化图像数据
])
# 获取数据
train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 训练数据
train_image = train_data.data.numpy()
train_label = train_data.targets.numpy()
# 测试数据
test_image = test_data.data.numpy()
test_label = test_data.targets.numpy()
np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)
mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)
3 PyTorch 显示 MNIST 数据
python
复制代码
import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms
def mnist_show(mnist_path):
data = np.load(mnist_path)
image = data['train_data'][0:100]
label = data['train_label'].reshape(-1, )
plt.figure(figsize = (10, 10))
for i in range(100):
print('%f, %f' % (i, label[i]))
plt.subplot(10, 10, i + 1)
plt.imshow(image[i])
plt.show()
mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)