pytorch中保存训练模型和加载训练模型的用法

在 PyTorch 中,保存和加载模型主要有两种场景:一种是仅用于推理 (Inference),另一种是用于恢复训练(Resume Training)。

为了让你更直观地理解,我整理了一个核心对比表,随后会详细展开代码示例。

📊 核心对比:保存方式

场景 保存内容 推荐后缀 优点 缺点
仅推理 仅模型参数 (state_dict) .pt.pth 文件体积小,代码解耦,跨设备兼容性好。 加载前必须先定义模型结构类。
恢复训练 参数 + 优化器 + Epoch + Loss .tar 包含训练状态(如优化器动量),可无缝续训。 文件体积大,依赖具体的代码结构。

1. 仅保存和加载模型参数(最常用)

这是官方强烈推荐的方式。它只保存模型的"学习成果"(权重和偏置),不保存模型的结构代码。

保存代码
python 复制代码
# 假设 model 是你的网络结构
torch.save(model.state_dict(), 'model_weights.pth')
加载代码

注意:加载时必须先实例化一个同结构的模型对象,然后再把参数"灌"进去。

python 复制代码
# 1. 必须先定义/实例化模型结构(参数要和保存时一致)
model = MyModelClass(*args, **kwargs)

# 2. 加载参数字典
model.load_state_dict(torch.load('model_weights.pth'))

# 3. 重要:推理前必须调用 eval(),将 Dropout 和 BatchNorm 切换到评估模式
model.eval() 

提示model.eval() 非常重要,如果不调用,推理时的结果会因为 Dropout 等层的随机性而不准确。


2. 保存检查点(Checkpoint)以恢复训练

如果你训练到一半断电了,或者想接着之前的模型继续跑,就需要保存"检查点"。这不仅包含模型参数,还包含优化器的状态(如动量)、当前的 Epoch 数和你记录的 Loss。

保存代码

通常我们将这些信息打包成一个字典(Dictionary),并建议使用 .tar 后缀。

python 复制代码
# 假设 train_loader, optimizer, epoch, loss 都是当前状态
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # 如果有学习率调度器也可以保存
    # 'scheduler_state_dict': scheduler.state_dict(), 
}, 'checkpoint.tar')
加载代码

加载时需要先初始化模型和优化器,然后加载字典并分别赋值。

python 复制代码
# 1. 初始化模型和优化器
model = MyModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(model.parameters(), lr=0.01)

# 2. 加载检查点
checkpoint = torch.load('checkpoint.tar')

# 3. 恢复状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 4. 恢复训练模式
model.train() 

3. 常见坑点:GPU 与 CPU 的兼容

如果你是在 GPU 上训练的模型,但想在没有 GPU 的机器上(CPU)加载,或者反过来,需要特别注意 map_location 参数。

场景 A:GPU 训练 -> CPU 加载
python 复制代码
# 强制将数据映射到 CPU
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
场景 B:CPU 训练 -> GPU 加载
python 复制代码
# 1. 先加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 2. 再移动到 GPU
model.to(torch.device('cuda'))

📌 总结建议

  1. 日常训练:建议每个 Epoch 保存一次 Checkpoint(包含 optimizer),防止意外中断。
  2. 最终发布 :训练完成后,提取 model.state_dict() 单独保存一个文件,用于最终的部署和推理,这样文件更小且更安全。
  3. 永远记得 :推理前加 model.eval(),续训前加 model.train()
相关推荐
xinlianyq2 小时前
2026 交互革命:当“图形界面”消亡于智能体(Agent)的语义洪流
人工智能·api
墨染天姬2 小时前
【AI】Gemma 4
人工智能
C+++Python2 小时前
如何学习Python的应用领域知识?
开发语言·python·学习
北京耐用通信2 小时前
工业通信升级:耐达讯自动化CAN转EtherCAT网关的高效落地方案
服务器·人工智能·科技·物联网·自动化·信息与通信
疯狂打码的少年2 小时前
【Day12 Java转Python】Python工程的“骨架”——模块、包与__name__
java·开发语言·python
LarryHai62 小时前
AI 大模型思维链原理:从COT到AOT,解锁大模型的推理潜力
人工智能·aot·cot·tot·大模型推理·大模型思维链
ueotek2 小时前
Ansys Zemax | 在 MATLAB 或 Python 中使用 ZOS-API 进行光线追迹的批次处理
python·matlab·ansys·zemax·光学软件
Lab_AI2 小时前
山东兴文携手创腾科技打造数智化研发新标杆!电子实验记录本ELN在精细化工领域再添标杆用户
人工智能·数字化转型·企业数据管理·数智化转型·电子实验记录本
Henry-SAP2 小时前
SAP MRP PIR消耗机制解析
人工智能·sap·erp