神经网络入门实战:(十七)VGG16模型的加载、保存与修改,可以应用到其他网络模型上

VGG16模型的加载、保存与修改

该模型主要是用来 识别 ImageNet 数据集的。

16 的由来:该模型包含16个权重层(13个卷积层和3个全连接层),因此得名VGG16。

模型官网:vgg16 --- Torchvision 0.20 documentation

(一) 从官网加载 VGG16 模型

① 在 pycharm 中加载官网 ++已经训练好++ 的 VGG16 模型的指令:

  • 使用该网络模型进行 ImageNet 的 1000 种类别的 分类任务(完整权重)

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

    其中参数 weights 指定了要加载的预训练权重;V1 通常指的是模型的第一个版本或第一个公开发布的权重集。

    2024年目前情况下,这个参数是默认的,即 weights 可直接指定为 DEFAULT :

    python 复制代码
    weights=torchvision.models.VGG16_Weights.DEFAULT 

    不过 DEFAULT 一般都会指向新的权重版本,所以当官方更新最新权重之后(即 IMAGENET1K_V1 不再是最新的),就要慎重使用此命令。

  • 仅提供在 特征提取阶段 训练过的权重(只有 13 个卷积层训练好的权重):

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_FEATURES)

    这些权重可能缺少分类器模块(即全连接层)的权重。

加载好的VGG16模型参数如下:

python 复制代码
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

② 加载未经训练的 VGG16 模型:

python 复制代码
vgg16_true = torchvision.models.vgg16(weights=None)

这意味着模型的所有参数都将被自动初始化,而不是使用在大型数据集上训练好的权重。通常是:

  • 权重:通常是随机初始化,如均匀分布或正态分布等,在某些情况下,还可以自定义初始化方法。
  • 偏置:默认情况下,如果没有特别指定,偏置通常会被初始化为零(0),不过这也可以自定义。

(二) 保存完整模型

加载完成后,模型的权重会直接加载到内存中。 在加载模型的代码后面写保存模型的代码:

  • 如果没有指定路径,那么整个模型会默认放到:C:\Users\用户名/.cache\torch\hub\checkpoints\vgg16-397923af.pth

  • 指定路径保存torch.save(vgg16_true,"E:\\5_NN_model\\vgg16_IMAGENET1K_V1_pth")

    其中,vgg16_true 为代码中的模型名;vgg16_IMAGENET1K_V1_pth 为在指定文件夹中的文件名。

    注意:文件名一般要以 .pth 为后缀!!

  • 保存在当前路径下:torch.save(vgg16_true,"vgg16_IMAGENET1K_V1_pth")

    不用写清楚整个路径名,只需要给出保存之后的文件名即可。

(三) 只保存权重和偏置(官方推荐)

因为占内存相对较小

python 复制代码
torch.save(vgg16_None.state_dict(), 'E:\\5_NN_model\\vgg16_weights_bias.pth')

上述代码实际上是将模型 vgg16_None权重和偏置参数 (即模型的 状态字典 state_dict)保存为一个文件。这里的 state_dict 是一个从参数名称映射到参数张量的字典对象。

具体来说:

  • 模型权重:包括卷积层、全连接层等的权重。
  • 偏置参数:包括卷积层、全连接层等的偏置。
  • 不包括:模型的结构信息(如层的类型、顺序等)和训练时的超参数(如学习率、批次大小等)。

要完全恢复模型,你需要在加载这些权重之前定义一个具有相同结构的模型实例。

(四) 从本地加载模型

注意!!:如果本地文件是个完整的模型,那么就不能只从中加载权重和偏置;否则会报错(相当于:你保存了什么,就只能加载什么)。

  1. 之前保存了完整的模型:

    加载整个模型,包含权重:

    python 复制代码
    vgg16_None_new = torch.load("E:\\5_NN_model\\vgg16_None.pth") # 此代码运行之后,系统发出警告,不过可以忽视
    # 等效于:
    vgg16_None_new = torch.load("E:\\5_NN_model\\vgg16_None.pth", weights_only=False)

    随后使用 print 打印出的会是++整个完整架构++。

  2. 之前只保存了权重和偏置:

    python 复制代码
    vgg16_weights_bias_new = torch.load("E:\\5_NN_model\\vgg16_weights_bias.pth") # 主流
    # 等效于:
    vgg16_weights_bias_new = torch.load("E:\\5_NN_model\\vgg16_weights_bias.pth", weights_only=True)

    随后使用 print 打印出的只是字典形式的权重和偏置。

    如果想要将其应用到模型中,那么就需要先编写模型的架构,或者从官网加载模型,之后导入自己本地的权重和偏置:

    python 复制代码
    model = models.vgg16()  
    model = torch.load_state_dict("E:\\5_NN_model\\vgg16_None2.pth") # load_state_dict 函数
  3. 大坑:

    如果是自己一层一层写的模型,那么从本地导入模型之前,仍然需要将模型的定义再写一遍!!!只不过不需要实例化了:

    python 复制代码
    # 假设在别的py文件中已经保存过CIFAR10_NET_Instance模型了
    
    class CIFAR10_NET(nn.Module):
    	def __init__(self):
    		super(CIFAR10_NET, self).__init__()
    		self.model = nn.Sequential(
    			nn.Conv2d(3, 32, 5, padding=2),  # 输入输出尺寸相同,故根据公式计算出padding的值
    			nn.MaxPool2d(2, 2),
    			nn.Conv2d(32, 32, 5, padding=2),
    			nn.MaxPool2d(2, 2),
    			nn.Conv2d(32, 64, 5, padding=2),
    			nn.MaxPool2d(2, 2),
    			nn.Flatten(),
    			nn.Linear(1024, 64),
    			nn.Linear(64, 10)
    		)
    
    	def forward(self, x):
    		x = self.model(x)
    		return x
    
    # 从本地加载完整的模型
    CIFAR10_NET_Instance_new = torch.load("E:\\5_NN_model\\CIFAR10_NET_Instance.pth")
    print(CIFAR10_NET_Instance_new)

    不过可以通过导入该源码,来解决这个问题:

    不过,源码的文件名,只能以字母未开头!!

    python 复制代码
    # 假设该模型定义在了一个名为 nn_loss_network 的文件中
    ...
    from nn_loss_network import * # 这个星号就表示导入该文件中的所有类
    ...
    CIFAR10_NET_Instance_new = torch.load("E:\\5_NN_model\\CIFAR10_NET_Instance.pth")
    print(CIFAR10_NET_Instance_new)

    这样就不会报错了。

(五) 修改VGG16

  • 添加子层

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1) # 加载原始VGG16模型
    # 想在classifier分类块中添加一个 1000 --> 10 的线性层
    vgg16_true.classifier.add_module('7',nn.Linear(in_features=1000, out_features=10))
    print(vgg16_true)

    运行结果:

    python 复制代码
      ...
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
        (7): Linear(in_features=1000, out_features=10, bias=True)
      )
     ...
  • 修改子层

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1) # 加载原始VGG16模型
    # 想在classifier分类块中将最后一个线性层修改为 1000 --> 10 的线性层
    vgg16_true.classifier[6] = nn.Linear(in_features=1000, out_features=10)
    # 法则:classifier[6]中的6,指的是classifier块中,名为6的子层,其实就是那一层前面的 () 内的内容。
    print(vgg16_true)

    运行结果:

    python 复制代码
      ...
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=1000, out_features=10, bias=True)
      )
     ...

上一篇 下一篇
神经网络入门实战(十六) 待发布
相关推荐
董董灿是个攻城狮40 分钟前
020:为什么 Resnet 如此重要?
人工智能·计算机视觉·cnn
汪子熙1 小时前
为什么 BERT 仅使用 Transformer 的编码器部分,而不使用解码器部分?
人工智能
AI服务老曹1 小时前
满足不同场景的需求的智慧物流开源了
人工智能·开源·自动化·能源
程序猿阿伟1 小时前
《开源与合作:驱动鸿蒙Next系统中人工智能技术创新发展的双引擎》
人工智能·开源·harmonyos
pchmi1 小时前
C# OpenCV机器视觉:常用滤波算法
人工智能·opencv·c#·机器视觉·中值滤波
茫茫人海一粒沙1 小时前
稀疏检索、密集检索与混合检索:概念、技术与应用
人工智能
BUG制造机.2 小时前
HTTP / 2
网络·网络协议·http
明天不吃。2 小时前
【网络原理】万字详解 HTTP 协议
网络·网络协议·http
代数狂人2 小时前
NPC与AI深度融合结合雷鸟X3Pro AR智能眼镜:引领游戏行业沉浸式与增强现实新纪元的畅想
人工智能·游戏·ar
计算机毕设定制辅导-无忧学长2 小时前
中型项目中 HTTP 的挑战与解决方案
网络·网络协议·http