没有重复
array_intersect 允许你想要实现的操作。
array_intersect 不允许重复,(即)如果column_b 的值为["name", "name"],那么column_c 将包含一次["name"]。
from pyspark.sql import functions as F
data = [("Aaaa", ["name", "age", "subject"],),
("Bbbb", ["name", "age", "country", "subject"],),
("Cccc", ["name", "subject", "percentage"],),
("Dddd", ["name", "name"],),]
df = spark.createDataFrame(data, ("column_a", "column_b",))
lst=['name','age','country']
lit_lst = [F.lit(v) for v in lst]
df.withColumn("column_c", F.array_intersect(F.col("column_b"), F.array(lit_lst))).show(truncate=False)
输出
+--------+-----------------------------+--------------------+
|column_a|column_b |column_c |
+--------+-----------------------------+--------------------+
|Aaaa |[name, age, subject] |[name, age] |
|Bbbb |[name, age, country, subject]|[name, age, country]|
|Cccc |[name, subject, percentage] |[name] |
|Dddd |[name, name] |[name] |
+--------+-----------------------------+--------------------+
保留重复项
为了保留重复项,可以应用filter 高阶函数。
from pyspark.sql import functions as F
data = [("Aaaa", ["name", "age", "subject"],),
("Bbbb", ["name", "age", "country", "subject"],),
("Cccc", ["name", "subject", "percentage"],),
("Dddd", ["name", "name"],),]
df = spark.createDataFrame(data, ("column_a", "column_b",))
df.withColumn("column_c", F.array(lit_lst))\
.withColumn("column_c", F.expr("filter(column_b, element -> array_contains(column_c, element))"))\
.show(truncate=False)
输出
+--------+-----------------------------+--------------------+
|column_a|column_b |column_c |
+--------+-----------------------------+--------------------+
|Aaaa |[name, age, subject] |[name, age] |
|Bbbb |[name, age, country, subject]|[name, age, country]|
|Cccc |[name, subject, percentage] |[name] |
|Dddd |[name, name] |[name, name] |
+--------+-----------------------------+--------------------+