【问题标题】:Why we should call split() function during passing StratifiedKFold() as a parameter of GridSearchCV?为什么我们应该在传递 StratifiedKFold() 作为 GridSearchCV 的参数时调用 split() 函数?
【发布时间】:2020-09-22 05:26:19
【问题描述】:

我想做什么?

我正在尝试在GridSearchCV() 中使用StratifiedKFold()

那么,我有什么困惑呢?

当我们使用 K 折交叉验证时,我们只需在 GridSearchCV() 中传递 CV 的数量,如下所示。

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=5, scoring='f1', return_train_score=True, n_jobs=2)

然后,当我需要使用StratifiedKFold() 时,我认为程序应该保持不变。即,仅设置拆分数 - StratifiedKFold(n_splits=5)cv

grid_search_m = GridSearchCV(rdm_forest_clf, param_grid, cv=StratifiedKFold(n_splits=5), scoring='f1', return_train_score=True, n_jobs=2)

但是this answer

无论使用什么cross validation strategy,只需要 按照建议使用函数 split 提供生成器:

kfolds = StratifiedKFold(5)
clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain))
clf.fit(xtrain, ytrain)

另外,this question 的回答之一也建议这样做。这意味着,他们建议在使用GridSearchCV() 期间调用拆分函数:StratifiedKFold(n_splits=5).split(xtrain,ytrain)。但是,我发现调用split() 和不调用split() 给我相同的f1 分数。

因此,我的问题

  • 我不明白为什么我们需要在分层 K 折叠过程中调用 split() 函数为 我们不需要在 K Fold CV 期间做这种事情。

  • 如果调用split() 函数,GridSearchCV() 将如何作为Split() 函数returns training and testing data set indices 工作?也就是说,我想知道GridSearchCV() 将如何使用这些索引?

【问题讨论】:

    标签: split cross-validation grid-search gridsearchcv k-fold


    【解决方案1】:

    基本上 GridSearchCV 很聪明,可以为 cv 参数采用多个选项 - 一个数字、拆分索引的迭代器或具有拆分功能的对象。可以看下代码here,复制如下。

    cv = 5 if cv is None else cv
    if isinstance(cv, numbers.Integral):
        if (classifier and (y is not None) and
                (type_of_target(y) in ('binary', 'multiclass'))):
            return StratifiedKFold(cv)
        else:
            return KFold(cv)
    
    if not hasattr(cv, 'split') or isinstance(cv, str):
        if not isinstance(cv, Iterable) or isinstance(cv, str):
            raise ValueError("Expected cv as an integer, cross-validation "
                             "object (from sklearn.model_selection) "
                             "or an iterable. Got %s." % cv)
        return _CVIterableWrapper(cv)
    
    return cv  # New style cv objects are passed without any modification
    

    基本上,如果你什么都不通过,它会使用一个 5 的 KFold。如果它是一个分类问题并且目标是二元/多类,它也足够聪明地自动使用 StratifedKFold。

    如果你传递一个带有 split 函数的对象,它只会使用它。如果您不传递其中任何一个,但传递了一个可迭代对象,它会假定这是拆分索引的可迭代对象并为您包装它。

    因此,在您的情况下,假设这是一个具有二进制/多类目标的分类问题,以下所有内容将给出完全相同的结果/拆分 - 您使用哪一个并不重要!

    cv=5
    cv=StratifiedKFold(5)
    cv=StratifiedKFold(5).split(xtrain,ytrain)
    

    【讨论】:

    • 感谢您的回复。你说 “如果你传递一个带有拆分函数的对象,它只会使用它。” 但我不明白 “GridSearchCV() 将如何使用那些通过拆分找到的索引?” 你能描述一下吗?
    • 因此对于网格搜索中的每个参数集,它将使用拆分来运行交叉验证 - 因此,如果您在参数网格中有 2 个参数的 3 个选项(6 个集),并且 5 折交叉验证,然后你真的训练和验证 30 个模型。然后在交叉验证运行中具有最高平均验证分数的参数集被选为“最佳”
    猜你喜欢
    • 2019-09-29
    • 2016-05-04
    • 1970-01-01
    • 1970-01-01
    • 2019-09-26
    • 2020-12-27
    • 2021-05-29
    • 1970-01-01
    • 2019-04-03
    相关资源
    最近更新 更多