【问题标题】:TensorFlow checkpoint save and readTensorFlow 检查点保存和读取
【发布时间】:2016-03-10 19:03:58
【问题描述】:

我有一个基于 TensorFlow 的神经网络和一组变量。

训练函数是这样的:

def train(load = True, step)
    """
    Defining the neural network is skipped here
    """

    train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
    # Saver
    saver = tf.train.Saver()

    if not load:
        # Initalizing variables
        sess.run(tf.initialize_all_variables())
    else:
        saver.restore(sess, 'Variables/map.ckpt')
        print 'Model Restored!'

    # Perform stochastic gradient descent
    for i in xrange(step):
        train_step.run(feed_dict = {x: train, y_: label})

    # Save model
    save_path = saver.save(sess, 'Variables/map.ckpt')
    print 'Model saved in file: ', save_path
    print 'Training Done!'

我是这样调用训练函数的:

# First train
train(False, 1)
# Following train
for i in xrange(10):
    train(True, 10)

我进行这种培训是因为我需要向我的模型提供不同的数据集。但是,如果我以这种方式调用 train 函数,TensorFlow 会生成错误消息,指出它无法从文件中读取保存的模型。

经过一些实验,我发现出现这种情况是因为检查点保存速度很慢。在文件写入磁盘之前,下一个 train 函数将开始读取,从而产生错误。

我曾尝试使用 time.sleep() 函数在每次通话之间进行一些延迟,但没有奏效。

有人知道如何解决这种写/读错误吗?非常感谢!

【问题讨论】:

    标签: python io tensorflow


    【解决方案1】:

    您的代码中有一个微妙的问题:每次调用 train() 函数时,都会将更多节点添加到同一个 TensorFlow 图中,用于所有模型变量和神经网络的其余部分。这意味着每次构造tf.train.Saver() 时,它都会包含之前调用train() 的所有变量。每次重新创建模型时,都会使用额外的 _N 后缀创建变量,以赋予它们唯一的名称:

    1. 使用变量 var_avar_b 构建的保护程序。
    2. 使用变量 var_avar_bvar_a_1var_b_1 构建的保护程序。
    3. 使用变量var_avar_bvar_a_1var_b_1var_a_2var_b_2 构建的保护程序。

    tf.train.Saver 的默认行为是将每个变量与相应操作的名称相关联。这意味着var_a_1 不会从var_a 初始化,因为它们最终具有不同的名称。

    解决方案是每次调用train() 时创建一个新图表。修复它的最简单方法是更改​​主程序,为每次调用 train() 创建一个新图形,如下所示:

    # First train
    with tf.Graph().as_default():
        train(False, 1)
    
    # Following train
    for i in xrange(10):
        with tf.Graph().as_default():
            train(True, 10)
    

    ...或者,等效地,您可以将with 块移动到train() 函数内。

    【讨论】:

    • 那么向图中添加节点的行为类似于 C++ 中的类/对象吗?每次 train() 函数完成时,图形对象都不会被破坏。如果我继续添加W1、b1等同名变量,就会切换到W1_1和b1_1,从而导致加载失败。我的理解对吗?这个问题是由于在训练过程结束时没有调用一些析构函数吗?谢谢!
    • 基本上,除非您明确地构造tf.Graph 并使用with 构造将其设置为默认值,否则所有节点都将添加到仅在进程结束时被销毁的全局图中。 (这并不理想,但它使其他一些用例更容易。)使用with 块可确保在块的末尾取消注册图形,这应该会给您所需的行为 - 并避免内存泄漏!
    猜你喜欢
    • 1970-01-01
    • 2019-11-15
    • 1970-01-01
    • 1970-01-01
    • 2021-10-24
    • 2018-11-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多