python
复制代码
# -*- coding: utf-8 -*-
"""
paddle -> onnx
"""
import os
from paddle.static import InputSpec
import paddle2onnx
def func_onnx_test_valid(onnx_model_path):
""" (1) 检查 ONNX 模型的有效性
可以使用如下脚本验证导出的 ONNX 模型是否合理,包括检查模型的版本、图的结构、节点及其输入和输出。
如下脚本的输出为 None 则表示模型转换正确。
"""
# 导入 ONNX 库
import onnx
# 载入 ONNX 模型
onnx_model = onnx.load(onnx_model_path)
# 使用 ONNX 库检查 ONNX 模型是否合理
check = onnx.checker.check_model(onnx_model)
# 打印检查结果
print('check: ', check)
pass
def func_onnx_test_match(onnx_model_path, paddle_model_path):
""" 验证模型是否匹配
验证原始的飞桨模型和导出的 ONNX 模型是否有相同的计算结果。
"""
# 导入所需的库
import numpy as np
import onnxruntime
import paddle
def input_generate():
# 准备输入数据
batch_size = 1
max_seq_length = 128 # 假设最大序列长度为 128
# 生成示例输入数据
input_ids = np.random.randint(0, 10000, (batch_size, max_seq_length)).astype('int64')
token_type_ids = np.zeros((batch_size, max_seq_length), dtype='int64')
position_ids = np.arange(max_seq_length).reshape(1, -1).repeat(batch_size, axis=0).astype('int64')
attention_mask = np.ones((batch_size, 1, max_seq_length, max_seq_length), dtype='float32')
omask_positions = np.array([[10, 20]]).astype('int64') # 假设 omask_positions 为 [10, 20]
cls_positions = np.array([0]).astype('int64') # 假设 cls_positions 为 [0]
# 准备输入字典
ort_inputs = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'position_ids': position_ids,
'attention_mask': attention_mask,
'omask_positions': omask_positions,
'cls_positions': cls_positions
}
return ort_inputs
print("------------------------ ONNX -----------------------------")
# predict by ONNXRuntime
ort_sess = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider'])
# 获取输入和输出名称
input_names = [input.name for input in ort_sess.get_inputs()]
output_names = [output.name for output in ort_sess.get_outputs()]
# 打印输入和输出名称
print("Input names:", input_names)
print("Output names:", output_names)
# 获取输入数据
ort_inputs = input_generate()
# 运行模型
ort_outs = ort_sess.run(None, ort_inputs)
print("ONNX Outputs: \n", ort_outs)
print("Exported model has been predicted by ONNXRuntime!")
print("------------------------ ONNX -----------------------------")
pass
if __name__ == '__main__':
# paddle 模型保存目录及文件路径
model_dir_paddle = '/opt/utc/models/checkpoint_utc-mini/'
# onnx 保存目录及文件路径
model_dir_onnx = model_dir_paddle + "onnx/"
os.makedirs(model_dir_onnx, exist_ok=True)
onnx_model_path = model_dir_onnx + 'model.onnx'
# (1) 检查 ONNX 模型的有效性
func_onnx_test_valid(onnx_model_path)
# (2) 验证模型是否匹配
func_onnx_test_match(onnx_model_path)
print("done.")
pass