【发布时间】:2020-05-27 00:55:22
【问题描述】:
我正在尝试应用 PySpark Window 函数来执行“指数衰减”。公式是
todays_score = yesterdays_score * (weight) + todays_raw_score
例如,假设我们有一个按天排序的数据框,并且每天的得分为 1:
+---+----+---------+
|day|user|raw_score|
+---+----+---------+
| 0| a| 1|
| 1| a| 1|
| 2| a| 1|
| 3| a| 1|
+---+----+---------+
如果我要计算 todays_score,应该是这样的:
+---+----+---------+------------+
|day|user|raw_score|todays_score| # Here's the math:
+---+----+---------+------------+
| 0| a| 1| 1.0| (0 * .90) + 1
| 1| a| 1| 1.9| (1.0 * .90) + 1
| 2| a| 1| 2.71| (1.9 * .90) + 1
| 3| a| 1| 3.439| (2.71 * .90) + 1
+---+----+---------+------------+
我尝试过使用窗口函数;但是根据我所见,他们只能使用原始数据框中的“静态值”,而不是我们刚刚计算的值。我什至尝试创建一个“虚拟列”来启动该过程;但是这也不起作用。
我尝试的代码:
df = sqlContext.createDataFrame([
(0, 'a', 1),
(1, 'a', 1),
(2, 'a', 1),
(3, 'a', 1)],
['day', 'user', 'raw_score']
)
df.show()
# Create a "dummy column" (weighted score) so we can use it.
df2 = df.select('*', col('raw_score').alias('todays_score'))
df2.show()
w = Window.partitionBy('user')
df2.withColumn('todays_score',
F.lag(F.col('todays_score'), count=1, default=0).over(w.orderBy('day'))* 0.9 + F.col('raw_score')) \
.show()
这个(不需要的)输出是:
+---+----+---------+------------+
|day|user|raw_score|todays_score|
+---+----+---------+------------+
| 0| a| 1| 1.0|
| 1| a| 1| 1.9|
| 2| a| 1| 1.9|
| 3| a| 1| 1.9|
+---+----+---------+------------+
它只取前一个值 * (.90),而不是刚刚计算的值。
如何访问刚刚由窗口函数计算的值?
【问题讨论】:
-
你应该使用 pandas 分组地图 udaf。计算中的 +1 是否取自 raw_score 列?还是 +1 只是一个静态值,你的 spark 版本是什么?
-
@murtihash - 我想提两件重要的事情:(1)性能对我来说是个大问题;我将与成千上万的用户打交道,数百天,数百个分数......所以我有点犹豫是否使用 udaf。如果我错了,请纠正我,但它比原生 Spark SQL 函数慢,不是吗? (2) 是的,+1 取自原始分数列。我还有另一个计算分数的步骤。对于实际值,这些值每天都会有所不同,并且使用起来不太好。
-
所以是的,你是对的,它会比 spark 内置函数慢,但它会比普通 udf 快得多,因为它是一个矢量化的 udf,可以在数据组上执行(groupby用户。地图)。我看到的唯一其他选择是按收集列表分组,使用高阶函数来获得分数,然后分解列表。除此之外,我认为任何其他 spark 函数都不能完成对每一行都是动态的任务。你也可以告诉你的火花版本,因为pandas udaf是2.3+,高阶函数是2.4+
-
@murtihash - 我在下面看到了你的答案,它远远超出了头脑。我想知道你是否可以用熊猫分组地图 udaf 解释如何做到这一点。 *此外,根据我的研究,pandas grouped-map UDF 不适用于有界窗口。我的代码也一直失败。
-
你说的“它过头了”是什么意思,是不是因为 groupby 和爆炸太慢了?还是你不明白逻辑?我也用 pandas 分组地图 udaf 更新了解决方案。请投票/接受答案以关闭线程,干杯
标签: python pyspark window-functions