【问题标题】:Can I get the tensorflow session from the estimator?我可以从估计器中获得 tensorflow 会话吗?
【发布时间】:2019-06-02 19:56:03
【问题描述】:

我正在使用 tf.estimator 的 LinearRegressor 并希望将我的学习率衰减(最初是指数衰减)更改为使用损失的衰减。但要做到这一点,我需要将评估损失传递给学习率衰减张量的一些占位符,并且在这一步中,我需要 tf.session。

我尝试tf.get_default_session() 获取估算器创建的会话,但此会话具有估算器使用的不同图表。


    def my_decay(learning_rate, global_step, decay_step, loss, decay_rate):
      # If loss is not reduced, than decay with decay_rate.

    loss = tf.placeholder(tf.float32)
    estimator = tf.estimator.LinearRegressor(
    feature_columns=feature_columns,
    optimizer==lambda: tf.train.FtrlOptimizer(
        learning_rate=my_decay(learning_rate=0.1,
        global_step=tf.get_global_step(), decay_step=10000,
        loss=loss, decay_rate=0.96)),
      config=sess_config
    )

    for _ in range(n_epoches):
      metrics = tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
      session.run(loss.assign(metrics['loss']))

使用上面的代码,我需要从估算器中获取session。 有没有办法得到这个?

提前谢谢你!

【问题讨论】:

  • 简短的回答是你不能。如果您绝对需要访问会话,请使用受监控的会话。但我很确定你可以在没有会话的情况下定义自定义损失衰减,只要你能更具体一点
  • @Sharky,谢谢你的回答。似乎一般不建议使用受监控的会话。但是是否有可能从估计器中提取损失并将其传递给my_decay
  • 我觉得你可以用estimator.get_variable_value,把名字传给它

标签: tensorflow tensorflow-estimator


【解决方案1】:

此类问题的预期解决方案是将tf.train.SessionRunHook 子类化并覆盖before_run 方法以返回合适的tf.train.SessionRunArgs。这将允许您在训练时提供值并将提取添加到 session.run 调用。您的类必须在调用之间携带对占位符和 loss 状态的引用。

然后您只需实例化该类并将钩子添加到您的estimator.train 调用中的hooks 参数,或者在本例中为您的train_spec。如果您希望使用评估损失而不是训练损失,则可以通过向eval_spec 添加另一个挂钩来读取after_run 方法中的值。

【讨论】:

  • 没听清楚,能给点代码示例吗?
猜你喜欢
  • 2018-11-05
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2012-08-25
  • 1970-01-01
  • 2020-01-23
相关资源
最近更新 更多