pytorch中保存训练模型和加载训练模型的用法

在 PyTorch 中,保存和加载模型主要有两种场景:一种是仅用于推理 (Inference),另一种是用于恢复训练(Resume Training)。

为了让你更直观地理解,我整理了一个核心对比表,随后会详细展开代码示例。

📊 核心对比:保存方式

场景 保存内容 推荐后缀 优点 缺点
仅推理 仅模型参数 (state_dict) .pt.pth 文件体积小,代码解耦,跨设备兼容性好。 加载前必须先定义模型结构类。
恢复训练 参数 + 优化器 + Epoch + Loss .tar 包含训练状态(如优化器动量),可无缝续训。 文件体积大,依赖具体的代码结构。

1. 仅保存和加载模型参数(最常用)

这是官方强烈推荐的方式。它只保存模型的"学习成果"(权重和偏置),不保存模型的结构代码。

保存代码
python 复制代码
# 假设 model 是你的网络结构
torch.save(model.state_dict(), 'model_weights.pth')
加载代码

注意:加载时必须先实例化一个同结构的模型对象,然后再把参数"灌"进去。

python 复制代码
# 1. 必须先定义/实例化模型结构(参数要和保存时一致)
model = MyModelClass(*args, **kwargs)

# 2. 加载参数字典
model.load_state_dict(torch.load('model_weights.pth'))

# 3. 重要:推理前必须调用 eval(),将 Dropout 和 BatchNorm 切换到评估模式
model.eval() 

提示model.eval() 非常重要,如果不调用,推理时的结果会因为 Dropout 等层的随机性而不准确。


2. 保存检查点(Checkpoint)以恢复训练

如果你训练到一半断电了,或者想接着之前的模型继续跑,就需要保存"检查点"。这不仅包含模型参数,还包含优化器的状态(如动量)、当前的 Epoch 数和你记录的 Loss。

保存代码

通常我们将这些信息打包成一个字典(Dictionary),并建议使用 .tar 后缀。

python 复制代码
# 假设 train_loader, optimizer, epoch, loss 都是当前状态
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # 如果有学习率调度器也可以保存
    # 'scheduler_state_dict': scheduler.state_dict(), 
}, 'checkpoint.tar')
加载代码

加载时需要先初始化模型和优化器,然后加载字典并分别赋值。

python 复制代码
# 1. 初始化模型和优化器
model = MyModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(model.parameters(), lr=0.01)

# 2. 加载检查点
checkpoint = torch.load('checkpoint.tar')

# 3. 恢复状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 4. 恢复训练模式
model.train() 

3. 常见坑点:GPU 与 CPU 的兼容

如果你是在 GPU 上训练的模型,但想在没有 GPU 的机器上(CPU)加载,或者反过来,需要特别注意 map_location 参数。

场景 A:GPU 训练 -> CPU 加载
python 复制代码
# 强制将数据映射到 CPU
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
场景 B:CPU 训练 -> GPU 加载
python 复制代码
# 1. 先加载到 CPU
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 2. 再移动到 GPU
model.to(torch.device('cuda'))

📌 总结建议

  1. 日常训练:建议每个 Epoch 保存一次 Checkpoint(包含 optimizer),防止意外中断。
  2. 最终发布 :训练完成后,提取 model.state_dict() 单独保存一个文件,用于最终的部署和推理,这样文件更小且更安全。
  3. 永远记得 :推理前加 model.eval(),续训前加 model.train()
相关推荐
copyer_xyf7 分钟前
Prompt 组织管理
后端·python·agent
EasyCVR11 分钟前
国标GB28181视频监控平台EasyCVR夏季安防高风险场景的解决方案
人工智能·音视频
美狐美颜SDK开放平台18 分钟前
直播APP开发与第三方美颜SDK开发/集成实践分享
人工智能·美颜sdk·直播美颜sdk·视频美颜sdk·美颜api
邵宇然23 分钟前
llama.cpp 推理底座调优:从 KV Cache 到连续批处理的性能深潜
人工智能
云安全助手27 分钟前
Anthropic年度报告解读:AI重塑网络攻击形态,传统防御体系亟待升级
人工智能·安全·网络安全·ai大模型
pythonpioneer27 分钟前
PyTorch3D:基于 PyTorch 的高效 3D 深度学习工具库
pytorch·深度学习·其他·3d
谁似人间西林客36 分钟前
汽车智能制造解决方案:如何通过智能仓储物流降本提效?
人工智能·汽车·制造
shimly12345642 分钟前
python3 uvicorn 是啥?
python
jiushiapwojdap1 小时前
Antigravity Awesome Skills:1527+ AI 编程助手的可安装技能库
人工智能·其他