昇思25天学习打卡营第8天|MindSpore保存与加载(保存和加载MindIR)

在MindIR中,一个函数图(FuncGraph)表示一个普通函数的定义,函数图一般由ParameterNode、ValueNode和CNode组成有向无环图,可以清晰地表达出从参数到返回值的计算过程。在上图中可以看出,python代码中两个函数test_f和func转换成了两个函数图,其参数x和y转换为函数图的ParameterNode,每一个表达式转换为一个CNode。CNode的第一个输入链接着调用的函数,例如图中的add、func、return。值得注意的是这些节点均是ValueNode,因为它们被理解为常数函数值。CNode的其他输入链接这调用的参数,参数值可以来自于ParameterNode、ValueNode和其他CNode。

在ANF中每个表达式都用let表达式绑定为一个变量,通过对变量的引用来表示对表达式输出的依赖,而在MindIR中每个表达式都绑定为一个节点,通过节点与节点之间的有向边表示依赖关系。

介绍了如何调整超参数,并进行网络模型训练。在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。

js 复制代码
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

保存和加载模型权重

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

js 复制代码
model = network()
mindspore.save_checkpoint(model, "model.ckpt")

要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint和load_param_into_net方法加载参数。

js 复制代码
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
[]

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

保存和加载MindIR

除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

js 复制代码
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。

nn.GraphCell仅支持图模式。

js 复制代码
mindspore.set_context(mode=mindspore.GRAPH_MODE)
​
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
(1, 10)
相关推荐
helloyaren3 小时前
Docker Desktop里搭建RabbitMq 4.1.3集群的保姆级教程
学习·rabbitmq·集群
艾莉丝努力练剑3 小时前
【C语言16天强化训练】从基础入门到进阶:Day 6
c语言·数据结构·学习·算法
Insist7534 小时前
k8s----学习站点搭建
学习
月盈缺4 小时前
学习嵌入式第二十三天——数据结构——栈
数据结构·学习
mysla5 小时前
嵌入式学习day34-网络-tcp/udp
服务器·网络·学习
Moonnnn.5 小时前
【51单片机学习】AT24C02(I2C)、DS18B20(单总线)、LCD1602(液晶显示屏)
笔记·单片机·学习·51单片机
牛奶咖啡136 小时前
学习设计模式《二十三》——桥接模式
学习·设计模式·桥接模式·认识桥接模式·桥接模式的优点·何时选用桥接模式·桥接模式的使用示例
Jenkinscao7 小时前
我从零开始学习C语言(13)- 循环语句 PART2
c语言·开发语言·学习
新子y8 小时前
【操作记录】我的 MNN Android LLM 编译学习笔记记录(一)
android·学习·mnn
✎ ﹏梦醒͜ღ҉繁华落℘9 小时前
单片机学习---字节对齐
单片机·嵌入式硬件·学习