PyTorch|保存及加载模型、nn.Sequential、ModuleList和ModuleDict

系列文章目录

PyTorch|Dataset与DataLoader使用、构建自定义数据集
PyTorch|搭建分类网络实例、nn.Module源码学习
pytorch|autograd使用、训练模型

文章目录


一、保存及加载模型

通过torch.save可以将该模型的参数、优化器状态、batch normalization、dropout、buffer变量等信息。

python 复制代码
import torch
import torchvision.models as models

(一)保存及加载模型的权重

模型取自torchvision.models里的vgg16,权重为IMAGENET1K_V1。

model.state_dict()是模型的权重。state_dict状态字典:一般包含当前model的参数及buffer变量

python 复制代码
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

推理时可以实现模型的加载:

  • 创建模型实例
  • 将实现保存的模型信息通过torch.load导入进来
  • 采用load_state_dict函数将模型信息载入模型实例
  • model.eval()使得模型进入推理模式
python 复制代码
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

(二)保存及加载优化器的权重

保存优化器权重:

加载优化器权重:

(三)保存及加载整个模型

保存整个模型:

python 复制代码
torch.save(model, 'model.pth')

加载整个模型:

python 复制代码
model = torch.load('model.pth')

(四)保存及加载更具一般性的checkpoint

保存并加载用于推理或恢复训练的一般性checkpoint有助于从上次中断的地方重新开始。在保存一般检查点时,不仅仅是保存模型的state_dict,还包括保存优化器的state_dict、停止使用的时间,最近记录的训练损失,外部的torch.nn.Embedding层等等。

python 复制代码
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

加载:

python 复制代码
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

(五)保存多个模型

保存多个模型时可以将其直接合并到一个大字典中保存。

python 复制代码
# Specify a path to save to
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

二、nn.Sequential源码分析

nn.Sequential是有序的,当实例化nn.Sequential时,传入的模块顺序就是神经网络前向传播的顺序

在使用nn.Sequential时,可以按顺序传入模块,也可以输入一个字典。

(一)init函数

如果输入的是一个字典,init函数会采用遍历字典的方式,如果是一个一个的模块,init函数也会针对性的采取其他遍历方法。

(二)forward函数

对于一个模型的输入,nn.Sequential会依次的过其中的子模块。

nn.Sequential相比于ModuleList和ModuleDict来说,优势在于具有forward的功能。

三、ModuleList和ModuleDict

(一)ModuleList

pytorch允许我们把很多子模块放到一个列表中。ModuleList就是用于存放多个子模块的一个列表,在使用时可以对其进行遍历。ModuleList不单纯是一个列表,它本身就是一个module。

(二)ModuleDict

ModuleDict是用于存放多个子模块的一个字典,在使用时可以根据索引获得对应的子模块。ModuleDict不单纯是一个字典,它本身也是一个module。

除此之外,还有ParameterList、ParameterDict等,这些与ModuleList和ModuleDict的作用及使用方式类似。

参考:
8、深入剖析PyTorch的state_dict、parameters、modules源码
9、深入剖析PyTorch的nn.Sequential及ModuleList源码

相关推荐
轻口味27 分钟前
【每日学点鸿蒙知识】沙箱目录、图片压缩、characteristicsArray、gm-crypto 国密加解密、通知权限
pytorch·华为·harmonyos
bryant_meng31 分钟前
【python】OpenCV—Image Moments
开发语言·python·opencv·moments·图片矩
车载诊断技术1 小时前
电子电气架构 --- 什么是EPS?
网络·人工智能·安全·架构·汽车·需求分析
KevinRay_1 小时前
Python超能力:高级技巧让你的代码飞起来
网络·人工智能·python·lambda表达式·列表推导式·python高级技巧
跃跃欲试-迪之1 小时前
animatediff 模型网盘分享
人工智能·stable diffusion
Captain823Jack1 小时前
nlp新词发现——浅析 TF·IDF
人工智能·python·深度学习·神经网络·算法·自然语言处理
被制作时长两年半的个人练习生1 小时前
【AscendC】ReduceSum中指定workLocal大小时如何计算
人工智能·算子开发·ascendc
资源补给站2 小时前
大恒相机开发(2)—Python软触发调用采集图像
开发语言·python·数码相机
Captain823Jack2 小时前
w04_nlp大模型训练·中文分词
人工智能·python·深度学习·神经网络·算法·自然语言处理·中文分词
Black_mario2 小时前
链原生 Web3 AI 网络 Chainbase 推出 AVS 主网, 拓展 EigenLayer AVS 应用场景
网络·人工智能·web3