【问题标题】:Apply Udf on an array field(variable length) and split it into columns in pyspark在数组字段(可变长度)上应用 Udf 并将其拆分为 pyspark 中的列
【发布时间】:2021-02-17 20:17:01
【问题描述】:

我想对长度为 可变 (0-4000) 的数组字段应用逻辑并将其拆分为列。具有爆炸、创建新列和重命名列的 udf 可以完成这项工作,但我不确定如何将其作为 udf 迭代应用。 UDF 将采用可变长度数组字段并将一组新列 (0-4000) 返回到数据帧。如下所示的示例输入数据框

+--------------------+--------------------+
|             hashval|    dec_spec (array|
+--------------------+--------------------+
|3c65252a67546832d...|[8.02337424829602...|
|f5448c29403c80ea7...|[7.50372884795069...|
|94ff32cd2cfab9919...|[5.85195317398756...|
+--------------------+--------------------+

输出应该是这样的

+--------------------+--------------------+
    |             hashval|    dec_spec (array|   ftr_1    |    ftr_2 | ftr_3 |...
    +--------------------+--------------------+-----------+---------+--------+
    |3c65252a67546832d...|[8.02337424829602...|  8.023   | 3.21       | 4.23.....
    |f5448c29403c80ea7...|[7.50372884795069...| 7.502    | 8.23       |2.125
    |94ff32cd2cfab9919...|[5.85195317398756...|
    +--------------------+--------------------+

udf 可以采用如下的一些逻辑

df_grp = df2.withColumn("explode_col", F.explode_outer("dec_spec"))
df_grp = df_grp.groupBy("hashval").pivot("explode_col").agg(F.avg("explode_col"))

下面用于重命名列

count = 1
for col in df_grp.columns:
  if col != "hashval":
    df_grp = df_grp.withColumnRenamed(col, "ftr"+str(count))
    count = count+1

感谢任何帮助。

PS 上面的代码,在这里得到了论坛其他人的帮助。

【问题讨论】:

  • 您的列 dec_spec 长度不同?但想爆炸到 4000 列?
  • 数组长度不同。因此,如果特定行只有 20 个元素,那么它将是 20 列。某些行中的数组字段可以有 2048 个元素,那么将有 2048 列。但我需要处理的最大值是 4000。超过此值的任何东西,我都可以扔掉
  • 所以其余列将为空,对吧?
  • 是的;如果第一行只有 20 个元素,第二行数组字段有 40 个元素,那么对于第一行,第 20 列之后将为空表字段将包含长度不等的数组,最大值为 4000。我可以在之后丢弃那个。
  • 查看答案,在我的情况下,我使用 3 作为我的输入长度,请更改为 4000 并适用于您的问题

标签: arrays dataframe pyspark


【解决方案1】:
模拟样本数据
from pyspark.sql import functions as sf
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, StructType, StructField, StringType
sdf1 = sc.parallelize([["aaa", "1,2,3"],["bbb", "1,2,3,4,5"]]).toDF(["hash_val", "arr_str"])
sdf2 = sdf1.withColumn("arr", sf.split("arr_str", ","))
sdf2.show()

+--------+---------+---------------+
|hash_val|  arr_str|            arr|
+--------+---------+---------------+
|     aaa|    1,2,3|      [1, 2, 3]|
|     bbb|1,2,3,4,5|[1, 2, 3, 4, 5]|
+--------+---------+---------------+
udf 使所有数组长度相同
schema = ArrayType(StringType())

def fill_list(input_list, input_length):
    fill_len = input_length - len(input_list)
    if fill_len > 0:
        input_list += [None]*(fill_len)
    
    return input_list[0:input_length]

fill_list_udf = udf(fill_list, schema)
sdf3 = sdf2.withColumn("arr1", fill_list_udf(sf.col("arr"), sf.lit(3)))
sdf3.show()

+--------+---------+---------------+---------+
|hash_val|  arr_str|            arr|     arr1|
+--------+---------+---------------+---------+
|     aaa|    1,2,3|      [1, 2, 3]|[1, 2, 3]|
|     bbb|1,2,3,4,5|[1, 2, 3, 4, 5]|[1, 2, 3]|
+--------+---------+---------------+---------+
展开它们
sdf3.select("hash_val", *[sf.col("arr1")[i] for i in range(3)]).show()
+--------+-------+-------+-------+
|hash_val|arr1[0]|arr1[1]|arr1[2]|
+--------+-------+-------+-------+
|     aaa|      1|      2|      3|
|     bbb|      1|      2|      3|
+--------+-------+-------+-------+

【讨论】:

  • 谢谢。当索引变高时,我会遇到一些错误。想知道如果 fill_len > 0: input_list += [None]*(fill_len) 会做什么,我知道它会在指定长度后填充结尾部分。
  • 当长度达到 90 左右时出现错误。调用 o19712.showString 时出错。 org.apache.spark.SparkException:作业因阶段故障而中止:阶段 753.0 中的任务 0 失败 4 次,最近一次失败:阶段 753.0 中丢失任务 0.3(TID 6414,hpcb08r05d02.hpc.ford.com,执行程序 174): java.util.concurrent.ExecutionException:org.codehaus.commons.compiler.CompileException:文件'generated.java',第1220行,第14列:编译失败:org.codehaus.commons.compiler.CompileException:文件'generated.java ',第 1220 行,第 14 列:表达式“isNull_6​​”不是右值
  • 对,它将小于90的列表扩展到90长度
  • 我将数据更改为更大的值,仍然适用于我的示例,我想知道您的环境中有任何数据差异
  • 是的。可能有一些数据问题。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2021-12-13
  • 2013-10-14
  • 1970-01-01
  • 2016-01-31
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多