PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
A懿轩A23 分钟前
【Java 基础编程】Java 变量与八大基本数据类型详解:从声明到类型转换,零基础也能看懂
java·开发语言·python
Tansmjs40 分钟前
使用Python自动收发邮件
jvm·数据库·python
m0_5613596742 分钟前
用Python监控系统日志并发送警报
jvm·数据库·python
idwangzhen1 小时前
GEO优化系统哪个功能强大
python·信息可视化
许泽宇的技术分享1 小时前
第 1 章:认识 Claude Code
开发语言·人工智能·python
AIFQuant2 小时前
如何利用免费股票 API 构建量化交易策略:实战分享
开发语言·python·websocket·金融·restful
布局呆星2 小时前
SQLite数据库的介绍与使用
数据库·python
2401_838472512 小时前
用Python和Twilio构建短信通知系统
jvm·数据库·python
weixin_452159552 小时前
如何从Python初学者进阶为专家?
jvm·数据库·python
Hello.Reader2 小时前
面向 403 与域名频繁变更的合规爬虫工程实践以 Libvio 系站点为例
爬虫·python·网络爬虫