【发布时间】:2019-03-08 08:00:26
【问题描述】:
我是 TensorFlow 的新手。我关注了一些在线帖子并编写了代码以从生成器中获取数据。 代码如下所示:
def gen(my_list_of_files):
for fl in my_list_of_files:
with open(fl) as f:
for line in f.readlines():
json_line = json.loads(line)
features = json_line['features']
labels = json_line['labels']
yield features, labels
def get_dataset():
generator = lambda: gen()
return tf.data.Dataset.from_generator(generator, (tf.float32, tf.float32))
def get_input():
dataset = get_dataset()
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.repeat().unbatch(tf.contrib.data.unbatch())
dataset = dataset.batch(batch_size, drop_remainder=False)
# This is where the problem is
features, labels = dataset.make_one_shot_iterator().get_next()
return features, labels
当我运行它时,我得到了错误:
InvalidArgumentError (see above for traceback): Input element must have a non-scalar value in each component.
[[node IteratorGetNext (defined at /blah/blah/blah) ]]
我产生的值看起来像:
[1, 2, 3, 4, 5, 6] # features
7 # label
我对错误的理解是它不能遍历数据集,因为它不是向量。我的理解正确吗?我该如何解决这个问题?
【问题讨论】:
-
如果其他人遇到这种情况:当我返回列表中的标签时,这对我有用,尽管我仍然不确定为什么它一开始就不起作用。
标签: tensorflow generator tensorflow-datasets