Skip to content

Commit

Permalink
Implemented DStream.foreachRDD in the Python API using Py4J callback …
Browse files Browse the repository at this point in the history
…server
  • Loading branch information
giwa authored and Ken Takagiwa committed Aug 1, 2014
2 parents cc2092b + 54e2e8c commit 28c6620
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 82 deletions.
4 changes: 1 addition & 3 deletions examples/src/main/python/streaming/network_wordcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
fm_lines = lines.flatMap(lambda x: x.split(" "))
mapped_lines = fm_lines.map(lambda x: (x, 1))
reduced_lines = mapped_lines.reduceByKey(add)

fm_lines.pyprint()
mapped_lines.pyprint()

reduced_lines.pyprint()
ssc.start()
ssc.awaitTermination()
Binary file modified python/lib/py4j-0.8.1-src.zip
Binary file not shown.
2 changes: 1 addition & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run(self):
EchoOutputThread(proc.stdout).start()

# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False, start_callback_server=True)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
Expand Down
44 changes: 34 additions & 10 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ def print_(self):
#hack to call print function in DStream
getattr(self._jdstream, "print")()

def pyprint(self):
"""
Print the first ten elements of each RDD generated in this DStream. This is an output
operator, so this DStream will be registered as an output stream and there materialized.
"""
self._jdstream.pyprint()

def filter(self, f):
"""
Return DStream containing only the elements that satisfy predicate.
Expand Down Expand Up @@ -190,6 +182,38 @@ def getNumPartitions(self):
# TODO: remove hardcoding. RDD has NumPartitions but DStream does not have.
return 2

def foreachRDD(self, func):
"""
"""
from utils import RDDFunction
wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), wrapped_func)

def pyprint(self):
"""
Print the first ten elements of each RDD generated in this DStream. This is an output
operator, so this DStream will be registered as an output stream and there materialized.
"""
def takeAndPrint(rdd, time):
taken = rdd.take(11)
print "-------------------------------------------"
print "Time: %s" % (str(time))
print "-------------------------------------------"
for record in taken[:10]:
print record
if len(taken) > 10:
print "..."
print

self.foreachRDD(takeAndPrint)


#def transform(self, func):
# from utils import RDDFunction
# wrapped_func = RDDFunction(self.ctx, self._jrdd_deserializer, func)
# jdstream = self.ctx._jvm.PythonTransformedDStream(self._jdstream.dstream(), wrapped_func).toJavaDStream
# return DStream(jdstream, self._ssc, ...) ## DO NOT KNOW HOW

class PipelinedDStream(DStream):
def __init__(self, prev, func, preservesPartitioning=False):
Expand All @@ -209,7 +233,6 @@ def pipeline_func(split, iterator):
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self._ssc = prev._ssc
self.ctx = prev.ctx
self.prev = prev
Expand Down Expand Up @@ -246,4 +269,5 @@ def _jdstream(self):
return self._jdstream_val

def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)
return not (self.is_cached)

22 changes: 22 additions & 0 deletions python/pyspark/streaming/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@
# limitations under the License.
#

from pyspark.rdd import RDD


class RDDFunction():
def __init__(self, ctx, jrdd_deserializer, func):
self.ctx = ctx
self.deserializer = jrdd_deserializer
self.func = func

def call(self, jrdd, time):
# Wrap JavaRDD into python's RDD class
rdd = RDD(jrdd, self.ctx, self.deserializer)
# Call user defined RDD function
self.func(rdd, time)

def __str__(self):
return "%s, %s" % (str(self.deserializer), str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']



def msDurationToString(ms):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
dstream.print()
}

/**
* Print the first ten elements of each PythonRDD generated in the PythonDStream. This is an output
* operator, so this PythonDStream will be registered as an output stream and there materialized.
* This function is for PythonAPI.
*/
//TODO move this function to PythonDStream
def pyprint() = dstream.pyprint()

/**
* Return a new DStream in which each RDD has a single element generated by counting each RDD
* of this DStream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class PythonDStream[T: ClassTag](
}
}

def foreachRDD(foreachFunc: PythonRDDFunction) {
new PythonForeachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
}

val asJavaDStream = JavaDStream.fromDStream(this)
}

Expand Down Expand Up @@ -85,5 +89,39 @@ DStream[Array[Byte]](prev.ssc){
case None => None
}
}

val asJavaDStream = JavaDStream.fromDStream(this)
}

class PythonForeachDStream(
prev: DStream[Array[Byte]],
foreachFunction: PythonRDDFunction
) extends ForEachDStream[Array[Byte]](
prev,
(rdd: RDD[Array[Byte]], time: Time) => {
foreachFunction.call(rdd.toJavaRDD(), time.milliseconds)
}
) {

this.register()
}
/*
This does not work. Ignore this for now. -TD
class PythonTransformedDStream(
prev: DStream[Array[Byte]],
transformFunction: PythonRDDFunction
) extends DStream[Array[Byte]](prev.ssc) {
override def dependencies = List(prev)
override def slideDuration: Duration = prev.slideDuration
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
prev.getOrCompute(validTime).map(rdd => {
transformFunction.call(rdd.toJavaRDD(), validTime.milliseconds).rdd
})
}
val asJavaDStream = JavaDStream.fromDStream(this)
}
*/
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.apache.spark.streaming.api.python;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.streaming.Time;

public interface PythonRDDFunction {
JavaRDD<byte[]> call(JavaRDD<byte[]> rdd, long time);
}
Original file line number Diff line number Diff line change
Expand Up @@ -623,66 +623,6 @@ abstract class DStream[T: ClassTag] (
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}

//TODO: move pyprint to PythonDStream and executed by py4j call back function
/**
* Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output
* operator, so this PythonDStream will be registered as an output stream and there materialized.
* Since serialized Python object is readable by Python, pyprint writes out binary data to
* temporary file and run python script to deserialized and print the first ten elements
*
* Currently call python script directly. We should avoid this
*/
private[streaming] def pyprint() {
def foreachFunc = (rdd: RDD[T], time: Time) => {
val iter = rdd.take(11).iterator

// Generate a temporary file
val prefix = "spark"
val suffix = ".tmp"
val tempFile = File.createTempFile(prefix, suffix)
val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath))
// Write out serialized python object to temporary file
PythonRDD.writeIteratorToStream(iter, tempFileStream)
tempFileStream.close()

// pythonExec should be passed from python. Move pyprint to PythonDStream
val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON")

val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
// Call python script to deserialize and print result in stdout
val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath)
val workerEnv = pb.environment()

// envVars also should be pass from python
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
workerEnv.put("PYTHONPATH", pythonPath)
val worker = pb.start()
val is = worker.getInputStream()
val isr = new InputStreamReader(is)
val br = new BufferedReader(isr)

println ("-------------------------------------------")
println ("Time: " + time)
println ("-------------------------------------------")

// Print values which is from python std out
var line = ""
breakable {
while (true) {
line = br.readLine()
if (line == null) break()
println(line)
}
}
// Delete temporary file
tempFile.delete()
println()

}
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}


/**
* Return a new DStream in which each RDD contains all the elements in seen in a
* sliding window of time over this DStream. The new DStream generates RDDs with
Expand Down

0 comments on commit 28c6620

Please sign in to comment.