【问题标题】:Replicating models in Keras and Tensorflow for a multi-threaded setting在 Keras 和 Tensorflow 中复制模型以实现多线程设置
【发布时间】:2016-10-20 12:22:00
【问题描述】:

我正在尝试在 Keras 和 TensorFlow 中实现 actor-critic 的异步版本。我将 Keras 用作构建网络层的前端(我直接使用 tensorflow 更新参数)。我有一个 global_model 和一个主要的 tensorflow 会话。但在每个线程中,我创建了一个local_model,它从global_model 复制参数。我的代码看起来像这样

def main(args):
    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
    sess = tf.Session(config=config)
    K.set_session(sess) # K is keras backend
    global_model = ConvNetA3C(84,84,4,num_actions=3)

    threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]

    for t in threads:
        t.start()

def a3c_thread(i, sess, global_model):
    K.set_session(sess) # registering a session for each thread (don't know if it matters)
    local_model = ConvNetA3C(84,84,4,num_actions=3)
    sync = local_model.get_from(global_model) # I get the error here

    #in the get_from function I do tf.assign(dest.params[i], src.params[i])

我收到来自 Keras 的用户警告

UserWarning:默认的 TensorFlow 图不是关联的图 使用当前在 Keras 注册的 TensorFlow 会话,并且作为 这样的 Keras 无法自动初始化变量。你 应该考虑通过 Keras 注册正确的会话 K.set_session(sess)

tf.assign 操作上出现 tensorflow 错误,说明操作必须在同一个图上。

ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) 必须来自同一个图 as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)

我不太确定出了什么问题。

谢谢

【问题讨论】:

    标签: python multithreading tensorflow keras


    【解决方案1】:

    错误来自 Keras,因为 tf.get_default_graph() is sess.graph 正在返回 False。从 TF 文档中,我看到 tf.get_default_graph() 正在返回当前线程的默认图。当我开始一个新线程并创建一个图表时,它被构建为一个特定于该线程的单独图表。我可以通过执行以下操作来解决此问题,

    with sess.graph.as_default():
       local_model = ConvNetA3C(84,84,4,3)
    

    【讨论】:

      猜你喜欢
      • 2017-07-08
      • 1970-01-01
      • 1970-01-01
      • 2017-02-23
      • 2021-12-26
      • 2021-11-03
      • 1970-01-01
      • 1970-01-01
      • 2019-07-02
      相关资源
      最近更新 更多