PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
yukai080083 分钟前
【最后203篇系列】039 JWT使用
python
独好紫罗兰29 分钟前
对python的再认识-基于数据结构进行-a006-元组-拓展
开发语言·数据结构·python
Dfreedom.31 分钟前
图像直方图完全解析:从原理到实战应用
图像处理·python·opencv·直方图·直方图均衡化
铉铉这波能秀1 小时前
LeetCode Hot100数据结构背景知识之集合(Set)Python2026新版
数据结构·python·算法·leetcode·哈希算法
怒放吧德德1 小时前
Python3基础:基础实战巩固,从“会用”到“活用”
后端·python
aiguangyuan1 小时前
基于BERT的中文命名实体识别实战解析
人工智能·python·nlp
喵手1 小时前
Python爬虫实战:知识挖掘机 - 知乎问答与专栏文章的深度分页采集系统(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集知乎问答与专栏文章·采集知乎数据·采集知乎数据存储sqlite
铉铉这波能秀1 小时前
LeetCode Hot100数据结构背景知识之元组(Tuple)Python2026新版
数据结构·python·算法·leetcode·元组·tuple
kali-Myon1 小时前
2025春秋杯网络安全联赛冬季赛-day2
python·安全·web安全·ai·php·pwn·ctf
Olamyh2 小时前
【 超越 ReAct:手搓 Plan-and-Execute (Planner) Agent】
python·ai