【发布时间】:2021-03-09 17:34:06
【问题描述】:
考虑简单的DataFrame:
from pyspark import SparkContext
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType
spark = SparkSession.builder.appName('Trial').getOrCreate()
simpleData = (("2000-04-17", "144", 1), \
("2000-07-06", "015", 1), \
("2001-01-23", "015", -1), \
("2001-01-18", "144", -1), \
("2001-04-17", "198", 1), \
("2001-04-18", "036", -1), \
("2001-04-19", "012", -1), \
("2001-04-19", "188", 1), \
("2001-04-25", "188", 1),\
("2001-04-27", "015", 1) \
)
columns= ["dates", "id", "eps"]
df = spark.createDataFrame(data = simpleData, schema = columns)
df.printSchema()
df.show(truncate=False)
输出:
root
|-- dates: string (nullable = true)
|-- id: string (nullable = true)
|-- eps: long (nullable = true)
+----------+---+---+
|dates |id |eps|
+----------+---+---+
|2000-04-17|144|1 |
|2000-07-06|015|1 |
|2001-01-23|015|-1 |
|2001-01-18|144|-1 |
|2001-04-17|198|1 |
|2001-04-18|036|-1 |
|2001-04-19|012|-1 |
|2001-04-19|188|1 |
|2001-04-25|188|1 |
|2001-04-27|015|1 |
+----------+---+---+
我想对滚动窗口中的 eps 列中的值求和,仅保留 id 列中任何给定 ID 的最后一个值。例如,定义一个 5 行的窗口并假设我们在 2001 年 4 月 17 日,我只想总结每个给定唯一 ID 的最后一个 eps 值。在 5 行中,我们只有 3 个不同的 ID,因此总和必须是 3 个元素:-1 表示 ID 144(第四行),-1 表示 ID 015(第三行)和 1 表示 ID 198(第五行) 总共为 -1。
在我看来,在滚动窗口内我应该做类似F.sum(groupBy('id').agg(F.last('eps'))) 之类的事情,这在滚动窗口中当然是不可能实现的。
我使用 UDF 获得了想要的结果。
@pandas_udf(IntegerType(), PandasUDFType.GROUPEDAGG)
def fun_sum(id, eps):
df = pd.DataFrame()
df['id'] = id
df['eps'] = eps
value = df.groupby('id').last().sum()
return value
然后:
w = Window.orderBy('dates').rowsBetween(-5,0)
df = df.withColumn('sum', fun_sum(F.col('id'), F.col('eps')).over(w))
问题是我的数据集包含超过 800 万行,使用此 UDF 执行此任务大约需要 2 小时。
我在想是否有办法通过内置 PySpark 函数避免使用 UDF 来实现相同的结果,或者至少是否有办法提高我的 UDF 的性能。
为了完整起见,期望的输出应该是:
+----------+---+---+----+
|dates |id |eps|sum |
+----------+---+---+----+
|2000-04-17|144|1 |1 |
|2000-07-06|015|1 |2 |
|2001-01-23|015|-1 |0 |
|2001-01-18|144|-1 |-2 |
|2001-04-17|198|1 |-1 |
|2001-04-18|036|-1 |-2 |
|2001-04-19|012|-1 |-3 |
|2001-04-19|188|1 |-1 |
|2001-04-25|188|1 |0 |
|2001-04-27|015|1 |0 |
+----------+---+---+----+
编辑:结果也必须可以使用.rangeBetween() 窗口来实现。
【问题讨论】:
标签: python apache-spark pyspark apache-spark-sql window