机器学习 - save和load训练好的模型

如果已经训练好了一个模型,你就可以save和load这模型。

For saving and loading models in PyTorch, there are three main methods you should be aware of.

PyTorch method What does it do?
torch.save Saves a serialized object to disk using Python's pickle utility. Models, tensors and various other Python objects like dictionaries can be saved using torch.save
torch.load Uses pickle's unpickling features to deserialize and load pickled Python object files (like models, tensors or dictionaries) into memory. You can also set which device to load the object to (CPU, GPU etc)
torch.nn.Module.load_state_dict Loads a model's parameter dictionary (model.state_dict()) using a saved state_dict() object

在 PyTorch 中,pickle 是一个用于序列化和反序列化Python对象的标准库模块。它可以将Python对象转换为字节流 (即序列化),并将字节流转换回Python对象 (即反序列化)。pickle模块在很多情况下都非常有用,特别是在保存和加载模型,保存训练中间状态等方面。

在深度学习中,经常需要保存训练好的模型或者训练过程中的中间结果,以便后续的使用或分析。PyTorch提高了方便的API来保存和加载模型,其中就包括了使用pickle模块进行对象的序列化和反序列化。


save model

python 复制代码
import torch
from pathlib import Path 

# 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)

# 2. Create model save path
MODEL_NAME = "trained_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

# 3. Save the model state dict 
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj = model_0.state_dict(),
			f = MODEL_SAVE_PATH)

就能看到 trained_model.pth 文件下载到所属的文件夹位置。


Load the saved PyTorch model

You can load it in using torch.nn.Module.load_state_dict(torch.load(f)) where f is the filepath of the saved model state_dict().

Why call torch.load() inside torch.nn.Module.load_state_dict()?

Because you only saved the model's state_dict() which is a dictionary of learned parameters and not the entire model, you first have to load the state_dict() with torch.load() and then pass that state_dict() to a new instance of the model (which is a subclass of nn.Module).

python 复制代码
# Instantiate a new instance of the model 
loaded_model_0 = LinearRegressionModel()

# Load the state_dict of the saved model
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))

# 结果如下
<All keys matched successfully>

测试 loaded model。

python 复制代码
# Put the loaded model into evaluation model 
loaded_model_0.eval() 

# 2. Use the inference mode context manager to make predictions
with torch.inference_mode():
  loaded_model_preds = loaded_model_0(X_test)

# Compare previous model predictions with loaded model predictions
print(y_preds == loaded_model_preds) 

# 结果如下
tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])

看到这了,点个赞呗~

相关推荐
全知科技2 分钟前
API安全国家标准发布丨《数据安全技术 数据接口安全风险监测方法》
大数据·人工智能·安全
AI营销干货站7 分钟前
2025 AI市场舆情分析软件测评:原圈科技等3款工具深度对比
大数据·人工智能
微爱帮监所写信寄信8 分钟前
微爱帮监狱寄信写信系统后台PHP框架优化实战手册
android·开发语言·人工智能·网络协议·微信·https·php
民乐团扒谱机12 分钟前
【微实验】量子光梳技术革命:如何用量子压缩突破时频传递的终极极限?
人工智能·敏感性分析·量子力学·双梳
MARS_AI_14 分钟前
AI呼叫中心革命:大模型技术如何重构企业服务体验
人工智能·科技·自然语言处理·信息与通信·agi
EEPI23 分钟前
【论文阅读】Vision Language Models are In-Context Value Learners
论文阅读·人工智能·语言模型
金融Tech趋势派24 分钟前
2026企业微信私有化部署新选择:微盛·企微管家如何助力企业数据安全与运营效率提升?
大数据·人工智能·云计算·企业微信
短视频矩阵源码定制24 分钟前
专业的矩阵系统哪个公司好
大数据·人工智能·矩阵
jimmyleeee25 分钟前
人工智能基础知识笔记三十:模型的混合量化策略
人工智能·笔记
Gofarlic_oms127 分钟前
Cadence许可证全生命周期数据治理方案
java·大数据·运维·开发语言·人工智能·安全·自动化