Skip to content

Commit

Permalink
WIP: added PythonTestInputStream
Browse files Browse the repository at this point in the history
  • Loading branch information
giwa committed Aug 18, 2014
1 parent 1fd12ae commit c05922c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
14 changes: 4 additions & 10 deletions examples/src/main/python/streaming/test_oprations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,14 @@
from pyspark.streaming.duration import *

if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: wordcount <hostname> <port>"
exit(-1)
conf = SparkConf()
conf.setAppName("PythonStreamingNetworkWordCount")
ssc = StreamingContext(conf=conf, duration=Seconds(1))

lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
words = lines.flatMap(lambda line: line.split(" "))
# ssc.checkpoint("checkpoint")
mapped_words = words.map(lambda word: (word, 1))
count = mapped_words.reduceByKey(add)
test_input = ssc._testInputStream([1,1,1,1])
mapped = test_input.map(lambda x: (x, 1))
mapped.pyprint()

count.pyprint()
ssc.start()
ssc.awaitTermination()
# ssc.awaitTermination()
# ssc.stop()
25 changes: 25 additions & 0 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import sys
from signal import signal, SIGTERM, SIGINT
from tempfile import NamedTemporaryFile

from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
Expand Down Expand Up @@ -138,3 +139,27 @@ def checkpoint(self, directory):
"""
"""
self._jssc.checkpoint(directory)

def _testInputStream(self, test_input, numSlices=None):

numSlices = numSlices or self._sc.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(test_input):
c = list(test_input) # Make it a list so we can compute its length
batchSize = min(len(test_input) // numSlices, self._sc._batchSize)
if batchSize > 1:
serializer = BatchedSerializer(self._sc._unbatched_serializer,
batchSize)
else:
serializer = self._sc._unbatched_serializer
serializer.dump_stream(test_input, tempFile)
tempFile.close()
print tempFile.name
jinput_stream = self._jvm.PythonTestInputStream(self._jssc,
tempFile.name,
numSlices).asJavaDStream()
return DStream(jinput_stream, self, UTF8Deserializer())
1 change: 1 addition & 0 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def _mergeCombiners(iterator):
combiners[k] = v
else:
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()

return shuffled._mapPartitions(_mergeCombiners)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* JavaStreamingContext object contains a number of utility functions.
*/
object JavaStreamingContext {
implicit def fromStreamingContext(ssc: StreamingContext): JavaStreamingContext = new JavaStreamingContext(ssc)

implicit def toStreamingContext(jssc: JavaStreamingContext): StreamingContext = jssc.ssc

/**
* Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.ClassTag

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java._
import org.apache.spark.api.python._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.streaming.{StreamingContext, Duration, Time}
Expand Down

0 comments on commit c05922c

Please sign in to comment.