【问题标题】:How to calculate product of columns followed by sum over all columns?如何计算列的乘积,然后是所有列的总和?
【发布时间】:2017-06-24 21:40:03
【问题描述】:

表1--Spark DataFrame 表

表1中有一个名为“productMe”的列;还有其他列,如 a、b、c 等,其模式名称包含在模式数组 T 中。

我想要的是模式数组 T 中列(两列的每一行的乘积)与列 productMe 的内积(表 2)。并将表2的每一列相加得到表3。

如果您有一个好主意可以一步获得表 3,则不需要表 2。

表2——内积表

例如,列“a·productMe”为(3*0.2, 6*0.6, 5*0.4)得到(0.6, 3.6, 2)

表3——总和表

例如,列“sum(a·productMe)”为0.6+3.6+2=6.2。

Table 1 是 Spark 的 DataFrame,如何获取 Table 3?

【问题讨论】:

  • 不知道a、b、c的名称,列数不详。名称 a、b、c 等包含在模式数组 T 中。
  • 请停止发布您的数据快照并输入实际值!在完成你的“任务”时也要表现出一些努力!
  • 无论它是否是“作业”,这对我来说都是一个有用的问题。我需要一个工作解决方案,但我没有时间做“家庭作业”。

标签: scala apache-spark apache-spark-sql


【解决方案1】:

您可以尝试以下方法:

val df = Seq(
  (3,0.2,0.5,0.4),
  (6,0.6,0.3,0.1),
  (5,0.4,0.6,0.5)).toDF("productMe", "a", "b", "c")
import org.apache.spark.sql.functions.col
val columnsToSum = df.
  columns.  // <-- grab all the columns by their name
  tail.     // <-- skip productMe
  map(col). // <-- create Column objects
  map(c => round(sum(c * col("productMe")), 3).as(s"sum_${c}_productMe"))
val df2 = df.select(columnsToSum: _*)
df2.show()
# +---------------+---------------+---------------+
# |sum_a_productMe|sum_b_productMe|sum_c_productMe|
# +---------------+---------------+---------------+
# |            6.2|            6.3|            4.3|
# +---------------+---------------+---------------+

诀窍是使用df.select(columnsToSum: _*),这意味着您要选择我们对其进行列总和乘以productMe 列的所有列。 :_* 是一种 Scala 特定的语法,用于指定我们传递重复的参数,因为我们没有固定数量的参数。

【讨论】:

  • 这是一个聪明的解决方案。一个问题。您的地图功能不会按照您为我的方法建议的那样创建 OOM 吗?
  • @RameshMaharjan 在这里不会有问题,因为我们只聚合一次。
  • DataFrames 是不可变的集合。因此,您需要将转换结果分配给新的 DataFrame 值。我已经更新了我的答案,以便您更容易理解。
【解决方案2】:

我们可以用简单的 SparkSql 做到这一点

   val table1 = Seq(
   (3,0.2,0.5,0.4),
   (6,0.6,0.3,0.1),
   (5,0.4,0.6,0.5)
 ).toDF("productMe", "a", "b", "c")

table1.show
table1.createOrReplaceTempView("table1") 

val table2 = spark.sql("select a*productMe, b*productMe, c*productMe  from table1")   //spark is sparkSession here
table2.show

val table3 = spark.sql("select sum(a*productMe), sum(b*productMe), sum(c*productMe) from table1")
table3.show

【讨论】:

    【解决方案3】:

    所有其他答案都使用 sum 聚合,在后台使用 groupBy

    groupBy 总是引入 shuffle 阶段并且通常 (always?) 比相应的窗口聚合慢。

    在这种特殊情况下,我也相信窗口聚合可以提供更好的性能,正如您在他们的物理计划和他们唯一一项工作的详细信息中看到的那样。

    注意

    任何一种解决方案都使用单个分区进行计算,这反过来又使它们不适合用于大型数据集,因为它们的大小加在一起可能很容易超过单个 JVM 的内存大小。

    窗口聚合

    接下来是基于窗口聚合的计算,在这种特殊情况下,我们对数据集中的所有行进行分组,不幸的是给出了相同的物理计划。这让我的回答只是一个(希望)很好的学习经历。

    val df = Seq(
      (3,0.2,0.5,0.4),
      (6,0.6,0.3,0.1),
      (5,0.4,0.6,0.5)).toDF("productMe", "a", "b", "c")
    
    // yes, I did borrow this trick with columns from @eliasah's answer
    import org.apache.spark.sql.functions.col
    val columns = df.columns.tail.map(col).map(c => c * col("productMe") as s"${c}_productMe")
    val multiplies = df.select(columns: _*)
    scala> multiplies.show
    +------------------+------------------+------------------+
    |       a_productMe|       b_productMe|       c_productMe|
    +------------------+------------------+------------------+
    |0.6000000000000001|               1.5|1.2000000000000002|
    |3.5999999999999996|1.7999999999999998|0.6000000000000001|
    |               2.0|               3.0|               2.5|
    +------------------+------------------+------------------+
    
    def sumOverRows(name: String) = sum(name) over ()
    val multipliesCols = multiplies.
      columns.
      map(c => sumOverRows(c) as s"sum_${c}")
    val answer = multiplies.
      select(multipliesCols: _*).
      limit(1)  // <-- don't use distinct or dropDuplicates here
    scala> answer.show
    +-----------------+---------------+-----------------+
    |  sum_a_productMe|sum_b_productMe|  sum_c_productMe|
    +-----------------+---------------+-----------------+
    |6.199999999999999|            6.3|4.300000000000001|
    +-----------------+---------------+-----------------+
    

    物理计划

    然后让我们看看物理计划(因为这是我们想看看如何使用窗口聚合进行查询的唯一原因,不是吗?)

    以下是唯一作业0的详细信息。

    【讨论】:

      【解决方案4】:

      如果我正确理解了您的问题,那么以下可能是您的解决方案

         val df = Seq(
             (3,0.2,0.5,0.4),
             (6,0.6,0.3,0.1),
             (5,0.4,0.6,0.5)
           ).toDF("productMe", "a", "b", "c")
      

      这会提供您所拥有的输入数据框(您可以添加更多)

      +---------+---+---+---+
      |productMe|a  |b  |c  |
      +---------+---+---+---+
      |3        |0.2|0.5|0.4|
      |6        |0.6|0.3|0.1|
      |5        |0.4|0.6|0.5|
      +---------+---+---+---+
      

      val productMe = df.columns.head
      val colNames = df.columns.tail
      var tempdf = df
      for(column <- colNames){
        tempdf = tempdf.withColumn(column, col(column)*col(productMe))
      }
      

      以上步骤应该给你Table2

      +---------+------------------+------------------+------------------+
      |productMe|a                 |b                 |c                 |
      +---------+------------------+------------------+------------------+
      |3        |0.6000000000000001|1.5               |1.2000000000000002|
      |6        |3.5999999999999996|1.7999999999999998|0.6000000000000001|
      |5        |2.0               |3.0               |2.5               |
      +---------+------------------+------------------+------------------+
      

      Table3可以如下实现

      tempdf.select(sum("a").as("sum(a.productMe)"), sum("b").as("sum(b.productMe)"), sum("c").as("sum(c.productMe)")).show(false)
      

      表3是

      +-----------------+----------------+-----------------+
      |sum(a.productMe) |sum(b.productMe)|sum(c.productMe) |
      +-----------------+----------------+-----------------+
      |6.199999999999999|6.3             |4.300000000000001|
      +-----------------+----------------+-----------------+
      

      Table2 可以为您拥有的任意数量的列实现,但 Table3 需要您明确定义列

      【讨论】:

      • 迭代以在 spark 数据帧上添加列是一种不好的做法,因为它会使线性更长并导致 OOME。
      • @eliasah 你能帮我理解上述方法在哪里产生更长的线性吗?
      猜你喜欢
      • 2017-04-07
      • 2020-12-02
      • 1970-01-01
      • 2015-01-22
      • 2020-09-26
      • 2017-10-05
      • 2020-07-06
      • 2018-10-10
      • 1970-01-01
      相关资源
      最近更新 更多