diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index ba33762fba754..a78f47b416a90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -66,7 +66,7 @@ object MonotonicityConstraint { * @param monotonicityConstraint specifies if the sequence is increasing or decreasing */ class IsotonicRegressionModel( - val predictions: Seq[WeightedLabeledPoint], + val predictions: Seq[(Double, Double, Double)], val monotonicityConstraint: MonotonicityConstraint) extends RegressionModel { @@ -76,7 +76,7 @@ class IsotonicRegressionModel( override def predict(testData: Vector): Double = { // Take the highest of data points smaller than our feature or data point with lowest feature (predictions.head +: - predictions.filter(y => y.features.toArray.head <= testData.toArray.head)).last.label + predictions.filter(y => y._2 <= testData.toArray.head)).last._1 } } @@ -95,7 +95,7 @@ trait IsotonicRegressionAlgorithm * @return isotonic regression model */ protected def createModel( - predictions: Seq[WeightedLabeledPoint], + predictions: Seq[(Double, Double, Double)], monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel /** @@ -106,7 +106,7 @@ trait IsotonicRegressionAlgorithm * @return isotonic regression model */ def run( - input: RDD[WeightedLabeledPoint], + input: RDD[(Double, Double, Double)], monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel } @@ -117,7 +117,7 @@ class PoolAdjacentViolators private [mllib] extends IsotonicRegressionAlgorithm { override def run( - input: RDD[WeightedLabeledPoint], + input: RDD[(Double, Double, Double)], monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = { createModel( parallelPoolAdjacentViolators(input, monotonicityConstraint), @@ -125,7 +125,7 @@ class PoolAdjacentViolators private [mllib] } override protected def createModel( - predictions: Seq[WeightedLabeledPoint], + predictions: Seq[(Double, Double, Double)], monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = { new IsotonicRegressionModel(predictions, monotonicityConstraint) } @@ -194,12 +194,12 @@ class PoolAdjacentViolators private [mllib] * @return result */ private def parallelPoolAdjacentViolators( - testData: RDD[WeightedLabeledPoint], - monotonicityConstraint: MonotonicityConstraint): Seq[WeightedLabeledPoint] = { + testData: RDD[(Double, Double, Double)], + monotonicityConstraint: MonotonicityConstraint): Seq[(Double, Double, Double)] = { poolAdjacentViolators( testData - .sortBy(_.features.toArray.head) + .sortBy(_._2) .cache() .mapPartitions(it => poolAdjacentViolators(it.toArray, monotonicityConstraint).toIterator) .collect(), monotonicityConstraint) @@ -224,7 +224,7 @@ object IsotonicRegression { * @param monotonicityConstraint Isotonic (increasing) or Antitonic (decreasing) sequence */ def train( - input: RDD[WeightedLabeledPoint], + input: RDD[(Double, Double, Double)], monotonicityConstraint: MonotonicityConstraint = Isotonic): IsotonicRegressionModel = { new PoolAdjacentViolators().run(input, monotonicityConstraint) }