PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
仰望星空的凡人4 小时前
【JS逆向基础】数据库之MongoDB
javascript·数据库·python·mongodb
F_D_Z4 小时前
【PyTorch】图像多分类项目部署
人工智能·pytorch·python·深度学习·分类
pingzhuyan5 小时前
python入门篇12-虚拟环境conda的安装与使用
python·ai·llm·ocr·conda
香蕉可乐荷包蛋5 小时前
排序算法 (Sorting Algorithms)-Python示例
python·算法·排序算法
菜鸟学Python7 小时前
Python web框架王者 Django 5.0发布:20周年了!
前端·数据库·python·django·sqlite
旧时光巷8 小时前
【机器学习-4】 | 集成学习 / 随机森林篇
python·随机森林·机器学习·集成学习·sklearn·boosting·bagging
Ice__Cai9 小时前
Django + Celery 详细解析:构建高效的异步任务队列
分布式·后端·python·django
MediaTea9 小时前
Python 库手册:doctest 文档测试模块
开发语言·python·log4j
2025年一定要上岸9 小时前
【pytest高阶】源码的走读方法及插件hook
运维·前端·python·pytest
angushine9 小时前
Python将Word转换为Excel
python·word·excel