【发布时间】:2018-02-05 00:54:29
【问题描述】:
我想用 tf.estimator.Estimator 训练我的模式并通过 Dataset API 加载我的数据。因为我的数据,例如“mnist”,是一个数组(张量),所以我尝试用“tf”加载它.data.Dataset.from_tensor_slices'。但我不知道如何在'input_fn'中初始化'make_initializable_iterator'。
如果我可以使用“make_one_shot_iterator”成功训练,但它在训练前加载缓慢。而《Higher-Level APIs in TensorFlow》是在 'input_fn' 中 'make_initializable_iterator' 的一个很好的例子,但它需要从 'input_fn' 返回一个 'iterator_initializer_hook' 给其他函数。我想知道还有其他更好或更优雅的方式吗?
def input_fn():
mnist_data = input_data.read_data_sets('mnist_data', one_hot=False)
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = np.asarray(mnist_data.train.labels, dtype=np.int64)
# Build dataset iterator
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(100)
iterator = dataset.make_one_shot_iterator()
next_example = iterator.get_next()
# Set runhook to initialize iterator
return next_example
【问题讨论】:
标签: python tensorflow tensorflow-datasets tensorflow-estimator