【发布时间】:2020-01-28 18:04:23
【问题描述】:
我正在尝试使用 Tensorflow 实现编码器解码器模型。编码器是一个双向单元。
def encoder(hidden_units, encoder_embedding, sequence_length):
forward_cell = tf.contrib.rnn.LSTMCell(hidden_units)
backward_cell = tf.contrib.rnn.LSTMCell(hidden_units)
bi_outputs, final_states = tf.nn.bidirectional_dynamic_rnn(forward_cell, backward_cell, encoder_embedding, sequence_length= sequence_length, dtype=tf.float32)
encoder_outputs = tf.concat(bi_outputs, 2)
forward_cell_state, backward_cell_state =final_states
cell_state_final = tf.concat([forward_cell_state.c, backward_cell_state.c],1)
hidden_state_final = tf.concat([forward_cell_state.h, backward_cell_state.h],1)
encoder_final_state = tf.nn.rnn_cell.LSTMStateTuple(c=cell_state_final, h=hidden_state_final)
return encoder_outputs, encoder_final_state
编码器和解码器之间出现问题。我收到类似 ValueError 的错误:Shapes (?, 42) and (12, 21) are not compatible ....
Decoder 有一个注意力机制,看起来像这样:
def decoder(decoder_embedding, vocab_size, hidden_units, sequence_length, encoder_output, encoder_state, batchsize):
projection_layer = Dense(vocab_size)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding, sequence_length=sequence_length)
# Decoder
decoder_cell = tf.contrib.rnn.LSTMCell(hidden_units)
# Attention Mechanis
attention_mechanism = tf.contrib.seq2seq.LuongAttention(hidden_units, encoder_output)
attn_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=hidden_units)
# Initial attention
attn_zero = attn_cell.zero_state(batch_size=batchsize, dtype=tf.float32)
ini_state = attn_zero.clone(cell_state=encoder_state)
decoder = tf.contrib.seq2seq.BasicDecoder(cell=attn_cell, initial_state=ini_state, helper=helper, output_layer=projection_layer)
decoder_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder)
return decoder_outputs
如何解决这个问题?
【问题讨论】:
标签: tensorflow lstm bidirectional seq2seq attention-model