【问题标题】:Value Error - Error when checking target - LSTM值错误 - 检查目标时出错 - LSTM
【发布时间】:2020-05-03 11:16:01
【问题描述】:

关于数据集

以下路透社数据集包含 11228 条文本,对应于分类为 46 个类别的新闻。这些文本是在每个单词对应一个整数的意义上进行编码的。我指定我们要使用 2000 个单词。

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

num_words = 2000
(reuters_train_x, reuters_train_y), (reuters_test_x, reuters_test_y) = tf.keras.datasets.reuters.load_data(num_words=num_words)

n_labels = np.unique(reuters_train_y).shape[0]
print("labels: {}".format(n_labels))

# This is the first new
print(reuters_train_x[0])

实现 LSTM

我需要用一个具有 10 个单元的 LSTM 来实现一个网络。输入在进入 LSTM 单元之前需要一个 10 维的嵌入。最后,需要添加一个dense layer来根据类别的数量来调整输出的数量。

from keras.models import Sequential
from keras.layers import LSTM, Dense, Embedding
from from tensorflow.keras.utils import to_categorical

reuters_train_y = to_categorical(reuters_train_y, 46)
reuters_test_y = to_categorical(reuters_test_y, 46)

model = Sequential()
model.add(Embedding(input_dim = num_words, 10))
model.add(LSTM(10))
model.add(Dense(46,activation='softmax'))

培训

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
history = model.fit(reuters_train_x,reuters_train_y,epochs=20,validation_data=(reuters_test_x,reuters_test_y))

我得到的错误信息是:

ValueError: Error when checking target: expected dense_2 to have shape (46,) but got array with shape (1,)

【问题讨论】:

    标签: machine-learning keras lstm recurrent-neural-network


    【解决方案1】:

    您需要对 y 标签进行一次热编码。

    from tensorflow.keras.utils import to_categorical
    
    reuters_train_y = to_categorical(reuters_train_y, 46)
    
    reuters_test_y = to_categorical(reuters_test_y, 46)
    
    

    我在fit 函数中看到的另一个错误,您正在传递validation_data=(reuters_test_x,reuters_train_y),但它应该是validation_data=(reuters_test_x,reuters_test_y)

    您的 x 是具有不同长度的列表的 numpy 数组。您需要填充序列以获得固定形状的 numpy 数组。

    reuters_train_x = tf.keras.preprocessing.sequence.pad_sequences(
        reuters_train_x, maxlen=50
    )
    
    reuters_test_x = tf.keras.preprocessing.sequence.pad_sequences(
        reuters_test_x, maxlen=50
    )
    

    【讨论】:

    • 我已经相应地修改了代码,但在模型训练期间仍然出现错误。 ValueError:使用序列设置数组元素。
    • 检查更新的答案,你需要使用一些填充。
    猜你喜欢
    • 2019-01-11
    • 2018-02-20
    • 1970-01-01
    • 1970-01-01
    • 2020-07-11
    • 2019-05-07
    • 2017-12-19
    • 1970-01-01
    • 2021-08-16
    相关资源
    最近更新 更多