day42Dataset和Dataloader@浙大疏锦行
下载数据集
python
import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子
torch.manual_seed(42)
# 1. 定义预处理
# CIFAR-10 图片是 32x32 的 RGB 图片
transform = transforms.Compose([
transforms.ToTensor(), # 转为 Tensor,范围 [0, 1]
])
# 2. 加载 CIFAR-10 数据集
# root='./data' 指定下载/存放路径,如果数据不存在会自动下载
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# CIFAR-10 的类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(f"数据集大小: {len(train_dataset)}")
python
# 3. 随机获取并展示一张图片
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
image, label = train_dataset[sample_idx]
print(f"样本索引: {sample_idx}")
print(f"标签索引: {label}")
print(f"对应类别: {classes[label]}")
print(f"图片形状: {image.shape}") # (C, H, W)
# 可视化函数
def imshow(img):
# img: (C, H, W) -> (H, W, C)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off') # 不显示坐标轴
plt.show()
imshow(image)
样本索引: 37542
标签索引: 6
对应类别: frog
图片形状: torch.Size([3, 32, 32])
