【昇思初学入门】第八天打卡-模型保存与加载

模型保存与加载

学习心得

  • 保存 CheckPoint 格式文件 ,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。

    python 复制代码
    model = network()
    mindspore.save_checkpoint(model, "model.ckpt")

    可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。

    • save_checkpoint_steps表示每隔多少个step保存一次。
    • keep_checkpoint_max表示最多保留CheckPoint文件的数量。
    • prefix表示生成CheckPoint文件的前缀名。
    • directory表示存放文件的目录。
    python 复制代码
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
    model.train(epoch_num, dataset, callbacks=ckpoint_cb)

    要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

    python 复制代码
    	model = network()
    	param_dict = mindspore.load_checkpoint("model.ckpt")
    	param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    	print(param_not_load)

    param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

    cmd 复制代码
    [] 
  1. 保存和加载MindIR ,当有了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。

    • 统一表示:MindIR作为MindSpore云侧(训练)和端侧(推理)的统一模型文件,同时存储了网络结构和权重参数值。这使得MindSpore能够在不同的硬件平台上实现一次训练多次部署的能力。
    • 导出MindIR:MindSpore提供了export接口,可以直接将模型保存为MindIR格式。
    • 保存模型
    python 复制代码
    model = network()
    inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
    mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
    • 加载模型
    python 复制代码
    mindspore.set_context(mode=mindspore.GRAPH_MODE)
    graph = mindspore.load("model.mindir")
    model = nn.GraphCell(graph)
    outputs = model(inputs)
    print(outputs.shape)
相关推荐
盼小辉丶6 小时前
TensorFlow深度学习实战——胶囊网络
深度学习·tensorflow·keras
doubao369 小时前
如何有效降低AIGC生成内容被识别的概率?
人工智能·深度学习·自然语言处理·aigc·ai写作
Danceful_YJ11 小时前
31.注意力评分函数
pytorch·python·深度学习
xier_ran12 小时前
深度学习:神经网络中的参数和超参数
人工智能·深度学习
8Qi812 小时前
伪装图像生成之——GAN与Diffusion
人工智能·深度学习·神经网络·生成对抗网络·图像生成·伪装图像生成
思通数科多模态大模型13 小时前
扑灭斗殴的火苗:AI智能守护如何为校园安全保驾护航
大数据·人工智能·深度学习·安全·目标检测·计算机视觉·数据挖掘
xixixi7777714 小时前
了解一下LSTM:长短期记忆网络(改进的RNN)
人工智能·深度学习·机器学习
能来帮帮蒟蒻吗14 小时前
深度学习(1)—— 基本概念
人工智能·深度学习
LeonDL16814 小时前
基于YOLO11深度学习的电动车头盔检测系统【Python源码+Pyqt5界面+数据集+安装使用教程+训练代码】【附下载链接】
人工智能·python·深度学习·pyqt5·yolo数据集·电动车头盔检测系统·yolo11深度学习
carver w14 小时前
彻底理解传统卷积,深度可分离卷积
人工智能·深度学习·计算机视觉