clike
复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 1. 定义数据预处理(转为 Tensor)
transform = transforms.ToTensor()
# 2. 下载并加载 CIFAR-10 训练集
cifar10_dataset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# 3. 取出一张图片和对应标签
image, label = cifar10_dataset[0]
# CIFAR-10 类别名称
classes = (
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)
print("Label index:", label)
print("Label name:", classes[label])
print("Image shape:", image.shape) # [3, 32, 32]
# 4. Tensor → NumPy,并调整维度以便显示
image_np = image.permute(1, 2, 0).numpy()
# 5. 显示图片
plt.imshow(image_np)
plt.title(classes[label])
plt.axis('off')
plt.show()