现代卷积网络实战系列3:PyTorch从零构建AlexNet训练MNIST数据集

1、AlexNet

AlexNet提出了一下5点改进:

  1. 使用了Dropout,防止过拟合
  2. 使用Relu作为激活函数,极大提高了特征提取效果
  3. 使用MaxPooling池化进行特征降维,极大提高了特征提取效果
  4. 首次使用GPU进行训练
  5. 使用了LRN局部响应归一化(对局部神经元的活动创建竞争机制,使得其中响应比较大的值变得相对更大,并抑制其他反馈较小的神经元,增强了模型的泛化能力)

2、AlexNet网络结构

AlexNet(

(feature): Sequential(

(0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(1, 1))

(1): ReLU(inplace=True)

(2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(3): ReLU(inplace=True)

(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

(5): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(6): ReLU(inplace=True)

(7): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(8): ReLU(inplace=True)

(9): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

(10): ReLU(inplace=True)

(11): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)

)

(classifier): Sequential(

(0): Dropout(p=0.5, inplace=False)

(1): Linear(in_features=4608, out_features=2048, bias=True)

(2): ReLU(inplace=True)

(3): Dropout(p=0.5, inplace=False)

(4): Linear(in_features=2048, out_features=1024, bias=True)

(5): ReLU(inplace=True)

(6): Linear(in_features=1024, out_features=10, bias=True)

)

)

3、PyTorch构建AlexNet

python 复制代码
class AlexNet(nn.Module):
    def __init__(self, num=10):
        super(AlexNet, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 96, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=1),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(32 * 12 * 12, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(2048, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num),
        )

    def forward(self, x):
        x = self.feature(x)
        x = x.view(-1, 32 * 12 * 12)
        x = self.classifier(x)
        return x

10个epoch训练过程的打印:

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py

Reading data...

train_data: (60000, 28, 28) train_label (60000,)

test_data: (10000, 28, 28) test_label (10000,)

Initialize neural network

test loss: 2302.56

test accuracy: 10.1 %

epoch step: 1

training loss: 167.49

test loss: 46.66

test accuracy: 98.73 %

epoch step: 2

training loss: 59.43

test loss: 36.14

test accuracy: 98.95 %

epoch step: 3

training loss: 49.94

test loss: 24.93

test accuracy: 99.22 %

epoch step: 4

training loss: 38.7

test loss: 20.42

test accuracy: 99.45 %

epoch step: 5

training loss: 35.07

test loss: 26.18

test accuracy: 99.17 %

epoch step: 6

training loss: 30.65

test loss: 22.65

test accuracy: 99.34 %

epoch step: 7

training loss: 26.34

test loss: 20.5

test accuracy: 99.31 %

epoch step: 8

training loss: 26.24

test loss: 27.69

test accuracy: 99.11 %

epoch step: 9

training loss: 23.14

test loss: 22.55

test accuracy: 99.39 %

epoch step: 10

training loss: 20.22

test loss: 28.51

test accuracy: 99.24 %

Training finished

进程已结束,退出代码为 0

效果已经非常好了

相关推荐
CoovallyAIHub2 天前
仿生学突破:SILD模型如何让无人机在电力线迷宫中发现“隐形威胁”
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
从春晚机器人到零样本革命:YOLO26-Pose姿态估计实战指南
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
Le-DETR:省80%预训练数据,这个实时检测Transformer刷新SOTA|Georgia Tech & 北交大
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
强化学习凭什么比监督学习更聪明?RL的“聪明”并非来自算法,而是因为它学会了“挑食”
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
YOLO-IOD深度解析:打破实时增量目标检测的三重知识冲突
深度学习·算法·计算机视觉
用户1474853079742 天前
AI-动手深度学习环境搭建-d2l
深度学习
OpenBayes贝式计算2 天前
解决视频模型痛点,TurboDiffusion 高效视频扩散生成系统;Google Streetview 涵盖多个国家的街景图像数据集
人工智能·深度学习·机器学习
OpenBayes贝式计算2 天前
OCR教程汇总丨DeepSeek/百度飞桨/华中科大等开源创新技术,实现OCR高精度、本地化部署
人工智能·深度学习·机器学习
在人间耕耘3 天前
HarmonyOS Vision Kit 视觉AI实战:把官方 Demo 改造成一套能长期复用的组件库
人工智能·深度学习·harmonyos
homelook3 天前
Transformer与电池管理系统(BMS)的结合是当前 智能电池管理 的前沿研究方向
人工智能·深度学习·transformer