TensorFlow 2.0 模型保存与转换:.pb格式详解

版本要求:

Tensorflow 2.0.0+

tensorflow原生保存的模型不适用于多平台使用,将模型转化为.pb格式,可以更加方便的转化为别的格式,本文主要介绍如何转化.pb格式。

以mnist数据集为例进行讲解:

1. 创建网络模型,训练,保存模型:

inputs = tf.keras.Input(shape=(28,28,1), name='input')
# [28, 28, 1] => [28, 28, 64]
input = tf.keras.layers.Flatten(name="flatten")(inputs)
fc_1 = tf.keras.layers.Dense(512, activation='relu', name='fc_1')(input)
fc_2 = tf.keras.layers.Dense(256, activation='relu', name='fc_2')(fc_1)
pred = tf.keras.layers.Dense(10, activation='softmax', name='output')(fc_2)

model = tf.keras.Model(inputs=inputs, outputs=pred, name='mnist')
model.summary()

将网络搭建出来后,进行训练,然后保存模型,训练过程不再赘述,训练完保存模型,有两种保存方式:

方式1:

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import tensorflow as tf

# training code -----------------
tf.saved_model.save(obj=model, export_dir="./model")
# Convert Keras model to ConcreteFunction
# 注意这个Input,是自己定义的输入层名
full_model = tf.function(lambda Input: model(Input))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="mnist.pb",
                  as_text=False)

这种方式是训练完,然后使用tf.saved_model.save()函数保存模型,然后将模型转化为.pb格式.


方式2:

# training code--------------------
model.save("./mnist.h5")

将模型保存为.h5格式,然后再另外创建一个文件将.h5模型转为.pb模型:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def convert_h5to_pb():
    model = tf.keras.models.load_model("../mnist.h5",compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models",
                      name="mnist.pb",
                      as_text=False)

这种方式的好处是可以将别人训练好的.h5模型文件拿来转.pb模型


2. 测试.pb模型:

测试转化的.pb模型是否转化成功,能否加载

def wrap_frozen_graph(graph_def, inputs, outputs, print_graph=False):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    print("-" * 50)
    print("Frozen model layers: ")
    layers = [op.name for op in import_graph.get_operations()]
    if print_graph == True:
        for layer in layers:
            print(layer)
    print("-" * 50)

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))


def main():
    # 测试数据集,
    (train_images, train_labels), (test_images,
                                   test_labels) = tf.keras.datasets.mnist.load_data()

    # Load frozen graph using TensorFlow 1.x functions
    with tf.io.gfile.GFile("./model/frozen_models/mnist.pb", "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

    # Wrap frozen graph to ConcreteFunctions
    frozen_func = wrap_frozen_graph(graph_def=graph_def,
                                    inputs=["Input:0"],
                                    outputs=["Identity:0"],
                                    print_graph=True)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Get predictions for test images
    predictions = frozen_func(Input=tf.constant(test_images))[0]

    # Print the prediction for the first image
    print("-" * 50)
    print("Example prediction reference:")
    print(predictions[0].numpy())


if __name__ == "__main__":

    main()

至此 模型转化结束。

def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    # output_node_names = "clone_0/block3_box/Reshape:0,clone_0/block3_box/Reshape_1:0,clone_0/block4_box/Reshape:0,clone_0/block4_box/Reshape_1:0," \
    #                     "clone_0/block5_box/Reshape:0,clone_0/block5_box/Reshape_1:0,clone_0/block6_box/Reshape:0,clone_0/block6_box/Reshape_1:0," \
    #                     "clone_0/block7_box/Reshape:0,clone_0/block7_box/Reshape_1:0,clone_0/block8_box/Reshape:0,clone_0/block8_box/Reshape_1:0"

    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    output_node_names = "block3_box/Reshape,block3_box/Reshape_1,block4_box/Reshape,block4_box/Reshape_1,block5_box/Reshape,block5_box/Reshape_1,block6_box/Reshape,block6_box/Reshape_1,block7_box/Reshape,block7_box/Reshape_1,block8_box/Reshape,block8_box/Reshape_1,seg_argmax"
    # output_node_names = "seg_argmax"
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开
        # for node in output_graph_def.node:
        #     if node.op == 'RefSwitch':
        #         node.op = 'Switch'
        #         for index in range(len(node.input)):
        #             if 'moving_' in node.input[index]:
        #                 node.input[index] = node.input[index] + '/read'
        #     elif node.op == 'AssignSub':
        #         node.op = 'Sub'
        #         if 'use_locking' in node.attr:
        #             del node.attr['use_locking']
        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

        # for op in sess.graph.get_operations():
        #     print(op.name, op.values())

def show_Node_Name(varible_name=None):
    # ckpt = tf.train.get_checkpoint_state(model_path + '/')
    ckpt='../inference_model_rt/seg_model'
    saver = tf.train.import_meta_graph(ckpt + '.meta',clear_devices=True)
    with tf.Session() as sess:
        # saver.restore(sess, ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt)
        graph = tf.get_default_graph()

        tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
        if varible_name==None:
            for tensor_name in tensor_name_list:
                print(tensor_name)
        else:
            for tensor_name in tensor_name_list:
                if varible_name in tensor_name.split("/"):
                    print(tensor_name)
相关推荐
Tianyanxiao7 分钟前
如何利用探商宝精准营销,抓住行业机遇——以AI技术与大数据推动企业信息精准筛选
大数据·人工智能·科技·数据分析·深度优先·零售
撞南墙者14 分钟前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone42115 分钟前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
王哈哈^_^36 分钟前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心43 分钟前
【AI技术】PaddleSpeech
人工智能
是瑶瑶子啦1 小时前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
EasyCVR1 小时前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频
柳鲲鹏1 小时前
OpenCV视频防抖源码及编译脚本
人工智能·opencv·计算机视觉
西柚小萌新1 小时前
8.机器学习--决策树
人工智能·决策树·机器学习
向阳12181 小时前
Bert快速入门
人工智能·python·自然语言处理·bert