diff --git a/dev/run-tests.py b/dev/run-tests.py
index d1852b95bb292..f689425ee40b6 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -303,6 +303,8 @@ def build_spark_sbt(hadoop_version):
"assembly/assembly",
"streaming-kafka-assembly/assembly",
"streaming-flume-assembly/assembly",
+ "streaming-mqtt-assembly/assembly",
+ "streaming-mqtt/test:assembly",
"streaming-kinesis-asl-assembly/assembly"]
profiles_and_goals = build_profiles + sbt_goals
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index a9717ff9569c7..d82c0cca37bc6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -181,6 +181,7 @@ def contains_file(self, filename):
dependencies=[streaming],
source_file_regexes=[
"external/mqtt",
+ "external/mqtt-assembly",
],
sbt_test_goals=[
"streaming-mqtt/test",
@@ -306,6 +307,7 @@ def contains_file(self, filename):
streaming,
streaming_kafka,
streaming_flume_assembly,
+ streaming_mqtt,
streaming_kinesis_asl
],
source_file_regexes=[
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index dbfdb619f89e2..c59d936b43c88 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea
{:.no_toc}
Python API As of Spark {{site.SPARK_VERSION_SHORT}},
-out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future.
+out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future.
This category of sources require interfacing with external non-Spark libraries, some of them with
complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts
diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py
new file mode 100644
index 0000000000000..617ce5ea6775e
--- /dev/null
+++ b/examples/src/main/python/streaming/mqtt_wordcount.py
@@ -0,0 +1,58 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ A sample wordcount with MqttStream stream
+ Usage: mqtt_wordcount.py
+
+ To run this in your local machine, you need to setup a MQTT broker and publisher first,
+ Mosquitto is one of the open source MQTT Brokers, see
+ http://mosquitto.org/
+ Eclipse paho project provides number of clients and utilities for working with MQTT, see
+ http://www.eclipse.org/paho/#getting-started
+
+ and then run the example
+ `$ bin/spark-submit --jars external/mqtt-assembly/target/scala-*/\
+ spark-streaming-mqtt-assembly-*.jar examples/src/main/python/streaming/mqtt_wordcount.py \
+ tcp://localhost:1883 foo`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+from pyspark.streaming.mqtt import MQTTUtils
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: mqtt_wordcount.py "
+ exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingMQTTWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ brokerUrl = sys.argv[1]
+ topic = sys.argv[2]
+
+ lines = MQTTUtils.createStream(ssc, brokerUrl, topic)
+ counts = lines.flatMap(lambda line: line.split(" ")) \
+ .map(lambda word: (word, 1)) \
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml
new file mode 100644
index 0000000000000..9c94473053d96
--- /dev/null
+++ b/external/mqtt-assembly/pom.xml
@@ -0,0 +1,102 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.10
+ 1.5.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-streaming-mqtt-assembly_2.10
+ jar
+ Spark Project External MQTT Assembly
+ http://spark.apache.org/
+
+
+ streaming-mqtt-assembly
+
+
+
+
+ org.apache.spark
+ spark-streaming-mqtt_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+ ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar
+
+
+ *:*
+
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+ package
+
+ shade
+
+
+
+
+
+ reference.conf
+
+
+ log4j.properties
+
+
+
+
+
+
+
+
+
+
+
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 0e41e5781784b..69b309876a0db 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -78,5 +78,33 @@
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
+
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+
+
+ test-jar-with-dependencies
+ package
+
+ single
+
+
+
+ spark-streaming-mqtt-test-${project.version}
+ ${project.build.directory}/scala-${scala.binary.version}/
+ false
+
+ false
+
+ src/main/assembly/assembly.xml
+
+
+
+
+
+
diff --git a/external/mqtt/src/main/assembly/assembly.xml b/external/mqtt/src/main/assembly/assembly.xml
new file mode 100644
index 0000000000000..ecab5b360eb3e
--- /dev/null
+++ b/external/mqtt/src/main/assembly/assembly.xml
@@ -0,0 +1,44 @@
+
+
+ test-jar-with-dependencies
+
+ jar
+
+ false
+
+
+
+ ${project.build.directory}/scala-${scala.binary.version}/test-classes
+ /
+
+
+
+
+
+ true
+ test
+ true
+
+ org.apache.hadoop:*:jar
+ org.apache.zookeeper:*:jar
+ org.apache.avro:*:jar
+
+
+
+
+
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
index 1142d0f56ba34..38a1114863d15 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
@@ -74,3 +74,19 @@ object MQTTUtils {
createStream(jssc.ssc, brokerUrl, topic, storageLevel)
}
}
+
+/**
+ * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and
+ * function so that it can be easily instantiated and called from Python's MQTTUtils.
+ */
+private class MQTTUtilsPythonHelper {
+
+ def createStream(
+ jssc: JavaStreamingContext,
+ brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel
+ ): JavaDStream[String] = {
+ MQTTUtils.createStream(jssc, brokerUrl, topic, storageLevel)
+ }
+}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index c4bf5aa7869bb..a6a9249db8ed7 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -17,46 +17,30 @@
package org.apache.spark.streaming.mqtt
-import java.net.{URI, ServerSocket}
-import java.util.concurrent.CountDownLatch
-import java.util.concurrent.TimeUnit
-
import scala.concurrent.duration._
import scala.language.postfixOps
-import org.apache.activemq.broker.{TransportConnector, BrokerService}
-import org.apache.commons.lang3.RandomUtils
-import org.eclipse.paho.client.mqttv3._
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.streaming.scheduler.StreamingListener
-import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.util.Utils
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
private val batchDuration = Milliseconds(500)
private val master = "local[2]"
private val framework = this.getClass.getSimpleName
- private val freePort = findFreePort()
- private val brokerUri = "//localhost:" + freePort
private val topic = "def"
- private val persistenceDir = Utils.createTempDir()
private var ssc: StreamingContext = _
- private var broker: BrokerService = _
- private var connector: TransportConnector = _
+ private var mqttTestUtils: MQTTTestUtils = _
before {
ssc = new StreamingContext(master, framework, batchDuration)
- setupMQTT()
+ mqttTestUtils = new MQTTTestUtils
+ mqttTestUtils.setup()
}
after {
@@ -64,14 +48,17 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
ssc.stop()
ssc = null
}
- Utils.deleteRecursively(persistenceDir)
- tearDownMQTT()
+ if (mqttTestUtils != null) {
+ mqttTestUtils.teardown()
+ mqttTestUtils = null
+ }
}
test("mqtt input stream") {
val sendMessage = "MQTT demo for spark streaming"
- val receiveStream =
- MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
+ val receiveStream = MQTTUtils.createStream(ssc, "tcp://" + mqttTestUtils.brokerUri, topic,
+ StorageLevel.MEMORY_ONLY)
+
@volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd =>
if (rdd.collect.length > 0) {
@@ -79,89 +66,14 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
receiveMessage
}
}
- ssc.start()
- // wait for the receiver to start before publishing data, or we risk failing
- // the test nondeterministically. See SPARK-4631
- waitForReceiverToStart()
+ ssc.start()
- publishData(sendMessage)
+ // Retry it because we don't know when the receiver will start.
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
+ mqttTestUtils.publishData(topic, sendMessage)
assert(sendMessage.equals(receiveMessage(0)))
}
ssc.stop()
}
-
- private def setupMQTT() {
- broker = new BrokerService()
- broker.setDataDirectoryFile(Utils.createTempDir())
- connector = new TransportConnector()
- connector.setName("mqtt")
- connector.setUri(new URI("mqtt:" + brokerUri))
- broker.addConnector(connector)
- broker.start()
- }
-
- private def tearDownMQTT() {
- if (broker != null) {
- broker.stop()
- broker = null
- }
- if (connector != null) {
- connector.stop()
- connector = null
- }
- }
-
- private def findFreePort(): Int = {
- val candidatePort = RandomUtils.nextInt(1024, 65536)
- Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
- val socket = new ServerSocket(trialPort)
- socket.close()
- (null, trialPort)
- }, new SparkConf())._2
- }
-
- def publishData(data: String): Unit = {
- var client: MqttClient = null
- try {
- val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
- client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence)
- client.connect()
- if (client.isConnected) {
- val msgTopic = client.getTopic(topic)
- val message = new MqttMessage(data.getBytes("utf-8"))
- message.setQos(1)
- message.setRetained(true)
-
- for (i <- 0 to 10) {
- try {
- msgTopic.publish(message)
- } catch {
- case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- // wait for Spark streaming to consume something from the message queue
- Thread.sleep(50)
- }
- }
- }
- } finally {
- client.disconnect()
- client.close()
- client = null
- }
- }
-
- /**
- * Block until at least one receiver has started or timeout occurs.
- */
- private def waitForReceiverToStart() = {
- val latch = new CountDownLatch(1)
- ssc.addStreamingListener(new StreamingListener {
- override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
- latch.countDown()
- }
- })
-
- assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
- }
}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
new file mode 100644
index 0000000000000..1a371b7008824
--- /dev/null
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.mqtt
+
+import java.net.{ServerSocket, URI}
+
+import scala.language.postfixOps
+
+import com.google.common.base.Charsets.UTF_8
+import org.apache.activemq.broker.{BrokerService, TransportConnector}
+import org.apache.commons.lang3.RandomUtils
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+
+import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf}
+
+/**
+ * Share codes for Scala and Python unit tests
+ */
+private class MQTTTestUtils extends Logging {
+
+ private val persistenceDir = Utils.createTempDir()
+ private val brokerHost = "localhost"
+ private val brokerPort = findFreePort()
+
+ private var broker: BrokerService = _
+ private var connector: TransportConnector = _
+
+ def brokerUri: String = {
+ s"$brokerHost:$brokerPort"
+ }
+
+ def setup(): Unit = {
+ broker = new BrokerService()
+ broker.setDataDirectoryFile(Utils.createTempDir())
+ connector = new TransportConnector()
+ connector.setName("mqtt")
+ connector.setUri(new URI("mqtt://" + brokerUri))
+ broker.addConnector(connector)
+ broker.start()
+ }
+
+ def teardown(): Unit = {
+ if (broker != null) {
+ broker.stop()
+ broker = null
+ }
+ if (connector != null) {
+ connector.stop()
+ connector = null
+ }
+ Utils.deleteRecursively(persistenceDir)
+ }
+
+ private def findFreePort(): Int = {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ }, new SparkConf())._2
+ }
+
+ def publishData(topic: String, data: String): Unit = {
+ var client: MqttClient = null
+ try {
+ val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
+ client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence)
+ client.connect()
+ if (client.isConnected) {
+ val msgTopic = client.getTopic(topic)
+ val message = new MqttMessage(data.getBytes(UTF_8))
+ message.setQos(1)
+ message.setRetained(true)
+
+ for (i <- 0 to 10) {
+ try {
+ msgTopic.publish(message)
+ } catch {
+ case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
+ // wait for Spark streaming to consume something from the message queue
+ Thread.sleep(50)
+ }
+ }
+ }
+ } finally {
+ if (client != null) {
+ client.disconnect()
+ client.close()
+ client = null
+ }
+ }
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index 2bcc55b040a26..8942836a7da16 100644
--- a/pom.xml
+++ b/pom.xml
@@ -104,6 +104,7 @@
external/flume-sink
external/flume-assembly
external/mqtt
+ external/mqtt-assembly
external/zeromq
examples
repl
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 9a33baa7c6ce1..41a85fa9de778 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -45,8 +45,8 @@ object BuildCommons {
sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
"kinesis-asl").map(ProjectRef(buildLocation, _))
- val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) =
- Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly")
+ val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) =
+ Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly")
.map(ProjectRef(buildLocation, _))
val tools = ProjectRef(buildLocation, "tools")
@@ -212,6 +212,9 @@ object SparkBuild extends PomBuild {
/* Enable Assembly for all assembly projects */
assemblyProjects.foreach(enable(Assembly.settings))
+ /* Enable Assembly for streamingMqtt test */
+ enable(inConfig(Test)(Assembly.settings))(streamingMqtt)
+
/* Package pyspark artifacts in a separate zip file for YARN. */
enable(PySparkAssembly.settings)(assembly)
@@ -382,13 +385,16 @@ object Assembly {
.getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
},
jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
- if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) {
+ if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-mqtt-assembly") || mName.contains("streaming-kinesis-asl-assembly")) {
// This must match the same name used in maven (see external/kafka-assembly/pom.xml)
s"${mName}-${v}.jar"
} else {
s"${mName}-${v}-hadoop${hv}.jar"
}
},
+ jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
+ s"${mName}-test-${v}.jar"
+ },
mergeStrategy in assembly := {
case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py
new file mode 100644
index 0000000000000..f06598971c548
--- /dev/null
+++ b/python/pyspark/streaming/mqtt.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from py4j.java_gateway import Py4JJavaError
+
+from pyspark.storagelevel import StorageLevel
+from pyspark.serializers import UTF8Deserializer
+from pyspark.streaming import DStream
+
+__all__ = ['MQTTUtils']
+
+
+class MQTTUtils(object):
+
+ @staticmethod
+ def createStream(ssc, brokerUrl, topic,
+ storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
+ """
+ Create an input stream that pulls messages from a Mqtt Broker.
+ :param ssc: StreamingContext object
+ :param brokerUrl: Url of remote mqtt publisher
+ :param topic: topic name to subscribe to
+ :param storageLevel: RDD storage level.
+ :return: A DStream object
+ """
+ jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
+
+ try:
+ helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+ .loadClass("org.apache.spark.streaming.mqtt.MQTTUtilsPythonHelper")
+ helper = helperClass.newInstance()
+ jstream = helper.createStream(ssc._jssc, brokerUrl, topic, jlevel)
+ except Py4JJavaError as e:
+ if 'ClassNotFoundException' in str(e.java_exception):
+ MQTTUtils._printErrorMsg(ssc.sparkContext)
+ raise e
+
+ return DStream(jstream, ssc, UTF8Deserializer())
+
+ @staticmethod
+ def _printErrorMsg(sc):
+ print("""
+________________________________________________________________________________________________
+
+ Spark Streaming's MQTT libraries not found in class path. Try one of the following.
+
+ 1. Include the MQTT library and its dependencies with in the
+ spark-submit command as
+
+ $ bin/spark-submit --packages org.apache.spark:spark-streaming-mqtt:%s ...
+
+ 2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
+ Group Id = org.apache.spark, Artifact Id = spark-streaming-mqtt-assembly, Version = %s.
+ Then, include the jar in the spark-submit command as
+
+ $ bin/spark-submit --jars ...
+________________________________________________________________________________________________
+""" % (sc.version, sc.version))
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 5cd544b2144ef..66ae3345f468f 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -40,6 +40,7 @@
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
from pyspark.streaming.flume import FlumeUtils
+from pyspark.streaming.mqtt import MQTTUtils
from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream
@@ -893,6 +894,68 @@ def test_flume_polling_multiple_hosts(self):
self._testMultipleTimes(self._testFlumePollingMultipleHosts)
+class MQTTStreamTests(PySparkStreamingTestCase):
+ timeout = 20 # seconds
+ duration = 1
+
+ def setUp(self):
+ super(MQTTStreamTests, self).setUp()
+
+ MQTTTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+ .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils")
+ self._MQTTTestUtils = MQTTTestUtilsClz.newInstance()
+ self._MQTTTestUtils.setup()
+
+ def tearDown(self):
+ if self._MQTTTestUtils is not None:
+ self._MQTTTestUtils.teardown()
+ self._MQTTTestUtils = None
+
+ super(MQTTStreamTests, self).tearDown()
+
+ def _randomTopic(self):
+ return "topic-%d" % random.randint(0, 10000)
+
+ def _startContext(self, topic):
+ # Start the StreamingContext and also collect the result
+ stream = MQTTUtils.createStream(self.ssc, "tcp://" + self._MQTTTestUtils.brokerUri(), topic)
+ result = []
+
+ def getOutput(_, rdd):
+ for data in rdd.collect():
+ result.append(data)
+
+ stream.foreachRDD(getOutput)
+ self.ssc.start()
+ return result
+
+ def test_mqtt_stream(self):
+ """Test the Python MQTT stream API."""
+ sendData = "MQTT demo for spark streaming"
+ topic = self._randomTopic()
+ result = self._startContext(topic)
+
+ def retry():
+ self._MQTTTestUtils.publishData(topic, sendData)
+ # Because "publishData" sends duplicate messages, here we should use > 0
+ self.assertTrue(len(result) > 0)
+ self.assertEqual(sendData, result[0])
+
+ # Retry it because we don't know when the receiver will start.
+ self._retry_or_timeout(retry)
+
+ def _retry_or_timeout(self, test_func):
+ start_time = time.time()
+ while True:
+ try:
+ test_func()
+ break
+ except:
+ if time.time() - start_time > self.timeout:
+ raise
+ time.sleep(0.01)
+
+
class KinesisStreamTests(PySparkStreamingTestCase):
def test_kinesis_stream_api(self):
@@ -985,7 +1048,42 @@ def search_flume_assembly_jar():
"'build/mvn package' before running this test")
elif len(jars) > 1:
raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please "
- "remove all but one") % flume_assembly_dir)
+ "remove all but one") % flume_assembly_dir)
+ else:
+ return jars[0]
+
+
+def search_mqtt_assembly_jar():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly")
+ jars = glob.glob(
+ os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar"))
+ if not jars:
+ raise Exception(
+ ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) +
+ "You need to build Spark with "
+ "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or "
+ "'build/mvn package' before running this test")
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please "
+ "remove all but one") % mqtt_assembly_dir)
+ else:
+ return jars[0]
+
+
+def search_mqtt_test_jar():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ mqtt_test_dir = os.path.join(SPARK_HOME, "external/mqtt")
+ jars = glob.glob(
+ os.path.join(mqtt_test_dir, "target/scala-*/spark-streaming-mqtt-test-*.jar"))
+ if not jars:
+ raise Exception(
+ ("Failed to find Spark Streaming MQTT test jar in %s. " % mqtt_test_dir) +
+ "You need to build Spark with "
+ "'build/sbt assembly/assembly streaming-mqtt/test:assembly'")
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please "
+ "remove all but one") % mqtt_test_dir)
else:
return jars[0]
@@ -1012,8 +1110,12 @@ def search_kinesis_asl_assembly_jar():
if __name__ == "__main__":
kafka_assembly_jar = search_kafka_assembly_jar()
flume_assembly_jar = search_flume_assembly_jar()
+ mqtt_assembly_jar = search_mqtt_assembly_jar()
+ mqtt_test_jar = search_mqtt_test_jar()
kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar()
- jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar)
+
+ jars = "%s,%s,%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar,
+ mqtt_assembly_jar, mqtt_test_jar)
os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
unittest.main()