自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
大模型服务器厂商1 分钟前
适配的 GPU 服务器能让 AI 模型充分发挥算力优势
人工智能
AscendKing7 分钟前
LandPPT - AI驱动的PPT生成平台
人工智能·好好学电脑·hhxdn.com
FreeCode9 分钟前
LangChain1.0智能体开发:流输出组件
人工智能·langchain·agent
故作春风14 分钟前
手把手实现一个前端 AI 编程助手:从 MCP 思想到 VS Code 插件实战
前端·人工智能
人工智能训练24 分钟前
在ubuntu系统中如何将docker安装在指定目录
linux·运维·服务器·人工智能·ubuntu·docker·ai编程
掘金一周26 分钟前
没开玩笑,全框架支持的 dialog 组件,支持响应式| 掘金一周 11.6
前端·人工智能
电鱼智能的电小鱼1 小时前
基于电鱼 ARM 边缘网关的智慧工地数据可靠传输方案——断点续传 + 4G/5G冗余通信,保障数据完整上传
arm开发·人工智能·嵌入式硬件·深度学习·5g·机器学习
Juchecar1 小时前
翻译:Agentic AI:面向企业应用的智能
人工智能
武子康1 小时前
AI研究-121 DeepSeek-OCR 研究路线:无限上下文、跨模态抽取、未来创意点、项目创意点
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
半臻(火白)1 小时前
从“看见文字”到“理解内容”:DeepSeek-OCR重构OCR 2.0时代的效率革命
人工智能