【问题标题】:Spark ML Pipeline throws exception for Random Forest classification: Column label must be of type DoubleType but was actually IntegerTypeSpark ML Pipeline 为随机森林分类抛出异常:列标签必须是 DoubleType 类型,但实际上是 IntegerType
【发布时间】:2016-08-05 00:25:52
【问题描述】:

我正在尝试使用随机森林分类器创建一个 Spark ML 管道来执行分类(不是回归),但是我收到一个错误消息,指出我的训练集中的预测标签应该是双精度而不是整数。我正在遵循这些页面的说明:

我有一个包含以下列的 Spark 数据框:

scala> df.show(5)
+-------+----------+----------+---------+-----+
| userId|duration60|duration30|duration1|label|
+-------+----------+----------+---------+-----+
|user000|        11|        21|       35|    3|
|user001|        28|        41|       28|    4|
|user002|        17|         6|        8|    2|
|user003|        39|        29|        0|    1|
|user004|        26|        23|       25|    3|
+-------+----------+----------+---------+-----+


scala> df.printSchema()
root
 |-- userId: string (nullable = true)
 |-- duration60: integer (nullable = true)
 |-- duration30: integer (nullable = true)
 |-- duration1: integer (nullable = true)
 |-- label: integer (nullable = true)

我正在使用特征列 duration60、duration30 和 duration1 来预测分类列标签。

然后我像这样设置我的 Spark 脚本:

import org.apache.log4j.Logger
import org.apache.log4j.Level
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.{Pipeline, PipelineModel}


Logger.getLogger("org").setLevel(Level.ERROR)
Logger.getLogger("akka").setLevel(Level.ERROR)

val sqlContext = new SQLContext(sc)
val df = sqlContext.read.
    format("com.databricks.spark.csv").
    option("header", "true"). // Use first line of all files as header
    option("inferSchema", "true"). // Automatically infer data types
    load("/tmp/features.csv").
    withColumnRenamed("satisfaction", "label").
    select("userId", "duration60", "duration30", "duration1", "label")

val assembler = new VectorAssembler().
    setInputCols(Array("duration60", "duration30", "duration1")).
    setOutputCol("features")


val randomForest = new RandomForestClassifier().
    setLabelCol("label").
    setFeaturesCol("features").
    setNumTrees(10)

var pipeline = new Pipeline().setStages(Array(assembler, randomForest))

var model = pipeline.fit(df);

转换后的数据框如下:

scala> assembler.transform(df).show(5)
+-------+----------+----------+---------+-----+----------------+
| userId|duration60|duration30|duration1|label|        features|
+-------+----------+----------+---------+-----+----------------+
|user000|        11|        21|       35|    3|[11.0,21.0,35.0]|
|user001|        28|        41|       28|    4|[28.0,41.0,28.0]|
|user002|        17|         6|        8|    2|  [17.0,6.0,8.0]|
|user003|        39|        29|        0|    1| [39.0,29.0,0.0]|
|user004|        26|        23|       25|    3|[26.0,23.0,25.0]|
+-------+----------+----------+---------+-----+----------------+

但是最后一行抛出异常:

java.lang.IllegalArgumentException:要求失败:列标签 必须是 DoubleType 类型,但实际上是 IntegerType。

这是什么意思,我该如何解决?

为什么label 列必须是双精度的?我在做预测,而不是回归,所以我认为字符串或整数是合适的。预测列的双精度值通常意味着回归。

【问题讨论】:

    标签: scala apache-spark apache-spark-ml


    【解决方案1】:

    如果您使用 pyspark 并遇到同样的问题

    from pyspark.ml.feature import StringIndexer
       stringIndexer = StringIndexer(inputCol="label", outputCol="newlabel")
       model = stringIndexer.fit(df)
       df = model.transform(df)
       df.printSchema()
    

    这是将标签列转换为“双”类型的一种方法。

    【讨论】:

    • 我投了反对票,因为您的回答具有误导性 (other methods don't seem to work.)。当您有数字标签时,其他答案确实有效!
    • 对不起,我的错,相应地编辑它。谢谢指出
    • 这仍然不准确。您可能要指出的是,如果您有非数字标签,StringIndexer converts 将它们转换为所需的格式,这不是 cast
    【解决方案2】:

    在 pyspark 中

    from pyspark.sql.types import DoubleType
    df = df.withColumn("label", df.label.cast(DoubleType()))
    

    【讨论】:

      【解决方案3】:

      执行cast DoubleType,因为这是算法期望的类型。

      import org.apache.spark.sql.types._
      df.withColumn("label", 'label cast DoubleType)
      

      所以,就在您在应用程序中 val df 之前,在序列的最后一行进行转换:

      import org.apache.spark.sql.types._
      val df = sqlContext.read.
          format("com.databricks.spark.csv").
          option("header", "true"). // Use first line of all files as header
          option("inferSchema", "true"). // Automatically infer data types
          load("/tmp/features.csv").
          withColumnRenamed("satisfaction", "label").
          select("userId", "duration60", "duration30", "duration1", "label")
          .withColumn("label", 'label cast DoubleType) // <-- HERE
      

      请注意,我使用'label 符号(单引号' 后跟名称)来引用列label(我可能也使用$"label"col("label") 或@987654331 @ 或 column("label"))。

      【讨论】:

      • 谢谢。现在我收到一个错误:RandomForestClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer. 这是什么意思?如果我在进行分类,为什么标签应该是双重的?通常连续因变量用于回归,而不是分类。
      • @stackoverflowuser2010 手动转换的列缺少 ML 估计器工作所需的元数据。您必须手动添加它。参见例如stackoverflow.com/q/36517302/1560062
      猜你喜欢
      • 2018-10-05
      • 2015-10-24
      • 2020-01-31
      • 2016-09-30
      • 1970-01-01
      • 2015-09-22
      • 2018-09-18
      • 2019-09-05
      • 2013-09-22
      相关资源
      最近更新 更多