【发布时间】:2018-10-31 03:30:54
【问题描述】:
在 Tensorflow 指南中,指南有两个不同的地方描述了 Iris Data 示例的输入函数。一个输入函数只返回数据集本身,而另一个返回带有迭代器的数据集。
来自预制的 Estimator 指南:https://www.tensorflow.org/guide/premade_estimators
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)
来自自定义估算器指南:https://www.tensorflow.org/guide/custom_estimators
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
我很困惑哪一个是正确的,如果它们都用于不同的情况,什么时候使用迭代器返回数据集是正确的?
【问题讨论】:
标签: tensorflow tensorflow-datasets tensorflow-estimator