【发布时间】:2019-05-23 04:43:24
【问题描述】:
目标是在具有多个输入的 Keras 模型上执行交叉验证。这适用于只有一个输入的正常顺序模型。但是,当使用功能 api 并扩展到两个输入时,sklearns cross_val_score 似乎无法按预期工作。
def create_model():
input_text = Input(shape=(1,), dtype=tf.string)
embedding = Lambda(UniversalEmbedding, output_shape=(512, ))(input_text)
dense = Dense(256, activation='relu')(embedding)
input_title = Input(shape=(1,), dtype=tf.string)
embedding_title = Lambda(UniversalEmbedding, output_shape=(512, ))(input_title)
dense_title = Dense(256, activation='relu')(embedding_title)
out = Concatenate()([dense, dense_title])
pred = Dense(2, activation='softmax')(out)
model = Model(inputs=[input_text, input_title], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
失败的部分
keras_classifier = KerasClassifier(build_fn=create_model, epochs=10, batch_size=10, verbose=1)
cv = StratifiedKFold(n_splits=10, random_state=0)
results = cross_val_score(keras_classifier, [X1, X2], y, cv=cv, scoring='f1_weighted')
错误
Traceback (most recent call last):
File "func.py", line 73, in <module>
results = cross_val_score(keras_classifier, [X1, X2], y, cv=cv, scoring='f1_weighted')
File "/home/timisb/.local/lib/python3.6/site-packages/sklearn/model_selection/_validation.py", line 402, in cross_val_score
error_score=error_score)
File "/home/timisb/.local/lib/python3.6/site-packages/sklearn/model_selection/_validation.py", line 225, in cross_validate
X, y, groups = indexable(X, y, groups)
File "/home/timisb/.local/lib/python3.6/site-packages/sklearn/utils/validation.py", line 260, in indexable
check_consistent_length(*result)
File "/home/timisb/.local/lib/python3.6/site-packages/sklearn/utils/validation.py", line 235, in check_consistent_length
" samples: %r" % [int(l) for l in lengths])
ValueError: Found input variables with inconsistent numbers of samples: [2, 643]
有没有人对此有替代方法或解决方案的建议?谢谢!
【问题讨论】:
-
这里传递多个输入的解决方法:stackoverflow.com/questions/56824968/…
标签: tensorflow scikit-learn keras