【发布时间】:2020-03-13 04:07:52
【问题描述】:
版本:Python 3.6.9、Tensorflow 2.0.0、CUDA 10.0、CUDNN 7.6.1、Nvidia 驱动程序版本 410.78。
我正在尝试将基于 LSTM 的 Seq2Seq tf.keras 模型移植到 tensorflow 2.0
现在,当我尝试在解码器模型上调用 predict 时遇到以下错误(有关实际推理设置代码,请参见下文)
就好像它期待一个单个单词作为参数,但我需要它来解码一个完整的句子(我的句子是单词索引的右填充序列,长度为 24) p>
P.S.:这段代码在 TF 1.15 上使用完全一样
InvalidArgumentError: [_Derived_] Inputs to operation while/body/_1/Select_2 of type Select must have the same size and shape.
Input 0: [1,100] != input 1: [24,100]
[[{{node while/body/_1/Select_2}}]]
[[lstm_1_3/StatefulPartitionedCall]] [Op:__inference_keras_scratch_graph_45160]
Function call stack:
keras_scratch_graph -> keras_scratch_graph -> keras_scratch_graph
完整模型
ENCODER 推理模型
解码器推理模型
推理设置(实际发生错误的行)
重要信息:序列右填充到 24 个元素,100 是每个词嵌入的维数。这就是错误消息(和打印)显示输入形状为(24,100) 的原因。
请注意,此代码在 CPU 上运行。在 GPU 上运行它会导致另一个错误,详细说明 here
# original_keyword is a sample text string
with tf.device("/device:CPU:0"):
# this method turns the raw string into a right-padded sequence
query_sequence = keyword_to_padded_sequence_single(original_keyword)
# no problems here
initial_state = encoder_model.predict(query_sequence)
print(initial_state[0].shape) # prints (24, 100)
print(initial_state[1].shape) # (24, 100)
empty_target_sequence = np.zeros((1,1))
empty_target_sequence[0,0] = word_dict_titles["sos"]
# ERROR HAPPENS HERE:
# InvalidArgumentError: [_Derived_] Inputs to operation while/body/_1/Select_2 of type Select
# must have the same size and shape. Input 0: [1,100] != input 1: [24,100]
decoder_outputs, h, c = decoder_model.predict([empty_target_sequence] + initial_state)
我尝试过的事情
禁用 Eager 模式(这只会使训练速度变慢,并且推理期间的错误保持不变)
在将输入提供给预测函数之前对其进行整形
在调用 LSTM 层时手动计算 (
embedding_layer.compute_mask(inputs)) 并设置掩码
【问题讨论】:
-
如何构建图层?您是否在 LSTM 层中设置了 return_sequences=True?
-
@emirc 编码器是
False。对于解码器,它是True。以下是完整代码:gist.github.com/queirozfcom/20d76e3113c649660df8dc1e59455680 -
您好,您可以尝试将
decoder_inputs形状更改为decoder_inputs = tf.keras.layers.Input(shape=(None,),name="decoder_input")。错误即将到来,因为empty_target_sequence的形状为(1,1),而您的解码器需要形状为(?,24)的输入。
标签: python lstm tensorflow2.0 tf.keras seq2seq