23. 已有模型的修改

已有模型的修改
1. 为什么要进行已有模型的修改
  • PyTorch的torchvision模块包含了很多个已经创建好的模型,我们在使用一些经典模型的时候可以直接使用,但是部分模型不一定完全适用于当前数据集,如果直接网络构建的源码角度来修改模型是比较麻烦的
  • 当我们在团队合作的时候,也是可以直接通过查看别人的网络结构,直接在结构中添加、修改网络结构就可以,增加了开发效率
2.对VGG16模型结构简述
  • 从torchvision中导入VGG16模型

    python 复制代码
    vgg16 = torchvision.models.vgg16(weights=False)
    • weights=False:表示下载初始模型而不下载模型中训练好的诸多参数,本文只针对于模型的修改,不针对具体参数,所以设置为False即可
  • 输出 VGG16 的模型架构

    python 复制代码
    print(vgg16)
    ##########################################################################
    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU(inplace=True)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU(inplace=True)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    • 整个VGG16网络架构包括三个大块:features avgpool classifier
    • 每个模块中可以通过中包含具体的层信息也可以清楚的查看到
3. 给模型添加新的模块
  • 整体中添加新的模块结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.add_module("add", modules)
    • 通过pytorch构建的网络架构包含 add_module 方法可以用来给网络添加新的结构
    • 上面的代码中,给VGG16的整体添加了一个新的块,块名为 add 块中的网络结构就是 modules 中定义的内容,新的模块将自动添加到VGG架构的末尾
    python 复制代码
    print(vgg16)
    ###############################################################################
    VGG(
      (features): Sequential(
        ....
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        ......
      )
      (add): Sequential(
        (0): Linear(in_features=1000, out_features=10, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=10, out_features=1, bias=True)
      )
    )
  • 某个模块中添加新的结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.classifier.add_module("add", modules)
    • . 运算符可以直接选中具体的子模块,再通过 add_module 方法就可以给具体的子模块添加新的网络结构
    • 上面的代码中,给VGG中的 classifier 模块的末尾添加了新的网络序列,网络结构如下所示:
    python 复制代码
    print(vgg16)
    #########################################################
    VGG(
      (features): Sequential(
        .......
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
          
        (add): Sequential(
          (0): Linear(in_features=1000, out_features=10, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=10, out_features=1, bias=True)
        )
      )
    )
  • 修改某个层

    • 通过 . 运算符 可以选中 VGG中的三个主要模块,再通过索引就可以选择到模块中的具体的层,从而进行具体的修改

      python 复制代码
      vgg16.classifier[0] = nn.Linear(in_features=25088, out_features=4096, bias=False)

      上面的代码将 Classifier 中的第一个线性层中的偏置设置为 False,其它信息不变

      python 复制代码
      VGG(
        (features): Sequential(
           ......
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
        (classifier): Sequential(
          (0): Linear(in_features=25088, out_features=4096, bias=False)
          .......
        )
      )
相关推荐
Elastic 中国社区官方博客2 小时前
使用 Discord 和 Elastic Agent Builder A2A 构建游戏社区支持机器人
人工智能·elasticsearch·游戏·搜索引擎·ai·机器人·全文检索
2501_933329553 小时前
企业级AI舆情中台架构实践:Infoseek系统如何实现亿级数据实时监测与智能处置?
人工智能·架构
阿杰学AI3 小时前
AI核心知识70——大语言模型之Context Engineering(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·数据处理·上下文工程
赛博鲁迅3 小时前
物理AI元年:AI走出屏幕进入现实,88API为机器人装上“最强大脑“
人工智能·机器人
管牛牛3 小时前
图像的卷积操作
人工智能·深度学习·计算机视觉
云卓SKYDROID3 小时前
无人机航线辅助模块技术解析
人工智能·无人机·高科技·云卓科技
琅琊榜首20204 小时前
AI生成脑洞付费短篇小说:从灵感触发到内容落地
大数据·人工智能
imbackneverdie4 小时前
近年来,我一直在用的科研工具
人工智能·自然语言处理·aigc·论文·ai写作·学术·ai工具
roman_日积跬步-终至千里5 小时前
【计算机视觉-作业1】从图像到向量:kNN数据预处理完整流程
人工智能·计算机视觉