【问题标题】:SumProduct in Spark DataFrameSpark DataFrame 中的 SumProduct
【发布时间】:2019-01-08 00:24:09
【问题描述】:

我想在 Spark DataFrame 中跨列创建一个 sumproduct。我有一个如下所示的 DataFrame:

id    val1   val2   val3   val4
123   10     5      7      5

我还有一张地图,看起来像:

val coefficents = Map("val1" -> 1, "val2" -> 2, "val3" -> 3, "val4" -> 4)

我想获取 DataFrame 每一列中的值,将其乘以映射中的相应值,然后在新列中返回结果,因此本质上是:

(10*1) + (5*2) + (7*3) + (5*4) = 61

我试过这个:

val myDF1 = myDF.withColumn("mySum", {var a:Double = 0.0; for ((k,v) <- coefficients) a + (col(k).cast(DoubleType)*coefficients(k));a})

但出现“+”方法被重载的错误。即使我解决了这个问题,我也不确定这是否可行。有任何想法吗?我总是可以动态地将 SQL 查询构建为文本字符串并以这种方式进行,但我希望能更有说服力。

感谢任何想法。

【问题讨论】:

    标签: scala apache-spark dataframe apache-spark-sql


    【解决方案1】:

    您的代码的问题是您尝试将Column 添加到Doublecast(DoubleType) 只影响一种类型的存储值,而不影响一种类型的列本身。由于Double 没有提供*(x: org.apache.spark.sql.Column): org.apache.spark.sql.Column 方法,所以一切都失败了。

    要使其正常工作,您可以执行以下操作:

    import org.apache.spark.sql.Column
    import org.apache.spark.sql.functions.{col, lit}
    
    val df = sc.parallelize(Seq(
        (123, 10, 5, 7, 5), (456,  1, 1, 1, 1)
    )).toDF("k", "val1", "val2", "val3", "val4")
    
    val coefficients = Map("val1" -> 1, "val2" -> 2, "val3" -> 3, "val4" -> 4)
    
    val dotProduct: Column = coefficients
      // To be explicit you can replace
      // col(k) * v with col(k) * lit(v)
      // but it is not required here
      // since we use * f Column.* method not Int.*
      .map{ case (k, v) => col(k) * v }  // * -> Column.*
      .reduce(_ + _)  // + -> Column.+
    
    df.withColumn("mySum", dotProduct).show
    // +---+----+----+----+----+-----+
    // |  k|val1|val2|val3|val4|mySum|
    // +---+----+----+----+----+-----+
    // |123|  10|   5|   7|   5|   61|
    // |456|   1|   1|   1|   1|   10|
    // +---+----+----+----+----+-----+
    

    【讨论】:

      【解决方案2】:

      看起来问题是你实际上并没有对a做任何事情

      for((k, v) <- coefficients) a + ...
      

      你的意思可能是a += ...


      另外,关于清理withColumn 调用中的代码块的一些建议:

      您无需调用coefficients(k),因为您已经从for((k,v) &lt;- coefficients) 获得了v 中的值

      Scala 非常擅长制作单行代码,但如果您必须在那一行中放置分号,这有点作弊:P 我建议将求和部分拆分为每个表达式一行。

      sum 表达式可以重写为fold,从而避免使用var(惯用的Scala 通常避免使用vars),例如

      import org.apache.spark.sql.functions.lit
      
      coefficients.foldLeft(lit(0.0)){ 
        case (sumSoFar, (k,v)) => col(k).cast(DoubleType) * v + sumSoFar
      }
      

      【讨论】:

        【解决方案3】:

        我不确定这是否可以通过 DataFrame API 实现,因为您只能使用列而不是任何预定义的闭包(例如您的参数映射)。

        我在下面概述了一种使用 DataFrame 的底层 RDD 的方法:

        import org.apache.spark.sql.types._
        import org.apache.spark.sql.Row
        
        // Initializing your input example.
        val df1 = sc.parallelize(Seq((123, 10, 5, 7, 5))).toDF("id", "val1", "val2", "val3", "val4")
        
        // Return column names as an array
        val names = df1.columns
        
        // Grab underlying RDD and zip elements with column names
        val rdd1 = df1.rdd.map(row => (0 until row.length).map(row.getInt(_)).zip(names))
        
        // Tack on accumulated total to the existing row
        val rdd2 = rdd0.map { seq => Row.fromSeq(seq.map(_._1) :+ seq.map { case (value: Int, name: String) => value * coefficents.getOrElse(name, 0) }.sum) }
        
        // Create output schema (with total)
        val totalSchema = StructType(df1.schema.fields :+ StructField("total", IntegerType))
        
        // Apply schema to create output dataframe
        val df2 = sqlContext.createDataFrame(rdd1, totalSchema)
        
        // Show output:
        df2.show()
        ...
        +---+----+----+----+----+-----+
        | id|val1|val2|val3|val4|total|
        +---+----+----+----+----+-----+
        |123|  10|   5|   7|   5|   61|
        +---+----+----+----+----+-----+
        

        【讨论】:

          猜你喜欢
          • 2017-10-18
          • 2018-10-03
          • 1970-01-01
          • 2015-10-11
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多