Skip to content

Commit

Permalink
Graph primitives2
Browse files Browse the repository at this point in the history
Hi guys,

I'm following Joey and Ankur's suggestions to add collectEdges and pickRandomVertex. I'm also adding the tests for collectEdges and refactoring one method getCycleGraph in GraphOpsSuite.scala.

Thank you,

semih

Author: Semih Salihoglu <semihsalihoglu@gmail.com>

Closes apache#580 from semihsalihoglu/GraphPrimitives2 and squashes the following commits:

937d3ec [Semih Salihoglu] - Fixed the scalastyle errors.
a69a152 [Semih Salihoglu] - Adding collectEdges and pickRandomVertices. - Adding tests for collectEdges. - Refactoring a getCycle utility function for GraphOpsSuite.scala.
41265a6 [Semih Salihoglu] - Adding collectEdges and pickRandomVertex. - Adding tests for collectEdges. - Recycling a getCycle utility test file.
  • Loading branch information
semihsalihoglu authored and rxin committed Feb 25, 2014
1 parent a4f4fbc commit 1f4c7f7
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 10 deletions.
59 changes: 58 additions & 1 deletion graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
package org.apache.spark.graphx

import scala.reflect.ClassTag

import org.apache.spark.SparkContext._
import org.apache.spark.SparkException
import org.apache.spark.graphx.lib._
import org.apache.spark.rdd.RDD
import scala.util.Random

/**
* Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
Expand Down Expand Up @@ -137,6 +137,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
}
} // end of collectNeighbor

/**
* Returns an RDD that contains for each vertex v its local edges,
* i.e., the edges that are incident on v, in the user-specified direction.
* Warning: note that singleton vertices, those with no edges in the given
* direction will not be part of the return value.
*
* @note This function could be highly inefficient on power-law
* graphs where high degree vertices may force a large amount of
* information to be collected to a single location.
*
* @param edgeDirection the direction along which to collect
* the local edges of vertices
*
* @return the local edges for each vertex
*/
def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
edgeDirection match {
case EdgeDirection.Either =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
(edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
case EdgeDirection.In =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
case EdgeDirection.Out =>
graph.mapReduceTriplets[Array[Edge[ED]]](
edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
(a, b) => a ++ b)
case EdgeDirection.Both =>
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
}
}

/**
* Join the vertices with an RDD and then apply a function from the
* the vertex and RDD entry to a new vertex value. The input table
Expand Down Expand Up @@ -209,6 +245,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
graph.mask(preprocess(graph).subgraph(epred, vpred))
}

/**
* Picks a random vertex from the graph and returns its ID.
*/
def pickRandomVertex(): VertexId = {
val probability = 50 / graph.numVertices
var found = false
var retVal: VertexId = null.asInstanceOf[VertexId]
while (!found) {
val selectedVertices = graph.vertices.flatMap { vidVvals =>
if (Random.nextDouble() < probability) { Some(vidVvals._1) }
else { None }
}
if (selectedVertices.count > 1) {
found = true
val collectedVertices = selectedVertices.collect()
retVal = collectedVertices(Random.nextInt(collectedVertices.size))
}
}
retVal
}

/**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
* user-defined vertex-program `vprog` is executed in parallel on
Expand Down
134 changes: 125 additions & 9 deletions graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {

test("collectNeighborIds") {
withSpark { sc =>
val chain = (0 until 100).map(x => (x, (x+1)%100) )
val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
val graph = getCycleGraph(sc, 100)
val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
assert(nbrs.count === chain.size)
assert(nbrs.count === 100)
assert(graph.numVertices === nbrs.count)
nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
nbrs.collect.foreach { case (vid, nbrs) =>
val s = nbrs.toSet
assert(s.contains((vid + 1) % 100))
assert(s.contains(if (vid > 0) vid - 1 else 99 ))
nbrs.collect.foreach {
case (vid, nbrs) =>
val s = nbrs.toSet
assert(s.contains((vid + 1) % 100))
assert(s.contains(if (vid > 0) vid - 1 else 99))
}
}
}

test ("filter") {
withSpark { sc =>
val n = 5
Expand All @@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext {
}
}

test("collectEdgesCycleDirectionOut") {
withSpark { sc =>
val graph = getCycleGraph(sc, 100)
val edges = graph.collectEdges(EdgeDirection.Out).cache()
assert(edges.count == 100)
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
edges.collect.foreach {
case (vid, edges) =>
val s = edges.toSet
val edgeDstIds = s.map(e => e.dstId)
assert(edgeDstIds.contains((vid + 1) % 100))
}
}
}

test("collectEdgesCycleDirectionIn") {
withSpark { sc =>
val graph = getCycleGraph(sc, 100)
val edges = graph.collectEdges(EdgeDirection.In).cache()
assert(edges.count == 100)
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
edges.collect.foreach {
case (vid, edges) =>
val s = edges.toSet
val edgeSrcIds = s.map(e => e.srcId)
assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99))
}
}
}

test("collectEdgesCycleDirectionEither") {
withSpark { sc =>
val graph = getCycleGraph(sc, 100)
val edges = graph.collectEdges(EdgeDirection.Either).cache()
assert(edges.count == 100)
edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) }
edges.collect.foreach {
case (vid, edges) =>
val s = edges.toSet
val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
assert(edgeIds.contains((vid + 1) % 100))
assert(edgeIds.contains(if (vid > 0) vid - 1 else 99))
}
}
}

test("collectEdgesChainDirectionOut") {
withSpark { sc =>
val graph = getChainGraph(sc, 50)
val edges = graph.collectEdges(EdgeDirection.Out).cache()
assert(edges.count == 49)
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
edges.collect.foreach {
case (vid, edges) =>
val s = edges.toSet
val edgeDstIds = s.map(e => e.dstId)
assert(edgeDstIds.contains(vid + 1))
}
}
}

test("collectEdgesChainDirectionIn") {
withSpark { sc =>
val graph = getChainGraph(sc, 50)
val edges = graph.collectEdges(EdgeDirection.In).cache()
// We expect only 49 because collectEdges does not return vertices that do
// not have any edges in the specified direction.
assert(edges.count == 49)
edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) }
edges.collect.foreach {
case (vid, edges) =>
val s = edges.toSet
val edgeDstIds = s.map(e => e.srcId)
assert(edgeDstIds.contains((vid - 1) % 100))
}
}
}

test("collectEdgesChainDirectionEither") {
withSpark { sc =>
val graph = getChainGraph(sc, 50)
val edges = graph.collectEdges(EdgeDirection.Either).cache()
// We expect only 49 because collectEdges does not return vertices that do
// not have any edges in the specified direction.
assert(edges.count === 50)
edges.collect.foreach {
case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2)
else assert(edges.size == 1)
}
edges.collect.foreach {
case (vid, edges) =>
val s = edges.toSet
val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId)
if (vid == 0) { assert(edgeIds.contains(1)) }
else if (vid == 49) { assert(edgeIds.contains(48)) }
else {
assert(edgeIds.contains(vid + 1))
assert(edgeIds.contains(vid - 1))
}
}
}
}

private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices))
getGraphFromSeq(sc, cycle)
}

private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = {
val chain = (0 until numVertices - 1).map(x => (x, (x + 1)))
getGraphFromSeq(sc, chain)
}

private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = {
val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) }
Graph.fromEdgeTuples(rawEdges, 1.0).cache()
}
}

0 comments on commit 1f4c7f7

Please sign in to comment.