【问题标题】:Tensorflow (tf-slim) Model with is_training True and False带有 is_training True 和 False 的 Tensorflow (tf-slim) 模型
【发布时间】:2017-01-14 04:14:33
【问题描述】:

我想在训练集 (is_training=True) 和验证集 (is_training=False) 上运行给定模型,特别是如何应用 dropout。现在prebuilt models 公开了一个参数is_training,它在构建网络时传递给dropout 层。问题是如果我用不同的is_training 值调用该方法两次,我将得到两个不共享权重的不同网络(我认为?)。如何让两个网络共享相同的权重,以便我可以运行我在验证集上训练过的网络?

【问题讨论】:

  • 我认为默认行为是在两种情况之间共享权重,因此您无需做任何事情。 tf-slim 使用 tf.get_variable() 在调用之间重用变量。
  • 好的,我认为这很有效。您需要确保设置了scope,然后为了安全起见最好也设置reuse=True

标签: tensorflow tf-slim


【解决方案1】:

我用评论写了一个解决方案,以在火车和测试模式下使用过度污染。 (我无法测试它,以便您可以检查它是否有效?)

首先导入和参数:

import tensorflow as tf
slim = tf.contrib.slim
overfeat = tf.contrib.slim.nets.overfeat

batch_size = 32
inputs = tf.placeholder(tf.float32, [batch_size, 231, 231, 3])
dropout_keep_prob = 0.5
num_classes = 1000

在火车模式中,我们通过正常范围到函数overfeat

scope = 'overfeat'
is_training = True

output = overfeat.overfeat(inputs, num_classes, is_training,         
                           dropout_keep_prob, scope=scope)

然后在测试模式下,我们创建相同的范围,但使用reuse=True

scope = tf.VariableScope(reuse=True, name='overfeat')
is_training = False

output = overfeat.overfeat(inputs, num_classes, is_training,         
                           dropout_keep_prob, scope=scope)

【讨论】:

    【解决方案2】:

    您可以只为 is_training 使用占位符:

    isTraining = tf.placeholder(tf.bool)
    
    # create nn
    net = ...
    net = slim.dropout(net,
                       keep_prob=0.5,
                       is_training=isTraining)
    net = ...
    
    # training
    sess.run([net], feed_dict={isTraining: True})
    
    # testing
    sess.run([net], feed_dict={isTraining: False})
    

    【讨论】:

    • 我试过这个并遇到了问题,因为变量没有被重用。我还遇到了我无法解释的内存限制。
    【解决方案3】:

    视情况而定,解决方案不同。

    我的第一个选择是使用不同的流程进行评估。您只需要检查是否有一个新的检查点并将该权重加载到评估网络中(使用is_training=False):

    checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
    # wait until a new check point is available
    while self.lastest_checkpoint == checkpoint:
        time.sleep(30)  # sleep 30 seconds waiting for a new checkpoint
        checkpoint = tf.train.latest_checkpoint(self.checkpoints_path)
    logging.info('Restoring model from {}'.format(checkpoint))
    self.saver.restore(session, checkpoint)
    self.lastest_checkpoint = checkpoint
    

    第二个选项是在每个 epoch 之后卸载图表并创建一个新的评估图表。此解决方案浪费了大量时间加载和卸载图表。

    第三种选择是共享权重。但是为这些网络提供队列或数据集可能会导致问题,因此您必须非常小心。我只将它用于连体网络。

    with tf.variable_scope('the_scope') as scope:
        your_model(is_training=True)
        scope.reuse_variables()
        your_model(is_training=False)
    

    【讨论】:

      猜你喜欢
      • 2018-03-11
      • 2017-11-06
      • 2018-07-08
      • 1970-01-01
      • 2018-10-04
      • 1970-01-01
      • 2017-01-22
      • 2017-08-03
      • 1970-01-01
      相关资源
      最近更新 更多