pytorch中保存训练模型和加载训练模型的用法

在 PyTorch 中,保存和加载模型主要有两种场景:一种是仅用于推理 (Inference),另一种是用于恢复训练(Resume Training)。

为了让你更直观地理解,我整理了一个核心对比表,随后会详细展开代码示例。

📊 核心对比:保存方式

场景 保存内容 推荐后缀 优点 缺点
仅推理 仅模型参数 (state_dict) .pt.pth 文件体积小,代码解耦,跨设备兼容性好。 加载前必须先定义模型结构类。
恢复训练 参数 + 优化器 + Epoch + Loss .tar 包含训练状态(如优化器动量),可无缝续训。 文件体积大,依赖具体的代码结构。

1. 仅保存和加载模型参数(最常用)

这是官方强烈推荐的方式。它只保存模型的"学习成果"(权重和偏置),不保存模型的结构代码。

保存代码
python 复制代码
# 假设 model 是你的网络结构
torch.save(model.state_dict(), 'model_weights.pth')
加载代码

注意:加载时必须先实例化一个同结构的模型对象,然后再把参数"灌"进去。

python 复制代码
# 1. 必须先定义/实例化模型结构(参数要和保存时一致)
model = MyModelClass(*args, **kwargs)

# 2. 加载参数字典
model.load_state_dict(torch.load('model_weights.pth'))

# 3. 重要:推理前必须调用 eval(),将 Dropout 和 BatchNorm 切换到评估模式
model.eval() 

提示model.eval() 非常重要,如果不调用,推理时的结果会因为 Dropout 等层的随机性而不准确。


2. 保存检查点(Checkpoint)以恢复训练

如果你训练到一半断电了,或者想接着之前的模型继续跑,就需要保存"检查点"。这不仅包含模型参数,还包含优化器的状态(如动量)、当前的 Epoch 数和你记录的 Loss。

保存代码

通常我们将这些信息打包成一个字典(Dictionary),并建议使用 .tar 后缀。

python 复制代码
# 假设 train_loader, optimizer, epoch, loss 都是当前状态
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # 如果有学习率调度器也可以保存
    # 'scheduler_state_dict': scheduler.state_dict(), 
}, 'checkpoint.tar')
加载代码

加载时需要先初始化模型和优化器,然后加载字典并分别赋值。

python 复制代码
# 1. 初始化模型和优化器
model = MyModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(model.parameters(), lr=0.01)

# 2. 加载检查点
checkpoint = torch.load('checkpoint.tar')

# 3. 恢复状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 4. 恢复训练模式
model.train() 

3. 常见坑点:GPU 与 CPU 的兼容

如果你是在 GPU 上训练的模型,但想在没有 GPU 的机器上(CPU)加载,或者反过来,需要特别注意 map_location 参数。

场景 A:GPU 训练 -> CPU 加载
python 复制代码
# 强制将数据映射到 CPU
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
场景 B:CPU 训练 -> GPU 加载
python 复制代码
# 1. 先加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 2. 再移动到 GPU
model.to(torch.device('cuda'))

📌 总结建议

  1. 日常训练:建议每个 Epoch 保存一次 Checkpoint(包含 optimizer),防止意外中断。
  2. 最终发布 :训练完成后,提取 model.state_dict() 单独保存一个文件,用于最终的部署和推理,这样文件更小且更安全。
  3. 永远记得 :推理前加 model.eval(),续训前加 model.train()
相关推荐
羊羊小栈2 小时前
基于「YOLO目标检测 + 多模态AI分析」的篮球动作规范智能检测分析预警系统
人工智能·yolo·目标检测·计算机视觉·毕业设计·大作业
天上路人2 小时前
双波束拾音技术在双向翻译机中的应用 —— 基于 A-59F 模组的原理、效果与场景解析
人工智能·语音识别
Fleshy数模2 小时前
基于 LangChain 实现 PDF 文档检索:从加载到向量检索全流程
人工智能·数据挖掘·langchain·大模型
小袁说公考2 小时前
公考培训机构2025年度测评:财务健康度与用户体验重构排名格局
大数据·人工智能·经验分享·笔记·其他·重构·ux
xinlianyq2 小时前
2026 AI 生成电商短视频工具推荐与性价比分析,电商垂类与综合工具
人工智能·ai
跨境猫小妹2 小时前
爆款复制难度提高之后跨境卖家如何转向稳定型商品布局
大数据·人工智能·产品运营·跨境电商·营销策略
跟尚西学PowerBI2 小时前
【供应链AI实践案例】OpenClaw+PowerBI 打造 AI 智能库存预警实战
大数据·人工智能·数据分析·openclaw
动物园猫2 小时前
交通标识与信号灯数据集分享(适用于YOLO系列深度学习检测任务)
人工智能·深度学习·yolo
weixin_377634842 小时前
【SkillRL】强化学习详解
人工智能
吃好睡好便好2 小时前
在Matlab中绘制抛物三维曲面图
开发语言·人工智能·学习·算法·matlab·信息可视化