bash
import onnx
from onnx import helper, checker, TensorProto, numpy_helper
from collections import deque
import warnings
import numpy as np
def fix_split_attributes(model):
"""
将模型中所有 Split 节点的 'split' 属性转换为常量输入(兼容 opset >=13)。
"""
graph = model.graph
nodes = list(graph.node)
modified = False
# 收集所有初始化的常量名称
initializer_names = set(init.name for init in graph.initializer)
new_nodes = []
for node in nodes:
if node.op_type == "Split":
# 检查是否有 split 属性
attrs = {attr.name: attr for attr in node.attribute}
if "split" in attrs:
split_attr = attrs["split"]
split_vals = list(split_attr.ints) # split 属性是 ints 列表
if len(split_vals) == 0:
# 没有指定 split,则保持原样(这是合法的,表示均匀分割)
new_nodes.append(node)
continue
# 创建一个常量节点作为 splits 输入
splits_tensor_name = node.name + "_splits_const" if node.name else "split_const_" + str(id(node))
# 确保名称唯一
counter = 0
while splits_tensor_name in initializer_names:
splits_tensor_name = f"{splits_tensor_name}_{counter}"
counter += 1
# 将 split_vals 转为 numpy 数组,再转为 initializer
splits_array = np.array(split_vals, dtype=np.int64)
splits_initializer = numpy_helper.from_array(splits_array, name=splits_tensor_name)
graph.initializer.append(splits_initializer)
initializer_names.add(splits_tensor_name)
# 修改节点:添加 splits 输入,删除 split 属性
# 原有输入保持不变,新输入加在末尾
new_inputs = list(node.input) + [splits_tensor_name]
new_node = helper.make_node(
op_type="Split",
inputs=new_inputs,
outputs=list(node.output),
name=node.name,
domain=node.domain
)
# 复制其他属性(排除 split)
for attr in node.attribute:
if attr.name != "split":
new_node.attribute.append(attr)
new_nodes.append(new_node)
modified = True
print(f"修复 Split 节点: {node.name or '(unnamed)'}, splits={split_vals}")
continue
# 非 Split 或无 split 属性的节点,直接保留
new_nodes.append(node)
if modified:
graph.ClearField('node')
graph.node.extend(new_nodes)
# 可选:对 value_info 进行轻微清理(不必须)
print("Split 属性转换完成。")
else:
print("没有发现需要修复的 Split 节点。")
return model
def prune_model_to_outputs(input_model_path, output_model_path, target_output_names, target_opset=21):
"""
裁剪模型并修复可能的 Split 属性不兼容问题。
"""
# 1. 加载模型
model = onnx.load(input_model_path)
graph = model.graph
# 2. 构建节点索引映射
node_list = list(graph.node)
nodes_by_output = {}
for idx, node in enumerate(node_list):
for out in node.output:
nodes_by_output[out] = idx
# 3. 反向收集必需节点
required_indices = set()
queue = deque()
for output_name in target_output_names:
if output_name in nodes_by_output:
idx = nodes_by_output[output_name]
if idx not in required_indices:
required_indices.add(idx)
queue.append(idx)
else:
warnings.warn(f"输出张量 '{output_name}' 没有找到生产者节点,可能已是模型输入。")
while queue:
node_idx = queue.popleft()
node = node_list[node_idx]
for input_name in node.input:
if not input_name:
continue
if input_name in nodes_by_output:
producer_idx = nodes_by_output[input_name]
if producer_idx not in required_indices:
required_indices.add(producer_idx)
queue.append(producer_idx)
# 4. 构建新节点列表
new_nodes = [node_list[i] for i in sorted(required_indices)]
# 5. 收集使用到的张量名
used_tensor_names = set()
for node in new_nodes:
used_tensor_names.update(node.input)
used_tensor_names.update(node.output)
used_tensor_names.update(target_output_names)
for inp in graph.input:
used_tensor_names.add(inp.name)
# 6. 保留相关 initializer 和 value_info
new_initializers = [init for init in graph.initializer if init.name in used_tensor_names]
new_value_info = [vi for vi in graph.value_info if vi.name in used_tensor_names]
# 7. 构建新输出 value_info
def get_tensor_info(name):
for vi in graph.value_info:
if vi.name == name:
return vi
for out in graph.output:
if out.name == name:
return out
return helper.make_tensor_value_info(name, TensorProto.FLOAT, [None, None, None, None])
new_outputs = [get_tensor_info(name) for name in target_output_names]
# 8. 创建新图
new_graph = helper.make_graph(
nodes=new_nodes,
name=graph.name + "_pruned",
inputs=list(graph.input),
outputs=new_outputs,
initializer=new_initializers,
value_info=new_value_info
)
# 9. 创建新模型
new_model = helper.make_model(new_graph, producer_name="onnx_prune_tool")
new_model.opset_import.extend(model.opset_import)
# 10. 修复 Split 属性
print("正在修复 Split 节点的 split 属性...")
new_model = fix_split_attributes(new_model)
# 11. 降级 opset(如果需要)
current_opset = None
for imp in new_model.opset_import:
if imp.domain == "" or imp.domain == "ai.onnx":
current_opset = imp.version
break
if current_opset is None:
new_model.opset_import.append(helper.make_opsetid("", target_opset))
current_opset = target_opset
print(f"当前模型默认 opset: {current_opset}")
if current_opset > target_opset:
print(f"尝试将 opset 从 {current_opset} 降级到 {target_opset} ...")
try:
# 使用 version_converter 降级
from onnx import version_converter
converted_model = version_converter.convert_version(new_model, target_opset)
new_model = converted_model
print(f"成功降级到 opset {target_opset}")
except Exception as e:
warnings.warn(f"opset 自动降级失败: {e}\n将强制修改 opset 版本号(可能有风险)")
for imp in new_model.opset_import:
if imp.domain == "" or imp.domain == "ai.onnx":
imp.version = target_opset
else:
print(f"当前 opset {current_opset} <= {target_opset},无需降级。")
# 12. 最终验证
try:
checker.check_model(new_model)
print("模型检查通过。")
except Exception as e:
print(f"模型检查失败: {e}")
print("精简后模型可能无效,但将尝试保存。")
# 13. 保存
onnx.save(new_model, output_model_path)
print(f"模型已保存至: {output_model_path}")
print(f"原始节点数: {len(node_list)}, 保留节点数: {len(new_nodes)}")
if __name__ == "__main__":
input_onnx = "yolo11s-obb.onnx" # 原始模型
output_onnx = "yolo11s-obb_fixed.onnx"
target_outputs = [
"/model.23/Sigmoid_output_0",
"/model.23/Concat_3_output_0",
"/model.23/Concat_2_output_0",
"/model.23/Concat_1_output_0"
]
prune_model_to_outputs(input_onnx, output_onnx, target_outputs, target_opset=21)
原始导出的onnx只有一个输出,包括一些后处理,rk需要搞成4个输出,输出的节点在代码中,后处理放到外面去,不在npu做。坑人的是这个搞成4个输出的,网上没有一个人说怎么弄。