Skip to content

Commit

Permalink
Fixed SVDPlusPlusSuite in Maven build.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Jan 15, 2014
1 parent 74b46ac commit dfb1524
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
16 changes: 16 additions & 0 deletions graphx/src/test/resources/als-test.data
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
1,1,5.0
1,2,1.0
1,3,5.0
1,4,1.0
2,1,5.0
2,2,1.0
2,3,5.0
2,4,1.0
3,1,1.0
3,2,5.0
3,3,1.0
3,4,5.0
4,1,1.0
4,2,5.0
4,3,1.0
4,4,5.0
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,24 @@ package org.apache.spark.graphx.lib

import org.scalatest.FunSuite

import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.GraphGenerators
import org.apache.spark.rdd._


class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {

test("Test SVD++ with mean square error on training set") {
withSpark { sc =>
val svdppErr = 8.0
val edges = sc.textFile("mllib/data/als/test.data").map { line =>
val edges = sc.textFile(getClass.getResource("/als-test.data").getFile).map { line =>
val fields = line.split(",")
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
var (graph, u) = SVDPlusPlus.run(edges, conf)
graph.cache()
val err = graph.vertices.collect.map{ case (vid, vd) =>
val err = graph.vertices.collect().map{ case (vid, vd) =>
if (vid % 2 == 1) vd._4 else 0.0
}.reduce(_ + _) / graph.triplets.collect.size
}.reduce(_ + _) / graph.triplets.collect().size
assert(err <= svdppErr)
}
}
Expand Down

0 comments on commit dfb1524

Please sign in to comment.