【问题标题】:Generic UDAF in Spark 3.0 using AggregatorSpark 3.0 中使用聚合器的通用 UDAF
【发布时间】:2020-11-30 02:00:00
【问题描述】:

Spark 3.0 已弃用 UserDefinedAggregateFunction,我试图使用 Aggregator 重写我的 udaf。 Aggregator 的基本用法很简单,但是,我很难使用更通用的函数版本。

我将尝试用这个例子解释我的问题,collect_set 的实现。这不是我的实际情况,但更容易解释问题:

class CollectSetDemoAgg(name: String) extends Aggregator[Row, Set[Int], Set[Int]] {
  override def zero = Set.empty
  override def reduce(b: Set[Int], a: Row) = b + a.getInt(a.fieldIndex(name))
  override def merge(b1: Set[Int], b2: Set[Int]) = b1 ++ b2
  override def finish(reduction: Set[Int]) = reduction
  override def bufferEncoder = Encoders.kryo[Set[Int]]
  override def outputEncoder = ExpressionEncoder()
}

// using it:
df.agg(new CollectSetDemoAgg("rank").toColumn as "result").show()

我更喜欢.toColumn.udf.register,但这不是重点。

问题: 我不能制作这个聚合器的通用版本,它只适用于整数。

我尝试过:

class CollectSetDemo(name: String) extends Aggregator[Row, Set[Any], Set[Any]] 

它因错误而崩溃:

No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
java.lang.UnsupportedOperationException: No Encoder found for Any
- array element class: "java.lang.Object"
- root class: "scala.collection.immutable.Set"
    at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567)

我无法使用CollectSetDemo[T],如果我无法正确使用outputEncoder。另外,在使用 udaf 时,我只能使用 Spark 数据类型、列等。

【问题讨论】:

    标签: scala apache-spark generics aggregator


    【解决方案1】:

    还没有找到解决这种情况的好方法,但我能够在某种程度上解决它。部分代码是从RowEncoder借来的:

    class CollectSetDemoAgg(name: String, fieldType: DataType) extends Aggregator[Row, Set[Any], Any] {
      override def zero = Set.empty
      override def reduce(b: Set[Any], a: Row) = b + a.get(a.fieldIndex(name))
      override def merge(b1: Set[Any], b2: Set[Any]) = b1 ++ b2
      override def finish(reduction: Set[Any]) = reduction.toSeq
      override def bufferEncoder = Encoders.kryo[Set[Any]]
    
      // now
      override def outputEncoder = {
        val mirror = ScalaReflection.mirror
        val tt = fieldType match {
          case ArrayType(LongType, _) => typeTag[Seq[Long]]
          case ArrayType(IntegerType, _) => typeTag[Seq[Int]]
          case ArrayType(StringType, _) => typeTag[Seq[String]]
          // .. etc etc
          case _ => throw new RuntimeException(s"Could not create encoder for ${name} column (${fieldType})")
        }
        val tpe = tt.in(mirror).tpe
    
        val cls = mirror.runtimeClass(tpe)
        val serializer = ScalaReflection.serializerForType(tpe)
        val deserializer = ScalaReflection.deserializerForType(tpe)
    
        new ExpressionEncoder[Any](serializer, deserializer, ClassTag[Any](cls))
      }
    }
    

    我必须添加的一件事是聚合器中的结果数据类型参数。然后用法改为:

    df.agg(new CollectSetDemoAgg("rank", new ArrayType(IntegerType, true)).toColumn as "result").show()
    

    我真的不喜欢它的结果,但它确实有效。我也欢迎任何关于如何改进它的建议。

    【讨论】:

      【解决方案2】:

      用泛型修改@Ramunas答案:

      class CollectSetDemoAgg[T: TypeTag](name: String) extends Aggregator[Row, Set[T], Seq[T]] {
        override def zero = Set.empty
        override def reduce(b: Set[T], a: Row) = b + a.getAs[T](a.fieldIndex(name))
        override def merge(b1: Set[T], b2: Set[T]) = b1 ++ b2
        override def finish(reduction: Set[T]) = reduction.toSeq
        override def bufferEncoder = Encoders.kryo[Set[T]]
        
        override def outputEncoder = {
          val tt = typeTag[Seq[T]]
          val tpe = tt.in(mirror).tpe
      
          val cls = mirror.runtimeClass(tpe)
          val serializer = serializerForType(tpe)
          val deserializer = deserializerForType(tpe)
      
          new ExpressionEncoder[Seq[T]](serializer, deserializer, ClassTag[Seq[T]](cls))
        }
      }
      

      【讨论】:

        猜你喜欢
        • 2018-01-03
        • 1970-01-01
        • 1970-01-01
        • 2016-02-29
        • 2021-11-28
        • 1970-01-01
        • 2015-10-02
        • 2015-12-19
        • 2017-08-02
        相关资源
        最近更新 更多