前言
对onnx 结构中的权重进行修改
比如:原本是一个标量,修改为一个一维数组
code
cpp
import onnx
import numpy as np
import torch
import argparse
from onnx import TensorProto, helper, numpy_helper
# 检查onnx计算图
def check_onnx(model):
onnx.checker.check_model(model)
onnx_path = "./bs16_seq397.onnx"
save_path = "./bs16_seq397_m.onnx"
model = onnx.load(onnx_path)
graph = model.graph
# 修改 initializer
initializer1 = helper.make_tensor("537", TensorProto.FLOAT, [1], np.array([397]))
initializer2 = helper.make_tensor("540", TensorProto.FLOAT, [1], np.array([0]))
initializer3 = helper.make_tensor("1707", TensorProto.FLOAT, [1], np.array([2]))
# print(graph.initializer)
for initializer in graph.initializer:
if initializer.name =="537":
graph.initializer.remove(initializer)
graph.initializer.append(initializer1)
if initializer.name =="540":
graph.initializer.remove(initializer)
graph.initializer.append(initializer2)
if initializer.name =="1707":
graph.initializer.remove(initializer)
graph.initializer.append(initializer3)
graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
info_model = onnx.helper.make_model(graph)
info_model.ir_version = 8
info_model.opset_import[0].version = 11
check_onnx(info_model)
onnx.save_model(info_model, save_path)
print(f"modify onnx done, save path:{save_path} \n")