PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

python 复制代码
# 保存
torch.save(model.state_dict(), "model_params.pth")  

# 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
python 复制代码
# 保存
torch.save(model, "full_model.pth")  

# 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
python 复制代码
# 保存检查点
checkpoint = {
    'epoch': current_epoch,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")

# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

相关推荐
蓝纹绿茶8 分钟前
Python程序使用了Ffmpeg,结束程序后,文件夹中仍然生成音频、视频文件
python·ubuntu·ffmpeg·音视频
mahuifa16 分钟前
OpenCV 开发 -- 图像基本处理
人工智能·python·opencv·计算机视觉
一个java开发1 小时前
distributed.client.Client 用户可调用函数分析
大数据·python
eqwaak01 小时前
Matplotlib 动态显示详解:技术深度与创新思考
网络·python·网络协议·tcp/ip·语言模型·matplotlib
007php0071 小时前
某大厂MySQL面试之SQL注入触点发现与SQLMap测试
数据库·python·sql·mysql·面试·职场和发展·golang
CodeCraft Studio1 小时前
Excel处理控件Aspose.Cells教程:使用 Python 将 Pandas DataFrame 转换为 Excel
python·json·excel·pandas·csv·aspose·dataframe
flashlight_hi1 小时前
LeetCode 分类刷题:2563. 统计公平数对的数目
python·算法·leetcode
java1234_小锋1 小时前
Scikit-learn Python机器学习 - 特征预处理 - 归一化 (Normalization):MinMaxScaler
python·机器学习·scikit-learn
星空的资源小屋2 小时前
网易UU远程,免费电脑远程控制软件
人工智能·python·pdf·电脑
IMER SIMPLE2 小时前
人工智能-python-深度学习-神经网络-MobileNet V1&V2
人工智能·python·深度学习