本篇技术博文摘要 🌟
- 本文系统阐述了PyTorch中模型保存与加载的完整策略。
- 开篇对比了两种基本方法:一是直接保存整个模型,虽简便但存在模型定义依赖和版本兼容风险;二是推荐仅保存模型参数 (
state_dict),该方法灵活且安全,是实际开发中的首选。- 进而,文章详解了如何保存与加载完整的训练状态 (包括优化器状态、轮次等),以实现精确的训练断点续练。针对部署环节,重点探讨了跨设备加载模型的兼容性处理,如CPU与GPU间转换及多GPU训练模型的加载技巧。
- 同时,深入分析了模型转换与兼容性 核心问题,包括应对PyTorch版本差异的策略,以及将模型转换为TorchScript格式以实现独立于Python环境的部署。
- 文章还汇总了最佳实践 与常见问题解决方案 ,例如处理
Missing key(s) in state_dict报错和CUDA内存溢出等典型难题。- 最后,通过一个图像分类项目的完整代码示例与流程图,直观演示了从训练、保存到加载推理的端到端流程,将理论指导落地为可执行的实践代码,为开发者提供了一站式解决方案。
引言 📘
- 在这个变幻莫测、快速发展的技术时代,与时俱进是每个IT工程师的必修课。
- 我是盛透侧视攻城狮,一名什么都会一丢丢的网络安全工程师,也是众多技术社区的活跃成员以及多家大厂官方认可人员,希望能够与各位在此共同成长。

上节回顾
目录
[本篇技术博文摘要 🌟](#本篇技术博文摘要 🌟)
[引言 📘](#引言 📘)
[1.PyTorch 模型保存和加载](#1.PyTorch 模型保存和加载)
[6.2.1问题1:Missing key(s) in state_dict](#6.2.1问题1:Missing key(s) in state_dict)
[6.2.2问题2:CUDA out of memory](#6.2.2问题2:CUDA out of memory)

1.PyTorch 模型保存和加载
- 在深度学习项目中,模型保存和加载是至关重要的环节,主要原因包括:
- 训练中断恢复:当训练过程意外中断时,可以从保存点继续训练
- 模型部署:将训练好的模型部署到生产环境
- 模型共享:方便团队成员之间共享模型成果
- 迁移学习:保存预训练模型用于其他任务
- 性能评估:保存不同训练阶段的模型进行比较

2.基本保存和加载方法
2.1保存整个模型及示例
- 这是最简单的方法,保存模型的架构和参数:
python
import torch
import torchvision.models as models
# 创建并训练一个模型
model = models.resnet18(pretrained=True)
# ... 训练代码 ...
# 保存整个模型
torch.save(model, 'model.pth')
# 加载整个模型
loaded_model = torch.load('model.pth')
2.1.1保存整个模型优点
- 代码简单直观
- 保存了完整的模型结构
2.1.2保存整个模型缺点
- 文件体积较大
- 对模型类的定义有依赖

2.2推荐方式:仅保存模型参数及示例
- 更推荐的方式是只保存模型的状态字典(state_dict):
python
# 保存模型参数
torch.save(model.state_dict(), 'model_weights.pth')
# 加载模型参数
model = models.resnet18() # 必须先创建相同架构的模型
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 设置为评估模式
2.2.1仅保存模型参数优点:
- 文件更小
- 更灵活,可以加载到不同架构中
- 兼容性更好

3.保存和加载训练状态
- 在实际项目中,我们通常还需要保存优化器状态、epoch等信息:
3.1保存和加载训练状态及示例
python
# 保存检查点
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# 可以添加其他需要保存的信息
}
torch.save(checkpoint, 'checkpoint.pth')
# 加载检查点
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval() # 或者 model.train() 取决于你的需求

4.跨设备加载模型
4.1CPU/GPU兼容性处理示例
python
# 保存时指定map_location
torch.save(model.state_dict(), 'model_weights.pth')
# 加载到CPU(当模型是在GPU上训练时)
device = torch.device('cpu')
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
# 加载到GPU
device = torch.device('cuda')
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.to(device)
4.2多GPU训练模型加载示例
python
# 保存多GPU模型
torch.save(model.module.state_dict(), 'multigpu_model.pth')
# 加载到单GPU
model = ModelClass()
model.load_state_dict(torch.load('multigpu_model.pth'))

5.模型转换与兼容性
5.1PyTorch版本兼容性及示例
python
# 保存时指定_use_new_zipfile_serialization=True以获得更好的兼容性
torch.save(model.state_dict(), 'model.pth', _use_new_zipfile_serialization=True)
5.2转换为TorchScript及示例
python
# 将模型转换为TorchScript格式
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'model_scripted.pt')
# 加载TorchScript模型
loaded_script = torch.jit.load('model_scripted.pt')

6.最佳实践与常见问题
6.1最佳实践
- 命名规范 :使用有意义的文件名,如
resnet18_epoch50.pth- 定期保存:每隔几个epoch保存一次检查点
- 验证加载:保存后立即测试加载功能
- 文档记录:记录模型架构和训练参数
- 版本控制:将模型文件纳入版本控制系统

6.2常见问题解决方案
6.2.1问题1:Missing key(s) in state_dict
- 确保模型架构完全匹配,或使用
strict=False参数:
python
model.load_state_dict(torch.load('model.pth'), strict=False)
6.2.2问题2:CUDA out of memory
- 加载时先放到CPU:
python
model.load_state_dict(torch.load('model.pth', map_location='cpu'))

7.图像分类模型保存与加载流程代码汇总及示例
7.1图像分类模型保存与加载流程图

7.2完整代码示例
python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# 初始化
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# 模拟训练过程
for epoch in range(5):
# 模拟训练步骤
inputs = torch.randn(32, 10)
labels = torch.randint(0, 2, (32,))
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 每2个epoch保存一次检查点
if epoch % 2 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.item(),
}
torch.save(checkpoint, f'checkpoint_epoch{epoch}.pth')
print(f'Checkpoint saved at epoch {epoch}')
# 最终保存
torch.save(model.state_dict(), 'final_model.pth')
# 加载示例
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('final_model.pth'))
loaded_model.eval()
# 测试加载的模型
test_input = torch.randn(1, 10)
with torch.no_grad():
output = loaded_model(test_input)
print(f'Test output: {output}')

欢迎各位彦祖与热巴畅游本人专栏与技术博客
你的三连是我最大的动力
点击➡️指向的专栏名即可闪现
➡️计算机组成原理****
➡️操作系统
➡️****渗透终极之红队攻击行动********
➡️ 动画可视化数据结构与算法
➡️ 永恒之心蓝队联纵合横防御
➡️****华为高级网络工程师********
➡️****华为高级防火墙防御集成部署********
➡️ 未授权访问漏洞横向渗透利用
➡️****逆向软件破解工程********
➡️****MYSQL REDIS 进阶实操********
➡️****红帽高级工程师
➡️红帽系统管理员********
➡️****HVV 全国各地面试题汇总********

PyTorch版本兼容性