【问题标题】:Data generated with Tensorflow Dataset.from_generator results in error when iterator.get_next() is called on it使用 Tensorflow Dataset.from_generator 生成的数据在调用 iterator.get_next() 时会出错
【发布时间】:2019-03-08 08:00:26
【问题描述】:

我是 TensorFlow 的新手。我关注了一些在线帖子并编写了代码以从生成器中获取数据。 代码如下所示:

def gen(my_list_of_files):
    for fl in my_list_of_files:
        with open(fl) as f:
            for line in f.readlines():
                json_line = json.loads(line)
                features = json_line['features']
                labels = json_line['labels']
                yield features, labels

def get_dataset():
     generator = lambda: gen()
     return tf.data.Dataset.from_generator(generator, (tf.float32, tf.float32))

def get_input():
     dataset = get_dataset()
     dataset = dataset.shuffle(buffer_size=buffer_size)
     dataset = dataset.repeat().unbatch(tf.contrib.data.unbatch())
     dataset = dataset.batch(batch_size, drop_remainder=False)

     # This is where the problem is
     features, labels = dataset.make_one_shot_iterator().get_next()

     return features, labels

当我运行它时,我得到了错误:

InvalidArgumentError (see above for traceback): Input element must have a non-scalar value in each component.
     [[node IteratorGetNext (defined at /blah/blah/blah) ]]

我产生的值看起来像:

[1, 2, 3, 4, 5, 6] # features
7 # label

我对错误的理解是它不能遍历数据集,因为它不是向量。我的理解正确吗?我该如何解决这个问题?

【问题讨论】:

  • 如果其他人遇到这种情况:当我返回列表中的标签时,这对我有用,尽管我仍然不确定为什么它一开始就不起作用。

标签: tensorflow generator tensorflow-datasets


【解决方案1】:
{
   "features": ["1","2"],
   "labels": "2"

}

执行此代码时,我没有看到您的错误。

def gen():
    with open('jsondataset') as f:
        data = json.load(f)
        features = data['features']
        labels = data['labels']
        print( features)
        yield features, labels

def get_dataset():
     generator = lambda: gen()
     return tf.data.Dataset.from_generator(generator, (tf.float32, tf.float32))

def get_input():
     dataset = get_dataset()
     dataset = dataset.shuffle(buffer_size=5)
     dataset = dataset.batch(5, drop_remainder=False)

     # This is where the problem is
     iter = dataset.make_one_shot_iterator()
     features, labels = iter.get_next()

     with tf.Session() as sess:
         print(sess.run([features,labels]))


def main():
    get_input()

if __name__ == "__main__":
    main()

[数组([[1., 2.]], dtype=float32), 数组([2.], dtype=float32)]

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-09-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-05-17
    • 1970-01-01
    • 2018-08-07
    相关资源
    最近更新 更多