【问题标题】:Defining the input-function for TensorFlow pre-made estimator为 TensorFlow 预制估计器定义输入函数
【发布时间】:2019-06-15 16:55:15
【问题描述】:

我正在尝试使用预制的估算器 tf.estimator.DNNClassifier 用于 MNIST 数据集。我从tensorflow_dataset 加载数据集。

我追求以下四个步骤:首先构建数据集管道并定义输入函数:

## Step 1
mnist, info = tfds.load('mnist', with_info=True)

ds_train_orig, ds_test = mnist['train'], mnist['test']

def train_input_fn(dataset, batch_size):
    dataset = dataset.map(lambda x:({'image-pixels':tf.reshape(x['image'], (-1,))}, 
                                    x['label']))
    return dataset.shuffle(1000).repeat().batch(batch_size)

然后,在第 2 步中,我用一个键定义特征列,形状为 784:

## Step 2:
image_feature_column = tf.feature_column.numeric_column(key='image-pixels',
                                                        shape=(28*28))

image_feature_column
NumericColumn(key='image-pixels', shape=(784,), default_value=None, dtype=tf.float32, normalizer_fn=None)

第 3 步,我将估算器实例化如下:

## Step 3:
dnn_classifier = tf.estimator.DNNClassifier(
    feature_columns=image_feature_column,
    hidden_units=[16, 16],
    n_classes=10)

最后,通过调用.train() 方法使用估算器的第 4 步:

## Step 4:
dnn_classifier.train(
    input_fn=lambda:train_input_fn(ds_train_orig, batch_size=32),
    #lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
    steps=20)

但这会导致以下错误。看起来问题出在数据集上。

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-95736cd65e45> in <module>
      2 dnn_classifier.train(
      3     input_fn=lambda: train_input_fn(ds_train_orig, batch_size=32),
----> 4     steps=20)

~/anaconda3/envs/tf2.0-beta/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, ctx, accept_symbolic_tensors, accept_composite_tensors)
   1183       graph = get_default_graph()
   1184       if not graph.building_function:
-> 1185         raise RuntimeError("Attempting to capture an EagerTensor without "
   1186                            "building a function.")
   1187       return graph.capture(value, name=name)

RuntimeError: Attempting to capture an EagerTensor without building a function.

【问题讨论】:

    标签: python tensorflow tensorflow-datasets tensorflow-estimator


    【解决方案1】:

    我认为如果您在input_fn 之外加载一个 tensorflow_datasets 数据集,图形构造会变得很奇怪。我遵循了 TF2.0 迁移指南示例,这没有给出错误。请注意,我没有测试模型的正确性,您必须稍微修改input_fn 逻辑才能获得 eval 函数。

    # Define the estimator's input_fn
    def input_fn():
      datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
      mnist_train, mnist_test = datasets['train'], datasets['test']
      dataset = mnist_train
      dataset = mnist_train.map(lambda x, y:({'image-pixels':tf.reshape(x, (-1,))}, 
                                        y))
      return dataset.shuffle(1000).repeat().batch(32)
    
    
    image_feature_column = tf.feature_column.numeric_column(key='image-pixels',
                                                            shape=(28*28))
    
    
    dnn_classifier = tf.estimator.DNNClassifier(
        feature_columns=[image_feature_column],
        hidden_units=[16, 16],
        n_classes=10)
    
    
    dnn_classifier.train(
        input_fn=input_fn,
        steps=200)
    

    此时我收到了一堆弃用警告,但似乎估计器已经过训练。

    【讨论】:

    • > 我认为如果您在 input_fn 之外加载 tensorflow_datasets 数据集,图形构造会变得很奇怪。是的,你成功了。估算器构建自己的图,这就是为什么所有东西都必须包含在函数中以进行“惰性”评估。
    【解决方案2】:

    @dgumo 的回答是正确的。我只是想添加一个基本示例。

    输入函数返回的所有张量必须在输入函数内创建。

    #Raw data can be outside
    data_x = [0.0, 1.0, 2.0, 3.0, 4.0]
    data_y = [3.0, 4.9, 7.3, 8.65, 10.75]
    
    def supply_input():
      #Tensors must be created inside the function
      train_x = tf.constant(data_x)
      train_y = tf.constant(data_y)
    
      feature = {
          'x': train_x
      }
    
      return feature, train_y
    

    【讨论】:

      猜你喜欢
      • 2019-04-16
      • 2018-08-24
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-10-05
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多