【发布时间】:2016-03-10 19:03:58
【问题描述】:
我有一个基于 TensorFlow 的神经网络和一组变量。
训练函数是这样的:
def train(load = True, step)
"""
Defining the neural network is skipped here
"""
train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
# Saver
saver = tf.train.Saver()
if not load:
# Initalizing variables
sess.run(tf.initialize_all_variables())
else:
saver.restore(sess, 'Variables/map.ckpt')
print 'Model Restored!'
# Perform stochastic gradient descent
for i in xrange(step):
train_step.run(feed_dict = {x: train, y_: label})
# Save model
save_path = saver.save(sess, 'Variables/map.ckpt')
print 'Model saved in file: ', save_path
print 'Training Done!'
我是这样调用训练函数的:
# First train
train(False, 1)
# Following train
for i in xrange(10):
train(True, 10)
我进行这种培训是因为我需要向我的模型提供不同的数据集。但是,如果我以这种方式调用 train 函数,TensorFlow 会生成错误消息,指出它无法从文件中读取保存的模型。
经过一些实验,我发现出现这种情况是因为检查点保存速度很慢。在文件写入磁盘之前,下一个 train 函数将开始读取,从而产生错误。
我曾尝试使用 time.sleep() 函数在每次通话之间进行一些延迟,但没有奏效。
有人知道如何解决这种写/读错误吗?非常感谢!
【问题讨论】:
标签: python io tensorflow