Skip to content

Commit

Permalink
added some StreamingContextTestSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
giwa committed Sep 1, 2014
1 parent f7bc8f9 commit 150b94c
Showing 1 changed file with 52 additions and 28 deletions.
80 changes: 52 additions & 28 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,18 @@ def setUp(self):
self.batachDuration = Milliseconds(500)
self.sparkHome = "SomeDir"
self.envPair = {"key": "value"}
self.ssc = None
self.sc = None

def tearDown(self):
# Do not call pyspark.streaming.context.StreamingContext.stop directly because
# we do not wait to shutdown py4j client.
# We need change this simply calll streamingConxt.Stop
self.ssc._jssc.stop()
self.ssc._sc.stop()
#self.ssc._jssc.stop()
if self.ssc is not None:
self.ssc.stop()
if self.sc is not None:
self.sc.stop()
# Why does it long time to terminate StremaingContext and SparkContext?
# Should we change the sleep time if this depends on machine spec?
time.sleep(1)
Expand All @@ -486,48 +491,67 @@ def tearDownClass(cls):
SparkContext._gateway._shutdown_callback_server()

def test_from_no_conf_constructor(self):
ssc = StreamingContext(master=self.master, appName=self.appName, duration=batachDuration)
self.ssc = StreamingContext(master=self.master, appName=self.appName,
duration=self.batachDuration)
# Alternative call master: ssc.sparkContext.master
# I try to make code close to Scala.
self.assertEqual(ssc.sparkContext._conf.get("spark.master"), self.master)
self.assertEqual(ssc.sparkContext._conf.get("spark.app.name"), self.appName)
self.assertEqual(self.ssc.sparkContext._conf.get("spark.master"), self.master)
self.assertEqual(self.ssc.sparkContext._conf.get("spark.app.name"), self.appName)

def test_from_no_conf_plus_spark_home(self):
ssc = StreamingContext(master=self.master, appName=self.appName,
sparkHome=self.sparkHome, duration=batachDuration)
self.assertEqual(ssc.sparkContext._conf.get("spark.home"), self.sparkHome)
self.ssc = StreamingContext(master=self.master, appName=self.appName,
sparkHome=self.sparkHome, duration=self.batachDuration)
self.assertEqual(self.ssc.sparkContext._conf.get("spark.home"), self.sparkHome)

def test_from_no_conf_plus_spark_home_plus_env(self):
self.ssc = StreamingContext(master=self.master, appName=self.appName,
sparkHome=self.sparkHome, environment=self.envPair,
duration=self.batachDuration)
self.assertEqual(self.ssc.sparkContext._conf.get("spark.executorEnv.key"), self.envPair["key"])

def test_from_existing_spark_context(self):
sc = SparkContext(master=self.master, appName=self.appName)
ssc = StreamingContext(sparkContext=sc)
self.sc = SparkContext(master=self.master, appName=self.appName)
self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration)

def test_existing_spark_context_with_settings(self):
conf = SparkConf()
conf.set("spark.cleaner.ttl", "10")
sc = SparkContext(master=self.master, appName=self.appName, conf=conf)
ssc = StreamingContext(context=sc)
self.assertEqual(int(ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10)

def _addInputStream(self, s):
test_inputs = map(lambda x: range(1, x), range(5, 101))
# make sure numSlice is 2 due to deserializer proglem in pyspark
s._testInputStream(test_inputs, 2)

def test_from_no_conf_plus_spark_home_plus_env(self):
pass
self.sc = SparkContext(master=self.master, appName=self.appName, conf=conf)
self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration)
self.assertEqual(int(self.ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10)

def test_from_conf_with_settings(self):
pass
conf = SparkConf()
conf.set("spark.cleaner.ttl", "10")
conf.setMaster(self.master)
conf.setAppName(self.appName)
self.ssc = StreamingContext(conf=conf, duration=self.batachDuration)
self.assertEqual(int(self.ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10)

def test_stop_only_streaming_context(self):
pass

def test_await_termination(self):
pass


self.sc = SparkContext(master=self.master, appName=self.appName)
self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration)
self._addInputStream(self.ssc)
self.ssc.start()
self.ssc.stop(False)
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)

def test_stop_multiple_times(self):
self.ssc = StreamingContext(master=self.master, appName=self.appName,
duration=self.batachDuration)
self._addInputStream(self.ssc)
self.ssc.start()
self.ssc.stop()
self.ssc.stop()

def _addInputStream(self, s):
# Make sure each length of input is over 3 and
# numSlice is 2 due to deserializer problem in pyspark.streaming
test_inputs = map(lambda x: range(1, x), range(5, 101))
test_stream = s._testInputStream(test_inputs, 2)
# Register fake output operation
result = list()
test_stream._test_output(result)

if __name__ == "__main__":
unittest.main()

0 comments on commit 150b94c

Please sign in to comment.