1. 环境与模块准备
导入torch
、torch.nn
、torch.optim
等 PyTorch 核心模块,以及numpy
、torchvision
等工具库,用于模型构建、优化、数据处理;同时定义超参数(如批次大小BATCHSIZE=100
、训练轮数EPOCHES=20
、学习率LR=0.001
等)。
2. 模型定义
构建了多类 CNN 模型,覆盖不同复杂度:
- 基础 CNN 模型 :
CNNNet
、Net
、LeNet
,结构相近,由卷积层(Conv2d
) 、池化层(MaxPool2d
) 、* 全连接层(Linear
)* 组成,通过 ReLU 激活引入非线性,属于轻量型 CNN。 - VGG 模型 :
VGG
(支持 VGG16/VGG19),通过配置字典cfg
定义 "卷积块 + 池化层" 的重复结构,利用_make_layers
方法自动生成层序列,最终接全连接层完成分类,属于深度化 CNN。
3. 数据处理
基于torchvision
加载CIFAR10 数据集,并定义数据变换:
- 训练集:加入
RandomCrop
(随机裁剪)、RandomHorizontalFlip
(水平翻转)增强数据多样性,再通过ToTensor
(转张量)、Normalize
(标准化)统一数据分布。 - 测试集:仅保留
ToTensor
和Normalize
,避免数据增强引入额外噪声。随后通过DataLoader
创建训练 / 测试数据加载器,实现批量数据迭代。
4. 模型训练与评估
采用两种策略开展训练与性能评估:
- 集成学习(投票机制) :将
CNNNet
、Net
、LeNet
封装为列表,共享Adam
优化器与CrossEntropyLoss
损失函数。训练时,每个模型独立前向传播、计算损失并反向传播更新参数;测试时,各模型输出预测结果,通过 "多数表决" 得到集成模型的预测,最终对比集成模型与单模型的准确率。 - 单模型(VGG16)训练 :单独训练
VGG('VGG16')
,流程与集成方法一致,重点跟踪 VGG16 在每轮训练后的测试准确率。
核心意图
通过对比基础 CNN 模型 、集成模型 与深度 VGG 模型 的性能,展现模型结构复杂度 (如网络深度)、集成学习策略对 CIFAR10 图像分类任务准确率的影响。
5.代码










