PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
apcipot_rain5 分钟前
Python实战——蒙特卡洛模拟分析杀牌游戏技能收益
python·游戏·数学建模
老绿光9 分钟前
Python 字典完全指南:从入门到实战
linux·服务器·python
是小蟹呀^20 分钟前
【总结】LangChain中如何维持记忆
python·langchain·memory
蓝色的杯子21 分钟前
OpenClaw一文详细了解-手搓OpenClaw-4 Tool Runtime
人工智能·python
克里普crirp28 分钟前
电离层TEC地图中添加晨昏线/昼夜转换线
python
Dxy123931021629 分钟前
Python使用PyEnchant详解:打造高效拼写检查工具
开发语言·python
架构师老Y38 分钟前
011、消息队列应用:RabbitMQ、Kafka与Celery
python·架构·kafka·rabbitmq·ruby
枫叶林FYL43 分钟前
【Python高级工程与架构实战】项目四:生产级LLM Agent框架:基于PydanticAI的类型安全企业级实现
人工智能·python·自然语言处理
龙腾AI白云44 分钟前
多模大模型应用实战:智能问答系统开发
python·机器学习·数据分析·django·tornado
Hommy881 小时前
【开源剪映小助手】配置与部署
python·开源·aigc·剪映小助手