【问题标题】:Spark K-fold Cross ValidationSpark K-fold 交叉验证
【发布时间】:2016-10-21 12:23:03
【问题描述】:

我在理解 Spark 的交叉验证时遇到了一些麻烦。我见过的任何示例都使用它进行参数调整,但我认为它也只会进行常规的 K 折交叉验证?

我想做的是执行 k 折交叉验证,其中 k=5。我想获得每个结果的准确性,然后获得平均准确性。 在 scikit learn 中,这是如何完成的,其中 score 会给你每个折叠的结果,然后你可以使用 scores.mean()

scores = cross_val_score(classifier, y, x, cv=5, scoring='accuracy')

这就是我在 Spark 中的做法,paramGridBuilder 是空的,因为我不想输入任何参数。

val paramGrid = new ParamGridBuilder().build()
val evaluator = new MulticlassClassificationEvaluator()
  evaluator.setLabelCol("label")
  evaluator.setPredictionCol("prediction")
evaluator.setMetricName("precision")


val crossval = new CrossValidator()
crossval.setEstimator(classifier)
crossval.setEvaluator(evaluator) 
crossval.setEstimatorParamMaps(paramGrid)
crossval.setNumFolds(5)


val modelCV = crossval.fit(df4)
val chk = modelCV.avgMetrics

这是否与 scikit learn 实现相同?为什么这些示例在进行交叉验证时使用训练/测试数据?

How to cross validate RandomForest model?

https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala

【问题讨论】:

    标签: machine-learning classification apache-spark-mllib cross-validation


    【解决方案1】:
    1. 你正在做的看起来没问题。
    2. 基本上,是的,它的工作原理与 sklearn 的 grid search CV 相同。
      对于每个 EstimatorParamMaps(一组参数),该算法使用 CV 进行测试,因此 avgMetrics 是所有折叠的平均交叉验证准确度指标/秒。 如果一个人使用空的ParamGridBuilder(无参数搜索),这就像进行“常规”交叉验证”,我们将产生一个交叉验证的训练准确性。
    3. 每次 CV 迭代都包括K-1 训练折叠和1 测试折叠,那么为什么大多数示例在进行交叉验证之前将数据与训练/测试数据分开? 因为 CV 内的测试折叠用于参数网格搜索。 这意味着模型选择需要额外的验证数据集。 因此需要所谓的“测试数据集”来评估最终模型。阅读更多here

    【讨论】:

      猜你喜欢
      • 2019-02-19
      • 2011-10-01
      • 2018-09-11
      • 2020-10-25
      • 1970-01-01
      • 2015-09-26
      • 2020-07-30
      • 1970-01-01
      • 2021-10-11
      相关资源
      最近更新 更多