【问题标题】:Tensorflow 2.1/Keras - "output_node is not in graph" error when trying to freeze graphTensorflow 2.1/Keras - 尝试冻结图时出现“输出节点不在图中”错误
【发布时间】:2020-01-30 09:36:18
【问题描述】:

我正在尝试保存使用 Keras 创建并保存为 .h5 文件的模型,但每次尝试运行 freeze_session 函数时都会收到此错误消息:output_node/Identity is not in graph

这是我的代码(我使用的是 Tensorflow 2.1.0):

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a pruned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    pruned so subgraphs that are not necessary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.compat.v1.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph
model=kr.models.load_model("model.h5")
model.summary()
# inputs:
print('inputs: ', model.input.op.name)
# outputs: 
print('outputs: ', model.output.op.name)
#layers:
layer_names=[layer.name for layer in model.layers]
print(layer_names)

哪些打印:

inputs: input_node outputs: output_node/Identity ['input_node', 'conv2d_6', 'max_pooling2d_6', 'conv2d_7', 'max_pooling2d_7', 'conv2d_8', 'max_pooling2d_8', 'flatten_2', 'dense_4', 'dense_5', 'output_node'] 正如预期的那样(与我训练后保存的模型中的层名称和输出相同)。

然后我尝试调用 freeze_session 函数并保存生成的冻结图:

frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
write_graph(frozen_graph, './', 'graph.pb', as_text=False)

但我收到此错误:

AssertionError                            Traceback (most recent call last)
<ipython-input-4-1848000e99b7> in <module>
----> 1 frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in model.outputs])
      2 write_graph(frozen_graph, './', 'graph.pbtxt', as_text=True)
      3 write_graph(frozen_graph, './', 'graph.pb', as_text=False)

<ipython-input-2-3214992381a9> in freeze_session(session, keep_var_names, output_names, clear_devices)
     24                 node.device = ""
     25         frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
---> 26             session, input_graph_def, output_names, freeze_var_names)
     27         return frozen_graph

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist, variable_names_blacklist)
    275   # This graph only includes the nodes needed to evaluate the output nodes, and
    276   # removes unneeded nodes like those involved in saving and assignment.
--> 277   inference_graph = extract_sub_graph(input_graph_def, output_node_names)
    278 
    279   # Identify the ops in the graph.

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\util\deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in extract_sub_graph(graph_def, dest_nodes)
    195   name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
    196       graph_def)
--> 197   _assert_nodes_are_present(name_to_node, dest_nodes)
    198 
    199   nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)

c:\users\marco\anaconda3\envs\tfv2\lib\site-packages\tensorflow_core\python\framework\graph_util_impl.py in _assert_nodes_are_present(name_to_node, nodes)
    150   """Assert that nodes are present in the graph."""
    151   for d in nodes:
--> 152     assert d in name_to_node, "%s is not in graph" % d
    153 
    154 

**AssertionError: output_node/Identity is not in graph** 

我已经尝试过,但我真的不知道如何解决这个问题,所以任何帮助将不胜感激。

【问题讨论】:

  • 看看this question and its answer:在 TF2 中不支持冻结图形。您需要改为导出 SavedModel
  • 另一方面,对于(不受支持的)方法,请查看this blog post。但不能保证它有效,也不能长期有效
  • @GPhilo 感谢您的回答。我认为调用 Tensorflow 1 函数仍然可以工作,但后来我猜它不行。不过,我需要得到一个冻结图,因为我必须在 Google Vision Bonnet 上上传模型,并且编译器需要一个 .pb 文件。那么此时直接使用TF1可能是最好的吗?附带问题:我在使用 model.save l 函数时得到的 SavedModel 包含一个目录,该目录还包含一个 saved_model .fb 文件。它与旧的 (TF1) .pb 冻结图相同吗?我认为不是,但我想知道是否有办法使用它。
  • 如果您不需要 TF2 中的特定内容,可能使用 TF1 对您来说更容易,是的。根据另一个问题,不,它们是不同的文件。 savedmodel 格式需要整个目录。 “Pb”只是二进制protobuf消息的通用扩展,消息格式在SavedModel和冻结图之间变化
  • @GPhilo 完全切换到 Tensorflow 1.15 帮助我正确保存了文件!虽然 Vision Bonnet 编译器仍然无法解析 .pb 文件,但我将开始另一个关于这个问题的线程。

标签: python python-3.x tensorflow machine-learning keras


【解决方案1】:

如果您使用 Tensorflow 2.x 版,请添加:

tf.compat.v1.disable_eager_execution()

这应该可行。 我没有检查生成的 pb 文件,但它应该可以工作。

感谢您的反馈。

edit:但是,在this thread 之后,TF1 和 TF2 pb 文件根本不同。我的解决方案可能无法正常工作或实际上创建了一个 TF1 pb 文件。


如果你再遇到

RuntimeError: 尝试使用已关闭的会话。

这可以通过重启内核来解决。使用上面的线,您只有一枪。

【讨论】:

  • 嗨!最后,我通过完全切换到 TF1 解决了这个问题,因为我不得不将它用于需要冻结图的 Google Vision Kit,而 TF2 只处理保存的模型。无论如何感谢您的回答。
  • 很高兴听到它有效。保持储蓄和健康。
  • @MarcoEsposito 你能分享一下你是怎么做到的吗?
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2018-11-29
  • 2019-05-16
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多