【发布时间】:2016-06-02 11:12:21
【问题描述】:
我正在尝试找到一种方法来计算给定数据帧的中位数。
val df = sc.parallelize(Seq(("a",1.0),("a",2.0),("a",3.0),("b",6.0), ("b", 8.0))).toDF("col1", "col2")
+----+----+
|col1|col2|
+----+----+
| a| 1.0|
| a| 2.0|
| a| 3.0|
| b| 6.0|
| b| 8.0|
+----+----+
现在我想做那样的事情:df.groupBy("col1").agg(calcmedian("col2"))
结果应该是这样的:
+----+------+
|col1|median|
+----+------+
| a| 2.0|
| b| 7.0|
+----+------+`
因此 calcmedian() 必须是 UDAF,但问题是,UDAF 的“评估”方法只需要一个 Row,但我需要整个表对值进行排序并返回中位数......
// Once all entries for a group are exhausted, spark will evaluate to get the final result
def evaluate(buffer: Row) = {...}
这有可能吗?还是有另一个不错的解决方法?我想强调,我知道如何用“一组”计算数据集的中位数。但是我不想在“foreach”循环中使用这个算法,因为这效率低下!
谢谢!
编辑:
这就是我到目前为止所尝试的:
object calcMedian extends UserDefinedAggregateFunction {
// Schema you get as an input
def inputSchema = new StructType().add("col2", DoubleType)
// Schema of the row which is used for aggregation
def bufferSchema = new StructType().add("col2", DoubleType)
// Returned type
def dataType = DoubleType
// Self-explaining
def deterministic = true
// initialize - called once for each group
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = 0.0
}
// called for each input record of that group
def update(buffer: MutableAggregationBuffer, input: Row) = {
buffer(0) = input.getDouble(0)
}
// if function supports partial aggregates, spark might (as an optimization) comput partial results and combine them together
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer1(0) = input.getDouble(0)
}
// Once all entries for a group are exhausted, spark will evaluate to get the final result
def evaluate(buffer: Row) = {
val tile = 50
var median = 0.0
//PROBLEM: buffer is a Row --> I need DataFrame here???
val rdd_sorted = buffer.sortBy(x => x)
val c = rdd_sorted.count()
if (c == 1){
median = rdd_sorted.first()
}else{
val index = rdd_sorted.zipWithIndex().map(_.swap)
val last = c
val n = (tile/ 100d) * (c*1d)
val k = math.floor(n).toLong
val d = n - k
if( k <= 0) {
median = rdd_sorted.first()
}else{
if (k <= c){
median = index.lookup(last - 1).head
}else{
if(k >= c){
median = index.lookup(last - 1).head
}else{
median = index.lookup(k-1).head + d* (index.lookup(k).head - index.lookup(k-1).head)
}
}
}
}
} //end of evaluate
【问题讨论】:
-
您需要
groupByKey,将聚合数据转换为Buffer,还有一些UDFs 来实现这一点,然后您创建一个UDF 来计算中位数。 -
UserDefinedAggregateFunction基类的成员远不止evaluate,需要实现。传递给evaluate的Row缓冲区是最后一步。你有没有尝试过任何实现,如果是的话,你能展示你到目前为止的代码吗? -
@mattinbits:我添加了到目前为止我正在考虑的代码....
-
a) there are already built in functions to compute approximate or exact median b) 无法访问 UDAF 中的数据帧 c) 在分布式环境中计算精确中位数的效率极低,仅由于定义。
-
然后
percentile_approx/percentile是。数据帧上的groupBy不会物理移动数据(这对这里的洗牌有影响)。它是aggregate(ByKey)等价物,API 清楚地反映了这一点。一种或另一种方式,您无法访问 UDAF 中的数据框。
标签: scala apache-spark group-by median user-defined-aggregate