【问题标题】:Tenorflow load model errorTensorFlow 负载模型错误
【发布时间】:2017-04-04 10:02:29
【问题描述】:

我正在尝试加载以前保存的 TENSOFLOW 模型(图形和变量)。

这是我在训练期间导出模型的方式

tf.global_variables_initializer().run()
y = tf.matmul(x, W) + b

cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

for batch_index in range(batch_size):
    batch_xs, batch_ys = sample_dataframe(train_df, N=batch_size)
    #print(batch_xs.shape)
    #print(batch_ys.shape)
    sess.run(train_step, feed_dict = {x: batch_xs, y_:batch_ys})

    if batch_index % 100 == 0:
        print("Batch "+str(batch_index))
        correct_predictions = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
        print("Accuracy: "+str(sess.run(accuracy,
                                                  feed_dict = {x: batch_xs, y_: batch_ys})))
        #print("Predictions "+str(y))
        #print("Training accuracy: %.1f%%" %accuracy())
    if batch_index + 1 == batch_size:
        #Save the trained model
        print("Exporting trained model")
        builder = saved_model_builder.SavedModelBuilder(EXPORT_DIR)
        builder.add_meta_graph_and_variables(sess, ['simple-MNIST'])
        builder.save(as_text=True)

请忽略模型是如何定义的(这只是一个玩具示例),只检查调用 save 方法的最后几行。一切顺利,模型正确保存在 FS 中。

当我尝试加载导出的模型时,我总是收到以下错误:

TypeError: 无法将 MetaGraphDef 转换为张量或操作。

这是我加载模型的方式:

with tf.Session() as sess:
  print(tf.saved_model.loader.maybe_saved_model_directory(export_dir))
  saved_model = tf.saved_model.loader.load(sess, ['simple-MNIST'], export_dir)

  sess.run(saved_model)

知道如何解决吗?似乎模型以错误的格式导出,但我不知道如何更改它。

这是一个用于加载模型并对其评分的简单脚本。

with tf.device("/cpu:0"):
  x = tf.placeholder(tf.float32, shape =(batch_size, 784))
  W = tf.Variable(tf.truncated_normal(shape=(784, 10), stddev=0.1))
  b = tf.Variable(tf.zeros([10]))
  y_ = tf.placeholder(tf.float32, shape=(batch_size, 10))

with tf.Session() as sess:
  tf.global_variables_initializer().run()

  print(tf.saved_model.loader.maybe_saved_model_directory(export_dir))
  saved_model = tf.saved_model.loader.load(sess, ['simple-MNIST'], export_dir)

  batch_xs, batch_ys = sample_dataframe(train_df, N=batch_size)
  y = tf.matmul(x, W) + b
  correct_predictions = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))    

 print("Test Accuracy: "+ str(sess.run(accuracy, feed_dict = {x: batch_xs, y_: batch_ys})))

在全新的 PYTHON 上下文中运行此脚本,会以非常低的准确度对模型进行评分(似乎加载模型方法没有正确设置图形变量)

谢谢!

【问题讨论】:

    标签: python tensorflow tensorflow-serving


    【解决方案1】:

    我认为问题在于您不能将saved_model 传递给sess.run。来自saved_model.loader.load的文档:

    返回: 在提供的会话中加载的MetaGraphDef 协议缓冲区。这 可用于进一步提取signature-defs、collection-defs等。

    那么,当saved_modelMetaGraphDef 时,您对sess.run(saved_model) 有什么期望?如果我正确理解了load 的机制,那么图表以及相关变量将在您传递给load(..) 的会话中恢复,因此您的模型在load(..) 完成后就可以使用了。因此,您应该能够像往常一样通过(默认)图访问变量、操作和张量,无需进一步处理返回的 MetaGraphDef 对象。

    以下是有关MetaGraphDef 是什么的更多信息:What is the TensorFlow checkpoint meta file?。由此应该清楚,将其与sess.run() 一起使用是没有意义的。

    编辑

    跟进您的编辑:函数tf.saved_model.loader.load 在内部调用tf.import_meta_graph,后跟saver.restore,即它恢复图形图形中存在的变量的值。因此,您不必在添加的代码 sn-p 的开头自己重新定义变量。事实上,它可能会导致未定义的行为,因为某些节点可能在默认图中存在两次。查看此 stackoverflow 帖子以获取更多信息:Restoring Tensorflow model and viewing variable value。所以我猜这里发生的事情是:推理步骤使用您手动创建的未经训练的变量W,而不是您通过saved_model.loader 加载的预训练变量,这就是您看到低准确度的原因。

    所以,我的猜测是,如果您在开头省略 xWby_ 的定义并从恢复的图形中检索它们,例如通过调用tf.get_default_graph().get_tensor_by_name('variable_name')) 应该可以正常工作。

    PS:如果您正在恢复模型,则无需运行初始化程序(尽管我认为它也没有伤害)。

    PPS:在您的脚本中,您正在“手动”计算准确度,但我认为该操作已经存在于模型中,因为它最有可能在训练期间也需要,不是吗?因此,无需再次手动计算准确度,您只需从图中获取相应的节点并使用它即可。

    【讨论】:

    • 感谢您的回答。我认为你是对的,但我仍在努力寻找解决方案。我尝试将模型加载到单独的笔记本中,但变量似乎未初始化(我也不得不重新定义它们以使代码正常工作)。如果我在相同的上下文(本例中为笔记本)中加载模型,模型会正确地对新条目进行评分,但当我尝试将其加载到全新的笔记本中时则不然。
    • 模型加载后你想对它做什么(继续训练,用它进行推理,检索一些特定的变量/张量......)?
    • 我想用它来推断新条目。当我保存模型时,我认为它是完整的。再次感谢您!
    • 您能否使用您所描述的错误的可重现示例来编辑您的帖子?
    • 一切都越来越清楚了!不幸的是,我仍然无法加载图中定义的期限。当我尝试加载张量 W (tf.get_default_graph().get_tensor_by_name('W')) 时,我收到以下错误:“名称'W'看起来像一个(无效的)操作名称,而不是张量。张量名称必须是形式为“:”。我想我必须在保存之前在图中映射张量,对吗?
    猜你喜欢
    • 1970-01-01
    • 2022-06-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-05-16
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多