diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 98821c71c7aee..879be775ea9f9 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -67,17 +67,11 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter } } - MQTTTestUtils.registerStreamingListener(ssc) - ssc.start() - // wait for the receiver to start before publishing data, or we risk failing - // the test nondeterministically. See SPARK-4631 - MQTTTestUtils.waitForReceiverToStart(ssc) - - MQTTTestUtils.publishData(topic, sendMessage) - + // Retry it because we don't know when the receiver will start. eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + MQTTTestUtils.publishData(topic, sendMessage) assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala index 6c85019ae0723..47cc9af497778 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming.mqtt import java.net.{ServerSocket, URI} -import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.language.postfixOps @@ -28,10 +27,6 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.JavaStreamingContext -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf} @@ -47,8 +42,6 @@ private class MQTTTestUtils extends Logging { private var broker: BrokerService = _ private var connector: TransportConnector = _ - private var receiverStartedLatch = new CountDownLatch(1) - def brokerUri: String = { s"$brokerHost:$brokerPort" } @@ -73,7 +66,6 @@ private class MQTTTestUtils extends Logging { connector = null } Utils.deleteRecursively(persistenceDir) - receiverStartedLatch = null } private def findFreePort(): Int = { @@ -114,38 +106,4 @@ private class MQTTTestUtils extends Logging { } } - /** - * Call this one before starting StreamingContext so that we won't miss the - * StreamingListenerReceiverStarted event. - */ - def registerStreamingListener(jssc: JavaStreamingContext): Unit = { - registerStreamingListener(jssc.ssc) - } - - /** - * Call this one before starting StreamingContext so that we won't miss the - * StreamingListenerReceiverStarted event. - */ - def registerStreamingListener(ssc: StreamingContext): Unit = { - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - receiverStartedLatch.countDown() - } - }) - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - def waitForReceiverToStart(jssc: JavaStreamingContext): Unit = { - waitForReceiverToStart(jssc.ssc) - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - def waitForReceiverToStart(ssc: StreamingContext): Unit = { - assert( - receiverStartedLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 0da312b89b72f..2e37aa96c614f 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -931,14 +931,27 @@ def test_mqtt_stream(self): """Test the Python MQTT stream API.""" sendData = "MQTT demo for spark streaming" topic = self._randomTopic() - self._MQTTTestUtils.registerStreamingListener(self.ssc._jssc) result = self._startContext(topic) - self._MQTTTestUtils.waitForReceiverToStart(self.ssc._jssc) - self._MQTTTestUtils.publishData(topic, sendData) - self.wait_for(result, 1) - # Because "publishData" sends duplicate messages, here we should use > 0 - self.assertTrue(len(result) > 0) - self.assertEqual(sendData, result[0]) + + def retry(): + self._MQTTTestUtils.publishData(topic, sendData) + # Because "publishData" sends duplicate messages, here we should use > 0 + self.assertTrue(len(result) > 0) + self.assertEqual(sendData, result[0]) + + # Retry it because we don't know when the receiver will start. + self._retry_or_timeout(retry) + + def _retry_or_timeout(self, test_func): + start_time = time.time() + while True: + try: + test_func() + break + except: + if time.time() - start_time > self.timeout: + raise + time.sleep(0.01) def search_kafka_assembly_jar():