自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
zzc9213 分钟前
时频图数据集更正程序,去除坐标轴白边及调整对应的标签值
人工智能·深度学习·数据集·标签·时频图·更正·白边
isNotNullX4 分钟前
什么是数据分析?常见方法全解析
大数据·数据库·数据仓库·人工智能·数据分析
riveting14 分钟前
明远智睿H618:开启多场景智慧生活新时代
人工智能·嵌入式硬件·智能硬件·lga封装·3506
夜阑卧听风吹雨,铁马冰河入梦来28 分钟前
Spring AI 阿里巴巴学习
人工智能·学习·spring
c76939 分钟前
【文献笔记】Automatic Chain of Thought Prompting in Large Language Models
人工智能·笔记·语言模型·论文笔记
Blossom.1181 小时前
机器学习在智能供应链中的应用:需求预测与物流优化
人工智能·深度学习·神经网络·机器学习·计算机视觉·机器人·语音识别
Gyoku Mint1 小时前
深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹
人工智能·pytorch·python·深度学习·神经网络·算法·聚类
zzywxc7871 小时前
AI大模型的技术演进、流程重构、行业影响三个维度的系统性分析
人工智能·重构
点控云1 小时前
智能私域运营中枢:从客户视角看 SCRM 的体验革新与价值重构
大数据·人工智能·科技·重构·外呼系统·呼叫中心
zhaoyi_he1 小时前
多模态大模型的技术应用与未来展望:重构AI交互范式的新引擎
人工智能·重构