【发布时间】:2018-02-22 09:58:06
【问题描述】:
我有一个以这种方式创建的 MultiRNN 单元
def get_cell(cell_type, num_units, training):
if cell_type == "RNN":
cell = tf.contrib.rnn.BasicRNNCell(num_units)
elif cell_type == "LSTM":
cell = tf.contrib.rnn.BasicLSTMCell(num_units)
else:
cell = tf.contrib.rnn.GRUCell(num_units)
if training:
cell = tf.contrib.rnn.DropoutWrapper(cell,
input_keep_prob=params["dropout_input_keep_prob"],
output_keep_prob=params["dropout_output_keep_prob"],
state_keep_prob=params["dropout_state_keep_prob"])
return cell
final_cell_structure = tf.contrib.rnn.MultiRNNCell([get_cell(cell_type, num_units, (mode == tf.estimator.ModeKeys.TRAIN)) for _ in range(num_layers)])
我正在尝试将其状态初始化为随机值。我试过这样做:
initial_state = state = final_cell_structure.zero_state(batch_size, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
state = state + tf.random_normal(shape=tf.shape(state), mean=0.0, stddev=0.6)
但我不断收到一个错误提示
Expected state to be a tuple of length 3, but received: Tensor("Reshape:0", shape=(3, 1, 10), dtype=float32)
当我使用它时
output, state = final_cell_structure(inputs, state)
更新 我尝试使用
state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]
按照 Pop 的建议,它适用于 Basic RNN 单元和 GRU 单元,但是当我将它与 LSTM 单元一起使用时,出现以下错误
Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn
已解决 LSTM 单元状态由一个元组组成,所以我发现这个解决方案有效
state_placeholder = tf.random_normal(shape=(num_layers, 2, batch_size, num_units), mean=0.0, stddev=1.0)
l = tf.unstack(state_placeholder, axis=0)
state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) for idx in range(num_layers)])
【问题讨论】:
标签: python tensorflow state rnn