PyTorch导出ONNX报错(ShapeInferenceError)问题笔记(含dynamo=False作用解析)

PyTorch导出ONNX报错(ShapeInferenceError)问题笔记(含dynamo=False作用解析)

核心结论:针对BERT类模型(含chinese-roberta-wwm-ext ),torch.onnx.export() 必须添加 dynamo=False,否则必报形状错误;添加后,无论环境版本如何,均能正常导出、量化,无任何影响。

报错核心信息:InferenceError: \[ShapeInferenceError\] Inferred shape and existing shape differ in dimension 0: \(768\) vs \(4\)

关键数字解析(贴合自身BERT模型):

  • 768:BERT模型的隐藏层维度(基础版BERT固定维度,是模型中间特征的维度)

  • 4:自身微调模型的分类类别数(模型最终输出的维度,对应分类任务的结果)

  • 报错本质:模型导出时,形状推理逻辑错乱,误将「中间特征维度(768)」当作「最终输出维度(4)」,导致维度不匹配。

一、核心原理铺垫(先懂底层逻辑,再看场景)

1. 两个导出器的差异(PyTorch 2.0+ 关键区别)

导出方式 触发条件 核心特点 对BERT模型的兼容性
新动态导出器(Dynamo) 不加 dynamo=False(PyTorch 2.0+ 默认开启) 动态编译模型,强制绑定输入/权重形状,依赖onnxscript做形状推导 极差,易出现形状推理错乱、报错
传统静态导出器(TorchScript) 加 dynamo=False(强制切换) 逐节点计算形状,精准识别模型结构(特征层+分类头),不依赖onnxscript 极佳,形状推理准确,无报错

2. dynamo=False 的核心作用

强制关闭 PyTorch 2.0+ 默认的「Dynamo动态导出器」,切换回「传统TorchScript静态导出器」,彻底规避新导出器的形状推理bug,同时绕开onnxscript的依赖干扰,确保模型维度推导准确。

二、不同环境场景详细解析(贴合自身测试结果)

场景1:旧版环境(无onnxscript,numpy 1.x)

环境配置
  • onnx: 1.16.2、onnxruntime: 1.16.3

  • onnxscript: 未安装

  • numpy: 1.26.4(<2.0,兼容旧版ort)

  • PyTorch: 2.0+(默认开启Dynamo导出器)

报错情况

不加 dynamo=False:必报 ShapeInferenceError: \(768\) vs \(4\)

加 dynamo=False:无任何报错,正常导出、量化。

报错原因

PyTorch默认启用Dynamo动态导出器,动态编译时无法正确区分BERT的「中间特征层(768维)」和「分类输出层(4维)」,形状推理错乱;且旧版onnx/ort不兼容新导出器的推理逻辑,进一步加剧报错。

加dynamo=False的原理

切换回传统静态导出器,逐节点解析模型结构,精准识别768维是中间特征、4维是最终输出,形状推理准确;同时不依赖onnxscript,彻底规避无关依赖干扰。

场景2:旧版环境(有onnxscript,numpy 2.x)

环境配置
  • onnx: 1.16.2、onnxruntime: 1.16.3

  • onnxscript: 0.1.0.dev20240509(开发版,与旧版onnx不兼容)

  • numpy: 2.4.3(≥2.0,与旧版ort不兼容)

报错情况

不加 dynamo=False:先报 onnxscript 导入错误(如ImportError: cannot import name \&\#39;convenience\&\#39;),后续仍会报 ShapeInferenceError;

加 dynamo=False:仍会有numpy兼容警告(旧版ort用numpy1.x编译),但不影响导出、量化,无ShapeInferenceError。

报错原因
  1. 开发版onnxscript与旧版onnx不兼容,导致导入失败;2. Dynamo导出器依赖onnxscript做形状推导,导入失败进一步引发形状推理错乱;3. numpy2.x删除了旧版ort依赖的_ARRARY_API接口,触发底层兼容警告。
加dynamo=False的原理

切换回静态导出器,绕开onnxscript的依赖(无需调用onnxscript),解决导入错误;同时静态导出器形状推理准确,避免ShapeInferenceError;numpy警告不影响核心功能(仅底层编译兼容提示)。

场景3:新版环境(有onnxscript,numpy 2.x)

环境配置
  • onnx: 1.21.0、onnxruntime: 1.24.4

  • onnxscript: 0.6.2(正式版,与新版onnx兼容)

  • numpy: 2.4.3(≥2.0,与新版ort兼容)

报错情况

不加 dynamo=False:必报 ShapeInferenceError: \(768\) vs \(4\)

加 dynamo=False:无任何报错、无警告,正常导出、量化(如自身测试结果)。

报错原因

即便onnxscript、numpy、onnx/ort版本完全兼容,Dynamo动态导出器仍存在bug------无法正确解析BERT模型的层级结构,误将中间特征维度当作输出维度,导致形状冲突。

加dynamo=False的原理

切换回静态导出器,不依赖onnxscript的形状推导,逐节点精准识别模型输出维度,彻底解决形状错乱问题;同时新版环境各依赖兼容,无任何警告、报错。

场景4:新版环境(无onnxscript,numpy 2.x)

环境配置
  • onnx: 1.21.0、onnxruntime: 1.24.4

  • onnxscript: 未安装

  • numpy: 2.4.3(≥2.0,与新版ort兼容)

报错情况

不加 dynamo=False:必报 ShapeInferenceError: \(768\) vs \(4\)(Dynamo导出器无onnxscript支持,形状推理更错乱);

加 dynamo=False:无任何报错、无警告,正常导出、量化。

报错原因

Dynamo动态导出器默认依赖onnxscript做形状推导,未安装onnxscript时,推导逻辑缺失,进一步加剧形状识别错误,无法区分BERT的中间特征与最终输出。

加dynamo=False的原理

切换回静态导出器,无需onnxscript支持,自身就能精准推导模型形状,规避所有报错,同时新版环境兼容,运行流畅。

场景5:旧版环境(无onnxscript,numpy 2.x)

环境配置
  • onnx: 1.16.2、onnxruntime: 1.16.3

  • onnxscript: 未安装

  • numpy: 2.4.3(≥2.0,与旧版ort不兼容)

  • PyTorch: 2.0+(默认开启Dynamo导出器)

报错情况

不加 dynamo=False:先报 ModuleNotFoundError: No module named \&\#39;onnxscript\&\#39;,同时会引导安装onnxscript;即便安装onnxscript(开发版),后续仍会报 ShapeInferenceError 和 numpy兼容错误;

加 dynamo=False:会有numpy兼容警告(旧版ort用numpy1.x编译),但不影响导出、量化,无任何报错,且无需安装onnxscript。

报错原因
  1. Dynamo动态导出器默认依赖onnxscript做形状推导,未安装时直接报模块缺失错误,且会引导用户安装以尝试解决;2. 即便安装onnxscript,开发版与旧版onnx不兼容,仍会引发后续报错;3. numpy2.x与旧版ort不兼容,触发底层警告;4. 核心还是Dynamo导出器与BERT模型不兼容,即便解决onnxscript缺失问题,仍会出现形状错误。
加dynamo=False的原理

切换回传统静态导出器,完全绕开onnxscript的依赖,无需安装onnxscript就能正常推导模型形状,解决ModuleNotFoundError;同时静态导出器形状推理准确,避免ShapeInferenceError;numpy警告仅为底层编译兼容提示,不影响核心的导出、量化功能。

    1. 针对 BERT 类模型(含chinese-roberta-wwm-ext ),torch.onnx.export() 必须添加 dynamo=False,这是解决 ShapeInferenceError 的唯一方案;
    1. dynamo=False 的作用:强制关闭 PyTorch 新的 Dynamo 动态导出器,切换回传统稳定导出器,确保形状推理准确,不混淆中间特征与最终输出维度;
    1. 所有场景(新旧环境、有无onnxscript、装不装onnxsim)均适用,添加后无需修改任何环境配置;
    1. 补充备注:你当前使用的 chinese-roberta-wwm-ext 属于BERT类模型,与笔记中核心逻辑完全一致,不影响笔记所有结论和操作步骤,按原有代码执行即可。
    1. 针对 BERT 类模型(含chinese-roberta-wwm-ext 、基础BERT等),torch.onnx.export() 必须添加 dynamo=False,这是解决 ShapeInferenceError 的唯一方案;
    1. dynamo=False 的作用:强制切换到传统静态导出器,规避新导出器的形状推理bug,绕开onnxscript依赖,确保维度推导准确;
    1. 所有报错的核心根源:PyTorch 2.0+ 默认开启的 Dynamo 动态导出器,与 BERT 类模型(含chinese-roberta-wwm-ext)不兼容,导致形状推理错乱;
    1. 所有报错的核心根源:PyTorch 2.0+ 默认的 Dynamo 动态导出器,与 BERT 模型结构不兼容,且依赖onnxscript,导致形状推理错乱、模块缺失报错,与代码、模型本身无关。

四、万能代码模板(直接复制使用)

python 复制代码
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
import os
import onnxruntime as ort

# 路径配置(根据自身环境修改)
model_path = "你的模型路径"
onnx_path = "导出的ONNX模型路径"
quantized_onnx_path = "量化后的ONNX模型路径"
os.makedirs(os.path.dirname(onnx_path), exist_ok=True)

# 加载模型和分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()

# 构造示例输入
inputs = tokenizer("测试文本", return_tensors="pt", padding=True, truncation=True)

# 导出ONNX(核心:添加dynamo=False)
with torch.no_grad():
    torch.onnx.export(
        model,
        args=(inputs["input_ids"], inputs["attention_mask"]),
        f=onnx_path,
        input_names=["input_ids", "attention_mask"],
        output_names=["logits"],  # 输出名,与自身模型匹配
        dynamic_axes={
            "input_ids": {0: "batch", 1: "seq_len"},
            "attention_mask": {0: "batch", 1: "seq_len"},
            "logits": {0: "batch"}
        },
        opset_version=14,  # 适配BERT模型,14版本最稳定
        do_constant_folding=True,
        dynamo=False,  # 必加!避免形状报错
        verbose=False
    )

# 模型验证、量化、测试(可选,根据需求保留)
onnx.checker.check_model(onnx_path)
print("✅ ONNX模型导出成功")

quantize_dynamic(
    model_input=onnx_path,
    model_output=quantized_onnx_path,
    weight_type=QuantType.QInt8
)
print("✅ 模型量化成功")

# 推理测试
sess = ort.InferenceSession(quantized_onnx_path)
ort_inputs = {
    "input_ids": inputs["input_ids"].numpy(),
    "attention_mask": inputs["attention_mask"].numpy()
}
ort_outputs = sess.run(None, ort_inputs)
print("✅ 推理测试成功,输出:", ort_outputs[0])
相关推荐
Ronaldinho Gaúch4 小时前
梯度消失与梯度爆炸
人工智能·深度学习·机器学习
查无此人byebye4 小时前
硬核深度解析:KimiDeltaAttention 源码逐行精读+公式推导+复杂度优化(完整可运行)
人工智能·深度学习·神经网络·自然语言处理
丰海洋4 小时前
Transformer参数量
人工智能·深度学习·transformer
chools4 小时前
Java后端拥抱AI开发之个人学习路线 - - Spring AI【第三期】(向量数据库 + RAG检索增强生成)
java·人工智能·学习·spring·ai
tianbaolc4 小时前
Claude Code 源码剖析 模块一 · 第一节:Claude Code 宏观架构
人工智能·ai·架构·claude code
温九味闻醉4 小时前
人工智能应用作业1:PPO强化学习算法
人工智能·算法
安科士andxe4 小时前
实践指南|安科士SFP-10/25G-LR-S-I光模块部署与运维技巧
运维·人工智能·5g
AI360labs_atyun4 小时前
我在命令行里养了只电子宠物,还顺便学会了Claude Code
人工智能·科技·学习·ai·宠物
ab1237684 小时前
C++ size() 与 length() 核心笔记
开发语言·c++·笔记
dydm_131284 小时前
笔尖下的奇迹:当AI实时绘画“撞见”未来教育
人工智能