PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
江上清风山间明月11 分钟前
如何将python开发的window应用打包成exe
开发语言·python·exe·打包
知识分享小能手13 分钟前
Flask入门学习教程,从入门到精通, Flask模板 — 完整知识点与案例代码 (2)
python·学习·flask
不懒不懒18 分钟前
基于 Flask —— 异步任务处理接口服务
后端·python·flask
happybasic21 分钟前
Python库升级标准流程~
linux·前端·python
彦为君26 分钟前
JavaSE-11-BIO/NIO/AIO(多人聊天室)
java·开发语言·python·ai·nio
恣艺28 分钟前
Python 实用工具与机器学习入门:Rich + Tqdm + Faker + Schedule + Scikit-learn
python·机器学习·scikit-learn
测试员周周35 分钟前
【Appium 系列】第14节-断言与验证 — Validator 的设计
android·人工智能·python·功能测试·ios·单元测试·appium
Hanniel39 分钟前
Python __slots__ 入门指南
开发语言·python·性能优化
小白|44 分钟前
tensorflow:昇腾CANN的TensorFlow适配层
人工智能·python·tensorflow
彦为君1 小时前
JavaSE-10-并发编程(11个案例)
java·开发语言·python·ai·nio