【问题标题】:How to grab the input and output tensors after load a tf.saved_model in python在 python 中加载 tf.saved_model 后如何获取输入和输出张量
【发布时间】:2018-09-18 04:11:34
【问题描述】:

假设我使用以下代码保存了一个模型 tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})

现在我将保存的模型加载到另一个 python 程序中 with tf.Session() as sess: tf.saved_model.loader.load(sess, ['serve'], export_dir)

现在的问题是如何通过调用 simple_save() 方法时在输入/输出参数中指定的 'input_x'、'input_y'、'output_z' 键获取 x、y、z 张量的句柄?

我在网上找到的唯一解决方案是在创建 x、y、z 张量时显式命名它们,然后使用这些名称从图中检索它们,这似乎是相当多余的,因为我们在调用时为它们指定了键simple_save()。

【问题讨论】:

    标签: tensorflow tensorflow-serving


    【解决方案1】:

    我确实遇到了您的问题,经过一些调查(我认为 TF 文档很差),我找到了下一个解决方案:

    使用返回的 MetaGraphDef 对象查找您的输入\输出名称映射。

            graph = tf.Graph()
        with graph.as_default():
            metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
    
        inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
        outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
    

    此代码将为您提供保存到“TensorInfo”对象时提供的名称之间的映射,您可以从他那里轻松获取映射的张量名称,例如:

        my_input = inputs_mapping['my_input_name'].name
        my_input_t = graph.get_tensor_by_name(my_input)
    

    【讨论】:

      【解决方案2】:

      tf.saved_model.loader.load 的返回值是一个MetaGraphDef 协议缓冲区,它应该具有您在保存保存的模型时设置的所有签名;那些应该包含你想要的名字。

      【讨论】:

      • 谢谢,但我不确定您的回答是否解决了我的问题。我在问如何通过我放在签名中的键来处理张量。如果你能放一些代码,可能会更清楚。
      猜你喜欢
      • 2019-04-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2015-08-18
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-12-13
      相关资源
      最近更新 更多