Day 41 卷积神经网络(CNN)基础与实战

在上一节中,我们尝试使用全连接网络(MLP)处理 CIFAR-10 图像分类任务,但发现准确率难以突破瓶颈。这是因为 MLP 将图像的所有像素展平为一维向量,破坏了图像原本的空间结构信息(如局部纹理、形状边缘等)。今天我们正式引入卷积神经网络(CNN),它通过"卷积"和"池化"操作,专门用于提取图像的空间特征。

1. 为什么需要 CNN?

全连接网络(MLP)处理图像面临两个主要问题:

  1. 参数量爆炸:对于高分辨率图像,全连接层的权重数量巨大,难以训练且容易过拟合。
  2. 空间信息丢失:展平操作忽略了像素之间的邻域关系。

CNN 通过局部感知 (卷积核只看局部区域)和权值共享(同一个卷积核扫描整张图),在大幅减少参数量的同时,有效地提取了图像的平移不变性特征。

2. 数据增强 (Data Augmentation)

在训练深度学习模型时,数据量往往决定了模型的上限。数据增强通过对原始图像进行一系列随机变换,生成形态各异的新样本,从而在不增加实际采集成本的情况下扩展数据集,显著提升模型的泛化能力。

我们在训练集中使用了以下增强策略:

复制代码
train_transform = transforms.Compose([
    # 随机裁剪:在四周填充4像素后,随机裁剪出32x32
    transforms.RandomCrop(32, padding=4),
    # 随机水平翻转:模拟物体方向的变化
    transforms.RandomHorizontalFlip(),
    # 颜色抖动:随机调整亮度、对比度、饱和度、色相
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    # 随机旋转:最大旋转15度
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    # 标准化:使用 CIFAR-10 数据集的均值和标准差
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

注意:测试集通常只进行标准化处理,不进行随机变换,以确保评估结果的稳定性。

3. CNN 模型架构设计

我们构建了一个经典的 CNN 结构,包含三个卷积块和一个分类器。

3.1 核心组件解析

  • 卷积层 (Conv2d):特征提取器。通过滑动窗口(卷积核)提取边缘、纹理等特征。
  • 批量归一化 (BatchNorm2d):加速收敛。对每一批数据的特征图进行归一化(均值0,方差1),解决"内部协变量偏移"问题,使得模型可以使用更大的学习率,并具有一定的正则化效果。
  • 激活函数 (ReLU):引入非线性,增加模型的表达能力。
  • 最大池化 (MaxPool2d):下采样。保留局部区域的最强特征,减小特征图尺寸,降低计算量。

3.2 模型代码实现

复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        # 卷积块 1:输入 3 通道 -> 输出 32 通道
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 尺寸减半: 32 -> 16
        
        # 卷积块 2:输入 32 通道 -> 输出 64 通道
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2) # 尺寸减半: 16 -> 8
        
        # 卷积块 3:输入 64 通道 -> 输出 128 通道
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2) # 尺寸减半: 8 -> 4
        
        # 全连接分类器
        # 展平维度计算:128通道 * 4(高) * 4(宽) = 2048
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 10) # 输出 10 个类别

    def forward(self, x):
        # x: [batch, 3, 32, 32]
        x = self.pool1(self.relu1(self.bn1(self.conv1(x)))) # -> [batch, 32, 16, 16]
        x = self.pool2(self.relu2(self.bn2(self.conv2(x)))) # -> [batch, 64, 8, 8]
        x = self.pool3(self.relu3(self.bn3(self.conv3(x)))) # -> [batch, 128, 4, 4]
        
        # 展平
        x = x.view(-1, 128 * 4 * 4) # -> [batch, 2048]
        
        x = self.dropout(self.relu3(self.fc1(x)))
        x = self.fc2(x)
        return x

3.3 维度变换推导

输入图片尺寸为 32 \\times 32

  1. Block 1: Conv(padding=1) \\rightarrow 32 \\times 32; Pool(2x2) \\rightarrow 16 \\times 16.
  2. Block 2: Conv(padding=1) \\rightarrow 16 \\times 16; Pool(2x2) \\rightarrow 8 \\times 8.
  3. Block 3: Conv(padding=1) \\rightarrow 8 \\times 8; Pool(2x2) \\rightarrow 4 \\times 4.

最终特征图大小为 128 \\times 4 \\times 4

4. 学习率调度器 (Learning Rate Scheduler)

为了进一步提升模型性能,我们引入了学习率调度器。在训练初期,较大的学习率有助于快速下降;在训练后期,较小的学习率有助于模型在极小值附近精细收敛。

我们使用的是 ReduceLROnPlateau,它是一种"监控型"调度器:

  • 机制 :当监控的指标(如验证集 Loss)在 patience 个 epoch 内不再下降时,自动将学习率乘以 factor 进行衰减。

  • 适用场景:几乎适用于所有监督学习任务,特别是在不知道具体何时衰减 LR 最优时。

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min', # 监控指标是越小越好(Loss)
    patience=3, # 容忍 3 个 epoch 不提升
    factor=0.5 # 衰减系数
    )

    在训练循环中更新

    scheduler.step(epoch_test_loss)

5. 训练效果对比

相较于 MLP,CNN 在 CIFAR-10 上的表现有质的飞跃:

  • MLP:通常只能达到 50%-55% 的准确率。
  • 简单 CNN:在本实验中,配合数据增强和 BatchNorm,准确率可以轻松达到 80% 以上。

这一结果证明了 CNN 在提取图像特征方面的强大能力。卷积层作为特征提取器,能够从底层的边缘、颜色,逐层抽象到高层的形状、物体部件,这是全连接网络无法做到的。

6. 总结

  1. 数据增强是提升图像分类模型泛化能力的必备手段。
  2. BatchNorm 是现代 CNN 的标配,能显著加速收敛并稳定训练。
  3. CNN 结构(卷积+池化)通过保留空间结构和参数共享,高效地处理了图像数据。
  4. 学习率调度器 帮助模型在训练后期打破瓶颈,进一步提升精度。
相关推荐
Dingdangcat864 小时前
基于改进YOLO11-C2PSA-SEFFN的工业环境气体泄漏检测与定位系统实现
python
AI视觉网奇4 小时前
live2d 单图转模型 单图生成模型
java·前端·python
祝威廉4 小时前
摘下数据分析的皇冠:机器学习,InfiniSynapse 金融评分卡案例
人工智能·机器学习·金融·数据挖掘·数据分析
产品何同学4 小时前
复刻DeepSeek与GPT!AI智能对话Web高保真原型设计全解析
人工智能·gpt·墨刀·高保真原型·deepseek·ai智能写作·ai智能对话
杭州泽沃电子科技有限公司4 小时前
变流器与变压器:风电并网智能监测的“守护神”与“稳定锚”
人工智能·智能监测·发电
咸鱼加辣4 小时前
【python面试题】LRUCache
开发语言·python
中國龍在廣州4 小时前
“太空数据中心”成AI必争之地?
人工智能·深度学习·算法·机器学习·机器人
LitchiCheng4 小时前
WSL2 中 pynput 无法捕获按键输入?
开发语言·python
中年程序员一枚4 小时前
Python 中处理视频添加 / 替换音频
开发语言·python·音视频