自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
Bonnie373几秒前
算力基建入门-AI时代,算力为何是数字底座
人工智能·程序人生·云原生·个人开发
前端摸鱼匠1 分钟前
面试题6:因果掩码(Causal Mask)在Decoder中的作用是什么?训练、推理阶段如何使用?
人工智能·ai·语言模型·自然语言处理·面试
这张生成的图像能检测吗2 分钟前
(论文速读)ASFRMT:基于对抗的超特征重构元传递网络弱特征增强与谐波传动故障诊断
人工智能·深度学习·计算机视觉·故障诊断
statistican_ABin6 分钟前
Python数据分析-宝马全球汽车销售数据分析(可视化分析)
大数据·人工智能·数据分析·汽车·数据可视化
ARM+FPGA+AI工业主板定制专家6 分钟前
基于ARM+FPGA+AI的船舶状态智能监测系统(一)总体设计
网络·arm开发·人工智能·机器学习·fpga开发·自动驾驶
前端摸鱼匠7 分钟前
面试题7:Encoder-only、Decoder-only、Encoder-Decoder三种架构的差异与适用场景?
人工智能·深度学习·ai·面试·职场和发展·架构·transformer
ryrhhhh7 分钟前
矩阵跃动技术创新:GEO搜索占位+AI智能体双融合,重构企业获客链路
大数据·人工智能
no_work8 分钟前
基于python的hog+svm实现混凝土裂缝目标检测
人工智能·python·目标检测·计算机视觉
小陈工8 分钟前
2026年3月21日技术资讯洞察:云原生理性回归与Python异步革命
人工智能·python·云原生·数据挖掘·回归
柯儿的天空10 分钟前
【OpenClaw 全面解析:从零到精通】第 018 篇:OpenClaw 多智能体协作系统——多 Agent 路由、任务委托与负载均衡
运维·人工智能·aigc·负载均衡·ai编程·ai写作·agi