机器学习 - save和load训练好的模型

如果已经训练好了一个模型,你就可以save和load这模型。

For saving and loading models in PyTorch, there are three main methods you should be aware of.

PyTorch method What does it do?
torch.save Saves a serialized object to disk using Python's pickle utility. Models, tensors and various other Python objects like dictionaries can be saved using torch.save
torch.load Uses pickle's unpickling features to deserialize and load pickled Python object files (like models, tensors or dictionaries) into memory. You can also set which device to load the object to (CPU, GPU etc)
torch.nn.Module.load_state_dict Loads a model's parameter dictionary (model.state_dict()) using a saved state_dict() object

在 PyTorch 中,pickle 是一个用于序列化和反序列化Python对象的标准库模块。它可以将Python对象转换为字节流 (即序列化),并将字节流转换回Python对象 (即反序列化)。pickle模块在很多情况下都非常有用,特别是在保存和加载模型,保存训练中间状态等方面。

在深度学习中,经常需要保存训练好的模型或者训练过程中的中间结果,以便后续的使用或分析。PyTorch提高了方便的API来保存和加载模型,其中就包括了使用pickle模块进行对象的序列化和反序列化。


save model

python 复制代码
import torch
from pathlib import Path 

# 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)

# 2. Create model save path
MODEL_NAME = "trained_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model state dict 
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj = model_0.state_dict(),
			f = MODEL_SAVE_PATH)

就能看到 trained_model.pth 文件下载到所属的文件夹位置。


Load the saved PyTorch model

You can load it in using torch.nn.Module.load_state_dict(torch.load(f)) where f is the filepath of the saved model state_dict().

Why call torch.load() inside torch.nn.Module.load_state_dict()?

Because you only saved the model's state_dict() which is a dictionary of learned parameters and not the entire model, you first have to load the state_dict() with torch.load() and then pass that state_dict() to a new instance of the model (which is a subclass of nn.Module).

python 复制代码
# Instantiate a new instance of the model 
loaded_model_0 = LinearRegressionModel()

# Load the state_dict of the saved model
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

# 结果如下
<All keys matched successfully>

测试 loaded model。

python 复制代码
# Put the loaded model into evaluation model 
loaded_model_0.eval() 

# 2. Use the inference mode context manager to make predictions
with torch.inference_mode():
  loaded_model_preds = loaded_model_0(X_test)

# Compare previous model predictions with loaded model predictions
print(y_preds == loaded_model_preds) 

# 结果如下
tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])

看到这了,点个赞呗~

相关推荐
dazzle几秒前
计算机视觉处理(OpenCV基础教学(二十二):霍夫变换技术详解)
人工智能·opencv·计算机视觉
狗狗学不会9 分钟前
RK3588 极致性能:使用 Pybind11 封装 MPP 实现 Python 端 8 路视频硬件解码
人工智能·python·音视频
Aevget10 分钟前
Kendo UI for jQuery 2025 Q4新版亮点 - AI 助手持续加持,主力开发更智能
人工智能·ui·jquery·界面控件·kendo ui
北京耐用通信10 分钟前
耐达讯自动化CANopen转Profibus网关在矿山机械RFID读写器应用中的技术分析
人工智能·科技·物联网·自动化·信息与通信
飞睿科技12 分钟前
UWB技术在机器人领域的创新应用与前景
网络·人工智能·机器人·定位技术·uwb技术
空山新雨后、13 分钟前
RAG:搜索引擎与大模型的完美融合
人工智能·搜索引擎·rag
sld16818 分钟前
以S2B2C平台重构快消品生态:效率升级与价值共生
大数据·人工智能·重构
love530love20 分钟前
EPGF 新手教程 21把“环境折磨”从课堂中彻底移除:EPGF 如何重构 AI / Python 教学环境?
人工智能·windows·python·重构·架构·epgf
ldccorpora21 分钟前
Chinese News Translation Text Part 1数据集介绍,官网编号LDC2005T06
数据结构·人工智能·python·算法·语音识别
大学生毕业题目21 分钟前
毕业项目推荐:99-基于yolov8/yolov5/yolo11的肾结石检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·目标检测·cnn·pyqt·肾结石检测