最近在研究机器学习,使用的工具是spark,本文是针对spar最新的源码Spark1.6.0的MLlib中的 logistic regression, linear regression进行源码分析,其理论部分参考:http://www.cnblogs.com/ljy2013/p/5129610.html
下面我们跟随我的demo来一步一步解剖源码,首先来看一下我的demo:
1 package org.apache.spark.mllib.classification 2 3 import org.apache.spark.SparkContext 4 import org.apache.spark.mllib.classification.{ LogisticRegressionWithLBFGS, LogisticRegressionModel } 5 import org.apache.spark.mllib.evaluation.MulticlassMetrics 6 import org.apache.spark.mllib.regression.LabeledPoint 7 import org.apache.spark.mllib.linalg.Vectors 8 import org.apache.spark.mllib.util.MLUtils 9 import org.apache.spark.SparkConf 10 11 object MyLogisticRegression { 12 def main(args: Array[String]): Unit = { 13 14 val conf = new SparkConf().setAppName("Simple Application").setMaster("local[*]") 15 val sc = new SparkContext(conf) 16 17 // Load training data in LIBSVM format. 这里的数据格式是LIBSVM格式:<label> <index1>:<value1> <index2>:<value2> ...index1是按1开始的 18 val data = MLUtils.loadLibSVMFile(sc, "D:\\MyFile\\wine.txt") 19 20 // Split data into training (60%) and test (40%). 21 val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) 22 val training = splits(0).cache() 23 val test = splits(1) 24 25 // Run training algorithm to build the model 26 val model = new LogisticRegressionWithLBFGS() 27 .setNumClasses(10) //设置类别的个数 28 .run(training) 29 30 // Compute raw scores on the test set. 31 val predictionAndLabels = test.map { 32 case LabeledPoint(label, features) => 33 val prediction = model.predict(features) 34 (prediction, label) 35 } 36 37 // Get evaluation metrics. 38 val metrics = new MulticlassMetrics(predictionAndLabels) 39 val precision = metrics.precision 40 println("Precision = " + precision) 41 42 // Save and load model 43 model.save(sc, "myModelPath") 44 val sameModel = LogisticRegressionModel.load(sc, "myModelPath") 45 46 } 47 }