From 3c2954b1efb7bcf7766013c849f671d464a2abff Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Mon, 12 Jan 2015 00:55:05 +0000 Subject: [PATCH] SPARK-3278 Isotonic regression java api --- .../spark/mllib/regression/IsotonicRegression.scala | 4 ++-- .../spark/mllib/util/IsotonicDataGenerator.scala | 11 ++++++++++- .../mllib/regression/JavaIsotonicRegressionSuite.java | 6 +++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index dda666c4c2c1b..238bfe5e3fab7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -46,8 +46,8 @@ class IsotonicRegressionModel ( * @param testData features to be labeled * @return predicted labels */ - def predict(testData: JavaRDD[java.lang.Double]): RDD[java.lang.Double] = - testData.rdd.map(x => x.doubleValue()).map(predict) + def predict(testData: JavaRDD[java.lang.Double]): JavaRDD[java.lang.Double] = + testData.rdd.map(_.doubleValue()).map(predict).map(new java.lang.Double(_)) /** * Predict a single label diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/IsotonicDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/IsotonicDataGenerator.scala index f5c8c97ff8dca..d976f3a5965e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/IsotonicDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/IsotonicDataGenerator.scala @@ -21,20 +21,28 @@ import org.apache.spark.annotation.DeveloperApi import scala.collection.JavaConversions._ import java.lang.{Double => JDouble} +/** + * :: DeveloperApi :: + * Generate test data for Isotonic regresision. + */ @DeveloperApi object IsotonicDataGenerator { /** * Return a Java List of ordered labeled points + * * @param labels list of labels for the data points * @return Java List of input. */ def generateIsotonicInputAsList(labels: Array[Double]): java.util.List[(JDouble, JDouble)] = { - seqAsJavaList(generateIsotonicInput(wrapDoubleArray(labels):_*).map(x => (new JDouble(x._1), new JDouble(x._2)))) + seqAsJavaList( + generateIsotonicInput( + wrapDoubleArray(labels):_*).map(x => (new JDouble(x._1), new JDouble(x._2)))) } /** * Return an ordered sequence of labeled data points with default weights + * * @param labels list of labels for the data points * @return sequence of data points */ @@ -45,6 +53,7 @@ object IsotonicDataGenerator { /** * Return an ordered sequence of labeled weighted data points + * * @param labels list of labels for the data points * @param weights list of weights for the data points * @return sequence of data points diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index ee072ef601e1e..f127b9b717a39 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -86,10 +86,10 @@ public Double call(Tuple2 v) throws Exception { } }); - Double[] predictions = model.predict(testRDD).collect(); + List predictions = model.predict(testRDD).collect(); - Assert.assertTrue(predictions[0] == 1d); - Assert.assertTrue(predictions[11] == 12d); + Assert.assertTrue(predictions.get(0) == 1d); + Assert.assertTrue(predictions.get(11) == 12d); } }