Day 41 卷积神经网络(CNN)基础与实战

在上一节中,我们尝试使用全连接网络(MLP)处理 CIFAR-10 图像分类任务,但发现准确率难以突破瓶颈。这是因为 MLP 将图像的所有像素展平为一维向量,破坏了图像原本的空间结构信息(如局部纹理、形状边缘等)。今天我们正式引入卷积神经网络(CNN),它通过"卷积"和"池化"操作,专门用于提取图像的空间特征。

1. 为什么需要 CNN?

全连接网络(MLP)处理图像面临两个主要问题:

  1. 参数量爆炸:对于高分辨率图像,全连接层的权重数量巨大,难以训练且容易过拟合。
  2. 空间信息丢失:展平操作忽略了像素之间的邻域关系。

CNN 通过局部感知 (卷积核只看局部区域)和权值共享(同一个卷积核扫描整张图),在大幅减少参数量的同时,有效地提取了图像的平移不变性特征。

2. 数据增强 (Data Augmentation)

在训练深度学习模型时,数据量往往决定了模型的上限。数据增强通过对原始图像进行一系列随机变换,生成形态各异的新样本,从而在不增加实际采集成本的情况下扩展数据集,显著提升模型的泛化能力。

我们在训练集中使用了以下增强策略:

复制代码
train_transform = transforms.Compose([
    # 随机裁剪:在四周填充4像素后,随机裁剪出32x32
    transforms.RandomCrop(32, padding=4),
    # 随机水平翻转:模拟物体方向的变化
    transforms.RandomHorizontalFlip(),
    # 颜色抖动:随机调整亮度、对比度、饱和度、色相
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    # 随机旋转:最大旋转15度
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    # 标准化:使用 CIFAR-10 数据集的均值和标准差
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

注意:测试集通常只进行标准化处理,不进行随机变换,以确保评估结果的稳定性。

3. CNN 模型架构设计

我们构建了一个经典的 CNN 结构,包含三个卷积块和一个分类器。

3.1 核心组件解析

  • 卷积层 (Conv2d):特征提取器。通过滑动窗口(卷积核)提取边缘、纹理等特征。
  • 批量归一化 (BatchNorm2d):加速收敛。对每一批数据的特征图进行归一化(均值0,方差1),解决"内部协变量偏移"问题,使得模型可以使用更大的学习率,并具有一定的正则化效果。
  • 激活函数 (ReLU):引入非线性,增加模型的表达能力。
  • 最大池化 (MaxPool2d):下采样。保留局部区域的最强特征,减小特征图尺寸,降低计算量。

3.2 模型代码实现

复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        # 卷积块 1:输入 3 通道 -> 输出 32 通道
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 尺寸减半: 32 -> 16
        
        # 卷积块 2:输入 32 通道 -> 输出 64 通道
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2) # 尺寸减半: 16 -> 8
        
        # 卷积块 3:输入 64 通道 -> 输出 128 通道
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2) # 尺寸减半: 8 -> 4
        
        # 全连接分类器
        # 展平维度计算:128通道 * 4(高) * 4(宽) = 2048
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 10) # 输出 10 个类别

    def forward(self, x):
        # x: [batch, 3, 32, 32]
        x = self.pool1(self.relu1(self.bn1(self.conv1(x)))) # -> [batch, 32, 16, 16]
        x = self.pool2(self.relu2(self.bn2(self.conv2(x)))) # -> [batch, 64, 8, 8]
        x = self.pool3(self.relu3(self.bn3(self.conv3(x)))) # -> [batch, 128, 4, 4]
        
        # 展平
        x = x.view(-1, 128 * 4 * 4) # -> [batch, 2048]
        
        x = self.dropout(self.relu3(self.fc1(x)))
        x = self.fc2(x)
        return x

3.3 维度变换推导

输入图片尺寸为 32 \\times 32

  1. Block 1: Conv(padding=1) \\rightarrow 32 \\times 32; Pool(2x2) \\rightarrow 16 \\times 16.
  2. Block 2: Conv(padding=1) \\rightarrow 16 \\times 16; Pool(2x2) \\rightarrow 8 \\times 8.
  3. Block 3: Conv(padding=1) \\rightarrow 8 \\times 8; Pool(2x2) \\rightarrow 4 \\times 4.

最终特征图大小为 128 \\times 4 \\times 4

4. 学习率调度器 (Learning Rate Scheduler)

为了进一步提升模型性能,我们引入了学习率调度器。在训练初期,较大的学习率有助于快速下降;在训练后期,较小的学习率有助于模型在极小值附近精细收敛。

我们使用的是 ReduceLROnPlateau,它是一种"监控型"调度器:

  • 机制 :当监控的指标(如验证集 Loss)在 patience 个 epoch 内不再下降时,自动将学习率乘以 factor 进行衰减。

  • 适用场景:几乎适用于所有监督学习任务,特别是在不知道具体何时衰减 LR 最优时。

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min', # 监控指标是越小越好(Loss)
    patience=3, # 容忍 3 个 epoch 不提升
    factor=0.5 # 衰减系数
    )

    在训练循环中更新

    scheduler.step(epoch_test_loss)

5. 训练效果对比

相较于 MLP,CNN 在 CIFAR-10 上的表现有质的飞跃:

  • MLP:通常只能达到 50%-55% 的准确率。
  • 简单 CNN:在本实验中,配合数据增强和 BatchNorm,准确率可以轻松达到 80% 以上。

这一结果证明了 CNN 在提取图像特征方面的强大能力。卷积层作为特征提取器,能够从底层的边缘、颜色,逐层抽象到高层的形状、物体部件,这是全连接网络无法做到的。

6. 总结

  1. 数据增强是提升图像分类模型泛化能力的必备手段。
  2. BatchNorm 是现代 CNN 的标配,能显著加速收敛并稳定训练。
  3. CNN 结构(卷积+池化)通过保留空间结构和参数共享,高效地处理了图像数据。
  4. 学习率调度器 帮助模型在训练后期打破瓶颈,进一步提升精度。
相关推荐
杜子不疼.2 分钟前
计算机视觉热门模型手册:Spring Boot 3.2 自动装配新机制:@AutoConfiguration 使用指南
人工智能·spring boot·计算机视觉
无心水2 小时前
【分布式利器:腾讯TSF】7、TSF高级部署策略全解析:蓝绿/灰度发布落地+Jenkins CI/CD集成(Java微服务实战)
java·人工智能·分布式·ci/cd·微服务·jenkins·腾讯tsf
北辰alk7 小时前
RAG索引流程详解:如何高效解析文档构建知识库
人工智能
九河云7 小时前
海上风电“AI偏航对风”:把发电量提升2.1%,单台年增30万度
大数据·人工智能·数字化转型
wm10437 小时前
机器学习第二讲 KNN算法
人工智能·算法·机器学习
沈询-阿里8 小时前
Skills vs MCP:竞合关系还是互补?深入解析Function Calling、MCP和Skills的本质差异
人工智能·ai·agent·ai编程
xiaobai1788 小时前
测试工程师入门AI技术 - 前序:跨越焦虑,从优势出发开启学习之旅
人工智能·学习
盛世宏博北京8 小时前
云边协同・跨系统联动:智慧档案馆建设与功能落地
大数据·人工智能
Learn-Python8 小时前
MongoDB-only方法
python·sql
TGITCIC9 小时前
讲透知识图谱Neo4j在构建Agent时到底怎么用(二)
人工智能·知识图谱·neo4j·ai agent·ai智能体·大模型落地·graphrag