来自 千问
在 PyTorch 中,保存和加载模型是深度学习工程中至关重要的一步。根据你的需求(是继续训练还是仅仅推理),有不同的最佳实践。
这里为你总结了 PyTorch 模型保存与加载的三种主要方式:
1. 仅保存/加载模型参数(⭐ 强烈推荐:用于部署和推理)
这是最常用、最灵活的方式。它只保存模型的权重(state_dict),不保存模型的结构。
- 优点:文件体积小,移植性好,可以在不同代码结构间迁移权重。
- 缺点:加载时必须先定义好模型结构。
python
import torch
import torch.nn as nn
# 假设这是你的模型定义
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# --- 保存 ---
model = MyModel()
# 训练代码...
torch.save(model.state_dict(), 'model_weights.pth') # 推荐使用 .pth 或 .pt 后缀
# --- 加载 (用于推理) ---
# 必须先重建模型结构
model = MyModel()
# 加载权重
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 切换到评估模式,关闭 Dropout/BatchNorm 的训练行为
2. 保存/加载整个模型(不推荐用于生产)
这种方式将模型的结构和参数一起打包保存。
- 优点:加载时不需要重新定义模型类,一行代码搞定。
- 缺点:文件大,依赖具体的类定义路径。如果代码重构了(比如改了类名),加载时可能会报错;且存在安全风险。
python
# --- 保存 ---
torch.save(model, 'whole_model.pth')
# --- 加载 ---
# 直接得到模型实例
model = torch.load('whole_model.pth')
model.eval()
3. 保存检查点(用于断点续训)
如果你需要中断训练并稍后继续(比如服务器断连了),你需要保存更多的信息,包括优化器状态、当前轮数和损失值。
python
# 假设你有优化器
optimizer = torch.optim.Adam(model.parameters())
# --- 保存检查点 ---
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# 可选:保存随机状态以保证结果可复现
'rng_state': torch.get_rng_state()
}
torch.save(checkpoint, 'checkpoint_epoch_{}.pth'.format(epoch))
# --- 加载检查点 (恢复训练) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load('checkpoint_epoch_10.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
loss = checkpoint['loss']
# 恢复随机状态
torch.set_rng_state(checkpoint['rng_state'])
model.train() # 别忘了切换回训练模式
💡 关键技巧与避坑指南
-
.eval()和.train()模式:- 推理前 务必调用
model.eval()。这会关闭Dropout和BatchNorm的训练行为,保证推理结果稳定。 - 继续训练前 务必调用
model.train()。
- 推理前 务必调用
-
设备不匹配问题(CPU/GPU) :
如果你在 GPU 上训练,想在 CPU 上推理,或者反之,使用
map_location参数:python# 无论当前是什么设备,都先加载到 CPU 上 state_dict = torch.load('model.pth', map_location='cpu') # 或者自动映射到可用设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") state_dict = torch.load('model.pth', map_location=device) -
多卡训练模型的加载 :
如果模型是在多 GPU(
DataParallel)上训练的,保存的权重键名会带有module.前缀。如果在单卡上加载会报错。解决方法是在加载时去掉前缀:pythonstate_dict = torch.load('model.pth') # 去除 'module.' 前缀 from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) -
文件扩展名 :
PyTorch 官方推荐使用
.pt或.pth作为文件后缀。
总结建议:
- 如果是为了部署上线 ,使用方式 1(仅保存参数)。
- 如果是为了断点续训 ,使用方式 3(保存检查点)。