自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
莽撞的大地瓜几秒前
连获国内多行业认可、入选全球AI全景图谱 彰显蜜度智能校对的硬核实力
人工智能·ai·语言模型·新媒体运营·知识图谱
ATM0063 分钟前
专其利AI | 专利撰写的救赎:AI工具为何成为知识产权保护的神兵利器
人工智能·大模型·专利撰写
海天一色y7 分钟前
使用 PyTorch RNN 识别手写数字
人工智能·pytorch·rnn
百***074511 分钟前
OpenClaw+一步API实战:本地化AI自动化助手从部署到落地全指南
大数据·人工智能·python
临水逸11 分钟前
OpenClaw WebUI 的外网访问配置
人工智能·策略模式
yuanyuan2o216 分钟前
【深度学习】AlexNet
人工智能·深度学习
不惑_17 分钟前
通俗理解条件生成对抗网络(cGAN)
人工智能·生成对抗网络·计算机视觉
deephub19 分钟前
torch.compile 加速原理:kernel 融合与缓冲区复用
人工智能·pytorch·深度学习·神经网络
ydl112819 分钟前
解码AI大模型:从神经网络到落地应用的全景探索
人工智能·深度学习·神经网络
小程故事多_8020 分钟前
Elasticsearch ES 分词与关键词匹配技术方案解析
大数据·人工智能·elasticsearch·搜索引擎·aigc