【问题标题】:TypeError: __init__() got multiple values for argument 'n_splits'TypeError:__init__() 为参数“n_splits”获得了多个值
【发布时间】:2019-07-05 11:35:36
【问题描述】:

我正在使用 SKLearn 版本 (0.20.2),如下:

from sklearn.model_selection import StratifiedKFold


grid = GridSearchCV(
    pipeline,  # pipeline from above
    params,  # parameters to tune via cross validation
    refit=True,  # fit using all available data at the end, on the best found param combination
    scoring='accuracy',  # what score are we optimizing?
    cv=StratifiedKFold(label_train, n_splits=5),  # what type of cross validation to use
)

但我不明白为什么我会收到这个错误:


TypeError                                 Traceback (most recent call last)
<ipython-input-26-03a56044cb82> in <module>()
     10     refit=True,  # fit using all available data at the end, on the best found param combination
     11     scoring='accuracy',  # what score are we optimizing?
---> 12     cv=StratifiedKFold(label_train, n_splits=5),  # what type of cross validation to use
     13 )

TypeError: __init__() got multiple values for argument 'n_splits'

我已经尝试过n_fold,但出现了相同的错误结果。并且也厌倦了更新我的 scikit 版本和我的 conda。有什么办法解决这个问题吗?非常感谢!

【问题讨论】:

  • 删除label_train; first 参数被命名为n_splits

标签: python scikit-learn


【解决方案1】:

StratifiedKFold 在初始化时正好有 3 个参数,它们都不是训练数据:

StratifiedKFold(n_splits=’warn’, shuffle=False, random_state=None)

所以当你调用StratifiedKFold(label_train, n_splits=5) 时,它认为你通过了n_splits 两次。

相反,创建对象,然后使用 sklearn 文档页面上示例中描述的方法来使用对象拆分数据:

get_n_splits([X, y, groups]) 返回分割数 交叉验证器中的迭代 split(X, y[, groups]) 生成 将数据拆分为训练集和测试集的索引。

【讨论】:

    【解决方案2】:

    StratifiedKFold 接受三个参数,但您传递了两个参数。在 sklearn 中查看更多信息documentation

    创建 StratifiedKFold 对象并将其传递给 GridSearchCV,如下所示。

    skf = StratifiedKFold(n_splits=5)
    skf.get_n_splits(X_train, Y_train)
    
    grid = GridSearchCV(
    pipeline,  # pipeline from above
    params,  # parameters to tune via cross validation
    refit=True,  # fit using all available data at the end, on the best found param combination
    scoring='accuracy',  # what score are we optimizing?
    cv=skf,  # what type of cross validation to use
    )
    

    【讨论】:

      猜你喜欢
      • 2020-10-07
      • 1970-01-01
      • 1970-01-01
      • 2018-07-28
      • 1970-01-01
      • 2022-11-17
      • 2022-08-22
      • 2019-01-03
      • 1970-01-01
      相关资源
      最近更新 更多