python
import torch
import torchvision
import matplotlib.pyplot as plt
batch_size = 512
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist data',train=True,download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist data/',train=False,download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=batch_size,shuffle=False)
x,y = next(iter(train_loader))
plt.imshow(x[0].permute(1, 2, 0)) # 将通道维度调整到最后
plt.axis('off') # 关闭坐标轴
plt.show()