onnx模型数据结构分析,用于解析onnx模型

文章目录

onnx 模型数据结构

onnx是基于protobuffer的, scheme 定义在:

https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3

onnx官方其实也有一篇文章在介绍onnx IR的结构: onnx IR

一、核心字段分类

ModelProto 的字段可分为基础标识 / 版本元数据计算图核心扩展能力四大类,以下是关键字段的详细说明:

可以看到整个模型是 ModelProto:

字段类别 核心字段 作用说明
版本与兼容性 ir_version 标识 ONNX 中间表示(IR)的版本(如 8/9/13 等),决定语法 / 语义兼容性
opset_import 模型依赖的算子集(OperatorSet)列表,格式为 {domain: version}(如 "" : 19 表示默认域的 19 版算子)
元数据 producer_name 生成模型的工具 / 框架名称(如 PyTorch/TensorFlow)
producer_version 生成工具的版本号
model_version 模型自身的版本(自定义整数,用于版本管理)
doc_string 模型的可读文档(支持 Markdown)
metadata_props 自定义键值对元数据(如作者、训练数据集、license 等)
计算图核心 graph 模型的计算逻辑核心,类型为 GraphProto(见下文拆解)
扩展能力 functions 模型内自定义函数(FunctionProto),用于封装可复用的算子逻辑
training_info 训练相关信息(TrainingInfoProto),存储训练步骤、优化器等训练态信息
二、关键嵌套结构详解

ModelProto 依赖多个子 Proto 结构实现完整能力,核心嵌套结构如下:

1. GraphProto(计算图核心)

graph 是 ModelProto 的核心,定义了模型的计算逻辑,其核心字段:

  • name:图的名称(唯一标识);
  • input/output:图的输入 / 输出张量定义(ValueInfoProto),包含张量名称、类型(TypeProto)、形状(TensorShapeProto);
  • node:计算节点列表(NodeProto),每个节点对应一个算子(如 Conv/Add/MatMul),包含 op_type(算子类型)、input/output(张量名)、attribute(算子属性,如卷积核大小);
  • initializer:模型的权重 / 常量张量(TensorProto),如卷积核、偏置等;
  • sparse_initializer:稀疏张量类型的权重(可选)。
字段名 数据类型 核心作用 补充说明
name string 计算图的唯一标识名称 "cspresnet50_graph",用于区分不同子图(如循环 / 条件子图)
input repeated ValueInfoProto 图的输入张量元信息(不含权重) 每个 ValueInfoProto 包含张量名、类型(TypeProto)、形状(TensorShapeProto),如模型输入 input.1
output repeated ValueInfoProto 图的输出张量元信息 如模型输出 output.1,定义输出张量的名称、类型、形状
node repeated NodeProto 计算图的核心算子节点(DAG 节点) 每个 NodeProto 对应一个算子(如 Conv/Relu),包含输入、输出、属性(AttributeProto
initializer repeated TensorProto 模型权重 / 常量张量(如卷积核、偏置) 存储张量的二进制数据(raw_data)、形状(dims)、数据类型(data_type),是模型参数的核心载体
sparse_initializer repeated SparseTensorProto 稀疏权重张量(如稀疏卷积的权重) 仅用于稀疏张量场景,包含非零值(values)、索引(indices)等
doc_string string 计算图的文档说明(备注 / 注释) 可选字段,用于描述图的用途、版本等
value_info repeated ValueInfoProto 图中中间张量的元信息(非输入输出 / 初始化器) 可选字段,用于补充中间张量(如 Conv 输出)的类型 / 形状信息
quantization_annotation repeated QuantizationAnnotationProto 量化相关标注(如张量的量化参数) 量化模型专用,描述张量的零点、缩放因子等
layout_annotation repeated LayoutAnnotationProto 张量布局标注(如 NCHW/NHWC) 可选字段,用于指定张量的维度顺序
metadata_props repeated StringStringEntryProto 自定义元数据(键值对) "author": "alan""version": "1.0"
functions repeated FunctionProto 图内自定义算子函数 用于封装复用的算子逻辑,如自定义 Selu 函数
arg repeated AttributeProto 图级别的属性(极少使用) 区别于 NodeProto 的算子属性,用于全局配置
node
  • proto
protobuf 复制代码
message NodeProto {
  string name = 1;          // 节点名(可选)
  string op_type = 2;       // 算子类型(如 "Conv"/"Relu"/"MatMul")
  repeated string input = 3;// 输入张量名列表(如 ["input.1", "conv1.weight"])
  repeated string output = 4;// 输出张量名列表(如 ["conv1.output"])
  string domain = 5;        // 算子域(默认 "" 为 ONNX 标准算子,如 "com.microsoft" 为自定义算子)
  repeated AttributeProto attribute = 6; // 算子属性(如 Conv 的 stride/padding)
  string doc_string = 10;   // 节点说明
}
  • json
json 复制代码
    {
      "input": [
        "/stem/conv1/conv/Conv_output_0"
      ],
      "output": [
        "/stem/conv1/bn/act/LeakyRelu_output_0"
      ],
      "name": "/stem/conv1/bn/act/LeakyRelu",
      "op_type": "LeakyRelu",
      "attribute": [
        {
          "name": "alpha",
          "f": 0.01,
          "type": "FLOAT"
        }
      ]
    },
2. OperatorSetProto(算子集依赖)

opset_import 引用的算子集定义,核心字段:

  • domain:算子集的域名(默认域为空字符串 "",自定义算子可指定如 custom_domain);
  • version:算子集版本(如 15/19),决定算子的签名 / 语义。
3. TensorProto(张量定义)

用于描述权重、输入输出张量,核心字段:

  • data_type:张量元素类型(如 FLOAT/INT32/STRING,对应枚举值 1/6/7);
  • dims:张量形状(如 [3, 224, 224] 表示 3 通道 224x224 图像);
  • raw_data:张量的二进制数据(序列化后的权重值);
  • 辅助字段:name(张量名)、segment(稀疏张量分段信息)等。

onnx存储为json结构

python 复制代码
import onnx
import json
from google.protobuf.json_format import MessageToDict

model: onnx.ModelProto = onnx.load("/Users/alanchen/workspace/cspresnet50_Opset16.onnx")
graph: onnx.GraphProto = model.graph


def strip_tensor_weights(graph_dict: dict) -> dict:
    """
    剥离 Graph 字典中 Tensor 的权重数据,仅保留元信息(名称、形状、数据类型)
    """
    # 处理 initializer(模型权重/常量张量)
    if "initializer" in graph_dict:
        for tensor in graph_dict["initializer"]:
            # 移除存储权重的字段(按需扩展,覆盖所有数值存储字段)
            weight_fields = [
                "raw_data",  # 二进制权重(最常见)
                "float_data",  # float 类型数值列表
                "int32_data",  # int32 类型数值列表
                "int64_data",  # int64 类型数值列表
                "uint64_data",  # uint64 类型数值列表
                "bool_data",  # bool 类型数值列表
                "string_data",  # string 类型数值列表
                "double_data",  # double 类型数值列表
                "int8_data",  # int8 类型数值列表
                "uint8_data",  # uint8 类型数值列表
                "int16_data",  # int16 类型数值列表
                "uint16_data",  # uint16 类型数值列表
            ]
            for field in weight_fields:
                if field in tensor:
                    del tensor[field]

    # 可选:处理 sparse_initializer(稀疏张量权重,若有)
    if "sparse_initializer" in graph_dict:
        for sparse_tensor in graph_dict["sparse_initializer"]:
            if "values" in sparse_tensor:
                del sparse_tensor["values"]  # 移除稀疏张量的数值
            if "indices" in sparse_tensor:
                del sparse_tensor["indices"]  # 移除稀疏张量的索引

    return graph_dict


# 步骤1:将 GraphProto 转为字典(Protobuf 原生转换)
graph_dict = MessageToDict(graph, preserving_proto_field_name=True)

# 步骤2:剥离权重数据
graph_dict_stripped = strip_tensor_weights(graph_dict)

# 步骤3:导出为 JSON 文件(无权重)
with open("cspresnet50_graph_proto.json", "w", encoding="utf-8") as f:
    json.dump(graph_dict_stripped, f, ensure_ascii=False, indent=2, sort_keys=False)

print("已导出剥离权重的 Graph JSON 文件:cspresnet50_graph_proto_stripped.json")
相关推荐
@atweiwei2 小时前
Go语言面试篇数据结构底层原理精讲(下)
数据结构·面试·golang
CHANG_THE_WORLD2 小时前
PDFium 处理通用 `W` 数组的方式
数据结构·算法
北顾笙9803 小时前
day18-数据结构力扣
数据结构·算法·leetcode
charliejohn3 小时前
计算机考研 408 数据结构 排序算法
数据结构
汀、人工智能4 小时前
[特殊字符] 第36课:柱状图最大矩形
数据结构·算法·数据库架构·图论·bfs·柱状图最大矩形
LG.YDX4 小时前
笔试训练48天:跳台阶
数据结构·算法
汀、人工智能4 小时前
[特殊字符] 第42课:对称二叉树
数据结构·算法·数据库架构·图论·bfs·对称二叉树
ZTL-NPU4 小时前
代码随想录-第二章:时间复杂度
数据结构
@atweiwei4 小时前
Go语言面试篇数据结构底层原理精讲(上)
数据结构·面试·golang