yolov11-obb在rk芯片部署的onnx模型输出的剪枝处理

bash 复制代码
import onnx
from onnx import helper, checker, TensorProto, numpy_helper
from collections import deque
import warnings
import numpy as np

def fix_split_attributes(model):
    """
    将模型中所有 Split 节点的 'split' 属性转换为常量输入(兼容 opset >=13)。
    """
    graph = model.graph
    nodes = list(graph.node)
    modified = False

    # 收集所有初始化的常量名称
    initializer_names = set(init.name for init in graph.initializer)

    new_nodes = []
    for node in nodes:
        if node.op_type == "Split":
            # 检查是否有 split 属性
            attrs = {attr.name: attr for attr in node.attribute}
            if "split" in attrs:
                split_attr = attrs["split"]
                split_vals = list(split_attr.ints)  # split 属性是 ints 列表
                if len(split_vals) == 0:
                    # 没有指定 split,则保持原样(这是合法的,表示均匀分割)
                    new_nodes.append(node)
                    continue
                
                # 创建一个常量节点作为 splits 输入
                splits_tensor_name = node.name + "_splits_const" if node.name else "split_const_" + str(id(node))
                # 确保名称唯一
                counter = 0
                while splits_tensor_name in initializer_names:
                    splits_tensor_name = f"{splits_tensor_name}_{counter}"
                    counter += 1
                
                # 将 split_vals 转为 numpy 数组,再转为 initializer
                splits_array = np.array(split_vals, dtype=np.int64)
                splits_initializer = numpy_helper.from_array(splits_array, name=splits_tensor_name)
                graph.initializer.append(splits_initializer)
                initializer_names.add(splits_tensor_name)
                
                # 修改节点:添加 splits 输入,删除 split 属性
                # 原有输入保持不变,新输入加在末尾
                new_inputs = list(node.input) + [splits_tensor_name]
                new_node = helper.make_node(
                    op_type="Split",
                    inputs=new_inputs,
                    outputs=list(node.output),
                    name=node.name,
                    domain=node.domain
                )
                # 复制其他属性(排除 split)
                for attr in node.attribute:
                    if attr.name != "split":
                        new_node.attribute.append(attr)
                new_nodes.append(new_node)
                modified = True
                print(f"修复 Split 节点: {node.name or '(unnamed)'}, splits={split_vals}")
                continue
        # 非 Split 或无 split 属性的节点,直接保留
        new_nodes.append(node)

    if modified:
        graph.ClearField('node')
        graph.node.extend(new_nodes)
        # 可选:对 value_info 进行轻微清理(不必须)
        print("Split 属性转换完成。")
    else:
        print("没有发现需要修复的 Split 节点。")
    return model

def prune_model_to_outputs(input_model_path, output_model_path, target_output_names, target_opset=21):
    """
    裁剪模型并修复可能的 Split 属性不兼容问题。
    """
    # 1. 加载模型
    model = onnx.load(input_model_path)
    graph = model.graph

    # 2. 构建节点索引映射
    node_list = list(graph.node)
    nodes_by_output = {}
    for idx, node in enumerate(node_list):
        for out in node.output:
            nodes_by_output[out] = idx

    # 3. 反向收集必需节点
    required_indices = set()
    queue = deque()
    for output_name in target_output_names:
        if output_name in nodes_by_output:
            idx = nodes_by_output[output_name]
            if idx not in required_indices:
                required_indices.add(idx)
                queue.append(idx)
        else:
            warnings.warn(f"输出张量 '{output_name}' 没有找到生产者节点,可能已是模型输入。")

    while queue:
        node_idx = queue.popleft()
        node = node_list[node_idx]
        for input_name in node.input:
            if not input_name:
                continue
            if input_name in nodes_by_output:
                producer_idx = nodes_by_output[input_name]
                if producer_idx not in required_indices:
                    required_indices.add(producer_idx)
                    queue.append(producer_idx)

    # 4. 构建新节点列表
    new_nodes = [node_list[i] for i in sorted(required_indices)]

    # 5. 收集使用到的张量名
    used_tensor_names = set()
    for node in new_nodes:
        used_tensor_names.update(node.input)
        used_tensor_names.update(node.output)
    used_tensor_names.update(target_output_names)
    for inp in graph.input:
        used_tensor_names.add(inp.name)

    # 6. 保留相关 initializer 和 value_info
    new_initializers = [init for init in graph.initializer if init.name in used_tensor_names]
    new_value_info = [vi for vi in graph.value_info if vi.name in used_tensor_names]

    # 7. 构建新输出 value_info
    def get_tensor_info(name):
        for vi in graph.value_info:
            if vi.name == name:
                return vi
        for out in graph.output:
            if out.name == name:
                return out
        return helper.make_tensor_value_info(name, TensorProto.FLOAT, [None, None, None, None])

    new_outputs = [get_tensor_info(name) for name in target_output_names]

    # 8. 创建新图
    new_graph = helper.make_graph(
        nodes=new_nodes,
        name=graph.name + "_pruned",
        inputs=list(graph.input),
        outputs=new_outputs,
        initializer=new_initializers,
        value_info=new_value_info
    )

    # 9. 创建新模型
    new_model = helper.make_model(new_graph, producer_name="onnx_prune_tool")
    new_model.opset_import.extend(model.opset_import)

    # 10. 修复 Split 属性
    print("正在修复 Split 节点的 split 属性...")
    new_model = fix_split_attributes(new_model)

    # 11. 降级 opset(如果需要)
    current_opset = None
    for imp in new_model.opset_import:
        if imp.domain == "" or imp.domain == "ai.onnx":
            current_opset = imp.version
            break
    if current_opset is None:
        new_model.opset_import.append(helper.make_opsetid("", target_opset))
        current_opset = target_opset

    print(f"当前模型默认 opset: {current_opset}")
    if current_opset > target_opset:
        print(f"尝试将 opset 从 {current_opset} 降级到 {target_opset} ...")
        try:
            # 使用 version_converter 降级
            from onnx import version_converter
            converted_model = version_converter.convert_version(new_model, target_opset)
            new_model = converted_model
            print(f"成功降级到 opset {target_opset}")
        except Exception as e:
            warnings.warn(f"opset 自动降级失败: {e}\n将强制修改 opset 版本号(可能有风险)")
            for imp in new_model.opset_import:
                if imp.domain == "" or imp.domain == "ai.onnx":
                    imp.version = target_opset
    else:
        print(f"当前 opset {current_opset} <= {target_opset},无需降级。")

    # 12. 最终验证
    try:
        checker.check_model(new_model)
        print("模型检查通过。")
    except Exception as e:
        print(f"模型检查失败: {e}")
        print("精简后模型可能无效,但将尝试保存。")

    # 13. 保存
    onnx.save(new_model, output_model_path)
    print(f"模型已保存至: {output_model_path}")
    print(f"原始节点数: {len(node_list)}, 保留节点数: {len(new_nodes)}")

if __name__ == "__main__":
    input_onnx = "yolo11s-obb.onnx"       # 原始模型
    output_onnx = "yolo11s-obb_fixed.onnx"
    target_outputs = [
        "/model.23/Sigmoid_output_0",
        "/model.23/Concat_3_output_0",
        "/model.23/Concat_2_output_0",
        "/model.23/Concat_1_output_0"
    ]
    prune_model_to_outputs(input_onnx, output_onnx, target_outputs, target_opset=21)

原始导出的onnx只有一个输出,包括一些后处理,rk需要搞成4个输出,输出的节点在代码中,后处理放到外面去,不在npu做。坑人的是这个搞成4个输出的,网上没有一个人说怎么弄。

相关推荐
KaMeidebaby13 小时前
卡梅德生物技术快报|糖蛋白纯化 Sevage 法工艺优化:正交与响应面法对比实操分析
人工智能·其他·算法·百度·新浪微博
前网易架构师-高司机13 小时前
ROS2 Jazzy+Gazebo Harmonic 环境下,用 URDF 搭建机器人,配置物理属性、插件与桥接,修复车轮和激光雷达故障 (手把手保姆级教程)
开发语言·算法·golang·机器人·ros
视觉算法小姥14 小时前
YOLOV11-OBB之ONNX转RKNN并跑在模拟器上
yolo
wjcroom14 小时前
时空和电子1-平直相对论时空的构建
算法·重构·物理学
吃好睡好便好14 小时前
矩阵的求幂运算
人工智能·学习·线性代数·算法·matlab·矩阵
计算机安禾14 小时前
【算法分析与设计】第18篇:改进的最大流算法:Edmonds-Karp与Dinic
大数据·人工智能·算法
buhuizhiyuci14 小时前
【算法篇】初识双指针
算法
牧鸯人14 小时前
基于yolov8的课堂行为检测系统——主要功能检测睡觉、手机、人数
python·深度学习·yolo·学生行为统计
超梦dasgg14 小时前
归并排序 Java 实现(递归 + 非递归)
java·算法·排序算法