【问题标题】:GroupBy column and filter rows with maximum value in Pyspark在 Pyspark 中 GroupBy 列和筛选具有最大值的行
【发布时间】:2018-07-27 13:16:12
【问题描述】:

我几乎可以肯定以前有人问过这个问题,但a search through stackoverflow 没有回答我的问题。不是[2] 的重复项,因为我想要最大值,而不是最常见的项目。我是 pyspark 的新手,并试图做一些非常简单的事情:我想对列“A”进行分组,然后只保留在“B”列中具有最大值的每个组的行。像这样:

df_cleaned = df.groupBy("A").agg(F.max("B"))

不幸的是,这会丢弃所有其他列 - df_cleaned 仅包含列“A”和 B 的最大值。我该如何保留这些行? (“A”、“B”、“C”...)

【问题讨论】:

    标签: python apache-spark pyspark apache-spark-sql


    【解决方案1】:

    您可以在没有udf 的情况下使用Window 来执行此操作。

    考虑以下示例:

    import pyspark.sql.functions as f
    data = [
        ('a', 5),
        ('a', 8),
        ('a', 7),
        ('b', 1),
        ('b', 3)
    ]
    df = sqlCtx.createDataFrame(data, ["A", "B"])
    df.show()
    #+---+---+
    #|  A|  B|
    #+---+---+
    #|  a|  5|
    #|  a|  8|
    #|  a|  7|
    #|  b|  1|
    #|  b|  3|
    #+---+---+
    

    创建一个Window 以按列A 进行分区,并使用它来计算每个组的最大值。然后过滤掉行,使得B 列中的值等于最大值。

    from pyspark.sql import Window
    w = Window.partitionBy('A')
    df.withColumn('maxB', f.max('B').over(w))\
        .where(f.col('B') == f.col('maxB'))\
        .drop('maxB')\
        .show()
    #+---+---+
    #|  A|  B|
    #+---+---+
    #|  a|  8|
    #|  b|  3|
    #+---+---+
    

    或者等效地使用pyspark-sql:

    df.registerTempTable('table')
    q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
    sqlCtx.sql(q).show()
    #+---+---+
    #|  A|  B|
    #+---+---+
    #|  b|  3|
    #|  a|  8|
    #+---+---+
    

    【讨论】:

    • 我无法重现此解决方案 (Spark 2.4)。我得到:java.lang.UnsupportedOperationException: Cannot evaluate expression: max(input[1, bigint, false]) windowspecdefinition(input[0, string, true], specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$()))
    • 谢谢@pault 我刚刚收到来自 Databricks 的建议,这是 Spark 2.4 的问题。当他们带着他们的最终分析回来找我时,我应该为社区创建一个问题和答案。
    • @AltShift;作为遇到相同错误的人,无论如何创建问题是否有意义,以便我们其他人有一个可以监控此问题进展的地方?
    • @Jeroen:现已记录:stackoverflow.com/questions/54508608/…
    • @ZaneDufour 不是 spark-sql,但我相信 this answers your question
    【解决方案2】:

    另一种可能的方法是应用加入数据框,其自身指定“leftsemi”。 这种连接包括左侧数据框中的所有列,右侧没有列。

    例如:

    import pyspark.sql.functions as f
    data = [
        ('a', 5, 'c'),
        ('a', 8, 'd'),
        ('a', 7, 'e'),
        ('b', 1, 'f'),
        ('b', 3, 'g')
    ]
    df = sqlContext.createDataFrame(data, ["A", "B", "C"])
    df.show()
    +---+---+---+
    |  A|  B|  C|
    +---+---+---+
    |  a|  5|  c|
    |  a|  8|  d|
    |  a|  7|  e|
    |  b|  1|  f|
    |  b|  3|  g|
    +---+---+---+
    

    可以通过A列选择B列的最大值:

    df.groupBy('A').agg(f.max('B')
    +---+---+
    |  A|  B|
    +---+---+
    |  a|  8|
    |  b|  3|
    +---+---+
    

    将此表达式作为左半联接的右侧,并将得到的列max(B)重命名为原来的名称B,就可以得到需要的结果:

    df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').show()
    +---+---+---+
    |  B|  A|  C|
    +---+---+---+
    |  3|  b|  g|
    |  8|  a|  d|
    +---+---+---+
    

    此解决方案背后的物理计划和接受答案的计划不同,我仍然不清楚哪一个在大型数据帧上表现更好。

    使用 spark SQL 语法做同样的结果可以得到:

    df.registerTempTable('table')
    q = '''SELECT *
    FROM table a LEFT SEMI
    JOIN (
        SELECT 
            A,
            max(B) as max_B
        FROM table
        GROUP BY A
        ) t
    ON a.A=t.A AND a.B=t.max_B
    '''
    sqlContext.sql(q).show()
    +---+---+---+
    |  A|  B|  C|
    +---+---+---+
    |  b|  3|  g|
    |  a|  8|  d|
    +---+---+---+
    

    【讨论】:

    • 问题是关于获得最大值,而不是只保留一行。所以实际上这在不考虑 B 列中的唯一值的情况下有效。无论如何,如果您只想为 A 列的每个值保留一行,您应该选择 df.select("A","B",F.row_number().over(Window.partitionBy("A").orderBy("B", ascending=False)).alias("rn")).filter("rn = 1")
    【解决方案3】:

    只想添加 @ndricca 答案的 scala spark 版本,以防万一有人需要:

    val data = Seq(("a", 5,"c"), ("a",8,"d"),("a",7,"e"),("b",1,"f"),("b",3,"g"))
    val df = data.toDF("A","B","C")
    df.show()
    +---+---+---+
    |  A|  B|  C|
    +---+---+---+
    |  a|  5|  c|
    |  a|  8|  d|
    |  a|  7|  e|
    |  b|  1|  f|
    |  b|  3|  g|
    +---+---+---+
    
    val rightdf = df.groupBy("A").max("B")
    rightdf.show()
    +---+------+
    |  A|max(B)|
    +---+------+
    |  b|     3|
    |  a|     8|
    +---+------+
    
    val resdf = df.join(rightdf, df("B") === rightdf("max(B)"), "leftsemi")
    resdf.show()
    +---+---+---+
    |  A|  B|  C|
    +---+---+---+
    |  a|  8|  d|
    |  b|  3|  g|
    +---+---+---+
    
    

    【讨论】:

    • 有趣,但可能与 pyspark 问题不相关?
    • 我自己正在寻找一种在 scala spark 中实现这一目标的方法,并遇到了这个问题。我确定有人像我一样,只是希望可以节省他们的时间
    • 我同意,如果有更多人发布替代语言解决方案(明确免责声明它适用于另一种语言),我确实会很高兴,因为 Google 的搜索算法经常将一个人带到错误的语言或问题仅在 PySpark / Scala 中回答。
    【解决方案4】:

    有两个很好的解决方案,所以我决定对它们进行基准测试。首先让我定义一个更大的数据框:

    N_SAMPLES = 600000
    N_PARTITIONS = 1000
    MAX_VALUE = 100
    data = zip([random.randint(0, N_PARTITIONS-1) for i in range(N_SAMPLES)],
              [random.randint(0, MAX_VALUE) for i in range(N_SAMPLES)],
              list(range(N_SAMPLES))
              )
    df = spark.createDataFrame(data, ["A", "B", "C"])
    df.show()
    +---+---+---+
    |  A|  B|  C|
    +---+---+---+
    |118| 91|  0|
    |439| 80|  1|
    |779| 77|  2|
    |444| 14|  3|
    ...
    

    基准测试@pault 的解决方案:

    %%timeit
    w = Window.partitionBy('A')
    df_collect = df.withColumn('maxB', f.max('B').over(w))\
        .where(f.col('B') == f.col('maxB'))\
        .drop('maxB')\
        .collect()
    

    给予

    655 ms ± 70.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    基准测试@ndricca 的解决方案:

    %%timeit
    df_collect = df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').collect()
    

    给予

    1 s ± 49.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    所以,@pault 的解决方案似乎快了 1.5 倍。非常欢迎对此基准提供反馈。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-07-23
      • 1970-01-01
      • 1970-01-01
      • 2019-06-25
      相关资源
      最近更新 更多