【问题标题】:How to switch model iterator between train and validate datasets?如何在训练和验证数据集之间切换模型迭代器?
【发布时间】:2019-01-30 22:34:48
【问题描述】:

我正在学习 TensorFlow“底层 API”,您可以在其中使用 tf.layers 手动指定层、创建数据集和迭代器,然后运行循环来训练和验证模型。我正在尝试进行培训和验证。不幸的是,我在尝试在训练和验证数据集之间切换时遇到了错误:

这是我所拥有的:

self.train_it = \
    train_dataset.batch(self.batch_size).make_initializable_iterator()
self.validate_it = \
    train_dataset.batch(self.batch_size).make_initializable_iterator()

...

input_layer = self.train_it.get_next()[0]
hidden1 = tf.layers.dense(
    input_layer,
    ... )

...

with tf.name_scope('train'):
  self.train_op = \
        tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)

...

for epo in range(epochs):
  # Train using self.train_it iterator.
  self.sess.run(self.train_it.initializer)
  total_loss = 0
  for iteration in range(n_batches):
    summary, _, batch_loss = self.sess.run([self.merged_summary, \
        self.train_op, self.loss])
    total_loss += batch_loss
  print('   Epoch : {}/{}, Training loss = {:.4f}'. \
            format(epo+1, epochs, total_loss / n_batches))
  # Validate using self.valid_it iterator.
  self.sess.run(self.validate_it.initializer)
  # HOW DO I TELL THE MODEL TO USE self.valid_it INSTEAD OF self.train_it ???

这里的问题是,一开始我已经告诉模型使用 train_it : input_layer = self.train_it.get_next()[0] ,现在我必须告诉它每个 epoch 在 train_itvalidate_it 之间切换。我必须在 API 中遗漏一些关于如何做到这一点的内容。

【问题讨论】:

    标签: tensorflow iterator tensorflow-datasets


    【解决方案1】:

    我会使用可重新初始化的迭代器并执行以下操作。

    train_dataset = train_dataset.batch(batch_size_train)
    val_dataset = validation_dataset.batch(batch_size_val)
    
    iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
    
    train_init_op = iterator.make_initializer(train_dataset)
    val_init_op = iterator.make_initializer(val_dataset)
    
    data, labels = iterator.get_next()
    

    然后链接模型中的数据和标签。之后在训练时执行以下操作:

    for e in range(epochs):
        sess.run(train_init_op)
        for iteration in range(n_batches_val):
            ....
        sess.run(val_init_op)
        for iteration in range(n_batches_val):
            ....
    

    如果您发现有什么令人困惑的地方,请告诉我。

    【讨论】:

      猜你喜欢
      • 2017-04-30
      • 2017-01-29
      • 2019-05-01
      • 2019-06-05
      • 1970-01-01
      • 2018-06-16
      • 2021-01-05
      • 2016-09-13
      • 2015-05-18
      相关资源
      最近更新 更多