【问题标题】:How to use the iterator 'make_initializable_iterator' of tensorflow within a 'input_fn'?如何在'input_fn'中使用tensorflow的迭代器'make_initializable_iterator'?
【发布时间】:2018-02-05 00:54:29
【问题描述】:

我想用 tf.estimator.Estimator 训练我的模式并通过 Dataset API 加载我的数据。因为我的数据,例如“mnist”,是一个数组(张量),所以我尝试用“tf”加载它.data.Dataset.from_tensor_slices'。但我不知道如何在'input_fn'中初始化'make_initializable_iterator'。

如果我可以使用“make_one_shot_iterator”成功训练,但它在训练前加载缓慢。而《Higher-Level APIs in TensorFlow》是在 'input_fn' 中 'make_initializable_iterator' 的一个很好的例子,但它需要从 'input_fn' 返回一个 'iterator_initializer_hook' 给其他函数。我想知道还有其他更好或更优雅的方式吗?

    def input_fn():

    mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
    images = mnist_data.train.images.reshape([-1, 28, 28, 1])
    labels = np.asarray(mnist_data.train.labels, dtype=np.int64)

    # Build dataset iterator
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat(None)  # Infinite iterations
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(100)
    iterator = dataset.make_one_shot_iterator()
    next_example = iterator.get_next()
    # Set runhook to initialize iterator

    return next_example

【问题讨论】:

    标签: python tensorflow tensorflow-datasets tensorflow-estimator


    【解决方案1】:

    在 TensorFlow 1.5 及更高版本中,当您从 input_fn 返回 tf.data.Dataset 时,tf.estimator.Estimator 将自动创建并初始化一个可初始化的迭代器。这使您可以编写以下代码,而不必担心初始化或挂钩:

    def input_fn():
        mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
        images = mnist_data.train.images.reshape([-1, 28, 28, 1])
        labels = np.asarray(mnist_data.train.labels, dtype=np.int64)
    
        # Build dataset.
        dataset = tf.data.Dataset.from_tensor_slices((images, labels))
        dataset = dataset.repeat(None)  # Infinite iterations
        dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.batch(100)
        return dataset
    

    【讨论】:

      【解决方案2】:

      在您的代码中,添加以下内容:

            self.hooks.append(utils_hooks.DatasetHook(iter))
      

      在 run_loop.py 中,在调用你的 fn 之前,添加这个

       for hook in dataset_hooks:
              sess.run(hook.iterator().initializer)
      

      那么,应该没问题。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2017-12-14
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2019-01-06
        • 1970-01-01
        • 2020-10-23
        • 2018-04-05
        相关资源
        最近更新 更多