【问题标题】:Random initial state using MultiRNNCell in TensorFlow在 TensorFlow 中使用 MultiRNNCell 的随机初始状态
【发布时间】:2018-02-22 09:58:06
【问题描述】:

我有一个以这种方式创建的 MultiRNN 单元

def get_cell(cell_type, num_units, training):
    if cell_type == "RNN":
        cell = tf.contrib.rnn.BasicRNNCell(num_units)
    elif cell_type == "LSTM":
        cell = tf.contrib.rnn.BasicLSTMCell(num_units)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units)

    if training:
        cell = tf.contrib.rnn.DropoutWrapper(cell,
                                input_keep_prob=params["dropout_input_keep_prob"],
                                output_keep_prob=params["dropout_output_keep_prob"],
                                state_keep_prob=params["dropout_state_keep_prob"])

    return cell

final_cell_structure = tf.contrib.rnn.MultiRNNCell([get_cell(cell_type, num_units, (mode == tf.estimator.ModeKeys.TRAIN)) for _ in range(num_layers)])

我正在尝试将其状态初始化为随机值。我试过这样做:

initial_state = state = final_cell_structure.zero_state(batch_size, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
    state = state + tf.random_normal(shape=tf.shape(state), mean=0.0, stddev=0.6)

但我不断收到一个错误提示

Expected state to be a tuple of length 3, but received: Tensor("Reshape:0", shape=(3, 1, 10), dtype=float32)

当我使用它时

output, state = final_cell_structure(inputs, state)

更新 我尝试使用

state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]

按照 Pop 的建议,它适用于 Basic RNN 单元和 GRU 单元,但是当我将它与 LSTM 单元一起使用时,出现以下错误

Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn

已解决 LSTM 单元状态由一个元组组成,所以我发现这个解决方案有效

state_placeholder = tf.random_normal(shape=(num_layers, 2, batch_size, num_units), mean=0.0, stddev=1.0)
l = tf.unstack(state_placeholder, axis=0)
state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) for idx in range(num_layers)])

【问题讨论】:

    标签: python tensorflow state rnn


    【解决方案1】:

    这个想法是state 是一个元组。

    所以需要这样更新:

    state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]
    

    它应该可以工作。

    使用您的方法,您创建的是单个张量 f 形状 (2, b, k),而不是具有相同大小 (b,k) 的两个张量的元组

    【讨论】:

    • 经过一些测试,我发现它适用于 BasicRNNCell 和 GRUCell,但是当我将它与 BasicLSTMCell 一起使用时,我收到以下消息:Tensor 对象在未启用急切执行时不可迭代。要迭代此张量,请使用 tf.map_fn
    猜你喜欢
    • 2018-04-06
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-05-05
    • 1970-01-01
    • 1970-01-01
    • 2019-09-07
    • 1970-01-01
    相关资源
    最近更新 更多