【问题标题】:How to pivot Spark DataFrame?如何旋转 Spark DataFrame?
【发布时间】:2015-07-26 12:24:09
【问题描述】:

我开始使用 Spark DataFrames,我需要能够旋转数据以从 1 列和多行中创建多列。 Scalding 中有内置功能,我相信 Python 中的 Pandas,但我找不到新的 Spark Dataframe 的任何东西。

我假设我可以编写某种自定义函数来执行此操作,但我什至不确定如何开始,特别是因为我是 Spark 的新手。如果有人知道如何使用内置功能或有关如何在 Scala 中编写内容的建议,将不胜感激。

【问题讨论】:

  • 查看similar question,我在其中发布了不需要提前知道列/类别名称的原生 Spark 方法。

标签: dataframe apache-spark pyspark apache-spark-sql pivot


【解决方案1】:

As mentioned by David Anderson Spark 从 1.6 版本开始提供 pivot 功能。一般语法如下:

df
  .groupBy(grouping_columns)
  .pivot(pivot_column, [values]) 
  .agg(aggregate_expressions)

使用nycflights13csv格式的使用示例:

Python

from pyspark.sql.functions import avg

flights = (sqlContext
    .read
    .format("csv")
    .options(inferSchema="true", header="true")
    .load("flights.csv")
    .na.drop())

flights.registerTempTable("flights")
sqlContext.cacheTable("flights")

gexprs = ("origin", "dest", "carrier")
aggexpr = avg("arr_delay")

flights.count()
## 336776

%timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
## 10 loops, best of 3: 1.03 s per loop

斯卡拉

val flights = sqlContext
  .read
  .format("csv")
  .options(Map("inferSchema" -> "true", "header" -> "true"))
  .load("flights.csv")

flights
  .groupBy($"origin", $"dest", $"carrier")
  .pivot("hour")
  .agg(avg($"arr_delay"))

Java

import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.*;

Dataset<Row> df = spark.read().format("csv")
        .option("inferSchema", "true")
        .option("header", "true")
        .load("flights.csv");

df.groupBy(col("origin"), col("dest"), col("carrier"))
        .pivot("hour")
        .agg(avg(col("arr_delay")));

R / SparkR

library(magrittr)

flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)

flights %>% 
  groupBy("origin", "dest", "carrier") %>% 
  pivot("hour") %>% 
  agg(avg(column("arr_delay")))

R / sparklyr

library(dplyr)

flights <- spark_read_csv(sc, "flights", "flights.csv")

avg.arr.delay <- function(gdf) {
   expr <- invoke_static(
      sc,
      "org.apache.spark.sql.functions",
      "avg",
      "arr_delay"
    )
    gdf %>% invoke("agg", expr, list())
}

flights %>% 
  sdf_pivot(origin + dest + carrier ~  hour, fun.aggregate=avg.arr.delay)

SQL

请注意,从 2.4 版本开始支持 Spark SQL 中的 PIVOT 关键字。

CREATE TEMPORARY VIEW flights 
USING csv 
OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;

 SELECT * FROM (
   SELECT origin, dest, carrier, arr_delay, hour FROM flights
 ) PIVOT (
   avg(arr_delay)
   FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
 );

示例数据

"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour"
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00

性能考虑

一般来说,旋转是一项昂贵的操作。

相关问题

【讨论】:

  • 如果旋转数据框太大而无法放入内存怎么办。我怎样才能直接在磁盘上做呢?
  • 应该如何更改 aggexpr = avg("arr_delay") 以便旋转更多列,而不仅仅是 1
  • 在 SQL 解决方案(不是 Scala)中,我可以看到您使用硬编码列表 '(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)'。有没有办法使用从另一列中获取的所有值?我搜索了互联网和这个网站,但没有找到任何东西。
  • 与@Windoze 相同的问题。如果需要手动提供列列表,SQL 解决方案并不等同于其他解决方案。是否可以通过选择子查询获取列表?
【解决方案2】:

我通过编写一个 for 循环来动态创建 SQL 查询来克服这个问题。说我有:

id  tag  value
1   US    50
1   UK    100
1   Can   125
2   US    75
2   UK    150
2   Can   175

我想要:

id  US  UK   Can
1   50  100  125
2   75  150  175

我可以创建一个包含我想要透视的值的列表,然后创建一个包含我需要的 SQL 查询的字符串。

val countries = List("US", "UK", "Can")
val numCountries = countries.length - 1

var query = "select *, "
for (i <- 0 to numCountries-1) {
  query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", "
}
query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"

myDataFrame.registerTempTable("myTable")
val myDF1 = sqlContext.sql(query)

我可以创建类似的查询然后进行聚合。不是一个非常优雅的解决方案,但它适用于任何值列表并且灵活,也可以在调用代码时作为参数传入。

【讨论】:

  • 我正在尝试重现您的示例,但我得到一个“org.apache.spark.sql.AnalysisException: cannot resolve 'US' given input columns id, tag, value”
  • 这与引号有关。如果您查看生成的文本字符串,您会得到“case when tag = US”,因此 Spark 认为这是列名而不是文本值。你真正想看到的是'case when tag = "US"'。我已经编辑了上面的答案以正确设置引号。
  • 但也如前所述,这是使用 pivot 命令的 Spark 原生功能。
【解决方案3】:

Pivot 运算符已添加到 Spark 数据帧 API,它是 Spark 1.6 的一部分。

详情请见https://github.com/apache/spark/pull/7841

【讨论】:

    【解决方案4】:

    我已经通过以下步骤使用数据框解决了类似的问题:

    为您的所有国家/地区创建列,以“值”为值:

    import org.apache.spark.sql.functions._
    val countries = List("US", "UK", "Can")
    val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) =>
      if(countryToCheck == countryInRow) value else 0
    }
    val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) }
    val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")
    

    您的数据框“dfWithCountries”将如下所示:

    +--+--+---+---+
    |id|US| UK|Can|
    +--+--+---+---+
    | 1|50|  0|  0|
    | 1| 0|100|  0|
    | 1| 0|  0|125|
    | 2|75|  0|  0|
    | 2| 0|150|  0|
    | 2| 0|  0|175|
    +--+--+---+---+
    

    现在您可以将所有值加在一起以获得您想要的结果:

    dfWithCountries.groupBy("id").sum(countries: _*).show
    

    结果:

    +--+-------+-------+--------+
    |id|SUM(US)|SUM(UK)|SUM(Can)|
    +--+-------+-------+--------+
    | 1|     50|    100|     125|
    | 2|     75|    150|     175|
    +--+-------+-------+--------+
    

    虽然这不是一个非常优雅的解决方案。我必须创建一个函数链来添加到所有列中。此外,如果我有很多国家/地区,我会将我的临时数据集扩展为包含很多零的非常广泛的集合。

    【讨论】:

      【解决方案5】:

      有简单而优雅的解决方案。

      scala> spark.sql("select * from k_tags limit 10").show()
      +---------------+-------------+------+
      |           imsi|         name| value|
      +---------------+-------------+------+
      |246021000000000|          age|    37|
      |246021000000000|       gender|Female|
      |246021000000000|         arpu|    22|
      |246021000000000|   DeviceType| Phone|
      |246021000000000|DataAllowance|   6GB|
      +---------------+-------------+------+
      
      scala> spark.sql("select * from k_tags limit 10").groupBy($"imsi").pivot("name").agg(min($"value")).show()
      +---------------+-------------+----------+---+----+------+
      |           imsi|DataAllowance|DeviceType|age|arpu|gender|
      +---------------+-------------+----------+---+----+------+
      |246021000000000|          6GB|     Phone| 37|  22|Female|
      |246021000000001|          1GB|     Phone| 72|  10|  Male|
      +---------------+-------------+----------+---+----+------+
      

      【讨论】:

        【解决方案6】:

        有一个简单的方法可以进行旋转:

          id  tag  value
          1   US    50
          1   UK    100
          1   Can   125
          2   US    75
          2   UK    150
          2   Can   175
        
          import sparkSession.implicits._
        
          val data = Seq(
            (1,"US",50),
            (1,"UK",100),
            (1,"Can",125),
            (2,"US",75),
            (2,"UK",150),
            (2,"Can",175),
          )
        
          val dataFrame = data.toDF("id","tag","value")
        
          val df2 = dataFrame
                            .groupBy("id")
                            .pivot("tag")
                            .max("value")
          df2.show()
        
        +---+---+---+---+
        | id|Can| UK| US|
        +---+---+---+---+
        |  1|125|100| 50|
        |  2|175|150| 75|
        +---+---+---+---+
        

        【讨论】:

          【解决方案7】:

          有很多关于数据集/数据框的枢轴操作示例,但我找不到很多使用 SQL 的示例。这是一个对我有用的例子。

          create or replace temporary view faang 
          as SELECT stock.date AS `Date`,
              stock.adj_close AS `Price`,
              stock.symbol as `Symbol` 
          FROM stock  
          WHERE (stock.symbol rlike '^(FB|AAPL|GOOG|AMZN)$') and year(date) > 2010;
          
          
          SELECT * from faang 
          
          PIVOT (max(price) for symbol in ('AAPL', 'FB', 'GOOG', 'AMZN')) order by date; 
          
          

          【讨论】:

            【解决方案8】:

            最初我采用了 Al M 的解决方案。后来也有同样的想法,把这个函数改写成转置函数。

            此方法使用键和值列将任何 df 行转换为任何数据格式的列

            用于输入 csv

            id,tag,value
            1,US,50a
            1,UK,100
            1,Can,125
            2,US,75
            2,UK,150
            2,Can,175
            

            输出

            +--+---+---+---+
            |id| UK| US|Can|
            +--+---+---+---+
            | 2|150| 75|175|
            | 1|100|50a|125|
            +--+---+---+---+
            

            转置法:

            def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = {
            
            val distinctCols =   df.select(key).distinct.map { r => r(0) }.collect().toList
            
            val rdd = df.map { row =>
            (compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] },
            scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any]))
            }
            val pairRdd = rdd.reduceByKey(_ ++ _)
            val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols))
            hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols)))
            
            }
            
            private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = {
            val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) }
            val array = r._1 ++ cols
            Row(array: _*)
            }
            
            private  def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = {
            val idSchema = idCols.map { idCol => srcSchema.apply(idCol) }
            val colSchema = srcSchema.apply(distinctCols._1)
            val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) }
            StructType(idSchema ++ colsSchema)
            }
            

            主sn-p

            import java.util.Date
            import org.apache.spark.SparkConf
            import org.apache.spark.SparkContext
            import org.apache.spark.sql.Row
            import org.apache.spark.sql.DataFrame
            import org.apache.spark.sql.types.StructType
            import org.apache.spark.sql.hive.HiveContext
            import org.apache.spark.sql.types.StructField
            
            
            ...
            ...
            def main(args: Array[String]): Unit = {
            
                val sc = new SparkContext(conf)
                val sqlContext = new org.apache.spark.sql.SQLContext(sc)
                val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true")
                .load("data.csv")
                dfdata1.show()  
                val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value")
                dfOutput.show
            
            }
            

            【讨论】:

              【解决方案9】:

              内置的 spark pivot 功能效率低下。下面的实现适用于 spark 2.4+ - 这个想法是聚合地图并将值提取为列。唯一的限制是它不处理透视列中的聚合函数,只处理列。

              在 8M 表上,这些函数在 3 秒 上应用,而在内置 spark 版本中应用 40 分钟

              # pass an optional list of string to avoid computation of columns
              def pivot(df, group_by, key, aggFunction, levels=[]):
                  if not levels:
                      levels = [row[key] for row in df.filter(col(key).isNotNull()).groupBy(col(key)).agg(count(key)).select(key).collect()]
                  return df.filter(col(key).isin(*levels) == True).groupBy(group_by).agg(map_from_entries(collect_list(struct(key, expr(aggFunction)))).alias("group_map")).select([group_by] + ["group_map." + l for l in levels])
              
              # Usage
              pivot(df, "id", "key", "value")
              pivot(df, "id", "key", "array(value)")
              
              // pass an optional list of string to avoid computation of columns
                def pivot(df: DataFrame, groupBy: Column, key: Column, aggFunct: String, _levels: List[String] = Nil): DataFrame = {
                  val levels =
                    if (_levels.isEmpty) df.filter(key.isNotNull).select(key).distinct().collect().map(row => row.getString(0)).toList
                    else _levels
              
                  df
                    .filter(key.isInCollection(levels))
                    .groupBy(groupBy)
                    .agg(map_from_entries(collect_list(struct(key, expr(aggFunct)))).alias("group_map"))
                    .select(groupBy.toString, levels.map(f => "group_map." + f): _*)
                }
              
              // Usage:
              pivot(df, col("id"), col("key"), "value")
              pivot(df, col("id"), col("key"), "array(value)")
              

              【讨论】:

                【解决方案10】:

                Spark 一直在提供对 Spark DataFrame 透视的改进。 Spark 1.6 版本的 Spark DataFrame API 中添加了一个数据透视函数,它存在性能问题,已在 Spark 2.0 中得到纠正

                但是,如果您使用的是较低版本;请注意,pivot 是一项非常昂贵的操作,因此,建议提供列数据(如果已知)作为函数的参数,如下所示。

                val countries = Seq("USA","China","Canada","Mexico")
                val pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
                pivotDF.show()
                

                这在Pivoting and Unpivoting Spark DataFrame有详细解释

                学习愉快!!

                【讨论】:

                  猜你喜欢
                  • 2019-03-19
                  • 1970-01-01
                  • 1970-01-01
                  • 2016-12-23
                  • 2021-12-23
                  • 2022-01-03
                  • 2021-03-29
                  • 2016-05-16
                  相关资源
                  最近更新 更多