神经网络入门实战:(十七)VGG16模型的加载、保存与修改,可以应用到其他网络模型上

VGG16模型的加载、保存与修改

该模型主要是用来 识别 ImageNet 数据集的。

16 的由来:该模型包含16个权重层(13个卷积层和3个全连接层),因此得名VGG16。

模型官网:vgg16 --- Torchvision 0.20 documentation

(一) 从官网加载 VGG16 模型

① 在 pycharm 中加载官网 ++已经训练好++ 的 VGG16 模型的指令:

  • 使用该网络模型进行 ImageNet 的 1000 种类别的 分类任务(完整权重)

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)

    其中参数 weights 指定了要加载的预训练权重;V1 通常指的是模型的第一个版本或第一个公开发布的权重集。

    2024年目前情况下,这个参数是默认的,即 weights 可直接指定为 DEFAULT :

    python 复制代码
    weights=torchvision.models.VGG16_Weights.DEFAULT 

    不过 DEFAULT 一般都会指向新的权重版本,所以当官方更新最新权重之后(即 IMAGENET1K_V1 不再是最新的),就要慎重使用此命令。

  • 仅提供在 特征提取阶段 训练过的权重(只有 13 个卷积层训练好的权重):

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_FEATURES)

    这些权重可能缺少分类器模块(即全连接层)的权重。

加载好的VGG16模型参数如下:

python 复制代码
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

② 加载未经训练的 VGG16 模型:

python 复制代码
vgg16_true = torchvision.models.vgg16(weights=None)

这意味着模型的所有参数都将被自动初始化,而不是使用在大型数据集上训练好的权重。通常是:

  • 权重:通常是随机初始化,如均匀分布或正态分布等,在某些情况下,还可以自定义初始化方法。
  • 偏置:默认情况下,如果没有特别指定,偏置通常会被初始化为零(0),不过这也可以自定义。

(二) 保存完整模型

加载完成后,模型的权重会直接加载到内存中。 在加载模型的代码后面写保存模型的代码:

  • 如果没有指定路径,那么整个模型会默认放到:C:\Users\用户名/.cache\torch\hub\checkpoints\vgg16-397923af.pth

  • 指定路径保存torch.save(vgg16_true,"E:\\5_NN_model\\vgg16_IMAGENET1K_V1_pth")

    其中,vgg16_true 为代码中的模型名;vgg16_IMAGENET1K_V1_pth 为在指定文件夹中的文件名。

    注意:文件名一般要以 .pth 为后缀!!

  • 保存在当前路径下:torch.save(vgg16_true,"vgg16_IMAGENET1K_V1_pth")

    不用写清楚整个路径名,只需要给出保存之后的文件名即可。

(三) 只保存权重和偏置(官方推荐)

因为占内存相对较小

python 复制代码
torch.save(vgg16_None.state_dict(), 'E:\\5_NN_model\\vgg16_weights_bias.pth')

上述代码实际上是将模型 vgg16_None权重和偏置参数 (即模型的 状态字典 state_dict)保存为一个文件。这里的 state_dict 是一个从参数名称映射到参数张量的字典对象。

具体来说:

  • 模型权重:包括卷积层、全连接层等的权重。
  • 偏置参数:包括卷积层、全连接层等的偏置。
  • 不包括:模型的结构信息(如层的类型、顺序等)和训练时的超参数(如学习率、批次大小等)。

要完全恢复模型,你需要在加载这些权重之前定义一个具有相同结构的模型实例。

(四) 从本地加载模型

注意!!:如果本地文件是个完整的模型,那么就不能只从中加载权重和偏置;否则会报错(相当于:你保存了什么,就只能加载什么)。

  1. 之前保存了完整的模型:

    加载整个模型,包含权重:

    python 复制代码
    vgg16_None_new = torch.load("E:\\5_NN_model\\vgg16_None.pth") # 此代码运行之后,系统发出警告,不过可以忽视
    # 等效于:
    vgg16_None_new = torch.load("E:\\5_NN_model\\vgg16_None.pth", weights_only=False)

    随后使用 print 打印出的会是++整个完整架构++。

  2. 之前只保存了权重和偏置:

    python 复制代码
    vgg16_weights_bias_new = torch.load("E:\\5_NN_model\\vgg16_weights_bias.pth") # 主流
    # 等效于:
    vgg16_weights_bias_new = torch.load("E:\\5_NN_model\\vgg16_weights_bias.pth", weights_only=True)

    随后使用 print 打印出的只是字典形式的权重和偏置。

    如果想要将其应用到模型中,那么就需要先编写模型的架构,或者从官网加载模型,之后导入自己本地的权重和偏置:

    python 复制代码
    model = models.vgg16()  
    model = torch.load_state_dict("E:\\5_NN_model\\vgg16_None2.pth") # load_state_dict 函数
  3. 大坑:

    如果是自己一层一层写的模型,那么从本地导入模型之前,仍然需要将模型的定义再写一遍!!!只不过不需要实例化了:

    python 复制代码
    # 假设在别的py文件中已经保存过CIFAR10_NET_Instance模型了
    
    class CIFAR10_NET(nn.Module):
    	def __init__(self):
    		super(CIFAR10_NET, self).__init__()
    		self.model = nn.Sequential(
    			nn.Conv2d(3, 32, 5, padding=2),  # 输入输出尺寸相同,故根据公式计算出padding的值
    			nn.MaxPool2d(2, 2),
    			nn.Conv2d(32, 32, 5, padding=2),
    			nn.MaxPool2d(2, 2),
    			nn.Conv2d(32, 64, 5, padding=2),
    			nn.MaxPool2d(2, 2),
    			nn.Flatten(),
    			nn.Linear(1024, 64),
    			nn.Linear(64, 10)
    		)
    
    	def forward(self, x):
    		x = self.model(x)
    		return x
    
    # 从本地加载完整的模型
    CIFAR10_NET_Instance_new = torch.load("E:\\5_NN_model\\CIFAR10_NET_Instance.pth")
    print(CIFAR10_NET_Instance_new)

    不过可以通过导入该源码,来解决这个问题:

    不过,源码的文件名,只能以字母未开头!!

    python 复制代码
    # 假设该模型定义在了一个名为 nn_loss_network 的文件中
    ...
    from nn_loss_network import * # 这个星号就表示导入该文件中的所有类
    ...
    CIFAR10_NET_Instance_new = torch.load("E:\\5_NN_model\\CIFAR10_NET_Instance.pth")
    print(CIFAR10_NET_Instance_new)

    这样就不会报错了。

(五) 修改VGG16

  • 添加子层

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1) # 加载原始VGG16模型
    # 想在classifier分类块中添加一个 1000 --> 10 的线性层
    vgg16_true.classifier.add_module('7',nn.Linear(in_features=1000, out_features=10))
    print(vgg16_true)

    运行结果:

    python 复制代码
      ...
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
        (7): Linear(in_features=1000, out_features=10, bias=True)
      )
     ...
  • 修改子层

    python 复制代码
    vgg16_true = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1) # 加载原始VGG16模型
    # 想在classifier分类块中将最后一个线性层修改为 1000 --> 10 的线性层
    vgg16_true.classifier[6] = nn.Linear(in_features=1000, out_features=10)
    # 法则:classifier[6]中的6,指的是classifier块中,名为6的子层,其实就是那一层前面的 () 内的内容。
    print(vgg16_true)

    运行结果:

    python 复制代码
      ...
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=1000, out_features=10, bias=True)
      )
     ...

上一篇 下一篇
神经网络入门实战(十六) 待发布
相关推荐
低头不见1 小时前
tcp的粘包拆包问题,如何解决?
网络·网络协议·tcp/ip
羑悻的小杀马特1 小时前
OpenCV 引擎:驱动实时应用开发的科技狂飙
人工智能·科技·opencv·计算机视觉
SKYDROID云卓小助手2 小时前
三轴云台之相机技术篇
运维·服务器·网络·数码相机·音视频
guanshiyishi4 小时前
ABeam 德硕 | 中国汽车市场(2)——新能源车的崛起与中国汽车市场机遇与挑战
人工智能
极客天成ScaleFlash4 小时前
极客天成NVFile:无缓存直击存储性能天花板,重新定义AI时代并行存储新范式
人工智能·缓存
yuzhangfeng5 小时前
【云计算物理网络】从传统网络到SDN:云计算的网络演进之路
网络·云计算
TDengine (老段)5 小时前
TDengine 中的关联查询
大数据·javascript·网络·物联网·时序数据库·tdengine·iotdb
zhu12893035565 小时前
网络安全的现状与防护措施
网络·安全·web安全
澳鹏Appen6 小时前
AI安全:构建负责任且可靠的系统
人工智能·安全
蹦蹦跳跳真可爱5896 小时前
Python----机器学习(KNN:使用数学方法实现KNN)
人工智能·python·机器学习