【昇思初学入门】第八天打卡-模型保存与加载

模型保存与加载

学习心得

  • 保存 CheckPoint 格式文件 ,在模型训练过程中,可以添加检查点(CheckPoint)用于保存模型的参数,以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理,可通过网络和CheckPoint格式文件生成对应的MINDIR、AIR和ONNX格式文件。

    python 复制代码
    model = network()
    mindspore.save_checkpoint(model, "model.ckpt")

    可以通过CheckpointConfig对象可以设置CheckPoint的保存策略。

    • save_checkpoint_steps表示每隔多少个step保存一次。
    • keep_checkpoint_max表示最多保留CheckPoint文件的数量。
    • prefix表示生成CheckPoint文件的前缀名。
    • directory表示存放文件的目录。
    python 复制代码
    from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
    config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
    ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
    model.train(epoch_num, dataset, callbacks=ckpoint_cb)

    要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

    python 复制代码
    	model = network()
    	param_dict = mindspore.load_checkpoint("model.ckpt")
    	param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
    	print(param_not_load)

    param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

    cmd 复制代码
    [] 
  1. 保存和加载MindIR ,当有了CheckPoint文件后,如果想继续在MindSpore Lite端侧做推理,需要通过网络和CheckPoint生成对应的MINDIR格式模型文件。

    • 统一表示:MindIR作为MindSpore云侧(训练)和端侧(推理)的统一模型文件,同时存储了网络结构和权重参数值。这使得MindSpore能够在不同的硬件平台上实现一次训练多次部署的能力。
    • 导出MindIR:MindSpore提供了export接口,可以直接将模型保存为MindIR格式。
    • 保存模型
    python 复制代码
    model = network()
    inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
    mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
    • 加载模型
    python 复制代码
    mindspore.set_context(mode=mindspore.GRAPH_MODE)
    graph = mindspore.load("model.mindir")
    model = nn.GraphCell(graph)
    outputs = model(inputs)
    print(outputs.shape)
相关推荐
zy_destiny5 小时前
【工业场景】用YOLOv26实现4种输电线隐患检测
人工智能·深度学习·算法·yolo·机器学习·计算机视觉·输电线隐患识别
雍凉明月夜5 小时前
深度学习之目标检测yolo算法Ⅴ-YOLOv8
深度学习·yolo·目标检测
2501_941652775 小时前
改进YOLOv5-BiFPN-SDI实现牙齿龋齿检测与分类_深度学习_计算机视觉_原创
深度学习·yolo·分类
肾透侧视攻城狮5 小时前
《PyTorch神经网络从开发到调试:实战技巧、可视化与兼容性问题解决方案》
神经网络·语言模型·二分类任务·实现前馈神经网络·可视化执行梯度下降算法·matplotlib版本兼容性·pytorch实现二分类任务
zy_destiny6 小时前
【工业场景】用YOLOv26实现8种道路隐患检测
人工智能·深度学习·算法·yolo·机器学习·计算机视觉·目标跟踪
铁手飞鹰6 小时前
[深度学习]Vision Transformer
人工智能·pytorch·python·深度学习·transformer
weixin_395448916 小时前
average_weights.py
pytorch·python·深度学习
香芋Yu6 小时前
【深度学习教程——02_优化与正则(Optimization)】09_为什么Dropout能防止过拟合?正则化的本质
人工智能·深度学习
皮肤科大白6 小时前
超轻量SAM模型部署:ONNX量化与Transformer剪枝全攻略
深度学习·transformer
Loo国昌7 小时前
【大模型应用开发】第三阶段:深度解析检索增强生成(RAG)原理
人工智能·后端·深度学习·自然语言处理·transformer