【问题标题】:How to wrap a frozen Tensoflow graph in a Keras Lambda layer in TF2?如何在 TF2 的 Keras Lambda 层中包装冻结的 Tensorflow 图?
【发布时间】:2020-07-09 11:18:08
【问题描述】:

这个问题与this question有关,它提供了一个在Tensorflow 1.15 中有效,但在TF2 中不再有效的解决方案

我正在从该问题中提取部分代码并稍微调整它(删除了冻结模型的多个输入,并随之消除了对 nest 的需求)。

注意:我将代码分隔成块,但它们应该作为文件运行(即,我不会在每个块中重复不必要的导入)

首先,我们生成一个冻结图用作虚拟测试网络:

import numpy as np
import tensorflow.compat.v1 as tf

def dump_model():
    with tf.Graph().as_default() as gf:
        x = tf.placeholder(tf.float32, shape=(None, 123), name='x')
        c = tf.constant(100, dtype=tf.float32, name='C')
        y = tf.multiply(x, c, name='y')
        z = tf.add(y, x, name='z')
        with tf.gfile.GFile("tmp_net.pb", "wb") as f:
            raw = gf.as_graph_def().SerializeToString()
            print(type(raw), len(raw))
            f.write(raw)

dump_model()

然后,我们加载冻结的模型并将其包装在 Keras 模型中:

persisted_sess = tf.Session()
with tf.Session().as_default() as session:
    with tf.gfile.FastGFile("./tmp_net.pb",'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        persisted_sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        print(persisted_sess.graph.get_name_scope())
        for i, op in enumerate(persisted_sess.graph.get_operations()):
            tensor = persisted_sess.graph.get_tensor_by_name(op.name + ':0')
            print(i, '\t', op.name, op.type, tensor)
        x_tensor = persisted_sess.graph.get_tensor_by_name('x:0')
        y_tensor = persisted_sess.graph.get_tensor_by_name('y:0')
        z_tensor = persisted_sess.graph.get_tensor_by_name('z:0')

from tensorflow.compat.v1.keras.layers import Lambda, InputLayer
from tensorflow.compat.v1.keras import Model
from tensorflow.python.keras.utils import layer_utils

input_x = InputLayer(name='x', input_tensor=x_tensor)
input_x.is_placeholder = True
output_y = Lambda(lambda x: y_tensor, name='output_y')(input_x.output)
output_z = Lambda(lambda x_b: z_tensor, name='output_z')(input_x.output)

base_model_inputs = layer_utils.get_source_inputs(input_x.output)
base_model = Model(base_model_inputs, [output_y, output_z])

最后,我们在一些随机数据上运行模型,并验证它运行时没有错误:

y_out, z_out = base_model.predict(np.ones((3, 123), dtype=np.float32))
y_out.shape, z_out.shape

在 TensorFlow 1.15.3 中,上面的输出是((3, 123), (3, 123)),但是,如果我在 TensorFlow 2.1.0 中运行相同的代码,前两个块运行没有问题,但第三个块失败:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: y:0

该错误似乎与Tensorflow的自动“编译”和功能优化有关,但我不知道如何解释,错误的根源是什么,或者如何解决。

在 Tensorflow 2 中包装冻结模型的正确方法是什么?

【问题讨论】:

    标签: python tensorflow keras tensorflow2.0


    【解决方案1】:

    我可以像这样在 2.2.0 中运行你的整个示例。

    import tensorflow as tf
    from tensorflow.core.framework.graph_pb2 import GraphDef
    import numpy as np
    
    with tf.Graph().as_default() as gf:
        x = tf.compat.v1.placeholder(tf.float32, shape=(None, 123), name='x')
        c = tf.constant(100, dtype=tf.float32, name='c')
        y = tf.multiply(x, c, name='y')
        z = tf.add(y, x, name='z')
        with open('tmp_net.pb', 'wb') as f:
            f.write(gf.as_graph_def().SerializeToString())
    
    with tf.Graph().as_default():
        gd = GraphDef()
        with open('tmp_net.pb', 'rb') as f:
            gd.ParseFromString(f.read())
        x, y, z = tf.graph_util.import_graph_def(
            gd, name='', return_elements=['x:0', 'y:0', 'z:0'])
        del gd
        input_x = tf.keras.layers.InputLayer(name='x', input_tensor=x)
        input_x.is_placeholder = True
        output_y = tf.keras.layers.Lambda(lambda x: y, name='output_y')(input_x.output)
        output_z = tf.keras.layers.Lambda(lambda x: z, name='output_z')(input_x.output)
    
        base_model_inputs = tf.keras.utils.get_source_inputs(input_x.output)
        base_model = tf.keras.Model(base_model_inputs, [output_y, output_z])
    
        y_out, z_out = base_model.predict(np.ones((3, 123), dtype=np.float32))
        print(y_out.shape, z_out.shape)
        # (3, 123) (3, 123)
    

    “诀窍”是将模型构造包装在 with tf.Graph().as_default(): 块中,这将确保在同一个图形对象中以图形模式创建所有内容。

    但是,将图形加载和计算包装在 @tf.function 中可能更简单,这样可以避免此类错误并使模型构建更加透明:

    import tensorflow as tf
    from tensorflow.core.framework.graph_pb2 import GraphDef
    import numpy as np
    
    @tf.function
    def my_model(x):
        gd = GraphDef()
        with open('tmp_net.pb', 'rb') as f:
            gd.ParseFromString(f.read())
        y, z = tf.graph_util.import_graph_def(
            gd, name='', input_map={'x:0': x}, return_elements=['y:0', 'z:0'])
        return [y, z]
    
    x = tf.keras.Input(shape=123)
    y, z = tf.keras.layers.Lambda(my_model)(x)
    model = tf.keras.Model(x, [y, z])
    y_out, z_out = model.predict(np.ones((3, 123), dtype=np.float32))
    print(y_out.shape, z_out.shape)
    # (3, 123) (3, 123)
    

    【讨论】:

    • 谢谢!似乎这里的关键是让该图成为模型创建和推理时的默认图。在我的实际代码中,我在一个函数中加载和构建模型。如果我同时返回图形和模型,并在调用 predict(), it works! If I don't, though, I still have the same error as above. Thank you very much for this answer, it gave me just the right information I was missing (like the fact that I don't need Session`s 之前使用 with graph.as_default(): 处理图形...)。
    • 作为一个可能的改进点,您能否澄清一下您对我的代码所做的更改,以便未来的观众也可以从这个解决方案中受益?
    • @GPhilo 是的,感谢您的建议。我还添加了另一种可能的方法来对@tf.function 做同样的事情,我认为这可能更干净。唯一的小缺点是您必须提前指定输入的大小,尽管您可能应该知道您正在加载的模型的输入和输出的大小。
    • 您的第二个选项看起来非常像我希望的那样!当您说您需要知道输入的大小时,是否必须完全知道,或者它也可以有Nones? (我正在使用一个检测网络,不幸的是它可以并且需要接受可变形状的输入,所以输入形状是[None, None, None, 3]
    • @GPhilo 啊,是的,你可以有None(事实上,如果你在示例中传递shape=(None,),它也可以工作)。我的意思是说,你需要给模型Input一个与模型中输入的形状兼容的形状。
    【解决方案2】:

    另一种可能的方法是

    import tensorflow as tf
    
    input_layer = tf.keras.Input(shape=[123])
    keras_graph = input_layer.graph
    
    with keras_graph.as_default():
        with tf.io.gfile.GFile('tmp_net.pb', 'rb') as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
    
        tf.graph_util.import_graph_def(graph_def, name='', input_map={'x:0': input_layer})
        
        
    y_tensor = keras_graph.get_tensor_by_name('y:0')
    z_tensor = keras_graph.get_tensor_by_name('z:0')
    
    base_model = tf.keras.Model(input_layer, [y_tensor, z_tensor])
    

    然后

    y_out, z_out = base_model.predict(tf.ones((3, 123), dtype=tf.float32))
    print(y_out.shape, z_out.shape)
    # (3, 123) (3, 123)
    

    【讨论】:

    • 我不知道Input 层有graph 属性,谢谢!这看起来也很有趣!
    猜你喜欢
    • 2022-07-25
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-08-07
    • 2020-03-25
    相关资源
    最近更新 更多