PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
许彰午29 分钟前
30_Java Stream流操作全解
java·windows·python
秋91 小时前
3年经验Python后端转AI Engineer:3个月实战转型计划(2026版)
开发语言·人工智能·python
2601_956319881 小时前
期货夜盘无人值守监控什么:断线、无成交与拒单信号
python·区块链
CTA终结者1 小时前
期货量化目标仓和净持仓对不齐:天勤 TargetPosTask 与 pos 偏差排查
python·区块链
科技林总2 小时前
解决vllm服务漏扫问题
python·安全
祭曦念2 小时前
古诗小集开发实战:从零开发一款 HarmonyOS 古诗鉴赏应用
pytorch·深度学习·harmonyos
财经资讯数据_灵砚智能3 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年6月10日
大数据·人工智能·python·ai·信息可视化·自然语言处理·灵砚智能
namexingyun3 小时前
拆解Fable 5三重安全护栏:模型路由、蒸馏防护与生物安全分类器的技术原理 - 微元算力(weytoken)
java·人工智能·python·安全·架构·ai编程
chenment3 小时前
别再为每个模型单独写一套队列了:用 200 行代码封装多模态统一调用层
人工智能·python·产品
啊森要自信4 小时前
【GUI自动化测试】控件、鼠标键盘操作与多场景自动化
c语言·开发语言·python·adb·ipython