【问题标题】:Pyspark Window function on entire data frame整个数据帧上的 Pyspark Window 函数
【发布时间】:2020-06-10 14:04:46
【问题描述】:

考虑一个 pyspark 数据框。我想按列汇总整个数据框,并为每一行附加结果。

+-----+----------+-----------+
|index|      col1| col2      |
+-----+----------+-----------+
|  0.0|0.58734024|0.085703015|
|  1.0|0.67304325| 0.17850411|

预期结果

+-----+----------+-----------+-----------+-----------+-----------+-----------+
|index|      col1| col2      |  col1_min | col1_mean |col2_min   | col2_mean
+-----+----------+-----------+-----------+-----------+-----------+-----------+
|  0.0|0.58734024|0.085703015|  -5       | 2.3       |  -2       | 1.4 |
|  1.0|0.67304325| 0.17850411|  -5       | 2.3       |  -2       | 1.4 |

据我所知,我需要将整个数据框作为 Window 的 Window 函数,以保留每一行的结果(而不是,例如,分别进行统计,然后再加入以复制每一行)

我的问题是:

  1. 如何编写没有任何分区和排序的窗口

我知道有分区和顺序的标准窗口,但不是将所有内容都视为 1 个单独分区的窗口

w = Window.partitionBy("col1", "col2").orderBy(desc("col1"))
df = df.withColumn("col1_mean", mean("col1").over(w)))

如何编写一个将所有内容都作为一个分区的窗口?

  1. 为所有列动态写入的任何方式。

假设我有500列,重复写看起来不太好。

df = df.withColumn("col1_mean", mean("col1").over(w))).withColumn("col1_min", min("col2").over(w)).withColumn("col2_mean", mean().over(w)).....

假设我希望每列有多个统计信息,因此每个colx 都会生成colx_min, colx_max, colx_mean

【问题讨论】:

    标签: dataframe apache-spark pyspark apache-spark-sql window-functions


    【解决方案1】:

    您可以通过自定义聚合结合交叉连接来实现相同的效果,而不是使用窗口:

    import pyspark.sql.functions as F
    from pyspark.sql.functions import broadcast
    from itertools import chain
    
    df = spark.createDataFrame([
      [1, 2.3, 1],
      [2, 5.3, 2],
      [3, 2.1, 4],
      [4, 1.5, 5]
    ], ["index", "col1", "col2"])
    
    agg_cols = [(
                 F.min(c).alias("min_" + c), 
                 F.max(c).alias("max_" + c), 
                 F.mean(c).alias("mean_" + c)) 
    
      for c in df.columns if c.startswith('col')]
    
    stats_df = df.agg(*list(chain(*agg_cols)))
    
    # there is no performance impact from crossJoin since we have only one row on the right table which we broadcast (most likely Spark will broadcast it anyway)
    df.crossJoin(broadcast(stats_df)).show() 
    
    # +-----+----+----+--------+--------+---------+--------+--------+---------+
    # |index|col1|col2|min_col1|max_col1|mean_col1|min_col2|max_col2|mean_col2|
    # +-----+----+----+--------+--------+---------+--------+--------+---------+
    # |    1| 2.3|   1|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # |    2| 5.3|   2|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # |    3| 2.1|   4|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # |    4| 1.5|   5|     1.5|     5.3|      2.8|       1|       5|      3.0|
    # +-----+----+----+--------+--------+---------+--------+--------+---------+
    

    注意 1: 使用广播我们将避免洗牌,因为广播的 df 将被发送给所有的执行者。

    注意 2: 使用 chain(*agg_cols) 我们将在上一步中创建的元组列表展平。

    更新:

    这是上述程序的执行计划:

    == Physical Plan ==
    *(3) BroadcastNestedLoopJoin BuildRight, Cross
    :- *(3) Scan ExistingRDD[index#196L,col1#197,col2#198L]
    +- BroadcastExchange IdentityBroadcastMode, [id=#274]
       +- *(2) HashAggregate(keys=[], functions=[finalmerge_min(merge min#233) AS min(col1#197)#202, finalmerge_max(merge max#235) AS max(col1#197)#204, finalmerge_avg(merge sum#238, count#239L) AS avg(col1#197)#206, finalmerge_min(merge min#241L) AS min(col2#198L)#208L, finalmerge_max(merge max#243L) AS max(col2#198L)#210L, finalmerge_avg(merge sum#246, count#247L) AS avg(col2#198L)#212])
          +- Exchange SinglePartition, [id=#270]
             +- *(1) HashAggregate(keys=[], functions=[partial_min(col1#197) AS min#233, partial_max(col1#197) AS max#235, partial_avg(col1#197) AS (sum#238, count#239L), partial_min(col2#198L) AS min#241L, partial_max(col2#198L) AS max#243L, partial_avg(col2#198L) AS (sum#246, count#247L)])
                +- *(1) Project [col1#197, col2#198L]
                   +- *(1) Scan ExistingRDD[index#196L,col1#197,col2#198L]
    

    在这里,我们看到 SinglePartitionBroadcastExchange 广播单行,因为 stats_df 可以放入 SinglePartition。因此,这里被洗牌的数据只有一行(可能的最小值)。

    【讨论】:

    • 我记得阅读交叉连接很昂贵。为此创建了窗口函数:将 agg 值带到每一行。
    • 在这种情况下不是@Kenny。这里的 crossJoin 将在 df 和 stats_df 之间进行,最后一个只有一行。在这种情况下,程序会将 stats_df 视为广播值。该行将在每个执行程序上进行广播,并且连接只会将 stats_df 行附加到 df 的行
    • @Kenny 我修改了答案,添加了查询的执行计划,表明不会有洗牌,因为它只是一个广播交换。如果您需要有关交换类型的更多信息,请观看 David Vrba 的video
    • 如果您在整数/双列上查找最大值/最小值,这是最好的解决方案。如果您需要在 Timestamp 列中找到最大值,这不是最佳解决方案。 Timestamp 上的连接可能会导致大量数据帧出现问题。我个人经历过这种连接的数据拆分。所以在这种情况下,我建议使用 Window.... 或者至少在加入之前转换为较长的时间戳!
    【解决方案2】:

    我们也可以在窗口函数中不指定 orderby,partitionBy 子句 min("<col_name>").over()

    Example:

    //sample data
    val df=Seq((1,2,3),(4,5,6)).toDF("i","j","k")
    
    val df1=df.columns.foldLeft(df)((df, c) => {
      df.withColumn(s"${c}_min",min(col(s"${c}")).over()).
      withColumn(s"${c}_max",max(col(s"${c}")).over()).
      withColumn(s"${c}_mean",mean(col(s"${c}")).over())
    })
    
    df1.show()
    //+---+---+---+-----+-----+------+-----+-----+------+-----+-----+------+
    //|  i|  j|  k|i_min|i_max|i_mean|j_min|j_max|j_mean|k_min|k_max|k_mean|
    //+---+---+---+-----+-----+------+-----+-----+------+-----+-----+------+
    //|  1|  2|  3|    1|    4|   2.5|    2|    5|   3.5|    3|    6|   4.5|
    //|  4|  5|  6|    1|    4|   2.5|    2|    5|   3.5|    3|    6|   4.5|
    //+---+---+---+-----+-----+------+-----+-----+------+-----+-----+------+
    

    【讨论】:

    • 我得到 over() 缺少 1 个必需的位置参数:'window'。似乎需要一个窗口?我正在使用 Pyspark
    • 您可以这样做:1)为整个数据框创建一个具有相同值的列 .withColumn("all", lit("1")) 2)将此列用于窗口:val window= Window.partitionBy("all")
    【解决方案3】:

    环顾四周后,我意识到 Pyspark 2.0+ 还有一个很好的解决方案,其中 over 需要窗口参数:

    from pyspark.sql import Window
    from pyspark.sql.functions import min
    
    df.withColumn(f"{c}_min", min(col(f"{c}")).over(Window.partitionBy()))
    
    

    如果你把partitionBy留空,它不会做任何分区。

    【讨论】:

      猜你喜欢
      • 2018-09-11
      • 2018-06-08
      • 1970-01-01
      • 1970-01-01
      • 2017-01-16
      • 1970-01-01
      • 1970-01-01
      • 2017-11-02
      • 2016-09-16
      相关资源
      最近更新 更多