Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解


保存模型

python 复制代码
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

python 复制代码
import torch
model = torch.load("save_method1.pth")
print(model)

输出:

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

#保存模型和参数

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

model = torch.load("save_method1_question.pth")

print(model)

报错如下

说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = torch.load("save_method1_question.pth")

print(model)
或者
python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子


model = torch.load("save_method1_question.pth")

print(model)

二、网络模型的保存与读取方式2

保存模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)
相关推荐
CoovallyAIHub3 分钟前
YOLO11算法深度解析:四大工业场景实战,开源数据集助力AI质检落地
深度学习·算法·计算机视觉
智算菩萨5 分钟前
未来家居可能的新变化:从“智能设备堆叠”到“自适应生活系统”
人工智能·生活
STLearner5 分钟前
AAAI 2026 | 时空数据(Spatial-temporal)论文总结[下](自动驾驶,天气预报,城市科学,POI推荐等)
人工智能·python·深度学习·机器学习·数据挖掘·自动驾驶·智慧城市
后端小张5 分钟前
【AI 学习】LangChain框架深度解析:从核心组件到企业级应用实战
java·人工智能·学习·langchain·tensorflow·gpt-3·ai编程
NAGNIP8 分钟前
LongCat-Flash-Omni:美团的全模态大模型
人工智能
未来之窗软件服务8 分钟前
幽冥大陆(六十二) 多数据库交叉链接系统Go语言—东方仙盟筑基期
数据库·人工智能·oracle·golang·数据库集群·仙盟创梦ide·东方仙盟
Coder个人博客16 分钟前
三大DDS实现对比分析(CycloneDDS/Fast DDS/OpenDDS)
人工智能·自动驾驶·dds
郝学胜-神的一滴17 分钟前
人工智能与机器学习:从理论到实践的技术全景
人工智能·python·程序人生·算法·机器学习
长安牧笛17 分钟前
开发中老年发型设计推荐系统,输入脸型,年龄,推荐适合的发型,提供效果图参考。
python
liangshanbo121518 分钟前
AI给我的调理方案
人工智能·中医调理