From 0324bfd0c47851bd741f98e2a58423953df1cb07 Mon Sep 17 00:00:00 2001 From: Matthew Taylor Date: Sat, 11 Jan 2014 18:55:03 +0000 Subject: [PATCH 1/2] streaming iterable Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala --- .../main/scala/org/apache/spark/rdd/RDD.scala | 17 +++++- .../org/apache/spark/util/RDDiterable.scala | 60 +++++++++++++++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 20 +++++++ 3 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/org/apache/spark/util/RDDiterable.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b529754638908..08cfc808d09e4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,10 +41,12 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogLog} +import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue, SerializableHyperLogLog} import org.apache.spark.SparkContext._ import org.apache.spark._ +import scala.concurrent.duration.Duration +import java.util.concurrent.TimeUnit /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -601,6 +603,8 @@ abstract class RDD[T: ClassTag]( sc.runJob(this, (iter: Iterator[T]) => f(iter)) } + + /** * Return an array that contains all of the elements in this RDD. */ @@ -621,6 +625,17 @@ abstract class RDD[T: ClassTag]( filter(f.isDefinedAt).map(f) } + /** + * Return iterable that lazily fetches partitions + * @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases + * driver memory requirement + * @param timeOut how long to wait for each partition fetch + * @return Iterable of every element in this RDD + */ + def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = { + new RDDiterable[T](this, prefetchPartitions, timeOut) + } + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/util/RDDiterable.scala b/core/src/main/scala/org/apache/spark/util/RDDiterable.scala new file mode 100644 index 0000000000000..a68c31f81ccd2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RDDiterable.scala @@ -0,0 +1,60 @@ +package org.apache.spark.util + +import scala.collection.immutable.Queue +import scala.concurrent.{Await, Future} +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration +import scala.annotation.tailrec +import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD + +/**Iterable whose iterator iterates over all elements of an RDD without fetching all partitions to the driver process + * + * @param rdd RDD to iterate + * @param prefetchPartitions The number of partitions to prefetch + * @param timeOut How long to wait for each partition before failing. + * @tparam T + */ +class RDDiterable[T: ClassTag](rdd: RDD[T], prefetchPartitions: Int, timeOut: Duration) extends Serializable with Iterable[T] { + + def iterator = new Iterator[T] { + var partitions = Range(0, rdd.partitions.size) + var pendingFetches = Queue.empty.enqueue(partitions.take(prefetchPartitions).map(par => fetchData(par))) + partitions = partitions.drop(prefetchPartitions) + var currentIterator: Iterator[T] = Iterator.empty + @tailrec + def hasNext() = { + if (currentIterator.hasNext) { + true + } else { + pendingFetches = partitions.headOption.map { + partitionNo => + pendingFetches.enqueue(fetchData(partitionNo)) + }.getOrElse(pendingFetches) + partitions = partitions.drop(1) + + if (pendingFetches.isEmpty) { + currentIterator = Iterator.empty + false + } else { + val (future, pendingFetchesN) = pendingFetches.dequeue + pendingFetches = pendingFetchesN + currentIterator = Await.result(future, timeOut).iterator + this.hasNext() + } + } + } + def next() = { + hasNext() + currentIterator.next() + } + } + private def fetchData(partitionIndex: Int): Future[Seq[T]] = { + val results = new ArrayBuffer[T]() + rdd.context.submitJob[T, Array[T], Seq[T]](rdd, + x => x.toArray, + List(partitionIndex), + (inx: Int, res: Array[T]) => results.appendAll(res), + results.toSeq) + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ac9df34afe6ee..6c8994b1b0ced 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -380,6 +380,26 @@ class RDDSuite extends FunSuite with SharedSparkContext { for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + test("toIterable") { + var nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterable(prefetchPartitions = 10).size === 999) + assert(nums.toIterable().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1000, 1, -1), 100) + assert(nums.toIterable(prefetchPartitions = 10).size === 999) + assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray) + + nums = sc.makeRDD(Range(1, 100), 1000) + assert(nums.toIterable(prefetchPartitions = 10).size === 99) + assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterable(prefetchPartitions = -1).size === 999) + assert(nums.toIterable().toArray === (1 to 999).toArray) + } + + + test("take") { var nums = sc.makeRDD(Range(1, 1000), 1) assert(nums.take(0).size === 0) From 4f51bdf91657e7460ce71963addbcd327239d599 Mon Sep 17 00:00:00 2001 From: Matthew Taylor Date: Mon, 20 Jan 2014 16:29:16 +0000 Subject: [PATCH 2/2] updated streaming iterable Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala core/src/main/scala/org/apache/spark/util/RDDiterable.scala Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala core/src/main/scala/org/apache/spark/util/RDDiterable.scala --- .../main/scala/org/apache/spark/rdd/RDD.scala | 16 ++-- .../org/apache/spark/rdd/RDDiterator.scala | 76 +++++++++++++++++++ .../org/apache/spark/util/RDDiterable.scala | 60 --------------- .../scala/org/apache/spark/rdd/RDDSuite.scala | 28 ++++--- 4 files changed, 103 insertions(+), 77 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala delete mode 100644 core/src/main/scala/org/apache/spark/util/RDDiterable.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 08cfc808d09e4..dde38f0f74c81 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue, SerializableHyperLogLog} +import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogLog} import org.apache.spark.SparkContext._ import org.apache.spark._ @@ -603,8 +603,6 @@ abstract class RDD[T: ClassTag]( sc.runJob(this, (iter: Iterator[T]) => f(iter)) } - - /** * Return an array that contains all of the elements in this RDD. */ @@ -626,14 +624,16 @@ abstract class RDD[T: ClassTag]( } /** - * Return iterable that lazily fetches partitions - * @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases - * driver memory requirement + * Return iterator that lazily fetches partitions + * @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism + * but also increases driver memory requirement. + * @param partitionBatchSize How many partitions fetch per job * @param timeOut how long to wait for each partition fetch * @return Iterable of every element in this RDD */ - def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = { - new RDDiterable[T](this, prefetchPartitions, timeOut) + def toIterator(prefetchPartitions: Int = 1, partitionBatchSize: Int = 10, + timeOut: Duration = Duration(30, TimeUnit.SECONDS)):Iterator[T] = { + new RDDiterator[T](this, prefetchPartitions,partitionBatchSize, timeOut) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala b/core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala new file mode 100644 index 0000000000000..1b423a1d7e387 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala @@ -0,0 +1,76 @@ +package org.apache.spark.rdd + +import scala.concurrent.{Await, Future} +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration +import scala.annotation.tailrec +import scala.collection.mutable +import org.apache.spark.rdd.RDDiterator._ +import org.apache.spark.FutureAction + +/** + * Iterable whose iterator iterates over all elements of an RDD without fetching all partitions + * to the driver process + * + * @param rdd RDD to iterate + * @param prefetchPartitions The number of partitions to prefetch. + * If <1 will not prefetch. + * partitions prefetched = min(prefetchPartitions, partitionBatchSize) + * @param partitionBatchSize How many partitions to fetch per job + * @param timeOut How long to wait for each partition before failing. + */ +class RDDiterator[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, partitionBatchSize: Int, + timeOut: Duration) + extends Iterator[T] { + + val batchSize = math.max(1,partitionBatchSize) + var partitionsBatches: Iterator[Seq[Int]] = Range(0, rdd.partitions.size).grouped(batchSize) + var pendingFetchesQueue = mutable.Queue.empty[Future[Seq[Seq[T]]]] + //add prefetchPartitions prefetch + 0.until(math.max(0, prefetchPartitions / batchSize)).foreach(x=>enqueueDataFetch()) + + var currentIterator: Iterator[T] = Iterator.empty + @tailrec + final def hasNext = { + if (currentIterator.hasNext) { + //Still values in the current partition + true + } else { + //Move on to the next partition + //Queue new prefetch of a partition + enqueueDataFetch() + if (pendingFetchesQueue.isEmpty) { + //No more partitions + currentIterator = Iterator.empty + false + } else { + val future = pendingFetchesQueue.dequeue() + currentIterator = Await.result(future, timeOut).flatMap(x => x).iterator + //Next partition might be empty so check again. + this.hasNext + } + } + } + def next() = { + hasNext + currentIterator.next() + } + + def enqueueDataFetch() ={ + if (partitionsBatches.hasNext) { + pendingFetchesQueue.enqueue(fetchData(partitionsBatches.next(), rdd)) + } + } +} + +object RDDiterator { + private def fetchData[T: ClassManifest](partitionIndexes: Seq[Int], + rdd: RDD[T]): FutureAction[Seq[Seq[T]]] = { + val results = new ArrayBuffer[Seq[T]]() + rdd.context.submitJob[T, Array[T], Seq[Seq[T]]](rdd, + x => x.toArray, + partitionIndexes, + (inx: Int, res: Array[T]) => results.append(res), + results.toSeq) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/RDDiterable.scala b/core/src/main/scala/org/apache/spark/util/RDDiterable.scala deleted file mode 100644 index a68c31f81ccd2..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/RDDiterable.scala +++ /dev/null @@ -1,60 +0,0 @@ -package org.apache.spark.util - -import scala.collection.immutable.Queue -import scala.concurrent.{Await, Future} -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration.Duration -import scala.annotation.tailrec -import scala.reflect.ClassTag -import org.apache.spark.rdd.RDD - -/**Iterable whose iterator iterates over all elements of an RDD without fetching all partitions to the driver process - * - * @param rdd RDD to iterate - * @param prefetchPartitions The number of partitions to prefetch - * @param timeOut How long to wait for each partition before failing. - * @tparam T - */ -class RDDiterable[T: ClassTag](rdd: RDD[T], prefetchPartitions: Int, timeOut: Duration) extends Serializable with Iterable[T] { - - def iterator = new Iterator[T] { - var partitions = Range(0, rdd.partitions.size) - var pendingFetches = Queue.empty.enqueue(partitions.take(prefetchPartitions).map(par => fetchData(par))) - partitions = partitions.drop(prefetchPartitions) - var currentIterator: Iterator[T] = Iterator.empty - @tailrec - def hasNext() = { - if (currentIterator.hasNext) { - true - } else { - pendingFetches = partitions.headOption.map { - partitionNo => - pendingFetches.enqueue(fetchData(partitionNo)) - }.getOrElse(pendingFetches) - partitions = partitions.drop(1) - - if (pendingFetches.isEmpty) { - currentIterator = Iterator.empty - false - } else { - val (future, pendingFetchesN) = pendingFetches.dequeue - pendingFetches = pendingFetchesN - currentIterator = Await.result(future, timeOut).iterator - this.hasNext() - } - } - } - def next() = { - hasNext() - currentIterator.next() - } - } - private def fetchData(partitionIndex: Int): Future[Seq[T]] = { - val results = new ArrayBuffer[T]() - rdd.context.submitJob[T, Array[T], Seq[T]](rdd, - x => x.toArray, - List(partitionIndex), - (inx: Int, res: Array[T]) => results.appendAll(res), - results.toSeq) - } -} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6c8994b1b0ced..db6690c242189 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -382,23 +382,33 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("toIterable") { var nums = sc.makeRDD(Range(1, 1000), 100) - assert(nums.toIterable(prefetchPartitions = 10).size === 999) - assert(nums.toIterable().toArray === (1 to 999).toArray) + assert(nums.toIterator(prefetchPartitions = 10).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) nums = sc.makeRDD(Range(1000, 1, -1), 100) - assert(nums.toIterable(prefetchPartitions = 10).size === 999) - assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray) + assert(nums.toIterator(prefetchPartitions = 10).size === 999) + assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray) nums = sc.makeRDD(Range(1, 100), 1000) - assert(nums.toIterable(prefetchPartitions = 10).size === 99) - assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray) + assert(nums.toIterator(prefetchPartitions = 10).size === 99) + assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1, 100).toArray) nums = sc.makeRDD(Range(1, 1000), 100) - assert(nums.toIterable(prefetchPartitions = -1).size === 999) - assert(nums.toIterable().toArray === (1 to 999).toArray) - } + assert(nums.toIterator(prefetchPartitions = -1).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = 3,partitionBatchSize = 10).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = -1,partitionBatchSize = 0).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = -1).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + } test("take") { var nums = sc.makeRDD(Range(1, 1000), 1)