神经网络入门实战:(十七)VGG16模型的加载、保存与修改,可以应用到其他网络模型上

VGG16模型的加载、保存与修改

该模型主要是用来 识别 ImageNet 数据集的。

16 的由来:该模型包含16个权重层(13个卷积层和3个全连接层),因此得名VGG16。

模型官网:vgg16 --- Torchvision 0.20 documentation

(一) 从官网加载 VGG16 模型

① 在 pycharm 中加载官网 ++已经训练好++ 的 VGG16 模型的指令:

  • 使用该网络模型进行 ImageNet 的 1000 种类别的 分类任务(完整权重)

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

    其中参数 weights 指定了要加载的预训练权重;V1 通常指的是模型的第一个版本或第一个公开发布的权重集。

    2024年目前情况下,这个参数是默认的,即 weights 可直接指定为 DEFAULT :

    python 复制代码
    weights=torchvision.models.VGG16_Weights.DEFAULT 

    不过 DEFAULT 一般都会指向新的权重版本,所以当官方更新最新权重之后(即 IMAGENET1K_V1 不再是最新的),就要慎重使用此命令。

  • 仅提供在 特征提取阶段 训练过的权重(只有 13 个卷积层训练好的权重):

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_FEATURES)

    这些权重可能缺少分类器模块(即全连接层)的权重。

加载好的VGG16模型参数如下:

python 复制代码
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

② 加载未经训练的 VGG16 模型:

python 复制代码
vgg16_true = torchvision.models.vgg16(weights=None)

这意味着模型的所有参数都将被自动初始化,而不是使用在大型数据集上训练好的权重。通常是:

  • 权重:通常是随机初始化,如均匀分布或正态分布等,在某些情况下,还可以自定义初始化方法。
  • 偏置:默认情况下,如果没有特别指定,偏置通常会被初始化为零(0),不过这也可以自定义。

(二) 保存完整模型

加载完成后,模型的权重会直接加载到内存中。 在加载模型的代码后面写保存模型的代码:

  • 如果没有指定路径,那么整个模型会默认放到:C:\Users\用户名/.cache\torch\hub\checkpoints\vgg16-397923af.pth

  • 指定路径保存torch.save(vgg16_true,"E:\\5_NN_model\\vgg16_IMAGENET1K_V1_pth")

    其中,vgg16_true 为代码中的模型名;vgg16_IMAGENET1K_V1_pth 为在指定文件夹中的文件名。

    注意:文件名一般要以 .pth 为后缀!!

  • 保存在当前路径下:torch.save(vgg16_true,"vgg16_IMAGENET1K_V1_pth")

    不用写清楚整个路径名,只需要给出保存之后的文件名即可。

(三) 只保存权重和偏置(官方推荐)

因为占内存相对较小

python 复制代码
torch.save(vgg16_None.state_dict(), 'E:\\5_NN_model\\vgg16_weights_bias.pth')

上述代码实际上是将模型 vgg16_None权重和偏置参数 (即模型的 状态字典 state_dict)保存为一个文件。这里的 state_dict 是一个从参数名称映射到参数张量的字典对象。

具体来说:

  • 模型权重:包括卷积层、全连接层等的权重。
  • 偏置参数:包括卷积层、全连接层等的偏置。
  • 不包括:模型的结构信息(如层的类型、顺序等)和训练时的超参数(如学习率、批次大小等)。

要完全恢复模型,你需要在加载这些权重之前定义一个具有相同结构的模型实例。

(四) 从本地加载模型

注意!!:如果本地文件是个完整的模型,那么就不能只从中加载权重和偏置;否则会报错(相当于:你保存了什么,就只能加载什么)。

  1. 之前保存了完整的模型:

    加载整个模型,包含权重:

    python 复制代码
    vgg16_None_new = torch.load("E:\\5_NN_model\\vgg16_None.pth") # 此代码运行之后,系统发出警告,不过可以忽视
    # 等效于:
    vgg16_None_new = torch.load("E:\\5_NN_model\\vgg16_None.pth", weights_only=False)

    随后使用 print 打印出的会是++整个完整架构++。

  2. 之前只保存了权重和偏置:

    python 复制代码
    vgg16_weights_bias_new = torch.load("E:\\5_NN_model\\vgg16_weights_bias.pth") # 主流
    # 等效于:
    vgg16_weights_bias_new = torch.load("E:\\5_NN_model\\vgg16_weights_bias.pth", weights_only=True)

    随后使用 print 打印出的只是字典形式的权重和偏置。

    如果想要将其应用到模型中,那么就需要先编写模型的架构,或者从官网加载模型,之后导入自己本地的权重和偏置:

    python 复制代码
    model = models.vgg16()  
    model = torch.load_state_dict("E:\\5_NN_model\\vgg16_None2.pth") # load_state_dict 函数
  3. 大坑:

    如果是自己一层一层写的模型,那么从本地导入模型之前,仍然需要将模型的定义再写一遍!!!只不过不需要实例化了:

    python 复制代码
    # 假设在别的py文件中已经保存过CIFAR10_NET_Instance模型了
    
    class CIFAR10_NET(nn.Module):
    	def __init__(self):
    		super(CIFAR10_NET, self).__init__()
    		self.model = nn.Sequential(
    			nn.Conv2d(3, 32, 5, padding=2),  # 输入输出尺寸相同,故根据公式计算出padding的值
    			nn.MaxPool2d(2, 2),
    			nn.Conv2d(32, 32, 5, padding=2),
    			nn.MaxPool2d(2, 2),
    			nn.Conv2d(32, 64, 5, padding=2),
    			nn.MaxPool2d(2, 2),
    			nn.Flatten(),
    			nn.Linear(1024, 64),
    			nn.Linear(64, 10)
    		)
    
    	def forward(self, x):
    		x = self.model(x)
    		return x
    
    # 从本地加载完整的模型
    CIFAR10_NET_Instance_new = torch.load("E:\\5_NN_model\\CIFAR10_NET_Instance.pth")
    print(CIFAR10_NET_Instance_new)

    不过可以通过导入该源码,来解决这个问题:

    不过,源码的文件名,只能以字母未开头!!

    python 复制代码
    # 假设该模型定义在了一个名为 nn_loss_network 的文件中
    ...
    from nn_loss_network import * # 这个星号就表示导入该文件中的所有类
    ...
    CIFAR10_NET_Instance_new = torch.load("E:\\5_NN_model\\CIFAR10_NET_Instance.pth")
    print(CIFAR10_NET_Instance_new)

    这样就不会报错了。

(五) 修改VGG16

  • 添加子层

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1) # 加载原始VGG16模型
    # 想在classifier分类块中添加一个 1000 --> 10 的线性层
    vgg16_true.classifier.add_module('7',nn.Linear(in_features=1000, out_features=10))
    print(vgg16_true)

    运行结果:

    python 复制代码
      ...
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
        (7): Linear(in_features=1000, out_features=10, bias=True)
      )
     ...
  • 修改子层

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1) # 加载原始VGG16模型
    # 想在classifier分类块中将最后一个线性层修改为 1000 --> 10 的线性层
    vgg16_true.classifier[6] = nn.Linear(in_features=1000, out_features=10)
    # 法则:classifier[6]中的6,指的是classifier块中,名为6的子层,其实就是那一层前面的 () 内的内容。
    print(vgg16_true)

    运行结果:

    python 复制代码
      ...
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=1000, out_features=10, bias=True)
      )
     ...

上一篇 下一篇
神经网络入门实战(十六) 待发布
相关推荐
_Itachi__2 分钟前
Model.eval() 与 torch.no_grad() PyTorch 中的区别与应用
人工智能·pytorch·python
白光白光25 分钟前
大语言模型训练的两个阶段
人工智能·机器学习·语言模型
巷9551 小时前
OpenCV图像金字塔详解:原理、实现与应用
人工智能·opencv·计算机视觉
科技小E1 小时前
WebRTC实时音视频通话技术EasyRTC嵌入式音视频通信SDK,助力智慧物流打造实时高效的物流管理体系
人工智能·音视频
BioRunYiXue1 小时前
一文了解氨基酸的分类、代谢和应用
人工智能·深度学习·算法·机器学习·分类·数据挖掘·代谢组学
firshman_start1 小时前
第十五章,SSL VPN
网络
Johnstons1 小时前
AnaTraf:深度解析网络性能分析(NPM)
前端·网络·安全·web安全·npm·网络流量监控·网络流量分析
落——枫1 小时前
路由交换实验
网络
Johny_Zhao2 小时前
K8S+nginx+MYSQL+TOMCAT高可用架构企业自建网站
linux·网络·mysql·nginx·网络安全·信息安全·tomcat·云计算·shell·yum源·系统运维·itsm
小诸葛的博客2 小时前
华为ensp实现跨vlan通信
网络·华为·智能路由器