【发布时间】:2018-09-12 21:36:45
【问题描述】:
在使用 Apache Spark 时,我正在尝试将一些计算从 Python 卸载到 Scala。我想使用 Java 的类接口来使用持久变量,就像这样(这是基于我更复杂的用例的无意义的 MWE):
package mwe
import org.apache.spark.sql.api.java.UDF1
class SomeFun extends UDF1[Int, Int] {
private var prop: Int = 0
override def call(input: Int): Int = {
if (prop == 0) {
prop = input
}
prop + input
}
}
现在我尝试在 pyspark 中使用这个类:
import pyspark
from pyspark.sql import SQLContext
from pyspark import SparkContext
conf = pyspark.SparkConf()
conf.set("spark.jars", "mwe.jar")
sc = SparkContext.getOrCreate(conf)
sqlContext = SQLContext.getOrCreate(sc)
sqlContext.registerJavaFunction("fun", "mwe.SomeFun")
df0 = sc.parallelize((i,) for i in range(6)).toDF(["num"])
df1 = df0.selectExpr("fun(num) + 3 as new_num")
df1.show()
并得到以下异常:
pyspark.sql.utils.AnalysisException: u"cannot resolve '(UDF:fun(num) + 3)' due to data type mismatch: differing types in '(UDF:fun(num) + 3)' (struct<> and int).; line 1 pos 0;\n'Project [(UDF:fun(num#0L) + 3) AS new_num#2]\n+- AnalysisBarrier\n +- LogicalRDD [num#0L], false\n"
实现这一点的正确方法是什么?我必须求助于Java本身来上课吗?非常感谢您的提示!
【问题讨论】:
标签: scala apache-spark pyspark apache-spark-sql user-defined-functions