Lnton羚通关于PyTorch的保存和加载模型基础知识

SAVE AND LOAD THE MODEL (保存和加载模型)

PyTorch 模型存储学习到的参数在内部状态字典中,称为 state_dict, 他们的持久化通过 torch.save 方法。

复制代码
model = models.shufflenet_v2_x0_5(pretrained=True)
torch.save(model, "../../data/ShuffleNetV2_X0.5.pth")

如果要加载模型的话,首先需要实例化一个同类型的模型对象,然后用 load_state_dict() 方法加载参数。

复制代码
model = models.shufflenet_v2_x0_5()
model.load_state_dict(torch.load("../../data/ShuffleNetV2_X0.5.pth"))
model.eval()

Output exceeds the size limit. Open the full output data in a text editor
ShuffleNetV2(
  (conv1): Sequential(
    (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (stage2): Sequential(
    (0): InvertedResidual(
      (branch1): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU(inplace=True)
      )
      (branch2): Sequential(
        (0): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
        (4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (6): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (7): ReLU(inplace=True)
...
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (fc): Linear(in_features=1024, out_features=1000, bias=True)
)

Saving and Loading Models with Shapes

当加载模型权重时,我们需要首先实例化模型类,因为类定义了网络的结构。我们可能想要保存类的结构以及模型,在这种情况下,我们可以将 model (而不是 model.state_dict() ) 传递给保存函数:

复制代码
torch.save(model, "../../data/ShuffleNetV2_X0.5_eval2.pth")

加载模型如这样:

复制代码
model = torch.load("../../data/ShuffleNetV2_X0.5_eval2.pth")
print(model)

这种方法在序列化模型时使用 Python pickle 模块,因此它依赖于加载模型时可用的实际类定义。

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

相关推荐
2501_941982059 分钟前
结合 AI 视觉:使用 OCR 识别企业微信聊天记录中的图片信息
人工智能·ocr·企业微信
我送炭你添花11 分钟前
Pelco KBD300A 模拟器:06+2.Pelco KBD300A 模拟器项目重构指南
python·重构·自动化·运维开发
Swizard13 分钟前
别再只会算直线距离了!用“马氏距离”揪出那个伪装的数据“卧底”
python·算法·ai
站大爷IP14 分钟前
Python函数与模块化编程:局部变量与全局变量的深度解析
python
我命由我1234522 分钟前
Python Flask 开发问题:ImportError: cannot import name ‘Markup‘ from ‘flask‘
开发语言·后端·python·学习·flask·学习方法·python3.11
事变天下24 分钟前
肾尚科技完成新一轮融资,加速慢性肾脏病(CKD)精准化管理闭环渗透
大数据·人工智能
GEO AI搜索优化助手26 分钟前
范式革命——从“关键词”到“意图理解”,搜索本质的演进与重构
人工智能·搜索引擎·生成式引擎优化·ai优化·geo搜索优化
大刘讲IT27 分钟前
2025年企业级 AI Agent 标准化落地深度年度总结:从“对话”到“端到端价值闭环”的范式重构
大数据·人工智能·程序人生·ai·重构·制造
databook30 分钟前
掌握相关性分析:读懂数据间的“悄悄话”
python·数据挖掘·数据分析
2301_8234380235 分钟前
【无标题】解析《采用非对称自玩实现强健多机器人群集的深度强化学习方法》
数据库·人工智能·算法