【问题标题】:How to avoid automatic cast for ArrayType in Spark (2.4) SQL - Scala 2.11如何避免在 Spark (2.4) SQL 中自动转换 ArrayType - Scala 2.11
【发布时间】:2020-01-23 13:04:25
【问题描述】:

鉴于 Spark 2.4 和 scala 2.11 中的代码

val df = spark.sql("""select array(45, "something", 45)""")

如果我使用 df.printSchema() 打印架构,我看到 spark 会自动转换为字符串 CAST(45 AS STRING)

root
 |-- array(CAST(45 AS STRING), something, CAST(45 AS STRING)): array (nullable = false)
 |    |-- element: string (containsNull = false)

我想知道是否有办法避免这种自动转换,而是让 Spark SQL 因异常而失败?假设我在此之后调用任何操作,例如 df.collect()

这只是一个查询示例,但它应该适用于任何查询。

【问题讨论】:

    标签: scala apache-spark casting apache-spark-sql


    【解决方案1】:

    这会在 Dataframe 中创建一个“ArrayType”列。

    来自scaladocsAn ArrayType object comprises two fields, elementType: DataType and containsNull: Boolean. The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has null values.

    因此,ArrayType 只接受 Array 中的一种类型的列。 如果有不同类型的值传递给array 函数,它将首先尝试将列转换为字段中最适合的类型。如果列完全不兼容,则 Spark 将抛出异常。下面的例子

    val df = spark.sql("""select array(45, 46L, 45.45)""")
    df.printSchema()
    
    root
     |-- array(CAST(45 AS DECIMAL(22,2)), CAST(46 AS DECIMAL(22,2)), CAST(45.45 AS DECIMAL(22,2))): array (nullable = false)
     |    |-- element: decimal(22,2) (containsNull = false)
    
    df: org.apache.spark.sql.DataFrame = [array(CAST(45 AS DECIMAL(22,2)), CAST(46 AS DECIMAL(22,2)), CAST(45.45 AS DECIMAL(22,2))): array<decimal(22,2)>]
    

    下一个,错误:

    val df = spark.sql("""select array(45, 46L, True)""")
    df.printSchema()
    
    org.apache.spark.sql.AnalysisException: cannot resolve 'array(45, 46L, true)' due to data type mismatch: input to function array should all be the same type, but it's [int, bigint, boolean]; line 1 pos 7;
    'Project [unresolvedalias(array(45, 46, true), None)]
    +- OneRowRelation
    
        at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
        at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:126)
        at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:111)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$6.apply(TreeNode.scala:304)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$6.apply(TreeNode.scala:304)
        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:303)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:301)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:301)
        at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$8.apply(TreeNode.scala:354)
        at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
        at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:352)
        at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:301)
        at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:94)
        at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:94)
        at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:106)
        at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:106)
        at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)
    

    【讨论】:

      【解决方案2】:

      我假设您正在从某个数据框中的列创建一个数组。在这种情况下,您可以在该数据框的架构中检查输入列的类型为StringType。在 scala 中,它看起来像这样:

      // some dataframe with a long and a string
      val df = spark.range(3).select('id, 'id cast "string" as "id_str")
      
      // a function that checks if the provided columns are strings
      def check_df(df : DataFrame, cols : Seq[String]) {
          val non_string_column = df
              .schema
              .find(field => cols.contains(field.name) &&
                                    field.dataType != DataTypes.StringType)
          if(non_string_column.isDefined)
              throw new Error(s"The column ${non_string_column.get.name} has type " +
                              s"${non_string_column.get.dataType} instead of StringType")
      

      那我们试试吧,

      scala> check_df(df, Seq("id", "id_str"))
      java.lang.Error: The column id has type  LongType instead of string
        at check_def(<console>:36)
        ... 50 elided
      
      scala> check_def(df, Seq("id_str"))
      // no exception
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2020-12-27
        • 1970-01-01
        • 2020-11-25
        • 1970-01-01
        相关资源
        最近更新 更多