【问题标题】:Tensorflow : Trainning and test into the same graph with input queuesTensorflow:使用输入队列在同一个图中训练和测试
【发布时间】:2017-10-25 02:11:48
【问题描述】:

我遇到的问题无法通过我在互联网上找到的内容解决。

我已经构建了我的神经网络并将其连接到输入管道。 从 tfrecord 读取数据,使用 tf.train.batch 和 queueRunners、Coords 等。

我已经将我的 NN 构建到一个名为“Model”的 python 类中,我使用它:

model = Model(...这里的所有超参数...)

...

model.predict()

model.step()

所有训练阶段都运行良好。

但现在我想在训练的每 X 个时期/步骤中添加一个测试阶段。

我真的不知道该怎么做。 我有几个想法,但我没有找到最好的:

  • 将代码复制到我的类中以获得:loss_train 和 loss_test,等等,用于我的图表的每个节点? (在训练和测试之间使用共享变量)
  • 为我的模型创建 2 个实例:

model_train = 模型(reuse=false)

model_test = 模型(reuse=true)

  • 使用 tf.make_template ?我真的没有找到这个函数的任何好例子......
  • 还有其他解决方案吗?

如果有任何建议,我将不胜感激,

【问题讨论】:

    标签: testing tensorflow dataset training-data


    【解决方案1】:

    我在试验 TFRecords 数据集时遇到了同样的问题。有几种可能性。因为我想在只有一个 GPU 的计算机上执行此操作,所以我按如下方式实现它:

    # Training Dataset
    train_dataset = tf.contrib.data.TFRecordDataset(train_files)
    train_dataset = train_dataset.map(parse_function)
    train_dataset = train_dataset.shuffle(buffer_size=10000)
    train_dataset = train_dataset.batch(200)
    # Validation Dataset
    validation_dataset = tf.contrib.data.TFRecordDataset(val_files)
    validation_dataset = validation_dataset.map(parse_function)
    validation_dataset = validation_dataset.batch(200)
    
    # A feedable iterator is defined by a handle placeholder and its structure. We
    # could use the `output_types` and `output_shapes` properties of either
    # `training_dataset` or `validation_dataset` here, because they have
    # identical structure.
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.contrib.data.Iterator.from_string_handle(handle,
     train_dataset.output_types, train_dataset.output_shapes)
    next_element = iterator.get_next()
    
    # Generate the Iterators
    training_iterator = train_dataset.make_initializable_iterator()
    validation_iterator = validation_dataset.make_one_shot_iterator()
    
    # The `Iterator.string_handle()` method returns a tensor that can be evaluated
    # and used to feed the `handle` placeholder.
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    

    然后为了访问元素,你可以这样:

    img, lbl = sess.run(next_element, feed_dict={handle: training_handle})
    

    并根据您愿意做什么来交换句柄 ATM。

    但是请记住,这是不可并行的。通过此链接,您可以深入了解创建多个输入管道Tensorflow | Reading Data 的不同方法。

    【讨论】:

      猜你喜欢
      • 2016-10-09
      • 2016-12-31
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-09-02
      • 1970-01-01
      • 2018-12-23
      • 2021-11-12
      相关资源
      最近更新 更多