PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
邪恶马铃薯几秒前
python基本语法元素
python
我想睡觉26120 分钟前
Python训练营打卡DAY51
开发语言·人工智能·python·深度学习·机器学习
不叫猫先生21 分钟前
Bright Data网页抓取工具实战:BOSS直聘爬虫 + PandasAI分析洞察前端岗位市场趋势
爬虫·python·ai·代理
Blossom.11821 分钟前
基于生成对抗网络(GAN)的图像生成与编辑:原理、应用与实践
人工智能·python·深度学习·机器学习·计算机视觉·分类·tensorflow
猛犸MAMMOTH42 分钟前
Python打卡第51天
开发语言·python·深度学习
gavin carter1 小时前
gitHub hexo 个人博客升级版
python·github·hexo
苏苏susuus1 小时前
深度学习:PyTorch简介
人工智能·pytorch·深度学习
社会零时工2 小时前
【python】基于pycharm的海康相机SDK二次开发
python·opencv·pycharm·相机
火车叼位2 小时前
从 pip 到 pipx:隔离、轻量、跨平台的 Python 工具管理
python
魔都吴所谓2 小时前
【无标题】
python