diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b529754638908..dde38f0f74c81 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -45,6 +45,8 @@ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogL import org.apache.spark.SparkContext._ import org.apache.spark._ +import scala.concurrent.duration.Duration +import java.util.concurrent.TimeUnit /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -621,6 +623,19 @@ abstract class RDD[T: ClassTag]( filter(f.isDefinedAt).map(f) } + /** + * Return iterator that lazily fetches partitions + * @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism + * but also increases driver memory requirement. + * @param partitionBatchSize How many partitions fetch per job + * @param timeOut how long to wait for each partition fetch + * @return Iterable of every element in this RDD + */ + def toIterator(prefetchPartitions: Int = 1, partitionBatchSize: Int = 10, + timeOut: Duration = Duration(30, TimeUnit.SECONDS)):Iterator[T] = { + new RDDiterator[T](this, prefetchPartitions,partitionBatchSize, timeOut) + } + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala b/core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala new file mode 100644 index 0000000000000..1b423a1d7e387 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala @@ -0,0 +1,76 @@ +package org.apache.spark.rdd + +import scala.concurrent.{Await, Future} +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration +import scala.annotation.tailrec +import scala.collection.mutable +import org.apache.spark.rdd.RDDiterator._ +import org.apache.spark.FutureAction + +/** + * Iterable whose iterator iterates over all elements of an RDD without fetching all partitions + * to the driver process + * + * @param rdd RDD to iterate + * @param prefetchPartitions The number of partitions to prefetch. + * If <1 will not prefetch. + * partitions prefetched = min(prefetchPartitions, partitionBatchSize) + * @param partitionBatchSize How many partitions to fetch per job + * @param timeOut How long to wait for each partition before failing. + */ +class RDDiterator[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, partitionBatchSize: Int, + timeOut: Duration) + extends Iterator[T] { + + val batchSize = math.max(1,partitionBatchSize) + var partitionsBatches: Iterator[Seq[Int]] = Range(0, rdd.partitions.size).grouped(batchSize) + var pendingFetchesQueue = mutable.Queue.empty[Future[Seq[Seq[T]]]] + //add prefetchPartitions prefetch + 0.until(math.max(0, prefetchPartitions / batchSize)).foreach(x=>enqueueDataFetch()) + + var currentIterator: Iterator[T] = Iterator.empty + @tailrec + final def hasNext = { + if (currentIterator.hasNext) { + //Still values in the current partition + true + } else { + //Move on to the next partition + //Queue new prefetch of a partition + enqueueDataFetch() + if (pendingFetchesQueue.isEmpty) { + //No more partitions + currentIterator = Iterator.empty + false + } else { + val future = pendingFetchesQueue.dequeue() + currentIterator = Await.result(future, timeOut).flatMap(x => x).iterator + //Next partition might be empty so check again. + this.hasNext + } + } + } + def next() = { + hasNext + currentIterator.next() + } + + def enqueueDataFetch() ={ + if (partitionsBatches.hasNext) { + pendingFetchesQueue.enqueue(fetchData(partitionsBatches.next(), rdd)) + } + } +} + +object RDDiterator { + private def fetchData[T: ClassManifest](partitionIndexes: Seq[Int], + rdd: RDD[T]): FutureAction[Seq[Seq[T]]] = { + val results = new ArrayBuffer[Seq[T]]() + rdd.context.submitJob[T, Array[T], Seq[Seq[T]]](rdd, + x => x.toArray, + partitionIndexes, + (inx: Int, res: Array[T]) => results.append(res), + results.toSeq) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ac9df34afe6ee..db6690c242189 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -380,6 +380,36 @@ class RDDSuite extends FunSuite with SharedSparkContext { for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + test("toIterable") { + var nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = 10).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1000, 1, -1), 100) + assert(nums.toIterator(prefetchPartitions = 10).size === 999) + assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray) + + nums = sc.makeRDD(Range(1, 100), 1000) + assert(nums.toIterator(prefetchPartitions = 10).size === 99) + assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1, 100).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = -1).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = 3,partitionBatchSize = 10).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = -1,partitionBatchSize = 0).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterator(prefetchPartitions = -1).size === 999) + assert(nums.toIterator().toArray === (1 to 999).toArray) + } + test("take") { var nums = sc.makeRDD(Range(1, 1000), 1) assert(nums.take(0).size === 0)