【问题标题】:PySpark Dataframe take mean of list within column and create new column with 1 & 0 depending on a conditionPySpark Dataframe 取列内列表的平均值,并根据条件创建具有 1 和 0 的新列
【发布时间】:2021-11-27 12:00:21
【问题描述】:

我正在尝试计算 PySpark Dataframe 列中列表(成本)的平均值,小于平均值的值得到 1,高于平均值的值为 0。

这是当前的数据框:

+----------+--------------------+--------------------+
|        id|  collect_list(p_id)|collect_list(cost)  |
+----------+--------------------+--------------------+
|         7|[10, 987, 872]      |[12.0, 124.6, 197.0]|
|         6|[11, 858, 299]      |[15.0, 167.16, 50.0]|
|        17|                 [2]|           [65.4785]|
|         1|[34359738369, 343...|[16.023384, 104.9...|
|         3|[17179869185, 0, ...|[48.3255, 132.025...|
+----------+--------------------+--------------------+

这是所需的输出:

+----------+--------------------+--------------------+-----------+
|        id|    p_id            |cost                | result    |
+----------+--------------------+--------------------+-----------+
|         7|10                  |12.0                |  1        |
|         7|987                 |124.6               |  0        |
|         7|872                 |197.0               |  0        |
|         6|11                  |15.0                |  1        |
|         6|858                 |167.16              |  0        |
|         6|299                 |50.0                |  1        |
|        17|2                   |65.4785             |  1        |
+----------+--------------------+--------------------+-----------+

【问题讨论】:

    标签: python pyspark apache-spark-sql


    【解决方案1】:
    from pyspark.sql.functions import col, mean
    
    #sample data
    df = sc.parallelize([(7,[10, 987, 872],[12.0, 124.6, 197.0]),
                         (6,[11, 858, 299],[15.0, 167.16, 50.0]),
                         (17,[2],[65.4785])]).toDF(["id", "collect_list(p_id)","collect_list(cost)"])
    
    #unpack collect_list in desired output format
    df = df.rdd.flatMap(lambda row: [(row[0], x, y) for x,y in zip(row[1],row[2])]).toDF(["id", "p_id","cost"])
    df1 = df.\
        join(df.groupBy("id").agg(mean("cost").alias("mean_cost")), "id", 'left').\
        withColumn("result",(col("cost") <= col("mean_cost")).cast("int")).\
        drop("mean_cost")
    df1.show()
    

    输出是:

    +---+----+-------+------+
    | id|p_id|   cost|result|
    +---+----+-------+------+
    |  7|  10|   12.0|     1|
    |  7| 987|  124.6|     0|
    |  7| 872|  197.0|     0|
    |  6|  11|   15.0|     1|
    |  6| 858| 167.16|     0|
    |  6| 299|   50.0|     1|
    | 17|   2|65.4785|     1|
    +---+----+-------+------+
    

    【讨论】:

      【解决方案2】:

      您可以为每一行创建一个结果列表,然后压缩 pid、成本和结果列表。之后在压缩列上使用爆炸。

      from pyspark.sql.functions import udf, explode
      from pyspark.sql.types import *
      def zip_cols(pid_list,cost_list):
          mean = np.mean(cost_list)
          res_list = list(map(lambda cost:1 if mean >= cost else 0,cost_list))
          return[(x,y,z) for x,y,z in zip(pid_list, cost_list, res_list)]
      
      udf_zip = udf(zip_cols, ArrayType(StructType([StructField("pid",IntegerType()),
                                                    StructField("cost", DoubleType()), 
                                                    StructField("result",IntegerType())])))
      df1 = (df.withColumn("temp",udf_zip("collect_list(p_id)","collect_list(cost)")).
              drop("collect_list(p_id)","collect_list(cost)"))
      
      df2 =   (df1.withColumn("temp",explode(df1.temp)).
              select("id",col("temp.pid").alias("pid"),
                     col("temp.cost").alias("cost"),
                     col("temp.result").alias("result")))
      df2.show()
      

      输出

      +---+---+-------+------+
      | id|pid|   cost|result|
      +---+---+-------+------+
      |  7| 10|   12.0|     1|
      |  7| 98|  124.6|     0|
      |  7|872|  197.0|     0|
      |  6| 11|   15.0|     1|
      |  6|858| 167.16|     0|
      |  6|299|   50.0|     1|
      | 17|  2|65.4758|     1|
      +---+---+-------+------+
      

      【讨论】:

      • 谢谢你 ashwinids!这真的很有帮助,非常感谢:) 只是想强调我得到的两个非常微小的错误,只是为了其他读者的利益。 1) 需要从 pyspark.sql.types 导入必要的类型 2).drop() 对多列抛出错误。
      • @mur,我在多个列上调用 drop 时没有出错。
      猜你喜欢
      • 2020-09-05
      • 2019-02-15
      • 1970-01-01
      • 1970-01-01
      • 2023-03-10
      • 2018-10-01
      • 2019-12-27
      • 1970-01-01
      • 2016-09-22
      相关资源
      最近更新 更多