Skip to content

Commit

Permalink
Fix the flaky MQTT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Jul 31, 2015
1 parent 47278c5 commit 935615c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,11 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
}
}

MQTTTestUtils.registerStreamingListener(ssc)

ssc.start()

// wait for the receiver to start before publishing data, or we risk failing
// the test nondeterministically. See SPARK-4631
MQTTTestUtils.waitForReceiverToStart(ssc)

MQTTTestUtils.publishData(topic, sendMessage)

// Retry it because we don't know when the receiver will start.
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
MQTTTestUtils.publishData(topic, sendMessage)
assert(sendMessage.equals(receiveMessage(0)))
}
ssc.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.streaming.mqtt

import java.net.{ServerSocket, URI}
import java.util.concurrent.{CountDownLatch, TimeUnit}

import scala.language.postfixOps

Expand All @@ -28,10 +27,6 @@ import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence

import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.JavaStreamingContext
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkConf}

Expand All @@ -47,8 +42,6 @@ private class MQTTTestUtils extends Logging {
private var broker: BrokerService = _
private var connector: TransportConnector = _

private var receiverStartedLatch = new CountDownLatch(1)

def brokerUri: String = {
s"$brokerHost:$brokerPort"
}
Expand All @@ -73,7 +66,6 @@ private class MQTTTestUtils extends Logging {
connector = null
}
Utils.deleteRecursively(persistenceDir)
receiverStartedLatch = null
}

private def findFreePort(): Int = {
Expand Down Expand Up @@ -114,38 +106,4 @@ private class MQTTTestUtils extends Logging {
}
}

/**
* Call this one before starting StreamingContext so that we won't miss the
* StreamingListenerReceiverStarted event.
*/
def registerStreamingListener(jssc: JavaStreamingContext): Unit = {
registerStreamingListener(jssc.ssc)
}

/**
* Call this one before starting StreamingContext so that we won't miss the
* StreamingListenerReceiverStarted event.
*/
def registerStreamingListener(ssc: StreamingContext): Unit = {
ssc.addStreamingListener(new StreamingListener {
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
receiverStartedLatch.countDown()
}
})
}

/**
* Block until at least one receiver has started or timeout occurs.
*/
def waitForReceiverToStart(jssc: JavaStreamingContext): Unit = {
waitForReceiverToStart(jssc.ssc)
}

/**
* Block until at least one receiver has started or timeout occurs.
*/
def waitForReceiverToStart(ssc: StreamingContext): Unit = {
assert(
receiverStartedLatch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
}
}
27 changes: 20 additions & 7 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,14 +931,27 @@ def test_mqtt_stream(self):
"""Test the Python MQTT stream API."""
sendData = "MQTT demo for spark streaming"
topic = self._randomTopic()
self._MQTTTestUtils.registerStreamingListener(self.ssc._jssc)
result = self._startContext(topic)
self._MQTTTestUtils.waitForReceiverToStart(self.ssc._jssc)
self._MQTTTestUtils.publishData(topic, sendData)
self.wait_for(result, 1)
# Because "publishData" sends duplicate messages, here we should use > 0
self.assertTrue(len(result) > 0)
self.assertEqual(sendData, result[0])

def retry():
self._MQTTTestUtils.publishData(topic, sendData)
# Because "publishData" sends duplicate messages, here we should use > 0
self.assertTrue(len(result) > 0)
self.assertEqual(sendData, result[0])

# Retry it because we don't know when the receiver will start.
self._retry_or_timeout(retry)

def _retry_or_timeout(self, test_func):
start_time = time.time()
while True:
try:
test_func()
break
except:
if time.time() - start_time > self.timeout:
raise
time.sleep(0.01)


def search_kafka_assembly_jar():
Expand Down

0 comments on commit 935615c

Please sign in to comment.