【CNN算法理解】:二、AlexNet深度学习的数据集处理(附代码)

续上二、AlexNet深度学习的数据集处理

代码实现:

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt

class AlexNetDataHandler:
    """AlexNet数据处理器"""

    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # ImageNet标准预处理(用于原始AlexNet)
        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]

    def get_imagenet_transforms(self, img_size=224):
        """获取ImageNet数据预处理转换"""
        # 训练集转换(包含数据增强)
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        # 验证集/测试集转换
        val_transform = transforms.Compose([
            transforms.Resize(256),  # 调整大小
            transforms.CenterCrop(img_size),  # 中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        return train_transform, val_transform

    def get_cifar10_transforms(self):
        """获取CIFAR-10数据预处理转换"""
        # CIFAR-10均值和标准差
        cifar_mean = [0.4914, 0.4822, 0.4465]
        cifar_std = [0.2470, 0.2435, 0.2616]

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        return train_transform, val_transform

    def load_cifar10(self):
        """加载CIFAR-10数据集"""
        print("加载CIFAR-10数据集...")

        train_transform, val_transform = self.get_cifar10_transforms()

        # 下载并加载训练集
        train_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=True,
            download=True,
            transform=train_transform
        )

        # 下载并加载测试集
        test_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=False,
            download=True,
            transform=val_transform
        )

        # 从训练集划分验证集
        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )

        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        # 类别名称
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck')

        print(f"训练集: {len(train_dataset)} 张图像")
        print(f"验证集: {len(val_dataset)} 张图像")
        print(f"测试集: {len(test_dataset)} 张图像")

        return train_loader, val_loader, test_loader, classes

    def visualize_batch(self, data_loader, classes, num_images=8):
        """可视化一个批次的数据"""
        # 获取一个批次的数据
        images, labels = next(iter(data_loader))

        # 反归一化以便显示
        mean = torch.tensor(self.imagenet_mean).view(3, 1, 1)
        std = torch.tensor(self.imagenet_std).view(3, 1, 1)
        images = images * std + mean
        images = torch.clamp(images, 0, 1)

        # 创建可视化
        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
        axes = axes.ravel()

        for i in range(num_images):
            # 转换图像维度 (C, H, W) -> (H, W, C)
            img = images[i].permute(1, 2, 0).numpy()

            axes[i].imshow(img)
            axes[i].set_title(f"标签: {classes[labels[i]]}")
            axes[i].axis('off')

        plt.suptitle('训练批次样本', fontsize=14)
        plt.tight_layout()
        plt.show()
相关推荐
Boop_wu1 天前
[Java 算法] 动态规划2
算法·leetcode·动态规划
mxbb.1 天前
“Hello 神经网络!”
人工智能·深度学习·神经网络
yugi9878381 天前
非支配排序遗传算法NSGA-III详解与MATLAB实现
算法
ballball~~1 天前
ISP-Tone Mapping
图像处理·算法·isp
米粒11 天前
力扣算法刷题 Day22
算法·leetcode·职场和发展
科德航空的张先生1 天前
飞行错觉(空间定向障碍)地面模拟训练系统
人工智能·算法
老四啊laosi1 天前
[双指针] 2. 力扣--复写零
算法·leetcode·双指针·复写零
ballball~~1 天前
ISP-Gamma
图像处理·算法·isp
机器学习之心1 天前
HHO-LSBoost哈里斯鹰算法优化最小二乘提升多输入回归预测MATLAB代码
算法·matlab·回归·hho-lsboost
ballball~~1 天前
ISP-Demosaic
图像处理·数码相机·算法