【问题标题】:Given a tensor flow model graph, how to find the input node and output node names给定一个张量流模型图,如何找到输入节点和输出节点名称
【发布时间】:2017-04-20 11:14:26
【问题描述】:

我在 Tensor flow Camera Demo 中使用自定义模型进行分类。 我生成了一个 .pb 文件(序列化的 protobuf 文件),我可以显示它包含的巨大图形。 要将此图转换为优化图,如 [https://www.oreilly.com/learning/tensorflow-on-android] 中给出的,可以使用以下过程:

$ bazel-bin/tensorflow/python/tools/optimize_for_inference  \
--input=tf_files/retrained_graph.pb \
--output=tensorflow/examples/android/assets/retrained_graph.pb
--input_names=Mul \
--output_names=final_result

这里如何从图形显示中找到 input_names 和 output_names。 当我不使用专有名称时,我会遇到设备崩溃:

E/TensorFlowInferenceInterface(16821): Failed to run TensorFlow inference 
with inputs:[AvgPool], outputs:[predictions]

E/AndroidRuntime(16821): FATAL EXCEPTION: inference

E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible 
shapes: [1,224,224,3] vs. [32,1,1,2048]

E/AndroidRuntime(16821):     [[Node: dropout/dropout/mul = Mul[T=DT_FLOAT, 
_device="/job:localhost/replica:0/task:0/cpu:0"](dropout/dropout/div, 
dropout/dropout/Floor)]]

【问题讨论】:

  • 嗨@Dr.SantleCamilus ,您找到解决方案了吗?
  • 是的,提及正确的输入和输出节点名称对于 android TF 演示的工作至关重要。一些较旧的 TF 训练代码可能不会在模型中包含这些名称。 JP Kim 的以下回答可以找到节点名称的存在。如果名称不存在,则需要迁移到新的 TF 训练代码以包含正确的节点名称。
  • 我得到这样的输出 *[u'image_tensor=>Placeholder'] *
  • [u'image_tensor=>Placeholder'] 表示你的输入节点名是 ''image_tensor" (/定义optimize_for_interface时可以使用--input_names=image_tensor)
  • 请使用 JP Kim 的以下答案检查模型中是否存在 softmax 节点。如果返回任何内容,请使用相同的名称作为输出名称。输出名称是生成CNN网络输出的特定节点。

标签: android tensorflow bazel


【解决方案1】:

试试这个:

运行python

>>> import tensorflow as tf
>>> gf = tf.GraphDef()
>>> gf.ParseFromString(open('/your/path/to/graphname.pb','rb').read())

然后

>>> [n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]

然后,你可以得到类似这样的结果:

['Mul=>Placeholder', 'final_result=>Softmax']

但我不确定这是关于错误消息的节点名称的问题。 我猜您在加载图形文件时提供了错误的论点,或者您生成的图形文件有问题?

检查这部分:

E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible 
shapes: [1,224,224,3] vs. [32,1,1,2048]

更新: 对不起, 如果您使用的是(重新)训练图,那么试试这个:

[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Mul')]

似乎(重新)训练的图将输入/输出操作名称保存为“Mul”和“Softmax”,而优化和/或量化的图将它们保存为“Placeholder”和“Softmax”。

顺便说一句, 根据 Peter Warden 的帖子:https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/,不建议在移动环境中使用重新训练的图。由于性能和文件大小问题,最好使用量化或 memmapped 图,但我不知道如何在 android 中加载 memmapped 图...:( (在android中加载优化/量化图没有问题)

【讨论】:

  • 当我为我的自定义模型执行注释时:[n.name + '=>' + n.op for n in input_graph_def.node if n.op in ('Softmax','Placeholder' )],我得到 [u'tower_0/logits/predictions=>Softmax'],显示输出层名称,而输入层名称不存在。我不明白哪里出了问题。
  • @Dr.SantleCamilus ,我认为加载图形文件时出错的原因是您尝试加载未针对移动设备优化的图形。您不应该使用重新训练的输出中的 pb 文件。它在移动设备上有 Djpeg 问题。所以只需使用 optimize_for_inference 和/或 quantize_graph 转换它。两者都很好,但量化图更好。
  • 在 optimize_for_inference 或 quantize_graph 或 transform_graph 之后 [n.name + '=>' + n.op for n in gf.node if n.op in ('Softmax','Placeholder')] 的输出操作是 [u'tower_0/logits/predictions=>Softmax']。
  • 在 optimize_for_inference 或 quantize_graph [u 'tower_0/conv0/BatchNorm/moments/normalize/shifted_mean=>Mul', u'tower_0/conv0/BatchNorm/moments/normalize/Mul=>Mul', ...... u'tower_0/ mixed_8x8x2048b/branch_pool/Conv/BatchNorm/batchnorm/mul=>Mul', u'tower_0/mixed_8x8x2048b/branch_pool/Conv/BatchNorm/batchnorm/mul_1=>Mul', u'tower_0/logits/dropout/dropout/random_uniform/mul= >Mul', u'tower_0/logits/dropout/dropout/mul=>Mul', u'tower_0/logits/predictions=>Softmax']
  • 历史在这里:张量流模型是使用 inception V3 架构创建的。:github.com/tensorflow/models/tree/master/inception 模型以检查点 (ckpt) 格式(.meta、.index 和 .data)保存。将模型转换为 .pb 文件以移植到张量流相机演示 (github.com/tensorflow/tensorflow/blob/master/tensorflow/…)
【解决方案2】:

最近我直接从 tensorflow 中遇到了这个选项:

bazel build tensorflow/tools/graph_transforms:summarize_graph    
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph
--in_graph=custom_graph_name.pb

【讨论】:

    【解决方案3】:

    我写了一个简单的脚本来分析计算图(通常是 DAG,直接无环图)中的依赖关系。很明显,输入是缺少输入的节点。但是,输出可以定义为图中的任何节点,因为在最奇怪但仍然有效的情况下,输出可以是输入,而其他节点都是虚拟的。我仍然将输出操作定义为代码中没有输出的节点。您可以随意忽略它。

    import tensorflow as tf
    
    def load_graph(frozen_graph_filename):
        with tf.io.gfile.GFile(frozen_graph_filename, "rb") as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def)
        return graph
    
    def analyze_inputs_outputs(graph):
        ops = graph.get_operations()
        outputs_set = set(ops)
        inputs = []
        for op in ops:
            if len(op.inputs) == 0 and op.type != 'Const':
                inputs.append(op)
            else:
                for input_tensor in op.inputs:
                    if input_tensor.op in outputs_set:
                        outputs_set.remove(input_tensor.op)
        outputs = list(outputs_set)
        return (inputs, outputs)
    

    【讨论】:

      猜你喜欢
      • 2018-09-05
      • 2018-06-20
      • 2018-02-28
      • 1970-01-01
      • 1970-01-01
      • 2021-12-06
      • 2011-07-30
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多