自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: Liststr

arg1: Dictstr, object

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0='output' 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks'input_ids') 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
岁月宁静1 小时前
RAG 文档摄入全链路,从原理到生产落地
vue.js·人工智能·python
小和尚同志1 小时前
AI 自动化测试探索(一):Playwright MCP
前端·人工智能·aigc
硅谷秋水1 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶
郭泽斌之心1 小时前
MQL5 EA 怎么和外部程序通信?文件三件套协议:参数热更新不重启、状态心跳、远程触发
人工智能·经验分享·深度学习·ea·fay数字人·easydeal
mit6.8241 小时前
“泄露了windows12“
人工智能
syc78901231 小时前
中文语境下AI编码工具实战对比:从迭代体验看日常开发选择
linux·人工智能·ubuntu
dualven_in_csdn2 小时前
用户点击“一键起飞“
人工智能
米核AI易山2 小时前
扣子工作流变量体系深度解析:从踩坑到精通
人工智能·coze·扣子工作流·米核ai易山
nap-joker2 小时前
用于转录组信息精确肿瘤学和药物机制分析的多模态可解释深度学习
人工智能·深度学习·药物敏感性·多层级生物网络·细胞异质性·可解释性多模态