【发布时间】:2016-08-29 16:02:53
【问题描述】:
我知道stack和github等上有无数关于如何在Tensorflow中恢复训练好的模型的问题。我读过其中的大部分(1,2,3)。
我的问题与 3 几乎完全相同,但是如果可能的话,我希望以不同的方式解决它,因为我的训练和测试需要在从 shell 调用的单独脚本中,我不想添加确切的我用来在测试脚本中定义图形的同一行,所以我不能使用 tensorflow FLAGS 和其他基于手动重新运行图形的答案。
我也不想 sess.run 每个变量并手动手动映射它们,因为我的图表很大(使用 import_graph_def 和参数 input_map)。
所以我运行一些图表并在特定脚本中对其进行训练。例如(但没有训练部分)
#Script 1
import tensorflow as tf
import cPickle as pickle
x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
pickle.dump(graph_def,output,HIGHEST_PROTOCOL)
#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")
我现在保存了图表和变量,因此即使我的图表中有额外的训练节点,我也应该能够从另一个脚本运行我的测试模型。
#Script 2
import tensorflow as tf
import cPickle as pickle
sess=tf.Session()
with open('graph.pkl','rb') as input:
graph_def=pickle.load(input)
tf.import_graph_def(graph_def,name='persisted')
那么显然我想使用保护程序恢复变量,但我遇到了与 3 相同的问题,因为没有找到要保存的变量甚至创建保护程序。所以我不能写:
saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")
有没有办法绕过这些限制?我认为通过导入图形可以恢复每个节点中未初始化的变量,但似乎不是。我真的需要像大多数给出的答案一样重新运行它吗?
【问题讨论】:
标签: python tensorflow restore