diff --git a/dev/run-tests.py b/dev/run-tests.py
index 90535fd3b7b03..237fb76c9b3d9 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -295,7 +295,8 @@ def build_spark_sbt(hadoop_version):
"assembly/assembly",
"streaming-kafka-assembly/assembly",
"streaming-flume-assembly/assembly",
- "streaming-mqtt-assembly/assembly"]
+ "streaming-mqtt-assembly/assembly",
+ "streaming-mqtt/test:assembly"]
profiles_and_goals = build_profiles + sbt_goals
print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ",
diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml
index 7c5ba7051ac15..e216a9676abcc 100644
--- a/external/mqtt-assembly/pom.xml
+++ b/external/mqtt-assembly/pom.xml
@@ -58,6 +58,7 @@
maven-shade-plugin
false
+ ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar
*:*
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index a28dd3603503a..0e41e5781784b 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -72,6 +72,7 @@
org.apache.activemq
activemq-core
5.7.0
+ test
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
index de8f5650fbe55..22dabb36efa11 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
@@ -87,12 +87,7 @@ private class MQTTUtilsPythonHelper {
brokerUrl: String,
topic: String,
storageLevel: StorageLevel
- ): JavaDStream[Array[Byte]] = {
- val dstream = MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
- dstream.map(new Function[String, Array[Byte]] {
- override def call(data: String): Array[Byte] = {
- data.getBytes("UTF-8")
- }
- })
+ ): JavaDStream[String] = {
+ MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
}
}
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
similarity index 95%
rename from external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
rename to external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
index e5036fbc6d626..34e81b3f0f84f 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.mqtt
import java.net.{ServerSocket, URI}
-import java.util.concurrent.{TimeUnit, CountDownLatch}
+import java.util.concurrent.{CountDownLatch, TimeUnit}
import scala.language.postfixOps
@@ -27,7 +27,7 @@ import org.apache.commons.lang3.RandomUtils
import org.eclipse.paho.client.mqttv3._
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-import org.apache.spark.streaming.{StreamingContext, Milliseconds}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.util.Utils
@@ -40,7 +40,7 @@ private class MQTTTestUtils extends Logging {
private val persistenceDir = Utils.createTempDir()
private val brokerHost = "localhost"
- private var brokerPort = findFreePort()
+ private val brokerPort = findFreePort()
private var broker: BrokerService = _
private var connector: TransportConnector = _
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 970594c7fa10c..d3592f0a1f7c2 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -850,28 +850,43 @@ def tearDown(self):
def _randomTopic(self):
return "topic-%d" % random.randint(0, 10000)
- def _validateStreamResult(self, sendData, dstream):
+ def _startContext(self, topic):
+ # Start the StreamingContext and also collect the result
+ stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
result = []
- def get_output(_, rdd):
+ def getOutput(_, rdd):
for data in rdd.collect():
result.append(data)
- dstream.foreachRDD(get_output)
- receiveData = ' '.join(result[0])
+ stream.foreachRDD(getOutput)
+ self.ssc.start()
+ return result
+
+ def _publishData(self, topic, data):
+ start_time = time.time()
+ while True:
+ try:
+ self._MQTTTestUtils.publishData(topic, data)
+ break
+ except:
+ if time.time() - start_time < self.timeout:
+ time.sleep(0.01)
+ else:
+ raise
+
+ def _validateStreamResult(self, sendData, result):
+ receiveData = ''.join(result[0])
self.assertEqual(sendData, receiveData)
def test_mqtt_stream(self):
"""Test the Python MQTT stream API."""
- topic = self._randomTopic()
sendData = "MQTT demo for spark streaming"
- ssc = self.ssc
-
- self._MQTTTestUtils.waitForReceiverToStart(ssc)
- self._MQTTTestUtils.publishData(topic, sendData)
-
- stream = MQTTUtils.createStream(ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
- self._validateStreamResult(sendData, stream)
+ topic = self._randomTopic()
+ result = self._startContext(topic)
+ self._publishData(topic, sendData)
+ self.wait_for(result, len(sendData))
+ self._validateStreamResult(sendData, result)
def search_kafka_assembly_jar():
@@ -928,11 +943,29 @@ def search_mqtt_assembly_jar():
return jars[0]
+def search_mqtt_test_jar():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt")
+ jars = glob.glob(
+ os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar"))
+ if not jars:
+ raise Exception(
+ ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) +
+ "You need to build Spark with "
+ "'build/sbt assembly/assembly streaming-mqtt/test:assembly'")
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please "
+ "remove all but one") % mqtt_test_dir)
+ else:
+ return jars[0]
+
if __name__ == "__main__":
kafka_assembly_jar = search_kafka_assembly_jar()
flume_assembly_jar = search_flume_assembly_jar()
mqtt_assembly_jar = search_mqtt_assembly_jar()
- jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar)
+ mqtt_test_jar = search_mqtt_test_jar()
+ jars = "%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar,
+ mqtt_assembly_jar, mqtt_test_jar)
os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
unittest.main()