【问题标题】:How to specify input_shape for Keras Sequential model如何为 Keras Sequential 模型指定 input_shape
【发布时间】:2018-12-27 02:54:48
【问题描述】:

你如何处理这个错误?

检查目标时出错:预期dense_3 的形状为(1,),但得到的数组的形状为(398,)

我尝试更改 input_shape=(14,),这是 train_samples 中的列数,但我仍然得到错误。

set = pd.read_csv('NHL_DATA.csv')
set.head()

train_labels = [set['Won/Lost']] 
train_samples = [set['team'], set['blocked'],set['faceOffWinPercentage'],set['giveaways'],set['goals'],set['hits'],
            set['pim'], set['powerPlayGoals'], set['powerPlayOpportunities'], set['powerPlayPercentage'],
           set['shots'], set['takeaways'], set['homeaway_away'],set['homeaway_home']]

train_labels = np.array(train_labels)
train_samples = np.array(train_samples)

scaler = MinMaxScaler(feature_range=(0,1))
scaled_train_samples = scaler.fit_transform(train_samples).reshape(-1,1)

model = Sequential()

model.add(Dense(16, input_shape=(14,), activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(2, activation='softmax'))

model.compile(Adam(lr=.0001), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(scaled_train_samples, train_labels, batch_size=1, epochs=20, shuffle=True, verbose=2)

【问题讨论】:

    标签: python python-3.x keras sequential


    【解决方案1】:

    1) 你用.reshape(-1,1) 重塑你的训练样本,这意味着所有的训练样本都是一维的。但是,您将网络的输入形状定义为 input_shape=(14,),表示输入维度为 14。我想这是您的模型存在的一个问题。

    2) 你使用了sparse_categorical_crossentropy,这意味着真实标签是稀疏的(train_labels 应该是稀疏的),但我想不是。

    以下是您的输入应该如何的示例:

    import numpy as np
    from tensorflow.python.keras.engine.sequential import Sequential
    from tensorflow.python.keras.layers import Dense
    
    x = np.zeros([1000, 14])
    y = np.zeros([1000, 2])
    
    model = Sequential()
    
    model.add(Dense(16, input_shape=(14,), activation='relu'))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(2, activation='softmax'))
    
    model.compile('adam', 'categorical_crossentropy')
    model.fit(x, y, batch_size=1, epochs=1)
    

    【讨论】:

    • 是的,抱歉,忘记将它改回我收到错误时的状态,即 input_shape=(1, )。您向我展示的示例,它如何合并我的 NHL_DATA.csv ?
    • @MisterButter 上面的代码适用于 Tensorflow 1.12,您可以适当调整输入以适合您的模型。
    • 哦,谢谢!会尝试,我不是特别擅长这样做,主要是应对其他人的代码并在某种程度上对其进行调整。谢谢你的例子!
    猜你喜欢
    • 1970-01-01
    • 2019-03-23
    • 1970-01-01
    • 2019-03-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-01-02
    相关资源
    最近更新 更多