PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
'需尽欢'1 小时前
基于 Flask+Vue+MySQL的研学网站
python·mysql·flask
新子y2 小时前
【小白笔记】最大交换 (Maximum Swap)问题
笔记·python
程序员爱钓鱼3 小时前
Python编程实战 · 基础入门篇 | Python的缩进与代码块
后端·python
pr_note4 小时前
python|if判断语法对比
python
apocelipes6 小时前
golang unique包和字符串内部化
java·python·性能优化·golang
Geoking.7 小时前
NumPy zeros() 函数详解
python·numpy
Full Stack Developme7 小时前
java.text 包详解
java·开发语言·python
丁浩6668 小时前
Python机器学习---2.算法:逻辑回归
python·算法·机器学习
B站_计算机毕业设计之家8 小时前
计算机毕业设计:Python农业数据可视化分析系统 气象数据 农业生产 粮食数据 播种数据 爬虫 Django框架 天气数据 降水量(源码+文档)✅
大数据·爬虫·python·机器学习·信息可视化·课程设计·农业
Q_Q5110082859 小时前
python+uniapp基于微信小程序的旅游信息系统
spring boot·python·微信小程序·django·flask·uni-app·node.js