【问题标题】:Can someone explain the train function in cifar10_train.py from cifar10 tutorials in tensorflow有人可以从 tensorflow 的 cifar10 教程中解释 cifar10_train.py 中的训练功能吗
【发布时间】:2018-03-29 01:14:54
【问题描述】:

我正在关注https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10 的 cifar10 教程。 在这个项目中,有 6 个班级。搜索互联网后,我了解了 cifar10.py 和 cifar10_input.py 类。但我无法理解 cifar10_train.py 中的火车功能。这是 cifar10_train.py 类中的 train 函数。

def train():
with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # get images and labels for cifar 10
    # Force input pipeline to CPU:0 to avoid operations sometime ending on
    # GPU and resulting in a slow down

    with tf.device('/cpu:0'):
        images, labels = cifar10.distorted_inputs()

    logits = cifar10.inference(images)

    loss = cifar10.loss(logits, labels)

    train_op = cifar10.train(loss, global_step)

    class _LoggerHook(tf.train.SessionRunHook):

        def begin(self):
            self._step = -1
            self._start_time = time.time()

        def before_run(self, run_context):
            self._step += 1
            return tf.train.SessionRunArgs(loss)

        def after_run(self, run_context, run_values):
            if self._step % FLAGS.log_frequency == 0:
                current_time = time.time()
                duration = current_time - self._start_time
                self._start_time = current_time

                loss_value = run_values.results
                examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                sec_per_batch = float(duration / FLAGS.log_frequency)

                format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), self._step, loss_value,
                                    examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=FLAGS.train_dir,
            hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                   tf.train.NanTensorHook(loss),
                   _LoggerHook()],
            config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

谁能解释一下 _LoggerHook 类发生了什么?

【问题讨论】:

    标签: machine-learning tensorflow


    【解决方案1】:

    它使用MonitoredSessionSessionRunHook 在训练时记录损失。

    _LoggerHookSessionRunHook 的一个实现,它按如下所述的顺序运行:

      call hooks.begin()
      sess = tf.Session()
      call hooks.after_create_session()
      while not stop is requested:
        call hooks.before_run()
        try:
          results = sess.run(merged_fetches, feed_dict=merged_feeds)
        except (errors.OutOfRangeError, StopIteration):
          break
        call hooks.after_run()
      call hooks.end()
      sess.close()
    

    来自here

    它在session.run 之前收集loss 数据,然后以预定义的格式输出loss

    教程:https://www.tensorflow.org/tutorials/layers

    希望这是希望。

    【讨论】:

      猜你喜欢
      • 2017-10-24
      • 1970-01-01
      • 2021-04-30
      • 1970-01-01
      • 1970-01-01
      • 2020-08-01
      • 2011-06-24
      相关资源
      最近更新 更多