模型保存与加载
学习心得
-
保存 CheckPoint 格式文件 ,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。
pythonmodel = network() mindspore.save_checkpoint(model, "model.ckpt")
可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。
- save_checkpoint_steps表示每隔多少个step保存一次。
- keep_checkpoint_max表示最多保留CheckPoint文件的数量。
- prefix表示生成CheckPoint文件的前缀名。
- directory表示存放文件的目录。
pythonfrom mindspore.train.callback import ModelCheckpoint, CheckpointConfig config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10) ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck) model.train(epoch_num, dataset, callbacks=ckpoint_cb)
要加载模型权重,需要先创建相同模型的实例,然后使用
load_checkpoint
和load_param_into_net
方法加载参数。pythonmodel = network() param_dict = mindspore.load_checkpoint("model.ckpt") param_not_load, _ = mindspore.load_param_into_net(model, param_dict) print(param_not_load)
param_not_load
是未被加载的参数列表,为空时代表所有参数均加载成功。cmd[]
-
保存和加载MindIR ,当有了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。
- 统一表示:MindIR作为MindSpore云侧(训练)和端侧(推理)的统一模型文件,同时存储了网络结构和权重参数值。这使得MindSpore能够在不同的硬件平台上实现一次训练多次部署的能力。
- 导出MindIR:MindSpore提供了export接口,可以直接将模型保存为MindIR格式。
- 保存模型
pythonmodel = network() inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32)) mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
- 加载模型
pythonmindspore.set_context(mode=mindspore.GRAPH_MODE) graph = mindspore.load("model.mindir") model = nn.GraphCell(graph) outputs = model(inputs) print(outputs.shape)