【问题标题】:How to compute a cumulative sum under a limit with Spark?如何使用 Spark 计算限制下的累积和?
【发布时间】:2020-06-17 12:25:43
【问题描述】:

经过几次尝试和一些研究,我坚持尝试用 Spark 解决以下问题。

我有一个具有优先级和数量的元素数据框。

+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
|    f1| elmt 1|       1| 20|
|    f1| elmt 2|       2| 40|
|    f1| elmt 3|       3| 10|
|    f1| elmt 4|       4| 50|
|    f1| elmt 5|       5| 40|
|    f1| elmt 6|       6| 10|
|    f1| elmt 7|       7| 20|
|    f1| elmt 8|       8| 10|
+------+-------+--------+---+

我有一个固定的限制数量:

+------+--------+
|family|limitQty|
+------+--------+
|    f1|     100|
+------+--------+

我想将累积和低于限制的元素标记为“ok”。这是预期的结果:

+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
|    f1| elmt 1|       1| 20|  1| -> 20 < 100   => ok
|    f1| elmt 2|       2| 40|  1| -> 20 + 40 < 100  => ok
|    f1| elmt 3|       3| 10|  1| -> 20 + 40 + 10 < 100   => ok
|    f1| elmt 4|       4| 50|  0| -> 20 + 40 + 10 + 50 > 100   => ko 
|    f1| elmt 5|       5| 40|  0| -> 20 + 40 + 10 + 40 > 100   => ko  
|    f1| elmt 6|       6| 10|  1| -> 20 + 40 + 10 + 10 < 100   => ok
|    f1| elmt 7|       7| 20|  1| -> 20 + 40 + 10 + 10 + 20 < 100   => ok
|    f1| elmt 8|       8| 10|  0| -> 20 + 40 + 10 + 10 + 20 + 10 > 100   => ko
+------+-------+--------+---+---+  

我尝试用累积和来解决:

    initDF
      .join(limitQtyDF, Seq("family"), "left_outer")
      .withColumn("cumulSum", sum($"qty").over(Window.partitionBy("family").orderBy("priority")))
      .withColumn("ok", when($"cumulSum" <= $"limitQty", 1).otherwise(0))
      .drop("cumulSum", "limitQty")

但这还不够,因为达到限制的元素之后的元素没有考虑在内。 我找不到用 Spark 解决它的方法。你有什么想法吗?

这里是对应的 Scala 代码:

    val sparkSession = SparkSession.builder()
      .master("local[*]")
      .getOrCreate()

    import sparkSession.implicits._

    val initDF = Seq(
      ("f1", "elmt 1", 1, 20),
      ("f1", "elmt 2", 2, 40),
      ("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

    val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

    val expectedDF = Seq(
      ("f1", "elmt 1", 1, 20, 1),
      ("f1", "elmt 2", 2, 40, 1),
      ("f1", "elmt 3", 3, 10, 1),
      ("f1", "elmt 4", 4, 50, 0),
      ("f1", "elmt 5", 5, 40, 0),
      ("f1", "elmt 6", 6, 10, 1),
      ("f1", "elmt 7", 7, 20, 1),
      ("f1", "elmt 8", 8, 10, 0)
    ).toDF("family", "element", "priority", "qty", "ok").show()

感谢您的帮助!

【问题讨论】:

  • 你需要某种递归来做到这一点,窗口化是不够的。您只想要 scala 解决方案,还是选择 SQL?
  • 我想要一个 Spark 解决方案,因此纯 SQL 解决方案可以提供帮助并转换为 Spark SQL。我来看看 Spark 的递归可能性。

标签: sql scala apache-spark


【解决方案1】:

PFA 答案

val initDF = Seq(("f1", "elmt 1", 1, 20),("f1", "elmt 2", 2, 40),("f1", "elmt 3", 3, 10),
      ("f1", "elmt 4", 4, 50),
      ("f1", "elmt 5", 5, 40),
      ("f1", "elmt 6", 6, 10),
      ("f1", "elmt 7", 7, 20),
      ("f1", "elmt 8", 8, 10)
    ).toDF("family", "element", "priority", "qty")

val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")

sc.broadcast(limitQtyDF)


val joinedInitDF=initDF.join(limitQtyDF,Seq("family"),"left")

case class dataResult(family:String,element:String,priority:Int, qty:Int, comutedValue:Int, limitQty:Int,controlOut:String) 
val familyIDs=initDF.select("family").distinct.collect.map(_(0).toString).toList

def checkingUDF(inputRows:List[Row])={
var controlVarQty=0
val outputArrayBuffer=collection.mutable.ArrayBuffer[dataResult]()
val setLimit=inputRows.head.getInt(4) 
for(inputRow <- inputRows)
{
val currQty=inputRow.getInt(3) 
//val outpurForRec=
controlVarQty + currQty match {
case value if value <= setLimit => 
controlVarQty+=currQty
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ok")
case value => 
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ko")
}
//outputArrayBuffer+=Row(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),controlVarQty+currQty,setLimit,outpurForRec)
}
outputArrayBuffer.toList
}

val tmpAB=collection.mutable.ArrayBuffer[List[dataResult]]()
for (familyID <- familyIDs) // val familyID="f1"
{
val currentFamily=joinedInitDF.filter(s"family = '${familyID}'").orderBy("element", "priority").collect.toList
tmpAB+=checkingUDF(currentFamily)
}

tmpAB.toSeq.flatMap(x => x).toDF.show(false)

这对我有用。

+------+-------+--------+---+------------+--------+----------+
|family|element|priority|qty|comutedValue|limitQty|controlOut|
+------+-------+--------+---+------------+--------+----------+
|f1    |elmt 1 |1       |20 |20          |100     |ok        |
|f1    |elmt 2 |2       |40 |60          |100     |ok        |
|f1    |elmt 3 |3       |10 |70          |100     |ok        |
|f1    |elmt 4 |4       |50 |120         |100     |ko        |
|f1    |elmt 5 |5       |40 |110         |100     |ko        |
|f1    |elmt 6 |6       |10 |80          |100     |ok        |
|f1    |elmt 7 |7       |20 |100         |100     |ok        |
|f1    |elmt 8 |8       |10 |110         |100     |ko        |
+------+-------+--------+---+------------+--------+----------+

请务必从输出中删除不必要的列

【讨论】:

    【解决方案2】:

    每组累计数

    from pyspark.sql.window import Window as window
    from pyspark.sql.types import IntegerType,StringType,FloatType,StructType,StructField,DateType
    schema = StructType() \
            .add(StructField("empno",IntegerType(),True)) \
            .add(StructField("ename",StringType(),True)) \
            .add(StructField("job",StringType(),True)) \
            .add(StructField("mgr",StringType(),True)) \
            .add(StructField("hiredate",DateType(),True)) \
            .add(StructField("sal",FloatType(),True)) \
            .add(StructField("comm",StringType(),True)) \
            .add(StructField("deptno",IntegerType(),True))
    
    emp = spark.read.csv('data/emp.csv',schema)
    dept_partition = window.partitionBy(emp.deptno).orderBy(emp.sal)
    emp_win = emp.withColumn("dept_cum_sal", 
                             f.sum(emp.sal).over(dept_partition.rowsBetween(window.unboundedPreceding, window.currentRow)))
    emp_win.show()
    

    结果如下所示:

    +-----+------+---------+----+----------+------+-------+------+------------ 
    +
    |empno| ename|      job| mgr|  hiredate|   sal|   comm|deptno|dept_cum_sal|
    +-----+------+---------+----+----------+------+-------+------+------------ 
    +
    | 7369| SMITH|    CLERK|7902|1980-12-17| 800.0|   null|    20|       800.0|
    | 7876| ADAMS|    CLERK|7788|1983-01-12|1100.0|   null|    20|      1900.0|
    | 7566| JONES|  MANAGER|7839|1981-04-02|2975.0|   null|    20|      4875.0|
    | 7788| SCOTT|  ANALYST|7566|1982-12-09|3000.0|   null|    20|      7875.0|
    | 7902|  FORD|  ANALYST|7566|1981-12-03|3000.0|   null|    20|     10875.0|
    | 7934|MILLER|    CLERK|7782|1982-01-23|1300.0|   null|    10|      1300.0|
    | 7782| CLARK|  MANAGER|7839|1981-06-09|2450.0|   null|    10|      3750.0|
    | 7839|  KING|PRESIDENT|null|1981-11-17|5000.0|   null|    10|      8750.0|
    | 7900| JAMES|    CLERK|7698|1981-12-03| 950.0|   null|    30|       950.0|
    | 7521|  WARD| SALESMAN|7698|1981-02-22|1250.0| 500.00|    30|      2200.0|
    | 7654|MARTIN| SALESMAN|7698|1981-09-28|1250.0|1400.00|    30|      3450.0|
    | 7844|TURNER| SALESMAN|7698|1981-09-08|1500.0|   0.00|    30|      4950.0|
    | 7499| ALLEN| SALESMAN|7698|1981-02-20|1600.0| 300.00|    30|      6550.0|
    | 7698| BLAKE|  MANAGER|7839|1981-05-01|2850.0|   null|    30|      9400.0|
    +-----+------+---------+----+----------+------+-------+------+------------+
    

    【讨论】:

      【解决方案3】:

      我是 Spark 的新手,所以这个解决方案可能不是最佳的。我假设 100 的值是这里程序的输入。在这种情况下:

      case class Frame(family:String, element : String, priority : Int, qty :Int)
      
      import scala.collection.JavaConverters._
      val ans = df.as[Frame].toLocalIterator
        .asScala
        .foldLeft((Seq.empty[Int],0))((acc,a) => 
          if(acc._2 + a.qty <= 100) (acc._1 :+ a.priority, acc._2 + a.qty) else acc)._1
      
      df.withColumn("OK" , when($"priority".isin(ans :_*), 1).otherwise(0)).show
      

      结果:

      +------+-------+--------+---+--------+
      |family|element|priority|qty|OK      |
      +------+-------+--------+---+--------+
      |    f1| elmt 1|       1| 20|       1|
      |    f1| elmt 2|       2| 40|       1|
      |    f1| elmt 3|       3| 10|       1|
      |    f1| elmt 4|       4| 50|       0|
      |    f1| elmt 5|       5| 40|       0|
      |    f1| elmt 6|       6| 10|       1|
      |    f1| elmt 7|       7| 20|       1|
      |    f1| elmt 8|       8| 10|       0|
      +------+-------+--------+---+--------+
      

      这个想法只是获取一个 Scala 迭代器并从中提取参与的 priority 值,然后使用这些值过滤掉参与的行。鉴于此解决方案在一台机器上收集内存中的所有数据,如果数据帧太大而无法放入内存,它可能会遇到内存问题。

      【讨论】:

      • 感谢您的解决方案,但很遗憾我无法收集数据,因为 DataFrame 有数百万行:/
      【解决方案4】:

      另一种方法是通过逐行迭代的基于 RDD 的方法。

      var bufferRow: collection.mutable.Buffer[Row] = collection.mutable.Buffer.empty[Row]
      var tempSum: Double = 0
      val iterator = df.collect.iterator
      while(iterator.hasNext){
        val record = iterator.next()
        val y = record.getAs[Integer]("qty")
        tempSum = tempSum + y
        print(record)
        if (tempSum <= 100.0 ) {
          bufferRow = bufferRow ++ Seq(transformRow(record,1))
        }
        else{
          bufferRow = bufferRow ++ Seq(transformRow(record,0))
          tempSum = tempSum - y
        }
      }
      

      定义transformRow函数,用于在行中添加一列。

      def transformRow(row: Row,flag : Int): Row =  Row.fromSeq(row.toSeq ++ Array[Integer](flag))
      

      接下来要做的是向架构中添加一个额外的列。

      val newSchema = StructType(df.schema.fields ++ Array(StructField("C_Sum", IntegerType, false))
      

      接着创建一个新的数据框。

      val outputdf = spark.createDataFrame(spark.sparkContext.parallelize(bufferRow.toSeq),newSchema)
      

      输出数据框:

      +------+-------+--------+---+-----+
      |family|element|priority|qty|C_Sum|
      +------+-------+--------+---+-----+
      |    f1|  elmt1|       1| 20|    1|
      |    f1|  elmt2|       2| 40|    1|
      |    f1|  elmt3|       3| 10|    1|
      |    f1|  elmt4|       4| 50|    0|
      |    f1|  elmt5|       5| 40|    0|
      |    f1|  elmt6|       6| 10|    1|
      |    f1|  elmt7|       7| 20|    1|
      |    f1|  elmt8|       8| 10|    0|
      +------+-------+--------+---+-----+
      

      【讨论】:

      • 作为 jrook 解决方案,由于行数(> 1000 万),我无法在 master 上收集数据。但是您使用 RDD 而不是 Dataframe 的想法可能是一个好主意。谢谢!
      • 我明白你的意思,会调查它,让我们知道你的解决方案肯定会很有趣。
      【解决方案5】:

      解决方法如下:

      scala> initDF.show
      +------+-------+--------+---+
      |family|element|priority|qty|
      +------+-------+--------+---+
      |    f1| elmt 1|       1| 20|
      |    f1| elmt 2|       2| 40|
      |    f1| elmt 3|       3| 10|
      |    f1| elmt 4|       4| 50|
      |    f1| elmt 5|       5| 40|
      |    f1| elmt 6|       6| 10|
      |    f1| elmt 7|       7| 20|
      |    f1| elmt 8|       8| 10|
      +------+-------+--------+---+
      
      scala> val df1 = initDF.groupBy("family").agg(collect_list("qty").as("comb_qty"), collect_list("priority").as("comb_prior"), collect_list("element").as("comb_elem"))
      df1: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 2 more fields]
      
      scala> df1.show
      +------+--------------------+--------------------+--------------------+
      |family|            comb_qty|          comb_prior|           comb_elem|
      +------+--------------------+--------------------+--------------------+
      |    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|
      +------+--------------------+--------------------+--------------------+
      
      
      scala> val df2 = df1.join(limitQtyDF, df1("family") === limitQtyDF("family")).drop(limitQtyDF("family"))
      df2: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]
      
      scala> df2.show
      +------+--------------------+--------------------+--------------------+--------+
      |family|            comb_qty|          comb_prior|           comb_elem|limitQty|
      +------+--------------------+--------------------+--------------------+--------+
      |    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|     100|
      +------+--------------------+--------------------+--------------------+--------+
      
      
      scala> def validCheck = (qty: Seq[Int], limit: Int) => {
           | var sum = 0
           | qty.map(elem => {
           | if (elem + sum <= limit) {
           | sum = sum + elem
           | 1}else{
           | 0
           | }})}
      validCheck: (scala.collection.mutable.Seq[Int], Int) => scala.collection.mutable.Seq[Int]
      
      scala> val newUdf = udf(validCheck)
      newUdf: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,ArrayType(IntegerType,false),Some(List(ArrayType(IntegerType,false), IntegerType)))
      
      val df3 = df2.withColumn("valid", newUdf(col("comb_qty"),col("limitQty"))).drop("limitQty")
      df3: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]
      
      scala> df3.show
      +------+--------------------+--------------------+--------------------+--------------------+
      |family|            comb_qty|          comb_prior|           comb_elem|               valid|
      +------+--------------------+--------------------+--------------------+--------------------+
      |    f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|[1, 1, 1, 0, 0, 1...|
      +------+--------------------+--------------------+--------------------+--------------------+
      
      scala> val myUdf = udf((qty: Seq[Int], prior: Seq[Int], elem: Seq[String], valid: Seq[Int]) => {
           | elem zip prior zip qty zip valid map{
           | case (((a,b),c),d) => (a,b,c,d)}
           | }
           | )
      
      scala> val df4 = df3.withColumn("combined", myUdf(col("comb_qty"),col("comb_prior"),col("comb_elem"),col("valid")))
      df4: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 4 more fields]
      
      
      
      scala> val df5 = df4.drop("comb_qty","comb_prior","comb_elem","valid")
      df5: org.apache.spark.sql.DataFrame = [family: string, combined: array<struct<_1:string,_2:int,_3:int,_4:int>>]
      
      scala> df5.show(false)
      +------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
      |family|combined                                                                                                                                                        |
      +------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
      |f1    |[[elmt 1, 1, 20, 1], [elmt 2, 2, 40, 1], [elmt 3, 3, 10, 1], [elmt 4, 4, 50, 0], [elmt 5, 5, 40, 0], [elmt 6, 6, 10, 1], [elmt 7, 7, 20, 1], [elmt 8, 8, 10, 0]]|
      +------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
      
      scala> val df6 = df5.withColumn("combined",explode(col("combined")))
      df6: org.apache.spark.sql.DataFrame = [family: string, combined: struct<_1: string, _2: int ... 2 more fields>]
      
      scala> df6.show
      +------+------------------+
      |family|          combined|
      +------+------------------+
      |    f1|[elmt 1, 1, 20, 1]|
      |    f1|[elmt 2, 2, 40, 1]|
      |    f1|[elmt 3, 3, 10, 1]|
      |    f1|[elmt 4, 4, 50, 0]|
      |    f1|[elmt 5, 5, 40, 0]|
      |    f1|[elmt 6, 6, 10, 1]|
      |    f1|[elmt 7, 7, 20, 1]|
      |    f1|[elmt 8, 8, 10, 0]|
      +------+------------------+
      
      scala> val df7 = df6.select("family", "combined._1", "combined._2", "combined._3", "combined._4").withColumnRenamed("_1","element").withColumnRenamed("_2","priority").withColumnRenamed("_3", "qty").withColumnRenamed("_4","ok")
      df7: org.apache.spark.sql.DataFrame = [family: string, element: string ... 3 more fields]
      
      scala> df7.show
      +------+-------+--------+---+---+
      |family|element|priority|qty| ok|
      +------+-------+--------+---+---+
      |    f1| elmt 1|       1| 20|  1|
      |    f1| elmt 2|       2| 40|  1|
      |    f1| elmt 3|       3| 10|  1|
      |    f1| elmt 4|       4| 50|  0|
      |    f1| elmt 5|       5| 40|  0|
      |    f1| elmt 6|       6| 10|  1|
      |    f1| elmt 7|       7| 20|  1|
      |    f1| elmt 8|       8| 10|  0|
      +------+-------+--------+---+---+
      

      如果有帮助请告诉我!!

      【讨论】:

      • 您的解决方案解决了问题!但我有两个问题。我看到数组之间存在隐式顺序(comb_qty、comb_prior、comb_elem),即数组中的索引 0 用于元素 0。Spark 是否保证顺序?还有一个关于可扩展性的问题,如果我按家庭有 10k 个元素,这个解决方案是否仍然可行?感谢您的解决方案!
      猜你喜欢
      • 2016-05-11
      • 1970-01-01
      • 2021-05-11
      • 2020-11-04
      • 2021-09-03
      • 2014-02-16
      • 2021-10-24
      • 2019-08-23
      • 2019-07-08
      相关资源
      最近更新 更多