【问题标题】:tensorflow cifar10 resume training from checkpoint filetensorflow cifar10 从检查点文件恢复训练
【发布时间】:2016-09-29 11:08:19
【问题描述】:

在使用 Tensorflow 时,我正在尝试使用检查点文件恢复 CIFAR10 训练。参考其他一些文章,我尝试了 tf.train.Saver().restore 没有成功。有人可以告诉我如何进行吗?

来自 Tensorflow CIFAR10 的代码 sn-p

def train():
  # methods to build graph from the cifar10_train.py
  global_step = tf.Variable(0, trainable=False)
  images, labels = cifar10.distorted_inputs()
  logits = cifar10.inference(images)
  loss = cifar10.loss(logits, labels)
  train_op = cifar10.train(loss, global_step)
  saver = tf.train.Saver(tf.all_variables())
  summary_op = tf.merge_all_summaries()

  init = tf.initialize_all_variables() 
  sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
  sess.run(init)


  print("FLAGS.checkpoint_dir is %s" % FLAGS.checkpoint_dir)

  if FLAGS.checkpoint_dir is None:
    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
  else:
    # restoring from the checkpoint file
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
    tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)

  # cur_step prints out well with the checkpointed variable value
  cur_step = sess.run(global_step);
  print("current step is %s" % cur_step)

  for step in xrange(cur_step, FLAGS.max_steps):
    start_time = time.time()
    # **It stucks at this call **
    _, loss_value = sess.run([train_op, loss])
    # below same as original

【问题讨论】:

    标签: tensorflow restore deep-learning checkpoint


    【解决方案1】:

    问题似乎是这一行:

    tf.train.start_queue_runners(sess=sess)
    

    ...仅在FLAGS.checkpoint_dir is None 时执行。如果您从检查点恢复,您仍然需要启动队列运行器。

    请注意,我建议您在创建tf.train.Saver 之后启动队列运行器(由于已发布版本的代码中存在竞争条件),因此更好的结构是:

    if FLAGS.checkpoint_dir is not None:
      # restoring from the checkpoint file
      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
      tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)
    
    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)
    
    # ...
    
    for step in xrange(cur_step, FLAGS.max_steps):
      start_time = time.time()
      _, loss_value = sess.run([train_op, loss])
      # ...
    

    【讨论】:

    • 感谢您的回答!它解决了这个问题。我认为 queue_runner 负责创建输入图像(通过失真),这不是我从检查点文件恢复的必要步骤。
    猜你喜欢
    • 1970-01-01
    • 2018-02-16
    • 2019-10-07
    • 2020-08-11
    • 2017-09-13
    • 2017-07-12
    • 2021-04-08
    • 1970-01-01
    • 2017-07-30
    相关资源
    最近更新 更多