Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解


保存模型

python 复制代码
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

python 复制代码
import torch
model = torch.load("save_method1.pth")
print(model)

输出:

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

#保存模型和参数

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

model = torch.load("save_method1_question.pth")

print(model)

报错如下

说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = torch.load("save_method1_question.pth")

print(model)
或者
python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子


model = torch.load("save_method1_question.pth")

print(model)

二、网络模型的保存与读取方式2

保存模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)
相关推荐
幼稚园的山代王5 分钟前
Prompt Enginering(提示工程)先进技术
java·人工智能·ai·chatgpt·langchain·prompt
dfsj660116 分钟前
LLMs 系列科普文(14)
人工智能·深度学习·算法
摘取一颗天上星️12 分钟前
深入解析机器学习的心脏:损失函数及其背后的奥秘
人工智能·深度学习·机器学习·损失函数·梯度下降
远方160919 分钟前
20-Oracle 23 ai free Database Sharding-特性验证
数据库·人工智能·oracle
znhy605821 分钟前
智能终端与边缘计算按章复习
人工智能·边缘计算
__Benco27 分钟前
OpenHarmony平台驱动使用(十五),SPI
人工智能·驱动开发·harmonyos
Listennnn28 分钟前
AI系统的构建
人工智能·系统架构
新智元31 分钟前
全球 30 名顶尖数学家秘密集会围剿 AI,当场破防!惊呼已接近数学天才
人工智能·openai
nenchoumi311931 分钟前
AirSim/Cosys-AirSim 游戏开发(一)XBox 手柄 Windows + python 连接与读取
windows·python·xbox
GoodStudyAndDayDayUp32 分钟前
初入 python Django 框架总结
数据库·python·django