【问题标题】:Keras LSTM for converting sentences to document context vectorKeras LSTM 用于将句子转换为文档上下文向量
【发布时间】:2019-12-24 16:12:17
【问题描述】:

我阅读了以下博客文章并尝试通过 Keras 实现它: https://andriymulyar.com/blog/bert-document-classification

现在,我对 Keras 很陌生,我不明白如何使用“seq2seq 神经网络”将一系列子块(句子)压缩成全局上下文向量(文档向量)。 -通过 LSTM ..

例如:我有 10 个文档,每个文档包含 100 个句子,每个句子由一个 1x500 向量表示。 所以数组看起来像这样:

X = np.array(Matrix).reshape(10, 100, 500) # reshape to 10 documents with 100 sequence of 500 features

所以我知道我想训练我的网络并采用最后一个隐藏层,因为这个隐藏层代表我的文档向量/全局上下文向量。

然而,对我来说最困难的部分是想象输出向量.. 我只是枚举我的文档

y = [1,2,3,4,5,6,7,8,9,10]
y = np.array(y)

还是我必须使用 one-hot-encoded 输出向量:

yy = to_categorical(y)

甚至是别的什么......?

据我了解,最终模型应如下所示:

model = Sequential()
model.add(LSTM(50, input_shape=(100,500)))
model.add(Dense(1))
model.compile(loss='categorical_crossentropy',optimizer='rmsprop')
model.fit(X, yy, epochs=100, validation_split=0.2, verbose=1)

【问题讨论】:

    标签: python keras nltk lstm seq2seq


    【解决方案1】:

    这仅取决于您使用的数据:

    对于 one-hot 编码,使用 Categorical Crossentropy Loss

    model.compile(loss='categorical_crossentropy',optimizer='rmsprop')
    

    对于标签编码,使用 Sparse Categorical Crossentropy Loss

    model.compile(loss='sparse_categorical_crossentropy',optimizer='rmsprop')
    

    两个版本的基本方法相同。 因此,如果您有 目标数据y,例如:

    Class1 Class2 Class3
    0      0      1
    1      0      0
    1      0      0
    0      1      0
    

    你应该像这样编译你的模型:

    model.compile(loss='categorical_crossentropy',optimizer='rmsprop')
    

    相反,如果您有一个目标数据y,例如:

    labels
    2
    0
    0
    1
    

    你应该像这样编译你的模型:

    model.compile(loss='sparse_categorical_crossentropy',optimizer='rmsprop')
    

    您的模型的结果和性能将相同,只是内存使用会受到影响。

    【讨论】:

    • 谢谢,我明白了.. 对于单热编码标签数据,它现在可以工作了,但我也忘记调整为 model.add(Dense(10) 以反映 10 个标签。但它仍然不适用于只有一个输出向量 y=[1,2,...10] 以反映所有文档的目标数据。我收到以下错误(如果为 5 个文档完成 y=[1,2,3,4,5]InvalidArgumentError: Received a label value of 5 which is outside the valid range of [0, 5). Label values: 5 [[{{node loss_15/dense_15_loss/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]]
    • @Felix 如果输出密集层有 5 个输出,它将除了 [0,1,2,3,4] 中的标签之外,因为 Python 的索引来自 0。所以你的标签应该有这些值,这就是抱怨的原因。如果您满意,请不要忘记用复选标记接受我的回答,
    猜你喜欢
    • 1970-01-01
    • 2015-08-28
    • 1970-01-01
    • 2019-09-28
    • 1970-01-01
    • 2020-08-20
    • 2018-06-29
    • 1970-01-01
    • 2021-05-16
    相关资源
    最近更新 更多