Skip to content

Commit

Permalink
Merge pull request #188 from vladak/user_data_public
Browse files Browse the repository at this point in the history
make user_data "public"
  • Loading branch information
FoamyGuy authored Nov 27, 2023
2 parents 1c25441 + 66309c1 commit 6270110
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 8 deletions.
25 changes: 17 additions & 8 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ class MQTT:
in seconds.
:param int connect_retries: How many times to try to connect to the broker before giving up
on connect or reconnect. Exponential backoff will be used for the retries.
:param class user_data: arbitrary data to pass as a second argument to the callbacks.
:param class user_data: arbitrary data to pass as a second argument to most of the callbacks.
This works with all callbacks but the "on_message" and those added via add_topic_callback();
for those, to get access to the user_data use the 'user_data' member of the MQTT object
passed as 1st argument.
"""

Expand Down Expand Up @@ -205,7 +208,7 @@ def __init__(
self._recv_timeout = recv_timeout

self.keep_alive = keep_alive
self._user_data = user_data
self.user_data = user_data
self._is_connected = False
self._msg_size_lim = MQTT_MSG_SZ_LIM
self._pid = 0
Expand Down Expand Up @@ -413,6 +416,11 @@ def add_topic_callback(self, mqtt_topic: str, callback_method) -> None:
:param str mqtt_topic: MQTT topic identifier.
:param function callback_method: The callback method.
Expected method signature is ``on_message(client, topic, message)``
To get access to the user_data, use the client argument.
If a callback is called for the topic, then any "on_message" callback will not be called.
"""
if mqtt_topic is None or callback_method is None:
raise ValueError("MQTT topic and callback method must both be defined.")
Expand All @@ -437,6 +445,7 @@ def on_message(self):
"""Called when a new message has been received on a subscribed topic.
Expected method signature is ``on_message(client, topic, message)``
To get access to the user_data, use the client argument.
"""
return self._on_message

Expand Down Expand Up @@ -638,7 +647,7 @@ def _connect(
self._is_connected = True
result = rc[0] & 1
if self.on_connect is not None:
self.on_connect(self, self._user_data, result, rc[2])
self.on_connect(self, self.user_data, result, rc[2])

return result

Expand All @@ -661,7 +670,7 @@ def disconnect(self) -> None:
self._is_connected = False
self._subscribed_topics = []
if self.on_disconnect is not None:
self.on_disconnect(self, self._user_data, 0)
self.on_disconnect(self, self.user_data, 0)

def ping(self) -> list[int]:
"""Pings the MQTT Broker to confirm if the broker is alive or if
Expand Down Expand Up @@ -757,7 +766,7 @@ def publish(
self._sock.send(pub_hdr_var)
self._sock.send(msg)
if qos == 0 and self.on_publish is not None:
self.on_publish(self, self._user_data, topic, self._pid)
self.on_publish(self, self.user_data, topic, self._pid)
if qos == 1:
stamp = time.monotonic()
while True:
Expand All @@ -769,7 +778,7 @@ def publish(
rcv_pid = rcv_pid_buf[0] << 0x08 | rcv_pid_buf[1]
if self._pid == rcv_pid:
if self.on_publish is not None:
self.on_publish(self, self._user_data, topic, rcv_pid)
self.on_publish(self, self.user_data, topic, rcv_pid)
return

if op is None:
Expand Down Expand Up @@ -849,7 +858,7 @@ def subscribe(self, topic: str, qos: int = 0) -> None:

for t, q in topics:
if self.on_subscribe is not None:
self.on_subscribe(self, self._user_data, t, q)
self.on_subscribe(self, self.user_data, t, q)
self._subscribed_topics.append(t)
return

Expand Down Expand Up @@ -907,7 +916,7 @@ def unsubscribe(self, topic: str) -> None:
assert rc[1] == packet_id_bytes[0] and rc[2] == packet_id_bytes[1]
for t in topics:
if self.on_unsubscribe is not None:
self.on_unsubscribe(self, self._user_data, t, self._pid)
self.on_unsubscribe(self, self.user_data, t, self._pid)
self._subscribed_topics.remove(t)
return

Expand Down
99 changes: 99 additions & 0 deletions examples/cpython/user_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
# SPDX-License-Identifier: Unlicense

# pylint: disable=logging-fstring-interpolation

"""
Demonstrate on how to use user_data for various callbacks.
"""

import logging
import socket
import ssl
import sys

import adafruit_minimqtt.adafruit_minimqtt as MQTT


# pylint: disable=unused-argument
def on_connect(mqtt_client, user_data, flags, ret_code):
"""
connect callback
"""
logger = logging.getLogger(__name__)
logger.debug("Connected to MQTT Broker!")
logger.debug(f"Flags: {flags}\n RC: {ret_code}")


# pylint: disable=unused-argument
def on_subscribe(mqtt_client, user_data, topic, granted_qos):
"""
subscribe callback
"""
logger = logging.getLogger(__name__)
logger.debug(f"Subscribed to {topic} with QOS level {granted_qos}")


def on_message(client, topic, message):
"""
received message callback
"""
logger = logging.getLogger(__name__)
logger.debug(f"New message on topic {topic}: {message}")

messages = client.user_data
if not messages.get(topic):
messages[topic] = []
messages[topic].append(message)


# pylint: disable=too-many-statements,too-many-locals
def main():
"""
Main loop.
"""

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# dictionary/map of topic to list of messages
messages = {}

# connect to MQTT broker
mqtt = MQTT.MQTT(
broker="172.40.0.3",
port=1883,
socket_pool=socket,
ssl_context=ssl.create_default_context(),
user_data=messages,
)

mqtt.on_connect = on_connect
mqtt.on_subscribe = on_subscribe
mqtt.on_message = on_message

logger.info("Connecting to MQTT broker")
mqtt.connect()
logger.info("Subscribing")
mqtt.subscribe("foo/#", qos=0)
mqtt.add_topic_callback("foo/bar", on_message)

i = 0
while True:
i += 1
logger.debug(f"Loop {i}")
# Make sure to stay connected to the broker e.g. in case of keep alive.
mqtt.loop(1)

for topic, msg_list in messages.items():
logger.info(f"Got {len(msg_list)} messages from topic {topic}")
for msg_cnt, msg in enumerate(msg_list):
logger.debug(f"#{msg_cnt}: {msg}")


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(0)

0 comments on commit 6270110

Please sign in to comment.