export_onnx.py_0130

import argparse

import os

from pathlib import Path

import torch

from models.yolo import Model

from utils.general import check_file, set_logging

from utils.torch_utils import select_device, intersect_dicts

def load_model(weights, cfg, device):

ckpt = torch.load(weights, map_location=device) if weights else None

if cfg:

model_yaml = cfg

elif ckpt and isinstance(ckpt, dict) and "model" in ckpt and hasattr(ckpt["model"], "yaml"):

model_yaml = ckpt["model"].yaml

else:

raise ValueError("cfg 未指定,且权重中没有可用的 yaml")

model = Model(model_yaml, ch=3, nc=None, anchors=None).to(device)

if ckpt:

if isinstance(ckpt, dict) and "model" in ckpt and hasattr(ckpt["model"], "state_dict"):

state = ckpt["model"].float().state_dict()

elif isinstance(ckpt, dict) and "state_dict" in ckpt:

state = ckpt["state_dict"]

else:

state = ckpt.state_dict() if hasattr(ckpt, "state_dict") else None

if state:

state = intersect_dicts(state, model.state_dict(), exclude=[])

model.load_state_dict(state, strict=False)

model.eval()

return model

def export_onnx(model, onnx_path, img_size, opset, simplify):

h, w = img_size

dummy = torch.zeros(1, 3, h, w, device=next(model.parameters()).device)

onnx_path = str(onnx_path)

torch.onnx.export(

model,

dummy,

onnx_path,

export_params=True,

opset_version=opset,

do_constant_folding=True,

input_names=["images"],

output_names=["det", "seg"],

dynamic_axes=None,

)

if simplify:

try:

import onnx

from onnxsim import simplify as onnx_simplify

onnx_model = onnx.load(onnx_path)

onnx_model, check = onnx_simplify(

onnx_model, input_shapes={"images": [1, 3, h, w]}

)

if not check:

raise RuntimeError("onnxsim 简化失败")

onnx.save(onnx_model, onnx_path)

except Exception as e:

print(f"[WARN] onnxsim 失败: {e}")

def parse_args():

parser = argparse.ArgumentParser()

parser.add_argument("--weights", type=str, default="", help="weights path (.pt)")

parser.add_argument("--cfg", type=str, default="", help="model.yaml path")

parser.add_argument("--img-size", nargs="+", type=int, default=[544, 480], help="input size [h w]")

parser.add_argument("--opset", type=int, default=11, help="onnx opset version")

parser.add_argument("--device", default="0", help="cuda device, i.e. 0 or cpu")

parser.add_argument("--output", type=str, default="", help="output onnx path")

parser.add_argument("--simplify", action="store_true", help="run onnxsim simplify")

return parser.parse_args()

def main():

opt = parse_args()

set_logging()

if opt.weights:

opt.weights = check_file(opt.weights)

if opt.cfg:

opt.cfg = check_file(opt.cfg)

if len(opt.img_size) == 1:

opt.img_size = [opt.img_size[0], opt.img_size[0]]

elif len(opt.img_size) > 2:

opt.img_size = opt.img_size[:2]

device = select_device(opt.device)

model = load_model(opt.weights, opt.cfg, device)

out_path = opt.output or (Path(opt.weights).with_suffix(".onnx") if opt.weights else Path("model.onnx"))

out_path = Path(out_path)

out_path.parent.mkdir(parents=True, exist_ok=True)

export_onnx(model, out_path, opt.img_size, opt.opset, opt.simplify)

print(f"ONNX 导出完成: {out_path}")

if name == "main":

main()

======================

onnx_postprocess_fix_slice.py

import argparse

import sys

from pathlib import Path

import numpy as np

import onnx

from onnx import helper, numpy_helper, shape_inference

def _const_attr_to_tensor(node):

"""Convert Constant node attributes to TensorProto if possible."""

attrs = {a.name: a for a in node.attribute}

if "value" in attrs:

return attrs["value"].t

if "value_int" in attrs:

return helper.make_tensor(

name=node.output[0],

data_type=onnx.TensorProto.INT64,

dims=[],

vals=[attrs["value_int"].i],

)

if "value_float" in attrs:

return helper.make_tensor(

name=node.output[0],

data_type=onnx.TensorProto.FLOAT,

dims=[],

vals=[attrs["value_float"].f],

)

if "value_ints" in attrs:

vals = list(attrs["value_ints"].ints)

return helper.make_tensor(

name=node.output[0],

data_type=onnx.TensorProto.INT64,

dims=[len(vals)],

vals=vals,

)

if "value_floats" in attrs:

vals = list(attrs["value_floats"].floats)

return helper.make_tensor(

name=node.output[0],

data_type=onnx.TensorProto.FLOAT,

dims=[len(vals)],

vals=vals,

)

if "value_string" in attrs:

vals = [attrs["value_string"].s]

return helper.make_tensor(

name=node.output[0],

data_type=onnx.TensorProto.STRING,

dims=[1],

vals=vals,

)

return None

def hoist_constant_nodes_to_initializers(model):

"""Replace Constant nodes with initializers."""

graph = model.graph

const_nodes = []

new_initializers = []

for node in graph.node:

if node.op_type != "Constant":

continue

t = _const_attr_to_tensor(node)

if t is None:

continue

t.name = node.output[0]

new_initializers.append(t)

const_nodes.append(node)

if const_nodes:

graph.initializer.extend(new_initializers)

for n in const_nodes:

graph.node.remove(n)

return len(const_nodes)

def ensure_slice_inputs_are_const(model):

"""Report Slice nodes with non-constant inputs."""

graph = model.graph

init_names = {i.name for i in graph.initializer}

producer = {}

for node in graph.node:

for out in node.output:

producer[out] = (node.name or "", node.op_type)

bad = []

for node in graph.node:

if node.op_type != "Slice":

continue

for idx in range(1, len(node.input)):

name = node.input[idx]

if name and name not in init_names:

prod = producer.get(name, ("", ""))

bad.append((node.name or node.output[0], idx, name, prod[0], prod[1]))

return bad

def try_optimizer(model):

try:

import onnxoptimizer

passes = [

"eliminate_deadend",

"eliminate_identity",

"eliminate_nop_transpose",

"eliminate_unused_initializer",

"fuse_consecutive_transposes",

"fold_constant",

]

return onnxoptimizer.optimize(model, passes)

except Exception as e:

print(f"[WARN] onnxoptimizer 失败: {e}")

return model

def try_simplify(model, input_shape):

try:

from onnxsim import simplify

model_s, check = simplify(model, input_shapes={"images": input_shape})

if not check:

raise RuntimeError("onnxsim simplify check failed")

return model_s

except Exception as e:

print(f"[WARN] onnxsim 失败: {e}")

return model

def main():

parser = argparse.ArgumentParser()

parser.add_argument("--input", required=True, help="input onnx path")

parser.add_argument("--output", required=True, help="output onnx path")

parser.add_argument("--img-size", nargs="+", type=int, default=[544, 480], help="input size [h w]")

parser.add_argument("--simplify", action="store_true", help="run onnxsim simplify")

parser.add_argument("--no-optimizer", action="store_true", help="skip onnxoptimizer passes")

args = parser.parse_args()

in_path = Path(args.input)

out_path = Path(args.output)

out_path.parent.mkdir(parents=True, exist_ok=True)

model = onnx.load(str(in_path))

try:

model = shape_inference.infer_shapes(model)

except Exception as e:

print(f"[WARN] shape_inference 失败: {e}")

if not args.no_optimizer:

model = try_optimizer(model)

const_count = hoist_constant_nodes_to_initializers(model)

print(f"[INFO] Hoisted Constant nodes: {const_count}")

if args.simplify:

h, w = args.img_size[:2] if len(args.img_size) > 1 else (args.img_size[0], args.img_size[0])

model = try_simplify(model, [1, 3, h, w])

bad_slices = ensure_slice_inputs_are_const(model)

if bad_slices:

print("[WARN] 仍存在 Slice 非常量输入:")

for name, idx, inp, prod_name, prod_type in bad_slices[:20]:

prod_info = f"{prod_type}:{prod_name}" if (prod_type or prod_name) else "Unknown"

print(f" - Slice={name} input[{idx}]={inp} (from {prod_info})")

if len(bad_slices) > 20:

print(f" ... and {len(bad_slices) - 20} more")

else:

print("[INFO] 所有 Slice 参数已为常量/initializer")

onnx.save(model, str(out_path))

print(f"[OK] 保存完成: {out_path}")

if name == "main":

main()

========================

原因是 TIDL 对 ONNX 里的 Slice 有严格限制:

它要求 Slice 的 starts/ends/axes/steps 必须是常量 initializer 且类型为 INT64。

而你当前导出的 YOLO11 ONNX 里出现了非常量 Slice(或 Slice 参数不是 INT64),所以提示:

  • Could not find const or initializer of layer Slice_47
  • Only float and INT64 tensor is supported
  • 最后 tidl_model_import 崩溃

这通常由 YOLO11 新增的 PSA/Attention(C2PSA) 或 动态 shape 导出引入的切片导致。


解决思路(推荐顺序)

方案 A:重新导出 ONNX(不改模型)

  1. 关闭动态 shape

导出 ONNX 时不要设置 dynamic_axes,固定输入尺寸 544x480。

  1. 开启常量折叠

torch.onnx.export(..., do_constant_folding=True)

  1. 使用 onnxsim 简化 python -m onnxsim input.onnx output.onnx --input-shape 1,3,544,480

这样能把很多 Slice 的参数折叠成常量。

> 如果按这套流程后 still 报 Slice,说明模型结构里确实有 TIDL 不支持的 Slice。


方案 B:让 backbone 更 "TIDL 友好"

YOLO11 的 C2PSA 很可能触发 Slice(注意力结构里常见 split/slice)。

你可以把 yolov11_custom_seg_big.yaml 里这行:

-1, 2, C2PSA, \[1024\]

换成更简单的结构,比如:

-1, 2, C3k2, \[1024, True\]

这样能显著减少 Slice 相关算子,TIDL 通常就能吃下。


建议

先试 方案 A(最小改动),如果仍然报 Slice,就走 方案 B。

如果你愿意,我可以直接帮你生成一个 yolov11_custom_seg_big_tidl.yaml(去掉 C2PSA 的版本),或者帮你写一套 ONNX 导出脚本确保常量折叠。

已新增导出脚本 tda4/export_onnx.py,默认开启 常量折叠(do_constant_folding=True),并可选 onnxsim 简化。

用法示例(固定输入 544x480,opset11)

python export_onnx.py \

--weights runs/train/exp/weights/last.pt \

--cfg models/yolov11_custom_seg_big.yaml \

--img-size 544 480 \

--opset 11 \

--output ../output_yolov11_8/yolov11_last_8.onnx \

--simplify

说明:

  • --simplify 会尝试用 onnxsim 进一步折叠常量(没装 onnxsim 会提示但不影响导出)
  • 未传 --output 时默认生成 weights 同名 .onnx
  • 默认固定输入尺寸,不使用动态 shape,减少 Slice 相关问题

你先用这套导出再跑 TIDL,看是否还报 Slice。如果还有,我再帮你做"去掉 C2PSA 的 TIDL 兼容版本"配置。

万一方案A失效,因为我的模型已经训练完了,有没有什么成本尽可能小的方法,可以让我顺利导出模型?

可以,最小成本的一般做法是不改训练结果、不重训,而是对导出的 ONNX 做"后处理",把不被 TIDL 接受的 Slice 变成常量参数。常见有效路径如下(从成本最低到高):

低成本优先方案

  1. ONNX 后处理 + 常量折叠
  • onnxsim + onnxoptimizer + shape_inference
  • 目标:把 Slice 的 starts/ends/axes/steps 变成 initializer(常量)
  • 不改模型、只改图,成本最低
  1. 用 onnx-graphsurgeon 直接改 Slice
  • 找到非 const 的 Slice 输入,按固定输入形状把它们"烘焙"为常量
  • 这类脚本通常对 TIDL 最有效
  1. 固定输入尺寸重新导出 + 再做优化
  • 你已经固定输入尺寸了,但仍可加上更强的后处理链
  • 这属于不重训、只再导出

仍失败时的"次低成本"

  1. 导出时替换问题模块为等价简化版本
  • 例如把 C2PSA 替换为轻量 C3k2(只在导出时替换,训练权重不变)
  • 会有轻微精度变化,但比重训成本小

如果你愿意,我可以直接给你一个"ONNX 后处理脚本"(自动修复 Slice 参数为常量,并保存新 onnx),基本不改你训练结果。

你只需要提供导出的 .onnx 路径即可。

已新增 ONNX 后处理脚本 tda4/onnx_postprocess_fix_slice.py,会自动把 Constant 节点提升为 initializer,并尽量让 Slice 的参数变成常量,同时可选运行 onnxoptimizer 和 onnxsim。

用法

python onnx_postprocess_fix_slice.py \

--input ../output_yolov11_8/yolov11_last_8.onnx \

--output ../output_yolov11_8/yolov11_last_8_fixed.onnx \

--img-size 544 480 \

--simplify

说明

  • 默认会运行 constant folding + Constant 节点提升
  • --simplify 会调用 onnxsim(没装也不会中断,只会提示)
  • 如果仍有 Slice 参数非常量,会在日志中列出前 20 个

如果修完仍然报 Slice,我再给你加一版"强制替换 C2PSA 的导出替身",保证 TIDL 能过。

请问现在的Slice问题是出在yolov11的backbone的哪一层?

目前无法直接从训练代码判断 具体是哪一层 导致 Slice 非常量,但在 YOLO11 里最常见的来源是 C2PSA/PSA 注意力模块(内部会用到 split/chunk,导出为 Slice)。

我已经把脚本增强了:现在会输出 Slice 的输入是谁产生的(op_type + 节点名),你一跑就能定位到具体来源,再去对应到 backbone 层。

用法(会打印 Slice 来源)

python onnx_postprocess_fix_slice.py \

--input ../output_yolov11_8/yolov11_last_8.onnx \

--output ../output_yolov11_8/yolov11_last_8_fixed.onnx \

--img-size 544 480

输出示例会类似:

Slice=Slice_47 input[1]=... (from Split:xxx)

你把这段输出发我,我可以帮你精确定位到 具体 backbone 层序号(例如"C2PSA 这一层"或"某个 C3k2 内部")。

万一我考虑方案B,准备利用已经训练好的模型继续训练,有什么低成本的方法来解决吗

可以,低成本方式是"结构小改 + 权重最大化继承 + 小步微调"。方案如下:

低成本做法(推荐)

  1. 只在配置里把 C2PSA 换成 C3k2

训练权重不丢,其他层结构不变。

  1. 用当前训练好的权重继续训练

加载方式还是:

  • --weights 旧模型全量
  • --backbone-weights 如果需要可继续覆盖 backbone(可选)
  1. 短期微调

先冻结 backbone 5~10 epoch,再全量解冻训练 20~50 epoch。

这样能稳定适应新 backbone 小改动。

> 这是"改结构最小 + 训练成本最小"的路径。


如果你愿意,我可以直接:

  • 给你生成一个 yolov11_custom_seg_big_tidl.yaml(只改 C2PSA→C3k2)
  • 以及把 train2.py 增加一个 --freeze-backbone-epochs 的选项,自动分阶段训练

你告诉我是否需要,我就直接改。

为什么先冻结 backbone 5~10 epoch,再全量解冻训练 20~50 epoch,能稳定适应新 backbone 小改动?如果我希望训练时间尽可能短,能否不冻结,直接全部参数一起训练?

先冻结再解冻的作用主要是稳定收敛,原因有三点:

  1. 新旧结构不完全一致

你替换 C2PSA 后,backbone 的部分层是"新初始化/部分对不上"的,直接全量训练时,随机权重会对 head/分割分支的特征分布造成冲击,容易出现损失不稳定或收敛慢。

  1. 保护已训练好的 head/分割

你的 head/分割权重来自旧模型,先冻结 backbone,让 head/分割先"适配新特征分布",能减少大幅震荡。

  1. 降低 early-stage 震荡

冻结的前几轮相当于"让上层先稳定",再逐步解冻让整体联合优化。


如果你想训练时间尽可能短

可以不冻结,直接全量训练,是允许的。

但建议加一个低成本保险:

  • 把学习率降低一些(例如 lr0 乘 0.5)
  • 或者前 2~3 epoch 用更小的 lr,再恢复

这样可以减少因为新 backbone 结构变化带来的梯度冲击。


总结:

  • 冻结:更稳,但总训练时间略长
  • 不冻结:更快,但波动可能大,需降低初期学习率

你如果要我帮你改成"全量训练 + 自动降低初期 lr",我可以直接改代码。

好的,给你生成一个 yolov11_custom_seg_big_tidl.yaml(只改 C2PSA→C3k2)

以及把 train2.py 增加一个 --freeze-backbone-epochs 的选项,自动分阶段训练

已完成两件事:

  1. 新增 tda4/models/yolov11_custom_seg_big_tidl.yaml
  • 仅把 backbone 中的 C2PSA 替换为 C3k2,其余结构保持一致,兼容 TIDL。
  1. train2.py 新增 --freeze-backbone-epochs
  • 自动分阶段训练:前 N 个 epoch 冻结 backbone,之后自动解冻。
  • 支持 resume 且旧 opt.yaml 不会报错(缺失字段会自动补默认值)。

用法示例

python train2.py \

--cfg models/yolov11_custom_seg_big_tidl.yaml \

--weights /path/to/old_multitask.pt \

--backbone-weights /path/to/yolov11s.pt \

--freeze-backbone-epochs 10

  • --freeze-backbone-epochs 0 表示不冻结(默认 0)
  • 如果只想导出、减少 Slice 问题,直接用这个新 yaml 训练/导出即可

如果你也希望把 train2_baseon_train608736.py 同步增加该参数,我可以一起补上。

我想基于已经训练好的替换了yolov11backbone的模型减少 Slice 问题,我应该是传入我的模型,然后重新训练而不是resume,对吗

对的。

你现在是结构有变化(C2PSA → C3k2),因此应该用 --weights 加载已有模型权重 进行继续训练/微调,而不是 --resume。

  • --resume 只适合结构完全一致、要恢复优化器/学习率/epoch 的断点续训
  • 结构改了,用 --weights 只加载能匹配的权重,更稳妥

建议命令(示例):

python train2.py \

--cfg models/yolov11_custom_seg_big_tidl.yaml \

--weights /path/to/your_trained_yolov11.pt \

--freeze-backbone-epochs 5

如果你想最短训练时间,可以把 --freeze-backbone-epochs 0 去掉或设为 0。

为了解决Slice问题,你提出的 export_onnx.py 方法和 onnx_postprocess_fix_slice.py 方法,会不会影响模型的性能,会不会影响模型的推理结果

一般来说,这两种方法不应该改变模型的推理结果,前提是它们只是做"图优化/常量折叠",没有改算子逻辑。

  • export_onnx.py:只是固定输入尺寸 + 常量折叠,属于 ONNX 导出的标准优化,输出应数值等价。
  • onnx_postprocess_fix_slice.py:主要是把 Constant 节点提升为 initializer,再配合 onnxoptimizer/onnxsim 做折叠,不改变计算图的数学含义。

不过要注意两点:

  1. 如果 onnxsim/optimizer 失败或不完整,可能仍然有 Slice 问题,但不会主动改变精度。

  2. 如果后处理过程中出现精度变化,通常是因为:

  • 输入动态形状被固定(你已经固定输入尺寸,这点OK)
  • 或者工具错误地折叠了与 shape 相关的动态逻辑(少见)

建议验证方式

导出后做一次数值对齐验证:

  • 用 PyTorch 和 ONNX 各跑一张图,比较输出差异
  • 如果差异很小(比如平均误差 <1e-4),可以放心用

如果你愿意,我可以给你一个简单的"PyTorch vs ONNX 输出一致性检查脚本"。

相关推荐
s1hiyu2 小时前
使用Scrapy框架构建分布式爬虫
jvm·数据库·python
2301_763472462 小时前
使用Seaborn绘制统计图形:更美更简单
jvm·数据库·python
无垠的广袤2 小时前
【VisionFive 2 Lite 单板计算机】边缘AI视觉应用部署:缺陷检测
linux·人工智能·python·opencv·开发板
Duang007_2 小时前
【LeetCodeHot100 超详细Agent启发版本】字母异位词分组 (Group Anagrams)
开发语言·javascript·人工智能·python
工程师老罗3 小时前
反向传播及其用法
pytorch
浒畔居3 小时前
机器学习模型部署:将模型转化为Web API
jvm·数据库·python
抠头专注python环境配置3 小时前
基于Pytorch ResNet50 的珍稀野生动物识别系统(Python源码 + PyQt5 + 数据集)
pytorch·python
百***78753 小时前
Kimi K2.5开源模型实战指南:核心能力拆解+一步API接入(Python版,避坑全覆盖)
python·microsoft·开源
喵手3 小时前
Python爬虫实战:针对天文历法网站(以 TimeandDate 或类似的静态历法页为例),构建高精度二十四节气天文数据采集器(附xlsx导出)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集天文历法网站数据·构建二十四节气天文数据