Skip to content

Commit

Permalink
Merge pull request #6 from markhamstra/streamingIterable
Browse files Browse the repository at this point in the history
SPY-287 updated streaming iterable
  • Loading branch information
jhartlaub committed Feb 21, 2014
2 parents 397f05f + ede65ec commit 12280b5
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 76 deletions.
16 changes: 8 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue}
import org.apache.spark.util.{Utils, BoundedPriorityQueue}

import org.apache.spark.SparkContext._
import org.apache.spark._
Expand Down Expand Up @@ -576,8 +576,6 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => f(iter))
}



/**
* Return an array that contains all of the elements in this RDD.
*/
Expand All @@ -599,14 +597,16 @@ abstract class RDD[T: ClassManifest](
}

/**
* Return iterable that lazily fetches partitions
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases
* driver memory requirement
* Return iterator that lazily fetches partitions
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism
* but also increases driver memory requirement.
* @param partitionBatchSize How many partitions fetch per job
* @param timeOut how long to wait for each partition fetch
* @return Iterable of every element in this RDD
*/
def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = {
new RDDiterable[T](this, prefetchPartitions, timeOut)
def toIterator(prefetchPartitions: Int = 1, partitionBatchSize: Int = 10,
timeOut: Duration = Duration(30, TimeUnit.SECONDS)):Iterator[T] = {
new RDDiterator[T](this, prefetchPartitions,partitionBatchSize, timeOut)
}

/**
Expand Down
76 changes: 76 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDDiterator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.apache.spark.rdd

import scala.concurrent.{Await, Future}
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration.Duration
import scala.annotation.tailrec
import scala.collection.mutable
import org.apache.spark.rdd.RDDiterator._
import org.apache.spark.FutureAction

/**
* Iterable whose iterator iterates over all elements of an RDD without fetching all partitions
* to the driver process
*
* @param rdd RDD to iterate
* @param prefetchPartitions The number of partitions to prefetch.
* If <1 will not prefetch.
* partitions prefetched = min(prefetchPartitions, partitionBatchSize)
* @param partitionBatchSize How many partitions to fetch per job
* @param timeOut How long to wait for each partition before failing.
*/
class RDDiterator[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, partitionBatchSize: Int,
timeOut: Duration)
extends Iterator[T] {

val batchSize = math.max(1,partitionBatchSize)
var partitionsBatches: Iterator[Seq[Int]] = Range(0, rdd.partitions.size).grouped(batchSize)
var pendingFetchesQueue = mutable.Queue.empty[Future[Seq[Seq[T]]]]
//add prefetchPartitions prefetch
0.until(math.max(0, prefetchPartitions / batchSize)).foreach(x=>enqueueDataFetch())

var currentIterator: Iterator[T] = Iterator.empty
@tailrec
final def hasNext = {
if (currentIterator.hasNext) {
//Still values in the current partition
true
} else {
//Move on to the next partition
//Queue new prefetch of a partition
enqueueDataFetch()
if (pendingFetchesQueue.isEmpty) {
//No more partitions
currentIterator = Iterator.empty
false
} else {
val future = pendingFetchesQueue.dequeue()
currentIterator = Await.result(future, timeOut).flatMap(x => x).iterator
//Next partition might be empty so check again.
this.hasNext
}
}
}
def next() = {
hasNext
currentIterator.next()
}

def enqueueDataFetch() ={
if (partitionsBatches.hasNext) {
pendingFetchesQueue.enqueue(fetchData(partitionsBatches.next(), rdd))
}
}
}

object RDDiterator {
private def fetchData[T: ClassManifest](partitionIndexes: Seq[Int],
rdd: RDD[T]): FutureAction[Seq[Seq[T]]] = {
val results = new ArrayBuffer[Seq[T]]()
rdd.context.submitJob[T, Array[T], Seq[Seq[T]]](rdd,
x => x.toArray,
partitionIndexes,
(inx: Int, res: Array[T]) => results.append(res),
results.toSeq)
}
}
59 changes: 0 additions & 59 deletions core/src/main/scala/org/apache/spark/util/RDDiterable.scala

This file was deleted.

28 changes: 19 additions & 9 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,23 +342,33 @@ class RDDSuite extends FunSuite with SharedSparkContext {

test("toIterable") {
var nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
assert(nums.toIterable().toArray === (1 to 999).toArray)
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
assert(nums.toIterator().toArray === (1 to 999).toArray)

nums = sc.makeRDD(Range(1000, 1, -1), 100)
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)
assert(nums.toIterator(prefetchPartitions = 10).size === 999)
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)

nums = sc.makeRDD(Range(1, 100), 1000)
assert(nums.toIterable(prefetchPartitions = 10).size === 99)
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray)
assert(nums.toIterator(prefetchPartitions = 10).size === 99)
assert(nums.toIterator(prefetchPartitions = 10).toArray === Range(1, 100).toArray)

nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterable(prefetchPartitions = -1).size === 999)
assert(nums.toIterable().toArray === (1 to 999).toArray)
}
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
assert(nums.toIterator().toArray === (1 to 999).toArray)

nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterator(prefetchPartitions = 3,partitionBatchSize = 10).size === 999)
assert(nums.toIterator().toArray === (1 to 999).toArray)

nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterator(prefetchPartitions = -1,partitionBatchSize = 0).size === 999)
assert(nums.toIterator().toArray === (1 to 999).toArray)

nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterator(prefetchPartitions = -1).size === 999)
assert(nums.toIterator().toArray === (1 to 999).toArray)
}

test("take") {
var nums = sc.makeRDD(Range(1, 1000), 1)
Expand Down

0 comments on commit 12280b5

Please sign in to comment.