Skip to content

Commit

Permalink
move treeReduce and treeAggregate to core
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 27, 2015
1 parent ff356e2 commit d600b6c
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 74 deletions.
63 changes: 63 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,38 @@ abstract class RDD[T: ClassTag](
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}

/**
* Reduces the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#reduce]]
*/
def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
val cleanF = context.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
} else {
None
}
}
val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it)))
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
if (c.isDefined && x.isDefined) {
Some(cleanF(c.get, x.get))
} else if (c.isDefined) {
c
} else if (x.isDefined) {
x
} else {
None
}
}
partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth)
.getOrElse(throw new UnsupportedOperationException("empty collection"))
}

/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
Expand Down Expand Up @@ -935,6 +967,37 @@ abstract class RDD[T: ClassTag](
jobResult
}

/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#aggregate]]
*/
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.size == 0) {
return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
}
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.size
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
while (numPartitions > scale + numPartitions / scale) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)
}

/**
* Return the number of elements in the RDD.
*/
Expand Down
19 changes: 19 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}

test("treeAggregate") {
val rdd = sc.makeRDD(-1000 until 1000, 10)
def seqOp = (c: Long, x: Int) => c + x
def combOp = (c1: Long, c2: Long) => c1 + c2
for (depth <- 1 until 10) {
val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
assert(sum === -1000L)
}
}

test("treeReduce") {
val rdd = sc.makeRDD(-1000 until 1000, 10)
for (depth <- 1 until 10) {
val sum = rdd.treeReduce(_ + _, depth)
assert(sum === -1000)
}
}

test("basic caching") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
Expand Down Expand Up @@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assertFails { sc.parallelize(1 to 100) }
assertFails { sc.textFile("/nonexistent-path") }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd.RDD

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.mllib.feature
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.rdd.RDDFunctions._

/**
* Class used to solve an optimization problem using Gradient Descent.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd.RDD

/**
Expand Down
59 changes: 9 additions & 50 deletions mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ import scala.language.implicitConversions
import scala.reflect.ClassTag

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

/**
* Machine learning specific RDD functions.
Expand Down Expand Up @@ -53,69 +50,31 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
* Reduces the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#reduce]]
* @see [[org.apache.spark.rdd.RDD#treeReduce]]
* @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead.
*/
def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
val cleanF = self.context.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
} else {
None
}
}
val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it)))
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
if (c.isDefined && x.isDefined) {
Some(cleanF(c.get, x.get))
} else if (c.isDefined) {
c
} else if (x.isDefined) {
x
} else {
None
}
}
RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth)
.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
@deprecated("Use RDD.treeReduce instead.", "1.3.0")
def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth)

/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#aggregate]]
* @see [[org.apache.spark.rdd.RDD#treeAggregate]]
* @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead.
*/
@deprecated("Use RDD.treeAggregate instead.", "1.3.0")
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (self.partitions.size == 0) {
return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance())
}
val cleanSeqOp = self.context.clean(seqOp)
val cleanCombOp = self.context.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.size
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
while (numPartitions > scale + numPartitions / scale) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)
self.treeAggregate(zeroValue)(seqOp, combOp, depth)
}
}

@DeveloperApi
object RDDFunctions {

/** Implicit conversion from an RDD to RDDFunctions. */
implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd)
implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
import org.apache.spark.rdd.RDD

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq)
assert(sliding === expected)
}

test("treeAggregate") {
val rdd = sc.makeRDD(-1000 until 1000, 10)
def seqOp = (c: Long, x: Int) => c + x
def combOp = (c1: Long, c2: Long) => c1 + c2
for (depth <- 1 until 10) {
val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth)
assert(sum === -1000L)
}
}

test("treeReduce") {
val rdd = sc.makeRDD(-1000 until 1000, 10)
for (depth <- 1 until 10) {
val sum = rdd.treeReduce(_ + _, depth)
assert(sum === -1000)
}
}
}

0 comments on commit d600b6c

Please sign in to comment.