【问题标题】:Early stopping with tf.estimator, how?使用 tf.estimator 提前停止,如何?
【发布时间】:2018-04-18 14:56:59
【问题描述】:

我在 TensorFlow 1.4 中使用 tf.estimatortf.estimator.train_and_evaluate 很棒,但我需要尽早停止。添加它的首选方式是什么?

我认为在某处有一些tf.train.SessionRunHook。我看到有一个带有ValidationMonitor 的旧 contrib 包似乎已经提前停止,但它似乎在 1.4 中不再存在。或者将来首选的方式是依靠tf.keras(提前停止真的很容易)而不是tf.estimator/tf.layers/tf.data,也许?

【问题讨论】:

    标签: python tensorflow neural-network keras tensorflow-estimator


    【解决方案1】:

    好消息! tf.estimator 现在对 master 提供了提前停止支持,看起来它将在 1.10 中。

    estimator = tf.estimator.Estimator(model_fn, model_dir)
    
    os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.
    
    early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
        estimator,
        metric_name='loss',
        max_steps_without_decrease=1000,
        min_steps=100)
    
    tf.estimator.train_and_evaluate(
        estimator,
        train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
        eval_spec=tf.estimator.EvalSpec(eval_input_fn))
    

    【讨论】:

    • 这看起来很有希望,但似乎不在 r1.9 中(我相信它是今天的稳定版本)>>> tf.contrib.estimator.stop_if_no_decrease_hook Traceback (most recent call last): File "<stdin>", line 1, in <module> AttributeError: module 'tensorflow.contrib.estimator' has no attribute 'stop_if_no_decrease_hook'
    • 我在 1.10 版本中尝试过,但出现以下错误:NotFoundError: Key signal_early_stopping/STOP not found in checkpoint [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_INT64, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
    • @TasosGlrs 你确定你没有尝试添加钩子并从先前运行中创建的没有钩子的现有检查点继续,因此缺少必要的键?
    • @oens 是的,我每次测试时都会创建一个新模型目录。我不知道它是否相关,但我从这样的 Keras 模型创建了我的估算器:estimator = tf.keras.estimator.model_to_estimator(keras_model=my_model, model_dir='{0}/{1}/model_{2}'.format(def_modelpath, modeldir, my_model.name))
    • 我也有关于signal_early_stopping not found 的类似问题。似乎是因为 early_stopping 钩子只能放在 TrainSpec 钩子中。对于 EvalSpec 挂钩中的使用,会发生此错误。
    【解决方案2】:

    首先,您必须命名损失以使其可用于提前停止调用。如果您的损失变量在估计器中被命名为“损失”,则行

    copyloss = tf.identity(loss, name="loss")
    

    它的正下方会起作用。

    然后,使用此代码创建一个挂钩。

    class EarlyStopping(tf.train.SessionRunHook):
        def __init__(self,smoothing=.997,tolerance=.03):
            self.lowestloss=float("inf")
            self.currentsmoothedloss=-1
            self.tolerance=tolerance
            self.smoothing=smoothing
        def before_run(self, run_context):
            graph = ops.get_default_graph()
            #print(graph)
            self.lossop=graph.get_operation_by_name("loss")
            #print(self.lossop)
            #print(self.lossop.outputs)
            self.element = self.lossop.outputs[0]
            #print(self.element)
            return tf.train.SessionRunArgs([self.element])
        def after_run(self, run_context, run_values):
            loss=run_values.results[0]
            #print("loss "+str(loss))
            #print("running average "+str(self.currentsmoothedloss))
            #print("")
            if(self.currentsmoothedloss<0):
                self.currentsmoothedloss=loss*1.5
            self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
            if(self.currentsmoothedloss<self.lowestloss):
                self.lowestloss=self.currentsmoothedloss
            if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
                run_context.request_stop()
                print("REQUESTED_STOP")
                raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')
    

    这会将指数平滑的损失验证与其最低值进行比较,如果它高于容差,则停止训练。如果它停止得太早,提高容差和平滑会使它停止得更晚。保持平滑低于 1,否则它永远不会停止。

    如果您想根据不同的条件停止,可以将 after_run 中的逻辑替换为其他内容。

    现在,将此挂钩添加到评估规范中。您的代码应如下所示:

    eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#
    

    重要提示:函数 run_context.request_stop() 在 train_and_evaluate 调用中被破坏,并且不会停止训练。所以,我提出了一个价值错误来停止训练。因此,您必须将 train_and_evaluate 调用包装在 try catch 块中,如下所示:

    try:
        tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
    except ValueError as e:
        print("training stopped")
    

    如果你不这样做,当训练停止时代码会崩溃并报错。

    【讨论】:

    • 这似乎不做早停?如果我正确理解了您的代码,那么您监控的是训练损失而不是验证损失。
    • 这是附加到 EvalSpec 的,因此它正在监视验证丢失。如果训练时间足够长,它会提前停止。如果停止速度不够快,您可能需要将平滑值降低到 0.99 并降低容差。
    【解决方案3】:

    另一个不使用钩子的选项是创建一个tf.contrib.learn.Experiment(即使在contrib 中,它似乎也支持新的tf.estimator.Estimator)。

    然后通过(显然是实验性的)方法continuous_train_and_eval 和适当定制的continuous_eval_predicate_fn 进行训练。

    根据 tensorflow 文档,continuous_eval_predicate_fn

    判断每次迭代后是否继续 eval 的谓词函数。

    并使用上次评估运行的eval_results 调用。对于提前停止,使用一个自定义函数,该函数将当前最佳结果和计数器保持状态,并在达到提前停止条件时返回False

    添加注意事项:此方法将使用 tensorflow 1.7 已弃用的方法(从该版本开始,所有 tf​​.contrib.learn 均已弃用:https://www.tensorflow.org/api_docs/python/tf/contrib/learn

    【讨论】:

      【解决方案4】:

      是的,有tf.train.StopAtStepHook

      此挂钩请求在执行了多个步骤或到达最后一步后停止。只能指定两个选项之一。

      您还可以扩展它并根据步骤结果实施自己的停止策略。

      class MyHook(session_run_hook.SessionRunHook):
        ...
        def after_run(self, run_context, run_values):
          if condition:
            run_context.request_stop()
      

      【讨论】:

      • tf.train.StopAtStepHook 似乎没有提前停止?但是,是的,我想我可以自己做一个评估验证集的钩子,我只是希望它从 TensorFlow 1.4 开始内置。谢谢!
      • @CarlThomé 我明白你的意思。你是对的,tensorflow 现在只捆绑了琐碎的会话钩子,并建议使用自己的钩子插入复杂的决策。
      • 什么变量可以帮助我在每个步骤中捕获 after_run 函数中的损失以实现 ealry 停止?
      猜你喜欢
      • 2020-10-12
      • 2020-01-25
      • 1970-01-01
      • 2022-01-17
      • 2020-11-16
      • 1970-01-01
      • 2020-02-06
      • 2020-05-19
      • 2017-08-17
      相关资源
      最近更新 更多