【问题标题】:Ensemble two tensorflow models集成两个张量流模型
【发布时间】:2019-01-27 01:28:58
【问题描述】:

我正在尝试从两个几乎相同的模型中创建一个模型,在不同的条件下进行训练,并在 在 tensorflow 内平均它们的输出。我们希望最终模型具有相同的推理接口。

我们已经保存了两个模型的检查点,下面是我们尝试解决问题的方法:

merged_graph = tf.Graph()
with merged_graph.as_default():
    saver1 = tf.train.import_meta_graph('path_to_checkpoint1_model1.meta', import_scope='g1')
    saver2 = tf.train.import_meta_graph('path_to_checkpoint1_model2.meta', import_scope='g2')

with tf.Session(graph=merged_graph) as sess:
  saver1.restore(sess, 'path_to_checkpoint1_model1')
  saver1.restore(sess, 'path_to_checkpoint1_model2')    

  sess.run(tf.global_variables_initializer())

  # export as a saved_model
  builder = tf.saved_model.builder.SavedModelBuilder(kPathToExportDir)
  builder.add_meta_graph_and_variables(sess,
                                       [tf.saved_model.tag_constants.SERVING],
                                       strip_default_attrs=True)    
  builder.save()

上述方法至少有3个缺陷,我们尝试了很多路线但无法让它发挥作用:

  1. model1 和 model2 的图有自己的主要操作。结果,模型在加载过程中失败,并出现以下错误: 失败的前提条件:

_

Expected exactly one main op in : model
Expected exactly one SavedModel main op. Found: [u'g1/group_deps', u'g2/group_deps']
  1. 两个模型都有自己的 Placeholder 节点用于输入(即合并后的 g1/Placeholder 和 g2/Placeholder)。我们找不到删除占位符节点的方法来创建一个新的节点来为两个模型提供输入(我们不想要一个需要将数据输入两个不同占位符的新界面)。

    李>
  2. 这两个图有自己的 init_all、restore_all 节点。我们无法弄清楚如何将这些 NoOp 操作组合到单个节点中。这与问题 #1 相同。

我们也无法在 tensorflow 中找到这种模式集成的示例实现。一个示例代码可能会回答上述所有问题。

注意:我的两个模型是使用 tf.estimator.Estimator 训练的,并导出为 saved_models。因此,它们包含 main_op。

【问题讨论】:

    标签: tensorflow deep-learning tensorflow-serving tensorflow-estimator ensemble-learning


    【解决方案1】:

    对于问题 1,save_model 不是必须的

    对于问题2,可以使用tf.train.import_meta_graph中的input_maparg

    对于问题 3,您真的不需要全部恢复或初始化所有操作

    此代码快照可以向您展示如何在 tensorflow 中组合两个图并平均它们的输出:

    import tensorflow as tf
    merged_graph = tf.Graph()
    with merged_graph.as_default():
        input = tf.placeholder(dtype=tf.float32, shape=WhatEverYourShape)
        saver1 = tf.train.import_meta_graph('path_to_checkpoint1_model1.meta', import_scope='g1',
                                            input_map={"YOUR/INPUT/NAME": input})
        saver2 = tf.train.import_meta_graph('path_to_checkpoint1_model2.meta', import_scope='g2',
                                            input_map={"YOUR/INPUT/NAME": input})
    
        output1 = merged_graph.get_tensor_by_name("g1/YOUR/OUTPUT/TENSOR/NAME")
        output2 = merged_graph.get_tensor_by_name("g2/YOUR/OUTPUT/TENSOR/NAME")
        final_output = (output1 + output2) / 2
    
    with tf.Session(graph=merged_graph) as sess:
        saver1.restore(sess, 'path_to_checkpoint1_model1')
        saver1.restore(sess, 'path_to_checkpoint1_model2')
        # this line should NOT run because it will initialize all variables, your restore op will have no effect
        # sess.run(tf.global_variables_initializer())
        fianl_output_numpy = sess.run(final_output, feed_dict={input: YOUR_NUMPY_INPUT})
    

    【讨论】:

    • 这解决了一些问题,但是对于 g1 和 g2 仍然有多个 init、save、...节点。应该有一种方法可以将它们合并到一个节点下!合并模型失败并出现以下错误:Session status: Failed precondition: Expected exactly one main op in : merged_model
    • 我使用构建器builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], strip_default_attrs=True)保存了模型,但模型加载失败,同样的错误Expected exactly one SavedModel main op. Found: [u'g1/group_deps', u'g2/group_deps']
    【解决方案2】:

    我没有解决,但找到了解决上述问题的方法。

    主要问题是,每当使用 saved_model API 导出模型时,都会添加main_op node。由于我的两个模型都是使用此 API 导出的,因此它们都有 ma​​in_op 节点,该节点将被导入到新图表中。然后,新图将包含两个 ma​​in_ops,它们稍后将无法加载,因为 正是预期的一个主操作

    我选择使用的解决方法不是使用 saved_model API 导出我的最终模型,而是使用方便的旧 freeze_graph 导出到单个 .pb 文件中。

    这是我的工作代码 sn-p:

    # set some constants:
    #   INPUT_SHAPE, OUTPUT_NODE_NAME, OUTPUT_FILE_NAME, 
    #   TEMP_DIR, TEMP_NAME, SCOPE_PREPEND_NAME, EXPORT_DIR
    
    # Set path for trained models which are exported with the saved_model API
    input_model_paths = [PATH_TO_MODEL1, 
                         PATH_TO_MODEL2, 
                         PATH_TO_MODEL3, ...]
    num_model = len(input_model_paths)
    
    def load_model(sess, path, scope, input_node):
        tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], 
                                   path,
                                   import_scope=scope, 
                                   input_map={"Placeholder": input_node})  
        output_tensor = tf.get_default_graph().get_tensor_by_name(
            scope + "/" + OUTPUT_NODE_NAME + ":0")
        return output_tensor  
    
    with tf.Session(graph=tf.Graph()) as sess:
      new_input = tf.placeholder(dtype=tf.float32, 
                                 shape=INPUT_SHAPE, name="Placeholder")      
    
      output_tensors = []
      for k, path in enumerate(input_model_paths):
        output_tensors.append(load_model(sess, 
                                         path, 
                                         SCOPE_PREPEND_NAME+str(k), 
                                         new_input))
      # Mix together the outputs (e.g. sum, weighted sum, etc.)
      sum_outputs = output_tensors[0] + output_tensors[1]
      for i in range(2, num_model):
        sum_outputs = sum_outputs + output_tensors[i]
      final_output = tf.divide(sum_outputs, float(num_model), name=OUTPUT_NODE_NAME)
    
      # Save checkpoint to be loaded later by the freeze_graph!
      saver_checkpoint = tf.train.Saver()
      saver_checkpoint.save(sess, os.path.join(TEMP_DIR, TEMP_NAME))
    
      tf.train.write_graph(sess.graph_def, TEMP_DIR, TEMP_NAME + ".pbtxt")
      freeze_graph.freeze_graph(
          os.path.join(TEMP_DIR, TEMP_NAME + ".pbtxt"), 
          "", 
          False, 
          os.path.join(TEMP_DIR, TEMP_NAME),  
          OUTPUT_NODE_NAME, 
          "", # deprecated
          "", # deprecated
          os.path.join(EXPORT_DIR, OUTPUT_FILE_NAME),
          False,
          "")
    

    【讨论】:

      猜你喜欢
      • 2016-11-27
      • 1970-01-01
      • 2021-01-30
      • 2017-08-07
      • 1970-01-01
      • 2018-11-25
      • 2019-06-18
      • 2017-11-22
      • 1970-01-01
      相关资源
      最近更新 更多