【问题标题】:how to use the linear regression of MLlib of apache spark?如何使用apache spark的MLlib的线性回归?
【发布时间】:2014-07-19 11:20:27
【问题描述】:

我是apache spark的新手,从MLlib的文档中,我找到了一个scala的例子,但我真的不知道scala,有人知道java中的例子吗?谢谢!示例代码是

import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint

// Load and parse the data
val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
val parsedData = data.map { line =>
  val parts = line.split(',')
  LabeledPoint(parts(0).toDouble, parts(1).split(' ').map(x => x.toDouble).toArray)
}

// Building the model
val numIterations = 20
val model = LinearRegressionWithSGD.train(parsedData, numIterations)

// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ +     _)/valuesAndPreds.count
println("training Mean Squared Error = " + MSE)

来自MLlib的文档 谢谢!

【问题讨论】:

    标签: java apache-spark apache-spark-mllib


    【解决方案1】:

    如文档所示:

    MLlib 的所有方法都使用 Java 友好类型,因此您可以导入和 像在 Scala 中那样调用它们。唯一需要注意的是 这些方法采用 Scala RDD 对象,而 Spark Java API 使用 单独的 JavaRDD 类。您可以通过以下方式将 Java RDD 转换为 Scala 在你的 JavaRDD 对象上调用 .rdd()。

    这并不容易,因为您仍然必须在 java 中重现 scala 代码,但它可以工作(至少在这种情况下)。

    话虽如此,这是一个java实现:

    public void linReg() {
        String master = "local";
        SparkConf conf = new SparkConf().setAppName("csvParser").setMaster(
                master);
        JavaSparkContext sc = new JavaSparkContext(conf);
        JavaRDD<String> data = sc.textFile("mllib/data/ridge-data/lpsa.data");
        JavaRDD<LabeledPoint> parseddata = data
                .map(new Function<String, LabeledPoint>() {
                // I see no ways of just using a lambda, hence more verbosity than with scala
                    @Override
                    public LabeledPoint call(String line) throws Exception {
                        String[] parts = line.split(",");
                        String[] pointsStr = parts[1].split(" ");
                        double[] points = new double[pointsStr.length];
                        for (int i = 0; i < pointsStr.length; i++)
                            points[i] = Double.valueOf(pointsStr[i]);
                        return new LabeledPoint(Double.valueOf(parts[0]),
                                Vectors.dense(points));
                    }
                });
    
        // Building the model
        int numIterations = 20;
        LinearRegressionModel model = LinearRegressionWithSGD.train(
        parseddata.rdd(), numIterations); // notice the .rdd()
    
        // Evaluate model on training examples and compute training error
        JavaRDD<Tuple2<Double, Double>> valuesAndPred = parseddata
                .map(point -> new Tuple2<Double, Double>(point.label(), model
                        .predict(point.features())));
        // important point here is the Tuple2 explicit creation.
    
        double MSE = valuesAndPred.mapToDouble(
                tuple -> Math.pow(tuple._1 - tuple._2, 2)).mean();
        // you can compute the mean with this function, which is much easier
        System.out.println("training Mean Squared Error = "
                + String.valueOf(MSE));
    }
    

    它远非完美,但我希望它能让你更好地理解如何在 Mllib 文档中使用 scala 示例。

    【讨论】:

      【解决方案2】:
      package org.apache.spark.examples;
      
      import org.apache.spark.SparkConf;
      import org.apache.spark.api.java.JavaRDD;
      import org.apache.spark.api.java.JavaSparkContext;
      import org.apache.spark.api.java.function.Function;
      import org.apache.spark.api.java.function.Function2;
      
      import java.io.Serializable;
      import java.util.Arrays;
      import java.util.Random;
      import java.util.regex.Pattern;
      
      /**
       * Logistic regression based classification.
       *
       * This is an example implementation for learning how to use Spark. For more conventional use,
       * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
       * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
       */
      public final class JavaHdfsLR {
      
        private static final int D = 10;   // Number of dimensions
        private static final Random rand = new Random(42);
      
        static void showWarning() {
          String warning = "WARN: This is a naive implementation of Logistic Regression " +
                  "and is given as an example!\n" +
                  "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " +
                  "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " +
                  "for more conventional use.";
          System.err.println(warning);
        }
      
        static class DataPoint implements Serializable {
          DataPoint(double[] x, double y) {
            this.x = x;
            this.y = y;
          }
      
          double[] x;
          double y;
        }
      
        static class ParsePoint implements Function<String, DataPoint> {
          private static final Pattern SPACE = Pattern.compile(" ");
      
          @Override
          public DataPoint call(String line) {
            String[] tok = SPACE.split(line);
            double y = Double.parseDouble(tok[0]);
            double[] x = new double[D];
            for (int i = 0; i < D; i++) {
              x[i] = Double.parseDouble(tok[i + 1]);
            }
            return new DataPoint(x, y);
          }
        }
      
        static class VectorSum implements Function2<double[], double[], double[]> {
          @Override
          public double[] call(double[] a, double[] b) {
            double[] result = new double[D];
            for (int j = 0; j < D; j++) {
              result[j] = a[j] + b[j];
            }
            return result;
          }
        }
      
        static class ComputeGradient implements Function<DataPoint, double[]> {
          private final double[] weights;
      
          ComputeGradient(double[] weights) {
            this.weights = weights;
          }
      
          @Override
          public double[] call(DataPoint p) {
            double[] gradient = new double[D];
            for (int i = 0; i < D; i++) {
              double dot = dot(weights, p.x);
              gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
            }
            return gradient;
          }
        }
      
        public static double dot(double[] a, double[] b) {
          double x = 0;
          for (int i = 0; i < D; i++) {
            x += a[i] * b[i];
          }
          return x;
        }
      
        public static void printWeights(double[] a) {
          System.out.println(Arrays.toString(a));
        }
      
        public static void main(String[] args) {
      
          if (args.length < 2) {
            System.err.println("Usage: JavaHdfsLR <file> <iters>");
            System.exit(1);
          }
      
          showWarning();
      
          SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
          JavaSparkContext sc = new JavaSparkContext(sparkConf);
          JavaRDD<String> lines = sc.textFile(args[0]);
          JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
          int ITERATIONS = Integer.parseInt(args[1]);
      
          // Initialize w to a random value
          double[] w = new double[D];
          for (int i = 0; i < D; i++) {
            w[i] = 2 * rand.nextDouble() - 1;
          }
      
          System.out.print("Initial w: ");
          printWeights(w);
      
          for (int i = 1; i <= ITERATIONS; i++) {
            System.out.println("On iteration " + i);
      
            double[] gradient = points.map(
              new ComputeGradient(w)
            ).reduce(new VectorSum());
      
            for (int j = 0; j < D; j++) {
              w[j] -= gradient[j];
            }
      
          }
      
          System.out.print("Final w: ");
          printWeights(w);
          sc.stop();
        }
      }
      

      【讨论】:

        猜你喜欢
        • 2018-04-06
        • 2015-01-08
        • 2016-12-12
        • 2017-10-29
        • 2014-12-03
        • 1970-01-01
        • 2016-10-09
        • 2016-03-07
        • 2014-07-10
        相关资源
        最近更新 更多