动手学深度学习(pytorch)学习记录29-网络中的网络(NiN)[学习记录]

目录

介绍

网络中的网络(Network in Network,简称NiN)是一种经典的卷积神经网络结构,由Min Lin等人在2013年提出。NiN的核心思想是在传统的卷积神经网络中引入小型的多层感知机(MLP),以增强网络的特征提取能力。

NiN的主要特点包括:

  • MLP卷积层(MLP Convolution Layers):NiN通过在卷积层之间加入小型的MLP网络来提取更抽象的特征。这些MLP网络实际上是1x1卷积层,它们可以在保持空间结构的同时,对每个像素位置的通道进行全连接操作。

  • 全局平均池化层(Global Average Pooling):NiN去除了容易造成过拟合的全连接层,而是使用全局平均池化层来减少模型参数的数量。这种池化层在所有位置上进行求和,输出固定数量的特征,直接用于分类。

  • NiN块(NiN Blocks):NiN的基本构建单元是NiN块,它由一个普通卷积层和两个1x1卷积层组成。普通卷积层负责提取空间特征,而1x1卷积层则充当逐像素的全连接层,增强了特征的非线性表达能力。

  • 减少参数数量:由于使用了全局平均池化层和1x1卷积层替代传统的全连接层,NiN显著减少了模型的参数数量,有助于缓解过拟合问题。

  • 提高泛化能力:NiN的设计有助于提高模型的泛化能力,因为它通过MLP卷积层和全局平均池化层捕捉到了更丰富的特征表示。

NiN的这些设计影响了后续许多卷积神经网络的结构,尤其是在特征提取和分类器设计方面。尽管NiN是一个相对较老的模型,但它的设计理念仍然对深度学习领域产生了深远的影响。

LeNet、AlexNet和VGG都有相同的设计模式:用一系列的卷积层和汇聚层来提取空间结构特征,然后通过全连接层对特征的表征进行处理。

NiN网络使用的NiN块通过在卷积层之间加入类似于全连接层的1x1卷积层(也称为mlpconv层),增强了网络的非线性特征提取能力。这种设计允许网络在保持空间结构信息的同时,增加了网络的深度和复杂度。

NiN块

以一个普通的卷积层开始,后接两个1×1卷积层(充当带有ReLU激活函数的逐像素全连接层)

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l


def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

NiN模型

NiN完全取消了全连接层,而是使用一个NiN块,输出通道数等于标签类别数

python 复制代码
net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2),
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    # 将四维的输出转成二维的输出,其形状为(批量大小,10)
    nn.Flatten())

创建一个数据样本来查看每个块的输出形状。

python 复制代码
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

训练模型

python 复制代码
lr, num_epochs, batch_size = 0.05, 30, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.222, train acc 0.919, test acc 0.903
1686.5 examples/sec on cuda:0

· 本文使用了d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码
封面图片来源

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/

恳请大佬批评指正。

相关推荐
找了一圈尾巴30 分钟前
Wend看源码-Java-Map学习
java·学习·map
CM莫问2 小时前
tokenizer、tokenizer.encode、tokenizer.encode_plus比较
人工智能·python·深度学习·语言模型·大模型·tokenizer·文本表示
山山而川粤3 小时前
母婴用品系统|Java|SSM|JSP|
java·开发语言·后端·学习·mysql
MUTA️5 小时前
专业版pycharm与服务器连接
人工智能·python·深度学习·计算机视觉·pycharm
MinIO官方账号7 小时前
使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器实现可迭代式数据集
人工智能·pytorch·python
四口鲸鱼爱吃盐7 小时前
Pytorch | 利用IE-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·计算机视觉
四口鲸鱼爱吃盐7 小时前
Pytorch | 利用EMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
红色的山茶花8 小时前
YOLOv9-0.1部分代码阅读笔记-loss_tal_dual.py
笔记·深度学习·yolo
一棵开花的树,枝芽无限靠近你8 小时前
【PPTist】表格功能
前端·笔记·学习·编辑器·ppt·pptist
呆头鹅AI工作室8 小时前
基于特征工程(pca分析)、小波去噪以及数据增强,同时采用基于注意力机制的BiLSTM、随机森林、ARIMA模型进行序列数据预测
人工智能·深度学习·神经网络·算法·随机森林·回归