【问题标题】:How to use feedable iterator from Tensorflow Dataset API along with MonitoredTrainingSession?如何使用来自 Tensorflow Dataset API 的 feedable 迭代器和 MonitoredTrainingSession?
【发布时间】:2018-02-17 01:27:55
【问题描述】:

Tensorflow programmer's guide 建议使用 feedable 迭代器在训练和验证数据集之间切换,而无需重新初始化迭代器。主要是需要进给手柄来选择。

如何与tf.train.MonitoredTrainingSession一起使用?

以下方法失败并显示“RuntimeError: Graph is finalized and cannot be modified。”错误。

with tf.train.MonitoredTrainingSession() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

如何同时实现 MonitoredTrainingSession 的便利性和迭代训练和验证数据集?

【问题讨论】:

    标签: tensorflow tensorflow-datasets


    【解决方案1】:

    我从 Tensorflow GitHub 问题中得到了答案 - https://github.com/tensorflow/tensorflow/issues/12859

    解决方案是在创建MonitoredSession 之前调用iterator.string_handle()

    import tensorflow as tf
    from tensorflow.contrib.data import Dataset, Iterator
    
    dataset_train = Dataset.range(10)
    dataset_val = Dataset.range(90, 100)
    
    iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
    iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
    
    handle = tf.placeholder(tf.string, shape=[])
    iterator = Iterator.from_string_handle(
        handle, dataset_train.output_types, dataset_train.output_shapes)
    next_batch = iterator.get_next()
    
    with tf.train.MonitoredTrainingSession() as sess:
        handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
    
        for step in range(10):
            print('train', sess.run(next_batch, feed_dict={handle: handle_train}))
    
            if step % 3 == 0:
                print('val', sess.run(next_batch, feed_dict={handle: handle_val}))
    
    Output:
    ('train', 0)
    ('val', 90)
    ('train', 1)
    ('train', 2)
    ('val', 91)
    ('train', 3)
    

    【讨论】:

      【解决方案2】:

      @Michael Jaison G 的答案是正确的。但是,当您还想使用某些需要评估图形部分的 session_run_hooks 时,它不起作用,例如LoggingTensorHook 或 SummarySaverHook。 下面的例子会报错:

      import tensorflow as tf
      
      dataset_train = tf.data.Dataset.range(10)
      dataset_val = tf.data.Dataset.range(90, 100)
      
      iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
      iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()
      
      handle = tf.placeholder(tf.string, shape=[])
      iterator = tf.data.Iterator.from_string_handle(
          handle, dataset_train.output_types, dataset_train.output_shapes)
      feature = iterator.get_next()
      
      pred = feature * feature
      tf.summary.scalar('pred', pred)
      global_step = tf.train.create_global_step()
      
      summary_hook = tf.train.SummarySaverHook(save_steps=5,
                                               output_dir="summaries", summary_op=tf.summary.merge_all())
      
      with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess: 
          handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
      
          for step in range(10):
              feat = sess.run(feature, feed_dict={handle: handle_train})
              pred_ = sess.run(pred, feed_dict={handle: handle_train})
              print('train: ', feat)
              print('pred: ', pred_)
      
              if step % 3 == 0:
                  print('val', sess.run(feature, feed_dict={handle: handle_val}))
      

      这将失败并出现错误:

      InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
           [[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
           [[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
      

      原因是钩子会在第一个 session.run([iter_train_handle, iter_val_handle]) 时尝试评估图形,显然在 feed_dict 中还没有句柄。

      解决方法是覆盖导致问题的钩子,并将 before_run 和 after_run 中的代码更改为仅对包含 feed_dict 中的句柄的 session.run 调用进行评估(您可以访问当前 session.run 调用的 feed_dict通过 before_run 和 after_run 的 run_context 参数)

      或者您可以使用最新的 Tensorflow 大师(post-1.4),它向 MonitoredSession 添加了一个 run_step_fn 函数,它允许您指定以下 step_fn 以避免错误(以评估 if 语句 TrainingIteration 次数为代价...)

      def step_fn(step_context):
        if handle_train is None:
          handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
        return step_context.run_with_hooks(fetches=..., feed_dict=...)
      

      【讨论】:

        【解决方案3】:

        有一个使用 SessionRunHook 在 mot_session 中使用占位符的演示。 这个演示是关于通过输入 diff handle_string 来切换数据集。

        顺便说一句,我已经尝试了所有解决方案,但只有这个有效。

        dataset_switching

        【讨论】:

        • 这个链接使用make_one_shot_iterator,其他人使用不同的迭代器类型,这也是我唯一可以开始工作的。如果你被卡住了,这个链接可能会非常有帮助。谢谢分享!
        猜你喜欢
        • 2018-04-14
        • 1970-01-01
        • 1970-01-01
        • 2019-09-24
        • 1970-01-01
        • 2018-02-18
        • 2018-08-11
        • 1970-01-01
        • 2019-11-15
        相关资源
        最近更新 更多