【发布时间】:2017-01-17 12:48:50
【问题描述】:
以下是来自 Tensorflow 网站的简单 mnist 教程(即单层 softmax),我尝试通过多线程训练步骤对其进行扩展:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import threading
# Training loop executed in each thread
def training_func():
while True:
batch = mnist.train.next_batch(100)
global_step_val,_ = sess.run([global_step, train_step], feed_dict={x: batch[0], y_: batch[1]})
print("global step: %d" % global_step_val)
if global_step_val >= 4000:
break
# create session and graph
sess = tf.Session()
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
global_step = tf.Variable(0, name="global_step")
y = tf.matmul(x,W) + b
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))
inc = global_step.assign_add(1)
with tf.control_dependencies([inc]):
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# initialize graph and create mnist loader
sess.run(tf.global_variables_initializer())
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# create workers and execute threads
workers = []
for _ in range(8):
t = threading.Thread(target=training_func)
t.start()
workers.append(t)
for t in workers:
t.join()
# evaluate accuracy of the model
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels},
session=sess))
我一定遗漏了一些东西,因为下面的 8 个线程会产生不一致的结果(精度大约 = 0.1),而使用 1 个线程只能获得预期的精度(大约 0.92)。有人知道我的错误吗?谢谢!
【问题讨论】:
-
您确实意识到 TF 图是由高度并行的引擎编译和执行的。如果您查看单线程训练期间的 CPU 利用率,您会看到所有内核都在接收负载,而不仅仅是一个。你想通过线程化训练来完成什么?我希望您看到的问题来自多个线程在没有任何控制的情况下更新权重并覆盖彼此的更改。
-
我的目标是加速昂贵的培训。我知道 TF 是真正并行的,但也可以通过多线程获得加速 - 例如在上面的示例中,range(1) 为所有内核产生 15-20% 的使用率,而 range(16) 导致 60-80% 的使用率。
-
我怀疑我的问题确实来自不受控制的并发体重更新。然而this TF tutorial code 做了一些类似于我的示例代码(l.319 到 l.340)的事情,但我不明白为什么这适用于他们的情况。也许他们的训练操作(word2vec.neg_train)在内部管理这些并发更新?
标签: python multithreading tensorflow