【发布时间】:2018-07-04 15:19:49
【问题描述】:
我偶然发现了一个我无法解决的错误。我正在尝试做的是以下事情:
我想训练一个(虚拟)模型,在每次迭代时将 a 添加到 b。完成后,我想将变量保存为检查点。我第一次运行它时,它会从头开始构建模型。每次我重新运行模型时,它应该从最后一个检查点开始并再次添加。因此,我从 .meta 文件加载完整的图表。全局 step 变量用于跟踪我训练的总步数。
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# List ALL tensors.
print_tensors_in_checkpoint_file(tf.train.latest_checkpoint('./'), all_tensors=True, tensor_name='')
tf.reset_default_graph()
global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False)
def model(a, b):
b = tf.assign_add(b, a)
return b
with tf.Session() as sess:
ckpt = tf.train.latest_checkpoint('./')
if ckpt:
saver = tf.train.import_meta_graph('./my_test_model-1.meta')
saver.restore(sess, ckpt)
else:
a = tf.Variable(3.0, name='a')
b = tf.Variable(5.0, name='b')
b = model(a, b)
### before EDIT
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
###
### after EDIT
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
###
for step in range(5):
global_step.assign_add(1).eval()
print(global_step.eval())
print(b.eval())
saver.save(sess, './my_test_model', global_step=global_step)
脚本第一次运行正常,输出如下:
1 # step
8.0 # value of b
2
11.0
3
14.0
4
17.0
5
20.0
我第二次运行程序时,我得到这个输出,然后是一个错误:
tensor_name: a
3.0
tensor_name: b
20.0
tensor_name: global_step
0
tensor_name: global_step_1
5
INFO:tensorflow:Restoring parameters from ./my_test_model-5
Traceback(最近一次调用最后一次):... FailedPreconditionError: 尝试使用未初始化的值 global_step [[Node: AssignAdd_2 = AssignAdd[T=DT_INT32, use_locking=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](global_step, AssignAdd_2/value)]] ...
第一次,很明显它不会在我为所有变量运行初始化程序时引发错误。但我认为恢复模型算作某种初始化?我真的无法理解这个概念。我也尝试在定义a和b之后定义global_step,但这导致第一次加载时出现另一个错误:
ValueError:无法使用默认会话来评估张量: 张量的图与会话的图不同。通过显式 会话到
eval(session=sess)。 错误是指递增global_step(global_step.assign_add(1).eval()) 的行。
我做错了什么?我应该在哪里定义变量?
感谢您对此问题的任何帮助!感谢您阅读本文。
编辑: 感谢@Diana,前提条件错误消失了。不幸的是,发生了另一个错误。每当运行加载检查点的脚本时,都会引发名称错误:
NameError:名称“global_step”未定义。
这也发生在变量“b”上。恢复检查点时不应该是加载的名称吗?当我检查检查点文件中的张量时,张量似乎具有正确的名称和值。
【问题讨论】:
标签: tensorflow loading