【发布时间】:2016-05-12 16:57:33
【问题描述】:
我正在尝试开发一个用户定义的聚合函数,该函数计算一行数字的线性回归。我已经成功完成了计算均值置信区间的 UDAF(经过大量试验和错误以及 SO!)。
这就是我已经实际运行的内容:
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{StructType, StructField, DoubleType, LongType, DataType, ArrayType}
case class RegressionData(intercept: Double, slope: Double)
class Regression {
import org.apache.commons.math3.stat.regression.SimpleRegression
def roundAt(p: Int)(n: Double): Double = { val s = math pow (10, p); (math round n * s) / s }
def getRegression(data: List[Long]): RegressionData = {
val regression: SimpleRegression = new SimpleRegression()
data.view.zipWithIndex.foreach { d =>
regression.addData(d._2.toDouble, d._1.toDouble)
}
RegressionData(roundAt(3)(regression.getIntercept()), roundAt(3)(regression.getSlope()))
}
}
class UDAFRegression extends UserDefinedAggregateFunction {
import java.util.ArrayList
def deterministic = true
def inputSchema: StructType =
new StructType().add("units", LongType)
def bufferSchema: StructType =
new StructType().add("buff", ArrayType(LongType))
def dataType: DataType =
new StructType()
.add("intercept", DoubleType)
.add("slope", DoubleType)
def initialize(buffer: MutableAggregationBuffer) = {
buffer.update(0, new ArrayList[Long]())
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer.getList(0))
longList.add(input.getLong(0));
buffer.update(0, longList);
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val longList: ArrayList[Long] = new ArrayList[Long](buffer1.getList(0))
longList.addAll(buffer2.getList(0))
buffer1.update(0, longList)
}
def evaluate(buffer: Row) = {
import scala.collection.JavaConverters._
val list = buffer.getList(0).asScala.toList
val regression = new Regression
regression.getRegression(list)
}
}
但是数据集不是按顺序排列的,这在这里显然非常重要。因此,我需要第二个参数 regression($"longValue", $"created_day") 而不是 regression($"longValue")。 created_day 是 sql.types.DateType。
我对 DataTypes、StructTypes 和诸如此类的东西感到很困惑,并且由于网络上缺乏示例,我在这里的试用和订购尝试被卡住了。
我的bufferSchema 会是什么样子?
在我的情况下,这些 StructTypes 是开销吗? (可变的)Map 不会做吗? MapType 实际上是不可变的吗?作为缓冲区类型,这不是毫无意义吗?
我的inputSchema 会是什么样子?
这是否必须与我在update() 中通过input.getLong(0) 检索到的类型相匹配?
有没有标准的方法来重置initialize()中的缓冲区
我见过buffer.update(0, 0.0)(显然,它包含双打),buffer(0) = new WhatEver(),我认为甚至是buffer = Nil。这些有什么不同吗?
如何更新数据?
上面的例子似乎过于复杂。我期待能够做某事。喜欢buffer += input.getLong(0) -> input.getDate(1)。
我可以期望以这种方式访问输入吗
如何合并数据?
我可以把功能块留空吗
def merge(…) = {}?
在evaluate() 中对缓冲区进行排序的挑战是……。我应该能够弄清楚,尽管我仍然对你们如何做到这一点的最优雅的方式感兴趣(在很短的时间内)。
额外问题:dataType 扮演什么角色?
我返回一个案例类,而不是 dataType 中定义的 StructType,这似乎不是问题。还是因为它恰好与我的案例类匹配而有效?
【问题讨论】:
标签: scala apache-spark user-defined-functions