Skip to content

Commit

Permalink
add cache back
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 30, 2015
1 parent 0b35c15 commit 4dfe136
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
}

test("isotonic regression with unordered input") {
val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2)
val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache()

val model = new IsotonicRegression().run(trainRDD)
assert(model.predictions === Array(1, 2, 3, 4, 5))
Expand Down Expand Up @@ -169,7 +169,7 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
test("isotonic regression prediction with duplicate features") {
val trainRDD = sc.parallelize(
Seq[(Double, Double, Double)](
(2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2)
(2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache()
val model = new IsotonicRegression().run(trainRDD)

assert(model.predict(0) === 1)
Expand All @@ -181,7 +181,7 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
test("antitonic regression prediction with duplicate features") {
val trainRDD = sc.parallelize(
Seq[(Double, Double, Double)](
(5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2)
(5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache()
val model = new IsotonicRegression().setIsotonic(false).run(trainRDD)

assert(model.predict(0) === 6)
Expand All @@ -193,7 +193,7 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
test("isotonic regression RDD prediction") {
val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true)

val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2)
val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache()
val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2)
assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3))
}
Expand Down

0 comments on commit 4dfe136

Please sign in to comment.