【发布时间】:2020-02-13 17:57:07
【问题描述】:
我的输入管道在 CPU、GPU 和磁盘利用率较低的情况下性能不佳。我一直在阅读 tensorflow "Better performance with tf.data API" 文档和 Dataset 文档,但我不明白发生了什么足以将其应用于我的情况。这是我当前的设置:
img_files = sorted(tf.io.gfile.glob(...))
imgd = tf.data.FixedLengthRecordDataset(img_files, inrez*inrez)
#POINT1A
imgd = imgd.map(lambda s: tf.reshape(tf.io.decode_raw(s, tf.int8), (inrez,inrez,1)))
imgd = imgd.map(lambda x: tf.cast(x, dtype=tf.float32))
out_files = sorted(tf.io.gfile.glob(...))
outd = tf.data.FixedLengthRecordDataset(out_files, 4, compression_type="GZIP")
#POINT1B
outd = outd.map(lambda s: tf.io.decode_raw(s, tf.float32))
xsrc = tf.data.Dataset.zip((imgd, outd)).batch(batchsize)
xsrc = xsrc.repeat() # indefinitely
#POINT2
xsrc = xsrc.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
我是否应该在预取之前在末尾 (POINT2) 交错整个管道?或者在每个 FixedLengthRecordDataset (POINT1A, POINT1B) 之后分别交错 imgd 和 outd,并并行化地图? (需要保持 imgd 和 outd 同步!) Dataset.range(rvalue) 怎么了——似乎有必要但不明显使用什么右值?有没有更好的总体规划?
请注意,数据集非常大,无法放入 RAM。
【问题讨论】:
-
请注意,数据来自 SSD……它应该是嗡嗡声……
标签: python tensorflow tensorflow2.0 tensorflow-datasets