【问题标题】:Wrapping a java function in pyspark在pyspark中包装一个java函数
【发布时间】:2016-03-08 13:08:02
【问题描述】:

我正在尝试创建一个可以从 python 调用的用户定义聚合函数。我试图关注this 问题的答案。 我基本上实现了以下(取自here):

package com.blu.bla;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;

public class MySum extends UserDefinedAggregateFunction {
    private StructType _inputDataType;
    private StructType _bufferSchema;
    private DataType _returnDataType;

    public MySum() {
        List<StructField> inputFields = new ArrayList<StructField>();
        inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
        _inputDataType = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<StructField>();
        bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
        _bufferSchema = DataTypes.createStructType(bufferFields);

        _returnDataType = DataTypes.DoubleType;
    }

    @Override public StructType inputSchema() {
        return _inputDataType;
    }

    @Override public StructType bufferSchema() {
        return _bufferSchema;
    }

    @Override public DataType dataType() {
        return _returnDataType;
    }

    @Override public boolean deterministic() {
        return true;
    }

    @Override public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, null);
    }

    @Override public void update(MutableAggregationBuffer buffer, Row input) {
        if (!input.isNullAt(0)) {
            if (buffer.isNullAt(0)) {
                buffer.update(0, input.getDouble(0));
            } else {
                Double newValue = input.getDouble(0) + buffer.getDouble(0);
                buffer.update(0, newValue);
            }
        }
    }

    @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        if (!buffer2.isNullAt(0)) {
            if (buffer1.isNullAt(0)) {
                buffer1.update(0, buffer2.getDouble(0));
            } else {
                Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
                buffer1.update(0, newValue);
            }
        }
    }

    @Override public Object evaluate(Row buffer) {
        if (buffer.isNullAt(0)) {
            return null;
        } else {
            return buffer.getDouble(0);
        }
    }
}

然后我用所有依赖项编译它并使用 --jars myjar.jar 运行 pyspark

在 pyspark 中我做了:

df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"])
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql import Row

def myCol(col):
    _f = sc._jvm.com.blu.bla.MySum.apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))
b = df.agg(myCol("A"))

我收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-f45b2a367e67> in <module>()
----> 1 b = df.agg(myCol("A"))

<ipython-input-22-afcb8884e1db> in myCol(col)
      4 def myCol(col):
      5     _f = sc._jvm.com.blu.bla.MySum.apply
----> 6     return Column(_f(_to_seq(sc,[col], _to_java_column)))

TypeError: 'JavaPackage' object is not callable

我也尝试将 --driver-class-path 添加到 pyspark 调用中,但得到了相同的结果。

也尝试通过java import来访问java类:

from py4j.java_gateway import java_import
jvm = sc._gateway.jvm
java_import(jvm, "com.bla.blu.MySum")
def myCol2(col):
    _f = jvm.bla.blu.MySum.apply
    return Column(_f(_to_seq(sc,[col], _to_java_column)))

还尝试简单地创建类(如建议的here):

a = jvm.com.bla.blu.MySum()

所有人都收到相同的错误消息。

我似乎无法弄清楚问题所在。

【问题讨论】:

    标签: java python apache-spark pyspark


    【解决方案1】:

    所以似乎主要问题是如果提供相对路径,则添加 jar 的所有选项(--jars、驱动程序类路径、SPARK_CLASSPATH)都无法正常工作。这可能是因为 ipython 中的工作目录与我运行 pyspark 的位置不同。

    一旦我把它改成绝对路径,它就可以工作了(还没有在集群上测试过,但至少它可以在本地安装上工作)。

    另外,我不确定这是否也是答案 here 中的错误,因为该答案使用 scala 实现,但是在 java 实现中我需要做

    def myCol(col):
        _f = sc._jvm.com.blu.bla.MySum().apply
        return Column(_f(_to_seq(sc,[col], _to_java_column)))
    

    这可能不是很有效,因为它每次都会创建 _f,相反我应该在函数外部定义 _f(同样,这需要在集群上进行测试)但至少现在它提供了正确的函数答案

    【讨论】:

    • 最后一件事,以供将来参考,这是在 spark 1.6.0 本地(单节点)安装上测试的
    • 在集群上进行了测试,它可以工作。一起使用 --jars 和 --driver-class-path (显然 --jars 没有将类路径设置为驱动程序)
    • 嗨!我正在尝试做类似的事情,但我遇到了同样的错误,我无法理解它,您能否提供一些关于如何运行它的见解。
    • 假设您的 jar 位于 /home/a/a.jar。我要做的是使用 --jars /home/a/a.jar --driver-class-path /home/a/a.jar 运行它,这很好
    • 嗨 Assaf,我遇到了同样的错误。我通过了两个罐子(逗号分隔)。然后我尝试将它们放在一个目录中,我得到错误我的支持 jar(一个是使用另一个的主 jar)没有主类。我该如何解决?
    猜你喜欢
    • 2021-07-10
    • 1970-01-01
    • 2013-05-23
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-04-06
    • 1970-01-01
    • 2017-03-18
    相关资源
    最近更新 更多