23. 已有模型的修改

已有模型的修改
1. 为什么要进行已有模型的修改
  • PyTorch的torchvision模块包含了很多个已经创建好的模型,我们在使用一些经典模型的时候可以直接使用,但是部分模型不一定完全适用于当前数据集,如果直接网络构建的源码角度来修改模型是比较麻烦的
  • 当我们在团队合作的时候,也是可以直接通过查看别人的网络结构,直接在结构中添加、修改网络结构就可以,增加了开发效率
2.对VGG16模型结构简述
  • 从torchvision中导入VGG16模型

    python 复制代码
    vgg16 = torchvision.models.vgg16(weights=False)
    • weights=False:表示下载初始模型而不下载模型中训练好的诸多参数,本文只针对于模型的修改,不针对具体参数,所以设置为False即可
  • 输出 VGG16 的模型架构

    python 复制代码
    print(vgg16)
    ##########################################################################
    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU(inplace=True)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU(inplace=True)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    • 整个VGG16网络架构包括三个大块:features avgpool classifier
    • 每个模块中可以通过中包含具体的层信息也可以清楚的查看到
3. 给模型添加新的模块
  • 整体中添加新的模块结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.add_module("add", modules)
    • 通过pytorch构建的网络架构包含 add_module 方法可以用来给网络添加新的结构
    • 上面的代码中,给VGG16的整体添加了一个新的块,块名为 add 块中的网络结构就是 modules 中定义的内容,新的模块将自动添加到VGG架构的末尾
    python 复制代码
    print(vgg16)
    ###############################################################################
    VGG(
      (features): Sequential(
        ....
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        ......
      )
      (add): Sequential(
        (0): Linear(in_features=1000, out_features=10, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=10, out_features=1, bias=True)
      )
    )
  • 某个模块中添加新的结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.classifier.add_module("add", modules)
    • . 运算符可以直接选中具体的子模块,再通过 add_module 方法就可以给具体的子模块添加新的网络结构
    • 上面的代码中,给VGG中的 classifier 模块的末尾添加了新的网络序列,网络结构如下所示:
    python 复制代码
    print(vgg16)
    #########################################################
    VGG(
      (features): Sequential(
        .......
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
          
        (add): Sequential(
          (0): Linear(in_features=1000, out_features=10, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=10, out_features=1, bias=True)
        )
      )
    )
  • 修改某个层

    • 通过 . 运算符 可以选中 VGG中的三个主要模块,再通过索引就可以选择到模块中的具体的层,从而进行具体的修改

      python 复制代码
      vgg16.classifier[0] = nn.Linear(in_features=25088, out_features=4096, bias=False)

      上面的代码将 Classifier 中的第一个线性层中的偏置设置为 False,其它信息不变

      python 复制代码
      VGG(
        (features): Sequential(
           ......
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
        (classifier): Sequential(
          (0): Linear(in_features=25088, out_features=4096, bias=False)
          .......
        )
      )
相关推荐
千天夜35 分钟前
激活函数解析:神经网络背后的“驱动力”
人工智能·深度学习·神经网络
大数据面试宝典36 分钟前
用AI来写SQL:让ChatGPT成为你的数据库助手
数据库·人工智能·chatgpt
封步宇AIGC41 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
m0_5236742143 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
HappyAcmen1 小时前
IDEA部署AI代写插件
java·人工智能·intellij-idea
噜噜噜噜鲁先森1 小时前
看懂本文,入门神经网络Neural Network
人工智能
InheritGuo2 小时前
It’s All About Your Sketch: Democratising Sketch Control in Diffusion Models
人工智能·计算机视觉·sketch
weixin_307779132 小时前
证明存在常数c, C > 0,使得在一系列特定条件下,某个特定投资时刻出现的概率与天数的对数成反比
人工智能·算法·机器学习
封步宇AIGC2 小时前
量化交易系统开发-实时行情自动化交易-3.4.1.6.A股宏观经济数据
人工智能·python·机器学习·数据挖掘
Jack黄从零学c++2 小时前
opencv(c++)图像的灰度转换
c++·人工智能·opencv