【发布时间】:2017-07-26 05:24:56
【问题描述】:
我在 PyCharm 中使用 Tensorflow 1.0 和 python 3.5
执行this 代码后,我在每 500 次迭代时保存模型(索引、元和 ckpt 文件)。现在加载模型我们需要指向哪个文件?
我写了下面的代码来加载 ckpt(weights) 文件(没有对上面的 github 代码进行任何更改)
w1 = tf.Variable(tf.zeros([5, 5, 1, 32]), name="conv1/W")
b1 = tf.Variable(tf.zeros(shape=[32]), name="conv1/B")
w2 = tf.Variable(tf.zeros([5, 5, 32, 64]), name="conv2/W")
b2 = tf.Variable(tf.zeros(shape=[64]), name="conv2/B")
w3 = tf.Variable(tf.zeros([3136, 1024]), name="fc1/W")
b3 = tf.Variable(tf.zeros(shape=[1024]), name="fc1/B")
w4 = tf.Variable(tf.zeros([1024, 10]), name="fc2/W")
b4 = tf.Variable(tf.zeros(shape=[10]), name="fc2/B")
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "D:/tmp/mnist_tutorial/model.ckpt")
出现以下错误
W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\framework\op_kernel.cc:975] 未找到:TensorSliceReader 构造函数不成功:找不到 D:/ 的任何匹配文件tmp/mnist_tutorial/model.ckpt
回溯(最近一次通话最后一次):
_do_call 中的文件“C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\client\session.py”,第 1021 行 返回 fn(*args)
_run_fn 中的文件“C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\client\session.py”,第 1003 行 状态,运行元数据)
退出中的文件“C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\contextlib.py”,第 66 行 下一个(self.gen)
文件“C:\Users\Admin\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\framework\errors_impl.py”,第 469 行,在 raise_exception_on_not_ok_statuspywrap_tensorflow.TF_GetCode(status)) tensorflow.python.framework.errors_impl.NotFoundError:不成功的 TensorSliceReader 构造函数:未能找到 D:/tmp/mnist_tutorial/model.ckpt 的任何匹配文件
[[节点:save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_1/tensor_names, save/恢复V2_1/shape_and_slices)]]
在训练开始之前(在卷积层的函数定义中),我们可以通过以下方式打印各个层的权重:
w = tf.Variable(tf.truncated_normal([5, 5, 1, 64], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[64]), name="B")
init = tf.global_variables_initializer()
with tf.Session()as sess:
sess.run(init)
print("weight type is ", w)
print('bias type is', b)
print("random generated weights are: ")
x = tf.Print('conv/W:0', [w],summarize=1600)
sess.run(x)
print("Generated Biases are: ")
y = tf.Print(b, [b],summarize=64)
sess.run(y)
如果有很多卷积层和全连接层,如何从*.ckpt 文件中加载和打印任何特定层的权重和偏差,因为上述方法不起作用
更新:对代码进行了更改并更新了错误消息
【问题讨论】:
-
不重命名 ckpt 文件。假设文件是“model.ckpt-2000.index”和“model.cpkt-2000.data-00000-of-00001”。您应该使用保护程序恢复的路径是“model.ckpt-2000”,尽管实际上没有这样的文件。对于读取变量,NewCheckpointReader 是一个不错的选择。详情可以查看代码。
标签: python tensorflow neural-network deep-learning conv-neural-network