【问题标题】:Can I concatenate an Embedding layer with a layer of shape (?, 5) in keras?我可以在keras中将嵌入层与形状层(?,5)连接起来吗?
【发布时间】:2019-02-25 16:53:14
【问题描述】:

我想创建一个 LSTM 内存。 LSTM 应该预测给定句子的长度为 4 的 one-hot 编码值。第一步很容易。

我想做的下一件事是向我的数据集添加其他信息。该信息是一个长度为 5 的 one-hot 编码向量。

我的想法是在将数据传递给 LSTM 之前将嵌入层与另一个输入形状连接起来。这对我来说是这样的:

main_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32', name='main_input')
embedding = Embedding(MAX_NB_WORDS, EMBEDDING_SIZE,
                    input_length=MAX_SEQUENCE_LENGTH)(main_input)

# second input model
auxiliary_input = Input(shape=(5,), name='aux_input')
x = concatenate([embedding, auxiliary_input])

lstm = LSTM(HIDDEN_LAYER_SIZE)(x)

main_output = Dense(4, activation='sigmoid', name='main_output')(lstm)

model = Model(inputs=[main_input, auxiliary_input], outputs=main_output)

但是如果我尝试进行这样的设置,我会收到以下错误:ValueError:连接层需要具有匹配形状的输入,但连接轴除外。得到输入形状:[(None, 50, 128), (None, 5)]

我创建了嵌入层的 LSTM 并将其连接到辅助输入,但之后我无法再运行 LSTM(出现错误:ValueError:输入 0 与层 lstm_2 不兼容:预期ndim=3,发现 ndim=2)

所以我的问题是:在 keras 中构建具有嵌入层输入和附加数据的 LSTM 的正确方法是什么?

【问题讨论】:

  • 第一个问题是“你想如何添加这些数据”。 “在哪里”。 “这是什么意思?”。然后你决定使用哪种方法。无法连接 5 个值,形状 (None, 5) 到一系列值,形状 (None, 50, 128)

标签: python python-3.x tensorflow keras lstm


【解决方案1】:

您似乎在这里尝试传递有关完整序列(而不是每个令牌)的附加信息,这就是您遇到不匹配问题的原因。

有几种方法可以解决这个问题,各有利弊

(1) 您可以将 aux_data 与 lstm 的最后一个输出连接起来,因此连接 concat_with_aux = concatenate([auxiliary_input,lstm]) 并将此连接向量传递给您的模型。 这里的意思是,如果你有两个不同类别的相同序列,LSTM 的输出将是相同的,那么在连接之后,密集分类器的工作就是使用这个连接的结果来产生正确的输出。

(2)如果要在LSTM的输入端直接传递信息。例如,您可以为您的类别创建新的可训练Embedding 层:

auxiliary_input = Input(shape=(1,), name='aux_input') # Now you pass the idx (0,1,2,3,4) not the one_hot encoded form
embed_categories = Embedding(5, EMBEDDING_SIZE,
                    input_length=1)(auxiliary_input)

x = concatenate([embed_categories, embedding])

通过这样做,您的 LSTM 将以您的辅助信息为条件,并且具有不同类别的两个相同句子将具有不同的最后 lstm 输出。

【讨论】:

    猜你喜欢
    • 2019-04-06
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-02-27
    • 2017-05-25
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多