【CNN算法理解】:二、AlexNet深度学习的数据集处理(附代码)

续上二、AlexNet深度学习的数据集处理

代码实现:

python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt

class AlexNetDataHandler:
    """AlexNet数据处理器"""

    def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # ImageNet标准预处理(用于原始AlexNet)
        self.imagenet_mean = [0.485, 0.456, 0.406]
        self.imagenet_std = [0.229, 0.224, 0.225]

    def get_imagenet_transforms(self, img_size=224):
        """获取ImageNet数据预处理转换"""
        # 训练集转换(包含数据增强)
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        # 验证集/测试集转换
        val_transform = transforms.Compose([
            transforms.Resize(256),  # 调整大小
            transforms.CenterCrop(img_size),  # 中心裁剪
            transforms.ToTensor(),
            transforms.Normalize(mean=self.imagenet_mean, std=self.imagenet_std)
        ])

        return train_transform, val_transform

    def get_cifar10_transforms(self):
        """获取CIFAR-10数据预处理转换"""
        # CIFAR-10均值和标准差
        cifar_mean = [0.4914, 0.4822, 0.4465]
        cifar_std = [0.2470, 0.2435, 0.2616]

        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # 随机裁剪
            transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        val_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cifar_mean, std=cifar_std)
        ])

        return train_transform, val_transform

    def load_cifar10(self):
        """加载CIFAR-10数据集"""
        print("加载CIFAR-10数据集...")

        train_transform, val_transform = self.get_cifar10_transforms()

        # 下载并加载训练集
        train_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=True,
            download=True,
            transform=train_transform
        )

        # 下载并加载测试集
        test_dataset = datasets.CIFAR10(
            root=self.data_dir,
            train=False,
            download=True,
            transform=val_transform
        )

        # 从训练集划分验证集
        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, val_size]
        )

        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        # 类别名称
        classes = ('plane', 'car', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck')

        print(f"训练集: {len(train_dataset)} 张图像")
        print(f"验证集: {len(val_dataset)} 张图像")
        print(f"测试集: {len(test_dataset)} 张图像")

        return train_loader, val_loader, test_loader, classes

    def visualize_batch(self, data_loader, classes, num_images=8):
        """可视化一个批次的数据"""
        # 获取一个批次的数据
        images, labels = next(iter(data_loader))

        # 反归一化以便显示
        mean = torch.tensor(self.imagenet_mean).view(3, 1, 1)
        std = torch.tensor(self.imagenet_std).view(3, 1, 1)
        images = images * std + mean
        images = torch.clamp(images, 0, 1)

        # 创建可视化
        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
        axes = axes.ravel()

        for i in range(num_images):
            # 转换图像维度 (C, H, W) -> (H, W, C)
            img = images[i].permute(1, 2, 0).numpy()

            axes[i].imshow(img)
            axes[i].set_title(f"标签: {classes[labels[i]]}")
            axes[i].axis('off')

        plt.suptitle('训练批次样本', fontsize=14)
        plt.tight_layout()
        plt.show()
相关推荐
徐小夕11 小时前
pxcharts Ultra V2.3更新:多维表一键导出 PDF,渲染兼容性拉满!
vue.js·算法·github
CoovallyAIHub12 小时前
OpenClaw一脚踩碎传统CV?机器终于不再只是看世界
深度学习·算法·计算机视觉
CoovallyAIHub12 小时前
仅凭单目相机实现3D锥桶定位?UNet-RKNet破解自动驾驶锥桶检测难题
深度学习·算法·计算机视觉
zone773912 小时前
002:RAG 入门-LangChain 读取文本
后端·算法·面试
得物技术13 小时前
得物社区搜推公式融合调参框架-加乘树3.0实战
算法
会员源码网1 天前
使用`mysql_*`废弃函数(PHP7+完全移除,导致代码无法运行)
后端·算法
木心月转码ing1 天前
Hot100-Day10-T438T438找到字符串中所有字母异位词
算法
HelloReader1 天前
Wi-Fi CSI 感知技术用无线信号“看见“室内的人
算法
颜酱2 天前
二叉树分解问题思路解题模式
javascript·后端·算法
qianpeng8972 天前
水声匹配场定位原理及实验
算法