Pytorch学习--神经网络--网络模型的保存与读取

一、网络模型的保存与读取方式1

方法讲解


保存模型

python 复制代码
import torch
import torchvision
model = torchvision.models.vgg16(weights='DEFAULT')
#保存模型和参数
torch.save(model,"save_method1.pth")

读取模型

python 复制代码
import torch
model = torch.load("save_method1.pth")
print(model)

输出:

比较坑人的点

使用 torch.save 必须将该模型的架构引入到该文件中(可以使用from A import B的方式来解决),这里举一个例子来说明

保存模型

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

#保存模型和参数

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
Yorelee = Mary()
torch.save(Yorelee,"save_method1_question.pth")

读取模型

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

model = torch.load("save_method1_question.pth")

print(model)

报错如下

说明我们还要把 Mary 这个框架复制到读取模型的.py文件中

重新更正后的读取模型代码

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Mary(nn.Module):
    def __init__(self):
        super(Mary,self).__init__()
        self.model1 = nn.Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = torch.load("save_method1_question.pth")

print(model)
或者
python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from torch_save import Mary   #这里仅举一个例子


model = torch.load("save_method1_question.pth")

print(model)

二、网络模型的保存与读取方式2

保存模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear


vgg_model = torchvision.models.vgg16(weights='DEFAULT')
#保存参数
torch.save(vgg_model.state_dict(),"save_method2.pth")

读取模型参数

python 复制代码
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

vgg_model = torchvision.models.vgg16(weights='DEFAULT')
parameter = torch.load("save_method2.pth")
vgg_model.load_state_dict(parameter)
print(vgg_model)
相关推荐
科技小花3 小时前
全球化深水区,数据治理成为企业出海 “核心竞争力”
大数据·数据库·人工智能·数据治理·数据中台·全球化
X56614 小时前
如何在 Laravel 中正确保存嵌套动态表单数据(主服务与子服务)
jvm·数据库·python
zhuiyisuifeng4 小时前
2026前瞻:GPTimage2镜像官网或将颠覆视觉创作
人工智能·gpt
徐健峰4 小时前
GPT-image-2 热门玩法实战(一):AI 看手相 — 一张手掌照片生成专业手相分析图
人工智能·gpt
weixin_370976354 小时前
AI的终极赛跑:进入AGI,还是泡沫破灭?
大数据·人工智能·agi
Slow菜鸟4 小时前
AI学习篇(五) | awesome-design-md 使用说明
人工智能·学习
ZhengEnCi4 小时前
03ab-PyTorch安装教程 📚
python
冬奇Lab5 小时前
RAG 系列(五):Embedding 模型——语义理解的核心
人工智能·llm·aigc
深小乐5 小时前
AI 周刊【2026.04.27-05.03】:Anthropic 9000亿美元估值、英伟达死磕智能体、中央重磅定调AI
人工智能
码点滴5 小时前
什么时候用 DeepSeek V4,而不是 GPT-5/Claude/Gemini?
人工智能·gpt·架构·大模型·deepseek