神经网络模型导出及开放标准格式ONNX

本文在上一篇文章的基础上,继续深入的分析模型格式及导出,重点分析PyTorch导出格式pt及开放标准格式ONNX。

1 模型格式

1.1 现代主流/安全优化格式(新生态)

1 safetensors(Hugging Face主导)

**背景:**目前AI绘画(Stable Diffusion、ComfyUI)和主流大模型最推荐的格式。

详细说明:

核心特点:100%安全,它只包含纯粹的张量数据(模型权重)和一小段描述结构的JSON文本,没有任何可执行代码,完全免疫了传统格式的木马病毒风险。

性能:支持零拷贝(Zero-copy)和内存映射(mmap),加载速度极快,远超旧格式。

使用场景:SD/SDXL/Flux模型的Checkpoint、LoRA、ControlNet,以及大语言模型权重。

2 gguf

**背景:**替代了早期的ggml格式,是目前端侧/本地大模型(LLM)量化的绝对霸主。

详细说明:

核心特点:将模型的所有组件(权重、分词器、超参数、元数据)封装在单个文件内。它的扩展性极强,支持未来添加新的元数据而不破坏向后兼容性。

量化优势:专门为CPU/GPU混合推理优化,支持各种强度的量化(如Q4_K_M、Q8_0),能让几百亿参数的大模型在消费级显卡甚至手机、MacBook上流畅运行。

适用场景:Llama3、Gemma、Qwen等大语言模型的本地部署,最近ComfyUI也开始流行用GGUF格式来运行轻量化的Flux或SDXL模型。

1.2 传统开发/原生框架格式(多见于训练与研究)

这类格式通常伴随官方框架诞生,虽然灵活,但在跨平台部署或安全性上存在一定局限。

1 pth/pt(PyTorch)

详细说明:

核心特点:PyTorch框架的原生存储格式,它基于PyTorch的pickle模块进行序列化。

优缺点:它极其灵活,不仅能保存模型权重,还能把训练了一半的优化器状态(Optimizer)、Epoch轮数甚至网络结构代码一起存进去,是模型训练和微调的标配。但缺点是不安全,加载恶意的pth文件,可能会在电脑上自动执行破坏性代码。

适用场景:AI绘画中的早期特定模型(如部分ControlNet、ESRGAN放大算法模型)、PyTorch算法开发与训练。

2 ckpt(Checkpoint)

**背景:**这是Stable Diffusion早期(1.5时代)最常用的格式,本质上和pth一样也是基于pickle序列化。

详细说明:

核心特点:由于存在和pth一模一样的安全漏洞(可注入恶意脚本),目前在生图领域已经全面被safetensors淘汰。如果你在网上下载到老的ckpt权重,建议用工具转换成safetensors再使用。

3 h5/keras/pb (TensorFlow/Keras)

**背景:**Google生态(TensorFlow)下的主流格式。

详细说明:

核心特点:h5(HDF5)多用于保存Keras模型的权重或完整结构;pb(Protocol Buffers)则是TensorFlow用于生产端部署的图模型格式。

适用场景:传统计算机视觉(如某些老的人脸识别、目标检测算法)和工业级TensorFlow工业管线。

1.3 夸平台推理与硬件加速格式(生产端部署)

当模型训练完成后,为了让它在手机、网页、或者不同显卡(英伟达、AMD、Intel)上跑得飞快,通常会转换成以下格式。

1 onnx(Open Neural Network Exchange)

详细说明:

核心特点:"通用翻译官",由微软、Meta等联合发起的开放标准。它把模型结构抽象成一张通用的计算图,让你可以把PyTorch训练的模型无缝转换到TensorFlow或其他推理引擎中运行。

适用场景:跨平台部署,比如你在ComfyUI里用到的某些WD14自动打标插件、LayerDiffusion背景分离模型,很多底层都在跑ONNX格式,因为它在CPU或非英伟达显卡上的兼容性极好。

2 engine(NVIDIA TensorRT)

**背景:**英伟达官方对自家显卡进行极致硬件加速的专用格式。

详细说明:

核心特点:速度榨干者,你必须在自己的显卡上,把onnx或safetensors现场"编译"成 engine 格式。它会根据你当前的显卡架构(如RTX 4090)进行剪枝、层融合和量化加速。

优缺点:速度达到物理极限(通常比原模型快 50%~100%)。但完全没有通用性,在4090上编译的engine文件拿到3080上是用不了的,甚至显卡驱动升级了都可能需要重新编译。

适用场景:追求极致生图速度的WebUI/ComfyUI TensorRT加速工作流,或工业级实时AI推理。

3 tflite/onnxruntime (移动端轻量化)

背景:tflite是TensorFlow Lite格式,专门给安卓、iOS手机或嵌入式设备(如树莓派)使用,做了极端的体积压缩和定点量化。

1.4 总结

在实际生产应用时,模型包含训练和部署两大关键环节,在训练环节模型可通过不同框架来实现,典型如PyTorch、TensorFlow等,由于框架的不同会产生不同格式的训练模型,而在部署环节由于硬件平台的不同又需要需要将不同格式的模型进行差异化修改,这种多到多的操作实施起来很繁琐。经过工业界和学术界数年的探索,模型部署有了一条流行的流水线:

如上图,为了让模型最终能够部署到某一环境上,开发者们可以使用任意一种深度学习框架来定义网络结构,并通过训练确定网络中的参数。之后,模型的结构和参数会被转换成一种只描述网络结构的中间表示,一些针对网络结构的优化会在中间表示上进行。最后,用面向硬件的高性能编程框架(如 CUDA,OpenCL)编写,能高效执行深度学习网络中算子的推理引擎会把中间表示转换成特定的文件格式,并在对应硬件平台上高效运行模型,比如中间表示ONNX转换支持华为芯片推理的OM文件。

2 模型导出及格式转换

2.1 导出pth格式模型

在上一篇文章中源码trainNN.py已经导出mnist_cnn.pth模型文件,PyTorch使用Python原生pickle序列化任意Python对象(模型、张量、字典、优化器、数字、自定义类等),这是pth文件的核心。pickle可执行任意代码,所以处于安全考虑不要加载来源不明的pt文件,Torch 2.0+提供weights_only=True安全加载模式,仅读取张量、禁止反序列化Python对象。新版PyTorch默认开启_use_new_zipfile_serialization=True,pth本质是一个zip压缩包;旧版本是纯二进制pickle文件,无压缩。对于新版pt文件可以直接用unzip model.pt解压查看内部文件。如解压后内容如下:

1 目录及文件解释

version:pickle序列化协议版本(pickle4/pickle5),用于版本兼容校验;

.format_version:Torch自定义ZIP存储格式版本,区分新旧存储结构,版本不匹配会直接加载失败;

.storage_alignment:张量二进制存储内存对齐字节数,GPU/CPU加载时用来对齐内存,提升张量读取速度,一般是1或64;

byteorder:存储硬件字节序,x86机器固定是little(小端序),用来解析.storage里的浮点/整数二进制;

data.pkl:整个模型的逻辑骨架,用pickle序列化了checkpoint /state_dict字典,但只存张量的「描述信息」,不存权重数字;

data:目录下存储张量权重二进制数据;

.data:该目录可能是特定场景(如torch.package/export)的残留或误生成;

2 data.pkl详解

直接用UE打开该二进制文件,内容如下:

在VS Code中安装PKL Viewer扩展,也一并打开该文件进行分析,内容如下:

复制代码
 1     0: \x80 PROTO      2
 2     2: c    GLOBAL     'collections OrderedDict'
 3    27: q    BINPUT     0
 4    29: )    EMPTY_TUPLE
 5    30: R    REDUCE
 6    31: q    BINPUT     1
 7    33: (    MARK
 8    34: X        BINUNICODE 'features.0.weight'
 9    56: q        BINPUT     2
10    58: c        GLOBAL     'torch._utils _rebuild_tensor_v2'
11    91: q        BINPUT     3
12    93: (        MARK
13    94: (            MARK
14    95: X                BINUNICODE 'storage'
15   107: q                BINPUT     4
16   109: c                GLOBAL     'torch FloatStorage'
17   129: q                BINPUT     5
18   131: X                BINUNICODE '0'
19   137: q                BINPUT     6
20   139: X                BINUNICODE 'cuda:0'
21   150: q                BINPUT     7
22   152: M                BININT2    288
23   155: t                TUPLE      (MARK at 94)
24   156: q            BINPUT     8
25   158: Q            BINPERSID
26   159: K            BININT1    0
27   161: (            MARK
28   162: K                BININT1    32
29   164: K                BININT1    1
30   166: K                BININT1    3
31   168: K                BININT1    3
32   170: t                TUPLE      (MARK at 161)
33   171: q            BINPUT     9
34   173: (            MARK
35   174: K                BININT1    9
36   176: K                BININT1    9
37   178: K                BININT1    3
38   180: K                BININT1    1
39   182: t                TUPLE      (MARK at 173)
40   183: q            BINPUT     10
41   185: \x89         NEWFALSE
42   186: h            BINGET     0
43   188: )            EMPTY_TUPLE
44   189: R            REDUCE
45   190: q            BINPUT     11
46   192: t            TUPLE      (MARK at 93)
47   193: q        BINPUT     12
48   195: R        REDUCE
49   196: q        BINPUT     13
50   198: X        BINUNICODE 'features.0.bias'
51   218: q        BINPUT     14
52   220: h        BINGET     3

最一开始是由Pickle规范(PEP307/PEP574)规定的0x80=PROTO头,声明了本次序列化使用的Pickle协议版本,0x80后紧跟1字节参数不是协议号本身,而是有对应的映射表:

随后的二进制流,可以通过Python标准库的pickletools模块,来查看包含所有协议版本的完整指令对照表,源码内带详细注释说明每个指令的作用,该文件路径可以通过以下代码获取:

复制代码
python -c "import pickletools; print(pickletools.__file__)"

也可以通过如下文件直接打印出相应对照表:

复制代码
import pickletools

op_list = pickletools.opcodes

def opcode_value(code):
    if isinstance(code, str):
        return ord(code)
    return code

op_list_sorted = sorted(op_list, key=lambda op: opcode_value(op.code))

print(f"{'Hex':<6} | {'Dec':<3} | {'Opcode Name':<16} | Min Proto | Description")
print("-" * 120)

for op in op_list_sorted:
    code_val = opcode_value(op.code)

    hex_byte = f"0x{code_val:02X}"

    desc = ""
    if op.doc:
        desc = op.doc.splitlines()[0].strip()

    print(
        f"{hex_byte:<6} | "
        f"{code_val:<3} | "
        f"{op.name:<16} | "
        f"{op.proto:<9} | "
        f"{desc}"
    )

printpkl.py

该文件运行输出内容如下:

接下来继续结合PKL Viewer解析结果进行分析,第2行内容等价于python代码collections.OrderedDict,把该类压入到Pickle栈,此时栈为OrderedDict,第3行内容表示将OrderedDict类缓存到memo表,即memo0 = class 'collections.OrderedDict',第4行将空元组压入栈,此时栈为OrderedDict, (),第5行REDUCE的含义是callable(*args),对应的代码是弹出参数空元组及类并调用类实现OrderedDict(*()),之后将结果及用OrderedDict类构造一个空实例对象压入到栈,此时栈为ordereddict实例,第6行保存对象到memo表,即memo1 = OrderedDict(),此时实际上已经得到:state_dict = OrderedDict(),此时在memo表中分别存储了类及其实例对象,后续可根据需要进行复用。接下来代码开始向OrderedDict插入元素,第7行向栈插入MARK分隔符(用于标记一个可变长度对象的开始位置),第8行向栈压入一个Unicode字符串对象,后续它将作为key,第9行将字符串存入到memo表,第10行向栈中压入_rebuild_tensor_v2重建函数,此时栈内容为OrderedDict(), MARK33, 'features.0.weight', '_rebuild_tensor_v2',第11行将其存入到memo表,第12,13行两次将MARK压入到栈,第14~22行依次将"storage",FloatStorage,"0","cuda:0",288入栈,最终由第23行的TUPLE和最近MARK94产生元组对象,并入栈,此时栈内容为OrderedDict(), MARK33, 'features.0.weight', '_rebuild_tensor_v2', MARK93, ("storage", FloatStorage, "0", "cuda:0", 288),第25行BINPERSID会调用unpickler.persistent_load(pid)来真正从data/0中的数据实现Storage对象,这里正是结构和数据分别存放的精华所在,例如,这里weight数据288个float数据,总大小为1152字节,和data/0文件大小正好对应上:

此时栈内容为:OrderedDict(), MARK33, 'features.0.weight', '_rebuild_tensor_v2', MARK93, FloatStorage(288),而_rebuild_tensor_v2函数的参数定义如下:

复制代码
_rebuild_tensor_v2(
    storage,
    storage_offset,
    size,
    stride,
    requires_grad,
    backward_hooks
)

可见第2个参数是storage_offset,就是说生成Tensor时不一定从Storage的开头开始,而是可以从指定storage_offset下标开始,这正好对应源码中第26行内容,这里向栈中压入下标0;接下来第27~32行给出_rebuild_tensor_v2函数的第3个参数size=(32,1,3,3),即weight形状是32x1x3x3,正好是288个元素;再接下来第34~39行给出函数第4个参数stride=(9, 9, 3, 1),即stride0=1*3*3=9,stride1=3*3=9,stride2=3,stride3=1;随后第41行向栈中压入False值,对应参数requires_grad,之后第42~44先通过BINGET 0将之前memo中预存的OrderedDict类压栈,然后再将空元组压栈,之后通过REDUCE实例化OrderedDict对象作为参数backward_hooks的值;最终通过第46行的和MARK93闭合产生参数元组,并在48行完成_rebuild_tensor_v2函数调用产生tensor实例。接下来过程类似,最终栈内容大致如下:

复制代码
OrderedDict()
MARK33
"features.0.weight"
Tensor(...)

"features.0.bias"
Tensor(...)

"features.1.weight"
Tensor(...)

各个元素初始化完毕后最终会通过SETITEMS完成OrderedDict对象赋值,该指令会把最近MARK后的所有key------value键值对插入到OrderedDict对象,最终栈变成:

复制代码
[
OrderedDict({
    "features.0.weight": tensor(...),
    "features.0.bias": tensor(...),
    "features.1.weight": tensor(...),
    ...
})
]

2.2 pth转换为onnx

1 Protobuf

Protocol Buffers(简称Protobuf)是Google开发的一种语言无关、平台无关、可扩展的机制,用于序列化结构化数据。简单来说,可以把它想象成一种更高效、更快速的"XML"或"JSON"。它的核心作用是:将内存中的复杂数据结构(如对象、结构体)转换成紧凑的二进制字节流,以便进行网络传输或持久化存储;反过来,也能将二进制数据还原为原始数据结构。

使用Protobuf通常遵循以下三个步骤:

(1)定义数据结构

在一个.proto文件中,使用Protobuf的语法定义好数据结构(即"消息"),如addressbook.proto内容如下:

复制代码
syntax = "proto3";

// 定义一个人的信息
message Person {
  string name = 1;
  int32 id = 2;
  string email = 3;
}

// 定义整个通讯录(repeated 表示可包含多个 Person)
message AddressBook {
  repeated Person people = 1;
}

(2)生成源代码

使用Protobuf编译器(protoc)根据你的 .proto 文件,自动生成你所用编程语言的代码,这些代码包含了读写该数据结构的方法,比如可以使用如下命令生成python语言的代码:

复制代码
protoc --python_out=. addressbook.proto

以上命令会生成addressbook_pb2.py文件,里面包含了读写Person和Address的所有类方法。但是由于我的电脑上是直接安装的最新版的protoc 编译器(核心版本35.0),而使用pip install protobuf安装的Protobuf运行时库版本版本是6.33.6(python"语言前缀"是6,"核心版本号"是33.6),这两个是不同的Protobuf核心版本,在使用python protobuf库去直接解析高版本编译器库生成的addressbook_pb2.py时会有版本冲突问题,这里使用Python生态自带的grpcio-tools来生成代码,它内嵌了一个与当前Python Protobuf库严格配套的protoc,使用如下命令安装并生成addressbook_pb2.py文件:

复制代码
pip install grpcio-tools

python -m grpc_tools.protoc -I. --python_out=. addressbook.proto

文件内容如下:

复制代码
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler.  DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: addressbook.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
    _runtime_version.Domain.PUBLIC,
    6,
    31,
    1,
    '',
    'addressbook.proto'
)
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()




DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11\x61\x64\x64ressbook.proto\"1\n\x06Person\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\r\n\x05\x65mail\x18\x03 \x01(\t\"&\n\x0b\x41\x64\x64ressBook\x12\x17\n\x06people\x18\x01 \x03(\x0b\x32\x07.Personb\x06proto3')

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'addressbook_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
  DESCRIPTOR._loaded_options = None
  _globals['_PERSON']._serialized_start=21
  _globals['_PERSON']._serialized_end=70
  _globals['_ADDRESSBOOK']._serialized_start=72
  _globals['_ADDRESSBOOK']._serialized_end=110
# @@protoc_insertion_point(module_scope)

addressbook_pb2.py

虽然机器生成的代码不太好看,但它确实包含了我们定义的字段。

(3)使用生成的代码

在应用程序中,直接调用生成的代码来序列化(写)和反序列化(读)数据。现在,创建main.py文件,写入下面的代码。它做了两件事:写入(序列化)一个包含两人的通讯录到磁盘文件,再读取(反序列化)并打印出来。

复制代码
import addressbook_pb2

# ============ 1. 序列化:创建对象并写入文件 ============
# 创建一个通讯录对象
book = addressbook_pb2.AddressBook()

# 添加第一个人
person1 = book.people.add()
person1.name = "张三"
person1.id = 1001
person1.email = "zhangsan@example.com"

# 添加第二个人
person2 = book.people.add()
person2.name = "李四"
person2.id = 1002
person2.email = "lisi@example.com"

# 将通讯录序列化为二进制数据,并写入本地文件
with open("addressbook.data", "wb") as f:
    f.write(book.SerializeToString())
print("✅ 序列化完成,数据已写入 addressbook.data")

# ============ 2. 反序列化:从文件读取并解析 ============
# 创建一个空的通讯录对象
new_book = addressbook_pb2.AddressBook()

# 从文件中读取二进制数据并反序列化到 new_book
with open("addressbook.data", "rb") as f:
    new_book.ParseFromString(f.read())

# 打印解析出的数据
print("\n✅ 反序列化成功,读取到的内容如下:")
for person in new_book.people:
    print(f"姓名: {person.name}, ID: {person.id}, 邮箱: {person.email}")

main.py

运行效果如下:

相比XML和JSON等文本格式,Protobuf在性能上优势显著,可以用一个形象的比喻来理解:

JSON/XML像一封用自然语言写的书信,内容清晰易读,但有很多冗余字符(如引号、括号、标签),体积大,传输慢。

Protobuf则像一封用密文写的电报,体积小、传输快,但只有拥有"密码本"(即.proto定义文件)的人才能看懂。

Protobuf是一个为高性能、高效率和跨语言兼容性而生的数据序列化方案,非常适合对性能和带宽有严苛要求的系统内部通信。它为了追求极致的体积和速度,在二进制流中绝对不会保存任何字段名,而是采用的是Tag - Length - Value(TLV)的紧凑格式:

(1)Tag(标签/字段编号):每一个字段在.proto图纸里都有一个独一无二的数字编号,二进制里只存这个编号。

(2)Wire Type(传输类型):用来告诉解析器这个数据是个变长整数(Varint)、还是个有固定长度的字符串/嵌套块,Tag和Wire Type会被打包成一个字节(称为 Field Key)。

Field Key = (Tag << 3) | Wire Type

(3)Value(真实数据):如果是字符串,前面会用一个数字表明它的字节长度(Length),后面紧跟纯字符的 ASCII 码。

Google Protobuf官方的C++底层源码库,在头文件wire_format_lite.h里定义了6种核心数据物理传输类型(Wire Type):

复制代码
class WireFormatLite {
 public:
  enum WireType {
    WIRETYPE_VARINT           = 0,        // int32, int64, uint32, uint64, sint32, sint64, bool, enum
    WIRETYPE_FIXED64          = 1,        // fixed64, sfixed64, double(固定8字节)
    WIRETYPE_LENGTH_DELIMITED = 2,        // string, bytes, 嵌套的 message, packed repeated fields
    WIRETYPE_START_GROUP      = 3,        // 已废弃的老式分组
    WIRETYPE_END_GROUP        = 4,        // 已废弃的老式分组
    WIRETYPE_FIXED32          = 5,        // fixed32, sfixed32, float(固定 4 字节)
  };
};

int64在内存里明明是固定占8个字节(64位)的长整型,为什么把它分配给Varint(变长整数)?这正是Protobuf极度高产和聪明的核心算法:Varint (Variable-length quantity)。

传统的存法:如果你存一个很小的数字(比如数字6),在传统的二进制里,使用int64也必须雷打不动地占满8个字节的内存(也就是一堆00 00 00 00 00 00 00 06),白白浪费空间。

Protobuf的存法:当Wire Type被识别为0时,它会启动Varint编码(每个字节的最高位MSB作为标志位,后面7位存数字)。如果是数字6,它在二进制流里只占用1个字节(即06)就存完了!只有当你的数字巨大无比(比如好几十亿)时,它才会慢慢扩充,最多扩充到10个字节。

下面以上述实例序列化后的二进制数据为例来说明解析过程,二进制数据如下:

解析时结合addressbook.proto解码文件,先从最外套娃AddressBook开始,然后再进一步解析内部嵌套结构Person。

(1)0A 21:开启第一个Person容器

0A:0000 1010,00001解构为字段编号1对应AddressBook.people,Wire Type为010(WIRETYPE_LENGTH_DELIMITED=2)表示repeated fields;

21:换算成十进制是33,意味着第一个人占用了接下来的33个字节。

(2)0A 06 E5 BC A0 E4 B8 89:解析名字 "张三"

0A:0000 1010,解构为Person.name(编号1,Wire Type 2)

06:字符串长度为6个字节;

数据段:E5 BC A0 E4 B8 89,在UTF8编码中,一个汉字占3个字节,E5 BC A0 = 张,E4 B8 89 = 三;

(3)10 E9 07:解析 ID 1001

10:解构为Person.id(编号2,Wire Type 0);

E9 07:典型的Varint变长整数编码,E9对应二进制1110 1001,最高位是1,表示其后还有字节,有效数据位是后7位110 1001,07对应二进制0000 0111,最高位0,代表结束,有效数据是后7位000 0111,低位在前把000 0111和110 1001进行拼接得到0000 0011 1110 1001,换算成十进制正好是1001;

(4)1A 14 7A 68 ... 63 6F 6D:解析邮箱

1A:解构为Person.email(编号3,Wire Type 2);

14:十六进制的14换算成十进制是20,代表邮箱长20个字节;

数据段:7A 68 61 6E 67 73 61 6E 40 65 78 61 6D 70 6C 65 2E 63 6F 6D 转换成 ASCII 码对应字符串zhangsan@example.com;

按同样流程可以解析出第2个Person,此外还可以通过命令:

复制代码
protoc --decode=AddressBook addressbook.proto < addressbook.data

让protoc帮助完成解析:

2 onnx结构

onnx(Open Neural Network Exchange)作为模型界的"通用翻译官",是由微软、Meta等联合发起的开放标准。它把模型结构抽象成一张通用的计算图,让你可以把PyTorch训练的模型无缝转换到TensorFlow或其他推理引擎中运行。ONNX文件的本质是一个通过Protocol Buffers(Protobuf)序列化后的二进制文件。如果你去解构一个.onnx文件,它的核心组成部分符合严格的层次结构:

以下对主要消息体进行介绍:

(1)顶层容器:ModelProto

这是整个.onnx文件的根节点,相当于总指挥部,包含以下主要内容。

ir_version:该模型对应的ONNX中间表示版本(当前最高已演进至0x000000000000000D,即IR Version 13);

opset_import:一个列表,声明了模型依赖的算子集版本(OperatorSetIdProto);

graph:核心计算图(GraphProto);

training_info:存放训练或微调阶段的梯度和优化器信息(TrainingInfoProto);

functions:模型本地定义的函数列表(FunctionProto);

configuration:该字段专门用于多设备部署场景,它允许一个 ONNX 模型文件预先定义好多种硬件部署方案(DeviceConfigurationProto);

此外该消息体还包括producer_name,producer_version,domain,model_version等一些字符串元数据。

(2)核心大脑:GraphProto

定义了网络中所有的数据和计算节点,形成一个有向无环图(Directed Acyclic Graph,DAG),主要元素如下:

node:计算节点列表(NodeProto),必须按拓扑顺序排列,每个节点指定了它使用的算子类型、输入参数名、输出参数名以及属性(Attributes,如卷积的stride、padding);

initializer:存放静态权重(如权重矩阵、偏置)的常规张量列表(TensorProto);

sparse_initializer:稀疏矩阵格式的静态权重(SparseTensorProto);

input/output:网络的输入和输出边界定义(ValueInfoProto);

(3)计算单元:NodeProto

网络中的"层"或算子(例如一个卷积层或激活层),主要内容如下:

op_type:算子名称(例如Conv,Relu,MatMul);

input/output:字符串数组,ONNX就是通过匹配不同节点的输入输出字符串名称来确立节点间的连接线的;

attribute:该算子的静态配置参数列表(AttributeProto,如卷积的strides);

(4)属性参数:AttributeProto

由于Protobuf没有原生的联合体(Union),ONNX在这里用了一个绝妙的设计,通过一个枚举AttributeType配合一堆optional字段,实现了Union(联合体)的等价类型:

复制代码
message AttributeProto {
  // 1. 先用一个枚举来标记"当前类型"
  optional AttributeType type = 20; 

  // 2. 下面这一堆字段,在实际使用时,【有且只能有一个】有值
  optional float f = 2;               // 如果是浮点数,就存这里
  optional int64 i = 3;               // 如果是整数,就存这里
  optional bytes s = 4;               // 如果是字符串,就存这里
  optional TensorProto t = 5;         // 如果是高维张量/权重,就存这里
  optional GraphProto g = 6;          // 如果是子计算图,就存这里

  // 下面是对应的数组(复数)形式
  repeated float floats = 7;          
  repeated int64 ints = 8;            
  // ...
}

单数类型字段:f(float),i(int64),s(bytes/string),t(TensorProto),g(GraphProto/子图);

复数类型字段:floats,ints,strings,tensors,graphs;

(5)数据载体:TensorProto

专门用来存多维数组和权重数据的底层实体。

dims:形状/维度(例如 64, 3, 3, 3);

data_type:极其丰富的低精度/高精度数据类型枚举。可以看到最新版已经扩充了FLOAT8,INT4,UINT4,甚至是2位的INT2/UINT2;

数据存储位置 (data_location):

DEFAULT(0):数据直接压缩塞在当前二进制文件里(通过float_data、int32_data或原生的raw_data存储);

EXTERNAL(1):大模型专用的外部权重解耦机制。数据不塞在.onnx里,而是记录在external_data键值对中,指明外部.data文件的路径、偏移量(offset,通常推荐4096字节对齐以支持mmap)和数据长度;

6 )类型描述: TypeProto

用来规定某个张量或者变量在运行时的数据形态。

value:使用了一个oneof语法,它可以是:

tensor_type: 包含基础元素类型(elem_type)和形状描述(TensorShapeProto);

sequence_type/map_type/optional_type:用于支持传统机器学习或非张量的数据流动;

规范中还有其他一些消息体,这里不再进行详细介绍,更多内容请查看onnx.proto3文件。

3 导出onnx

可以通过如下代码将生成的mnist_cnn.pth导出为onnx格式的文件:

复制代码
  1 import os
  2 import torch
  3 import torch.nn as nn
  4 import torch.nn.functional as F
  5 import warnings
  6 from PIL import Image
  7 
  8 # 🌟 核心操作:在代码最顶层直接屏蔽掉传统的 TorchScript 弃用黄色警告
  9 # 让控制台只输出我们关心的 ✓ 和 ✗
 10 warnings.filterwarnings("ignore", category=DeprecationWarning)
 11 
 12 # ==========================================
 13 # 1. 100% 完美还原你训练源码中的 DeepMNIST 模型结构
 14 # ==========================================
 15 class DeepMNIST(nn.Module):
 16     def __init__(self):
 17         super(DeepMNIST, self).__init__()
 18         
 19         # 1. 卷积特征提取层 (使用 nn.Sequential 容器)
 20         self.features = nn.Sequential(
 21             # 第一层:卷积 -> 批归一化 -> 激活 -> 池化
 22             nn.Conv2d(1, 32, kernel_size=3, padding=1), # in_channels=1 (灰度图)
 23             nn.BatchNorm2d(32),                         # 稳定分布
 24             nn.ReLU(inplace=True),                      # 非线性激活
 25             nn.MaxPool2d(kernel_size=2, stride=2),      # 28x28 -> 14x14
 26             
 27             # 第二层:进一步提取深层特征
 28             nn.Conv2d(32, 64, kernel_size=3, padding=1),
 29             nn.BatchNorm2d(64),
 30             nn.ReLU(inplace=True),
 31             nn.MaxPool2d(kernel_size=2, stride=2),      # 14x14 -> 7x7
 32             nn.Dropout(0.25)                            # 随机丢弃,防止过拟合
 33         )
 34         
 35         # 2. 分类层 (全连接层)
 36         self.classifier = nn.Sequential(
 37             nn.Flatten(),                               # 将 64x7x7 压平为 1D 向量
 38             nn.Linear(64 * 7 * 7, 128),                 # 特征线性组合
 39             nn.ReLU(inplace=True),
 40             nn.Dropout(0.5),                            # 全连接层常用的 Dropout
 41             nn.Linear(128, 10)                          # 最终映射到 10 个类别
 42         )
 43 
 44     def forward(self, x):
 45         x = self.features(x)
 46         x = self.classifier(x)
 47         return x
 48 
 49 # ==========================================
 50 # 2. 模型核心转换逻辑
 51 # ==========================================
 52 def convert_to_onnx(pth_path, onnx_path):
 53     # 实例化完全匹配的 DeepMNIST 模型
 54     model = DeepMNIST()
 55 
 56     print(f"正在加载 PyTorch 权重文件: {pth_path}")
 57     state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
 58 
 59     # 自动解析包装过的 state_dict (兼容某些训练框架)
 60     if isinstance(state_dict, dict) and 'state_dict' in state_dict:
 61         state_dict = state_dict['state_dict']
 62 
 63     model.load_state_dict(state_dict)
 64     
 65     # 开启评估模式 (非常重要:会固化 BatchNorm 的均值方差和 Dropout 的行为)
 66     model.eval()
 67 
 68     # 构建一阶段 Trace 用的高维度随机虚拟输入 (1张 28x28 的灰度图)
 69     dummy_input = torch.randn(1, 1, 28, 28, requires_grad=True)
 70 
 71     print(f"正在进行 Tracing 并导出 ONNX 模型: {onnx_path} ...")
 72     torch.onnx.export(
 73         model,
 74         dummy_input,
 75         onnx_path,
 76         export_params=True,          # 导出模型权重参数
 77         opset_version=11,            # 推荐使用 opset 11 及以上版本,对 BatchNorm 支持极佳
 78         do_constant_folding=True,    # 开启常量折叠图优化
 79         input_names=['input'],       # 计算图的输入节点自定义名称
 80         output_names=['output'],     # 计算图的输出节点自定义名称
 81         dynamic_axes={               # 开启动态 Batch Size 推理轴支持
 82             'input': {0: 'batch_size'},
 83             'output': {0: 'batch_size'}
 84         }
 85     )
 86     print("✓ ONNX 导出流程成功结束。")
 87 
 88 
 89 # ==========================================
 90 # 3. ONNX 推理一致性校对
 91 # ==========================================
 92 def verify_onnx(pth_path, onnx_path):
 93     import onnx
 94     import onnxruntime as ort
 95     import numpy as np
 96 
 97     print("\n--- 启动 ONNX Runtime 精度与对齐测试 ---")
 98 
 99     # 3.1 验证 ONNX 计算图结构合法性
100     onnx_model = onnx.load(onnx_path)
101     onnx.checker.check_model(onnx_model)
102     print("✓ [ONNX Checker] 模型图物理结构校验通过!")
103 
104     # 3.2 使用测试图片数据或者生成随机测试样本数据
105     img_path = "debug_input.png"
106     if not os.path.exists(img_path):
107         print(f"未找到测试图片 '{img_path}',将使用随机测试样本数据")
108         test_input_np = np.random.randn(1, 1, 28, 28).astype(np.float32)
109     else:
110         print(f"使用'{img_path}'进行验证")
111         from torchvision import transforms  # 🌟 引入 torchvision
112         # a. 严格复刻你前端/训练时的 1:1 黄金流水线
113         transform_pipeline = transforms.Compose([
114             transforms.Resize((28, 28)),
115             transforms.Grayscale(),
116             transforms.ToTensor(),
117             transforms.Normalize((0.1307,), (0.3081,)) # 🌟 灵魂所在,彻底对齐数据分布
118         ])
119         # b. 直接用 Pillow 打开图片(保持你之前的灰度转换)
120         img = Image.open(img_path).convert('L')
121         # c. 用流水线直接榨出标准的 PyTorch 四维张量
122         img_tensor = transform_pipeline(img)      # 出来是 (1, 28, 28)
123         img_tensor = img_tensor.unsqueeze(0)      # 升维成 (1, 1, 28, 28)
124         test_input_np = img_tensor.numpy().astype(np.float32) # 给 ONNX Runtime 用
125     test_input_torch = torch.from_numpy(test_input_np)
126 
127     # 3.3 运行 PyTorch 原生推理
128     model = DeepMNIST()
129     state_dict = torch.load(pth_path, map_location='cpu')
130     if isinstance(state_dict, dict) and 'state_dict' in state_dict:
131         state_dict = state_dict['state_dict']
132     model.load_state_dict(state_dict)
133     model.eval()
134     with torch.no_grad():
135         pytorch_output = model(test_input_torch).numpy()
136         #print(pytorch_output)
137     # 3.4 使用 ONNX Runtime 执行推理
138     ort_session = ort.InferenceSession(onnx_path)
139     ort_inputs = {ort_session.get_inputs()[0].name: test_input_np}
140     ort_outputs = ort_session.run(None, ort_inputs)
141     onnx_output = ort_outputs[0]
142     #print(onnx_output)
143     # 3.5 严格的数值对齐校验(相对误差 1e-3,绝对误差 1e-5)
144     try:
145         np.testing.assert_allclose(pytorch_output, onnx_output, rtol=1e-03, atol=1e-05)
146         print("✓ [精度校验] 恭喜!PyTorch 与 ONNX Runtime 数值完美对齐,误差精度在安全范围内。")
147         
148         # 调试打印:看看这张图片被模型预测成了数字几?
149         pred_digit = np.argmax(onnx_output)
150         confidence = np.max(onnx_output)
151         print(f"🔮 [模型最终预测] 数字是: {pred_digit} | 置信度: {confidence:.4f}")
152     except AssertionError as e:
153         print("✗ [精度校验] 两侧算子输出存在浮点对齐偏差,请排查!")
154         print(e)
155 
156 
157 if __name__ == "__main__":
158     pth_file = "mnist_cnn.pth"
159     onnx_file = "mnist_cnn.onnx"
160 
161     if not os.path.exists(pth_file):
162         print(f"致命错误:请将您训练生成的 {pth_file} 文件放置于脚本同级目录下再运行转换。")
163     else:
164         convert_to_onnx(pth_file, onnx_file)
165         try:
166             verify_onnx(pth_file, onnx_file)
167         except ImportError:
168             print("\n温馨提示: 补充安装 `pip install onnx onnxruntime` 可自动进行数值比对校验。")

pth2onnx.py

程序运行输出如下:

可见程序不仅导出了onnx模型,还使用测试图片对onnx模型和pth模型的预测结果进行了对比,从结果看onnx模型和原始pth模型的预测结果仅有很小误差,可以通过在线网站netron来图形化显示onnx的结构:

还可以通过如下命令导出完整DAG结构:

复制代码
protoc --decode=onnx.ModelProto onnx.proto3 < mnist_cnn.onnx > mnist_cnn_text.txt

当然需要提前下载onnx.proto3文件并将其和mnist_cnn.onnx模型文件放到同一目录下,下面结合二进制模型文件和onnx.proto3文件,说明下解析产生txt结构文件的关键点。

从二进制文件开头08 06 12 ... 38 2E 30的18字节数据,可以解析出ModelProto消息的ir_version、producer_name和producer_version成员,该部分解析比较简单,不再详述。接下来3A = 0011 1010高5位数值7对应GraphProto类型编号,底层类型2对应嵌套message结构,接下来B6 88 67这3个字节解析后是整个graph的大小1688630,由于Graph包含多个节点和相关权重数据,所以它的字节数较大,对应该实例来说整个onnx文件大小是1688656字节,头部18+4占22字节,尾部4个字节是opset_import算子集版本,其他所有数据都对应Graph内容。大图长度交代完后,正式跨入GraphProto内部,接下来流里碰到了0A=0000 1010,在GraphProto中,字段编号为1,WireType为2,对应repeated NodeProto node,接下来是第一个算子节点,C8 01是第一个NodeProto算子节点的总长度(Varint编码),C8=1001000,01=000 0001,拼接二进制对应十进制200,即后续200个字节对应第一个node节点内容。其他内容解析不再详解说明,这里重点说一下2A 12 0A 09 64 69 6C 61 74 69 6F 6E 73 40 01 40 01 A0 01 07对应AttributeProto message部分的解析,首先2A>>3=5是字段编号,2A&7=2对应WireType,在NodeProto中字段5对应repeated AttributeProto attribute,12是长度即18字节,接下来0A 09对应属性名"dilations",接下来的40 01 40 01对应两个ints值为1,之后A0 01 07有点特殊,在解析A0发现字段编号超过15,所以这里Tag实际上是双字节Varint编码,所以要和01一起进行解析,A0砍掉开头的1剩下7位010 0000,01砍掉开头的0,剩下7位000 0001,拼接后二进制为10100000,字段编号是20,WireType为0,对应enum AttributeType类型type,其值为7,AttributeProto在onnx.proto3文件中定义如下:

复制代码
message AttributeProto {
  reserved 12, 16 to 19;
  reserved "v";

  // Note: this enum is structurally identical to the OpSchema::AttrType
  // enum defined in schema.h.  If you rev one, you likely need to rev the other.
  enum AttributeType {
    UNDEFINED = 0;
    FLOAT = 1;
    INT = 2;
    ...
    FLOATS = 6;
    INTS = 7;
    STRINGS = 8;
    ...
  }
  ...

  // The type field MUST be present for this version of the IR.
  // For 0.0.1 versions of the IR, this field was not defined, and
  // implementations needed to use has_field heuristics to determine
  // which value field was in use.  For IR_VERSION 0.0.2 or later, this
  // field MUST be set and match the f|i|s|t|... field in use.  This
  // change was made to accommodate proto3 implementations.
  AttributeType type = 20;   // discriminator that indicates which field below is i
  ...
}

View Code

之后二进制内容解析类似,这里不再详述。

3 onnx模型预测过程详解

之前已经就onnx模型解析进行了详细分析,但是由于之前模型输入都是float型,为了方便验证分析模型预测过程,这里通过ONNX库手动构建ONNX神经网络图,并通过该网络模型来详细分析预测过程。首先给出模型生成的源码:

复制代码
import onnx
from onnx import helper
from onnx import TensorProto

inh=8
inw=8
outh=16
outw=16

#define tensor
# 1. 定义张量 (外部接口 input 使用 UINT8)
input_tensor = helper.make_tensor_value_info('input', TensorProto.UINT8, [1, 3, inh, inw])
roi = helper.make_tensor_value_info('roi', TensorProto.FLOAT, [])
scales = helper.make_tensor_value_info('scales', TensorProto.FLOAT, [4])

# 内部连线:Resize 吐出 UINT8
conv_input_u8 = helper.make_tensor_value_info('conv_input_u8', TensorProto.UINT8, [1, 3, outh, outw])
# 内部连线:经过 Cast 后变成 FLOAT32,供 Conv 消费
conv_input_f32 = helper.make_tensor_value_info('conv_input_f32', TensorProto.FLOAT, [1, 3, outh, outw])

# 权重和偏置保持 FLOAT32 (我们在推理端传 float 格式的整数,比如 10.0,模拟整型加权)
conv_weight = helper.make_tensor_value_info('conv_weight', TensorProto.FLOAT, [32, 3, 3, 3])
conv_bias = helper.make_tensor_value_info('conv_bias', TensorProto.FLOAT, [32])
conv_output = helper.make_tensor_value_info('conv_output', TensorProto.FLOAT, [1, 32, outh, outw])
add_input = helper.make_tensor_value_info('add_input', TensorProto.FLOAT, [1])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 32, outh, outw])

# 2. 构建计算节点
# 节点A: Resize (输入 UINT8 -> 输出 UINT8)
resize_node = helper.make_node("Resize", ['input', 'roi', 'scales'], ['conv_input_u8'], name='resize')

# 节点B: Cast (将匹配完形状的 UINT8 特征图 强转为 FLOAT32)
# 属性 to=1 代表目标类型是 TensorProto.FLOAT
cast_node = helper.make_node("Cast", ['conv_input_u8'], ['conv_input_f32'], to=1, name='cast_to_float')

# 节点C: Conv (接收 FLOAT32,配合 pads=[1,1,1,1] 守恒尺寸)
conv_node = helper.make_node("Conv", ['conv_input_f32', 'conv_weight', 'conv_bias'], ['conv_output'], name='conv', pads=[1, 1, 1, 1])

# 节点D: Add (最后的广播加法)
add_node = helper.make_node('Add', ['conv_output', 'add_input'], ['output'], name='add')

# 3. 组装计算图 (为了调试,把中间的 conv_output 也作为官方出口抛出)
graph = helper.make_graph(
    [resize_node, cast_node, conv_node, add_node], 
    'resize_conv_add_uint8_debug_graph', 
    inputs=[input_tensor, roi, scales, conv_weight, conv_bias, add_input], 
    outputs=[output, conv_output, conv_input_u8]
)

# 4. 创建模型并锁定版本 (满足你的 ONNX Runtime 1.19.2 环境要求)
model = helper.make_model(graph, ir_version=10)
del model.opset_import[:]
model.opset_import.extend([helper.make_opsetid("", 13)]) # 使用极其稳定的 Opset 13

# 5. 校验并保存
onnx.checker.check_model(model)
onnx.save(model, 'resize_conv_addU8.onnx')

print("✅ 模型已成功生成(IR=10, Opset=19)!")

sampleU8.py

运行该python程序即可导出模型文件resize_conv_addU8.onnx,用netron打开该模型可视化图像如下:

图中输出节点conv_input_u8和conv_output,是为了调试而故意进行的输出,正常模型保持output一个输出节点即可。首先对Resize节点进行分析,它的输入是1x3x8x8 UINT8型input,FLOAT型roi,以及Shape为4的FLOAT型scales,输出是1x3x16x16 UINT8型conv_input_u8。这里input为输入的特征图(Tensor),代表一张Batch size为1、3 通道(RGB)、大小为8x8的图像,它是被缩放的数据源;roi全称是Region of Interest,即"感兴趣区域",简单来说,它的作用是告诉算子只放大/缩小原图中的某一个"局部裁剪区域",而不是整张图,它是一个一维浮点数Tensor,格式通常是 start_r1, start_r2, ..., end_r1, end_r2, ...,用来指定一个高维边界框的起始和结束坐标(归一化到 0-1 之间),初始化为\[\]说明要针对整张图像进行全局缩放,不需要裁剪局部,即选择整张原图进行缩放;scales是缩放比例因子,是一个一维浮点数数组,它的形状是4,对应输入Tensor的4个维度N, C, H, W,由输出形状可以确定scales静态值实际上应该是1.0, 1.0, 2.0, 2.0;经过Resize放到后正好得到1, 3, 16, 16输出特征图(Tensor)。在实际运行调试时会分别对行列数据进行扩展复制,对应数据如下:

因为在onnx库中卷积对应的输入需要时FLOAT型,所以这里直接使用一个Cast节点对UINT8型输入强制类型转换为FLOAT型输出conv_input_f32,这正是Conv节点的输入,另外Conv还有两个输入,32, 3, 3, 3 FLOAT型输入权重conv_weight和32 FLOAT型输入偏置conv_bias,并产生1, 32, 16, 16 FLOAT型输出conv_output。这里conv_weight是卷积核权重(也称滤波器),32是输出通道,意味着有32个不同的滤波器同时去提取特征,3是输入通道,必须与conv_input_f32的通道数严格对齐,每个卷积核内部都有3层,最后的3, 3对应卷积核的尺寸(Kernel Size),即高3x宽3的滑窗;conv_bias偏置形状32,表示每一个输出通道(卷积核)配一个偏置常数,做完矩阵乘法后要加上这个值。此外Conv算子还有其他一些属性(attribute),如strides,pads等。其中strides表示kernel窗口滑动步长,默认值为1,即每次做完一次卷积滑动1个像素;pads表示是否对输入特征图四周进行0填充,对于二维图像(H和W)来说,它的对应顺序是:常规方向的Top, 常规方向的Left, 常规方向的Bottom, 常规方向的Right,之所以进行填充是为了保证进行完卷积后输出特征图可以保持原始尺寸,以标准的二维卷积为例,其输出尺寸公式为:

如果不进行填充,则输出高和宽都为(16-3+0+0)/1+1=14,而如果上下左右都做1像素填充,则输出高和宽为(16-3+1+1)/1+1=16,可见能保持输入尺寸。实例中正是进行了pads=1, 1, 1, 1的填充,则实际卷积过程如下所示:

图中输入特征图的3个通道都进行了填充,之后每个通道和卷积核的3个通道依次做矩阵乘法,3个结果相加再加上相应的偏置输出相应的加权和结果,之后再在W和H依次滑动kernel窗口即可获得第一个卷积核对应的16x16输出,再利用其余卷积核最终可获得32x16x16输出特征图。

最后的Add节点非常简单,就是执行数学上的加法:A + B = C,但在深度学习和ONNX的底层,它触发了一个非常经典的矩阵计算机制------广播机制(Broadcasting)。如果按照严格的线性代数规则,两个矩阵相加,它们的形状(Shape)必须完全一模一样。但是在这里,一个形状是1, 32, 16, 16,另一个是1,无法直接进行相加,这里ONNX会启动广播机制完成相加过程,ONNX Runtime看到add_input的形状是1时,它会在后台的内存流水线上把这个单元素Tensor进行"隐式疯狂复制":

(1)它发现主矩阵有4个维度;

(2)会把 1 自动脑补并对齐成 1, 1, 1, 1

(3)接着,会把这个单维度的数字沿着主矩阵的每一个轴进行拉伸扩展,直到它的形状也变成 1, 32, 16, 16

之后形状相同即可完成最终相加,完整的模型预测验证源码如下:

复制代码
import numpy as np
import onnxruntime

inh=8
inw=8
outh=16
outw=16

# 1. 载入带有 Cast 保护的全新调试模型
session = onnxruntime.InferenceSession("resize_conv_addU8.onnx")

# 2. 构造符合调试预期的"整型测试数据"
# 图像输入:构造一个 [1, 3, 8, 8] 的整型矩阵,数值限制在 0~255 的 uint8
dummy_input = np.random.randint(0, 255, (1, 3, inh, inw)).astype(np.uint8)

# ROI 和 Scales
dummy_roi = np.array([], dtype=np.float32)
dummy_scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32)

# 卷积核与偏置:我们用带有 .0 的浮点数来填满它,模拟没有小数点的"纯整型乘法"
# 比如每个权重都随机填 1.0 到 5.0 之间的整数
dummy_weight = np.random.randint(1, 5, (32, 3, 3, 3)).astype(np.float32)
dummy_bias = np.random.randint(0, 10, (32,)).astype(np.float32)

# 广播加法的输入:同样传一个整数值
dummy_add_input = np.array([10.0], dtype=np.float32)

# 3. 喂入输入字典
input_feed = {
    "input": dummy_input,
    "roi": dummy_roi,
    "scales": dummy_scales,
    "conv_weight": dummy_weight,
    "conv_bias": dummy_bias,
    "add_input": dummy_add_input,
}

print("--- 📌 1. 外部输入数据调试 ------------------------------------")
print(f"【输入 input】 形状: {input_feed['input'].shape} | 数据类型: {input_feed['input'].dtype}")
print(f"【输入 input】 : {input_feed['input']}")
print(f"【模拟整型卷积核权重】 权重块:\n{input_feed['conv_weight']}")
print(f"偏置块:\n{dummy_bias}")

# 4. 申请索要最终输出和中间卷积层的输出
output_names = ["output", "conv_output", "conv_input_u8"]

# 5. 跑通推理
res_output, res_conv_output, res_conv_input_u8 = session.run(output_names, input_feed)

print("--- 📌 经过 Resize 放大后的真实卷积输入 (第0通道左上角 3x3) ---")
print(res_conv_input_u8)

print("\n--- 📌 2. 运行时中间层与最终层数据追踪 -----------------------")
# 打印中间卷积层出来的数值(此时处于 FLOAT32 状态)
print(f"【中间卷积层 conv_output】 形状: {res_conv_output.shape} | 数据类型: {res_conv_output.dtype}")
print(f"【中间卷积层 conv_output】 加权和结果: {res_conv_output}")
np.savetxt("conv_output_full.txt", res_conv_output.flatten(), fmt="%.1f")
print("💾 完整卷积输出已写入本地 conv_output_full.txt!")

# 打印经过广播加法(+10.0)后的最终输出
print(f"【最终输出层 output】 形状: {res_output.shape} | 数据类型: {res_output.dtype}")
print(f"【最终输出层 output】 最终结果: {res_output}")

print("\n💡 验证加权和对齐: 最终输出的数值应该精准比卷积层大 10.0 (验证广播加法正确)")

getoutputU8.py

4 参考

1 https://github.com/onnx/onnx

2 https://onnx.ai/onnx/operators/onnx__Conv.html

3 ONNX:从入门到精通

4 https://netron.app

相关推荐
程序猿追9 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
闵孚龙9 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
bryant_meng9 天前
【VAE】From Pixels to Faces: Building a VAE from Scratch
pytorch·vae·log-sigma2·重参数
装不满的克莱因瓶9 天前
了解多标签图像分类方法——从Sigmoid输出到真实世界复杂视觉理解
人工智能·pytorch·python·深度学习·机器学习·分类·数据挖掘
冷小鱼9 天前
TensorFlow 2.21 进阶实战:从训练优化到生产部署的完整指南
人工智能·pytorch·python·tensorflow
冷小鱼9 天前
PyTorch 2.12 完全指南:从动态图到编译优化的深度学习框架演进
人工智能·pytorch·深度学习
IRevers9 天前
【大模型】Gemma4在ROCm和vLLM部署
人工智能·pytorch·深度学习·大模型·datawhale·vllm·amdev
盼小辉丶9 天前
PyTorch强化学习实战(14)——优先经验回放机制
pytorch·python·深度学习·强化学习
装不满的克莱因瓶10 天前
【工业领域】了解目标检测评估指标——从mAP到IoU的完整评价体系解析
人工智能·pytorch·python·深度学习·目标检测·计算机视觉·目标跟踪