【问题标题】:How to save RandomForestClassifier Spark model in scala?如何在scala中保存RandomForestClassifier Spark模型?
【发布时间】:2016-09-15 01:22:33
【问题描述】:

我使用以下代码构建了一个随机森林模型:

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("features")
val labelConverter = new    IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
val training = labelIndexer.transform(df)
val model = rf.fit(training)

现在我想保存模型以便以后使用以下代码进行预测:

val predictions: DataFrame = model.transform(testData)

我查看了 Spark 文档 here 并没有找到任何选项。任何想法? 我花了几个小时来构建模型,如果 Spark 崩溃了,我将无法恢复它。

【问题讨论】:

标签: scala apache-spark apache-spark-mllib


【解决方案1】:

可以使用 Spark 1.6 对基于管道的模型和基本模型使用 saveAsObjectFile() 在 HDFS 中保存和重新加载基于树的模型。 以下是基于管道的模型的示例。

// model
val model = pipeline.fit(trainingData)

// Create rdd using Seq 
sc.parallelize(Seq(model), 1).saveAsObjectFile("hdfs://filepath")

// Reload model by using it's class
// You can get class of object using object.getClass()
val sameModel = sc.objectFile[PipelineModel]("filepath").first()

【讨论】:

  • 很好的方法可以解决缺乏对直接保存模型的支持。
【解决方案2】:

对于 RandomForestClassifier 保存和加载模型:在 ml 中测试 spark 1.6.2 + scala(在 spark 2.0 中,您可以有模型的直接保存选项)

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier //imports
val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043)
val model = classifier.fit(trainingData)

sc.parallelize(Seq(model), 1).saveAsObjectFile(modelSavePath) //保存模型

val linRegModel = sc.objectFile[RandomForestClassificationModel](modelSavePath).first() //load model
`val predictions1 = linRegModel.transform(testData)` //predictions1  is dataframe 

【讨论】:

    【解决方案3】:

    它位于 MLWriter 接口中 - 可通过模型上的 writer 属性访问:

    model.asInstanceOf[MLWritable].write.save(path)
    

    界面如下:

    abstract class MLWriter extends BaseReadWrite with Logging {
    
      protected var shouldOverwrite: Boolean = false
    
      /**
       * Saves the ML instances to the input path.
       */
      @Since("1.6.0")
      @throws[IOException]("If the input path already exists but overwrite is not enabled.")
      def save(path: String): Unit = {
    

    这是对早期版本mllib/spark.ml的重构

    更新模型似乎不可可写:

    线程“主”java.lang.UnsupportedOperationException 中的异常: 此管道上的管道写入将失败,因为它包含一个阶段 它不实现可写。不可写阶段: rfc_4e467607406f 类型类 org.apache.spark.ml.classification.RandomForestClassificationModel

    因此可能没有直接的解决方案。

    【讨论】:

    • 此代码不起作用。错误:值编写器不是 org.apache.spark.ml.classification.RandomForestClassificationModel 的成员
    • @Yaeli778 这个功能——正如你在我的回答中看到的——@Since("1.6.0")——需要最新版本的 Spark。如果您使用的是 1.5.X 或更早版本 - 那么您将不会拥有它。可以升级到 1.6.X 吗?
    • 我在 1.6.1 版本上运行
    • @Yaeli778 似乎有必要添加这个: .asInstanceOf[MLWritable] 。 OP 已更新以反映它。
    • 它也不起作用,得到以下错误:java.lang.ClassCastException: org.apache.spark.ml.classification.RandomForestClassificationModel cannot be cast to org.apache.spark.ml.util .M LWritable
    【解决方案4】:

    这是一个 PySpark v1.6 实现,对应于上面的 Scala saveAsObjectFile() 答案。

    它通过 saveAsObjectFile() 将 Python 对象强制转换为 Java 对象以实现序列化。

    如果没有 Java 强制,我在序列化时遇到了奇怪的 Py4J 错误。如果有人有更简单的实现,请编辑或评论。

    保存经过训练的RandomForestClassificationModel 对象:

    # Save RandomForestClassificationModel to hdfs
    gateway = sc._gateway
    java_list = gateway.jvm.java.util.ArrayList()
    java_list.add(rfModel._java_obj)
    modelRdd = sc._jsc.parallelize(java_list)
    modelRdd.saveAsObjectFile("hdfs:///some/path/rfModel")
    

    加载经过训练的RandomForestClassificationModel 对象:

    # Load RandomForestClassificationModel from hdfs
    rfObjectFileLoaded = sc._jsc.objectFile("hdfs:///some/path/rfModel")
    rfModelLoaded_JavaObject = rfObjectFileLoaded.first()
    rfModelLoaded = RandomForestClassificationModel(rfModelLoaded_JavaObject)
    predictions = rfModelLoaded.transform(test_input_df)
    

    【讨论】:

      猜你喜欢
      • 2019-12-06
      • 1970-01-01
      • 1970-01-01
      • 2016-04-09
      • 1970-01-01
      • 2016-03-20
      • 2016-10-12
      • 2019-03-13
      • 2020-05-20
      相关资源
      最近更新 更多