【发布时间】:2017-11-22 01:53:06
【问题描述】:
我对 Tensorflow 很陌生,所以我的问题可能听起来很愚蠢,但我真的找不到合适的解释,所以在这里问。 我需要您的帮助来了解如何在图分布式 Tensorflow 程序中进行数据批处理或分布。
由于我们执行多个客户端,它们本质上具有相同的代码来获取下一批:
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
我不明白这将如何确保对非常工人的独特批次。对我来说,似乎相同的数据正在发送给所有工人。
在这个示例脚本中,我们在每次迭代时都读取 next_batch,并且由于我们正在运行两个带有 job_type=worker 的客户端,因此两个 worker 将看到相同的 next_batch 代码。请帮助我了解在这种情况下数据并行性如何工作。
with sv.prepare_or_wait_for_session(server.target, config=sess_config) as sess:
print("Worker %d: Session initialization complete." % FLAGS.task_index)
# Loop until the supervisor shuts down or 1000000 steps have completed.
step = 0
while not sv.should_stop() and step < 1000000:
# Run a training step asynchronously.
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
print("FETCHING NEXT BATCH %d" % FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step = sess.run([train_op, global_step], feed_dict=train_feed)
if step % 100 == 0:
print("Done step %d" % step)
# Ask for all the services to stop.
sv.stop()
期待您的帮助。
【问题讨论】: