【发布时间】:2018-01-10 22:46:16
【问题描述】:
我一直在单会话环境中使用 Tensorflow 中的 cudnn_rnn 模型,它们运行良好。但是,当我尝试在 1 个 PS 主机和多个 GPU 工作人员的分布式运行中使用 cudnnLSTM 时,Tensorflow 崩溃。
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
with tf.device(tf.train.replica_device_setter(
worker_device = "/job:worker/task:%d" % TASK_INDEX, cluster = cluster)):
lstm = cudnn_rnn.CudnnLSTM(self.layers, self.hidden_units)
with tf.train.MonitoredTrainingSession(master = server.target,
is_chief = (TASK_INDEX == 0),
checkpoint_dir = CHECKPOINT_DIR,
hooks = hooks) as sess:
...
我的一个工作进程(可以访问 GPU)出现以下错误:
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'save/CudnnRNNCanonicalToParams': Could not satisfy explicit device specification '/job:worker/task:0/device:CPU:0' because no supported kernel for CPU devices is available.
[[Node: save/CudnnRNNCanonicalToParams = CudnnRNNCanonicalToParams[T=DT_FLOAT, direction="unidirectional", dropout=0, input_mode="linear_input", num_params=12, rnn_mode="gru", seed=0, seed2=0, _device="/job:worker/task:0/device:CPU:0"](save/CudnnRNNCanonicalToParams/num_layers, save/CudnnRNNCanonicalToParams/num_units, save/CudnnRNNCanonicalToParams/input_size, save/Reshape, save/Reshape_1, save/Reshape_2, save/Reshape_3, save/Reshape_4, save/Reshape_5, save/Reshape_6, save/Reshape_7, save/Reshape_8, save/Reshape_9, save/Reshape_10, save/Reshape_11, save/split_3, save/split_3:1, save/RestoreV2_22, save/split_4, save/split_4:1, save/RestoreV2_23, save/split_8, save/split_8:1, save/RestoreV2_25, save/split_9, save/split_9:1, save/RestoreV2_26)]]
我尝试在MonitoredTrainingSession 中设置save_checkpoint_secs = None,但仍然遇到同样的错误。
我已阅读 tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py 中提到保存参数和使用 PS 服务器的 cmets,但找不到工作示例。关于如何使分布式张量流和 cudnnLSTM 协同工作的任何想法?
更新: @Ash 关于更新 tensorflow 的回答有所帮助。另外,现在,我需要在 Saver 中指定不分片:
with tf.train.MonitoredTrainingSession(master = server.target,
is_chief = (TASK_INDEX == 0),
checkpoint_dir = CHECKPOINT_DIR,
scaffold = tf.train.Scaffold(
saver = tf.train.Saver(sharded = False, allow_empty = True)),
hooks = hooks) as sess:
【问题讨论】:
标签: tensorflow