【问题标题】:Tensorflow Estimator InvalidArgumentErrorTensorFlow Estimator InvalidArgumentError
【发布时间】:2019-01-09 04:21:12
【问题描述】:

我正在尝试找到一种方法来查找和修复我的 TF 代码中的错误。下面代码的sn-p成功训练模型,但调用最后一行(model.evaluate(input_fn))时产生如下错误:

InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
/var/folders/kx/y9syv3f91b1c6tzt3fgzc7jm0000gn/T/tmp_r6c94ni/model.ckpt-667.data-00000-of-00001; Invalid argument
     [[node save/RestoreV2 (defined at ../text_to_topic/train/nn/nn_tf.py:266)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/Users/foo/miniconda3/envs/tt/lib/python3.6/runpy.py", line 193, in _run_module_as_main

完全相同的代码在用于 MNIST 数据集时有效,但在用于我自己的数据集时无效。我该如何调试这个或可能是什么原因。从检查点恢复模型后,图表似乎不匹配,但我不确定如何继续解决这个问题。我试过 TF 版本 1.11 和 1.13

model = tf.estimator.Estimator(get_nn_model_fn(num_classes))

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(
    x=X_train, y=y_train,
    batch_size=batch_size,
    num_epochs=None, shuffle=True)

# Train the Model
model.train(input_fn, steps=num_steps)

# Evaluate the Model
# Define the input function for evaluating
input_fn = tf.estimator.inputs.numpy_input_fn(
    x=X_test, y=y_test,
    batch_size=batch_size, shuffle=False)

# Use the Estimator 'evaluate' method
e = model.evaluate(input_fn) 

【问题讨论】:

    标签: tensorflow tensorflow-estimator


    【解决方案1】:

    当您修改图表的某些部分时,通常会发生此错误,例如更改隐藏层的大小或删除/添加一些层,估计器会尝试加载先前的检查点。您有两种方法可以解决此问题:

    1)更改模型目录(model_dir):

    config = tf.estimator.RunConfig(model_dir='./NEW_PATH/', ) # new path
    model_estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
    

    2) 删除模型目录(model_dir)中之前保存的检查点。


    你确定图表没有被触及?

    请确保,新数据集的Data-type 与以前相同。如果您之前为输入加载浮点数,在新数据集中它们也应该是浮点数。

    【讨论】:

    • 感谢您的回复。我很确定图表没有受到影响 - 您在上面看到的代码就是我所拥有的(它在最后一行失败)。 MNIST 数据集不会出现此问题,而只会出现在另一个数据集上。它确实可能与数据类型有关,尽管到目前为止我无法准确确定问题发生的位置。我会再调查一下。
    • @foobar 你能提到“另一个数据集”吗?因为我用fashion_mnist测试过,没问题。
    • 这是一个由 sklearn.feature_extraction.text.TfidfVectorizer => 生成的 tf-idf 特征数据集(' 类型的稀疏矩阵,在 Compressed 中存储了 10352907 个元素稀疏行格式>)。然后我将它转换为一个数组(根据 tf.estimator.inputs.numpy_input_fn 的要求),就像 -data.toarray() 一样,所以 data.shape 是 (14953, 972981)。这是相当多的功能。
    • 我发现了一些功能的限制 - 低于 ~16k 的任何东西都可以,高于任何东西都会产生这个错误。当然,我的功能太多了,但有一个限制和它触发的错误类型仍然有点奇怪。
    • 我不知道那个选项,谢谢!不幸的是,它并没有解决问题(尝试了一些高达 100k 和 17k 功能的选项)。现在我只是使用较少的功能。感谢您的帮助!
    猜你喜欢
    • 2017-12-02
    • 1970-01-01
    • 2018-09-25
    • 1970-01-01
    • 2019-02-03
    • 2018-06-06
    • 2017-08-02
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多