【发布时间】:2018-04-18 04:19:22
【问题描述】:
我正在使用 tensorflow 数据集作为输入数据管道。我想知道如何在第一个时期不进行数据洗牌的情况下进行训练,并从第二个时期开始洗牌数据。
图表通常是在迭代训练开始之前构建的,并且在训练期间,如何更改 DataSet 改组行为似乎并不直接,因为在我看来这有点像更改图表。
有什么想法吗?
谢谢, 哈利
【问题讨论】:
标签: tensorflow tensorflow-datasets
我正在使用 tensorflow 数据集作为输入数据管道。我想知道如何在第一个时期不进行数据洗牌的情况下进行训练,并从第二个时期开始洗牌数据。
图表通常是在迭代训练开始之前构建的,并且在训练期间,如何更改 DataSet 改组行为似乎并不直接,因为在我看来这有点像更改图表。
有什么想法吗?
谢谢, 哈利
【问题讨论】:
标签: tensorflow tensorflow-datasets
Dataset.shuffle() 的 buffer_size 参数可以是计算得出的 tf.Tensor,因此您可以使用以下代码使用 Dataset.range(NUM_EPOCHS).flat_map(...) 将纪元数序列转换为 @ 的(混洗或其他)元素987654326@:
NUM_EPOCHS = ... # The total number of epochs.
BUFFER_SIZE = ... # The shuffle buffer size to use from the second epoch on.
per_epoch_dataset = ... # A `Dataset` representing the elements of a single epoch.
def shuffle_after_first_epoch(epoch):
# Set `epoch_buffer_size` to 1 (i.e. no shuffling) in the 0th epoch,
# and `BUFFER_SIZE` thereafter.
epoch_buffer_size = tf.cond(tf.equal(epoch, 0),
lambda: tf.constant(1, tf.int64),
lambda: tf.constant(BUFFER_SIZE, tf.int64))
return per_epoch_dataset.shuffle(epoch_buffer_size)
dataset = tf.data.Dataset.range(NUM_EPOCHS).flat_map(shuffle_after_first_epoch)
【讨论】:
flat_map() 就像一个嵌套的 for 循环,它围绕在 per_epoch_dataset 上的一个循环,它产生每个元素。
.batch() 直接放在flat_map() 之前(即执行Dataset.range(...).batch(...).flat_map(...))会批量处理纪元ID,我怀疑这不是您想要做的。放置.batch() 的最合适位置是return per_epoch_dataset.shuffle(epoch_buffer_size).batch(...),然后您将在第一个时期获得非混洗批次,并在所有后续时期获得混洗批次。