【问题标题】:pyspark how do we check if a column value is contained in a list [duplicate]pyspark我们如何检查列值是否包含在列表中[重复]
【发布时间】:2019-07-12 12:48:56
【问题描述】:

我试图弄清楚是否有一个函数可以检查 spark DataFrame 的列是否包含列表中的任何值:

# define a dataframe
rdd = sc.parallelize([(0,100), (0,1), (0,2), (1,2), (1,10), (1,20), (3,18), (3,18), (3,18)])
df = sqlContext.createDataFrame(rdd, ["id", "score"])

# define a list of scores
l = [1]

# filter out records by scores by list l
records = df.filter(~df.score.contains(l))

# expected: (0,100), (0,1), (1,10), (3,18)

运行此代码时遇到问题:

java.lang.RuntimeException: Unsupported literal type class java.util.ArrayList [1]

有没有办法做到这一点,还是我们必须遍历列表才能传递包含?

【问题讨论】:

  • 你能解释一下这背后的逻辑是什么吗?为什么(0, 1) 在结果中,而(0,2) 在结果中?
  • @Psidom .. 我试图找出分数是否包含值 1 所以 (0, 1) 是 1 的分数和 (0,2) 是 2 的分数.. 所以 (0 ,2) 不包括在内.. 就像循环每个值一样,但我需要将其作为包含来执行,因为它不是相等性检查.. 这有意义吗
  • 那么为什么要包括 100、10 和 18?

标签: apache-spark pyspark apache-spark-sql


【解决方案1】:

我看到了一些方法来做到这一点without using a udf

您可以对pyspark.sql.functions.regexp_extract 使用列表推导,利用如果没有匹配项则返回空字符串这一事实。

尝试提取列表l 中的所有值并连接结果。如果生成的连接字符串是空字符串,则表示没有匹配的值。

例如:

from pyspark.sql.functions import concat, regexp_extract

records = df.where(concat(*[regexp_extract("score", str(val), 0) for val in l]) != "")
records.show()
#+---+-----+
#| id|score|
#+---+-----+
#|  0|  100|
#|  0|    1|
#|  1|   10|
#|  3|   18|
#|  3|   18|
#|  3|   18|
#+---+-----+

如果您查看执行计划,您会发现它足够聪明地将score 列隐式转换为string

records.explain()
#== Physical Plan ==
#*Filter NOT (concat(regexp_extract(cast(score#11L as string), 1, 0)) = )
#+- Scan ExistingRDD[id#10L,score#11L]

另一种方法是使用pyspark.sql.Column.like(或与rlike类似):

from functools import reduce
from pyspark.sql.functions import col

records = df.where(
    reduce(
        lambda a, b: a|b, 
        map(
            lambda val: col("score").like(val.join(["%", "%"])), 
            map(str, l)
        )
    )
)

产生与上面相同的输出并具有以下执行计划:

#== Physical Plan ==
#*Filter Contains(cast(score#11L as string), 1)
#+- Scan ExistingRDD[id#10L,score#11L]

如果您只想要不同的记录,您可以这样做:

records.distinct().show()
#+---+-----+
#| id|score|
#+---+-----+
#|  0|    1|
#|  0|  100|
#|  3|   18|
#|  1|   10|
#+---+-----+

【讨论】:

  • @gaw.. 两种解决方案都有效,但我没有使用 udf 解决方案来提高性能..
【解决方案2】:

如果我对您的理解正确,您希望在您的情况下拥有一个包含元素的列表,它只有 1。您要检查此元素是否出现在乐谱中的位置。在这种情况下,使用字符串而不是直接使用数字更容易。

您可以使用自定义地图功能执行此操作,并通过 udf 应用此功能(直接应用会导致一些奇怪的行为并且仅在某些时候有效)。

找到下面的代码:

rdd = sc.parallelize([(0,100), (0,1), (0,2), (1,2), (1,10), (1,20), (3,18), (3,18), (3,18)])
df = sqlContext.createDataFrame(rdd, ["id", "score"])
l = [1]

def filter_list(score, l):
    found = True
    for e in l:
        if str(e) not in str(score):  #The filter that checks if an Element e
            found = False             #does not appear in the score
    if found:
        return True                   #boolean value if the all elements were found
    else:
        return False

def udf_filter(l):
    return udf(lambda score: filter_list(score, l)) #make a udf function out of the filter list
df.withColumn("filtered", udf_filter(l)(col("score"))).filter(col("filtered")==True).drop("filtered").show()
#apply the function and store results in "filtered" column afterwards 
#only select the successful filtered rows and drop the column

输出:

+---+-----+
| id|score|
+---+-----+
|  0|  100|
|  0|    1|
|  1|   10|
|  3|   18|
|  3|   18|
|  3|   18|
+---+-----+

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2018-06-20
    • 2018-03-13
    • 1970-01-01
    • 2021-11-01
    • 2019-12-01
    • 1970-01-01
    • 2020-07-21
    • 1970-01-01
    相关资源
    最近更新 更多