机器学习 - save和load训练好的模型

如果已经训练好了一个模型,你就可以save和load这模型。

For saving and loading models in PyTorch, there are three main methods you should be aware of.

PyTorch method What does it do?
torch.save Saves a serialized object to disk using Python's pickle utility. Models, tensors and various other Python objects like dictionaries can be saved using torch.save
torch.load Uses pickle's unpickling features to deserialize and load pickled Python object files (like models, tensors or dictionaries) into memory. You can also set which device to load the object to (CPU, GPU etc)
torch.nn.Module.load_state_dict Loads a model's parameter dictionary (model.state_dict()) using a saved state_dict() object

在 PyTorch 中,pickle 是一个用于序列化和反序列化Python对象的标准库模块。它可以将Python对象转换为字节流 (即序列化),并将字节流转换回Python对象 (即反序列化)。pickle模块在很多情况下都非常有用,特别是在保存和加载模型,保存训练中间状态等方面。

在深度学习中,经常需要保存训练好的模型或者训练过程中的中间结果,以便后续的使用或分析。PyTorch提高了方便的API来保存和加载模型,其中就包括了使用pickle模块进行对象的序列化和反序列化。


save model

python 复制代码
import torch
from pathlib import Path 

# 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)

# 2. Create model save path
MODEL_NAME = "trained_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model state dict 
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj = model_0.state_dict(),
			f = MODEL_SAVE_PATH)

就能看到 trained_model.pth 文件下载到所属的文件夹位置。


Load the saved PyTorch model

You can load it in using torch.nn.Module.load_state_dict(torch.load(f)) where f is the filepath of the saved model state_dict().

Why call torch.load() inside torch.nn.Module.load_state_dict()?

Because you only saved the model's state_dict() which is a dictionary of learned parameters and not the entire model, you first have to load the state_dict() with torch.load() and then pass that state_dict() to a new instance of the model (which is a subclass of nn.Module).

python 复制代码
# Instantiate a new instance of the model 
loaded_model_0 = LinearRegressionModel()

# Load the state_dict of the saved model
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

# 结果如下
<All keys matched successfully>

测试 loaded model。

python 复制代码
# Put the loaded model into evaluation model 
loaded_model_0.eval() 

# 2. Use the inference mode context manager to make predictions
with torch.inference_mode():
  loaded_model_preds = loaded_model_0(X_test)

# Compare previous model predictions with loaded model predictions
print(y_preds == loaded_model_preds) 

# 结果如下
tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])

看到这了,点个赞呗~

相关推荐
hay_lee2 分钟前
渐进式披露:Agent Skills让AI开发标准化
人工智能
阿里云云原生3 分钟前
探秘 AgentRun丨动态下发+权限隔离,重构 AI Agent 安全体系
人工智能·安全·阿里云·重构·agentrun
veminhe9 分钟前
人工智能学习笔记
人工智能
苍何fly9 分钟前
用腾讯版 Claude Code 做了个小红书封面图 Skills,已开源!
人工智能·经验分享
hnult12 分钟前
全功能学练考证在线考试平台,赋能技能认证
大数据·人工智能·笔记·课程设计
gang_unerry12 分钟前
量子退火与机器学习(4): 大模型 1-bit 量子化中的 QEP 与 QQA 准量子退火技术
人工智能·python·机器学习·量子计算
青瓷程序设计24 分钟前
【交通标志识别系统】python+深度学习+算法模型+Resnet算法+人工智能+2026计算机毕设项目
人工智能·python·深度学习
Mr.huang25 分钟前
RNN系列模型演进及其解决的问题
人工智能·rnn·lstm
香芋Yu29 分钟前
【深度学习教程——01_深度基石(Foundation)】05_数据太多怎么吃?Mini-batch训练的设计模式
深度学习·设计模式·batch
智驱力人工智能30 分钟前
货车走快车道检测 高速公路安全治理的工程实践与价值闭环 高速公路货车占用小客车道抓拍系统 城市快速路货车违规占道AI识别
人工智能·opencv·算法·安全·yolo·目标检测·边缘计算