【问题标题】:Spark - How to use QuantileDiscretizer with RandomForestClassifierSpark - 如何将 QuantileDiscretizer 与 RandomForestClassifier 一起使用
【发布时间】:2018-05-31 03:45:50
【问题描述】:

是否可以将 QuantileDiscretizerkeeping NaN 值与 RandomForestClassifier 一起使用?

我遇到了这样的错误:

18/03/23 17:38:15 ERROR Executor: Exception in task 3.0 in stage 133.0 (TID 381)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
  Bad data point: (1.0,[1.0,2.0])

示例

这里的想法是创建一个数字列并使用分位数对其进行离散化,将无效数字 (NaN) 保存在一个特殊的桶中。

import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler,
  QuantileDiscretizer}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassifier}

val tseq = Seq((0, "a", 1.0), (1, "b", 0.0), (2, "c", 2.0),
               (3, "a", 1.0), (4, "a", 3.0), (5, "c", Double.NaN))
val tdf = SparkInit.ss.createDataFrame(tseq).toDF("id", "category", "class")
val indexer = new StringIndexer()
  .setInputCol("category")
  .setOutputCol("categoryIndex")
val discr = new QuantileDiscretizer()
  .setInputCol("class")
  .setOutputCol("quant")
  .setNumBuckets(2)
  .setHandleInvalid("keep")
val assembler = new VectorAssembler()
  .setInputCols(Array("categoryIndex", "quant"))
  .setOutputCol("features")
val rf = new RandomForestClassifier()
  .setLabelCol("categoryIndex")
  .setFeaturesCol("features")
  .setNumTrees(3)
new Pipeline()
  .setStages(Array(indexer, discr, assembler, rf))
  .fit(tdf)
  .transform(tdf)
  .show()

没有尝试适应随机森林,我得到了这样的 DataFrame:

+---+--------+-----+-------------+-----+---------+
| id|category|class|categoryIndex|quant| features|
+---+--------+-----+-------------+-----+---------+
|  0|       a|  1.0|          0.0|  1.0|[0.0,1.0]|
|  1|       b|  0.0|          2.0|  0.0|[2.0,0.0]|
|  2|       c|  2.0|          1.0|  1.0|[1.0,1.0]|
|  3|       a|  1.0|          0.0|  1.0|[0.0,1.0]|
|  4|       a|  3.0|          0.0|  1.0|[0.0,1.0]|
|  5|       c|  NaN|          1.0|  2.0|[1.0,2.0]|
+---+--------+-----+-------------+-----+---------+

如果我尝试拟合模型,我会收到错误:

18/03/23 17:54:12 WARN DecisionTreeMetadata: DecisionTree reducing maxBins from 32 to 6 (= number of training instances)
18/03/23 17:54:12 WARN BlockManager: Putting block rdd_490_3 failed due to an exception
18/03/23 17:54:12 WARN BlockManager: Block rdd_490_3 could not be removed as it was not found on disk or in memory
18/03/23 17:54:12 ERROR Executor: Exception in task 3.0 in stage 143.0 (TID 414)
java.lang.IllegalArgumentException: DecisionTree given invalid data: Feature 1 is categorical with values in {0,...,1, but a data point gives it value 2.0.
  Bad data point: (1.0,[1.0,2.0])
    at org.apache.spark.ml.tree.impl.TreePoint$.findBin(TreePoint.scala:124)
    at org.apache.spark.ml.tree.impl.TreePoint$.org$apache$spark$ml$tree$impl$TreePoint$$labeledPointToTreePoint(TreePoint.scala:93)
    at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD$2.apply(TreePoint.scala:73)
    at org.apache.spark.ml.tree.impl.TreePoint$$anonfun$convertToTreeRDD$2.apply(TreePoint.scala:72)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
    at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:216)

QuantileDiscretizer 是否会插入一些关于特殊额外存储桶的元数据?奇怪的是,我之前能够使用具有相同值的列来构建模型,但没有强制任何离散化。

更新

是的,列确实有附加元数据,它看起来像这样:

org.apache.spark.sql.types.Metadata = {"ml_attr":
   {"ord":true,
    "vals":["-Infinity, 5.0","5.0, 10.0","10.0, Infinity"],
    "type":"nominal"}
}

现在的问题可能是:如何正确设置元数据以包含 Double.NaN 之类的值?

【问题讨论】:

  • 你解决了这个问题了吗?我在这里遇到了同样的问题,这是我谷歌后唯一相关的帖子。
  • @SpiritZhang,我找到了适合我的模型的解决方法。我来描述一下。

标签: scala apache-spark apache-spark-ml


【解决方案1】:

我使用的解决方法只是从离散列中删除关联的元数据,让决策树实现来决定如何处理数据。我认为该列实际上会变成一个数值列(例如[0, 1, 2, 2, 1]),但是,如果创建的类别过多,该列可能会再次离散化(查找参数maxBins)。

在我的例子中,删除元数据的最简单方法是在应用 QuantileDiscretizerfill DataFrame:

// Nothing is actually filled in my case, since there was no missing
// values before this operation.
df.na.fill(Double.NaN, Array("quant"))

我几乎可以肯定您也可以手动删除直接访问列对象的元数据。

更新

我们可以通过创建别名 (reference) 来更改列的元数据:

val metadata: Metadata = ...
df.select($"colA".as("colB", metadata))

This answer 描述了一种通过获取 DataFrame 架构的相应 StructField 来获取列元数据的方法。

【讨论】:

  • 我在我的 pyspark 代码上尝试了这种方法,但没有奏效。您能否解释一下为什么填充数据框会删除其元数据?
  • @SpiritZhang,您是否在上面的Array() 中添加了所有离散化列?我没有检查Spark的源代码中的确切代码行,但是我在操作后手动检查了我的列的元数据。我相信fill 会自动删除与相应列关联的所有元数据。
猜你喜欢
  • 2017-04-16
  • 2016-04-01
  • 2022-07-30
  • 1970-01-01
  • 2015-08-03
  • 2020-04-18
  • 1970-01-01
  • 2015-07-15
  • 1970-01-01
相关资源
最近更新 更多