【问题标题】:Pyspark checking if any of the rows is greater then zeroPyspark 检查是否有任何行大于零
【发布时间】:2020-04-27 17:44:13
【问题描述】:

我想过滤掉列表中所有列的值为零的行。

假设例如我们有以下df,

df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])
+---+---+---+---+---+                                                           
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  0|  0|  1|  0|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+

列的列表是 ['a', 'b', 'd'] 所以过滤后的数据框应该是,

+---+---+---+---+---+                                                           
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+

这是我尝试过的,

df = df.withColumn('total', sum(df[col] for col in ['a', 'b', 'd']))
df = df.filter(df.total > 0).drop('total')

这适用于小型数据集,但如果 col_list 很长并出现以下错误,则会失败并出现以下错误。

ava.lang.StackOverflowErrorat org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables.org$apache$spark$sql$catalyst$analysis$ResolveLambdaVariables$$resolve(更高...

我可以想到一个 pandas udf 解决方案,但我的 df 非常大,这可能是一个瓶颈。

编辑:

使用@Psidom 的答案时,我收到以下错误

py4j.protocol.Py4JJavaError:调用 o2508.filter 时出错。 : java.lang.StackOverflowError 在 org.apache.spark.sql.catalyst.expressions.Expression.references(Expression.scala:88) 在 org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$references$1.apply(Expression.scala:88) 在 org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$references$1.apply(Expression.scala:88) 在 scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) 在 scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) 在 scala.collection.immutable.List.foreach(List.scala:392) 在 scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241) 在 scala.collection.immutable.List.flatMap(List.scala:355)

【问题讨论】:

    标签: python dataframe apache-spark pyspark apache-spark-sql


    【解决方案1】:

    您可以将列作为 array 传递给 UDF,然后检查所有值是否为零,然后应用过滤器:

    from pyspark.sql.types import BooleanType
    from pyspark.sql.functions import udf, array, col
    
    all_zeros_udf = udf(lambda arr: arr.count(0) == len(arr), BooleanType())
    
    df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])
    
    df
    .withColumn('all_zeros', all_zeros_udf(array('a', 'b', 'd'))) # pass the columns as array
    .filter(~col('all_zeros')) # Filter the columns where all values are NOT zeros
    .drop('all_zeros')  # Drop the column
    .show()
    

    结果:

    +---+---+---+---+---+
    |  a|  b|  c|  d|  e|
    +---+---+---+---+---+
    |  0|  1|  1|  2|  1|
    |  1|  0|  1|  1|  1|
    +---+---+---+---+---+
    

    【讨论】:

    • 谢谢,当我们有数百万行和数千列时,您认为使用 udf 是个好主意吗?
    • 取决于你如何使用它。我们每天使用许多 udf 来处理 TB 的数据。
    【解决方案2】:

    functools.reduce 可能在这里有用:

    df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], 
         ['a', 'b', 'c', 'd', 'e'])
    cols = ['a', 'b', 'd']
    

    使用reduce 创建过滤器表达式:

    from functools import reduce
    predicate = reduce(lambda a, b: a | b, [df[x] != 0 for x in cols])
    
    print(predicate)
    # Column<b'(((NOT (a = 0)) OR (NOT (b = 0))) OR (NOT (d = 0)))'>
    

    然后filterpredicate

    df.where(predicate).show()
    +---+---+---+---+---+
    |  a|  b|  c|  d|  e|
    +---+---+---+---+---+
    |  0|  1|  1|  2|  1|
    |  1|  0|  1|  1|  1|
    +---+---+---+---+---+
    

    【讨论】:

    • 这仅适用于列数较少的情况,在我的列数为数千的情况下,我得到一个错误,错误日志被添加到问题中
    【解决方案3】:

    这是一个不同的解决方案。尚未尝试过大量列,如果可行,请告诉我。

    df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])
    df.show()
    
    +---+---+---+---+---+
    |  a|  b|  c|  d|  e|
    +---+---+---+---+---+
    |  0|  1|  1|  2|  1|
    |  0|  0|  1|  0|  1|
    |  1|  0|  1|  1|  1|
    +---+---+---+---+---+
    
    df = df.withColumn("Concat_cols" , F.concat(*list_of_cols)) # concat the list of columns 
    df.show()
    
    +---+---+---+---+---+-----------+
    |  a|  b|  c|  d|  e|Concat_cols|
    +---+---+---+---+---+-----------+
    |  0|  1|  1|  2|  1|        012|
    |  0|  0|  1|  0|  1|        000|
    |  1|  0|  1|  1|  1|        101|
    +---+---+---+---+---+-----------+
    
    pattern =  '0' * len(list_of_cols) 
    
    df1 = df.where(df['Concat_cols'] != pattern) # pattern will be 0's and the number will be equal to length of the columns list.
    df1.show()
    
        +---+---+---+---+---+-----------+
        |  a|  b|  c|  d|  e|Concat_cols|
        +---+---+---+---+---+-----------+
        |  0|  1|  1|  2|  1|        012|
        |  1|  0|  1|  1|  1|        101|
        +---+---+---+---+---+-----------+
    

    【讨论】:

    • 像魅力一样工作,但不确定将这么长的字符串存储在具有数百万行的 df 中是否是个好主意。我认为在性能方面,UDF 会更好
    • 当然,如果您不想拥有包含这些字符串数据的新数据框,那么您可以将这两者组合成一行代码。 df1 = df.where(F.concat(*list_of_cols) != pattern)
    【解决方案4】:

    如果目的只是检查所有列中出现的0 并且列表导致问题,则可能一次将它们组合1000,然后测试非零出现。

    from pyspark.sql import functions as F
    
    # all or whatever columns you would like to test.
    columns = df.columns 
    
    # Columns required to be concatenated at a time.
    split = 1000 
    
    # list of 1000 columns concatenated into a single column
    blocks = [F.concat(*columns[i*split:(i+1)*split]) 
                for i in range((len(columns)+split-1)//split)]
    
    # where expression here replaces zeroes to check if the resultant string is blank or not.
    (df.select("*")
        .where(F.regexp_replace(F.concat(*blocks).alias("concat"), "0", "") != "" )
        .show(10, False))
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2020-02-10
      • 2018-01-31
      • 2014-12-19
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多