【发布时间】:2016-10-21 12:23:03
【问题描述】:
我在理解 Spark 的交叉验证时遇到了一些麻烦。我见过的任何示例都使用它进行参数调整,但我认为它也只会进行常规的 K 折交叉验证?
我想做的是执行 k 折交叉验证,其中 k=5。我想获得每个结果的准确性,然后获得平均准确性。 在 scikit learn 中,这是如何完成的,其中 score 会给你每个折叠的结果,然后你可以使用 scores.mean()
scores = cross_val_score(classifier, y, x, cv=5, scoring='accuracy')
这就是我在 Spark 中的做法,paramGridBuilder 是空的,因为我不想输入任何参数。
val paramGrid = new ParamGridBuilder().build()
val evaluator = new MulticlassClassificationEvaluator()
evaluator.setLabelCol("label")
evaluator.setPredictionCol("prediction")
evaluator.setMetricName("precision")
val crossval = new CrossValidator()
crossval.setEstimator(classifier)
crossval.setEvaluator(evaluator)
crossval.setEstimatorParamMaps(paramGrid)
crossval.setNumFolds(5)
val modelCV = crossval.fit(df4)
val chk = modelCV.avgMetrics
这是否与 scikit learn 实现相同?为什么这些示例在进行交叉验证时使用训练/测试数据?
【问题讨论】:
标签: machine-learning classification apache-spark-mllib cross-validation