Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解


保存模型

python 复制代码
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

python 复制代码
import torch
model = torch.load("save_method1.pth")
print(model)

输出:

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

#保存模型和参数

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

model = torch.load("save_method1_question.pth")

print(model)

报错如下

说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = torch.load("save_method1_question.pth")

print(model)
或者
python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子


model = torch.load("save_method1_question.pth")

print(model)

二、网络模型的保存与读取方式2

保存模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)
相关推荐
Solyn_HAN2 分钟前
Python 生信进阶:Biopython 库完全指南(序列处理 + 数据库交互)
python·生物信息学·biopython
only-code9 分钟前
SeqXGPT:Sentence-Level AI-Generated Text Detection —— 把大模型的“波形”变成测谎仪
人工智能·大语言模型·ai检测·文本检测
九河_12 分钟前
解决pip install gym==0.19.0安装失败问题
开发语言·python·pip·gym
AI科技星14 分钟前
引力编程时代:人类文明存续与升维
数据结构·人工智能·经验分享·算法·计算机视觉
老胡说科技2 小时前
美砺科技谢秀鹏:让“看见”走在“相信”之前,AI驱动下的数字化范式革命,从“技术长征”到“生态协同”
人工智能·科技
iamohenry2 小时前
古早味的心理咨询聊天机器人
python·自然语言处理
早睡冠军候选人3 小时前
Ansible学习----管理复杂的 Play 和 Playbook 内容
运维·学习·云原生·ansible
LBuffer4 小时前
破解入门学习笔记题四十六
数据库·笔记·学习
endcy20165 小时前
基于Spring AI的RAG和智能体应用实践
人工智能·ai·系统架构