Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解


保存模型

python 复制代码
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

python 复制代码
import torch
model = torch.load("save_method1.pth")
print(model)

输出:

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

#保存模型和参数

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

model = torch.load("save_method1_question.pth")

print(model)

报错如下

说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = torch.load("save_method1_question.pth")

print(model)
或者
python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子


model = torch.load("save_method1_question.pth")

print(model)

二、网络模型的保存与读取方式2

保存模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)
相关推荐
wanzhong233313 分钟前
ArcGIS学习-20 实战-地形研究
学习
研梦非凡15 分钟前
CVPR 2025|基于视觉语言模型的零样本3D视觉定位
人工智能·深度学习·计算机视觉·3d·ai·语言模型·自然语言处理
wanzhong233318 分钟前
ArcGIS学习-20 实战-县域水文分析
学习·arcgis
Monkey的自我迭代19 分钟前
多目标轮廓匹配
人工智能·opencv·计算机视觉
每日新鲜事20 分钟前
Saucony索康尼推出全新 WOOOLLY 运动生活羊毛系列 生动无理由,从专业跑步延展运动生活的每一刻
大数据·人工智能
空白到白25 分钟前
机器学习-聚类
人工智能·算法·机器学习·聚类
小马学嵌入式~28 分钟前
嵌入式 SQLite 数据库开发笔记
linux·c语言·数据库·笔记·sql·学习·sqlite
索迪迈科技31 分钟前
java后端工程师进修ing(研一版 || day40)
java·开发语言·学习·算法
中新赛克43 分钟前
双引擎驱动!中新赛克AI安全方案入选网安创新大赛优胜榜单
人工智能·安全
飞哥数智坊1 小时前
解决AI幻觉,只能死磕模型?OpenAI给出不一样的思路
人工智能·openai