Skip to content

Commit

Permalink
Removed WeightedLabeledPoint. Replaced by tuple of doubles
Browse files Browse the repository at this point in the history
  • Loading branch information
zapletal-martin committed Dec 27, 2014
1 parent 34760d5 commit b8b1620
Showing 1 changed file with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object MonotonicityConstraint {
* @param monotonicityConstraint specifies if the sequence is increasing or decreasing
*/
class IsotonicRegressionModel(
val predictions: Seq[WeightedLabeledPoint],
val predictions: Seq[(Double, Double, Double)],
val monotonicityConstraint: MonotonicityConstraint)
extends RegressionModel {

Expand All @@ -76,7 +76,7 @@ class IsotonicRegressionModel(
override def predict(testData: Vector): Double = {
// Take the highest of data points smaller than our feature or data point with lowest feature
(predictions.head +:
predictions.filter(y => y.features.toArray.head <= testData.toArray.head)).last.label
predictions.filter(y => y._2 <= testData.toArray.head)).last._1
}
}

Expand All @@ -95,7 +95,7 @@ trait IsotonicRegressionAlgorithm
* @return isotonic regression model
*/
protected def createModel(
predictions: Seq[WeightedLabeledPoint],
predictions: Seq[(Double, Double, Double)],
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel

/**
Expand All @@ -106,7 +106,7 @@ trait IsotonicRegressionAlgorithm
* @return isotonic regression model
*/
def run(
input: RDD[WeightedLabeledPoint],
input: RDD[(Double, Double, Double)],
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel
}

Expand All @@ -117,15 +117,15 @@ class PoolAdjacentViolators private [mllib]
extends IsotonicRegressionAlgorithm {

override def run(
input: RDD[WeightedLabeledPoint],
input: RDD[(Double, Double, Double)],
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
createModel(
parallelPoolAdjacentViolators(input, monotonicityConstraint),
monotonicityConstraint)
}

override protected def createModel(
predictions: Seq[WeightedLabeledPoint],
predictions: Seq[(Double, Double, Double)],
monotonicityConstraint: MonotonicityConstraint): IsotonicRegressionModel = {
new IsotonicRegressionModel(predictions, monotonicityConstraint)
}
Expand Down Expand Up @@ -194,12 +194,12 @@ class PoolAdjacentViolators private [mllib]
* @return result
*/
private def parallelPoolAdjacentViolators(
testData: RDD[WeightedLabeledPoint],
monotonicityConstraint: MonotonicityConstraint): Seq[WeightedLabeledPoint] = {
testData: RDD[(Double, Double, Double)],
monotonicityConstraint: MonotonicityConstraint): Seq[(Double, Double, Double)] = {

poolAdjacentViolators(
testData
.sortBy(_.features.toArray.head)
.sortBy(_._2)
.cache()
.mapPartitions(it => poolAdjacentViolators(it.toArray, monotonicityConstraint).toIterator)
.collect(), monotonicityConstraint)
Expand All @@ -224,7 +224,7 @@ object IsotonicRegression {
* @param monotonicityConstraint Isotonic (increasing) or Antitonic (decreasing) sequence
*/
def train(
input: RDD[WeightedLabeledPoint],
input: RDD[(Double, Double, Double)],
monotonicityConstraint: MonotonicityConstraint = Isotonic): IsotonicRegressionModel = {
new PoolAdjacentViolators().run(input, monotonicityConstraint)
}
Expand Down

0 comments on commit b8b1620

Please sign in to comment.