PyTorch中保存模型的两种方式

文章目录

  • [一、状态字典(State Dictionary)](#一、状态字典(State Dictionary))
  • [二、序列化模型(Serialized Model)](#二、序列化模型(Serialized Model))
  • 三、示例代码

一、状态字典(State Dictionary)

这种保存形式将模型的参数保存为一个字典,其中包含了所有模型的权重和偏置等参数。状态字典保存了模型在训练过程中学到的参数值,而不包含模型的结构。可以使用这个字典来加载模型的参数,并将其应用于相同结构的模型。

在 PyTorch 中,您可以使用 torch.save() 函数将模型的状态字典保存到文件中,例如:

python 复制代码
torch.save(model.state_dict(), 'model.pth')

然后,可以使用 torch.load() 函数加载状态字典并将其应用于相同结构的模型:

python 复制代码
model = MyModel()  # 创建模型对象
model.load_state_dict(torch.load('model.pth'))

这种保存形式非常适用于仅保存和加载模型的参数,而不需要保存和加载模型的结构。

二、序列化模型(Serialized Model)

这种保存形式将整个模型(包括模型的结构、参数等)保存为一个文件。序列化模型保存了模型的完整信息,可以完全恢复模型的状态,包括模型的结构、权重、偏置以及其他相关参数。

在 PyTorch 中,您可以使用 torch.save() 函数直接保存整个模型对象,例如:

python 复制代码
torch.save(model, 'model.pth')

然后,您可以使用 torch.load() 函数加载整个序列化模型:

python 复制代码
model = torch.load('model.pth')

这种保存形式适用于需要保存和加载完整模型信息的情况,包括模型的结构和参数。

三、示例代码

python 复制代码
import torch

class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )

    def forward(self,x):
        return self.net(x)

square_net = LinearNet(1,1)

# square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重

if __name__ == '__main__':

    # print(square_net(torch.tensor([3.16],dtype=torch.float32)))
    # save 方式1
    torch.save(square_net.state_dict(), "./w1.pth")
    my_state_dict = torch.load("./w1.pth")
    print("纯state_dict:\n", my_state_dict)
    print("type:", type(my_state_dict))

    # save 方式2
    torch.save(square_net, "./w2.pth")
    my_state_dict = torch.load("./w2.pth")
    print("\n\n模型结构:\n", my_state_dict)
    print("type:", type(my_state_dict))


    # 执行结果
    '''
    纯state_dict:
    OrderedDict([('net.0.weight', tensor([[ 0.0820],
            [-0.6923],
            [ 0.5066],
            [-0.8931],
            [ 0.0460]])), ('net.0.bias', tensor([ 0.1455,  0.5106,  0.2347,  0.4903, -0.6838])), ('net.2.weight', tensor([[-0.4055, -0.2721,  0.3770, -0.2285,  0.3025],
            [-0.0416,  0.0133, -0.3834, -0.2151,  0.1454],
            [ 0.0749, -0.3664, -0.1901, -0.2829,  0.3957],
            [-0.3567,  0.2668,  0.3343, -0.3351, -0.3808],
            [ 0.4375,  0.1000,  0.1185,  0.2295, -0.3997]])), ('net.2.bias', tensor([-0.2405, -0.2751,  0.1928,  0.3970, -0.0005])), ('net.4.weight', tensor([[-0.4388, -0.2654,  0.3038,  0.2008,  0.0381]])), ('net.4.bias', tensor([0.1847]))])


    模型结构:
    LinearNet(
    (net): Sequential(
        (0): Linear(in_features=1, out_features=5, bias=True)
        (1): Sigmoid()
        (2): Linear(in_features=5, out_features=5, bias=True)
        (3): Sigmoid()
        (4): Linear(in_features=5, out_features=1, bias=True)
    )
    )
    '''
相关推荐
IT古董28 分钟前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师1 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)2 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
小王子10242 小时前
设计模式Python版 组合模式
python·设计模式·组合模式
kakaZhui3 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20253 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥4 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin4 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客4 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
云空5 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析