【CNN算法理解】:二、AlexNet深度学习的数据集处理(附代码)

续上二、AlexNet深度学习的数据集处理

代码实现:

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt

class AlexNetDataHandler:
    """AlexNet数据处理器"""

    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # ImageNet标准预处理(用于原始AlexNet)
        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]

    def get_imagenet_transforms(self, img_size=224):
        """获取ImageNet数据预处理转换"""
        # 训练集转换(包含数据增强)
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        # 验证集/测试集转换
        val_transform = transforms.Compose([
            transforms.Resize(256),  # 调整大小
            transforms.CenterCrop(img_size),  # 中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        return train_transform, val_transform

    def get_cifar10_transforms(self):
        """获取CIFAR-10数据预处理转换"""
        # CIFAR-10均值和标准差
        cifar_mean = [0.4914, 0.4822, 0.4465]
        cifar_std = [0.2470, 0.2435, 0.2616]

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        return train_transform, val_transform

    def load_cifar10(self):
        """加载CIFAR-10数据集"""
        print("加载CIFAR-10数据集...")

        train_transform, val_transform = self.get_cifar10_transforms()

        # 下载并加载训练集
        train_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=True,
            download=True,
            transform=train_transform
        )

        # 下载并加载测试集
        test_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=False,
            download=True,
            transform=val_transform
        )

        # 从训练集划分验证集
        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )

        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        # 类别名称
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck')

        print(f"训练集: {len(train_dataset)} 张图像")
        print(f"验证集: {len(val_dataset)} 张图像")
        print(f"测试集: {len(test_dataset)} 张图像")

        return train_loader, val_loader, test_loader, classes

    def visualize_batch(self, data_loader, classes, num_images=8):
        """可视化一个批次的数据"""
        # 获取一个批次的数据
        images, labels = next(iter(data_loader))

        # 反归一化以便显示
        mean = torch.tensor(self.imagenet_mean).view(3, 1, 1)
        std = torch.tensor(self.imagenet_std).view(3, 1, 1)
        images = images * std + mean
        images = torch.clamp(images, 0, 1)

        # 创建可视化
        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
        axes = axes.ravel()

        for i in range(num_images):
            # 转换图像维度 (C, H, W) -> (H, W, C)
            img = images[i].permute(1, 2, 0).numpy()

            axes[i].imshow(img)
            axes[i].set_title(f"标签: {classes[labels[i]]}")
            axes[i].axis('off')

        plt.suptitle('训练批次样本', fontsize=14)
        plt.tight_layout()
        plt.show()
相关推荐
九.九9 小时前
ops-transformer:AI 处理器上的高性能 Transformer 算子库
人工智能·深度学习·transformer
春日见9 小时前
拉取与合并:如何让个人分支既包含你昨天的修改,也包含 develop 最新更新
大数据·人工智能·深度学习·elasticsearch·搜索引擎
寻寻觅觅☆9 小时前
东华OJ-基础题-106-大整数相加(C++)
开发语言·c++·算法
偷吃的耗子9 小时前
【CNN算法理解】:三、AlexNet 训练模块(附代码)
深度学习·算法·cnn
化学在逃硬闯CS10 小时前
Leetcode1382. 将二叉搜索树变平衡
数据结构·算法
ceclar12310 小时前
C++使用format
开发语言·c++·算法
Faker66363aaa11 小时前
【深度学习】YOLO11-BiFPN多肉植物检测分类模型,从0到1实现植物识别系统,附完整代码与教程_1
人工智能·深度学习·分类
Gofarlic_OMS11 小时前
科学计算领域MATLAB许可证管理工具对比推荐
运维·开发语言·算法·matlab·自动化
夏鹏今天学习了吗11 小时前
【LeetCode热题100(100/100)】数据流的中位数
算法·leetcode·职场和发展
忙什么果12 小时前
上位机、下位机、FPGA、算法放在哪层合适?
算法·fpga开发