23. 已有模型的修改

已有模型的修改
1. 为什么要进行已有模型的修改
  • PyTorch的torchvision模块包含了很多个已经创建好的模型,我们在使用一些经典模型的时候可以直接使用,但是部分模型不一定完全适用于当前数据集,如果直接网络构建的源码角度来修改模型是比较麻烦的
  • 当我们在团队合作的时候,也是可以直接通过查看别人的网络结构,直接在结构中添加、修改网络结构就可以,增加了开发效率
2.对VGG16模型结构简述
  • 从torchvision中导入VGG16模型

    python 复制代码
    vgg16 = torchvision.models.vgg16(weights=False)
    • weights=False:表示下载初始模型而不下载模型中训练好的诸多参数,本文只针对于模型的修改,不针对具体参数,所以设置为False即可
  • 输出 VGG16 的模型架构

    python 复制代码
    print(vgg16)
    ##########################################################################
    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU(inplace=True)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU(inplace=True)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    • 整个VGG16网络架构包括三个大块:features avgpool classifier
    • 每个模块中可以通过中包含具体的层信息也可以清楚的查看到
3. 给模型添加新的模块
  • 整体中添加新的模块结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.add_module("add", modules)
    • 通过pytorch构建的网络架构包含 add_module 方法可以用来给网络添加新的结构
    • 上面的代码中,给VGG16的整体添加了一个新的块,块名为 add 块中的网络结构就是 modules 中定义的内容,新的模块将自动添加到VGG架构的末尾
    python 复制代码
    print(vgg16)
    ###############################################################################
    VGG(
      (features): Sequential(
        ....
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        ......
      )
      (add): Sequential(
        (0): Linear(in_features=1000, out_features=10, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=10, out_features=1, bias=True)
      )
    )
  • 某个模块中添加新的结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.classifier.add_module("add", modules)
    • . 运算符可以直接选中具体的子模块,再通过 add_module 方法就可以给具体的子模块添加新的网络结构
    • 上面的代码中,给VGG中的 classifier 模块的末尾添加了新的网络序列,网络结构如下所示:
    python 复制代码
    print(vgg16)
    #########################################################
    VGG(
      (features): Sequential(
        .......
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
          
        (add): Sequential(
          (0): Linear(in_features=1000, out_features=10, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=10, out_features=1, bias=True)
        )
      )
    )
  • 修改某个层

    • 通过 . 运算符 可以选中 VGG中的三个主要模块,再通过索引就可以选择到模块中的具体的层,从而进行具体的修改

      python 复制代码
      vgg16.classifier[0] = nn.Linear(in_features=25088, out_features=4096, bias=False)

      上面的代码将 Classifier 中的第一个线性层中的偏置设置为 False,其它信息不变

      python 复制代码
      VGG(
        (features): Sequential(
           ......
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
        (classifier): Sequential(
          (0): Linear(in_features=25088, out_features=4096, bias=False)
          .......
        )
      )
相关推荐
周末程序猿11 分钟前
机器学习|大模型为什么会出现"幻觉"?
人工智能
JoannaJuanCV25 分钟前
大语言模型基石:Transformer
人工智能·语言模型·transformer
飞哥数智坊28 分钟前
Qoder vs CodeBuddy,刚起步就收费,值吗?
人工智能·ai编程
强盛小灵通专卖员29 分钟前
闪电科创,深度学习辅导
人工智能·sci·小论文·大论文·延毕
通街市密人有35 分钟前
IDF: Iterative Dynamic Filtering Networks for Generalizable Image Denoising
人工智能·深度学习·计算机视觉
大千AI助手39 分钟前
TruthfulQA:衡量语言模型真实性的基准
人工智能·语言模型·自然语言处理·llm·模型评估·truthfulqa·事实性基准
蚂蚁RichLab前端团队39 分钟前
🚀🚀🚀 RichLab - 花呗前端团队招贤纳士 - 【转岗/内推/社招】
前端·javascript·人工智能
智数研析社39 分钟前
9120 部 TMDb 高分电影数据集 | 7 列全维度指标 (评分 / 热度 / 剧情)+API 权威源 | 电影趋势分析 / 推荐系统 / NLP 建模用
大数据·人工智能·python·深度学习·数据分析·数据集·数据清洗
救救孩子把1 小时前
2-机器学习与大模型开发数学教程-第0章 预备知识-0-2 数列与级数(收敛性、幂级数)
人工智能·数学·机器学习
yzx9910131 小时前
接口协议全解析:从HTTP到gRPC,如何选择适合你的通信方案?
网络·人工智能·网络协议·flask·pygame