【问题标题】:How can I avoid repopulating my tf.data shuffle buffer on each tf.keras epoch when using tf.distribute?使用 tf.distribute 时,如何避免在每个 tf.keras 时期重新填充我的 tf.data 洗牌缓冲区?
【发布时间】:2019-01-17 22:09:12
【问题描述】:

我正在使用 (tf-nightly-gpu==1.13.0.dev20190116) 做一个 Keras 模型:

with tf.distribute.MirroredStrategy().scope():
    model = tf.keras.Model(...)

和一个数据集:

dataset = (tf.data.Dataset
    .list_files(...)
    .map(load_example)
    .cache()
    .shuffle(100)
    .repeat())

然后训练

model.fit(dataset, epochs=10, steps_per_epoch=1000)

效果很好,因为它会在我的单机多 GPU 设置上自动拆分我的小批量。很酷!

但是,我的 shuffle 缓冲区会在每个 epoch 重新填充。有没有一种方法可以让洗牌缓冲区保持在多个时期?我尝试使用迭代器和张量直接调用 model.fit,但 tf.distribute 不支持(还没有?)而是引发异常。

TL;DR:我如何确保我的 tf.data shuffle 缓冲区跨时期得到维护?

【问题讨论】:

  • shuffle 有一个 reshuffle_each_iteration 参数,您可以将其设置为 False。这会有帮助吗?
  • 哦,是的!当然!完全忽略了这个论点。谢谢!提供答案,我会将其标记为已接受的答案。

标签: python tensorflow keras


【解决方案1】:

shuffle 有一个参数 reshuffle_each_iteration,您可以将其设置为 False,以便仅在第一个 epoch 发生洗牌,并在未来的 epoch 中保持状态。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-01-05
    • 2019-08-31
    • 1970-01-01
    • 2012-02-04
    • 1970-01-01
    • 1970-01-01
    • 2021-01-29
    相关资源
    最近更新 更多