【发布时间】:2018-09-23 01:05:52
【问题描述】:
scala> spark.version
res8: String = 2.2.0
我正在使用包含 locationID 列的 spark Dataframe。我创建了一个 MLlib 管道来构建线性回归模型,当我为单个 locationID 提供数据时它就可以工作。我现在想为每个“locationID”创建许多模型(生产中可能有几千个 locationID)。我想保存每个模型的模型系数。
我不确定如何在 Scala 中做到这一点。
我的管道是这样定义的:
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.regression.LinearRegressionModel
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql
// Load the regression input data
val mydata = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("./inputdata.csv")
// Crate month one hot encoding
val monthIndexer = new StringIndexer()
.setInputCol("month")
.setOutputCol("monthIndex").fit(mydata)
val monthEncoder = new OneHotEncoder()
.setInputCol(monthIndexer.getOutputCol)
.setOutputCol("monthVec")
val assembler = new VectorAssembler()
.setInputCols(Array("monthVec","tran_adr"))
.setOutputCol("features")
val lr = new LinearRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
val pipeline = new Pipeline()
.setStages(Array(monthIndexer, monthEncoder, assembler, lr))
// Fit using the model pipeline
val myPipelineModel = pipeline.fit(mydata)
然后我可以像这样提取模型细节:
val modelExtract = myPipelineModel.stages(3).asInstanceOf[LinearRegressionModel]
println(s"Coefficients: ${modelExtract.coefficients} Intercept: ${modelExtract.intercept}")
// Summarize the model over the training set and print out some metrics
val trainingSummary = modelExtract.summary
println(s"numIterations: ${trainingSummary.totalIterations}")
println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")
trainingSummary.residuals.show()
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
println(s"r2: ${trainingSummary.r2}")
现在我想对mydata 中的列locationID 进行分组,并在数据的每个分区上运行管道。
我尝试过使用 groupby,但我只能聚合。
val grouped = mydata.groupBy("locationID")
我还尝试将唯一的 locationID 拉为一个列表并循环遍历它:
val locationList = mydata.select(mydata("prop_code")).distinct
locationList.foreach { printLn }
我知道 spark 不适合创建许多较小的模型,它最适合在大量数据上创建一个模型,但我的任务是这样做作为概念证明。
在 spark 中做这样的事情的正确方法是什么?
【问题讨论】:
标签: scala apache-spark spark-dataframe apache-spark-mllib