【问题标题】:TF LSTM: Save State from training session for prediction session laterTF LSTM:从训练会话中保存状态以供以后预测会话
【发布时间】:2017-04-27 16:56:42
【问题描述】:

我正在尝试从训练中保存最新的 LSTM 状态,以便稍后在预测阶段重用。我遇到的问题是,在 TF LSTM 模型中,状态通过占位符和 numpy 数组的组合从一次训练迭代传递到下一次迭代——会话时默认情况下,这两种迭代似乎都不包含在图表中已保存。

为了解决这个问题,我创建了一个专用的 TF 变量来保存最新版本的状态,以便将其添加到会话图中,如下所示:

# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

这似乎很好地将savedState 变量添加到保存的会话图中,并且可以在以后使用会话的其余部分轻松恢复。

但问题是,我设法在稍后恢复的会话中实际使用该变量的唯一方法是,如果我在恢复会话后初始化会话中的所有变量(这似乎重置了所有经过训练的变量,包括权重/偏差/等等!)。如果我首先初始化变量,然后恢复会话(这在保留经过训练的变量方面效果很好),那么我会收到一个错误,我正在尝试访问一个未初始化的变量。

我知道有一种方法可以初始化特定的单个变量(我在最初保存它时正在使用它)但问题是当我们恢复它们时,我们按名称将它们称为字符串,我们不只是传递变量本身?!

# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')

完成这项工作的正确方法是什么?作为一种解决方法,我目前将状态保存为 CSV,就像一个 numpy 数组一样,然后以同样的方式恢复它。它工作正常,但显然不是最干净的解决方案,因为保存/恢复 TF 会话的所有其他方面都可以完美运行。

任何建议表示赞赏!

**编辑: 这是运行良好的代码,如下面接受的答案中所述:

# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')


# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])

我还没有测试过这种方法是否会无意中初始化保存的 Session 中的任何其他变量,但不明白为什么会这样,因为我们只运行特定的变量。

【问题讨论】:

    标签: tensorflow lstm


    【解决方案1】:

    问题是在构造Saver 之后创建新的tf.Variable 意味着Saver 不知道新变量。它仍然保存在元图中,但没有保存在检查点中:

    import tensorflow as tf
    with tf.Graph().as_default():
      var_a = tf.get_variable("a", shape=[])
      saver = tf.train.Saver()
      var_b = tf.get_variable("b", shape=[])
      print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
      initializer = tf.global_variables_initializer()
      with tf.Session() as session:
        session.run([initializer])
        saver.save(session, "/tmp/model", global_step=0)
    with tf.Graph().as_default():
      new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
      print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
      with tf.Session() as session:
        new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!
    

    我已经用Saver 知道的变量对上述问题的快速重现进行了注释。

    现在,解决方案相对容易。我建议在Saver 之前创建Variable,然后使用tf.assign 更新其值(确保运行 tf.assign 返回的操作)。分配的值将保存在检查点中并像其他变量一样恢复。

    None 被传递给它的var_list 构造函数参数时,Saver 可以更好地处理这种情况(即它可以自动获取新变量)。对此请随时 open a feature request on Github

    【讨论】:

    • 谢谢,只需在声明 Saver 之前声明变量就可以了。不确定这是否需要一个全新的功能请求。为了其他人,我用似乎有效的代码更新了我的问题..
    猜你喜欢
    • 2019-08-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-05-28
    • 1970-01-01
    • 2022-01-23
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多