文章目录
- [onnx 模型数据结构](#onnx 模型数据结构)
-
-
-
- 一、核心字段分类
- 二、关键嵌套结构详解
-
- [1. GraphProto(计算图核心)](#1. GraphProto(计算图核心))
- [2. OperatorSetProto(算子集依赖)](#2. OperatorSetProto(算子集依赖))
- [3. TensorProto(张量定义)](#3. TensorProto(张量定义))
-
-
- onnx存储为json结构
onnx 模型数据结构
onnx是基于protobuffer的, scheme 定义在:
https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
onnx官方其实也有一篇文章在介绍onnx IR的结构: onnx IR
一、核心字段分类
ModelProto 的字段可分为基础标识 / 版本 、元数据 、计算图核心 、扩展能力四大类,以下是关键字段的详细说明:
可以看到整个模型是 ModelProto:
| 字段类别 | 核心字段 | 作用说明 |
|---|---|---|
| 版本与兼容性 | ir_version |
标识 ONNX 中间表示(IR)的版本(如 8/9/13 等),决定语法 / 语义兼容性 |
opset_import |
模型依赖的算子集(OperatorSet)列表,格式为 {domain: version}(如 "" : 19 表示默认域的 19 版算子) |
|
| 元数据 | producer_name |
生成模型的工具 / 框架名称(如 PyTorch/TensorFlow) |
producer_version |
生成工具的版本号 | |
model_version |
模型自身的版本(自定义整数,用于版本管理) | |
doc_string |
模型的可读文档(支持 Markdown) | |
metadata_props |
自定义键值对元数据(如作者、训练数据集、license 等) | |
| 计算图核心 | graph |
模型的计算逻辑核心,类型为 GraphProto(见下文拆解) |
| 扩展能力 | functions |
模型内自定义函数(FunctionProto),用于封装可复用的算子逻辑 |
training_info |
训练相关信息(TrainingInfoProto),存储训练步骤、优化器等训练态信息 |
二、关键嵌套结构详解
ModelProto 依赖多个子 Proto 结构实现完整能力,核心嵌套结构如下:
1. GraphProto(计算图核心)
graph 是 ModelProto 的核心,定义了模型的计算逻辑,其核心字段:
name:图的名称(唯一标识);input/output:图的输入 / 输出张量定义(ValueInfoProto),包含张量名称、类型(TypeProto)、形状(TensorShapeProto);node:计算节点列表(NodeProto),每个节点对应一个算子(如 Conv/Add/MatMul),包含op_type(算子类型)、input/output(张量名)、attribute(算子属性,如卷积核大小);initializer:模型的权重 / 常量张量(TensorProto),如卷积核、偏置等;sparse_initializer:稀疏张量类型的权重(可选)。
| 字段名 | 数据类型 | 核心作用 | 补充说明 |
|---|---|---|---|
name |
string |
计算图的唯一标识名称 | 如 "cspresnet50_graph",用于区分不同子图(如循环 / 条件子图) |
input |
repeated ValueInfoProto |
图的输入张量元信息(不含权重) | 每个 ValueInfoProto 包含张量名、类型(TypeProto)、形状(TensorShapeProto),如模型输入 input.1 |
output |
repeated ValueInfoProto |
图的输出张量元信息 | 如模型输出 output.1,定义输出张量的名称、类型、形状 |
node |
repeated NodeProto |
计算图的核心算子节点(DAG 节点) | 每个 NodeProto 对应一个算子(如 Conv/Relu),包含输入、输出、属性(AttributeProto) |
initializer |
repeated TensorProto |
模型权重 / 常量张量(如卷积核、偏置) | 存储张量的二进制数据(raw_data)、形状(dims)、数据类型(data_type),是模型参数的核心载体 |
sparse_initializer |
repeated SparseTensorProto |
稀疏权重张量(如稀疏卷积的权重) | 仅用于稀疏张量场景,包含非零值(values)、索引(indices)等 |
doc_string |
string |
计算图的文档说明(备注 / 注释) | 可选字段,用于描述图的用途、版本等 |
value_info |
repeated ValueInfoProto |
图中中间张量的元信息(非输入输出 / 初始化器) | 可选字段,用于补充中间张量(如 Conv 输出)的类型 / 形状信息 |
quantization_annotation |
repeated QuantizationAnnotationProto |
量化相关标注(如张量的量化参数) | 量化模型专用,描述张量的零点、缩放因子等 |
layout_annotation |
repeated LayoutAnnotationProto |
张量布局标注(如 NCHW/NHWC) | 可选字段,用于指定张量的维度顺序 |
metadata_props |
repeated StringStringEntryProto |
自定义元数据(键值对) | 如 "author": "alan"、"version": "1.0" |
functions |
repeated FunctionProto |
图内自定义算子函数 | 用于封装复用的算子逻辑,如自定义 Selu 函数 |
arg |
repeated AttributeProto |
图级别的属性(极少使用) | 区别于 NodeProto 的算子属性,用于全局配置 |
node
- proto
protobuf
message NodeProto {
string name = 1; // 节点名(可选)
string op_type = 2; // 算子类型(如 "Conv"/"Relu"/"MatMul")
repeated string input = 3;// 输入张量名列表(如 ["input.1", "conv1.weight"])
repeated string output = 4;// 输出张量名列表(如 ["conv1.output"])
string domain = 5; // 算子域(默认 "" 为 ONNX 标准算子,如 "com.microsoft" 为自定义算子)
repeated AttributeProto attribute = 6; // 算子属性(如 Conv 的 stride/padding)
string doc_string = 10; // 节点说明
}
- json
json
{
"input": [
"/stem/conv1/conv/Conv_output_0"
],
"output": [
"/stem/conv1/bn/act/LeakyRelu_output_0"
],
"name": "/stem/conv1/bn/act/LeakyRelu",
"op_type": "LeakyRelu",
"attribute": [
{
"name": "alpha",
"f": 0.01,
"type": "FLOAT"
}
]
},
2. OperatorSetProto(算子集依赖)
opset_import 引用的算子集定义,核心字段:
domain:算子集的域名(默认域为空字符串"",自定义算子可指定如custom_domain);version:算子集版本(如 15/19),决定算子的签名 / 语义。
3. TensorProto(张量定义)
用于描述权重、输入输出张量,核心字段:
data_type:张量元素类型(如FLOAT/INT32/STRING,对应枚举值 1/6/7);dims:张量形状(如[3, 224, 224]表示 3 通道 224x224 图像);raw_data:张量的二进制数据(序列化后的权重值);- 辅助字段:
name(张量名)、segment(稀疏张量分段信息)等。
onnx存储为json结构
python
import onnx
import json
from google.protobuf.json_format import MessageToDict
model: onnx.ModelProto = onnx.load("/Users/alanchen/workspace/cspresnet50_Opset16.onnx")
graph: onnx.GraphProto = model.graph
def strip_tensor_weights(graph_dict: dict) -> dict:
"""
剥离 Graph 字典中 Tensor 的权重数据,仅保留元信息(名称、形状、数据类型)
"""
# 处理 initializer(模型权重/常量张量)
if "initializer" in graph_dict:
for tensor in graph_dict["initializer"]:
# 移除存储权重的字段(按需扩展,覆盖所有数值存储字段)
weight_fields = [
"raw_data", # 二进制权重(最常见)
"float_data", # float 类型数值列表
"int32_data", # int32 类型数值列表
"int64_data", # int64 类型数值列表
"uint64_data", # uint64 类型数值列表
"bool_data", # bool 类型数值列表
"string_data", # string 类型数值列表
"double_data", # double 类型数值列表
"int8_data", # int8 类型数值列表
"uint8_data", # uint8 类型数值列表
"int16_data", # int16 类型数值列表
"uint16_data", # uint16 类型数值列表
]
for field in weight_fields:
if field in tensor:
del tensor[field]
# 可选:处理 sparse_initializer(稀疏张量权重,若有)
if "sparse_initializer" in graph_dict:
for sparse_tensor in graph_dict["sparse_initializer"]:
if "values" in sparse_tensor:
del sparse_tensor["values"] # 移除稀疏张量的数值
if "indices" in sparse_tensor:
del sparse_tensor["indices"] # 移除稀疏张量的索引
return graph_dict
# 步骤1:将 GraphProto 转为字典(Protobuf 原生转换)
graph_dict = MessageToDict(graph, preserving_proto_field_name=True)
# 步骤2:剥离权重数据
graph_dict_stripped = strip_tensor_weights(graph_dict)
# 步骤3:导出为 JSON 文件(无权重)
with open("cspresnet50_graph_proto.json", "w", encoding="utf-8") as f:
json.dump(graph_dict_stripped, f, ensure_ascii=False, indent=2, sort_keys=False)
print("已导出剥离权重的 Graph JSON 文件:cspresnet50_graph_proto_stripped.json")