【发布时间】:2019-12-05 15:45:41
【问题描述】:
我想在我的输入管道中编写一个数据增强步骤,从概念上讲,我有两个数据集可以作为一对输入到生成器,在那里它们将产生一堆输出示例。
我已经通过以下方式实现了这样的目标:
import tensorflow as tf
def gen(a, b):
for i in range(2):
yield str(a) + " " + str(b) + " " + str(i)
a = tf.data.Dataset.range(3)
b = tf.data.Dataset.range(3)
dataset = b.interleave(lambda x: a.interleave(lambda y: tf.data.Dataset.from_generator(gen,
output_types=(tf.string),
args=(x, y)),
num_parallel_calls = None))
for d in dataset:
print (d.numpy())
这会产生:
b'0 0 0'
b'0 1 0'
b'0 2 0'
b'0 0 1'
b'0 1 1'
b'0 2 1'
b'1 0 0'
b'1 1 0'
b'1 2 0'
b'1 0 1'
b'1 1 1'
b'1 2 1'
b'2 0 0'
b'2 1 0'
b'2 2 0'
b'2 0 1'
b'2 1 1'
b'2 2 1'
正如预期的那样。我的问题是gen 是(在我的真实情况下)计算成本很高的操作,所以我想尽可能使用并行调用。到目前为止,我尝试添加 num_parallel_calls 未能产生性能提升。
另外,如果重要的话,我的输入数据集来自 TFRecordDataset,这为添加 num_parallel_calls 选项提供了更多机会,即
raw_a = tf.data.TFRecordDataset(a_tfrecord_list)
a = raw_dataset.map(some_parsing_function)
【问题讨论】:
标签: python tensorflow tfrecord data-augmentation