【发布时间】:2017-01-12 16:57:53
【问题描述】:
更新:我尝试使用以下方式生成置信度分数,但它给了我一个例外。我使用下面的代码sn-p:
double point = BLAS.dot(logisticregressionmodel.weights(), datavector);
double confScore = 1.0 / (1.0 + Math.exp(-point));
我得到的例外:
Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18
at scala.Predef$.require(Predef.scala:233)
at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99)
at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala)
你能帮忙吗?似乎权重向量的元素(198 个)比数据向量(我正在生成 18 个特征)更多。它们在dot() 函数中的长度必须相同。
我正在尝试在 Java 中实现一个程序,以从现有数据集进行训练,并使用 Spark MLLib (1.5.0) 中提供的逻辑回归算法对新数据集进行预测。我的训练和预测程序如下,我正在使用多类实现。问题是当我执行model.predict(vector)(注意预测程序中的 lrmodel.predict() )时,我得到了预测的标签。但是如果我需要一个信心分数呢?我怎么得到它?我已经浏览了 API,但无法找到任何给出置信度分数的特定 API。谁能帮帮我?
训练程序(生成 .model 文件)
public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Train").setMaster(
sparkMaster);
jsc = new JavaSparkContext(sparkConf);
...
JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache();
final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;
public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});
...
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;
public Object call(final Object arg0)
throws Exception {
...
StringBuilder featureSet =
new StringBuilder();
...
featureSet.append(i - 2);
featureSet.append(COLON);
featureSet.append(strVal);
featureSet.append(SPACE);
}
return featureSet.toString().trim();
}
});
List<String> featureList = featureRDD.collect();
String featureOutput = StringUtils.join(featureList, NEW_LINE);
String filePath = basePath + "lr.arff";
FileUtils.writeStringToFile(new File(filePath), featureOutput,
"UTF-8");
JavaRDD<LabeledPoint> trainingData =
MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache();
final LogisticRegressionModel model =
new LogisticRegressionWithLBFGS().setNumClasses(18).run(
trainingData.rdd());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(model);
oos.flush();
oos.close();
FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"),
baos.toByteArray());
baos.close();
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
预测程序(使用从训练程序生成的lr.model)
public static void main(final String[] args) throws Exception {
JavaSparkContext jsc = null;
int salesIndex = 1;
try {
...
SparkConf sparkConf =
new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster);
jsc = new JavaSparkContext(sparkConf);
ObjectInputStream objectInputStream =
new ObjectInputStream(new FileInputStream(basePath
+ "lr.model"));
LogisticRegressionModel lrmodel =
(LogisticRegressionModel) objectInputStream.readObject();
objectInputStream.close();
...
JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache();
final String firstRdd = trainRDD.first().trim();
JavaRDD<String> tempRddFilter =
trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() {
private static final long serialVersionUID =
11111111111111111L;
public Boolean call(final String arg0) {
return !arg0.trim().equalsIgnoreCase(firstRdd);
}
});
...
final Broadcast<LogisticRegressionModel> broadcastModel =
jsc.broadcast(lrmodel);
JavaRDD<String> featureRDD =
tempRddFilter
.map(new org.apache.spark.api.java.function.Function() {
private static final long serialVersionUID =
6948900080648474074L;
public Object call(final Object arg0)
throws Exception {
...
LogisticRegressionModel lrModel =
broadcastModel.value();
String row = ((String) arg0);
String[] featureSetArray =
row.split(CSV_SPLITTER);
...
final Vector vector =
Vectors.dense(doubleArr);
double score = lrModel.predict(vector);
...
return csvString;
}
});
String outputContent =
featureRDD
.reduce(new org.apache.spark.api.java.function.Function2() {
private static final long serialVersionUID =
1212970144641935082L;
public Object call(Object arg0, Object arg1)
throws Exception {
...
}
});
...
FileUtils.writeStringToFile(new File(basePath
+ "predicted-sales-data.csv"), sb.toString());
} catch (Exception e) {
e.printStackTrace();
} finally {
if (jsc != null) {
jsc.close();
}
}
}
}
【问题讨论】:
标签: java apache-spark-mllib logistic-regression