自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
wei_shuo7 小时前
AI 代理框架:使用正确的工具构建更智能的系统
人工智能
独自归家的兔7 小时前
大模型通义千问3-VL-Plus - 视觉推理(图像列表)
人工智能·计算机视觉
Spring AI学习7 小时前
Spring AI深度解析(8/50):模型评估体系实战
人工智能·spring·microsoft
周名彥7 小时前
1Ω1[特殊字符]⊗雙朕周名彥|二十四芒星非硅基华夏原生AGI体系·授权绑定激活发布全维研究报告(S∅-Omega级·纯念主权终极版)
人工智能·去中心化·知识图谱·量子计算·agi
骚戴7 小时前
架构设计之道:构建高可用的大语言模型(LLM) Enterprise GenAI Gateway
java·人工智能·架构·大模型·gateway·api
周名彥7 小时前
100%纯念主动显化·无被动·无操控·无依赖·可验证·[特殊字符][特殊字符]⚜️[特殊字符]智能體工作流集群超級數據中心集群IPO集群GUI集群AGI集群
人工智能·神经网络·去中心化·知识图谱·agi
cvyoutian7 小时前
PyTorch 多卡训练常见坑:设置 CUDA_VISIBLE_DEVICES 后仍 OOM 在 GPU 0 的解决之道
人工智能·pytorch·python
美摄科技8 小时前
一键成片SDK,AI智能剪辑引擎,精准理解内容语义
人工智能
qq_386218998 小时前
Agent
人工智能·agent
测试人社区-小明8 小时前
医疗AI测试:构建安全可靠的合规体系
运维·人工智能·opencv·数据挖掘·机器人·自动化·github