【发布时间】:2015-10-19 12:18:22
【问题描述】:
背景
我最初的问题是为什么在 map 函数中使用 DecisionTreeModel.predict 会引发异常? 并且与 How to generate tuples of (original lable, predicted label) on Spark with MLlib? 有关
当我们使用 Scala API a recommended way 获得对 RDD[LabeledPoint] 的预测时,使用 DecisionTreeModel 是简单地映射到 RDD:
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
不幸的是,PySpark 中的类似方法效果不佳:
labelsAndPredictions = testData.map(
lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()
异常:您似乎正试图从广播变量、操作或转换中引用 SparkContext。 SparkContext 只能在驱动程序上使用,不能在它在工作人员上运行的代码中使用。如需更多信息,请参阅SPARK-5063。
而不是 official documentation 推荐这样的东西:
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
那么这里发生了什么?这里没有广播变量,Scala API 定义predict 如下:
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Vector): Double = {
topNode.predict(features)
}
/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD of predictions for each of the given data points
*/
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
所以至少乍一看,从动作或转换调用不是问题,因为预测似乎是一种本地操作。
说明
经过一番挖掘,我发现问题的根源是从DecisionTreeModel.predict 调用的JavaModelWrapper.call 方法。调用Java函数需要accessSparkContext:
callJavaFunc(self._sc, getattr(self._java_model, name), *a)
问题
在DecisionTreeModel.predict 的情况下,有一个推荐的解决方法,并且所有必需的代码都已经是 Scala API 的一部分,但是一般来说有什么优雅的方法来处理这样的问题吗?
目前只有我能想到的比较重量级的解决方案:
- 通过隐式转换扩展 Spark 类或添加某种包装器,将所有内容推送到 JVM
- 直接使用 Py4j 网关
【问题讨论】:
-
这部分是正确的。我在将 Scala 中的相同代码实现放到 Python 中以用于决策树时遇到了同样的麻烦,并引发了相同的广播问题,因此不得不使用 .zip 函数将标签组合回来。谢谢你的解释!
标签: python scala apache-spark pyspark apache-spark-mllib