23. 已有模型的修改

已有模型的修改
1. 为什么要进行已有模型的修改
  • PyTorch的torchvision模块包含了很多个已经创建好的模型,我们在使用一些经典模型的时候可以直接使用,但是部分模型不一定完全适用于当前数据集,如果直接网络构建的源码角度来修改模型是比较麻烦的
  • 当我们在团队合作的时候,也是可以直接通过查看别人的网络结构,直接在结构中添加、修改网络结构就可以,增加了开发效率
2.对VGG16模型结构简述
  • 从torchvision中导入VGG16模型

    python 复制代码
    vgg16 = torchvision.models.vgg16(weights=False)
    • weights=False:表示下载初始模型而不下载模型中训练好的诸多参数,本文只针对于模型的修改,不针对具体参数,所以设置为False即可
  • 输出 VGG16 的模型架构

    python 复制代码
    print(vgg16)
    ##########################################################################
    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU(inplace=True)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU(inplace=True)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    • 整个VGG16网络架构包括三个大块:features avgpool classifier
    • 每个模块中可以通过中包含具体的层信息也可以清楚的查看到
3. 给模型添加新的模块
  • 整体中添加新的模块结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.add_module("add", modules)
    • 通过pytorch构建的网络架构包含 add_module 方法可以用来给网络添加新的结构
    • 上面的代码中,给VGG16的整体添加了一个新的块,块名为 add 块中的网络结构就是 modules 中定义的内容,新的模块将自动添加到VGG架构的末尾
    python 复制代码
    print(vgg16)
    ###############################################################################
    VGG(
      (features): Sequential(
        ....
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        ......
      )
      (add): Sequential(
        (0): Linear(in_features=1000, out_features=10, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=10, out_features=1, bias=True)
      )
    )
  • 某个模块中添加新的结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.classifier.add_module("add", modules)
    • . 运算符可以直接选中具体的子模块,再通过 add_module 方法就可以给具体的子模块添加新的网络结构
    • 上面的代码中,给VGG中的 classifier 模块的末尾添加了新的网络序列,网络结构如下所示:
    python 复制代码
    print(vgg16)
    #########################################################
    VGG(
      (features): Sequential(
        .......
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
          
        (add): Sequential(
          (0): Linear(in_features=1000, out_features=10, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=10, out_features=1, bias=True)
        )
      )
    )
  • 修改某个层

    • 通过 . 运算符 可以选中 VGG中的三个主要模块,再通过索引就可以选择到模块中的具体的层,从而进行具体的修改

      python 复制代码
      vgg16.classifier[0] = nn.Linear(in_features=25088, out_features=4096, bias=False)

      上面的代码将 Classifier 中的第一个线性层中的偏置设置为 False,其它信息不变

      python 复制代码
      VGG(
        (features): Sequential(
           ......
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
        (classifier): Sequential(
          (0): Linear(in_features=25088, out_features=4096, bias=False)
          .......
        )
      )
相关推荐
PM老周2 分钟前
DORA2025:如何用AI提升研发效能(以 ONES MCP Server 为例)
大数据·人工智能
皇族崛起4 分钟前
【众包 + AI智能体】AI境生态巡查平台边防借鉴价值专项调研——以广西边境线治理为例
大数据·人工智能
zhaodiandiandian23 分钟前
AI大模型:重构产业生态的核心引擎
人工智能·重构
沈浩(种子思维作者)28 分钟前
百项可控核聚变实现方式的全息太极矩阵
人工智能
_codemonster29 分钟前
自然语言处理容易混淆知识点(二)BERT和BERTopic的区别
人工智能·自然语言处理·bert
JoannaJuanCV32 分钟前
自动驾驶—CARLA仿真(9)visualize_multiple_sensors demo
人工智能·自动驾驶·pygame
良策金宝AI38 分钟前
全球工程软件格局重塑:中国AI原生平台的机会窗口
大数据·运维·人工智能
小笔学长39 分钟前
毕业论文答辩 PPT:从内容到呈现的全流程设计指南
人工智能·powerpoint
dagouaofei40 分钟前
长文档也能转成PPT:AI自动拆分章节并生成页面
人工智能·python·powerpoint
IT_陈寒42 分钟前
SpringBoot 3.2 实战:用这5个新特性让你的API性能提升40%
前端·人工智能·后端