【发布时间】:2019-01-17 22:09:12
【问题描述】:
我正在使用 (tf-nightly-gpu==1.13.0.dev20190116) 做一个 Keras 模型:
with tf.distribute.MirroredStrategy().scope():
model = tf.keras.Model(...)
和一个数据集:
dataset = (tf.data.Dataset
.list_files(...)
.map(load_example)
.cache()
.shuffle(100)
.repeat())
然后训练
model.fit(dataset, epochs=10, steps_per_epoch=1000)
效果很好,因为它会在我的单机多 GPU 设置上自动拆分我的小批量。很酷!
但是,我的 shuffle 缓冲区会在每个 epoch 重新填充。有没有一种方法可以让洗牌缓冲区保持在多个时期?我尝试使用迭代器和张量直接调用 model.fit,但 tf.distribute 不支持(还没有?)而是引发异常。
TL;DR:我如何确保我的 tf.data shuffle 缓冲区跨时期得到维护?
【问题讨论】:
-
shuffle 有一个 reshuffle_each_iteration 参数,您可以将其设置为 False。这会有帮助吗?
-
哦,是的!当然!完全忽略了这个论点。谢谢!提供答案,我会将其标记为已接受的答案。
标签: python tensorflow keras