【问题标题】:TensorFlow Dataset Shuffle Each EpochTensorFlow 数据集 Shuffle 每个 Epoch
【发布时间】:2017-10-22 18:58:45
【问题描述】:

在Tensorflow 中Dataset 类的manual 中,它显示了如何对数据进行混洗以及如何对其进行批处理。但是,如何在每个 epoch 中对数据进行洗牌并不明显。我已经尝试过以下方法,但是第二个时期的数据顺序与第一个时期完全相同。有人知道如何使用数据集在不同时期之间进行洗牌吗?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    在我看来,您在这两种情况下都使用了相同的next_batch。因此,根据您真正想要的,您可能需要在第二次调用sess.run 之前重新创建next_batch,如下所示,否则data = data.shuffle(12) 对您之前创建的next_batch 没有任何影响代码。

    n_epochs = 2
    batch_size = 3
    
    data = tf.contrib.data.Dataset.range(12)
    
    data = data.repeat(n_epochs)
    data = data.batch(batch_size)
    next_batch = data.make_one_shot_iterator().get_next()
    
    sess = tf.Session()
    for _ in range(4):
        print(sess.run(next_batch))
    
    print("new epoch")
    data = data.shuffle(12)
    
    """See how I recreate next_batch after the data has been shuffled"""
    next_batch = data.make_one_shot_iterator().get_next()
    for _ in range(4):
        print(sess.run(next_batch))
    

    请让我知道这是否有帮助。谢谢。

    【讨论】:

    • 您如何将其与您使用此数据的图形的定义结合起来?在我看来,您的解决方案似乎每次重新创建数据集时都必须重新定义图表,因为get_next() 返回的变量将用作图表的输入。
    【解决方案2】:

    我的环境:Python 3.6、TensorFlow 1.4。

    TensorFlow 已将 Dataset 添加到 tf.data 中。

    data.shuffle的位置要谨慎。在您的代码中,数据的时代已在您的shuffle 之前放入dataset 的缓冲区中。以下是洗牌数据集的两个可用示例。

    随机播放所有元素

    # shuffle all elements
    import tensorflow as tf
    
    n_epochs = 2
    batch_size = 3
    buffer_size = 5
    
    dataset = tf.data.Dataset.range(12)
    dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(n_epochs)
    iterator = dataset.make_one_shot_iterator()
    next_batch = iterator.get_next()
    
    sess = tf.Session()
    print("epoch 1")
    for _ in range(4):
        print(sess.run(next_batch))
    print("epoch 2")
    for _ in range(4):
        print(sess.run(next_batch))
    

    输出:

    epoch 1
    [1 4 5]
    [3 0 7]
    [6 9 8]
    [10  2 11]
    epoch 2
    [2 0 6]
    [1 7 4]
    [5 3 8]
    [11  9 10]
    

    批次之间随机播放,而不是批次随机播放

    # shuffle between batches, not shuffle in a batch
    import tensorflow as tf
    
    n_epochs = 2
    batch_size = 3
    buffer_size = 5
    
    dataset = tf.data.Dataset.range(12)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(n_epochs)
    dataset = dataset.shuffle(buffer_size=buffer_size)
    iterator = dataset.make_one_shot_iterator()
    next_batch = iterator.get_next()
    
    sess = tf.Session()
    print("epoch 1")
    for _ in range(4):
        print(sess.run(next_batch))
    print("epoch 2")
    for _ in range(4):
        print(sess.run(next_batch))
    

    输出:

    epoch 1
    [0 1 2]
    [6 7 8]
    [3 4 5]
    [6 7 8]
    epoch 2
    [3 4 5]
    [0 1 2]
    [ 9 10 11]
    [ 9 10 11]
    

    【讨论】:

    • 批次之间的混洗是指批次订单要混洗,而不是批次内的元素?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-07-06
    • 2018-10-30
    • 2019-11-18
    • 2018-08-24
    • 1970-01-01
    • 2018-12-15
    • 1970-01-01
    相关资源
    最近更新 更多