【问题标题】:TensorFlow 'global_step' variable not getting updated for exponential decayTensorFlow“global_step”变量未针对指数衰减进行更新
【发布时间】:2018-01-09 16:45:16
【问题描述】:

我有一个网络,我正在使用学习率的指数衰减。为此,我正在跟踪一个 'global_step' TF 变量,该变量在处理的每个批次中都增加 1。然而,看起来实际上,它并没有真正得到更新。这是代码。

...
global_step = tf.Variable(0, trainable=False, name='global_step')
starter_learning_rate = 0.01
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.50)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optm = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step)
init = tf.global_variables_initializer()


def train(file):
    global global_step
    for batch in batches:
        global_step += 1
        ...        
    return loss

...
global_step = 0
for epoch in EPOCHS:
    for f in files:
        loss = train(f)

函数内部和外部的 global_step 正在更新。但是我的学习率没有改变。当我将摘要附加到我的 TF global_step 变量时,我看到它保持在 0 不变。

这里有什么问题?

【问题讨论】:

    标签: python tensorflow neural-network


    【解决方案1】:

    其实我没看到你在哪里设置learning_rate变量,但这是如何使用它的方式:

    定义全局步进变量

    global_step = tf.Variable(0)
    

    定义不同params的学习率变化方式

    learning_rate = tf.train.exponential_decay(0.1, global_step, 500, 0.7, staircase=True)
    

    将它们传递给优化器

    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    

    【讨论】:

    • 我实际上是在以类似的方式定义学习率。我已经更新了我的问题以反映相同的情况。
    • 这可能是因为您创建了两次 global_step 变量:global_step = tf.Variable(0, trainable=False, name='global_step')global_step = 0
    【解决方案2】:

    这里有两个问题。

    1. 您不应该自己增加变量
    2. 您实际上并没有手动增加 global_step,即使它看起来好像是。

    问题 1

    根据tf.train.AdamOptimizer 的文档,调用minimize() 是调用compute_gradients()apply_gradients() 的简写。您实际返回到 optm 变量中的是:

    应用 [...] 渐变的 Operation。如果 global_step 不是 None,则该操作也会增加 global_step

    这意味着为global_step (tf.Variable) 存储的值将在您调用sess.run(optm) 时递增。

    问题 2

    在给定代码的第一行之后,您有一个名为global_step 的变量,它是一个tf.Variable 对象。重要的是,它没有数值。它只是对一个对象的引用,当您运行sess.run提供一个数值

    为了方便图的构建,Tensorflow 允许这样的操作:

    a = tf.constant(1)
    b = a + 2
    

    此时,变量b 将是一个新的张量对象。我们可以运行sess.run(b) 并获取实际值(当然是在初始化之后),但b 是一个对象,而不是一个值。当你运行 global_step += 1 时,它会创建一个新的张量对象,当你 sess.run 它时,它会进行一些计算并返回一个数字。

    所以,在global global_step,您仍然可以引用tf.Variable 张量,但在第一个循环之后,您的global_step 将引用: 一个张量,当通过sess.run 时,将为您提供将 1 添加到原始 tf.Variable 对象的结果。

    在第二个循环之后,您的 global_step 引用了一个张量,它会给您在原始 tf.Variable 对象加 1 的结果中加 1 的结果。

    在循环过程中,您正在添加操作并引用新结果,但从未真正更改为 tf.Variable 对象存储的值。这就是为什么当你运行sess.run(global_step) 时,你会得到你期望的数字,而实际的变量值永远不会改变。

    【讨论】:

      猜你喜欢
      • 2020-10-02
      • 2011-11-14
      • 1970-01-01
      • 2021-07-17
      • 1970-01-01
      • 2021-12-01
      • 2021-05-06
      相关资源
      最近更新 更多