【问题标题】:How to restore my loss from a saved meta graph?如何从保存的元图中恢复我的损失?
【发布时间】:2017-06-12 23:15:36
【问题描述】:

我已经建立了一个运行良好的简单 tensorflow 模型。 在训练时,我会保存 meta_graph 以及不同步骤的一些参数。

之后(在新脚本中)我想恢复保存的 meta_graph 并恢复变量和操作。

一切正常,但只有

with tf.name_scope('MSE'):
    error = tf.losses.mean_squared_error(Y, yhat, scope="error")

不会恢复。用下面这行

mse_error = graph.get_tensor_by_name("MSE/error:0")

"名称 'MSE/error:0' 指的是一个不存在的张量。 图中不存在操作“MSE/错误”。”

出现此错误消息。

由于我对其他变量和操作进行了完全相同的过程,这些变量和操作恢复时没有任何错误,我不知道如何处理。唯一不同的是 tf.losses.mean_squared_error 函数中只有 scope 属性,没有 name 属性。

那么如何用范围恢复损失操作呢?

这里是我如何保存和加载模型的代码。

保存:

# define network ...
saver = tf.train.Saver(max_to_keep=10)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(NUM_EPOCHS):
    # do training ..., save model all 1000 optimization steps
    if (i + 1) % 1000 == 0:
        saver.save(sess, "L:/model/mlp_model", global_step=(i+1))

恢复:

# start a session
sess=tf.Session()
# load meta graph
saver = tf.train.import_meta_graph('L:\\model\\mlp_model-1000.meta')
# restore weights
saver.restore(sess, tf.train.latest_checkpoint('L:\\model\\'))

# access network nodes
graph = tf.get_default_graph()
X = graph.get_tensor_by_name("Input/X:0")
Y = graph.get_tensor_by_name("Input/Y:0")

# restore output-generating operation used for prediction
yhat_op = graph.get_tensor_by_name("OutputLayer/yhat:0")
mse_error = graph.get_tensor_by_name("MSE/error:0") # this one doesn't work

【问题讨论】:

  • 感谢您的回复。我添加了两个脚本的草图。

标签: tensorflow scope restore


【解决方案1】:

为了让您的训练退后一步,documentation 建议您在将其保存之前将其添加到集合中,以便在恢复图表后能够指向它。

保存:

saver = tf.train.Saver(max_to_keep=10)
# put op in collection
tf.add_to_collection('train_op', train_op)
...

恢复:

saver = tf.train.import_meta_graph('L:\\model\\mlp_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('L:\\model\\'))
# recover op through collection
train_op = tf.get_collection('train_op')[0]

为什么您尝试按名称恢复张量失败了?

您确实可以通过名称获取张量——关键是您需要正确的名称。请注意,tf.losses.mean_squared_errorerror 参数是 范围 名称,而不是返回操作的名称。这可能会造成混淆,因为其他操作(例如 tf.nn.l2_loss)接受 name 参数。

最后,你的error操作的名字是MSE/error/value:0,你可以通过名字来获取它。

也就是说,直到将来您更新 tensorflow 时它再次中断。 tf.losses.mean_squared_error 不保证它的输出名称,所以它很可能会因为某种原因而改变。

我认为这就是使用集合的动机:无法保证您无法控制自己的运算符的名称。

或者,如果出于某种原因您真的想使用名称,您可以像这样重命名您的运算符:

with tf.name_scope('MSE'):
  error = tf.losses.mean_squared_error(Y, yhat, scope='error')
  # let me stick my own name on it
  error = tf.identity(error, 'my_error')

那么你就可以放心地依赖graph.get_tensor_by_name('MSE/my_error:0')了。

【讨论】:

  • 感谢您指出这一点。它现在正在工作。但是你知道为什么我不能从范围恢复吗?
  • 据我了解,“范围”一词指的是一个更大的层,它可以由一个或多个带有“名称”的组件组成。因此,要按名称识别张量,您需要“范围/层”
  • 当心,我发现将输入 x,y_ 恢复为 tf.get_collection 而不是 tf.get_tensor_by_name,有时会使 feed_dict 无法散列。在后台有一些非常难以理解的累积字典。
【解决方案2】:

tf.losses.mean_squared_error 是一个操作不是张量,你应该加载它 get_operation_by_name:

mse_error = graph.get_operation_by_name("MSE/error")

应该可以,注意不需要 ":0"

【讨论】:

    猜你喜欢
    • 2010-09-17
    • 2016-05-22
    • 1970-01-01
    • 1970-01-01
    • 2018-09-16
    • 2016-10-12
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多