diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index dab36cc3a9c24..95cb76a15be07 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -469,13 +469,18 @@ def setUp(self): self.batachDuration = Milliseconds(500) self.sparkHome = "SomeDir" self.envPair = {"key": "value"} + self.ssc = None + self.sc = None def tearDown(self): # Do not call pyspark.streaming.context.StreamingContext.stop directly because # we do not wait to shutdown py4j client. # We need change this simply calll streamingConxt.Stop - self.ssc._jssc.stop() - self.ssc._sc.stop() + #self.ssc._jssc.stop() + if self.ssc is not None: + self.ssc.stop() + if self.sc is not None: + self.sc.stop() # Why does it long time to terminate StremaingContext and SparkContext? # Should we change the sleep time if this depends on machine spec? time.sleep(1) @@ -486,48 +491,67 @@ def tearDownClass(cls): SparkContext._gateway._shutdown_callback_server() def test_from_no_conf_constructor(self): - ssc = StreamingContext(master=self.master, appName=self.appName, duration=batachDuration) + self.ssc = StreamingContext(master=self.master, appName=self.appName, + duration=self.batachDuration) # Alternative call master: ssc.sparkContext.master # I try to make code close to Scala. - self.assertEqual(ssc.sparkContext._conf.get("spark.master"), self.master) - self.assertEqual(ssc.sparkContext._conf.get("spark.app.name"), self.appName) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.master"), self.master) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.app.name"), self.appName) def test_from_no_conf_plus_spark_home(self): - ssc = StreamingContext(master=self.master, appName=self.appName, - sparkHome=self.sparkHome, duration=batachDuration) - self.assertEqual(ssc.sparkContext._conf.get("spark.home"), self.sparkHome) + self.ssc = StreamingContext(master=self.master, appName=self.appName, + sparkHome=self.sparkHome, duration=self.batachDuration) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.home"), self.sparkHome) + + def test_from_no_conf_plus_spark_home_plus_env(self): + self.ssc = StreamingContext(master=self.master, appName=self.appName, + sparkHome=self.sparkHome, environment=self.envPair, + duration=self.batachDuration) + self.assertEqual(self.ssc.sparkContext._conf.get("spark.executorEnv.key"), self.envPair["key"]) def test_from_existing_spark_context(self): - sc = SparkContext(master=self.master, appName=self.appName) - ssc = StreamingContext(sparkContext=sc) + self.sc = SparkContext(master=self.master, appName=self.appName) + self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration) def test_existing_spark_context_with_settings(self): conf = SparkConf() conf.set("spark.cleaner.ttl", "10") - sc = SparkContext(master=self.master, appName=self.appName, conf=conf) - ssc = StreamingContext(context=sc) - self.assertEqual(int(ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10) - - def _addInputStream(self, s): - test_inputs = map(lambda x: range(1, x), range(5, 101)) - # make sure numSlice is 2 due to deserializer proglem in pyspark - s._testInputStream(test_inputs, 2) - - def test_from_no_conf_plus_spark_home_plus_env(self): - pass + self.sc = SparkContext(master=self.master, appName=self.appName, conf=conf) + self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration) + self.assertEqual(int(self.ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10) def test_from_conf_with_settings(self): - pass + conf = SparkConf() + conf.set("spark.cleaner.ttl", "10") + conf.setMaster(self.master) + conf.setAppName(self.appName) + self.ssc = StreamingContext(conf=conf, duration=self.batachDuration) + self.assertEqual(int(self.ssc.sparkContext._conf.get("spark.cleaner.ttl")), 10) def test_stop_only_streaming_context(self): - pass - - def test_await_termination(self): - pass - - + self.sc = SparkContext(master=self.master, appName=self.appName) + self.ssc = StreamingContext(sparkContext=self.sc, duration=self.batachDuration) + self._addInputStream(self.ssc) + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + def test_stop_multiple_times(self): + self.ssc = StreamingContext(master=self.master, appName=self.appName, + duration=self.batachDuration) + self._addInputStream(self.ssc) + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + def _addInputStream(self, s): + # Make sure each length of input is over 3 and + # numSlice is 2 due to deserializer problem in pyspark.streaming + test_inputs = map(lambda x: range(1, x), range(5, 101)) + test_stream = s._testInputStream(test_inputs, 2) + # Register fake output operation + result = list() + test_stream._test_output(result) if __name__ == "__main__": unittest.main()