深度学习Pytorch中的模型保存与加载方法

深度学习:Pytorch中的模型保存与加载方法

在 PyTorch 中,模型的保存和加载对于模型的持久化和后续应用至关重要。这里详细介绍了两种主要方法:保存整个模型(包括架构和参数)和仅保存模型的状态字典。以下内容进一步完善了加载模型时使用 load_state_dict 方法的细节和参数设置。

1. 保存和加载整个模型

方法描述

这种方法涉及使用 torch.save 函数直接保存模型的整个对象,并通过 torch.load 来恢复整个模型,包括其架构和所有参数。

参数和设置
  • 保存时的参数和设置

    • model: 需要保存的模型实例。
    • filepath: 模型保存的文件路径,如 'vgg16_complete_model.pth'
  • 加载时的参数和设置

    • filepath: 保存的模型文件路径,用于加载模型。
示例代码
python 复制代码
import torch
import torchvision.models as models

# 初始化模型
model = models.vgg16(pretrained=True)

# 保存整个模型
torch.save(model, 'vgg16_complete_model.pth')

# 加载整个模型
loaded_model = torch.load('vgg16_complete_model.pth')
print("Loaded Model:", loaded_model)
优点
  • 快速便捷,可以直接恢复整个模型,包括其架构和参数。
缺点
  • 安全风险较高,因为使用了 pickle 序列化,可能在加载时执行未知代码。
  • 文件体积较大,包含了整个模型的详细信息。

2. 仅保存和加载模型的状态字典

方法描述

此方法仅保存模型的参数,使用 state_dict 获取所有参数的字典,并通过先定义模型架构后使用 load_state_dict 方法加载参数。

参数和设置
  • 保存时的参数和设置

    • model.state_dict(): 调用模型的 state_dict() 方法获取所有参数的字典。
    • filepath: 参数字典保存的文件路径,如 'vgg16_model_state_dict.pth'
  • 加载时的参数和设置

    • filepath: 保存的状态字典文件路径,用于加载参数。
    • 需要先重新定义模型结构以匹配状态字典。
    • 使用 load_state_dict 方法将参数加载到模型。
示例代码
python 复制代码
import torch
import torchvision.models as models

# 初始化预训练模型
model = models.vgg16(pretrained=True)

# 保存模型的状态字典
torch.save(model.state_dict(), 'vgg16_model_state_dict.pth')

# 重新定义模型架构
model = models.vgg16()  # 重新初始化模型

# 加载模型参数
model.load_state_dict(torch.load('vgg16_model_state_dict.pth'))
print("Loaded State Dict:", model)
load_state_dict 方法详解
  • load_state_dict 是一个 PyTorch 方法,用于将保存的状态字典加载到模型的参数中。
  • 此方法确保所有保存的参数能正确映射到模型的相应参数上。
  • 如果状态字典中的参数与模型架构不匹配,将抛出错误。
优点
  • 安全性高,不涉及 pickle,避免潜在的代码执行风险。
  • 文件体积小,因为仅包含参数。
缺点
  • 需要在加载参数前重新定义模型架构,可能在结构复杂或难以重建时带来不便。

比较异同

  • 相同点

    • 使用 torch.save()torch.load() 函数来保存和加载。
    • 都能有效地保存和恢复模型数据,方便模型的迁移和再使用。
  • 不同点

    • 安全性:保存整个模型可能面临执行恶意代码的风险,而保存状态字典则相对安全。
    • 便利性:保存整个模型可以直接加载使用,适合快速部署;而保存状态字典需要先定义架构,但提供更高的灵活性和安全性。
    • 文件大小:整个模型文件通常较大;状态字典文件体积较小,仅包含参数。

这些详细的说明帮助开发者根据具体需求选择合适的方法来保存和加载 PyTorch 模型,确保模型的有效部署和应用,同时最大化效率和安全性。

相关推荐
方见华Richard几秒前
整数阶时间重参数化:基于自适应豪斯多夫维数的偏微分方程正则化新框架
人工智能·笔记·交互·原型模式·空间计算
盼小辉丶12 分钟前
PyTorch实战(27)——自动混合精度训练
pytorch·深度学习·混合精度训练
aihuangwu14 分钟前
如何把豆包的回答导出
人工智能·ai·deepseek·ds随心转
好奇龙猫16 分钟前
【人工智能学习-AI入试相关题目练习-第十六次】
人工智能·学习
bing.shao19 分钟前
Golang 开发者视角:解读《“人工智能 + 制造” 专项行动》的技术落地机遇
人工智能·golang·制造
LOnghas121119 分钟前
玉米目标检测实战:基于YOLO13-C3k2-RFAConv的优化方案_1
人工智能·目标检测·计算机视觉
量子-Alex30 分钟前
【大模型课程笔记】斯坦福大学CS336 课程环境配置与讲座生成完整指南
人工智能·笔记
冬奇Lab34 分钟前
一天一个开源项目(第9篇):NexaSDK - 跨平台设备端 AI 运行时,让前沿模型在本地运行
人工智能·开源
量子-Alex44 分钟前
【大模型技术报告】Qwen2-VL大模型训练过程理解
人工智能
java1234_小锋1 小时前
【AI大模型舆情分析】微博舆情分析可视化系统(pytorch2+基于BERT大模型训练微调+flask+pandas+echarts) 实战(上)
人工智能·flask·大模型·bert