import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
设置随机种子确保结果可复现
torch.manual_seed(42)
定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化处理
])
加载数据集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
transform=transform
)
类别名称
class_names = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
可视化函数
def imshow(img, title=None):
"""显示反归一化后的图像"""
img = img * 0.5 + 0.5 # 反归一化 [-1,1] -> [0,1]
np_img = img.numpy()
plt.figure(figsize=(4, 4))
plt.imshow(np.transpose(np_img, (1, 2, 0))) # CHW -> HWC
if title:
plt.title(title)
plt.axis('off')
plt.show()
随机选择并显示样本
sample_idx = torch.randint(0, len(train_dataset), (1,)).item()
image, label = train_dataset[sample_idx]
print(f"Label: {label} ({class_names[label]})")
imshow(image, f"Label: {class_names[label]}")