代码实现:
python
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
class AlexNetDataHandler:
"""AlexNet数据处理器"""
def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
# ImageNet标准预处理(用于原始AlexNet)
self.imagenet_mean = [0.485, 0.456, 0.406]
self.imagenet_std = [0.229, 0.224, 0.225]
def get_imagenet_transforms(self, img_size=224):
"""获取ImageNet数据预处理转换"""
# 训练集转换(包含数据增强)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(img_size), # 随机裁剪
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # 颜色抖动
transforms.ToTensor(),
transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
])
# 验证集/测试集转换
val_transform = transforms.Compose([
transforms.Resize(256), # 调整大小
transforms.CenterCrop(img_size), # 中心裁剪
transforms.ToTensor(),
transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
])
return train_transform, val_transform
def get_cifar10_transforms(self):
"""获取CIFAR-10数据预处理转换"""
# CIFAR-10均值和标准差
cifar_mean = [0.4914, 0.4822, 0.4465]
cifar_std = [0.2470, 0.2435, 0.2616]
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize(mean=cifar_mean, std=cifar_std)
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cifar_mean, std=cifar_std)
])
return train_transform, val_transform
def load_cifar10(self):
"""加载CIFAR-10数据集"""
print("加载CIFAR-10数据集...")
train_transform, val_transform = self.get_cifar10_transforms()
# 下载并加载训练集
train_dataset = datasets.CIFAR10(
root=self.data_dir,
train=True,
download=True,
transform=train_transform
)
# 下载并加载测试集
test_dataset = datasets.CIFAR10(
root=self.data_dir,
train=False,
download=True,
transform=val_transform
)
# 从训练集划分验证集
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
train_dataset, [train_size, val_size]
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
print(f"训练集: {len(train_dataset)} 张图像")
print(f"验证集: {len(val_dataset)} 张图像")
print(f"测试集: {len(test_dataset)} 张图像")
return train_loader, val_loader, test_loader, classes
def visualize_batch(self, data_loader, classes, num_images=8):
"""可视化一个批次的数据"""
# 获取一个批次的数据
images, labels = next(iter(data_loader))
# 反归一化以便显示
mean = torch.tensor(self.imagenet_mean).view(3, 1, 1)
std = torch.tensor(self.imagenet_std).view(3, 1, 1)
images = images * std + mean
images = torch.clamp(images, 0, 1)
# 创建可视化
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.ravel()
for i in range(num_images):
# 转换图像维度 (C, H, W) -> (H, W, C)
img = images[i].permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].set_title(f"标签: {classes[labels[i]]}")
axes[i].axis('off')
plt.suptitle('训练批次样本', fontsize=14)
plt.tight_layout()
plt.show()