版本要求:
Tensorflow 2.0.0+
tensorflow原生保存的模型不适用于多平台使用,将模型转化为.pb格式,可以更加方便的转化为别的格式,本文主要介绍如何转化.pb格式。
以mnist数据集为例进行讲解:
1. 创建网络模型,训练,保存模型:
inputs = tf.keras.Input(shape=(28,28,1), name='input')
# [28, 28, 1] => [28, 28, 64]
input = tf.keras.layers.Flatten(name="flatten")(inputs)
fc_1 = tf.keras.layers.Dense(512, activation='relu', name='fc_1')(input)
fc_2 = tf.keras.layers.Dense(256, activation='relu', name='fc_2')(fc_1)
pred = tf.keras.layers.Dense(10, activation='softmax', name='output')(fc_2)
model = tf.keras.Model(inputs=inputs, outputs=pred, name='mnist')
model.summary()
将网络搭建出来后,进行训练,然后保存模型,训练过程不再赘述,训练完保存模型,有两种保存方式:
方式1:
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import tensorflow as tf
# training code -----------------
tf.saved_model.save(obj=model, export_dir="./model")
# Convert Keras model to ConcreteFunction
# 注意这个Input,是自己定义的输入层名
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./frozen_models",
name="mnist.pb",
as_text=False)
这种方式是训练完,然后使用tf.saved_model.save()函数保存模型,然后将模型转化为.pb格式.
方式2:
# training code--------------------
model.save("./mnist.h5")
将模型保存为.h5格式,然后再另外创建一个文件将.h5模型转为.pb模型:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
def convert_h5to_pb():
model = tf.keras.models.load_model("../mnist.h5",compile=False)
model.summary()
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./frozen_models",
name="mnist.pb",
as_text=False)
这种方式的好处是可以将别人训练好的.h5模型文件拿来转.pb模型
2. 测试.pb模型:
测试转化的.pb模型是否转化成功,能否加载
def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
print("-" * 50)
print("Frozen model layers: ")
layers = [op.name for op in import_graph.get_operations()]
if print_graph == True:
for layer in layers:
print(layer)
print("-" * 50)
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
def main():
# 测试数据集,
(train_images, train_labels), (test_images,
test_labels) = tf.keras.datasets.mnist.load_data()
# Load frozen graph using TensorFlow 1.x functions
with tf.io.gfile.GFile("./model/frozen_models/mnist.pb", "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())
# Wrap frozen graph to ConcreteFunctions
frozen_func = wrap_frozen_graph(graph_def=graph_def,
inputs=["Input:0"],
outputs=["Identity:0"],
print_graph=True)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
# Get predictions for test images
predictions = frozen_func(Input=tf.constant(test_images))[0]
# Print the prediction for the first image
print("-" * 50)
print("Example prediction reference:")
print(predictions[0].numpy())
if __name__ == "__main__":
main()
至此 模型转化结束。
def freeze_graph(input_checkpoint, output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
# output_node_names = "clone_0/block3_box/Reshape:0,clone_0/block3_box/Reshape_1:0,clone_0/block4_box/Reshape:0,clone_0/block4_box/Reshape_1:0," \
# "clone_0/block5_box/Reshape:0,clone_0/block5_box/Reshape_1:0,clone_0/block6_box/Reshape:0,clone_0/block6_box/Reshape_1:0," \
# "clone_0/block7_box/Reshape:0,clone_0/block7_box/Reshape_1:0,clone_0/block8_box/Reshape:0,clone_0/block8_box/Reshape_1:0"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
output_node_names = "block3_box/Reshape,block3_box/Reshape_1,block4_box/Reshape,block4_box/Reshape_1,block5_box/Reshape,block5_box/Reshape_1,block6_box/Reshape,block6_box/Reshape_1,block7_box/Reshape,block7_box/Reshape_1,block8_box/Reshape,block8_box/Reshape_1,seg_argmax"
# output_node_names = "seg_argmax"
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
# for node in output_graph_def.node:
# if node.op == 'RefSwitch':
# node.op = 'Switch'
# for index in range(len(node.input)):
# if 'moving_' in node.input[index]:
# node.input[index] = node.input[index] + '/read'
# elif node.op == 'AssignSub':
# node.op = 'Sub'
# if 'use_locking' in node.attr:
# del node.attr['use_locking']
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
# for op in sess.graph.get_operations():
# print(op.name, op.values())
def show_Node_Name(varible_name=None):
# ckpt = tf.train.get_checkpoint_state(model_path + '/')
ckpt='../inference_model_rt/seg_model'
saver = tf.train.import_meta_graph(ckpt + '.meta',clear_devices=True)
with tf.Session() as sess:
# saver.restore(sess, ckpt.model_checkpoint_path)
saver.restore(sess, ckpt)
graph = tf.get_default_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
if varible_name==None:
for tensor_name in tensor_name_list:
print(tensor_name)
else:
for tensor_name in tensor_name_list:
if varible_name in tensor_name.split("/"):
print(tensor_name)