【发布时间】:2020-01-13 10:53:37
【问题描述】:
我在加快使用 tf.data.dataset 进行训练的数据管道时遇到了麻烦,我想我在这里遗漏了一些东西。由于数据集中有不同的选项来预加载数据,数据集的速度仍然很慢。
我有一个复杂的数据管道,但我简化为下面的一个小示例。我尝试微调 num_parallel_calls、cycle_length、prefetch 等,但似乎无法顺利生成数据集。我错过了什么?有什么建议?在此输入代码
import tensorflow as tf
tf.enable_eager_execution()
from timeit import default_timer as timer
feature_count = 400
batch_size = 1024
look_back = 100
target_groups = 21
def random_data_generator(x=0):
while True:
x_data = tf.random.uniform(
shape=(batch_size, look_back, feature_count),
minval=-1.0,
maxval=5,
dtype=tf.dtypes.float32)
Y_data = tf.random.uniform(
shape=(batch_size, target_groups),
minval=1,
maxval=21,
dtype=tf.dtypes.int32)
yield x_data, Y_data
def get_simple_Dataset_generator():
dataset = tf.data.Dataset.from_tensor_slices([0,1,2])
dataset = dataset.interleave(lambda x: tf.data.Dataset.from_generator(random_data_generator,
output_types=(tf.float32, tf.float32), args=(x,)),
cycle_length=3,
block_length=3,
num_parallel_calls= tf.data.experimental.AUTOTUNE)
#dataset = dataset.prefetch(2)
while True:
for x, Y in dataset:
yield x, Y
def test_speed():
generator = get_simple_Dataset_generator()
print("Testing generator speed ")
for i in range(1,100):
start_time = timer()
next(generator)
lap_time = timer()-start_time
print("%s Time - %fsec "%(i, lap_time))
if __name__ == '__main__':
test_speed()```
I was hoping to see consistent generator speed but it still very erratic.
Output
1 Time - 3.417578sec
2 Time - 1.257846sec
3 Time - 1.286210sec
4 Time - 0.000456sec
5 Time - 0.027772sec
6 Time - 0.058985sec
7 Time - 0.000416sec
8 Time - 0.026721sec
9 Time - 0.027316sec
10 Time - 0.777332sec
11 Time - 1.379266sec
12 Time - 1.172304sec
13 Time - 0.000365sec
14 Time - 0.026909sec
15 Time - 0.045409sec
16 Time - 0.000708sec
17 Time - 0.025682sec
18 Time - 0.027223sec
19 Time - 0.577131sec
20 Time - 1.220682sec
21 Time - 1.189601sec
22 Time - 0.000573sec
23 Time - 0.079531sec
24 Time - 0.624080sec
25 Time - 0.038932sec
【问题讨论】:
标签: tensorflow dataset pipeline