现有模型的保存与加载(PyTorch版)

我们以VGG16网络为例,来说明现有模型的保存与加载操作。

保存与加载方式均有两种,接下来我们分别来学习这两种方式。注意:保存与加载不在同一个py文件中,我们设定保存操作在save.py文件中,而加载操作在load.py文件中。

保存模型的两种方式如下代码所示,第一种为既保存模型结构,又保存模型参数;第二种只保存模型参数,并且以字典的形式保存。

python 复制代码
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth") # 保存路径:vgg16_method1.pth

# 保存方式2,模型参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 保存路径:vgg16_method2.pth

加载模型的两种方式如下代码所示。

python 复制代码
import torch

# 方式1 --》保存方式1,加载模型
model1 = torch.load("vgg16_method1.pth")
print(model1)

# 方式2 --》保存方式2,加载模型
model2 = torch.load("vgg16_method2.pth")
print(model2)

打印结果为:

model1结果为:

python 复制代码
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

上述结果是VGG16网络模型结构以及网络模型参数。

model2结果为:(由于结果太多,故只给出部分结果)

python 复制代码
OrderedDict([('features.0.weight', tensor([[[[-0.1638, -0.0292,  0.0316],
          [-0.0149,  0.0681,  0.0458],
          [ 0.0633, -0.0374, -0.0047]],

         [[-0.0123, -0.0461,  0.0343],
          [ 0.0207, -0.0128,  0.0107],
          [-0.0181,  0.0154,  0.0320]],

         [[-0.0759, -0.1384, -0.0318],
          [ 0.0244, -0.0424,  0.0332],
          [-0.0244,  0.0524,  0.1292]]],
..........................................

上述结果是VGG16网络模型参数。

那我们要是想用通过保存方式2所保存的模型参数,该如何使用呢?请看下面代码。

我们先搭建出网络模型结构,随后将保存好的网络模型参数加载到网络模型结构中去。

python 复制代码
vgg16 = torchvision.models.vgg16(pretrained=False) # 网络模型结构
vgg16.load_state_dict(torch.load("vgg16_method2.pth")) # 加载保存的网络模型参数
print(vgg16)

保存方式1有一个小小的陷阱。

我们通过自己搭建一个网络来说明这个陷阱。

我们在save.py文件中搭建我们的网络结构,并将其保存。

python 复制代码
# 陷阱1
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
    def forward(self, x):
        x = self.conv1(x)
        return x


tudui = Tudui()
torch.save(tudui, "tudui_method1.pth")

接下来我们按照加载方式1的方法在load.py文件中加载这个模型。

python 复制代码
# 陷阱1
model = torch.load("tudui_method1.pth")
print(model)

打印结果为:

python 复制代码
AttributeError: Can't get attribute 'Tudui' on <module '__main__' from 'D:/graduate0/pytorch_practice/model_load.py'>

我们发现报错了,错误的原因是不能得到Tudui这个属性。

我们把网络结构添加在load.py文件中。注意:此时不需要创建网络模型,即不用运行tudui=Tudui()这句代码。

python 复制代码
from torch import nn
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
    def forward(self, x):
        x = self.conv1(x)
        return x


# tudui = Tudui()
model = torch.load("tudui_method1.pth")
print(model)

打印结果为:

python 复制代码
Tudui(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
)
相关推荐
Network_Engineer几秒前
从零手写LSTM:从门控原理到PyTorch源码级实现
人工智能·pytorch·lstm
芝士爱知识a4 分钟前
AlphaGBM 深度解析:下一代基于 AI 与蒙特卡洛的智能期权分析平台
数据结构·人工智能·python·股票·alphagbm·ai 驱动的智能期权分析·期权
weixin_6684 分钟前
GitHub 2026年AI项目热度分析报告-AI分析-分享
人工智能·github
vlln6 分钟前
【论文速读】达尔文哥德尔机 (Darwin Gödel Machine): 自进化智能体的开放式演化
人工智能·深度学习·ai agent
Katecat996639 分钟前
目标检测咖啡果实成熟度检测:RetinaNet-X101模型实现
人工智能·目标检测·目标跟踪
AAD5558889911 分钟前
基于Mask_RCNN的猫科动物目标检测识别模型实现与分析
人工智能·目标检测·计算机视觉
Katecat9966315 分钟前
基于YOLOv8和MAFPN的骆驼目标检测系统实现
人工智能·yolo·目标检测
合力亿捷-小亿18 分钟前
2026年AI语音机器人测评推荐:复杂噪声环境下语义识别准确率对比分析
人工智能·机器人
子夜江寒18 分钟前
基于 LSTM 的中文情感分类项目解析
人工智能·分类·lstm