paddle模型转onnx介绍(以utc-mini为例)

1、paddle到onnx转换命令

python 复制代码
paddle2onnx --model_dir /opt/utc/models/checkpoint_utc-mini/ --model_filename model.pdmodel --params_filename model.pdiparams --save_file /opt/utc/models/checkpoint_utc-mini/onnx/model.onnx --enable_dev_version True --opset_version 13 --enable_onnx_checker True

2、测试验证

python 复制代码
# -*- coding: utf-8 -*-
"""
    paddle -> onnx
"""
import os

from paddle.static import InputSpec

import paddle2onnx


def func_onnx_test_valid(onnx_model_path):
    """ (1) 检查 ONNX 模型的有效性
        可以使用如下脚本验证导出的 ONNX 模型是否合理,包括检查模型的版本、图的结构、节点及其输入和输出。
        如下脚本的输出为 None 则表示模型转换正确。
    """
    # 导入 ONNX 库
    import onnx

    # 载入 ONNX 模型
    onnx_model = onnx.load(onnx_model_path)

    # 使用 ONNX 库检查 ONNX 模型是否合理
    check = onnx.checker.check_model(onnx_model)

    # 打印检查结果
    print('check: ', check)

    pass


def func_onnx_test_match(onnx_model_path, paddle_model_path):
    """ 验证模型是否匹配
        验证原始的飞桨模型和导出的 ONNX 模型是否有相同的计算结果。
    """
    # 导入所需的库
    import numpy as np
    import onnxruntime
    import paddle

    def input_generate():
        # 准备输入数据
        batch_size = 1
        max_seq_length = 128  # 假设最大序列长度为 128

        # 生成示例输入数据
        input_ids = np.random.randint(0, 10000, (batch_size, max_seq_length)).astype('int64')
        token_type_ids = np.zeros((batch_size, max_seq_length), dtype='int64')
        position_ids = np.arange(max_seq_length).reshape(1, -1).repeat(batch_size, axis=0).astype('int64')
        attention_mask = np.ones((batch_size, 1, max_seq_length, max_seq_length), dtype='float32')
        omask_positions = np.array([[10, 20]]).astype('int64')  # 假设 omask_positions 为 [10, 20]
        cls_positions = np.array([0]).astype('int64')  # 假设 cls_positions 为 [0]

        # 准备输入字典
        ort_inputs = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'position_ids': position_ids,
            'attention_mask': attention_mask,
            'omask_positions': omask_positions,
            'cls_positions': cls_positions
        }

        return ort_inputs

    print("------------------------ ONNX -----------------------------")
    # predict by ONNXRuntime
    ort_sess = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])

    # 获取输入和输出名称
    input_names = [input.name for input in ort_sess.get_inputs()]
    output_names = [output.name for output in ort_sess.get_outputs()]
    # 打印输入和输出名称
    print("Input names:", input_names)
    print("Output names:", output_names)

    # 获取输入数据
    ort_inputs = input_generate()

    # 运行模型
    ort_outs = ort_sess.run(None, ort_inputs)
    print("ONNX Outputs: \n", ort_outs)
    print("Exported model has been predicted by ONNXRuntime!")
    print("------------------------ ONNX -----------------------------")

    pass


if __name__ == '__main__':

    # paddle 模型保存目录及文件路径
    model_dir_paddle = '/opt/utc/models/checkpoint_utc-mini/'

    # onnx 保存目录及文件路径
    model_dir_onnx = model_dir_paddle + "onnx/"
    os.makedirs(model_dir_onnx, exist_ok=True)
    onnx_model_path = model_dir_onnx + 'model.onnx'

    # (1) 检查 ONNX 模型的有效性
    func_onnx_test_valid(onnx_model_path)
    # (2) 验证模型是否匹配
    func_onnx_test_match(onnx_model_path)

    print("done.")
    pass
相关推荐
ARM+FPGA+AI工业主板定制专家10 分钟前
基于Jetson+FPGA+GMSL+AI的自动驾驶数据采集解决方案
人工智能·机器学习·自动驾驶
聊聊MES那点事41 分钟前
汽车零部件MES系统实施案例介绍
人工智能·信息可视化·汽车·数据可视化
软件算法开发1 小时前
基于螳螂虾优化的LSTM深度学习网络模型(MShOA-LSTM)的一维时间序列预测算法matlab仿真
深度学习·lstm·一维时间序列预测·螳螂虾优化·mshoa·mshoa-lstm
星期天要睡觉1 小时前
计算机视觉(opencv)——仿射变换(Affine Transformation)
人工智能·opencv·计算机视觉
Phoenixtree_DongZhao1 小时前
面向单步生成建模的均值流方法: MeanFlow, 一步生成高清图像(何恺明 [NeurIPS 2025 Oral] )
人工智能
hazy1k2 小时前
K230基础-录放视频
网络·人工智能·stm32·单片机·嵌入式硬件·音视频·k230
陈敬雷-充电了么-CEO兼CTO2 小时前
DeepSeek vs ChatGPT 技术架构、成本与场景全解析
人工智能·chatgpt·架构
MarvinP2 小时前
《Seq2Time: Sequential Knowledge Transfer for Video LLMTemporal Grounding》
人工智能·计算机视觉
AORO20252 小时前
适合户外探险、物流、应急、工业,五款三防智能手机深度解析
网络·人工智能·5g·智能手机·制造·信息与通信
铉铉这波能秀3 小时前
如何在Android Studio中使用Gemini进行AI Coding
android·java·人工智能·ai·kotlin·app·android studio