【发布时间】:2019-04-24 22:28:09
【问题描述】:
我正在尝试将一些代码转换为新的数据集 API,以便我可以使用分发策略。以下是我正在尝试做的事情。
def dataset_generator():
while True:
features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
yield features, labels
def get_ssf_input_fn():
def input_fn():
return tf.data.Dataset.from_generator(dataset_generator,
(tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))
return input_fn
问题是 ex_lib.get_image_batch 和 ex_lib.get_feature_batch 给了我一个张量而不是一个 numpy 数组,我无法更改 ex_lib 中的代码。此外,我无法在此处将张量转换为 numpy 数组,因为我无法访问此处的 sess。使用此代码,它将抛出
`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)
有没有办法让我的 input_fn 返回一个数据集?
【问题讨论】:
-
凹凸。面临同样的问题。你能解决这个问题吗?
标签: tensorflow tensorflow-datasets