【问题标题】:Output differences when changing order of batch(), shuffle() and repeat()更改 batch()、shuffle() 和 repeat() 顺序时的输出差异
【发布时间】:2018-09-29 15:40:25
【问题描述】:

我创建了一个 tensorflow 数据集,使其可重复,对其进行洗牌,将其分成批次,并构建了一个迭代器来获取下一批。但是当我这样做时,有时元素是重复的(在批次内和批次之间),尤其是对于小型数据集。为什么?

【问题讨论】:

    标签: tensorflow tensorflow-datasets


    【解决方案1】:

    例如,如果您想要与 Keras 的 .fit() 函数相同的行为,您可以使用:

    dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.repeat(EPOCHS)
    

    这将以与.fit(epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True) 相同的方式遍历数据集。一个简单的示例(仅出于可读性而启用渴望执行,在图形模式下的行为相同):

    import numpy as np
    import tensorflow as tf
    tf.enable_eager_execution()
    
    NUM_SAMPLES = 7
    BATCH_SIZE = 3
    EPOCHS = 2
    
    # Create the dataset
    x = np.array([[2 * i, 2 * i + 1] for i in range(NUM_SAMPLES)])
    dataset = tf.data.Dataset.from_tensor_slices(x)
    
    # Shuffle, batch and repeat the dataset
    dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.repeat(EPOCHS)
    
    # Iterate through the dataset
    iterator = dataset.make_one_shot_iterator()
    for batch in dataset:
        print(batch.numpy(), end='\n\n')
    

    打印

    [[ 8  9]
     [12 13]
     [10 11]]
    
    [[0 1]
     [2 3]
     [4 5]]
    
    [[6 7]]
    
    [[ 4  5]
     [10 11]
     [12 13]]
    
    [[6 7]
     [0 1]
     [2 3]]
    
    [[8 9]]
    

    您可以看到,即使.batch() 被调用之后 .shuffle(),批次在每个时期仍然不同。这就是为什么我们需要使用reshuffle_each_iteration=True。如果我们不在每次迭代中重新洗牌,我们将在每个 epoch 中获得相同的批次:

    [[12 13]
     [ 4  5]
     [10 11]]
    
    [[6 7]
     [8 9]
     [0 1]]
    
    [[2 3]]
    
    [[12 13]
     [ 4  5]
     [10 11]]
    
    [[6 7]
     [8 9]
     [0 1]]
    
    [[2 3]]
    

    这在对小型数据集进行训练时可能是有害的。

    【讨论】:

      【解决方案2】:

      与您自己的回答中所说的不同,不,随机播放然后重复不会解决您的问题

      问题的关键来源是您批处理,然后随机/重复。这样,您的批次中的项目将始终取自输入数据集中的连续样本。 批处理应该是您在输入管道中执行的最后一项操作

      稍微扩展问题。

      现在,随机播放、重复和批处理的顺序有所不同,但这不是你想的那样。引用自input pipeline performance guide

      如果在随机播放之前应用了重复转换 转换,那么时代边界就模糊了。那是, 某些元素甚至可以在其他元素出现之前重复 一次。另一方面,如果应用了混洗变换 在重复转换之前,性能可能会减慢 每个时期的开始与内部的初始化有关 洗牌转换的状态。换言之,前者 (repeat before shuffle) 提供更好的性能,而后者 (重复前洗牌)提供更强的排序保证。

      回顾

      • 重复,然后随机播放:您失去了在一个 epoch 中处理所有样本的保证。
      • 随机播放,然后重复:保证在下一次重复开始之前将处理所有样本,但您会(略微)损失性能。

      无论您选择哪种方式,都请在之前进行批处理。

      【讨论】:

      • 注意,shuffle时只要使用reshuffle_each_iteration=True就可以使用顺序[shuffle,batch,repeat]。
      【解决方案3】:

      你必须先洗牌,然后重复!

      如以下两个代码所示,洗牌和重复的顺序很重要。

      最差排序:

      import tensorflow as tf
      
      ds = tf.data.Dataset.range(10)
      ds = ds.batch(2)
      ds = ds.repeat()
      ds = ds.shuffle(100000)
      iterator   = ds.make_one_shot_iterator()
      next_batch = iterator.get_next()
      
      with tf.Session() as sess:
          for i in range(15):
              if i % (10//2) == 0:
                  print("------------")
              print("{:02d}:".format(i), next_batch.eval())
      

      输出:

      ------------
      00: [6 7]
      01: [2 3]
      02: [6 7]
      03: [0 1]
      04: [8 9]
      ------------
      05: [6 7]
      06: [4 5]
      07: [6 7]
      08: [4 5]
      09: [0 1]
      ------------
      10: [2 3]
      11: [0 1]
      12: [0 1]
      13: [2 3]
      14: [4 5]
      

      错误排序:

      import tensorflow as tf
      
      ds = tf.data.Dataset.range(10)
      ds = ds.batch(2)
      ds = ds.shuffle(100000)
      ds = ds.repeat()
      iterator   = ds.make_one_shot_iterator()
      next_batch = iterator.get_next()
      
      with tf.Session() as sess:
          for i in range(15):
              if i % (10//2) == 0:
                  print("------------")
              print("{:02d}:".format(i), next_batch.eval())
      

      输出:

      ------------
      00: [4 5]
      01: [6 7]
      02: [8 9]
      03: [0 1]
      04: [2 3]
      ------------
      05: [0 1]
      06: [4 5]
      07: [8 9]
      08: [2 3]
      09: [6 7]
      ------------
      10: [0 1]
      11: [4 5]
      12: [8 9]
      13: [2 3]
      14: [6 7]
      

      最佳排序:

      受 GPhilo 回答的启发,批处理的顺序也很重要。对于每个时期的批次不同,必须先洗牌,然后重复,最后是批次。从输出中可以看出,所有批次都是唯一的,与其他批次不同。

      import tensorflow as tf
      
      ds = tf.data.Dataset.range(10)
      
      ds = ds.shuffle(100000)
      ds = ds.repeat()
      ds = ds.batch(2)
      
      iterator   = ds.make_one_shot_iterator()
      next_batch = iterator.get_next()
      
      with tf.Session() as sess:
          for i in range(15):
              if i % (10//2) == 0:
                  print("------------")
              print("{:02d}:".format(i), next_batch.eval())
      

      输出:

      ------------
      00: [2 5]
      01: [1 8]
      02: [9 6]
      03: [3 4]
      04: [7 0]
      ------------
      05: [4 3]
      06: [0 2]
      07: [1 9]
      08: [6 5]
      09: [8 7]
      ------------
      10: [7 3]
      11: [5 9]
      12: [4 1]
      13: [8 6]
      14: [0 2]
      

      【讨论】:

      • ..你确实意识到“错误”和“正确”命令的输出是一样的,对吧?
      • @GPhilo,对不起,复制了错误的代码。感谢您的指正。现在已经修复了。
      • 查看您的新输出,我认为您完全符合他们在我引用的性能指南中所做的区分。您的时期没有错,但是由于您从中获取样本的缓冲区在重复发生之前被随机打乱,因此您最终可能会在耗尽前一次迭代的所有样本之前两次选择相同的样本(因为缓冲区永远不会完全清空)。请注意,虽然这对于小型数据集更为明显,但实际上在大型数据集的训练过程中这几乎不是问题
      • @GPhilo,根据您的回答(谢谢!),我添加了一个新订单。无论数据集有多大或多小,都应该认真避免我在前两个中所做的事情,因为批次保持相同。
      • 是的,批处理总是在洗牌之后进行。理想情况下,批处理是您在管道中执行的最后一个操作(可能后跟prefetch 以获得更好的性能),因为它与样本的预处理没有真正的关系,它只是将东西放在一起以同时提供更多样本网络时间。
      猜你喜欢
      • 2019-11-18
      • 2020-09-04
      • 2020-03-05
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多