动手学深度学习(pytorch)学习记录29-网络中的网络(NiN)[学习记录]

目录

介绍

网络中的网络(Network in Network,简称NiN)是一种经典的卷积神经网络结构,由Min Lin等人在2013年提出。NiN的核心思想是在传统的卷积神经网络中引入小型的多层感知机(MLP),以增强网络的特征提取能力。

NiN的主要特点包括:

  • MLP卷积层(MLP Convolution Layers):NiN通过在卷积层之间加入小型的MLP网络来提取更抽象的特征。这些MLP网络实际上是1x1卷积层,它们可以在保持空间结构的同时,对每个像素位置的通道进行全连接操作。

  • 全局平均池化层(Global Average Pooling):NiN去除了容易造成过拟合的全连接层,而是使用全局平均池化层来减少模型参数的数量。这种池化层在所有位置上进行求和,输出固定数量的特征,直接用于分类。

  • NiN块(NiN Blocks):NiN的基本构建单元是NiN块,它由一个普通卷积层和两个1x1卷积层组成。普通卷积层负责提取空间特征,而1x1卷积层则充当逐像素的全连接层,增强了特征的非线性表达能力。

  • 减少参数数量:由于使用了全局平均池化层和1x1卷积层替代传统的全连接层,NiN显著减少了模型的参数数量,有助于缓解过拟合问题。

  • 提高泛化能力:NiN的设计有助于提高模型的泛化能力,因为它通过MLP卷积层和全局平均池化层捕捉到了更丰富的特征表示。

NiN的这些设计影响了后续许多卷积神经网络的结构,尤其是在特征提取和分类器设计方面。尽管NiN是一个相对较老的模型,但它的设计理念仍然对深度学习领域产生了深远的影响。

LeNet、AlexNet和VGG都有相同的设计模式:用一系列的卷积层和汇聚层来提取空间结构特征,然后通过全连接层对特征的表征进行处理。

NiN网络使用的NiN块通过在卷积层之间加入类似于全连接层的1x1卷积层(也称为mlpconv层),增强了网络的非线性特征提取能力。这种设计允许网络在保持空间结构信息的同时,增加了网络的深度和复杂度。

NiN块

以一个普通的卷积层开始,后接两个1×1卷积层(充当带有ReLU激活函数的逐像素全连接层)

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l


def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

NiN模型

NiN完全取消了全连接层,而是使用一个NiN块,输出通道数等于标签类别数

python 复制代码
net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    # 将四维的输出转成二维的输出,其形状为(批量大小,10)
    nn.Flatten())

创建一个数据样本来查看每个块的输出形状。

python 复制代码
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
复制代码
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

训练模型

python 复制代码
lr, num_epochs, batch_size = 0.05, 30, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
复制代码
loss 0.222, train acc 0.919, test acc 0.903
1686.5 examples/sec on cuda:0

· 本文使用了d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码
封面图片来源

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/

恳请大佬批评指正。

相关推荐
ZH1545589131几秒前
Flutter for OpenHarmony Python学习助手实战:面向对象编程实战的实现
python·学习·flutter
心疼你的一切1 分钟前
模态交响:CANN驱动的跨模态AIGC统一架构
数据仓库·深度学习·架构·aigc·cann
小羊不会打字7 分钟前
CANN 生态中的跨框架兼容桥梁:`onnx-adapter` 项目实现无缝模型迁移
c++·深度学习
简佐义的博客16 分钟前
生信入门进阶指南:学习顶级实验室多组学整合方案,构建肾脏细胞空间分子图谱
人工智能·学习
白日做梦Q16 分钟前
Anchor-free检测器全解析:CenterNet vs FCOS
python·深度学习·神经网络·目标检测·机器学习
近津薪荼18 分钟前
dfs专题4——二叉树的深搜(验证二叉搜索树)
c++·学习·算法·深度优先
饭饭大王66626 分钟前
CANN 生态中的自动化测试利器:`test-automation` 项目保障模型部署可靠性
深度学习
island131429 分钟前
CANN HIXL 通信库深度解析:单边点对点数据传输、异步模型与异构设备间显存直接访问
人工智能·深度学习·神经网络
心疼你的一切34 分钟前
解锁CANN仓库核心能力:从零搭建AIGC轻量文本生成实战(附代码+流程图)
数据仓库·深度学习·aigc·流程图·cann
2的n次方_43 分钟前
CANN ascend-transformer-boost 深度解析:针对大模型的高性能融合算子库与算力优化机制
人工智能·深度学习·transformer