【问题标题】:tensorflow error: restore checkpoint filetensorflow错误:恢复检查点文件
【发布时间】:2017-07-30 13:59:48
【问题描述】:

我建立了自己的卷积神经网络,在其中我跟踪所有可训练变量的移动平均值(tensorflow 1.0):

variable_averages = tf.train.ExponentialMovingAverage(
        0.9999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
train_op = tf.group(apply_gradient_op, variables_averages_op)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
summary_op = tf.summary.merge(summaries)
init = tf.global_variables_initializer()
sess = tf.Session(config=tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False))
sess.run(init)
# start queue runners
tf.train.start_queue_runners(sess=sess)

summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

# training loop
start_time = time.time()
for step in range(FLAGS.max_steps):
        _, loss_value = sess.run([train_op, loss])
        duration = time.time() - start_time
        start_time = time.time()
        assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

        if step % 1 == 0:
            # print current model status
            num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
            examples_per_sec = num_examples_per_step/duration
            sec_per_batch = duration/FLAGS.num_gpus
            format_str = '{} step{}, loss {}, {} examples/sec, {} sec/batch'
            print(format_str.format(datetime.now(), step, loss_value, examples_per_sec, sec_per_batch))
        if step % 50 == 0:
            summary_str = sess.run(summary_op)
            summary_writer.add_summary(summary_str, step)
        if step % 10 == 0 or step == FLAGS.max_steps:
            print('save checkpoint')
            # save checkpoint file
            checkpoint_file = os.path.join(FLAGS.train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_file, global_step=step)

这工作正常并且检查点文件被保存(保护程序版本 V2)。然后我尝试在另一个脚本中恢复检查点以评估模型。我有这段代码

# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(
    MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)

我在哪里收到错误“NotFoundError(请参阅上面的回溯):在检查点中找不到键 conv1/Variable/ExponentialMovingAverage”,其中 conv1/variable/ 是变量范围。

这个错误甚至在我尝试恢复变量之前就发生了。可以帮忙解决一下吗?

提前致谢

裘德

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    我是这样解决的:
    在图中创建第二个 ExponentialMovingAverage(...) 之前调用 tf.reset_default_graph()

    # reset the graph before create a new ema
    tf.reset_default_graph()
    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    

    我花了 2 个小时...

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-04-03
      • 2016-09-29
      • 1970-01-01
      • 1970-01-01
      • 2018-02-16
      • 2019-08-20
      • 1970-01-01
      • 2016-06-14
      相关资源
      最近更新 更多