机器学习 - save和load训练好的模型

如果已经训练好了一个模型,你就可以save和load这模型。

For saving and loading models in PyTorch, there are three main methods you should be aware of.

PyTorch method What does it do?
torch.save Saves a serialized object to disk using Python's pickle utility. Models, tensors and various other Python objects like dictionaries can be saved using torch.save
torch.load Uses pickle's unpickling features to deserialize and load pickled Python object files (like models, tensors or dictionaries) into memory. You can also set which device to load the object to (CPU, GPU etc)
torch.nn.Module.load_state_dict Loads a model's parameter dictionary (model.state_dict()) using a saved state_dict() object

在 PyTorch 中,pickle 是一个用于序列化和反序列化Python对象的标准库模块。它可以将Python对象转换为字节流 (即序列化),并将字节流转换回Python对象 (即反序列化)。pickle模块在很多情况下都非常有用,特别是在保存和加载模型,保存训练中间状态等方面。

在深度学习中,经常需要保存训练好的模型或者训练过程中的中间结果,以便后续的使用或分析。PyTorch提高了方便的API来保存和加载模型,其中就包括了使用pickle模块进行对象的序列化和反序列化。


save model

python 复制代码
import torch
from pathlib import Path 

# 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)

# 2. Create model save path
MODEL_NAME = "trained_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model state dict 
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj = model_0.state_dict(),
			f = MODEL_SAVE_PATH)

就能看到 trained_model.pth 文件下载到所属的文件夹位置。


Load the saved PyTorch model

You can load it in using torch.nn.Module.load_state_dict(torch.load(f)) where f is the filepath of the saved model state_dict().

Why call torch.load() inside torch.nn.Module.load_state_dict()?

Because you only saved the model's state_dict() which is a dictionary of learned parameters and not the entire model, you first have to load the state_dict() with torch.load() and then pass that state_dict() to a new instance of the model (which is a subclass of nn.Module).

python 复制代码
# Instantiate a new instance of the model 
loaded_model_0 = LinearRegressionModel()

# Load the state_dict of the saved model
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

# 结果如下
<All keys matched successfully>

测试 loaded model。

python 复制代码
# Put the loaded model into evaluation model 
loaded_model_0.eval() 

# 2. Use the inference mode context manager to make predictions
with torch.inference_mode():
  loaded_model_preds = loaded_model_0(X_test)

# Compare previous model predictions with loaded model predictions
print(y_preds == loaded_model_preds) 

# 结果如下
tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])

看到这了,点个赞呗~

相关推荐
深小乐3 分钟前
从 AI Skills 学实战技能(六):让 AI 帮你总结网页、PDF、视频
人工智能
宝贝儿好10 分钟前
【LLM】第二章:文本表示:词袋模型、小案例:基于文本的推荐系统(酒店推荐)
人工智能·python·深度学习·神经网络·自然语言处理·机器人·语音识别
周末程序猿28 分钟前
详解 karpathy 的 microgpt:实现一个浏览器运行的 gpt
人工智能·llm
ACP广源盛1392462567334 分钟前
破局 Type‑C 切换器痛点@ACP#GSV6155+LH3828/GSV2221+LH3828 黄金方案
c语言·开发语言·网络·人工智能·嵌入式硬件·计算机外设·电脑
xixixi7777742 分钟前
通信领域的“中国速度”:从5G-A到6G,从地面到星空
人工智能·5g·安全·ai·fpga开发·多模态
Dfreedom.1 小时前
计算机视觉全景图
人工智能·算法·计算机视觉·图像算法
EasyDSS1 小时前
智能会议管理系统/私有化视频会议平台EasyDSS私有化部署构建企业级私域视频全场景解决方案
人工智能·音视频
zhanghongbin012 小时前
成本追踪:AI API 成本计算与预算管理
人工智能
YBAdvanceFu2 小时前
从零构建智能体:深入理解 ReAct Plan Solve Reflection 三大经典范式
人工智能·python·机器学习·数据挖掘·多智能体·智能体
啦啦啦在冲冲冲2 小时前
多头注意力机制的优势是啥,遇到长文本的情况,可以从哪些情况优化呢
人工智能·深度学习