【CNN算法理解】:三、AlexNet 训练模块(附代码)

文章目录

概述

本模块提供 AlexNet 网络结构的完整实现和训练框架,包含原始 AlexNet、简化版 AlexNet 以及完整的训练器类,支持多种数据集和训练配置。

AlexNet 网络架构

原始 AlexNet 结构 (AlexNet 类)

设计理念
  • 输入:227×227×3(论文中为224×224×3,计算推导常用227)
  • 输出:1000个类别(ImageNet)
  • 包含8个学习层(5个卷积层,3个全连接层)
  • 使用ReLU激活函数替代传统的tanh函数
  • 引入Local Response Normalization(LRN)
详细层结构
类型 参数 输入尺寸 输出尺寸 参数数量计算
1 Conv2d 11×11, stride=4, 96个滤波器 227×227×3 55×55×96 (11×11×3)×96 = 34,848
1 ReLU - 55×55×96 55×55×96 -
1 LRN size=5, α=0.0001, β=0.75 55×55×96 55×55×96 -
1 MaxPool 3×3, stride=2 55×55×96 27×27×96 -
2 Conv2d 5×5, padding=2, 256个滤波器 27×27×96 27×27×256 (5×5×96)×256 = 614,400
2 ReLU - 27×27×256 27×27×256 -
2 LRN size=5, α=0.0001, β=0.75 27×27×256 27×27×256 -
2 MaxPool 3×3, stride=2 27×27×256 13×13×256 -
3 Conv2d 3×3, padding=1, 384个滤波器 13×13×256 13×13×384 (3×3×256)×384 = 884,736
3 ReLU - 13×13×384 13×13×384 -
4 Conv2d 3×3, padding=1, 384个滤波器 13×13×384 13×13×384 (3×3×384)×384 = 1,327,104
4 ReLU - 13×13×384 13×13×384 -
5 Conv2d 3×3, padding=1, 256个滤波器 13×13×384 13×13×256 (3×3×384)×256 = 884,736
5 ReLU - 13×13×256 13×13×256 -
5 MaxPool 3×3, stride=2 13×13×256 6×6×256 -
FC1 Linear 4096个神经元 6×6×256=9216 4096 9216×4096 = 37,752,832
FC2 Linear 4096个神经元 4096 4096 4096×4096 = 16,777,216
FC3 Linear 1000个神经元 4096 1000 4096×1000 = 4,096,000

总参数数:约61百万(61M)

维度计算示例

第1卷积层计算:

复制代码
输入:227×227×3
卷积核:11×11,步长=4,填充=2
输出尺寸:(227-11+2×2)/4 + 1 = 55
输出通道:96
输出:55×55×96

第1池化层计算:

复制代码
输入:55×55×96
池化核:3×3,步长=2
输出尺寸:(55-3)/2 + 1 = 27
输出:27×27×96

简化版 AlexNet (SimplifiedAlexNet 类)

针对小数据集的修改
修改点 原始AlexNet 简化版AlexNet 说明
输入尺寸 227×227 32×32 适配CIFAR-10
第1卷积层 11×11, stride=4 3×3, stride=1 小图像无需大步长
通道数 96-256-384-384-256 64-192-384-256-256 减少参数数量
LRN 使用 可选去除 现代实践常用BatchNorm替代
全连接层 4096-4096-1000 1024-512-10 匹配CIFAR-10的10个类别
维度变化示例

对于32×32的CIFAR-10图像:

复制代码
Conv1 (3×3, stride=1): 32×32×3 → 32×32×64
Pool1 (2×2, stride=2): 32×32×64 → 16×16×64
Conv2 (3×3): 16×16×64 → 16×16×192
Pool2 (2×2, stride=2): 16×16×192 → 8×8×192
Conv3 (3×3): 8×8×192 → 8×8×384
Conv4 (3×3): 8×8×384 → 8×8×256
Conv5 (3×3): 8×8×256 → 8×8×256
Pool5 (2×2, stride=2): 8×8×256 → 4×4×256
全连接输入:4×4×256 = 4096

附代码

python 复制代码
import torch
import torch.nn as nn

class AlexNet(nn.Module):
    """
    AlexNet网络结构实现

    原始参数:
    - 输入:227×227×3(实际论文为224×224×3,但推导常用227×227)
    - 输出:1000类(ImageNet)
    """

    def __init__(self, num_classes=1000, dropout_rate=0.5):
        super(AlexNet, self).__init__()

        # 特征提取部分(卷积层)
        self.features = nn.Sequential(
            # 第1卷积层: 输入3通道,输出96通道
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # 输出: 55×55×96
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),  # LRN层
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出: 27×27×96

            # 第2卷积层: 输入96通道,输出256通道
            nn.Conv2d(96, 256, kernel_size=5, padding=2),  # 输出: 27×27×256
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出: 13×13×256

            # 第3卷积层: 输入256通道,输出384通道
            nn.Conv2d(256, 384, kernel_size=3, padding=1),  # 输出: 13×13×384
            nn.ReLU(inplace=True),

            # 第4卷积层: 输入384通道,输出384通道
            nn.Conv2d(384, 384, kernel_size=3, padding=1),  # 输出: 13×13×384
            nn.ReLU(inplace=True),

            # 第5卷积层: 输入384通道,输出256通道
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 输出: 13×13×256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出: 6×6×256
        )

        # 自适应平均池化层(替代固定的展平操作,使网络适应不同输入尺寸)
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        # 分类器部分(全连接层)
        self.classifier = nn.Sequential(
            # 第1全连接层: 输入6×6×256=9216,输出4096
            nn.Dropout(p=dropout_rate),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),

            # 第2全连接层: 输入4096,输出4096
            nn.Dropout(p=dropout_rate),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),

            # 第3全连接层(输出层): 输入4096,输出num_classes
            nn.Linear(4096, num_classes),
        )

        # 权重初始化
        self._initialize_weights()

    def _initialize_weights(self):
        """初始化网络权重(按原始论文方式)"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 卷积层使用正态分布初始化
                nn.init.normal_(m.weight, mean=0, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                # 全连接层使用正态分布初始化
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.constant_(m.bias, 0)

        # 对第2、4、5卷积层的偏置初始化为1(按原始论文)
        nn.init.constant_(self.features[3].bias, 1)
        nn.init.constant_(self.features[8].bias, 1)
        nn.init.constant_(self.features[10].bias, 1)

    def forward(self, x):
        """前向传播"""
        # 特征提取
        x = self.features(x)
        x = self.avgpool(x)

        # 展平
        x = torch.flatten(x, 1)

        # 分类
        x = self.classifier(x)

        return x

    def get_feature_maps(self, x, layer_idx=None):
        """
        获取指定层的特征图(用于可视化)

        参数:
            x: 输入图像
            layer_idx: 层索引列表,None表示获取所有层的特征图
        """
        feature_maps = {}
        layer_names = [
            'conv1', 'pool1', 'conv2', 'pool2',
            'conv3', 'conv4', 'conv5', 'pool5'
        ]

        # 逐层前向传播并保存特征图
        for i, layer in enumerate(self.features):
            x = layer(x)
            if layer_idx is None or i in layer_idx:
                feature_maps[layer_names[i]] = x.detach()

        return feature_maps


class SimplifiedAlexNet(nn.Module):
    """
    简化版AlexNet(适用于CIFAR-10等小尺寸数据集)

    修改点:
    1. 去除第1个卷积层的大步长
    2. 修改全连接层的输入尺寸
    3. 可选去除LRN层(现代实践中常用BatchNorm替代)
    """

    def __init__(self, num_classes=10, use_batchnorm=False):
        super(SimplifiedAlexNet, self).__init__()

        self.use_batchnorm = use_batchnorm

        # 特征提取部分
        self.features = nn.Sequential(
            # 第1卷积层(适配32×32输入)
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 输出: 32×32×64
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 16×16×64

            # 第2卷积层
            nn.Conv2d(64, 192, kernel_size=3, padding=1),  # 输出: 16×16×192
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 8×8×192

            # 第3卷积层
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # 输出: 8×8×384
            nn.ReLU(inplace=True),

            # 第4卷积层
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 输出: 8×8×256
            nn.ReLU(inplace=True),

            # 第5卷积层
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # 输出: 8×8×256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 4×4×256
        )

        # 分类器部分
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(inplace=True),

            nn.Dropout(p=0.5),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),

            nn.Linear(512, num_classes),
        )

        # 权重初始化
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
相关推荐
思绪无限4 小时前
YOLOv5至YOLOv12升级:木材表面缺陷检测系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·目标检测·计算机视觉·木材表面缺陷检测
kishu_iOS&AI4 小时前
深度学习 —— 损失函数
人工智能·pytorch·python·深度学习·线性回归
知识浅谈4 小时前
DeepSeek V4 和 GPT-5.5 在同一天发布了??我也很懵,但对比完我悟了
算法
DeepModel4 小时前
通俗易懂讲透 Q-Learning:从零学会强化学习核心算法
人工智能·学习·算法·机器学习
田梓燊5 小时前
力扣:19.删除链表的倒数第 N 个结点
算法·leetcode·链表
简简单单做算法6 小时前
基于GA遗传优化双BP神经网络的时间序列预测算法matlab仿真
神经网络·算法·matlab·时间序列预测·双bp神经网络
guygg887 小时前
利用遗传算法解决列车优化运行问题的MATLAB实现
开发语言·算法·matlab
武藤一雄7 小时前
19个核心算法(C#版)
数据结构·windows·算法·c#·排序算法·.net·.netcore
sali-tec7 小时前
C# 基于OpenCv的视觉工作流-章52-交点查找
图像处理·人工智能·opencv·算法·计算机视觉
ZhengEnCi7 小时前
01c-循环神经网络RNN详解
人工智能·深度学习