【昇思初学入门】第八天打卡-模型保存与加载

模型保存与加载

学习心得

  • 保存 CheckPoint 格式文件 ,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。

    python 复制代码
    model = network()
    mindspore.save_checkpoint(model, "model.ckpt")

    可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。

    • save_checkpoint_steps表示每隔多少个step保存一次。
    • keep_checkpoint_max表示最多保留CheckPoint文件的数量。
    • prefix表示生成CheckPoint文件的前缀名。
    • directory表示存放文件的目录。
    python 复制代码
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
    model.train(epoch_num, dataset, callbacks=ckpoint_cb)

    要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

    python 复制代码
    	model = network()
    	param_dict = mindspore.load_checkpoint("model.ckpt")
    	param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    	print(param_not_load)

    param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

    cmd 复制代码
    [] 
  1. 保存和加载MindIR ,当有了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。

    • 统一表示:MindIR作为MindSpore云侧(训练)和端侧(推理)的统一模型文件,同时存储了网络结构和权重参数值。这使得MindSpore能够在不同的硬件平台上实现一次训练多次部署的能力。
    • 导出MindIR:MindSpore提供了export接口,可以直接将模型保存为MindIR格式。
    • 保存模型
    python 复制代码
    model = network()
    inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
    mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
    • 加载模型
    python 复制代码
    mindspore.set_context(mode=mindspore.GRAPH_MODE)
    graph = mindspore.load("model.mindir")
    model = nn.GraphCell(graph)
    outputs = model(inputs)
    print(outputs.shape)
相关推荐
shangyingying_17 小时前
关于小波降噪、小波增强、小波去雾的原理区分
人工智能·深度学习·计算机视觉
书玮嘎8 小时前
【WIP】【VLA&VLM——InternVL系列】
人工智能·深度学习
要努力啊啊啊8 小时前
YOLOv2 正负样本分配机制详解
人工智能·深度学习·yolo·计算机视觉·目标跟踪
Blossom.1189 小时前
机器学习在智能建筑中的应用:能源管理与环境优化
人工智能·python·深度学习·神经网络·机器学习·机器人·sklearn
m0_6786933310 小时前
深度学习笔记29-RNN实现阿尔茨海默病诊断(Pytorch)
笔记·rnn·深度学习
胡耀超11 小时前
标签体系设计与管理:从理论基础到智能化实践的综合指南
人工智能·python·深度学习·数据挖掘·大模型·用户画像·语义分析
fzyz12312 小时前
Windows系统下WSL从C盘迁移方案
人工智能·windows·深度学习·wsl
FF-Studio14 小时前
【硬核数学 · LLM篇】3.1 Transformer之心:自注意力机制的线性代数解构《从零构建机器学习、深度学习到LLM的数学认知》
人工智能·pytorch·深度学习·线性代数·机器学习·数学建模·transformer
云渚钓月梦未杳15 小时前
深度学习03 人工神经网络ANN
人工智能·深度学习
贾全15 小时前
第十章:HIL-SERL 真实机器人训练实战
人工智能·深度学习·算法·机器学习·机器人