【问题标题】:How can I combine Pipeline with cross_val_score for a multiclass problem?如何将 Pipeline 与 cross_val_score 结合起来解决多类问题?
【发布时间】:2021-11-02 11:06:11
【问题描述】:

这是一个非常直截了当的问题,我认为我不能比直接问题添加更多内容:如何将 pipeline 与 cross_val_score 结合起来解决多类问题?

我在工作中处理一个多类问题(这就是为什么我不会分享任何数据的原因,但人们可以将这个问题视为与 iris 数据集有关的问题),我需要根据主题对一些文本进行相应的分类。这就是我正在做的:

pipe = Pipeline(
steps=[
    ("vect", CountVectorizer()),
    ("feature_selection", SelectKBest(chi2, k=10)),
    ("reg", RandomForestClassifier()),
])

pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)

print(classification_report(y_test, y_pred))

但是,我有点担心过度拟合(即使我正在使用测试集进行评估),我想进行更严格的分析并添加交叉验证。问题是我不知道如何在管道中添加 cross_val_score,也不知道如何使用交叉验证评估多类问题。我看到了这个answer,所以我把它添加到我的脚本中:

cv = KFold(n_splits=5)
scores = cross_val_score(pipe, X_train, y_train, cv = cv)

问题是这样会导致准确率,当我们讨论分类问题时,这并不是那么好。

还有其他选择吗?是否可以进行交叉验证而不仅仅获得准确性?还是我应该坚持准确性,这不是任何原因造成的问题?

我知道这个问题太“宽泛”了,实际上不仅仅是关于交叉验证,我希望这不是问题。

提前致谢

【问题讨论】:

    标签: scikit-learn pipeline cross-validation multiclass-classification


    【解决方案1】:

    几乎总是建议use cross validation to choose your model/hyperparameters,然后使用独立的保持测试集来评估模型的性能。

    好消息是,您可以在 scikit-learn 中完成您想做的事情!像这样的:

    pipe = Pipeline(
      steps=[
        ("vect", CountVectorizer()),
        ("feature_selection", SelectKBest(chi2, k=10)),
        ("reg", RandomForestClassifier())])
    
    # Parameters of pipelines can be set using ‘__’ separated parameter names:
    param_grid = {
        'feature_selection__k': np.linspace(4, 16, 4), # Test different number of features in SelectKBest
        'reg__n_estimators': [10, 30, 50, 100, 200],  # n_estimators in RandomForestClassifier
        'reg__min_samples_leaf': [2, 5, 10, 50] # min_samples_leaf in RandomForestClassifier
    }
    
    # This defines the grid search with "Area Under the ROC Curve" as the scoring metric to use.
    # More options here: https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
    search = GridSearchCV(pipe, param_grid, scoring='roc_auc',)
    
    search.fit(X_train, y_train)
    print("Best parameter (CV score={:3f}:".format(search.best_score_))
    print(search.best_params_)
    

    更多详情请参见here

    如果您想为多类问题定义自己的评分指标,而不是使用 AUC 或其他默认评分指标,请参阅 the documentation under the scoring parameter on this page for more,但我建议您不要知道您要优化的指标是什么.

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-06-09
      • 2021-08-23
      • 2013-03-05
      • 2020-06-13
      • 2018-10-13
      • 2013-04-19
      • 2017-01-22
      • 2015-07-10
      相关资源
      最近更新 更多