【问题标题】:How to iterate a dataset several times using TensorFlow's Dataset API?如何使用 TensorFlow 的 Dataset API 多次迭代数据集?
【发布时间】:2018-04-14 12:15:53
【问题描述】:

如何多次输出数据集中的值? (数据集由TensorFlow的Dataset API创建)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

错误信息:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

如何做到这一点?

【问题讨论】:

    标签: tensorflow tensorflow-datasets


    【解决方案1】:

    首先我建议你阅读Data Set Guide。描述了DataSet API的所有细节。

    您的问题是关于多次迭代数据。这里有两种解决方案:

    1. 一次迭代所有 epoch,没有关于单个 epoch 结束的信息
    import tensorflow as tf
    
    epoch   = 10
    dataset = tf.data.Dataset.range(100)
    dataset = dataset.repeat(epoch)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    sess = tf.Session()
    
    num_batch = 0
    j = 0
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
            if j > 99: # new epoch
                j = 0
        except tf.errors.OutOfRangeError:
            break
    
    print ("Num Batch: ", num_batch)
    
    1. 第二个选项通知你关于结束每个纪元,所以你可以前。检查验证损失:
    import tensorflow as tf
    
    epoch = 10
    dataset = tf.data.Dataset.range(100)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    sess = tf.Session()
    
    num_batch = 0
    
    for e in range(epoch):
        print ("Epoch: ", e)
        j = 0
        sess.run(iterator.initializer)
        while True:
            try:
                value = sess.run(next_element)
                assert j == value
                j += 1
                num_batch += 1
            except tf.errors.OutOfRangeError:
                break
    
    print ("Num Batch: ", num_batch)
    

    【讨论】:

      【解决方案2】:

      如果你的 tensorflow 版本是 1.3+,我推荐高级 API tf.train.MonitoredTrainingSession。这个API创建的sess可以自动检测tf.errors.OutOfRangeErrorsess.should_stop()。对于大多数训练情况,您需要打乱数据并在每一步获取一个批次,我在以下代码中添加了这些。

      import tensorflow as tf
      
      epoch = 10
      dataset = tf.data.Dataset.range(100)
      dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
      dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
      dataset = dataset.repeat(epoch)
      iterator = dataset.make_one_shot_iterator()
      next_element = iterator.get_next()
      
      num_batch = 0
      with tf.train.MonitoredTrainingSession() as sess:
          while not sess.should_stop():
              value = sess.run(next_element)
              num_batch += 1
              print("Num Batch: ", num_batch)
      

      【讨论】:

        【解决方案3】:

        试试这个

        while True:
          try:
            print(sess.run(value))
          except tf.errors.OutOfRangeError:
            break
        

        当数据集迭代器到达数据末尾时,它会引发 tf.errors.OutOfRangeError,你可以用 except 捕获它并从头开始数据集。

        【讨论】:

        • 你应该解释你的代码或者包括 cmets
        【解决方案4】:

        类似于 Toms 的回答,对于 tensorflow 2+,您可以使用以下高级 API 调用(他的回答中提出的代码在 tensorflow 2+ 中已弃用):

        epoch = 10
        batch_size = 32
        dataset = tf.data.Dataset.range(100) 
        
        dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
        dataset = dataset.batch(batch_size=batch_size)
        dataset = dataset.repeat(epoch)
        
        num_batch = 0
        for batch in dataset:
                num_batch += 1
                print("Num Batch: ", num_batch)
        

        跟踪进度的一个有用调用是将迭代的批次总数(在batchrepeat 调用之后使用):

        num_batches = tf.data.experimental.cardinality(dataset)
        

        请注意,目前(tensorflow 2.1),cardinality 方法仍处于试验阶段。

        【讨论】:

          猜你喜欢
          • 2018-02-17
          • 2020-10-03
          • 2020-02-21
          • 2018-04-08
          • 2018-08-27
          • 2018-02-07
          • 1970-01-01
          • 2018-02-23
          • 1970-01-01
          相关资源
          最近更新 更多