Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Mar 9, 2015
1 parent 24c92a4 commit d730286
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ private[spark] object PythonRDD extends Logging {
* This method will serve an iterator of an array that contains all elements in the RDD
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return the port number of a local socket which serves the data collected from this job.
*/
def runJob(
sc: SparkContext,
Expand All @@ -356,14 +358,17 @@ private[spark] object PythonRDD extends Logging {
val allPartitions: Array[UnrolledPartition] =
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
serveIterator(flattenedPartition.iterator)
serveIterator(flattenedPartition.iterator,
s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
}

/**
* A helper function to collect an RDD as an iterator, then serve it via socket
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return the port number of a local socket which serves the data collected from this job.
*/
def collectAndServe[T](rdd: RDD[T]): Int = {
serveIterator(rdd.collect().iterator)
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
Expand Down Expand Up @@ -583,12 +588,24 @@ private[spark] object PythonRDD extends Logging {
dataOut.write(bytes)
}

private def serveIterator[T](items: Iterator[T]): Int = {
/**
* Create a socket server and a background thread to serve the data in `items`,
*
* The socket server can only accept one connection, or close if no connection
* in 3 seconds.
*
* Once a connection comes in, it tries to serialize all the data in `items`
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
*/
private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
val serverSocket = new ServerSocket(0, 1)
serverSocket.setReuseAddress(true)
// Close the socket if no connection in 3 seconds
serverSocket.setSoTimeout(3000)

new Thread("serve iterator") {
new Thread(threadName) {
setDaemon(true)
override def run() {
try {
Expand All @@ -601,7 +618,7 @@ private[spark] object PythonRDD extends Logging {
}
} catch {
case NonFatal(e) =>
logError(s"Error while sending iterator: $e")
logError(s"Error while sending iterator", e)
} finally {
serverSocket.close()
}
Expand Down

0 comments on commit d730286

Please sign in to comment.