【问题标题】:Convert frozen model(.pb) to savedmodel将冻结模型(.pb)转换为已保存模型
【发布时间】:2020-09-16 18:35:49
【问题描述】:

最近我尝试将模型(tf1.x)转换为save_model,并关注了官方migrate document。但是在我的用例中,我手中的大多数模型或 tensorflow 模型动物园通常是 pb 文件,并且根据official document

没有直接的方法可以将原始 Graph.pb 文件升级到 TensorFlow 2.0,但是如果您有一个“Frozen graph”(变量已转换为常量的 tf.Graph),则可以转换使用 v1.wrap_function 将 this 转换为具体函数:

但是我还是不明白怎么转换成saved_model format

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    在 TF1 模式下:

    import tensorflow as tf
    from tensorflow.python.saved_model import signature_constants
    from tensorflow.python.saved_model import tag_constants
    
    def convert_pb_to_server_model(pb_model_path, export_dir, input_name='input:0', output_name='output:0'):
        graph_def = read_pb_model(pb_model_path)
        convert_pb_saved_model(graph_def, export_dir, input_name, output_name)
    
    
    def read_pb_model(pb_model_path):
        with tf.gfile.GFile(pb_model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            return graph_def
    
    
    def convert_pb_saved_model(graph_def, export_dir, input_name='input:0', output_name='output:0'):
        builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
    
        sigs = {}
        with tf.Session(graph=tf.Graph()) as sess:
            tf.import_graph_def(graph_def, name="")
            g = tf.get_default_graph()
            inp = g.get_tensor_by_name(input_name)
            out = g.get_tensor_by_name(output_name)
    
            sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
                tf.saved_model.signature_def_utils.predict_signature_def(
                    {"input": inp}, {"output": out})
    
            builder.add_meta_graph_and_variables(sess,
                                                 [tag_constants.SERVING],
                                                 signature_def_map=sigs)
            builder.save()
    

    在 TF2 模式下:

    import tensorflow as tf
    from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
    from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
    import numpy as np
    def frozen_keras_graph(func_model):
        frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func_model)
    
        input_tensors = [
            tensor for tensor in frozen_func.inputs
            if tensor.dtype != tf.resource
        ]
        output_tensors = frozen_func.outputs
        graph_def = run_graph_optimizations(
            graph_def,
            input_tensors,
            output_tensors,
            config=get_grappler_config(["constfold", "function"]),
            graph=frozen_func.graph)
    
        return graph_def
    
    
    def convert_keras_model_to_pb():
    
        keras_model = train_model()
        func_model = tf.function(keras_model).get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
        graph_def = frozen_keras_graph(func_model)
        tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')
    
    def convert_saved_model_to_pb():
        model_dir = '/tmp/saved_model'
        model = tf.saved_model.load(model_dir)
        func_model = model.signatures["serving_default"]
        graph_def = frozen_keras_graph(func_model)
        tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')
    
    

    或者:

    def convert_saved_model_to_pb(output_node_names, input_saved_model_dir, output_graph_dir):
        from tensorflow.python.tools import freeze_graph
    
        output_node_names = ','.join(output_node_names)
    
        freeze_graph.freeze_graph(input_graph=None, input_saver=None,
                                  input_binary=None,
                                  input_checkpoint=None,
                                  output_node_names=output_node_names,
                                  restore_op_name=None,
                                  filename_tensor_name=None,
                                  output_graph=output_graph_dir,
                                  clear_devices=None,
                                  initializer_nodes=None,
                                  input_saved_model_dir=input_saved_model_dir)
    
    
    def save_output_tensor_to_pb():
        output_names = ['StatefulPartitionedCall']
        save_pb_model_path = '/tmp/pb_model/freeze_graph.pb'
        model_dir = '/tmp/saved_model'
        convert_saved_model_to_pb(output_names, model_dir, save_pb_model_path)
    

    【讨论】:

    • 请加补偿。
    • 你救了我的命,谢谢!!!对于使用 tensorflow 版本 > 2.0 的任何人,“import tensorflow as tf”应替换为 import tensorflow.compat.v1 as tf.disable_v2_behavior()
    • 我如何知道输入和输出名称?
    • @secsilm 你可以使用saved_model_cli show --dir查看.pb信息。
    • ValueError: The name 'input' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".inp = g.get_tensor_by_name(input_name)
    【解决方案2】:

    为了保证我的理解是否正确,所以我也把我学到的贴出来:

    如果有人想将 tf1.x 迁移到 tf2.x,请先关注official post

    在 tensorflow 2.0 中,tf.train.Saver 和 freeze_graph 已经被 saved_model 替换。

    如果有人想将 pb 模型从 tf1.x 转换为 saved_model,你可以按照@Boluoyu 的回答。但是如果你的运行环境在tf2.0以上,你可以使用如下代码:

    import tensorflow.compat.v1 as tf 
    tf.disable_v2_behavior()
    from tensorflow.python.saved_model import signature_constants
    from tensorflow.python.saved_model import tag_constants
    
    def covert_pb_to_server_model(pb_model_path, export_dir, input_name='input', output_name='output'):
        graph_def = read_pb_model(pb_model_path)
        covert_pb_saved_model(graph_def, export_dir, input_name, output_name)
    
    
    def read_pb_model(pb_model_path):
        with tf.gfile.GFile(pb_model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            return graph_def
    
    
    def covert_pb_saved_model(graph_def, export_dir, input_name='input', output_name='output'):
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
    
    sigs = {}
    with tf.Session(graph=tf.Graph()) as sess:
        tf.import_graph_def(graph_def, name="")
        g = tf.get_default_graph()
        inp = g.get_tensor_by_name(input_name)
        out = g.get_tensor_by_name(output_name)
    
        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
            tf.saved_model.signature_def_utils.predict_signature_def(
                {"input": inp}, {"output": out})
    
        builder.add_meta_graph_and_variables(sess,
                                             [tag_constants.SERVING],
                                             signature_def_map=sigs)
        builder.save()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-10-27
      • 1970-01-01
      • 2018-01-08
      • 2022-01-19
      • 2021-03-23
      • 1970-01-01
      • 2019-11-28
      • 2019-01-04
      相关资源
      最近更新 更多