【昇思初学入门】第八天打卡-模型保存与加载

模型保存与加载

学习心得

  • 保存 CheckPoint 格式文件 ,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。

    python 复制代码
    model = network()
    mindspore.save_checkpoint(model, "model.ckpt")

    可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。

    • save_checkpoint_steps表示每隔多少个step保存一次。
    • keep_checkpoint_max表示最多保留CheckPoint文件的数量。
    • prefix表示生成CheckPoint文件的前缀名。
    • directory表示存放文件的目录。
    python 复制代码
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
    model.train(epoch_num, dataset, callbacks=ckpoint_cb)

    要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

    python 复制代码
    	model = network()
    	param_dict = mindspore.load_checkpoint("model.ckpt")
    	param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    	print(param_not_load)

    param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

    cmd 复制代码
    [] 
  1. 保存和加载MindIR ,当有了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。

    • 统一表示:MindIR作为MindSpore云侧(训练)和端侧(推理)的统一模型文件,同时存储了网络结构和权重参数值。这使得MindSpore能够在不同的硬件平台上实现一次训练多次部署的能力。
    • 导出MindIR:MindSpore提供了export接口,可以直接将模型保存为MindIR格式。
    • 保存模型
    python 复制代码
    model = network()
    inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
    mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
    • 加载模型
    python 复制代码
    mindspore.set_context(mode=mindspore.GRAPH_MODE)
    graph = mindspore.load("model.mindir")
    model = nn.GraphCell(graph)
    outputs = model(inputs)
    print(outputs.shape)
相关推荐
盼小辉丶1 小时前
Wasserstein GAN(WGAN)
人工智能·神经网络·生成对抗网络
XIAO·宝7 小时前
深度学习------专题《神经网络完成手写数字识别》
人工智能·深度学习·神经网络
Bugman.8 小时前
分类任务-三个重要网络模型
深度学习·机器学习·分类
小殊小殊9 小时前
超越CNN:GCN如何重塑图像处理
图像处理·人工智能·深度学习
Kaydeon12 小时前
【AIGC】50倍加速!NVIDIA蒸馏算法rCM:分数正则化连续时间一致性模型的大规模扩散蒸馏
人工智能·pytorch·python·深度学习·计算机视觉·aigc
机器学习之心13 小时前
PINN物理信息神经网络风电功率预测!引入物理先验知识嵌入学习的风电功率预测新范式!Matlab实现
神经网络·学习·matlab·风电功率预测·物理信息神经网络
三年呀13 小时前
深度剖析Mixture of Experts(MoE)架构:从原理到实践的全面指南
人工智能·深度学习·架构·模型优化·大规模模型
墨利昂14 小时前
神经网络常用激活函数公式
人工智能·深度学习·神经网络
初级炼丹师(爱说实话版)16 小时前
PGLRNet论文笔记
人工智能·深度学习·计算机视觉