Pytorch如何保存和加载模型

来自 千问

在 PyTorch 中,保存和加载模型是深度学习工程中至关重要的一步。根据你的需求(是继续训练还是仅仅推理),有不同的最佳实践。

这里为你总结了 PyTorch 模型保存与加载的三种主要方式:

1. 仅保存/加载模型参数(⭐ 强烈推荐:用于部署和推理)

这是最常用、最灵活的方式。它只保存模型的权重(state_dict),不保存模型的结构。

  • 优点:文件体积小,移植性好,可以在不同代码结构间迁移权重。
  • 缺点:加载时必须先定义好模型结构。
python 复制代码
import torch
import torch.nn as nn

# 假设这是你的模型定义
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.fc(x)

# --- 保存 ---
model = MyModel()
# 训练代码...
torch.save(model.state_dict(), 'model_weights.pth') # 推荐使用 .pth 或 .pt 后缀

# --- 加载 (用于推理) ---
# 必须先重建模型结构
model = MyModel()
# 加载权重
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 切换到评估模式,关闭 Dropout/BatchNorm 的训练行为

2. 保存/加载整个模型(不推荐用于生产)

这种方式将模型的结构和参数一起打包保存。

  • 优点:加载时不需要重新定义模型类,一行代码搞定。
  • 缺点:文件大,依赖具体的类定义路径。如果代码重构了(比如改了类名),加载时可能会报错;且存在安全风险。
python 复制代码
# --- 保存 ---
torch.save(model, 'whole_model.pth')

# --- 加载 ---
# 直接得到模型实例
model = torch.load('whole_model.pth')
model.eval()

3. 保存检查点(用于断点续训)

如果你需要中断训练并稍后继续(比如服务器断连了),你需要保存更多的信息,包括优化器状态、当前轮数和损失值。

python 复制代码
# 假设你有优化器
optimizer = torch.optim.Adam(model.parameters())

# --- 保存检查点 ---
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    # 可选:保存随机状态以保证结果可复现
    'rng_state': torch.get_rng_state()
}
torch.save(checkpoint, 'checkpoint_epoch_{}.pth'.format(epoch))

# --- 加载检查点 (恢复训练) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load('checkpoint_epoch_10.pth', map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
loss = checkpoint['loss']

# 恢复随机状态
torch.set_rng_state(checkpoint['rng_state'])

model.train() # 别忘了切换回训练模式

💡 关键技巧与避坑指南

  1. .eval().train() 模式

    • 推理前 务必调用 model.eval()。这会关闭 DropoutBatchNorm 的训练行为,保证推理结果稳定。
    • 继续训练前 务必调用 model.train()
  2. 设备不匹配问题(CPU/GPU)

    如果你在 GPU 上训练,想在 CPU 上推理,或者反之,使用 map_location 参数:

    python 复制代码
    # 无论当前是什么设备,都先加载到 CPU 上
    state_dict = torch.load('model.pth', map_location='cpu')
    # 或者自动映射到可用设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    state_dict = torch.load('model.pth', map_location=device)
  3. 多卡训练模型的加载

    如果模型是在多 GPU(DataParallel)上训练的,保存的权重键名会带有 module. 前缀。如果在单卡上加载会报错。解决方法是在加载时去掉前缀:

    python 复制代码
    state_dict = torch.load('model.pth')
    # 去除 'module.' 前缀
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
  4. 文件扩展名

    PyTorch 官方推荐使用 .pt.pth 作为文件后缀。

总结建议

  • 如果是为了部署上线 ,使用方式 1(仅保存参数)。
  • 如果是为了断点续训 ,使用方式 3(保存检查点)。
相关推荐
Piar1231sdafa4 小时前
蓝莓目标检测——改进YOLO11-C2TSSA-DYT-Mona模型实现
人工智能·目标检测·计算机视觉
愚公搬代码4 小时前
【愚公系列】《AI短视频创作一本通》002-AI引爆短视频创作革命(短视频创作者必备的能力)
人工智能
数据猿视觉4 小时前
新品上市|奢音S5耳夹耳机:3.5g无感佩戴,178.8元全场景适配
人工智能
2301_790300965 小时前
Python单元测试(unittest)实战指南
jvm·数据库·python
蚁巡信息巡查系统5 小时前
网站信息发布再巡查机制怎么建立?
大数据·人工智能·数据挖掘·内容运营
AI浩5 小时前
C-RADIOv4(技术报告)
人工智能·目标检测
Purple Coder5 小时前
AI赋予超导材料预测论文初稿
人工智能
Data_Journal5 小时前
Scrapy vs. Crawlee —— 哪个更好?!
运维·人工智能·爬虫·媒体·社媒营销
云边云科技_云网融合5 小时前
AIoT智能物联网平台:架构解析与边缘应用新图景
大数据·网络·人工智能·安全
VCR__5 小时前
python第三次作业
开发语言·python