【问题标题】:Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames在 Spark RDD 和/或 Spark DataFrames 中重塑/透视数据
【发布时间】:2015-07-27 09:48:04
【问题描述】:

我有一些以下格式的数据(RDD 或 Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

我想做的是“重塑”数据,将 Country(特别是 US、UK 和 CA)中的某些行转换为列:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

基本上,我需要类似于 Python 的 pivot 工作流:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

我的数据集相当大,所以我不能真正 collect() 并将数据摄取到内存中以在 Python 本身中进行重塑。有没有办法在映射 RDD 或 Spark DataFrame 时将 Python 的 .pivot() 转换为可调用函数?任何帮助将不胜感激!

【问题讨论】:

    标签: python apache-spark pyspark apache-spark-sql pivot


    【解决方案1】:

    从 Spark 1.6 开始,您可以在 GroupedData 上使用 pivot 函数并提供聚合表达式。

    pivoted = (df
        .groupBy("ID", "Age")
        .pivot(
            "Country",
            ['US', 'UK', 'CA'])  # Optional list of levels
        .sum("Score"))  # alternatively you can use .agg(expr))
    pivoted.show()
    
    ## +---+---+---+---+---+
    ## | ID|Age| US| UK| CA|
    ## +---+---+---+---+---+
    ## |X01| 41|  3|  1|  2|
    ## |X02| 72|  4|  6|  7|
    ## +---+---+---+---+---+
    

    级别可以省略,但如果提供,既可以提高性能,又可以用作内部过滤器。

    这种方法仍然比较慢,但肯定优于在 JVM 和 Python 之间手动传递数据。

    【讨论】:

      【解决方案2】:

      首先,这可能不是一个好主意,因为您没有获得任何额外的信息,但是您将自己绑定到一个固定的架构(即您必须知道您期望有多少个国家,当然,额外的国家意味着代码的变化)

      话虽如此,这是一个SQL问题,如下图所示。但是,如果您认为它不是太“类似软件”(说真的,我听说过!!),那么您可以参考第一个解决方案。

      解决方案一:

      def reshape(t):
          out = []
          out.append(t[0])
          out.append(t[1])
          for v in brc.value:
              if t[2] == v:
                  out.append(t[3])
              else:
                  out.append(0)
          return (out[0],out[1]),(out[2],out[3],out[4],out[5])
      def cntryFilter(t):
          if t[2] in brc.value:
              return t
          else:
              pass
      
      def addtup(t1,t2):
          j=()
          for k,v in enumerate(t1):
              j=j+(t1[k]+t2[k],)
          return j
      
      def seq(tIntrm,tNext):
          return addtup(tIntrm,tNext)
      
      def comb(tP,tF):
          return addtup(tP,tF)
      
      
      countries = ['CA', 'UK', 'US', 'XX']
      brc = sc.broadcast(countries)
      reshaped = calls.filter(cntryFilter).map(reshape)
      pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
      for i in pivot.collect():
          print i
      

      现在,解决方案 2:当然更好,因为 SQL 是解决此问题的正确工具

      callRow = calls.map(lambda t:   
      
      Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
      callsDF = ssc.createDataFrame(callRow)
      callsDF.printSchema()
      callsDF.registerTempTable("calls")
      res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                          from (select userid,age,\
                                        case when country='CA' then nbrCalls else 0 end ca,\
                                        case when country='UK' then nbrCalls else 0 end uk,\
                                        case when country='US' then nbrCalls else 0 end us,\
                                        case when country='XX' then nbrCalls else 0 end xx \
                                   from calls) x \
                           group by userid,age")
      res.show()
      

      数据设置:

      data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
       calls = sc.parallelize(data,1)
      countries = ['CA', 'UK', 'US', 'XX']
      

      结果:

      从第一个解决方案

      (('X02', 72), (7, 6, 4, 8)) 
      (('X01', 41), (2, 1, 3, 0))
      

      从第二个解决方案:

      root  |-- age: long (nullable = true)  
            |-- country: string (nullable = true)  
            |-- nbrCalls: long (nullable = true)  
            |-- userid: string (nullable = true)
      
      userid age ca uk us xx 
       X02    72  7  6  4  8  
       X01    41  2  1  3  0
      

      请让我知道这是否有效:)

      最好的 绫

      【讨论】:

      • 谢谢..您的解决方案有效,更重要的是它们具有可扩展性!
      • 你能把它扩展到更通用的情况吗?例如,有一次在我的数据中,我可能有 3 个国家。另一次我可能有 5 个。您上面的内容似乎被硬编码为 4 个特定国家/地区。我知道我需要提前知道我有哪些国家,但这可能会随着时间的推移而改变。我怎样才能将国家列表作为参数传递并仍然使其工作?这是处理数据时很常见的事情,所以我希望它很快就会内置在功能中。
      • 正如我所指出的,这是架构设计的问题。您“不能”只传递国家/地区列表,因为您的架构将在下游发生变化。但是,您可能只是从 reshape 返回一个广义元组并为 aggregateByKey 设置零值。在 SQL 方法中,您基本上需要按照此处描述的模式以编程方式“生成”一个 sql。
      • 这是一个非常常见的功能,存在于大多数数据语言/框架中:SAS、Scalding、Pandas 等。希望这能很快融入 Spark。
      • 我根据您上面的回答创建了一个灵活的版本。你可以在这里查看:stackoverflow.com/questions/30244910/pivot-spark-dataframe。我希望 Spark 尽快为此实施解决方案,因为它在大多数其他数据操作语言/工具(Pandas、Scalding、SAS、Excel 等)中是非常基本的功能。
      【解决方案3】:

      这是一种不硬连线列名的原生 Spark 方法。它基于aggregateByKey,并使用字典来收集每个键出现的列。然后我们收集所有列名来创建最终的数据框。 [以前的版本在为每条记录发出字典后使用 jsonRDD,但这更有效。] 限制到特定的列列表,或者排除像 XX 这样的列将是一个简单的修改。

      即使在相当大的桌子上,性能似乎也不错。我正在使用一种变体,它计算每个 ID 发生可变数量事件的次数,为每种事件类型生成一列。代码基本相同,只是它使用 collections.Counter 而不是 seqFn 中的 dict 来计算出现次数。

      from pyspark.sql.types import *
      
      rdd = sc.parallelize([('X01',41,'US',3),
                             ('X01',41,'UK',1),
                             ('X01',41,'CA',2),
                             ('X02',72,'US',4),
                             ('X02',72,'UK',6),
                             ('X02',72,'CA',7),
                             ('X02',72,'XX',8)])
      
      schema = StructType([StructField('ID', StringType(), True),
                           StructField('Age', IntegerType(), True),
                           StructField('Country', StringType(), True),
                           StructField('Score', IntegerType(), True)])
      
      df = sqlCtx.createDataFrame(rdd, schema)
      
      def seqPivot(u, v):
          if not u:
              u = {}
          u[v.Country] = v.Score
          return u
      
      def cmbPivot(u1, u2):
          u1.update(u2)
          return u1
      
      pivot = (
          df
          .rdd
          .keyBy(lambda row: row.ID)
          .aggregateByKey(None, seqPivot, cmbPivot)
      )
      columns = (
          pivot
          .values()
          .map(lambda u: set(u.keys()))
          .reduce(lambda s,t: s.union(t))
      )
      result = sqlCtx.createDataFrame(
          pivot
          .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
          schema=StructType(
              [StructField('ID', StringType())] + 
              [StructField(c, IntegerType()) for c in columns]
          )
      )
      result.show()
      

      生产:

      ID  CA UK US XX  
      X02 7  6  4  8   
      X01 2  1  3  null
      

      【讨论】:

      • 不错的文章 - b.t.w spark 1.6 数据帧支持简单的枢轴github.com/apache/spark/pull/7841
      • 酷 - 火花变得越来越快。
      • 如果重整后的输出太大而无法放入内存怎么办。我怎样才能直接在磁盘上做呢?
      【解决方案4】:

      首先,我必须对您的 RDD 进行此更正(与您的实际输出相匹配):

      rdd = sc.parallelize([('X01',41,'US',3),
                            ('X01',41,'UK',1),
                            ('X01',41,'CA',2),
                            ('X02',72,'US',4),
                            ('X02',72,'UK',6),
                            ('X02',72,'CA',7),
                            ('X02',72,'XX',8)])
      

      一旦我进行了修正,这就成功了:

      df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
      .join(
          df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
          $"ID" === $"usID" and $"C1" === "US"
      )
      .join(
          df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
          $"ID" === $"ukID" and $"C2" === "UK"
      )
      .join(
          df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
          $"ID" === $"caID" and $"C3" === "CA"
      )
      .select($"ID",$"Age",$"US",$"UK",$"CA")
      

      肯定没有你的支点那么优雅。

      【讨论】:

      • 大卫,我无法让它工作。首先,Spark 不接受 $ 作为引用列的方式。删除所有 $ 符号后,我仍然收到指向上述代码最后一行中的 .select 表达式的语法错误
      • 抱歉,我使用的是 Scala。它是直接从 spark-shell 剪切和粘贴的。如果你把最后一个 select() 去掉,你应该得到正确的结果,只是列太多。你能做到这一点并发布结果吗?
      【解决方案5】:

      patricksurry 的非常有帮助的回答只是一些 cmets:

      • 缺少 Age 列,因此只需将 u["Age"] = v.Age 添加到函数 seqPivot
      • 事实证明,列元素上的两个循环都以不同的顺序给出了元素。列的值是正确的,但它们的名称不正确。为避免这种行为,只需对列列表进行排序即可。

      这里是稍微修改的代码:

      from pyspark.sql.types import *
      
      rdd = sc.parallelize([('X01',41,'US',3),
                             ('X01',41,'UK',1),
                             ('X01',41,'CA',2),
                             ('X02',72,'US',4),
                             ('X02',72,'UK',6),
                             ('X02',72,'CA',7),
                             ('X02',72,'XX',8)])
      
      schema = StructType([StructField('ID', StringType(), True),
                           StructField('Age', IntegerType(), True),
                           StructField('Country', StringType(), True),
                           StructField('Score', IntegerType(), True)])
      
      df = sqlCtx.createDataFrame(rdd, schema)
      
      # u is a dictionarie
      # v is a Row
      def seqPivot(u, v):
          if not u:
              u = {}
          u[v.Country] = v.Score
          # In the original posting the Age column was not specified
          u["Age"] = v.Age
          return u
      
      # u1
      # u2
      def cmbPivot(u1, u2):
          u1.update(u2)
          return u1
      
      pivot = (
          rdd
          .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
          .keyBy(lambda row: row.ID)
          .aggregateByKey(None, seqPivot, cmbPivot)
      )
      
      columns = (
          pivot
          .values()
          .map(lambda u: set(u.keys()))
          .reduce(lambda s,t: s.union(t))
      )
      
      columns_ord = sorted(columns)
      
      result = sqlCtx.createDataFrame(
          pivot
          .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
              schema=StructType(
                  [StructField('ID', StringType())] + 
                  [StructField(c, IntegerType()) for c in columns_ord]
              )
          )
      
      print result.show()
      

      最后应该是输出

      +---+---+---+---+---+----+
      | ID|Age| CA| UK| US|  XX|
      +---+---+---+---+---+----+
      |X02| 72|  7|  6|  4|   8|
      |X01| 41|  2|  1|  3|null|
      +---+---+---+---+---+----+
      

      【讨论】:

        【解决方案6】:

        Hive 中有一个 JIRA 供 PIVOT 本地执行此操作,而每个值都没有巨大的 CASE 语句:

        https://issues.apache.org/jira/browse/HIVE-3776

        请投票支持 JIRA,以便尽快实施。 一旦在 Hive SQL 中,Spark 通常不会落后太多,最终也会在 Spark 中实现。

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2019-10-20
          • 2018-11-22
          • 2020-10-23
          • 2017-11-15
          • 2019-02-21
          • 2019-02-11
          • 1970-01-01
          相关资源
          最近更新 更多