在 PyTorch 中,保存和加载模型主要有两种场景:一种是仅用于推理 (Inference),另一种是用于恢复训练(Resume Training)。
为了让你更直观地理解,我整理了一个核心对比表,随后会详细展开代码示例。
📊 核心对比:保存方式
| 场景 | 保存内容 | 推荐后缀 | 优点 | 缺点 |
|---|---|---|---|---|
| 仅推理 | 仅模型参数 (state_dict) |
.pt 或 .pth |
文件体积小,代码解耦,跨设备兼容性好。 | 加载前必须先定义模型结构类。 |
| 恢复训练 | 参数 + 优化器 + Epoch + Loss | .tar |
包含训练状态(如优化器动量),可无缝续训。 | 文件体积大,依赖具体的代码结构。 |
1. 仅保存和加载模型参数(最常用)
这是官方强烈推荐的方式。它只保存模型的"学习成果"(权重和偏置),不保存模型的结构代码。
保存代码
python
# 假设 model 是你的网络结构
torch.save(model.state_dict(), 'model_weights.pth')
加载代码
注意:加载时必须先实例化一个同结构的模型对象,然后再把参数"灌"进去。
python
# 1. 必须先定义/实例化模型结构(参数要和保存时一致)
model = MyModelClass(*args, **kwargs)
# 2. 加载参数字典
model.load_state_dict(torch.load('model_weights.pth'))
# 3. 重要:推理前必须调用 eval(),将 Dropout 和 BatchNorm 切换到评估模式
model.eval()
提示 :
model.eval()非常重要,如果不调用,推理时的结果会因为 Dropout 等层的随机性而不准确。
2. 保存检查点(Checkpoint)以恢复训练
如果你训练到一半断电了,或者想接着之前的模型继续跑,就需要保存"检查点"。这不仅包含模型参数,还包含优化器的状态(如动量)、当前的 Epoch 数和你记录的 Loss。
保存代码
通常我们将这些信息打包成一个字典(Dictionary),并建议使用 .tar 后缀。
python
# 假设 train_loader, optimizer, epoch, loss 都是当前状态
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# 如果有学习率调度器也可以保存
# 'scheduler_state_dict': scheduler.state_dict(),
}, 'checkpoint.tar')
加载代码
加载时需要先初始化模型和优化器,然后加载字典并分别赋值。
python
# 1. 初始化模型和优化器
model = MyModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(model.parameters(), lr=0.01)
# 2. 加载检查点
checkpoint = torch.load('checkpoint.tar')
# 3. 恢复状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# 4. 恢复训练模式
model.train()
3. 常见坑点:GPU 与 CPU 的兼容
如果你是在 GPU 上训练的模型,但想在没有 GPU 的机器上(CPU)加载,或者反过来,需要特别注意 map_location 参数。
场景 A:GPU 训练 -> CPU 加载
python
# 强制将数据映射到 CPU
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
场景 B:CPU 训练 -> GPU 加载
python
# 1. 先加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 2. 再移动到 GPU
model.to(torch.device('cuda'))
📌 总结建议
- 日常训练:建议每个 Epoch 保存一次 Checkpoint(包含 optimizer),防止意外中断。
- 最终发布 :训练完成后,提取
model.state_dict()单独保存一个文件,用于最终的部署和推理,这样文件更小且更安全。 - 永远记得 :推理前加
model.eval(),续训前加model.train()。