【问题标题】:Pyspark UDF to return result similar to groupby().sum() between two columnsPyspark UDF 在两列之间返回类似于 groupby().sum() 的结果
【发布时间】:2019-12-09 00:45:12
【问题描述】:

我有以下示例数据框

fruit_list = ['apple', 'apple', 'orange', 'apple']
qty_list = [16, 2, 3, 1]
spark_df = spark.createDataFrame([(101, 'Mark', fruit_list, qty_list)], ['ID', 'name', 'fruit', 'qty'])

我想创建另一个列,其中包含类似于我使用熊猫groupby('fruit').sum() 实现的结果@

        qty
fruits     
apple    19
orange    3

上述结果可以任何形式(字符串、字典、元组列表...)存储在新列中。

我尝试了一种类似于以下方法但不起作用的方法

sum_cols = udf(lambda x: pd.DataFrame({'fruits': x[0], 'qty': x[1]}).groupby('fruits').sum())
spark_df.withColumn('Result', sum_cols(F.struct('fruit', 'qty'))).show()

结果数据框的一个示例可能是

+---+----+--------------------+-------------+-------------------------+
| ID|name|               fruit|          qty|                   Result|
+---+----+--------------------+-------------+-------------------------+
|101|Mark|[apple, apple, or...|[16, 2, 3, 1]|[(apple,19), (orange,3)] |
+---+----+--------------------+-------------+-------------------------+

您对我如何实现这一点有什么建议吗?

谢谢

编辑:在 Spark 2.4.3 上运行

【问题讨论】:

  • 你想要的输出是什么?从描述中看不清楚,请明确显示。
  • 感谢您的评论,完成!
  • 什么版本的火花?如果是 spark 2.4+,你可以使用array_zip。旧版本使这变得更加困难。
  • 我在 2.4.3 上运行,您能否为我提供一个示例用法?
  • 在我(有限的)经验中,我看到“本机”pyspark 代码的执行速度比 UDF(尤其是 UDAF)快 10 倍,即使在使用 explode 时也是如此。只是要记住的事情..

标签: apache-spark pyspark apache-spark-sql


【解决方案1】:

正如@pault 提到的,从 Spark 2.4+ 开始,您可以使用 Spark SQL 内置函数来处理您的任务,这是 array_distinct + 变换 + 聚合

from pyspark.sql.functions import expr

# set up data
spark_df = spark.createDataFrame([
        (101, 'Mark', ['apple', 'apple', 'orange', 'apple'], [16, 2, 3, 1])
      , (102, 'Twin', ['apple', 'banana', 'avocado', 'banana', 'avocado'], [5, 2, 11, 3, 1])
      , (103, 'Smith', ['avocado'], [10])
    ], ['ID', 'name', 'fruit', 'qty']
)

>>> spark_df.show(5,0)
+---+-----+-----------------------------------------+----------------+
|ID |name |fruit                                    |qty             |
+---+-----+-----------------------------------------+----------------+
|101|Mark |[apple, apple, orange, apple]            |[16, 2, 3, 1]   |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|
|103|Smith|[avocado]                                |[10]            |
+---+-----+-----------------------------------------+----------------+

>>> spark_df.printSchema()
root
 |-- ID: long (nullable = true)
 |-- name: string (nullable = true)
 |-- fruit: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- qty: array (nullable = true)
 |    |-- element: long (containsNull = true)

设置 SQL 语句:

stmt = '''
    transform(array_distinct(fruit), x -> (x, aggregate(
          transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
        , 0
        , (y,z) -> int(y + z)
    ))) AS sum_fruit
'''

>>> spark_df.withColumn('sum_fruit', expr(stmt)).show(10,0)
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|ID |name |fruit                                    |qty             |sum_fruit                               |
+---+-----+-----------------------------------------+----------------+----------------------------------------+
|101|Mark |[apple, apple, orange, apple]            |[16, 2, 3, 1]   |[[apple, 19], [orange, 3]]              |
|102|Twin |[apple, banana, avocado, banana, avocado]|[5, 2, 11, 3, 1]|[[apple, 5], [banana, 5], [avocado, 12]]|
|103|Smith|[avocado]                                |[10]            |[[avocado, 10]]                         |
+---+-----+-----------------------------------------+----------------+----------------------------------------+

解释:

  1. 使用array_distinct(fruit) 查找数组fruit 中的所有不同条目
  2. 将这个新数组(带有元素x)从x 转换为(x, aggregate(..x..))
  3. 上述函数aggregate(..x..)采用简单的形式,将array_T中的所有元素相加

    aggregate(array_T, 0, (y,z) -> y + z)
    

    array_T 来自以下转换:

    transform(sequence(0,size(fruit)-1), i -> IF(fruit[i] = x, qty[i], 0))
    

    遍历数组fruit,如果fruit[i] = x,则返回对应的qty[i],否则返回0。例如对于ID=101,当x = 'orange'时,它返回一个数组[0, 0, 3, 0]

【讨论】:

  • 哇,这是一个非常好的答案,谢谢@jxc!你认为这比@pault 解决方案表现更好吗?
  • @crash 我想这比使用udf +1 更好
  • @pault,您是一位经验丰富的用户,您认为这应该是公认的解决方案吗?
  • 是的,但这是您的决定。
【解决方案2】:

可能有一种奇特的方法可以仅使用 Spark 2.4+ 上的 API 函数,可能是 arrays_zipaggregate 的某种组合,但我想不出任何不涉及 @ 的方法987654326@ 步骤后跟 groupBy。考虑到这一点,在这种情况下,使用 udf 实际上可能更适合您。

我认为创建一个pandas DataFrame 只是为了调用.groupby().sum() 是矫枉过正。此外,即使您确实这样做了,您也需要将最终输出转换为不同的数据结构,因为 udf 无法返回 pandas DataFrame。

这是udf 使用collections.defaultdict 的一种方法:

from collections import defaultdict
from pyspark.sql.functions import udf

def sum_cols_func(frt, qty):
    d = defaultdict(int)
    for x, y in zip(frt, map(int, qty)):
        d[x] += y
    return d.items()

sum_cols = udf(
    lambda x: sum_cols_func(*x),
    ArrayType(
        StructType([StructField("fruit", StringType()), StructField("qty", IntegerType())])
    )
)

然后通过传入fruitqty 列来调用它:

from pyspark.sql.functions import array, col

spark_df.withColumn(
    "Result",
    sum_cols(array([col("fruit"), col("qty")]))
).show(truncate=False)
#+---+----+-----------------------------+-------------+--------------------------+
#|ID |name|fruit                        |qty          |Result                    |
#+---+----+-----------------------------+-------------+--------------------------+
#|101|Mark|[apple, apple, orange, apple]|[16, 2, 3, 1]|[[orange, 3], [apple, 19]]|
#+---+----+-----------------------------+-------------+--------------------------+

【讨论】:

  • 我喜欢您的解决方案 pault,感谢您抽出宝贵时间。但是我收到了这个错误Py4JJavaError: An error occurred while calling o3564.showString....Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype).. 它告诉你什么了吗?
  • 刚刚添加from pyspark.sql.types import ArrayType, StructType, StringType, IntegerType, StructField
  • 还可以尝试在输出return [(k, int(sum(v))) for k, v in d.items()] 中强制使用 Python int 类型,这样 udf 的结果肯定是 Python 原生类型。
  • @RichardNemeth 这是一个很好的观点——这让我想知道:你确定你使用的是__builtin__.sum 而不是numpy.sumpyspark.sql.functions.sumWhy you shouldn't use import *编辑如果你得到一个 numpy 对象,它表明你正在使用numpy.sum
  • 是的,Richard 的解决方案似乎解决了问题!
【解决方案3】:

如果你有spark this answer):

df_split = (spark_df.rdd.flatMap(lambda row: [(row.ID, row.name, f, q) for f, q in zip(row.fruit, row.qty)]).toDF(["ID", "name", "fruit", "qty"]))

df_split.show()

输出:

+---+----+------+---+
| ID|name| fruit|qty|
+---+----+------+---+
|101|Mark| apple| 16|
|101|Mark| apple|  2|
|101|Mark|orange|  3|
|101|Mark| apple|  1|
+---+----+------+---+

然后准备你想要的结果。首先找到聚合的数据框:

df_aggregated = df_split.groupby('ID', 'fruit').agg(F.sum('qty').alias('qty'))
df_aggregated.show()

输出:

+---+------+---+
| ID| fruit|qty|
+---+------+---+
|101|orange|  3|
|101| apple| 19|
+---+------+---+

最后将其更改为所需的格式:

df_aggregated.groupby('ID').agg(F.collect_list(F.struct(F.col('fruit'), F.col('qty'))).alias('Result')).show()

输出:

+---+--------------------------+
|ID |Result                    |
+---+--------------------------+
|101|[[orange, 3], [apple, 19]]|
+---+--------------------------+

【讨论】:

  • udf 在这种情况下可能比explode 具有更好的性能
  • 我不知道,但是我把答案改成了RDD而不是explode。
  • 我正在尝试这个。我认为将代码分布在多行将有助于提高可读性。
  • udf 几乎肯定(肯定?)也比rdd
  • 这也可以,但 pault 解决方案更简洁,更易于理解。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2021-11-21
  • 1970-01-01
  • 2022-11-24
  • 1970-01-01
  • 2021-11-11
  • 2021-06-17
  • 1970-01-01
相关资源
最近更新 更多