【CNN算法理解】:三、AlexNet 训练模块(附代码)

文章目录

概述

本模块提供 AlexNet 网络结构的完整实现和训练框架,包含原始 AlexNet、简化版 AlexNet 以及完整的训练器类,支持多种数据集和训练配置。

AlexNet 网络架构

原始 AlexNet 结构 (AlexNet 类)

设计理念
  • 输入:227×227×3(论文中为224×224×3,计算推导常用227)
  • 输出:1000个类别(ImageNet)
  • 包含8个学习层(5个卷积层,3个全连接层)
  • 使用ReLU激活函数替代传统的tanh函数
  • 引入Local Response Normalization(LRN)
详细层结构
类型 参数 输入尺寸 输出尺寸 参数数量计算
1 Conv2d 11×11, stride=4, 96个滤波器 227×227×3 55×55×96 (11×11×3)×96 = 34,848
1 ReLU - 55×55×96 55×55×96 -
1 LRN size=5, α=0.0001, β=0.75 55×55×96 55×55×96 -
1 MaxPool 3×3, stride=2 55×55×96 27×27×96 -
2 Conv2d 5×5, padding=2, 256个滤波器 27×27×96 27×27×256 (5×5×96)×256 = 614,400
2 ReLU - 27×27×256 27×27×256 -
2 LRN size=5, α=0.0001, β=0.75 27×27×256 27×27×256 -
2 MaxPool 3×3, stride=2 27×27×256 13×13×256 -
3 Conv2d 3×3, padding=1, 384个滤波器 13×13×256 13×13×384 (3×3×256)×384 = 884,736
3 ReLU - 13×13×384 13×13×384 -
4 Conv2d 3×3, padding=1, 384个滤波器 13×13×384 13×13×384 (3×3×384)×384 = 1,327,104
4 ReLU - 13×13×384 13×13×384 -
5 Conv2d 3×3, padding=1, 256个滤波器 13×13×384 13×13×256 (3×3×384)×256 = 884,736
5 ReLU - 13×13×256 13×13×256 -
5 MaxPool 3×3, stride=2 13×13×256 6×6×256 -
FC1 Linear 4096个神经元 6×6×256=9216 4096 9216×4096 = 37,752,832
FC2 Linear 4096个神经元 4096 4096 4096×4096 = 16,777,216
FC3 Linear 1000个神经元 4096 1000 4096×1000 = 4,096,000

总参数数:约61百万(61M)

维度计算示例

第1卷积层计算:

复制代码
输入:227×227×3
卷积核:11×11,步长=4,填充=2
输出尺寸:(227-11+2×2)/4 + 1 = 55
输出通道:96
输出:55×55×96

第1池化层计算:

复制代码
输入:55×55×96
池化核:3×3,步长=2
输出尺寸:(55-3)/2 + 1 = 27
输出:27×27×96

简化版 AlexNet (SimplifiedAlexNet 类)

针对小数据集的修改
修改点 原始AlexNet 简化版AlexNet 说明
输入尺寸 227×227 32×32 适配CIFAR-10
第1卷积层 11×11, stride=4 3×3, stride=1 小图像无需大步长
通道数 96-256-384-384-256 64-192-384-256-256 减少参数数量
LRN 使用 可选去除 现代实践常用BatchNorm替代
全连接层 4096-4096-1000 1024-512-10 匹配CIFAR-10的10个类别
维度变化示例

对于32×32的CIFAR-10图像:

复制代码
Conv1 (3×3, stride=1): 32×32×3 → 32×32×64
Pool1 (2×2, stride=2): 32×32×64 → 16×16×64
Conv2 (3×3): 16×16×64 → 16×16×192
Pool2 (2×2, stride=2): 16×16×192 → 8×8×192
Conv3 (3×3): 8×8×192 → 8×8×384
Conv4 (3×3): 8×8×384 → 8×8×256
Conv5 (3×3): 8×8×256 → 8×8×256
Pool5 (2×2, stride=2): 8×8×256 → 4×4×256
全连接输入:4×4×256 = 4096

附代码

python 复制代码
import torch
import torch.nn as nn

class AlexNet(nn.Module):
    """
    AlexNet网络结构实现

    原始参数:
    - 输入:227×227×3(实际论文为224×224×3,但推导常用227×227)
    - 输出:1000类(ImageNet)
    """

    def __init__(self, num_classes=1000, dropout_rate=0.5):
        super(AlexNet, self).__init__()

        # 特征提取部分(卷积层)
        self.features = nn.Sequential(
            # 第1卷积层: 输入3通道,输出96通道
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),  # 输出: 55×55×96
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),  # LRN层
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出: 27×27×96

            # 第2卷积层: 输入96通道,输出256通道
            nn.Conv2d(96, 256, kernel_size=5, padding=2),  # 输出: 27×27×256
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出: 13×13×256

            # 第3卷积层: 输入256通道,输出384通道
            nn.Conv2d(256, 384, kernel_size=3, padding=1),  # 输出: 13×13×384
            nn.ReLU(inplace=True),

            # 第4卷积层: 输入384通道,输出384通道
            nn.Conv2d(384, 384, kernel_size=3, padding=1),  # 输出: 13×13×384
            nn.ReLU(inplace=True),

            # 第5卷积层: 输入384通道,输出256通道
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 输出: 13×13×256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 输出: 6×6×256
        )

        # 自适应平均池化层(替代固定的展平操作,使网络适应不同输入尺寸)
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        # 分类器部分(全连接层)
        self.classifier = nn.Sequential(
            # 第1全连接层: 输入6×6×256=9216,输出4096
            nn.Dropout(p=dropout_rate),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),

            # 第2全连接层: 输入4096,输出4096
            nn.Dropout(p=dropout_rate),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),

            # 第3全连接层(输出层): 输入4096,输出num_classes
            nn.Linear(4096, num_classes),
        )

        # 权重初始化
        self._initialize_weights()

    def _initialize_weights(self):
        """初始化网络权重(按原始论文方式)"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # 卷积层使用正态分布初始化
                nn.init.normal_(m.weight, mean=0, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                # 全连接层使用正态分布初始化
                nn.init.normal_(m.weight, mean=0, std=0.01)
                nn.init.constant_(m.bias, 0)

        # 对第2、4、5卷积层的偏置初始化为1(按原始论文)
        nn.init.constant_(self.features[3].bias, 1)
        nn.init.constant_(self.features[8].bias, 1)
        nn.init.constant_(self.features[10].bias, 1)

    def forward(self, x):
        """前向传播"""
        # 特征提取
        x = self.features(x)
        x = self.avgpool(x)

        # 展平
        x = torch.flatten(x, 1)

        # 分类
        x = self.classifier(x)

        return x

    def get_feature_maps(self, x, layer_idx=None):
        """
        获取指定层的特征图(用于可视化)

        参数:
            x: 输入图像
            layer_idx: 层索引列表,None表示获取所有层的特征图
        """
        feature_maps = {}
        layer_names = [
            'conv1', 'pool1', 'conv2', 'pool2',
            'conv3', 'conv4', 'conv5', 'pool5'
        ]

        # 逐层前向传播并保存特征图
        for i, layer in enumerate(self.features):
            x = layer(x)
            if layer_idx is None or i in layer_idx:
                feature_maps[layer_names[i]] = x.detach()

        return feature_maps


class SimplifiedAlexNet(nn.Module):
    """
    简化版AlexNet(适用于CIFAR-10等小尺寸数据集)

    修改点:
    1. 去除第1个卷积层的大步长
    2. 修改全连接层的输入尺寸
    3. 可选去除LRN层(现代实践中常用BatchNorm替代)
    """

    def __init__(self, num_classes=10, use_batchnorm=False):
        super(SimplifiedAlexNet, self).__init__()

        self.use_batchnorm = use_batchnorm

        # 特征提取部分
        self.features = nn.Sequential(
            # 第1卷积层(适配32×32输入)
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 输出: 32×32×64
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 16×16×64

            # 第2卷积层
            nn.Conv2d(64, 192, kernel_size=3, padding=1),  # 输出: 16×16×192
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 8×8×192

            # 第3卷积层
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # 输出: 8×8×384
            nn.ReLU(inplace=True),

            # 第4卷积层
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # 输出: 8×8×256
            nn.ReLU(inplace=True),

            # 第5卷积层
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # 输出: 8×8×256
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 4×4×256
        )

        # 分类器部分
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(256 * 4 * 4, 1024),
            nn.ReLU(inplace=True),

            nn.Dropout(p=0.5),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),

            nn.Linear(512, num_classes),
        )

        # 权重初始化
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
相关推荐
九.九4 小时前
ops-transformer:AI 处理器上的高性能 Transformer 算子库
人工智能·深度学习·transformer
春日见4 小时前
拉取与合并:如何让个人分支既包含你昨天的修改,也包含 develop 最新更新
大数据·人工智能·深度学习·elasticsearch·搜索引擎
寻寻觅觅☆5 小时前
东华OJ-基础题-106-大整数相加(C++)
开发语言·c++·算法
化学在逃硬闯CS6 小时前
Leetcode1382. 将二叉搜索树变平衡
数据结构·算法
ceclar1236 小时前
C++使用format
开发语言·c++·算法
Faker66363aaa6 小时前
【深度学习】YOLO11-BiFPN多肉植物检测分类模型,从0到1实现植物识别系统,附完整代码与教程_1
人工智能·深度学习·分类
Gofarlic_OMS7 小时前
科学计算领域MATLAB许可证管理工具对比推荐
运维·开发语言·算法·matlab·自动化
夏鹏今天学习了吗7 小时前
【LeetCode热题100(100/100)】数据流的中位数
算法·leetcode·职场和发展
忙什么果7 小时前
上位机、下位机、FPGA、算法放在哪层合适?
算法·fpga开发