pytorch_trick(4) 模型本地保存与读取方法

模型本地保存与读取方法

同时,借助state_dict()方法,我们可以实现模型或优化器的本地保存于读取。此处以模型为例,优化器的本地保存相关操作类似。

对于模型而言,其实也有state_dict()方法。通过该方法的调用,可以查看模型全部参数信息。

值得注意的是,模型的训练和保存,本质上都是针对模型的参数。而模型的state_dict()则包含了模型当前全部的参数信息。因此,保存了模型的state_dict()就相当于是保存了模型。

python 复制代码
# 设置随机数种子
torch.manual_seed(24)  

# 实例化模型  
tanh_model1 = net_class2(act_fun= torch.tanh, in_features=5, BN_model='pre')
tanh_model1.state_dict()

1、保存模型参数

首先,我们可以将该存有模型全部参数信息的字典对象赋给某个变量。

python 复制代码
t1 = tanh_model1.state_dict()
t1

其次,我们也可以通过torch.save来将该参数保存至本地。

python 复制代码
torch.save(tanh_model1.state_dict(), 'tanh1.pt')

对于torch.save函数来说,第一个参数是需要保存的模型参数,而第二个参数则是保存到本地的文件名。一般来说可以令其后缀为.pt.pth。而当我们需要读取保存的参数结果时,则可以直接使用load_state_dict方法。该方法的使用我们稍后就会谈到。

接下来进行模型训练,也就是模型参数调整。回顾此前学习内容,当我们进行模型训练时,实际上就是借助损失函数和反向传播机制进行梯度求解,然后利用优化器根据梯度值去更新各线性层参数。

python 复制代码
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(tanh_model1.parameters(), lr=0.05)
for X, y in train_loader:
    yhat = tanh_model1.forward(X)
    loss = criterion(yhat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

训练完一轮之后,我们可以查看模型状态:

python 复制代码
tanh_model1.state_dict()

我们发现模型的参数已经发生了变化。当然,此时t1也随之发生了变化

python 复制代码
t1

2、还原模型参数

此时,如果我们想还原tanh_model1中原始参数,我们只能考虑通过使用load_state_dict方法,将本次保存的原模型参数替换当前的tanh_model1中参数,具体方法如下:

python 复制代码
torch.load('tanh1.pt')
tanh_model1.load_state_dict(torch.load('tanh1.pt'))
tanh_model1.state_dict()

至此,我们就完成了模型训练与保存的基本过程。当然,除了模型可以按照上述方法保存外,优化器也可以类似进行本地存储。

当然,结合此前介绍的深拷贝的相关概念,此处我们能否通过深拷贝的方式将模型参数保存在当前操作空间内然后再替换训练后的模型参数呢?同学们可以自行尝试

相关推荐
产品经理独孤虾1 分钟前
人工智能大模型如何助力电商产品经理打造高效的商品工业属性画像
人工智能·机器学习·ai·大模型·产品经理·商品画像·商品工业属性
老任与码11 分钟前
Spring AI Alibaba(1)——基本使用
java·人工智能·后端·springaialibaba
蹦蹦跳跳真可爱58924 分钟前
Python----OpenCV(图像増强——高通滤波(索贝尔算子、沙尔算子、拉普拉斯算子),图像浮雕与特效处理)
人工智能·python·opencv·计算机视觉
雷羿 LexChien34 分钟前
从 Prompt 管理到人格稳定:探索 Cursor AI 编辑器如何赋能 Prompt 工程与人格风格设计(上)
人工智能·python·llm·编辑器·prompt
两棵雪松1 小时前
如何通过向量化技术比较两段文本是否相似?
人工智能
heart000_11 小时前
128K 长文本处理实战:腾讯混元 + 云函数 SCF 构建 PDF 摘要生成器
人工智能·自然语言处理·pdf
敲键盘的小夜猫1 小时前
LLM复杂记忆存储-多会话隔离案例实战
人工智能·python·langchain
开开心心_Every2 小时前
便捷的Office批量转PDF工具
开发语言·人工智能·r语言·pdf·c#·音视频·symfony
cooldream20092 小时前
「源力觉醒 创作者计划」_基于 PaddlePaddle 部署 ERNIE-4.5-0.3B 轻量级大模型实战指南
人工智能·paddlepaddle·文心大模型
亚里随笔2 小时前
L0:让大模型成为通用智能体的强化学习新范式
人工智能·llm·大语言模型·rlhf