【问题标题】:Should GridSearchCV score results be equal to score of cross_validate using same input?GridSearchCV 得分结果是否应该等于使用相同输入的 cross_validate 得分?
【发布时间】:2019-11-14 13:42:14
【问题描述】:

我正在玩一些 scikit-learn,并希望重现执行网格搜索的一个特定超参数组合的交叉验证分数。

对于网格搜索,我使用了GridSearchCV 类,为了重现一个特定超参数组合的结果,我使用了具有完全相同拆分和分类器设置的cross_validate 函数。

我的问题是我没有得到预期的分数结果,据我了解,这应该与执行相同的计算以获得两种方法的分数完全相同。

通过修复训练数据上使用的拆分,我确保从脚本中排除任何随机源。

在下面的代码 sn-p 中,给出了所述问题的示例。

import numpy as np
from sklearn.model_selection import cross_validate, StratifiedKFold, GridSearchCV
from sklearn.svm import NuSVC

np.random.seed(2018)

# generate random training features
X = np.random.random((100, 10))

# class labels
y = np.random.randint(2, size=100)

clf = NuSVC(nu=0.4, gamma='auto')

# Compute score for one parameter combination
grid = GridSearchCV(clf,
                    cv=StratifiedKFold(n_splits=10, random_state=2018),
                    param_grid={'nu': [0.4]},
                    scoring=['f1_macro'],
                    refit=False)

grid.fit(X, y)
print(grid.cv_results_['mean_test_f1_macro'][0])

# Recompute score for exact same input
result = cross_validate(clf,
                        X,
                        y,
                        cv=StratifiedKFold(n_splits=10, random_state=2018),
                        scoring=['f1_macro'])

print(result['test_f1_macro'].mean())

执行给定的sn-p会得到输出:

0.38414468864468865
0.3848840048840049

我本来希望这些分数完全相同,因为它们是在相同的拆分上计算的,使用相同的训练数据和相同的分类器。

【问题讨论】:

    标签: python machine-learning scikit-learn cross-validation grid-search


    【解决方案1】:

    这是因为mean_test_f1_macro 不是所有折叠组合的简单平均,它是权重平均,权重是测试折叠的大小。了解更多参考this答案的实际实现。

    现在,要复制 GridSearchCV 结果,试试这个!

    print('grid search cv result',grid.cv_results_['mean_test_f1_macro'][0])
    
    # grid search cv result 0.38414468864468865
    
    print('simple mean: ', result['test_f1_macro'].mean())
    
    # simple mean:  0.3848840048840049
    
    weights= [len(test) for (_, test) in StratifiedKFold(n_splits=10, random_state=2018).split(X,y)]
    print('weighted mean: {}'.format(np.average(result['test_f1_macro'], axis=0, weights=weights)))
    
    # weighted mean: 0.38414468864468865
    

    【讨论】:

      猜你喜欢
      • 2020-05-03
      • 1970-01-01
      • 2021-07-17
      • 2013-08-22
      • 2019-09-07
      • 2021-07-19
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多