【问题标题】:Spark: How to flatten nested arrays with different shapesSpark:如何展平不同形状的嵌套数组
【发布时间】:2021-08-27 09:06:43
【问题描述】:

如何在 PySpark 中展平不同形状的嵌套数组?这里用 same shape arrays 回答How to flatten nested arrays by merging values in spark。对于具有不同形状的数组,我收到如下所述的错误。

数据结构:

  • 静态名称:iddatevalnum(可以硬编码)
  • 动态名称:name_1_aname_10000_xvz(无法硬编码,因为数据框有多达 10000 个列/数组)

输入df:

root
 |-- id: long (nullable = true)
 |-- name_10000_xvz: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- date: long (nullable = true)
 |    |    |-- num: long (nullable = true)  **NOTE: additional `num` field **
 |    |    |-- val: long (nullable = true)
 |-- name_1_a: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- date: long (nullable = true)
 |    |    |-- val: long (nullable = true)



df.show(truncate=False)
+---+---------------------------------------------------------------------+---------------------------------+
|id |name_10000_xvz                                                       |name_1_a                         |
+---+---------------------------------------------------------------------+---------------------------------+
|1  |[{2000, null, 30}, {2001, null, 31}, {2002, null, 32}, {2003, 1, 33}]|[{2001, 1}, {2002, 2}, {2003, 3}]|
+---+---------------------------------------------------------------------+---------------------------------+

所需的输出df:

+---+--------------+----+---+---+
| id|          name|date|val|num|
+---+--------------+----+---+---+
|  1|      name_1_a|2001|  1|   |
|  1|      name_1_a|2002|  2|   |
|  1|      name_1_a|2003|  3|   |
|  1|name_10000_xvz|2000| 30|   |
|  1|name_10000_xvz|2001| 31|   |
|  1|name_10000_xvz|2002| 32|   |
|  1|name_10000_xvz|2003| 33| 1 |
+---+--------------+----+---+---+

要重现的代码:

注意:当我在 TRANSFORM({name}, el -> STRUCT("{name}" AS name, el.date, el.val, el.num 中添加 el.num 时,出现以下错误。

import pyspark.sql.functions as f


df = spark.read.json(
    sc.parallelize(
        [
            """{"id":1,"name_1_a":[{"date":2001,"val":1},{"date":2002,"val":2},{"date":2003,"val":3}],"name_10000_xvz":[{"date":2000,"val":30},{"date":2001,"val":31},{"date":2002,"val":32},{"date":2003,"val":33, "num":1}]}"""
        ]
    )
).select("id", "name_1_a", "name_10000_xvz")

names = [column for column in df.columns if column.startswith("name_")]

expressions = []
for name in names:
    expressions.append(
        f.expr(
            'TRANSFORM({name}, el -> STRUCT("{name}" AS name, el.date, el.val, el.num))'.format(
                name=name
            )
        )
    )

flatten_df = df.withColumn("flatten", f.flatten(f.array(*expressions))).selectExpr(
    "id", "inline(flatten)"
)

输出:

AnalysisException: No such struct field num in date, Val

【问题讨论】:

  • “name_1_a”不包含“num”字段是否正常?这就是你例外的原因。问题是:当字段值“缺失”时,该值是设置为 NULL 还是该字段完全缺失?
  • 是的,字段 num 不见了,这是我的问题。
  • 我可以解决 pandas 中的问题,如本示例 stackoverflow.com/questions/68941232/… 所示,但不能解决 pyspark 。
  • 你输入的是 json 吗?你能多展示一点吗?更多行和更多列的原始格式?

标签: python apache-spark pyspark apache-spark-sql


【解决方案1】:

您需要单独 explode 每个数组,可能使用 UDF 来完成缺失值和 unionAll 每个新创建的数据帧。这是 pyspark 部分。 对于 python 部分,您只需要遍历不同的列并让魔法发生:

from functools import reduce
from pyspark.sql import functions as F, types as T


@F.udf(T.MapType(T.StringType(), T.LongType()))
def add_missing_fields(name_col):
    out = {}
    expected_fields = ["date", "num", "val"]
    for field in expected_fields:
        if field in name_col:
            out[field] = name_col[field]
        else:
            out[field] = None
    return out


flatten_df = reduce(
    lambda a, b: a.unionAll(b),
    (
        df.withColumn(col, F.explode(col))
        .withColumn(col, add_missing_fields(F.col(col)))
        .select(
            "id",
            F.lit(col).alias("name"),
            F.col(col).getItem("date").alias("date"),
            F.col(col).getItem("val").alias("val"),
            F.col(col).getItem("num").alias("num"),
        )
        for col in df.columns
        if col != "id"
    ),
)

结果如下:

flatten_df.show()
+---+--------------+----+---+----+
| id|          name|date|val| num|
+---+--------------+----+---+----+
|  1|      name_1_a|2001|  1|null|
|  1|      name_1_a|2002|  2|null|
|  1|      name_1_a|2003|  3|null|
|  1|name_10000_xvz|2000| 30|null|
|  1|name_10000_xvz|2001| 31|null|
|  1|name_10000_xvz|2002| 32|null|
|  1|name_10000_xvz|2003| 33|   1|
+---+--------------+----+---+----+

不使用unionAll的另一种解决方案:

c = [col for col in df.columns if col != "id"]

@F.udf(T.ArrayType(T.MapType(T.StringType(), T.LongType())))
def add_missing_fields(name_col):
    out = []
    expected_fields = ["date", "num", "val"]
    for elt in name_col:
        new_map = {}
        for field in expected_fields:
            if field in elt:
                new_map[field] = elt[field]
            else:
                new_map[field] = None
        out.append(new_map)
    return out

df1 = reduce(
    lambda a, b: a.withColumn(
        b, F.struct(F.lit(b).alias("name"), add_missing_fields(b).alias("values"))
    ),
    c,
    df,
)

df2 = (
    df1.withColumn("names", F.explode(F.array(*(F.col(col) for col in c))))
    .withColumn("value", F.explode("names.values"))
    .select(
        "id",
        F.col("names.name").alias("name"),
        F.col("value").getItem("date").alias("date"),
        F.col("value").getItem("val").alias("val"),
        F.col("value").getItem("num").alias("num"),
    )
)

结果:

df2.show()
+---+--------------+----+---+----+                                              
| id|          name|date|val| num|
+---+--------------+----+---+----+
|  1|      name_1_a|2001|  1|null|
|  1|      name_1_a|2002|  2|null|
|  1|      name_1_a|2003|  3|null|
|  1|name_10000_xvz|2000| 30|null|
|  1|name_10000_xvz|2001| 31|null|
|  1|name_10000_xvz|2002| 32|null|
|  1|name_10000_xvz|2003| 33|   1|
+---+--------------+----+---+----+

【讨论】:

  • 非常感谢您的帮助,我没有考虑这个解决方案,我会测试一下。
  • @dan 抱歉,我忘记将“missing_field”函数应用于数据帧。这是我的编辑。
  • @StevenI 感谢您的解决方案,它在小型数据集上运行良好,但在较大的数据集上unionAll 只是在加入大量数据帧时获得堆栈。知道如何避免使用unionAll 吗?
  • @dan 我添加了另一个没有unionAll 的解决方案。最适合您的是修复输入中的架构,这样您就不必使用 UDF。
  • @dan 是的,没错,但我认为,这很简单:c = [col for col in df.columns if col != "id"]。我编辑了代码。
【解决方案2】:

您可以只使用explode 函数,然后读取各个列的值并根据需要为每个数组列创建额外的列,并在最后联合两个数据框以获得所需的输出。

//Sample Data creation and reading
df = spark.read.json(sc.parallelize([
    """{"id":1,"name_1_a":[{"date":2001,"val":1},{"date":2002,"val":2},{"date":2003,"val":3}],"name_10000_xvz":[{"date":2000,"val":30},{"date":2001,"val":31},{"date":2002,"val":32},{"date":2003,"val":33, "num":1}]}"""
])).select('id','name_1_a', 'name_10000_xvz')
//using explode and creating dataframe for one of the column
import pyspark.sql.functions as f
df1 = df.withColumn("name_1_a_array",f.explode(f.col("name_1_a")))
.select("id","name_1_a_array").withColumn("date",f.col("name_1_a_array.date"))
.withColumn("val",f.col("name_1_a_array.val")).withColumn("num",f.lit("null"))
.drop("name_1_a_array")
//using explode and creating dataframe for second column
df2 = df.withColumn("name_10000_xvz_array",f.explode(f.col("name_10000_xvz")))
.select("id","name_10000_xvz_array").withColumn("date",f.col("name_10000_xvz_array.date"))
.withColumn("val",f.col("name_10000_xvz_array.val")).withColumn("num",f.col("name_10000_xvz_array.num"))
.drop("name_10000_xvz_array")
//union of both the dataframes
df3 = df1.union(df2)
display(df3)

您可以根据需要看到如下输出:

【讨论】:

  • 感谢您的帮助。不确定我是否说清楚了,但数组名称可以是任何字符串,并且我有大约 10000 个,因此使用该解决方案,我需要多次硬编码所有名称...
  • 您必须为每列仅硬编码一次,这将是架构的一部分,因此如果我正确理解您的问题,这在我看来应该不是问题
  • 道歉我会更新这个问题,我有10000列的数据框name_1_a ... name_10000_xvz,我会尽量让它更清楚。
  • 不是我,我更喜欢在这方面有不同的答案,但我认为原因是您的解决方案不能用于我的情况,因为我需要硬编码 10000 次不同的列。 ..
猜你喜欢
  • 1970-01-01
  • 2023-03-03
  • 1970-01-01
  • 2021-10-12
  • 2023-03-03
  • 2017-12-26
  • 1970-01-01
  • 2022-01-19
  • 2012-11-21
相关资源
最近更新 更多