23. 已有模型的修改

已有模型的修改
1. 为什么要进行已有模型的修改
  • PyTorch的torchvision模块包含了很多个已经创建好的模型,我们在使用一些经典模型的时候可以直接使用,但是部分模型不一定完全适用于当前数据集,如果直接网络构建的源码角度来修改模型是比较麻烦的
  • 当我们在团队合作的时候,也是可以直接通过查看别人的网络结构,直接在结构中添加、修改网络结构就可以,增加了开发效率
2.对VGG16模型结构简述
  • 从torchvision中导入VGG16模型

    python 复制代码
    vgg16 = torchvision.models.vgg16(weights=False)
    • weights=False:表示下载初始模型而不下载模型中训练好的诸多参数,本文只针对于模型的修改,不针对具体参数,所以设置为False即可
  • 输出 VGG16 的模型架构

    python 复制代码
    print(vgg16)
    ##########################################################################
    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU(inplace=True)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU(inplace=True)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    • 整个VGG16网络架构包括三个大块:features avgpool classifier
    • 每个模块中可以通过中包含具体的层信息也可以清楚的查看到
3. 给模型添加新的模块
  • 整体中添加新的模块结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.add_module("add", modules)
    • 通过pytorch构建的网络架构包含 add_module 方法可以用来给网络添加新的结构
    • 上面的代码中,给VGG16的整体添加了一个新的块,块名为 add 块中的网络结构就是 modules 中定义的内容,新的模块将自动添加到VGG架构的末尾
    python 复制代码
    print(vgg16)
    ###############################################################################
    VGG(
      (features): Sequential(
        ....
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        ......
      )
      (add): Sequential(
        (0): Linear(in_features=1000, out_features=10, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=10, out_features=1, bias=True)
      )
    )
  • 某个模块中添加新的结构

    python 复制代码
    modules = nn.Sequential(
        nn.Linear(in_features=1000, out_features=10),
        nn.ReLU(inplace=True),
        nn.Linear(in_features=10, out_features=1),
    )
    vgg16.classifier.add_module("add", modules)
    • . 运算符可以直接选中具体的子模块,再通过 add_module 方法就可以给具体的子模块添加新的网络结构
    • 上面的代码中,给VGG中的 classifier 模块的末尾添加了新的网络序列,网络结构如下所示:
    python 复制代码
    print(vgg16)
    #########################################################
    VGG(
      (features): Sequential(
        .......
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
          
        (add): Sequential(
          (0): Linear(in_features=1000, out_features=10, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=10, out_features=1, bias=True)
        )
      )
    )
  • 修改某个层

    • 通过 . 运算符 可以选中 VGG中的三个主要模块,再通过索引就可以选择到模块中的具体的层,从而进行具体的修改

      python 复制代码
      vgg16.classifier[0] = nn.Linear(in_features=25088, out_features=4096, bias=False)

      上面的代码将 Classifier 中的第一个线性层中的偏置设置为 False,其它信息不变

      python 复制代码
      VGG(
        (features): Sequential(
           ......
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
        (classifier): Sequential(
          (0): Linear(in_features=25088, out_features=4096, bias=False)
          .......
        )
      )
相关推荐
2301_7766816514 分钟前
【用「概率思维」重新理解生活】
开发语言·人工智能·自然语言处理
蜡笔小新..18 分钟前
从零开始:用PyTorch构建CIFAR-10图像分类模型达到接近1的准确率
人工智能·pytorch·机器学习·分类·cifar-10
富唯智能37 分钟前
转运机器人可以绕障吗?
人工智能·智能机器人·转运机器人
视觉语言导航1 小时前
湖南大学3D场景问答最新综述!3D-SQA:3D场景问答助力具身智能场景理解
人工智能·深度学习·具身智能
AidLux1 小时前
端侧智能重构智能监控新路径 | 2025 高通边缘智能创新应用大赛第三场公开课来袭!
大数据·人工智能
引量AI1 小时前
TikTok矩阵运营干货:从0到1打造爆款矩阵
人工智能·矩阵·自动化·tiktok矩阵·海外社媒
Hi-Dison2 小时前
神经网络极简入门技术分享
人工智能·深度学习·神经网络
奋斗者1号2 小时前
机器学习之决策树模型:从基础概念到条件类型详解
人工智能·决策树·机器学习
LinkTime_Cloud2 小时前
谷歌引入 AI 反诈系统:利用语言模型分析潜在恶意网站
人工智能·语言模型·自然语言处理
张小九992 小时前
PyTorch的dataloader制作自定义数据集
人工智能·pytorch·python