自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
学行库小秘4 分钟前
ANN神经网络回归预测模型
人工智能·python·深度学习·神经网络·算法·机器学习·回归
文弱_书生7 分钟前
为什么神经网络在长时间训练过程中会存在稠密特征图退化的问题
人工智能·深度学习·神经网络
爱写代码的小朋友18 分钟前
数字化与人工智能的崛起及其社会影响研究报告
人工智能
martinzh1 小时前
提示词工程师到底是干什么的?
人工智能
Coovally AI模型快速验证1 小时前
SOD-YOLO:基于YOLO的无人机图像小目标检测增强方法
人工智能·yolo·目标检测·机器学习·计算机视觉·目标跟踪·无人机
黎燃1 小时前
无人机+AI:精准农业的“降维打击”实践
人工智能
茫茫人海一粒沙1 小时前
RAG 分块中表格填补简明示例:Markdown、HTML、Excel、Doc
人工智能
爱喝奶茶的企鹅1 小时前
Ethan开发者创新项目日报 | 2025-08-18
人工智能
却道天凉_好个秋1 小时前
计算机视觉(一):nvidia与cuda介绍
人工智能·计算机视觉
爱喝奶茶的企鹅1 小时前
Ethan独立开发新品速递 | 2025-08-18
人工智能·程序员·开源