【问题标题】:How to copy parameters from global model to thread-specific model如何将参数从全局模型复制到线程特定模型
【发布时间】:2017-01-05 01:44:45
【问题描述】:

上下文:

我是 TensorFlow 的新手,我正在尝试实现 this paper 中的一些算法,这些算法偶尔需要从全局共享模型复制到本地线程特定模型。

问题:

完成上述任务的最佳/正确方法是什么?我在下面提供了一个虚拟示例,说明我目前正在执行此操作的方式以及我遇到的错误。有人可以解释为什么会发生错误吗?

import tensorflow as tf
import threading

class ExampleModel(object):
  def __init__(self, graph):
    with graph.as_default():
      self.w = tf.Variable(tf.constant(1, shape=[1,2]))

sess = tf.InteractiveSession()
graph = tf.get_default_graph()
global_network = ExampleModel(graph)
sess.run(tf.initialize_all_variables())

def example(i):
  global global_network, graph
  local_network = ExampleModel(graph)
  sess.run(local_network.w.assign(global_network.w))

threads = []
for i in range(5):
  t = threading.Thread(target=example, args=(i,))
  threads.append(t)

for t in threads:
  t.start()

错误:

Exception in thread Thread-3:
Traceback (most recent call last):
  File "/Users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 801, in __bootstrap_inner
    self.run()
  File "/Users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 754, in run
    self.__target(*self.__args, **self.__kwargs)
  File "tmp.py", line 16, in example
    local_network = ExampleModel(graph)
  File "tmp.py", line 7, in __init__
    self.w = tf.Variable(tf.constant(1, shape=[1,2]))
  File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 211, in __init__
dtype=dtype)
  File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 319, in _init_from_args
    self._snapshot = array_ops.identity(self._variable, name="read")
  File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2976, in __exit__
    self._graph._pop_control_dependencies_controller(self)
  File "/Users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2996, in _pop_control_dependencies_controller
    assert self._control_dependencies_stack[-1] is controller
AssertionError

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    关于Tensorflow中的tf.Graph类:

    重要提示:此类对于图形构造来说不是线程安全的。 所有操作都应从单个线程或外部创建 必须提供同步。除非另有说明,所有 方法不是线程安全的。

    self.w = ... 声明和local_network.w.assign(...) 操作导致错误。

    我知道它基本上会杀死您的多线程实现,但您可以通过将这些声明移至主线程来修复上述代码。然后,您可以使用线程来实际运行您规定的操作。例如:

    import tensorflow as tf
    import threading
    
    class ExampleModel(object):
      def __init__(self, graph):
        with graph.as_default():
          self.w = tf.Variable(tf.constant(1, shape=[1,2]))
    
    sess = tf.InteractiveSession()
    graph = tf.get_default_graph()
    global_network = ExampleModel(graph)
    sess.run(tf.global_variables_initializer())
    
    def example(sess, assign_w):
      sess.run(assign_w)
    
    threads = []
    for i in range(5):
      local_network = ExampleModel(graph)
      assign_w = local_network.w.assign(global_network.w)
      t = threading.Thread(target=example, args=(sess, assign_w))
      threads.append(t)
    
    for t in threads:
      t.start()
    

    我还建议您通过 args 参数而不是使用 global 将变量传递给线程。

    最后,考虑使用global_variables_initializer,而不是弃用的initialize_all_variables

    【讨论】:

    • 嗨,我在多线程here 方面遇到了类似的问题,是否有任何文档或示例可以指导我?谢谢!!
    猜你喜欢
    • 1970-01-01
    • 2019-05-24
    • 2016-05-16
    • 2022-10-25
    • 2015-10-30
    • 2023-03-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多