【问题标题】:How to get confidence scores from Spark MLLib Logistic Regression in javajava - 如何从Java中的Spark MLLib Logistic Regression获得置信度分数
【发布时间】:2017-01-12 16:57:53
【问题描述】:

更新:我尝试使用以下方式生成置信度分数,但它给了我一个例外。我使用下面的代码sn-p:

double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point)); 

我得到的例外:

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
    at scala.Predef$.require(Predef.scala:233)
    at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
    at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)

你能帮忙吗?似乎权重向量的元素(198 个)比数据向量(我正在生成 18 个特征)更多。它们在dot() 函数中的长度必须相同。

我正在尝试在 Java 中实现一个程序,以从现有数据集进行训练,并使用 Spark MLLib (1.5.0) 中提供的逻辑回归算法对新数据集进行预测。我的训练和预测程序如下,我正在使用多类实现。问题是当我执行model.predict(vector)(注意预测程序中的 lrmodel.predict() )时,我得到了预测的标签。但是如果我需要一个信心分数呢?我怎么得到它?我已经浏览了 API,但无法找到任何给出置信度分数的特定 API。谁能帮帮我?

训练程序(生成 .model 文件)

public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;

        try {
           ...
       SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Train").setMaster(
                            sparkMaster);
            jsc = new JavaSparkContext(sparkConf);
            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

           ...
            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    StringBuilder featureSet =
                                            new StringBuilder();
                                   ...
                                        featureSet.append(i - 2);
                                        featureSet.append(COLON);
                                        featureSet.append(strVal);
                                        featureSet.append(SPACE);
                                    }

                                    return featureSet.toString().trim();
                                }
                            });

            List<String> featureList = featureRDD.collect();
            String featureOutput = StringUtils.join(featureList, NEW_LINE);
            String filePath = basePath + "lr.arff";
            FileUtils.writeStringToFile(new File(filePath), featureOutput,
                    "UTF-8");

            JavaRDD<LabeledPoint> trainingData =
                    MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();

            final LogisticRegressionModel model =
                    new LogisticRegressionWithLBFGS().setNumClasses(18).run(
                            trainingData.rdd());
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(baos);
            oos.writeObject(model);
            oos.flush();
            oos.close();
            FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
                    baos.toByteArray());
            baos.close();

        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }

预测程序(使用从训练程序生成的lr.model

    public static void main(final String[] args) throws Exception {
        JavaSparkContext jsc = null;
        int salesIndex = 1;
        try {
            ...
        SparkConf sparkConf =
                    new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
            jsc = new JavaSparkContext(sparkConf);

            ObjectInputStream objectInputStream =
                    new ObjectInputStream(new FileInputStream(basePath
                            + "lr.model"));
            LogisticRegressionModel lrmodel =
                    (LogisticRegressionModel) objectInputStream.readObject();
            objectInputStream.close();

            ...

            JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();

            final String firstRdd = trainRDD.first().trim();
            JavaRDD<String> tempRddFilter =
                    trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
                        private static final long serialVersionUID =
                                11111111111111111L;

                        public Boolean call(final String arg0) {
                            return !arg0.trim().equalsIgnoreCase(firstRdd);
                        }
                    });

            ...
            final Broadcast<LogisticRegressionModel> broadcastModel =
                    jsc.broadcast(lrmodel);

            JavaRDD<String> featureRDD =
                    tempRddFilter
                            .map(new org.apache.spark.api.java.function.Function() {
                                private static final long serialVersionUID =
                                        6948900080648474074L;

                                public Object call(final Object arg0)
                                        throws Exception {
                                   ...
                                    LogisticRegressionModel lrModel =
                                            broadcastModel.value();
                                    String row = ((String) arg0);
                                    String[] featureSetArray =
                                            row.split(CSV_SPLITTER);
                                   ...
                                    final Vector vector =
                                            Vectors.dense(doubleArr);
                                    double score = lrModel.predict(vector);
                                   ...
                                    return csvString;
                                }
                            });

            String outputContent =
                    featureRDD
                            .reduce(new org.apache.spark.api.java.function.Function2() {

                                private static final long serialVersionUID =
                                        1212970144641935082L;

                                public Object call(Object arg0, Object arg1)
                                        throws Exception {
                                    ...
                                }

                            });
            ...
            FileUtils.writeStringToFile(new File(basePath
                    + "predicted-sales-data.csv"), sb.toString());
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            if (jsc != null) {
                jsc.close();
            }
        }
    }
}

【问题讨论】:

    标签: java apache-spark-mllib logistic-regression


    【解决方案1】:

    经过多次尝试,我终于设法编写了一个自定义函数来生成置信度分数。它一点也不完美,但现在对我有用!

    private static double getConfidenceScore(
                final LogisticRegressionModel lrModel, final Vector vector) {
            /* Approach to get confidence scores starts */
            Vector weights = lrModel.weights();
            int numClasses = lrModel.numClasses();
            int dataWithBiasSize = weights.size() / (numClasses - 1);
            boolean withBias = (vector.size() + 1) == dataWithBiasSize;
            double maxMargin = 0.0;
            double margin = 0.0;
            for (int j = 0; j < (numClasses - 1); j++) {
                margin = 0.0;
                for (int k = 0; k < vector.size(); k++) {
                    double value = vector.toArray()[k];
                    if (value != 0.0) {
                        margin += value
                                * weights.toArray()[(j * dataWithBiasSize) + k];
                    }
                }
                if (withBias) {
                    margin += weights.toArray()[(j * dataWithBiasSize)
                            + vector.size()];
                }
                if (margin > maxMargin) {
                    maxMargin = margin;
                }
            }
            double conf = 1.0 / (1.0 + Math.exp(-maxMargin));
            DecimalFormat twoDForm = new DecimalFormat("#.##");
            double confidenceScore = Double.valueOf(twoDForm.format(conf * 100));
            /* Approach to get confidence scores ends */
            return confidenceScore;
        }
    

    【讨论】:

      【解决方案2】:

      确实,这似乎是不可能的。查看源代码,您可能可以对其进行扩展以返回这些概率。

      https://github.com/apache/spark/blob/branch-1.5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

      if (numClasses == 2) {
        val margin = dot(weightMatrix, dataMatrix) + intercept
        val score = 1.0 / (1.0 + math.exp(-margin))
        threshold match {
          case Some(t) => if (score > t) 1.0 else 0.0
          case None => score
        }
      

      我希望它可以帮助开始寻找解决方法。

      【讨论】:

      • 你能举个例子吗?我浏览了 spark docs 中的 LogisticRegression.java 并找不到该方法。
      • 我找不到 raw2probabilityInPlaceraw2prediction 函数。你能帮忙吗?
      • 它在 org.apache.spark.ml.classificationLogisticRegressionModel 类中。如果它更简单,您也可以使用不同的名称复制它并将这些功能公开。
      • 请注意,我使用的是 org.apache.spark.mllib.classification.LogisticRegression 而不是 ml 包中的那个。两种情况都有区别。前者支持多类,但后者不支持(我之前已经评估过它们,ml 包中的一个只支持二进制分类)。你能帮忙吗?
      • 更新了第一条评论。希望对您有所帮助。
      猜你喜欢
      • 1970-01-01
      • 2018-08-25
      • 2019-05-25
      • 2017-01-24
      • 2015-10-12
      • 2018-02-17
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多