自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

复制代码
Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

复制代码
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]

arg1: Dict[str, object]

对应的参数

复制代码
output_names=['output'], input_feed=toks

arg0=['output'] 参数类型正确

arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks['input_ids']) 输出为 <class 'torch.Tensor'>, 实际需要输入类型为 <class 'numpy.ndarray'>

修改代码

复制代码
    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

相关推荐
林_学5 分钟前
我是如何把应用上线时间从1天缩短到3分钟的
人工智能
钓了猫的鱼儿6 分钟前
农作物病虫害目标检测数据集(百度网盘地址)
人工智能·目标检测·目标跟踪
万行9 分钟前
机器人系统ros2&期末速通2
前端·人工智能·python·算法·机器学习
qwerasda12385210 分钟前
基于改进的SABL Cascade RNN的安全装备检测系统:手套护目镜安全帽防护服安全鞋识别与实现_r101_fpn_1x_coco_1
人工智能·rnn·安全
实战项目11 分钟前
基于PyTorchMobile的语音识别模型部署与调优
人工智能·语音识别
AI即插即用11 分钟前
超分辨率重建 | 2025 FIWHN:轻量级超分辨率 SOTA!基于“宽残差”与 Transformer 混合架构的高效网络(代码实践)
图像处理·人工智能·深度学习·计算机视觉·transformer·超分辨率重建
小北方城市网11 分钟前
数据库性能优化实战指南:从索引到架构,根治性能瓶颈
数据结构·数据库·人工智能·性能优化·架构·哈希算法·散列表
万行12 分钟前
机器人系统ros2&期末速通&1
人工智能·python·机器学习·机器人
轻竹办公PPT12 分钟前
AI 生成 2026 年工作计划 PPT,逻辑清晰度对比测试
人工智能·python·powerpoint
装不满的克莱因瓶13 分钟前
【cursor】前后端分离项目下的AI跨工程管理方案
java·人工智能·ai·ai编程·cursor·trae·qoder