tf.nn.dynamic_rnn 接收一批(具有 minibatch 含义)不相关的序列。
-
cell 是您要使用的实际单元格(LSTM、GRU、...)
-
inputs 的形状为 batch_size x max_time x input_size,其中 max_time 是最长序列中的步数(但所有序列的长度可以相同)
-
sequence_length 是一个大小为 batch_size 的向量,其中每个元素都给出了批次中每个序列的长度(如果所有序列的大小相同,则将其保留为默认值。此参数是定义单元格展开大小。
隐藏状态处理
处理隐藏状态的常用方法是在dynamic_rnn 之前定义一个初始状态张量,例如:
hidden_state_in = cell.zero_state(batch_size, tf.float32)
output, hidden_state_out = tf.nn.dynamic_rnn(cell,
inputs,
initial_state=hidden_state_in,
...)
在上面的 sn-p 中,hidden_state_in 和 hidden_state_out 具有相同的形状 [batch_size, ...](实际形状取决于您使用的单元格类型,但重要的是第一个维度是批量大小)。
这样,dynamic_rnn 对每个序列都有一个初始隐藏状态。 它会自行将inputs 参数中每个序列的隐藏状态逐个传递,hidden_state_out 将包含批处理中每个序列的最终输出状态。同一批次的序列之间不传递隐藏状态,仅在同一序列的时间步之间传递。
什么时候需要手动反馈隐藏状态?
通常,当你训练时,每个批次都是不相关的,所以你在做session.run(output)时不必反馈隐藏状态。
但是,如果您正在测试,并且您需要每个时间步的输出,(即您必须在每个时间步执行 session.run()),您将需要使用评估和反馈输出隐藏状态像这样:
output, hidden_state = sess.run([output, hidden_state_out],
feed_dict={hidden_state_in:hidden_state})
否则 tensorflow 将在每个时间步仅使用默认的 cell.zero_state(batch_size, tf.float32),这相当于在每个时间步重新初始化隐藏状态。