【CNN算法理解】:二、AlexNet深度学习的数据集处理(附代码)

续上二、AlexNet深度学习的数据集处理

代码实现:

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt

class AlexNetDataHandler:
    """AlexNet数据处理器"""

    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # ImageNet标准预处理(用于原始AlexNet)
        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]

    def get_imagenet_transforms(self, img_size=224):
        """获取ImageNet数据预处理转换"""
        # 训练集转换(包含数据增强)
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        # 验证集/测试集转换
        val_transform = transforms.Compose([
            transforms.Resize(256),  # 调整大小
            transforms.CenterCrop(img_size),  # 中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        return train_transform, val_transform

    def get_cifar10_transforms(self):
        """获取CIFAR-10数据预处理转换"""
        # CIFAR-10均值和标准差
        cifar_mean = [0.4914, 0.4822, 0.4465]
        cifar_std = [0.2470, 0.2435, 0.2616]

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        return train_transform, val_transform

    def load_cifar10(self):
        """加载CIFAR-10数据集"""
        print("加载CIFAR-10数据集...")

        train_transform, val_transform = self.get_cifar10_transforms()

        # 下载并加载训练集
        train_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=True,
            download=True,
            transform=train_transform
        )

        # 下载并加载测试集
        test_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=False,
            download=True,
            transform=val_transform
        )

        # 从训练集划分验证集
        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )

        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        # 类别名称
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck')

        print(f"训练集: {len(train_dataset)} 张图像")
        print(f"验证集: {len(val_dataset)} 张图像")
        print(f"测试集: {len(test_dataset)} 张图像")

        return train_loader, val_loader, test_loader, classes

    def visualize_batch(self, data_loader, classes, num_images=8):
        """可视化一个批次的数据"""
        # 获取一个批次的数据
        images, labels = next(iter(data_loader))

        # 反归一化以便显示
        mean = torch.tensor(self.imagenet_mean).view(3, 1, 1)
        std = torch.tensor(self.imagenet_std).view(3, 1, 1)
        images = images * std + mean
        images = torch.clamp(images, 0, 1)

        # 创建可视化
        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
        axes = axes.ravel()

        for i in range(num_images):
            # 转换图像维度 (C, H, W) -> (H, W, C)
            img = images[i].permute(1, 2, 0).numpy()

            axes[i].imshow(img)
            axes[i].set_title(f"标签: {classes[labels[i]]}")
            axes[i].axis('off')

        plt.suptitle('训练批次样本', fontsize=14)
        plt.tight_layout()
        plt.show()
相关推荐
Tairitsu_H13 分钟前
[LC优选算法#2] 滑动窗口 | 长度最小的子数组 | 无重复字符的最长子串 | 最大连续1的个数
算法
小欣加油14 分钟前
leetcode3689最大子数组总值I
c++·算法·leetcode·职场和发展·贪心算法
JobDocLS19 分钟前
Jetson Orin的用法
深度学习
下午写HelloWorld22 分钟前
【概念与应用】轻量级加密算法LEA、动态脱敏算法DDA、零知识证明ZKP和优化协同交互协议OCIP
算法·区块链·密码学·安全架构·零知识证明
me83237 分钟前
【AI面试】小白理解大模型:自注意力机制如何使大模型能够捕捉长距离依赖关系,它跟RNN有什么区别?
人工智能·rnn·深度学习·ai
飞舞哲42 分钟前
三维点云最小二乘拟合MATLAB程序
开发语言·算法·matlab
Kobebryant-Manba44 分钟前
学习模型构造
python·深度学习·学习
LaughingZhu1 小时前
Product Hunt 每日热榜 | 2026-06-09
人工智能·经验分享·深度学习·神经网络·产品运营
Coder-magician1 小时前
《代码随想录》刷题打卡day12:二叉树part02
数据结构·c++·算法
极光代码工作室1 小时前
基于NLP的论文关键词提取系统
python·深度学习·自然语言处理·nlp