自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
大彬聊编程9 分钟前
清华大学102页PPT 《deepseek从入门到精通》
人工智能
明月与玄武14 分钟前
Apifox 增强 AI 接口调试功能:自动合并 SSE 响应、展示DeepSeek思考过程
人工智能·apifox·增强 ai 接口调试功能
虚假程序设计27 分钟前
opencv 自适应阈值
人工智能·opencv·计算机视觉
沐欣工作室_lvyiyi39 分钟前
基于物联网的家庭版防疫面罩设计与实现(论文+源码)
人工智能·stm32·单片机·物联网·目标跟踪
xzzd_jokelin1 小时前
Spring AI 接入 DeepSeek:开启智能应用的新篇章
java·人工智能·spring·ai·大模型·rag·deepseek
简简单单做算法1 小时前
基于WOA鲸鱼优化的BiLSTM双向长短期记忆网络序列预测算法matlab仿真,对比BiLSTM和LSTM
人工智能·lstm·bilstm·woa-bilstm·双向长短期记忆网络·woa鲸鱼优化·序列预测
星霜旅人1 小时前
开源机器学习框架
人工智能·机器学习·开源
资源大全免费分享1 小时前
清华大学第五版《DeepSeek与AI幻觉》附五版合集下载方法
人工智能
龚大龙1 小时前
机器学习(李宏毅)——RL(强化学习)
人工智能·机器学习
LaughingZhu1 小时前
PH热榜 | 2025-02-23
前端·人工智能·经验分享·搜索引擎·产品运营