【发布时间】:2017-04-27 16:56:42
【问题描述】:
我正在尝试从训练中保存最新的 LSTM 状态,以便稍后在预测阶段重用。我遇到的问题是,在 TF LSTM 模型中,状态通过占位符和 numpy 数组的组合从一次训练迭代传递到下一次迭代——会话时默认情况下,这两种迭代似乎都不包含在图表中已保存。
为了解决这个问题,我创建了一个专用的 TF 变量来保存最新版本的状态,以便将其添加到会话图中,如下所示:
# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')
这似乎很好地将savedState 变量添加到保存的会话图中,并且可以在以后使用会话的其余部分轻松恢复。
但问题是,我设法在稍后恢复的会话中实际使用该变量的唯一方法是,如果我在恢复会话后初始化会话中的所有变量(这似乎重置了所有经过训练的变量,包括权重/偏差/等等!)。如果我首先初始化变量,然后恢复会话(这在保留经过训练的变量方面效果很好),那么我会收到一个错误,我正在尝试访问一个未初始化的变量。
我知道有一种方法可以初始化特定的单个变量(我在最初保存它时正在使用它)但问题是当我们恢复它们时,我们按名称将它们称为字符串,我们不只是传递变量本身?!
# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')
完成这项工作的正确方法是什么?作为一种解决方法,我目前将状态保存为 CSV,就像一个 numpy 数组一样,然后以同样的方式恢复它。它工作正常,但显然不是最干净的解决方案,因为保存/恢复 TF 会话的所有其他方面都可以完美运行。
任何建议表示赞赏!
**编辑: 这是运行良好的代码,如下面接受的答案中所述:
# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')
# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])
我还没有测试过这种方法是否会无意中初始化保存的 Session 中的任何其他变量,但不明白为什么会这样,因为我们只运行特定的变量。
【问题讨论】:
标签: tensorflow lstm