From a5d6be57e895a9f3b5b5e00da9933c60ccedd3d0 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 16 May 2025 07:36:42 +0300 Subject: [PATCH 01/74] Initial commit, structure and basic data sending --- __init__.py | 3 - examples/device/operational_example.py | 135 ++++++++ pyproject.toml | 64 ++++ setup.py | 6 +- tb_mqtt_client/__init__.py | 14 + tb_mqtt_client/common/__init__.py | 14 + tb_mqtt_client/common/config_loader.py | 48 +++ tb_mqtt_client/common/exceptions.py | 69 ++++ tb_mqtt_client/common/gmqtt_patch.py | 88 +++++ tb_mqtt_client/common/logging_utils.py | 59 ++++ tb_mqtt_client/common/provision_client.py | 14 + tb_mqtt_client/common/rate_limit/__init__.py | 14 + .../rate_limit/backpressure_controller.py | 49 +++ .../common/rate_limit/rate_limit.py | 171 ++++++++++ tb_mqtt_client/common/request_id_generator.py | 44 +++ tb_mqtt_client/constants/__init__.py | 14 + tb_mqtt_client/constants/mqtt_topics.py | 66 ++++ tb_mqtt_client/constants/service_keys.py | 19 ++ tb_mqtt_client/constants/service_messages.py | 24 ++ tb_mqtt_client/entities/__init__.py | 14 + tb_mqtt_client/entities/data/__init__.py | 14 + .../entities/data/attribute_entry.py | 37 +++ .../entities/data/attribute_request.py | 14 + .../entities/data/attribute_response.py | 14 + .../entities/data/attribute_update.py | 57 ++++ tb_mqtt_client/entities/data/data_entry.py | 63 ++++ .../entities/data/device_uplink_message.py | 97 ++++++ tb_mqtt_client/entities/data/rpc_request.py | 47 +++ tb_mqtt_client/entities/data/rpc_response.py | 42 +++ .../entities/data/timeseries_entry.py | 56 ++++ tb_mqtt_client/entities/gateway/__init__.py | 14 + .../entities/gateway/device_session_state.py | 14 + .../entities/gateway/rpc_context.py | 14 + .../entities/gateway/virtual_device.py | 14 + tb_mqtt_client/service/__init__.py | 14 + tb_mqtt_client/service/base_client.py | 99 ++++++ tb_mqtt_client/service/device/__init__.py | 14 + .../device/attribute_updates_handler.py | 52 +++ tb_mqtt_client/service/device/client.py | 272 ++++++++++++++++ .../service/device/rpc_requests_handler.py | 69 ++++ tb_mqtt_client/service/event_dispatcher.py | 14 + tb_mqtt_client/service/gateway/__init__.py | 14 + tb_mqtt_client/service/gateway/client.py | 14 + .../service/gateway/device_sesion.py | 14 + .../gateway_attribute_updates_handler.py | 14 + .../gateway/gateway_rpc_requests_handler.py | 14 + .../service/gateway/multiplex_publisher.py | 15 + .../service/gateway/subdevice_manager.py | 14 + tb_mqtt_client/service/message_dispatcher.py | 146 +++++++++ tb_mqtt_client/service/message_queue.py | 211 ++++++++++++ tb_mqtt_client/service/message_splitter.py | 136 ++++++++ tb_mqtt_client/service/mqtt_manager.py | 302 ++++++++++++++++++ .../service/rpc_response_handler.py | 72 +++++ tb_mqtt_client/tb_device_mqtt.py | 14 + tests/constants/__init__.py | 14 + tests/constants/test_mqtt_topics.py | 23 ++ tests/service/__init__.py | 14 + tests/service/device/__init__.py | 14 + .../device/test_device_client_rate_limits.py | 77 +++++ tests/service/test_json_message_dispatcher.py | 92 ++++++ tests/service/test_message_splitter.py | 95 ++++++ tests/service/test_mqtt_manager.py | 152 +++++++++ 62 files changed, 3415 insertions(+), 6 deletions(-) create mode 100644 examples/device/operational_example.py create mode 100644 pyproject.toml create mode 100644 tb_mqtt_client/__init__.py create mode 100644 tb_mqtt_client/common/__init__.py create mode 100644 tb_mqtt_client/common/config_loader.py create mode 100644 tb_mqtt_client/common/exceptions.py create mode 100644 tb_mqtt_client/common/gmqtt_patch.py create mode 100644 tb_mqtt_client/common/logging_utils.py create mode 100644 tb_mqtt_client/common/provision_client.py create mode 100644 tb_mqtt_client/common/rate_limit/__init__.py create mode 100644 tb_mqtt_client/common/rate_limit/backpressure_controller.py create mode 100644 tb_mqtt_client/common/rate_limit/rate_limit.py create mode 100644 tb_mqtt_client/common/request_id_generator.py create mode 100644 tb_mqtt_client/constants/__init__.py create mode 100644 tb_mqtt_client/constants/mqtt_topics.py create mode 100644 tb_mqtt_client/constants/service_keys.py create mode 100644 tb_mqtt_client/constants/service_messages.py create mode 100644 tb_mqtt_client/entities/__init__.py create mode 100644 tb_mqtt_client/entities/data/__init__.py create mode 100644 tb_mqtt_client/entities/data/attribute_entry.py create mode 100644 tb_mqtt_client/entities/data/attribute_request.py create mode 100644 tb_mqtt_client/entities/data/attribute_response.py create mode 100644 tb_mqtt_client/entities/data/attribute_update.py create mode 100644 tb_mqtt_client/entities/data/data_entry.py create mode 100644 tb_mqtt_client/entities/data/device_uplink_message.py create mode 100644 tb_mqtt_client/entities/data/rpc_request.py create mode 100644 tb_mqtt_client/entities/data/rpc_response.py create mode 100644 tb_mqtt_client/entities/data/timeseries_entry.py create mode 100644 tb_mqtt_client/entities/gateway/__init__.py create mode 100644 tb_mqtt_client/entities/gateway/device_session_state.py create mode 100644 tb_mqtt_client/entities/gateway/rpc_context.py create mode 100644 tb_mqtt_client/entities/gateway/virtual_device.py create mode 100644 tb_mqtt_client/service/__init__.py create mode 100644 tb_mqtt_client/service/base_client.py create mode 100644 tb_mqtt_client/service/device/__init__.py create mode 100644 tb_mqtt_client/service/device/attribute_updates_handler.py create mode 100644 tb_mqtt_client/service/device/client.py create mode 100644 tb_mqtt_client/service/device/rpc_requests_handler.py create mode 100644 tb_mqtt_client/service/event_dispatcher.py create mode 100644 tb_mqtt_client/service/gateway/__init__.py create mode 100644 tb_mqtt_client/service/gateway/client.py create mode 100644 tb_mqtt_client/service/gateway/device_sesion.py create mode 100644 tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py create mode 100644 tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py create mode 100644 tb_mqtt_client/service/gateway/multiplex_publisher.py create mode 100644 tb_mqtt_client/service/gateway/subdevice_manager.py create mode 100644 tb_mqtt_client/service/message_dispatcher.py create mode 100644 tb_mqtt_client/service/message_queue.py create mode 100644 tb_mqtt_client/service/message_splitter.py create mode 100644 tb_mqtt_client/service/mqtt_manager.py create mode 100644 tb_mqtt_client/service/rpc_response_handler.py create mode 100644 tb_mqtt_client/tb_device_mqtt.py create mode 100644 tests/constants/__init__.py create mode 100644 tests/constants/test_mqtt_topics.py create mode 100644 tests/service/__init__.py create mode 100644 tests/service/device/__init__.py create mode 100644 tests/service/device/test_device_client_rate_limits.py create mode 100644 tests/service/test_json_message_dispatcher.py create mode 100644 tests/service/test_message_splitter.py create mode 100644 tests/service/test_mqtt_manager.py diff --git a/__init__.py b/__init__.py index 234aa97..6d6dc0d 100644 --- a/__init__.py +++ b/__init__.py @@ -12,6 +12,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - -__name__ = "tb_mqtt_client" diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py new file mode 100644 index 0000000..d9d8330 --- /dev/null +++ b/examples/device/operational_example.py @@ -0,0 +1,135 @@ +import asyncio +import logging +import signal +from datetime import datetime, UTC +from random import randint, uniform + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.DEBUG) +logging.getLogger("tb_mqtt_client").setLevel(logging.DEBUG) + + +async def attribute_update_callback(update: AttributeUpdate): + """ + Callback function to handle attribute updates. + :param update: The attribute update object. + """ + logger.info("Received attribute update: %s", update.as_dict()) + + +async def rpc_request_callback(request: RPCRequest): + """ + Callback function to handle RPC requests. + :param request: The RPC request object. + :return: A RPC response object. + """ + logger.info("Received RPC request: %s", request.to_dict()) + response_data = { + "status": "success", + } + response = RPCResponse(request_id=request.request_id, + result=response_data, + error=None) + return response + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) + except NotImplementedError: + # Windows compatibility fallback + signal.signal(sig, lambda *_: _shutdown_handler()) + + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + client.set_attribute_update_callback(attribute_update_callback) + client.set_rpc_request_callback(rpc_request_callback) + await client.connect() + + logger.info("Connected to ThingsBoard.") + + while not stop_event.is_set(): + # --- Attributes --- + + # 1. Raw dict + raw_dict = { + "firmwareVersion": "1.0.4", + "hardwareModel": "TB-SDK-Device" + } + await client.send_attributes(raw_dict) + + logger.info(f"Raw attributes sent: {raw_dict}") + + # 2. Single AttributeEntry + single_entry = AttributeEntry("mode", "normal") + await client.send_attributes(single_entry) + + logger.info("Single attribute sent: %s", single_entry) + + # 3. List of AttributeEntry + attr_entries = [ + AttributeEntry("maxTemperature", 85), + AttributeEntry("calibrated", True) + ] + await client.send_attributes(attr_entries) + + # --- Telemetry --- + + # 1. Raw dict + raw_dict = { + "temperature": round(uniform(20.0, 30.0), 2), + "humidity": 60 + } + await client.send_telemetry(raw_dict) + + logger.info(f"Raw telemetry sent: {raw_dict}") + + # 2. Single TelemetryEntry (with ts) + single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) + await client.send_telemetry(single_entry) + + logger.info("Single telemetry sent: %s", single_entry) + + # 3. List of TelemetryEntry with mixed timestamps + ts_now = int(datetime.now(UTC).timestamp() * 1000) + telemetry_entries = [] + for i in range(100): + telemetry_entries.append(TimeseriesEntry("temperature", i, ts=ts_now-i)) + await client.send_telemetry(telemetry_entries) + logger.info("List of telemetry sent: %s, it took %r milliseconds", len(telemetry_entries), + int(datetime.now(UTC).timestamp() * 1000) - ts_now) + + try: + await asyncio.wait_for(stop_event.wait(), timeout=2) + except asyncio.TimeoutError: + pass + + await client.disconnect() + logger.info("Disconnected cleanly.") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("Interrupted by user.") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..cf309c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "thingsboard-python-client-sdk" +version = "2.0.0" +description = "Python MQTT client SDK for ThingsBoard platform" +authors = [ + { name = "ThingsBoard Team", email = "info@thingsboard.io" }, +] +license = "Apache-2.0" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "uvloop>=0.21.0", + "gmqtt>=0.6.10", + "orjson>=0.2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-asyncio>=0.21", + "mypy>=1.8", + "black>=24.0", + "isort>=5.12", + "ruff>=0.1.0", +] + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "-ra -q" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*", "*Tests"] +python_functions = ["test_*"] +asyncio_mode = "auto" +asyncio_fixture_scope = "function" + +[tool.black] +line-length = 100 +target-version = ["py38"] +exclude = ''' +/( + \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 + +[tool.mypy] +python_version = "3.12" +ignore_missing_imports = true +disallow_untyped_defs = true +strict_optional = true +warn_unused_ignores = true +warn_return_any = true +no_implicit_optional = true + +[tool.ruff] +line-length = 100 +target-version = "py312" +select = ["E", "F", "I", "B", "W"] \ No newline at end of file diff --git a/setup.py b/setup.py index f587e98..043e216 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ with open(path.join(this_directory, 'README.md')) as f: long_description = f.read() -VERSION = "1.13.5" +VERSION = "2.0" setup( version=VERSION, @@ -33,6 +33,6 @@ url="https://github.com/thingsboard/thingsboard-python-client-sdk", long_description=long_description, long_description_content_type="text/markdown", - python_requires=">=3.9", + python_requires=">=3.12", packages=["."], - install_requires=['tb-paho-mqtt-client>=2.1.2', 'requests>=2.31.0', 'orjson']) + install_requires=['gmqtt', 'requests>=2.31.0', 'orjson']) diff --git a/tb_mqtt_client/__init__.py b/tb_mqtt_client/__init__.py new file mode 100644 index 0000000..6d6dc0d --- /dev/null +++ b/tb_mqtt_client/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/common/__init__.py b/tb_mqtt_client/common/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py new file mode 100644 index 0000000..3c4383d --- /dev/null +++ b/tb_mqtt_client/common/config_loader.py @@ -0,0 +1,48 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Optional + + +class DeviceConfig: + def __init__(self): + self.host: str = os.getenv("TB_HOST") + self.port: int = int(os.getenv("TB_PORT", 1883)) + + # Authentication options + self.access_token: Optional[str] = os.getenv("TB_ACCESS_TOKEN") + self.username: Optional[str] = os.getenv("TB_USERNAME") + self.password: Optional[str] = os.getenv("TB_PASSWORD") + + # Optional + self.client_id: Optional[str] = os.getenv("TB_CLIENT_ID") + + # TLS options + self.ca_cert: Optional[str] = os.getenv("TB_CA_CERT") + self.client_cert: Optional[str] = os.getenv("TB_CLIENT_CERT") + self.private_key: Optional[str] = os.getenv("TB_PRIVATE_KEY") + + def use_tls_auth(self) -> bool: + return all([self.ca_cert, self.client_cert, self.private_key]) + + def use_tls(self) -> bool: + return self.ca_cert is not None + + def __repr__(self): + return (f"") diff --git a/tb_mqtt_client/common/exceptions.py b/tb_mqtt_client/common/exceptions.py new file mode 100644 index 0000000..96bbea1 --- /dev/null +++ b/tb_mqtt_client/common/exceptions.py @@ -0,0 +1,69 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import logging +from typing import Callable, Dict, List, Optional, Type + +logger = logging.getLogger("tb_sdk") + +ExceptionCallback = Callable[[BaseException, Optional[dict]], None] + + +class ExceptionHandler: + def __init__(self): + self._callbacks: Dict[Type[BaseException], List[ExceptionCallback]] = {} + self._default_callbacks: List[ExceptionCallback] = [] + + def register(self, exc_type: Optional[Type[BaseException]], callback: ExceptionCallback): + """ + Register a callback for a specific exception type or all (if exc_type is None). + """ + if exc_type is None: + self._default_callbacks.append(callback) + else: + self._callbacks.setdefault(exc_type, []).append(callback) + + def handle(self, exc: BaseException, context: Optional[dict] = None): + """ + Dispatch the exception to the appropriate registered callbacks. + """ + handled = False + for exc_type, callbacks in self._callbacks.items(): + if isinstance(exc, exc_type): + for cb in callbacks: + cb(exc, context) + handled = True + if not handled: + for cb in self._default_callbacks: + cb(exc, context) + + def install_asyncio_handler(self, loop: Optional[asyncio.AbstractEventLoop] = None): + """ + Hook into asyncio event loop to catch global task errors. + """ + loop = loop or asyncio.get_event_loop() + + def _asyncio_handler(loop_, context: dict): + exception = context.get("exception") + if exception: + self.handle(exception, context) + else: + logger.error("Unhandled asyncio context: %s", context) + + loop.set_exception_handler(_asyncio_handler) + + +exception_handler = ExceptionHandler() diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py new file mode 100644 index 0000000..5c8dc01 --- /dev/null +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -0,0 +1,88 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import struct +from types import MethodType +from typing import Callable +from collections import defaultdict + +from tb_mqtt_client.common.logging_utils import get_logger +from gmqtt.mqtt.property import Property +from gmqtt.mqtt.utils import unpack_variable_byte_integer + +logger = get_logger(__name__) + + +def patch_gmqtt_puback(client, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): + """ + Monkey-patch gmqtt.Client instance to intercept PUBACK reason codes and properties. + + :param client: GMQTTClient instance + :param on_puback_with_reason_and_properties: Callback with (mid, reason_code, properties_dict) + """ + # Backup original method from MqttPackageHandler + base_method = client.__class__.__bases__[0].__dict__.get('_handle_puback_packet') + + if base_method is None: + logger.error("Could not find _handle_puback_packet in base class.") + return + + def _parse_properties(packet: bytes) -> dict: + """ + Parse MQTT 5.0 properties from packet. + """ + properties_dict = defaultdict(list) + + try: + properties_len, packet = unpack_variable_byte_integer(packet) + props = packet[:properties_len] + packet = packet[properties_len:] + + while props: + property_identifier = props[0] + property_obj = Property.factory(id_=property_identifier) + if property_obj is None: + logger.warning(f"Unknown PUBACK property id={property_identifier}") + break + + result, props = property_obj.loads(props[1:]) + for k, v in result.items(): + properties_dict[k].append(v) + + except Exception as e: + logger.warning("Failed to parse PUBACK properties: %s", e) + + return dict(properties_dict) + + def wrapped_handle_puback(self, cmd, packet): + try: + mid = struct.unpack("!H", packet[:2])[0] + reason_code = 0 + properties = {} + + if len(packet) > 2: + reason_code = packet[2] + if len(packet) > 3: + props_payload = packet[3:] + properties = _parse_properties(props_payload) + + on_puback_with_reason_and_properties(mid, reason_code, properties) + except Exception as e: + logger.exception("Error while handling PUBACK with properties: %s", e) + + return base_method(self, cmd, packet) + + client._handle_puback_packet = MethodType(wrapped_handle_puback, client) diff --git a/tb_mqtt_client/common/logging_utils.py b/tb_mqtt_client/common/logging_utils.py new file mode 100644 index 0000000..c5175df --- /dev/null +++ b/tb_mqtt_client/common/logging_utils.py @@ -0,0 +1,59 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import logging +import sys +from typing import Optional + + +DEFAULT_LOG_FORMAT = "[%(asctime)s.%(msecs)03d] [%(levelname)s] %(name)s - %(lineno)d - %(message)s" +DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +TRACE_LEVEL = 5 +logging.addLevelName(TRACE_LEVEL, "TRACE") + + +def trace(self, message, *args, **kwargs): + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, message, args, **kwargs) + + +logging.Logger.trace = trace + + +def configure_logging(level: int = logging.INFO, + log_format: str = DEFAULT_LOG_FORMAT, + date_format: str = DEFAULT_DATE_FORMAT, + stream=sys.stdout): + """ + Configures the root logger with a stream handler and standardized format. + Should be called once during app startup. + """ + logging.basicConfig( + level=level, + format=log_format, + datefmt=date_format, + handlers=[ + logging.StreamHandler(stream) + ] + ) + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Returns a logger instance with the given name. + """ + return logging.getLogger(name or __name__) diff --git a/tb_mqtt_client/common/provision_client.py b/tb_mqtt_client/common/provision_client.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/common/provision_client.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/common/rate_limit/__init__.py b/tb_mqtt_client/common/rate_limit/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py new file mode 100644 index 0000000..5e33fc3 --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -0,0 +1,49 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from datetime import datetime, timedelta, UTC +from typing import Optional + +from tb_mqtt_client.common.logging_utils import get_logger + +logger = get_logger(__name__) + + +class BackpressureController: + def __init__(self): + self._pause_until: Optional[datetime] = None + self._default_pause_duration = timedelta(seconds=10) + + def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): + duration = timedelta(seconds=delay_seconds) if delay_seconds else self._default_pause_duration + self._pause_until = datetime.now(UTC) + duration + + def notify_disconnect(self, delay_seconds: Optional[int] = None): + self.notify_quota_exceeded(delay_seconds) + + def should_pause(self) -> bool: + if self._pause_until is None: + return False + if datetime.now(UTC) < self._pause_until: + return True + self._pause_until = None + return False + + def pause_for(self, seconds: int): + self._pause_until = datetime.now(UTC) + timedelta(seconds=seconds) + + def clear(self): + self._pause_until = None diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py new file mode 100644 index 0000000..a0b8edf --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -0,0 +1,171 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import os +import logging +from threading import RLock +from time import monotonic + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 5 + +try: + DEFAULT_RATE_LIMIT_PERCENTAGE = int(os.getenv('TB_DEFAULT_RATE_LIMIT_PERCENTAGE', 80)) +except ValueError: + logger.warning("Invalid TB_DEFAULT_RATE_LIMIT_PERCENTAGE, falling back to 80%%") + DEFAULT_RATE_LIMIT_PERCENTAGE = 80 + + +class GreedyTokenBucket: + def __init__(self, capacity, duration_sec): + self.capacity = float(capacity) + self.duration = float(duration_sec) + self.tokens = float(capacity) + self.last_updated = monotonic() + + def refill(self): + now = monotonic() + elapsed = now - self.last_updated + rate = self.capacity / self.duration + self.tokens = min(self.capacity, self.tokens + elapsed * rate) + self.last_updated = now + + def can_consume(self, amount=1): + self.refill() + return round(self.tokens, 6) >= round(amount, 6) + + def consume(self, amount=1): + self.refill() + if self.tokens >= amount: + self.tokens -= amount + return True + return False + + def get_remaining_tokens(self): + self.refill() + return self.tokens + + +class RateLimit: + def __init__(self, rate_limit: str, name: str = None, percentage: int = DEFAULT_RATE_LIMIT_PERCENTAGE): + self.name = name + self.percentage = percentage + self._no_limit = False + self._rate_buckets = {} + self._lock = RLock() + self._minimal_timeout = DEFAULT_TIMEOUT + self._minimal_limit = float('inf') + self.__reached_index = 0 + self.__reached_index_time = 0 + + self._parse_string(rate_limit) + + def _parse_string(self, rate_limit: str): + if not rate_limit or rate_limit.strip().removesuffix(',') in ("0:0", ""): + self._no_limit = True + return + + entries = rate_limit.replace(";", ",").split(",") + for entry in entries: + try: + limit_str, dur_str = entry.strip().split(":") + limit = int(int(limit_str) * self.percentage / 100) + duration = int(dur_str) + + bucket = GreedyTokenBucket(limit, duration) + self._rate_buckets[duration] = bucket + self._minimal_limit = min(self._minimal_limit, limit) + self._minimal_timeout = min(self._minimal_timeout, duration + 1) + except Exception as e: + logger.warning("Invalid rate limit format '%s': %s", entry, e) + + self._no_limit = not bool(self._rate_buckets) + + def check_limit_reached(self, amount=1): + if self._no_limit: + return False + + with self._lock: + result = False + for dur, bucket in self._rate_buckets.items(): + bucket.refill() + if not result and bucket.tokens < amount: + result = (bucket.capacity, dur) + return result + + def consume(self, amount=1): + if self._no_limit: + return + with self._lock: + for bucket in self._rate_buckets.values(): + bucket.consume(amount) + + @property + def minimal_limit(self): + return self._minimal_limit if self.has_limit() else 0 + + @property + def minimal_timeout(self): + return self._minimal_timeout if self.has_limit() else 0 + + def has_limit(self): + return not self._no_limit + + def reach_limit(self): + if self._no_limit: + return + + with self._lock: + durations = sorted(self._rate_buckets.keys()) + now = monotonic() + + if self.__reached_index_time >= now - self._rate_buckets[durations[-1]].duration: + self.__reached_index = 0 + self.__reached_index_time = now + + if self.__reached_index >= len(durations): + self.__reached_index = 0 + self.__reached_index_time = now + + dur = durations[self.__reached_index] + self._rate_buckets[dur].tokens = 0.0 + self.__reached_index += 1 + + logger.info("Rate limit reached for \"%s\". Cooldown for %s seconds", self.name, dur) + return self.__reached_index, self.__reached_index_time + + def to_dict(self): + return { + "name": self.name, + "percentage": self.percentage, + "no_limit": self._no_limit, + "rateLimits": { + str(dur): { + "capacity": b.capacity, + "tokens": b.get_remaining_tokens(), + "last_updated": b.last_updated + } for dur, b in self._rate_buckets.items() + } + } + + def set_limit(self, rate_limit: str, percentage: int = DEFAULT_RATE_LIMIT_PERCENTAGE): + with self._lock: + self._rate_buckets.clear() + self._minimal_timeout = DEFAULT_TIMEOUT + self._minimal_limit = float('inf') + self.percentage = percentage + self._parse_string(rate_limit) diff --git a/tb_mqtt_client/common/request_id_generator.py b/tb_mqtt_client/common/request_id_generator.py new file mode 100644 index 0000000..9e961f1 --- /dev/null +++ b/tb_mqtt_client/common/request_id_generator.py @@ -0,0 +1,44 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from asyncio import Lock + + +class RPCRequestIdProducer: + """ + Singleton-style producer of unique RPC request IDs, + safe for async environments and shared across all SDK services. + """ + + _counter: int = 1 + _lock: Lock = Lock() + + @classmethod + async def get_next(cls) -> int: + """ + Atomically increment and return the next request ID. + """ + async with cls._lock: + current = cls._counter + cls._counter += 1 + return current + + @classmethod + def reset(cls): + """ + Reset the global request ID counter (usually on disconnect). + """ + cls._counter = 1 diff --git a/tb_mqtt_client/constants/__init__.py b/tb_mqtt_client/constants/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/constants/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py new file mode 100644 index 0000000..dc703cc --- /dev/null +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -0,0 +1,66 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +WILDCARD = "+" +REQUEST_TOPIC_SUFFIX = "/request" +RESPONSE_TOPIC_SUFFIX = "/response" +# V1 Topics for Device API +DEVICE_TELEMETRY_TOPIC = "v1/devices/me/telemetry" +DEVICE_ATTRIBUTES_TOPIC = "v1/devices/me/attributes" +DEVICE_ATTRIBUTES_REQUEST_TOPIC = DEVICE_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + "{request_id}" +DEVICE_ATTRIBUTES_RESPONSE_TOPIC = DEVICE_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD +DEVICE_RPC_TOPIC = "v1/devices/me/rpc" +# Device RPC topics +DEVICE_RPC_REQUEST_TOPIC = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" +DEVICE_RPC_RESPONSE_TOPIC = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" +DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + WILDCARD +DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD + +# V1 Topics for Gateway API +BASE_GATEWAY_TOPIC = "v1/gateway" +GATEWAY_CONNECT_TOPIC = BASE_GATEWAY_TOPIC + "/connect" +GATEWAY_DISCONNECT_TOPIC = BASE_GATEWAY_TOPIC + "disconnect" +GATEWAY_TELEMETRY_TOPIC = BASE_GATEWAY_TOPIC + "telemetry" +GATEWAY_ATTRIBUTES_TOPIC = BASE_GATEWAY_TOPIC + "attributes" +GATEWAY_ATTRIBUTES_REQUEST_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX +GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX +GATEWAY_RPC_TOPIC = BASE_GATEWAY_TOPIC + "rpc" + +# Topic Builders + + +def build_device_attributes_request_topic(request_id: int) -> str: + return DEVICE_ATTRIBUTES_REQUEST_TOPIC.format(request_id=request_id) + + +def build_device_rpc_request_topic(request_id: int) -> str: + return DEVICE_RPC_REQUEST_TOPIC + str(request_id) + + +def build_device_rpc_response_topic(request_id: int) -> str: + return DEVICE_RPC_RESPONSE_TOPIC + str(request_id) + + +def build_gateway_device_telemetry_topic() -> str: + return GATEWAY_TELEMETRY_TOPIC + + +def build_gateway_device_attributes_topic() -> str: + return GATEWAY_ATTRIBUTES_TOPIC + + +def build_gateway_rpc_topic() -> str: + return GATEWAY_RPC_TOPIC diff --git a/tb_mqtt_client/constants/service_keys.py b/tb_mqtt_client/constants/service_keys.py new file mode 100644 index 0000000..624bd8e --- /dev/null +++ b/tb_mqtt_client/constants/service_keys.py @@ -0,0 +1,19 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +MESSAGES_RATE_LIMIT = "MESSAGES_RATE_LIMIT" +TELEMETRY_MESSAGE_RATE_LIMIT = "TELEMETRY_MESSAGE_RATE_LIMIT" +TELEMETRY_DATAPOINTS_RATE_LIMIT = "TELEMETRY_DATAPOINTS_RATE_LIMIT" diff --git a/tb_mqtt_client/constants/service_messages.py b/tb_mqtt_client/constants/service_messages.py new file mode 100644 index 0000000..faaa3ff --- /dev/null +++ b/tb_mqtt_client/constants/service_messages.py @@ -0,0 +1,24 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from orjson import dumps + + +SESSION_LIMITS_REQUEST_MESSAGE = dumps( + { + "method": "getSessionLimits", + "params": {}, + } +) diff --git a/tb_mqtt_client/entities/__init__.py b/tb_mqtt_client/entities/__init__.py new file mode 100644 index 0000000..6d6dc0d --- /dev/null +++ b/tb_mqtt_client/entities/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/data/__init__.py b/tb_mqtt_client/entities/data/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/data/attribute_entry.py b/tb_mqtt_client/entities/data/attribute_entry.py new file mode 100644 index 0000000..cd5eef7 --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_entry.py @@ -0,0 +1,37 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any + +from tb_mqtt_client.entities.data.data_entry import DataEntry + + +class AttributeEntry(DataEntry): + def __init__(self, key: str, value: Any): + super().__init__(key, value) + + def __repr__(self): + return f"" + + def as_dict(self) -> dict: + return { + "key": self.key, + "value": self.value + } + + def __eq__(self, other): + if not isinstance(other, AttributeEntry): + return False + return self.key == other.key and self.value == other.value diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/data/attribute_response.py b/tb_mqtt_client/entities/data/attribute_response.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_response.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/data/attribute_update.py b/tb_mqtt_client/entities/data/attribute_update.py new file mode 100644 index 0000000..731c1bc --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_update.py @@ -0,0 +1,57 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from dataclasses import dataclass +from typing import Dict, Any, List + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@dataclass(slots=True) +class AttributeUpdate: + entries: List[AttributeEntry] + + def __repr__(self): + return f"" + + def get(self, key: str, default=None): + for entry in self.entries: + if entry.key == key: + return entry.value + return default + + def keys(self): + return [entry.key for entry in self.entries] + + def values(self): + return [entry.value for entry in self.entries] + + def items(self): + return [(entry.key, entry.value) for entry in self.entries] + + def as_dict(self) -> Dict[str, Any]: + return {entry.key: entry.value for entry in self.entries} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'AttributeUpdate': + """ + Deserialize dictionary into AttributeUpdate object. + :param data: Dictionary of attribute key-value pairs. + :param device: Optional device name (for gateway context). + :return: AttributeUpdate instance. + """ + entries = [AttributeEntry(k, v) for k, v in data.items()] + return cls(entries=entries) diff --git a/tb_mqtt_client/entities/data/data_entry.py b/tb_mqtt_client/entities/data/data_entry.py new file mode 100644 index 0000000..a6920c9 --- /dev/null +++ b/tb_mqtt_client/entities/data/data_entry.py @@ -0,0 +1,63 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from typing import Any, Optional +from orjson import dumps + + +class DataEntry: + def __init__(self, key: str, value: Any, ts: Optional[int] = None): + self.__key = key + self.__value = value + self.__ts = ts + self.__size = self.__estimate_size() + + def __estimate_size(self) -> int: + if self.ts is not None: + return len(dumps({"ts": self.ts, "values": {self.key: self.value}})) + else: + return len(dumps({self.key: self.value})) + + @property + def size(self) -> int: + return self.__size + + @property + def key(self) -> str: + return self.__key + + @key.setter + def key(self, value: str): + self.__key = value + self.__size = self.__estimate_size() + + @property + def value(self) -> Any: + return self.__value + + @value.setter + def value(self, value: Any): + self.__value = value + self.__size = self.__estimate_size() + + @property + def ts(self) -> Optional[int]: + return self.__ts + + @ts.setter + def ts(self, value: Optional[int]): + self.__ts = value + self.__size = self.__estimate_size() diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py new file mode 100644 index 0000000..656d73f --- /dev/null +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -0,0 +1,97 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from typing import List, Optional, Union +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + +DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) + + +class DeviceUplinkMessage: + def __init__(self, + device_name: Optional[str] = None, + device_profile: Optional[str] = None, + attributes: Optional[List[AttributeEntry]] = None, + timeseries: Optional[List[TimeseriesEntry]] = None, + _size: Optional[int] = None): + if _size is None: + raise ValueError("DeviceUplinkMessage must be created using DeviceUplinkMessageBuilder") + + self.device_name = device_name + self.device_profile = device_profile + self.attributes = attributes or [] + self.timeseries = timeseries or [] + self.__size = _size + + def timeseries_datapoint_count(self) -> int: + return len(self.timeseries) + + def has_attributes(self) -> bool: + return bool(self.attributes) + + def has_timeseries(self) -> bool: + return bool(self.timeseries) + + @property + def size(self) -> int: + return self.__size + + +class DeviceUplinkMessageBuilder: + def __init__(self): + self._device_name: Optional[str] = None + self._device_profile: Optional[str] = None + self._attributes: List[AttributeEntry] = [] + self._timeseries: List[TimeseriesEntry] = [] + self.__size = DEFAULT_FIELDS_SIZE + + def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': + self._device_name = device_name + if device_name is not None: + self.__size += len(device_name) + return self + + def set_device_profile(self, profile: str) -> 'DeviceUplinkMessageBuilder': + self._device_profile = profile + if profile is not None: + self.__size += len(profile) + return self + + def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(attributes, list): + attributes = [attributes] + self._attributes.extend(attributes) + for attribute in attributes: + self.__size += attribute.size + return self + + def add_telemetry(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(timeseries, list): + timeseries = [timeseries] + self._timeseries.extend(timeseries) + for timeseries_entry in timeseries: + self.__size += timeseries_entry.size + return self + + def build(self) -> DeviceUplinkMessage: + return DeviceUplinkMessage( + device_name=self._device_name, + device_profile=self._device_profile, + attributes=self._attributes, + timeseries=self._timeseries, + _size=self.__size + ) diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py new file mode 100644 index 0000000..d09ad1b --- /dev/null +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -0,0 +1,47 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from dataclasses import dataclass +from typing import Union, Optional, Dict, Any, List + + +@dataclass(slots=True, frozen=True) +class RPCRequest: + request_id: Union[int, str] + method: str + params: Optional[Union[Dict[str, Any], List[Any]]] = None + + def to_dict(self) -> Dict[str, Any]: + result = { + "id": self.request_id, + "method": self.method + } + if self.params is not None: + result["params"] = self.params + return result + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'RPCRequest': + if "id" not in data: + raise ValueError("Missing 'id' in RPC request") + if "method" not in data: + raise ValueError("Missing 'method' in RPC request") + + return cls( + request_id=data["id"], + method=data["method"], + params=data.get("params") + ) diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py new file mode 100644 index 0000000..12b0faf --- /dev/null +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -0,0 +1,42 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from dataclasses import dataclass +from typing import Union, Optional, Dict, Any + + +@dataclass(slots=True, frozen=True) +class RPCResponse: + """ + Represents a response to a device-side RPC call. + + Attributes: + request_id: Unique identifier of the request being responded to. + result: Optional response payload (Any type allowed). + error: Optional error information if the RPC failed. + """ + request_id: Union[int, str] + result: Optional[Any] = None + error: Optional[Union[str, Dict[str, Any]]] = None + + def to_dict(self) -> Dict[str, Any]: + """Serializes the RPC response for publishing.""" + data = {} + if self.result is not None: + data["result"] = self.result + if self.error is not None: + data["error"] = self.error + return data diff --git a/tb_mqtt_client/entities/data/timeseries_entry.py b/tb_mqtt_client/entities/data/timeseries_entry.py new file mode 100644 index 0000000..4f84a7b --- /dev/null +++ b/tb_mqtt_client/entities/data/timeseries_entry.py @@ -0,0 +1,56 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from typing import Any, Optional + +from tb_mqtt_client.entities.data.data_entry import DataEntry + + +class TimeseriesEntry(DataEntry): + def __init__(self, key: str, value: Any, ts: Optional[int] = None): + super().__init__(key, value, ts) + + def __repr__(self): + return f"" + + def as_dict(self) -> dict: + result = { + "key": self.key, + "value": self.value + } + if self.ts is not None: + result["ts"] = self.ts + return result + + def __eq__(self, other): + if not isinstance(other, TimeseriesEntry): + return False + return self.key == other.key and self.value == other.value and self.ts == other.ts diff --git a/tb_mqtt_client/entities/gateway/__init__.py b/tb_mqtt_client/entities/gateway/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/gateway/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/gateway/device_session_state.py b/tb_mqtt_client/entities/gateway/device_session_state.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_session_state.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/gateway/rpc_context.py b/tb_mqtt_client/entities/gateway/rpc_context.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/gateway/rpc_context.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/entities/gateway/virtual_device.py b/tb_mqtt_client/entities/gateway/virtual_device.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/entities/gateway/virtual_device.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/__init__.py b/tb_mqtt_client/service/__init__.py new file mode 100644 index 0000000..6d6dc0d --- /dev/null +++ b/tb_mqtt_client/service/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py new file mode 100644 index 0000000..2594634 --- /dev/null +++ b/tb_mqtt_client/service/base_client.py @@ -0,0 +1,99 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from abc import ABC, abstractmethod +import asyncio +import uvloop +from typing import Callable, Awaitable, Dict, Any, Union +from tb_mqtt_client.common.exceptions import exception_handler + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.rpc_response import RPCResponse + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +exception_handler.install_asyncio_handler() + + +class BaseClient(ABC): + """ + Abstract base class for ThingsBoard clients. + """ + + def __init__(self, host: str, port: int, client_id: str): + self._host = host + self._port = port + self._client_id = client_id + self._connected = asyncio.Event() + + @abstractmethod + async def connect(self): + """ + Connect to the ThingsBoard platform over MQTT. + """ + pass + + @abstractmethod + async def disconnect(self): + """ + Disconnect from the platform. + """ + pass + + @abstractmethod + async def send_telemetry(self, telemetry_data: Dict[str, Any]): + """ + Send telemetry data. + + :param telemetry_data: Dictionary of telemetry key-values. + """ + pass + + @abstractmethod + async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]]): + """ + Send client-side attributes. + + :param attributes: Dictionary of attributes. + """ + pass + + @abstractmethod + async def send_rpc_response(self, response: RPCResponse): + """ + Send a response to a server-initiated RPC request. + + :param RPCResponse response: The RPC response to send. + """ + pass + + @abstractmethod + def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + """ + Set callback to be triggered when a shared attribute update is received. + + :param callback: Coroutine accepting an AttributeUpdate instance. + """ + pass + + @abstractmethod + def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + """ + Set callback to be triggered when an RPC request is received. + + :param callback: Coroutine accepting (method, params) and returning result. + """ + pass diff --git a/tb_mqtt_client/service/device/__init__.py b/tb_mqtt_client/service/device/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/device/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/device/attribute_updates_handler.py b/tb_mqtt_client/service/device/attribute_updates_handler.py new file mode 100644 index 0000000..e0b593e --- /dev/null +++ b/tb_mqtt_client/service/device/attribute_updates_handler.py @@ -0,0 +1,52 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from typing import Awaitable, Callable, Optional +from orjson import loads + +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.common.logging_utils import get_logger + +logger = get_logger(__name__) + + +class AttributeUpdatesHandler: + """ + Handles shared attribute update messages from the platform. + """ + + def __init__(self): + self._callback: Optional[Callable[[AttributeUpdate], Awaitable[None]]] = None + + def set_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + """ + Sets the async callback that will be triggered on shared attribute update. + + :param callback: A coroutine that takes an AttributeUpdate object. + """ + self._callback = callback + + async def handle(self, topic: str, payload: bytes): + if not self._callback: + logger.debug("No attribute update callback set. Skipping payload.") + return + + try: + data = loads(payload) + update = AttributeUpdate.from_dict(data) + await self._callback(update) + except Exception as e: + logger.exception("Failed to handle attribute update: %s", e) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py new file mode 100644 index 0000000..8c82a16 --- /dev/null +++ b/tb_mqtt_client/service/device/client.py @@ -0,0 +1,272 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from asyncio import sleep, wait_for, TimeoutError + +from orjson import loads, dumps +from random import choices +from string import ascii_uppercase, digits +from typing import Callable, Awaitable, Optional, Dict, Any, Union, List + +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.message_queue import MessageQueue +from tb_mqtt_client.service.mqtt_manager import MQTTManager +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.service.device.attribute_updates_handler import AttributeUpdatesHandler +from tb_mqtt_client.service.device.rpc_requests_handler import RPCRequestsHandler +from tb_mqtt_client.service.rpc_response_handler import RPCResponseHandler + +logger = get_logger(__name__) + + +class DeviceClient(BaseClient): + def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): + self._config = None + if isinstance(config, DeviceConfig): + self._config = config + else: + self._config = DeviceConfig() + if isinstance(config, dict): + for key, value in config.items(): + if hasattr(self._config, key) and value is not None: + setattr(self._config, key, value) + + client_id = self._config.client_id or "tb-client-" + ''.join(choices(ascii_uppercase + digits, k=6)) + + super().__init__(self._config.host, self._config.port, client_id) + + self._messages_rate_limit = RateLimit("0:0,", name="messages") + self._telemetry_rate_limit = RateLimit("0:0,", name="telemetry") + self._telemetry_dp_rate_limit = RateLimit("0:0,", name="telemetryDataPoints") + self.max_payload_size = None + self._max_inflight_messages = 100 + self._max_uplink_message_queue_size = 10000 + self._max_queued_messages = 50000 + + self._rpc_response_handler = RPCResponseHandler() + + self._mqtt_manager = MQTTManager(self._client_id, + self._on_connect, + self._on_disconnect, + self._handle_rate_limit_response, + rpc_response_handler=self._rpc_response_handler,) + + self._message_queue: Optional[MessageQueue] = None + self._dispatcher: Optional[JsonMessageDispatcher] = None + + self._attribute_updates_handler = AttributeUpdatesHandler() + self._rpc_requests_handler = RPCRequestsHandler() + + async def connect(self): + logger.info("Connecting to platform at %s:%s", self._host, self._port) + + ssl_context = None + tls = self._config.use_tls() + if tls: + import ssl + ssl_context = ssl.create_default_context() + ssl_context.load_verify_locations(self._config.ca_cert) + ssl_context.load_cert_chain(certfile=self._config.client_cert, keyfile=self._config.private_key) + + await self._mqtt_manager.connect( + host=self._host, + port=self._port, + username=self._config.access_token or self._config.username, + password=None if self._config.access_token else self._config.password, + tls=tls, + ssl_context=ssl_context + ) + + while not self._mqtt_manager.is_connected(): + await self._mqtt_manager.await_ready() + + await self._on_connect() + + self._dispatcher = JsonMessageDispatcher(self.max_payload_size, self._telemetry_dp_rate_limit.minimal_limit) + self._message_queue = MessageQueue( + mqtt_manager=self._mqtt_manager, + message_rate_limit=self._messages_rate_limit, + telemetry_rate_limit=self._telemetry_rate_limit, + telemetry_dp_rate_limit=self._telemetry_dp_rate_limit, + message_dispatcher=self._dispatcher, + max_queue_size=self._max_uplink_message_queue_size, + ) + + async def disconnect(self): + await self._mqtt_manager.disconnect() + if self._message_queue: + await self._message_queue.shutdown() + + async def send_telemetry(self, telemetry_data: Union[Dict[str, Any], + TimeseriesEntry, + List[TimeseriesEntry], + List[Dict[str, Any]]]): + message = self._build_uplink_message_for_telemetry(telemetry_data) + await self._message_queue.publish(topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + payload=message, + datapoints_count=message.timeseries_datapoint_count()) + + async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]]): + message = self._build_uplink_message_for_attributes(attributes) + await self._message_queue.publish(topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, + payload=message, + datapoints_count=len(message.attributes)) + + async def send_rpc_response(self, response: RPCResponse): + topic = mqtt_topics.build_device_rpc_response_topic(request_id=response.request_id) + await self._message_queue.publish(topic=topic, + payload=dumps(response.to_dict()), + datapoints_count=0) + + def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + self._attribute_updates_handler.set_callback(callback) + + def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + self._rpc_requests_handler.set_callback(callback) + + async def _on_connect(self): + logger.info("Subscribing to attribute and RPC topics") + + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_response) # noqa + + async def _on_disconnect(self): + logger.warning("Device client disconnected") + + async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Any: + """ + Initiates a client-side RPC to ThingsBoard and awaits the result. + :param method: The RPC method to call. + :param params: The parameters to send. + :param timeout: Timeout for the response in seconds. + :return: The response result (dict, list, str, etc.), or raises on error. + """ + request_id = await RPCRequestIdProducer.get_next() + topic = mqtt_topics.build_device_rpc_request_topic(request_id) + payload = dumps({ + "method": method, + "params": params or {} + }) + + future = self._rpc_response_handler.register_request(request_id) + await self._mqtt_manager.publish(topic, payload, qos=1) + + try: + return await wait_for(future, timeout=timeout) + except TimeoutError: + raise TimeoutError(f"Timed out waiting for RPC response (method={method}, id={request_id})") + + async def _handle_attribute_update(self, topic: str, payload: bytes): + await self._attribute_updates_handler.handle(topic, payload) + + async def _handle_rpc_request(self, topic: str, payload: bytes): + response: RPCResponse = await self._rpc_requests_handler.handle(topic, payload) + if response: + await self.send_rpc_response(response) + + async def _handle_rpc_response(self, topic: str, payload: bytes): + await self._rpc_response_handler.handle(topic, payload) + + async def _handle_rate_limit_response(self, topic: str, payload: bytes): + try: + response = loads(payload.decode("utf-8")) + logger.debug("Received rate limit response payload: %s", response) + + if not isinstance(response, dict) or 'rateLimits' not in response: + logger.warning("Invalid rate limit response: %r", response) + return + + rate_limits = response.get('rateLimits', {}) + + self._messages_rate_limit.set_limit(rate_limits.get("messages", "0:0,")) + self._telemetry_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,")) + self._telemetry_dp_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,")) + + server_inflight = int(response.get("maxInflightMessages", 100)) + limits = [rl.minimal_limit for rl in [ + self._messages_rate_limit, + self._telemetry_rate_limit + ] if rl.has_limit()] + + if limits: + self._max_inflight_messages = int( + min(min(limits), server_inflight) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + else: + self._max_inflight_messages = int(server_inflight * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + if self._max_inflight_messages == 0: + self._max_inflight_messages = 10000 + + if "maxPayloadSize" in response: + self.max_payload_size = int(response["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + + if (not self._messages_rate_limit.has_limit() + and not self._telemetry_rate_limit.has_limit() + and not self._telemetry_dp_rate_limit.has_limit()): + self._max_queued_messages = 50000 + logger.debug("No rate limits, setting max_queued_messages to 50000") + else: + self._max_queued_messages = self._max_inflight_messages + logger.debug("With rate limits, setting max_queued_messages to %r", self._max_queued_messages) + + logger.info("Service configuration retrieved and applied.") + logger.info("Parsed device limits: %r", response) + + self._mqtt_manager.set_rate_limits( + self._messages_rate_limit, + self._telemetry_rate_limit, + self._telemetry_dp_rate_limit + ) + return True + + except Exception as e: + logger.exception("Failed to parse rate limits from server response: %s", e) + return False + + @staticmethod + def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], + TimeseriesEntry, + List[TimeseriesEntry], + List[Dict[str, Any]]]) -> DeviceUplinkMessage: + if isinstance(payload, dict): + payload = [TimeseriesEntry(k, v) for k, v in payload.items()] + + builder = DeviceUplinkMessageBuilder() + builder.add_telemetry(payload) + return builder.build() + + @staticmethod + def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], + AttributeEntry, + List[AttributeEntry]]) -> DeviceUplinkMessage: + if isinstance(payload, dict): + payload = [AttributeEntry(k, v) for k, v in payload.items()] + + builder = DeviceUplinkMessageBuilder() + builder.add_attributes(payload) + return builder.build() diff --git a/tb_mqtt_client/service/device/rpc_requests_handler.py b/tb_mqtt_client/service/device/rpc_requests_handler.py new file mode 100644 index 0000000..2f9d165 --- /dev/null +++ b/tb_mqtt_client/service/device/rpc_requests_handler.py @@ -0,0 +1,69 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from orjson import loads +from typing import Awaitable, Callable, Dict, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse + +logger = get_logger(__name__) + + +class RPCRequestsHandler: + """ + Handles incoming RPC request messages for a device. + """ + + def __init__(self): + self._callback: Optional[Callable[[str, Dict], Awaitable[Dict]]] = None + + def set_callback(self, callback: Callable[[str, Dict], Awaitable[Dict]]): + """ + Set the async callback to handle incoming RPC requests. + :param callback: A coroutine accepting (method_name, params) and returning a result dict. + """ + self._callback = callback + + async def handle(self, topic: str, payload: bytes) -> Optional[RPCResponse]: + """ + Process the RPC request and return response payload and request ID (if possible). + :returns: (request_id, response_dict) or None if failed + """ + if not self._callback: + logger.debug("No RPC request callback set. Skipping RPC handling. " + "You can add set callback using client.set_rpc_request_callback(your_method)") + return None + + try: + request_id = int(topic.split("/")[-1]) + parsed = loads(payload) + parsed["id"] = request_id + rpc_request = RPCRequest.from_dict(parsed) + + logger.debug("Handling RPC method id: %i - %s with params: %s", + rpc_request.request_id, rpc_request.method, rpc_request.params) + + result = await self._callback(rpc_request) + if not isinstance(result, RPCResponse): + logger.error("RPC callback must return an instance of RPCResponse, got: %s", type(result)) + return None + return result + + except Exception as e: + logger.exception("Failed to process RPC request: %s", e) + return None diff --git a/tb_mqtt_client/service/event_dispatcher.py b/tb_mqtt_client/service/event_dispatcher.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/event_dispatcher.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/gateway/__init__.py b/tb_mqtt_client/service/gateway/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/gateway/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/gateway/client.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/gateway/device_sesion.py b/tb_mqtt_client/service/gateway/device_sesion.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/gateway/device_sesion.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py b/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/gateway/multiplex_publisher.py b/tb_mqtt_client/service/gateway/multiplex_publisher.py new file mode 100644 index 0000000..52dbfcf --- /dev/null +++ b/tb_mqtt_client/service/gateway/multiplex_publisher.py @@ -0,0 +1,15 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + diff --git a/tb_mqtt_client/service/gateway/subdevice_manager.py b/tb_mqtt_client/service/gateway/subdevice_manager.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tb_mqtt_client/service/gateway/subdevice_manager.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py new file mode 100644 index 0000000..d6e8636 --- /dev/null +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -0,0 +1,146 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from abc import ABC, abstractmethod +from typing import Any, Dict, Union, List, Tuple, Optional +from orjson import dumps + +from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.service.message_splitter import MessageSplitter +from tb_mqtt_client.common.logging_utils import get_logger + +logger = get_logger(__name__) + + +class MessageDispatcher(ABC): + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): + self._splitter = MessageSplitter(max_payload_size, max_datapoints) + logger.trace("MessageDispatcher initialized with max_payload_size=%s, max_datapoints=%s", + max_payload_size, max_datapoints) + + @abstractmethod + def build_topic_payloads( + self, + messages: List[DeviceUplinkMessage] + ) -> List[Tuple[str, bytes, int]]: + """ + Build a list of topic-payload pairs from the given messages. + Each pair consists of a topic string and a payload byte array. + """ + pass + + @abstractmethod + def build_payload(self, msg: DeviceUplinkMessage) -> bytes: + """ + Build a JSON payload for a single DeviceUplinkMessage. + """ + pass + + @abstractmethod + def splitter(self) -> MessageSplitter: + """ + Get the message splitter instance. + """ + pass + + +class JsonMessageDispatcher(MessageDispatcher): + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): + super().__init__(max_payload_size, max_datapoints) + logger.trace("JsonMessageDispatcher created.") + + @property + def splitter(self) -> MessageSplitter: + return self._splitter + + def build_topic_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int]]: + if not messages: + logger.trace("No messages to process in build_topic_payloads.") + return [] + + from collections import defaultdict + + result: List[Tuple[str, bytes, int]] = [] + device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) + + for msg in messages: + device_name = msg.device_name or "" + device_groups[device_name].append(msg) + logger.trace("Queued message for device='%s'", device_name) + + logger.trace("Processing %d device group(s).", len(device_groups)) + + for device, device_msgs in device_groups.items(): + telemetry_msgs = [m for m in device_msgs if m.has_timeseries()] + attr_msgs = [m for m in device_msgs if m.has_attributes()] + logger.trace("Device '%s' - telemetry: %d, attributes: %d", + device, len(telemetry_msgs), len(attr_msgs)) + + for ts_batch in self._splitter.split_timeseries(telemetry_msgs): + payload = self.build_payload(ts_batch) + count = ts_batch.timeseries_datapoint_count() + result.append((DEVICE_TELEMETRY_TOPIC, payload, count)) + logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) + + for attr_batch in self._splitter.split_attributes(attr_msgs): + payload = self.build_payload(attr_batch) + count = len(attr_batch.attributes) + result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count)) + logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) + + logger.trace("Generated %d topic-payload entries.", len(result)) + return result + + def build_payload(self, msg: DeviceUplinkMessage) -> bytes: + result: Dict[str, Any] = {} + device_name = msg.device_name or "" + logger.trace("Building payload for device='%s'", device_name) + + if msg.device_name: + if msg.attributes: + logger.trace("Packing attributes for device='%s'", device_name) + result[msg.device_name] = self._pack_attributes(msg) + if msg.timeseries: + logger.trace("Packing timeseries for device='%s'", device_name) + result[msg.device_name] = self._pack_timeseries(msg) + else: + if msg.attributes: + logger.trace("Packing anonymous attributes") + result = self._pack_attributes(msg) + if msg.timeseries: + logger.trace("Packing anonymous timeseries") + result = self._pack_timeseries(msg) + + payload = dumps(result) + logger.trace("Built payload size: %d bytes", len(payload)) + return payload + + @staticmethod + def _pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: + logger.trace("Packing %d attribute(s)", len(msg.attributes)) + return {attr.key: attr.value for attr in msg.attributes} + + @staticmethod + def _pack_timeseries(msg: DeviceUplinkMessage) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + logger.trace("Packing %d timeseries entry(ies)", len(msg.timeseries)) + grouped = {} + for entry in msg.timeseries: + grouped.setdefault(entry.ts or 0, {})[entry.key] = entry.value + + if all(ts == 0 for ts in grouped): + return grouped[0] + return [{"ts": ts, "values": values} for ts, values in grouped.items()] diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py new file mode 100644 index 0000000..0bbae9f --- /dev/null +++ b/tb_mqtt_client/service/message_queue.py @@ -0,0 +1,211 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +from typing import List, Optional, Union, Tuple +from contextlib import suppress + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.service.mqtt_manager import MQTTManager +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher + +logger = get_logger(__name__) + + +class MessageQueue: + _BATCH_TIMEOUT = 0.01 # seconds to wait for batching (optional flush time) + + def __init__(self, + mqtt_manager: MQTTManager, + message_rate_limit: Optional[RateLimit], + telemetry_rate_limit: Optional[RateLimit], + telemetry_dp_rate_limit: Optional[RateLimit], + message_dispatcher: MessageDispatcher, + max_queue_size: int = 10000, + batch_collect_max_time_ms: int = 100, + batch_collect_max_count: int = 500): + self._batch_max_time = batch_collect_max_time_ms / 1000 # convert to seconds + self._batch_max_count = batch_collect_max_count + self._mqtt_manager = mqtt_manager + self._message_rate_limit = message_rate_limit + self._telemetry_rate_limit = telemetry_rate_limit + self._telemetry_dp_rate_limit = telemetry_dp_rate_limit + self._backpressure = self._mqtt_manager.backpressure + self._queue = asyncio.Queue(maxsize=max_queue_size) + self._active = asyncio.Event() + self._wakeup_event = asyncio.Event() + self._active.set() + self._dispatcher = message_dispatcher + self._loop_task = asyncio.create_task(self._dequeue_loop()) + logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", + max_queue_size, self._batch_max_time, batch_collect_max_count) + + async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int): + try: + self._queue.put_nowait((topic, payload, datapoints_count)) + logger.trace("Enqueued message: topic=%s, datapoints=%d, type=%s", + topic, datapoints_count, type(payload).__name__) + except asyncio.QueueFull: + logger.warning("Message queue full. Dropping message for topic %s", topic) + + async def _dequeue_loop(self): + while self._active.is_set(): + try: + topic, payload, count = await self._wait_for_message() + except asyncio.TimeoutError: + continue + + if isinstance(payload, bytes): + logger.trace("Dequeued immediate publish: topic=%s (raw bytes)", topic) + await self._try_publish(topic, payload, count) + continue + + logger.trace("Dequeued message for batching: topic=%s, device=%s", + topic, getattr(payload, 'device_name', 'N/A')) + + batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], int]] = [(topic, payload, count)] + start = asyncio.get_event_loop().time() + batch_size = payload.size + + while not self._queue.empty(): + elapsed = asyncio.get_event_loop().time() - start + if elapsed >= self._batch_max_time: + logger.trace("Batch time threshold reached: %.3fs", elapsed) + break + if len(batch) >= self._batch_max_count: + logger.trace("Batch count threshold reached: %d messages", len(batch)) + break + + next_topic, next_payload, next_count = self._queue.get_nowait() + if isinstance(next_payload, DeviceUplinkMessage): + msg_size = next_payload.size + if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: + logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) + await self._queue.put((next_topic, next_payload, next_count)) + break + batch.append((next_topic, next_payload, next_count)) + batch_size += msg_size + else: + logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_topic) + await self._try_publish(next_topic, next_payload, next_count) + + if batch: + messages = [p for _, p, _ in batch if isinstance(p, DeviceUplinkMessage)] + logger.trace("Formed batch with %d DeviceUplinkMessages", len(messages)) + topic_payloads = self._dispatcher.build_topic_payloads(messages) + for topic, payload, datapoints in topic_payloads: + logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d", + topic, len(payload), datapoints) + await self._try_publish(topic, payload, datapoints) + + async def _try_publish(self, topic: str, payload: bytes, points: int): + telemetry = topic == mqtt_topics.DEVICE_TELEMETRY_TOPIC + logger.trace("Attempting publish: topic=%s, datapoints=%d", topic, points) + + if self._backpressure.should_pause(): + self._schedule_delayed_retry(topic, payload, points, delay=1.0) + return + + if telemetry: + if self._telemetry_rate_limit and self._telemetry_rate_limit.check_limit_reached(1): + logger.debug("Telemetry message rate limit hit: topic=%s", topic) + retry_delay = self._telemetry_rate_limit.minimal_timeout + self._schedule_delayed_retry(topic, payload, points, delay=retry_delay) + return + if self._telemetry_dp_rate_limit and self._telemetry_dp_rate_limit.check_limit_reached(points): + logger.debug("Telemetry datapoint rate limit hit: topic=%s", topic) + retry_delay = self._telemetry_dp_rate_limit.minimal_timeout + self._schedule_delayed_retry(topic, payload, points, delay=retry_delay) + return + else: + if self._message_rate_limit and self._message_rate_limit.check_limit_reached(1): + logger.debug("Generic message rate limit hit: topic=%s", topic) + logger.debug("Rate limit state: %s", self._message_rate_limit.to_dict()) + retry_delay = self._message_rate_limit.minimal_timeout + self._schedule_delayed_retry(topic, payload, points, delay=retry_delay) + return + + try: + logger.debug("Rate limit state before publish: %s", self._message_rate_limit.to_dict()) + await self._mqtt_manager.publish(topic, payload, qos=1) + logger.trace("Publish successful: topic=%s", topic) + if telemetry: + if self._telemetry_rate_limit: + self._telemetry_rate_limit.consume(1) + if self._telemetry_dp_rate_limit: + self._telemetry_dp_rate_limit.consume(points) + else: + if self._message_rate_limit: + self._message_rate_limit.consume(1) + except Exception as e: + logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) + self._schedule_delayed_retry(topic, payload, points, delay=1.0) + + def _schedule_delayed_retry(self, topic: str, payload: bytes, points: int, delay: float): + logger.trace("Scheduling retry: topic=%s, delay=%.2f", topic, delay) + + async def retry(): + await asyncio.sleep(delay) + try: + self._queue.put_nowait((topic, payload, points)) + self._wakeup_event.set() + logger.trace("Re-enqueued message after delay: topic=%s", topic) + except asyncio.QueueFull: + logger.warning("Retry queue full. Dropping retried message: topic=%s", topic) + + asyncio.create_task(retry()) + + async def _wait_for_message(self): + if not self._queue.empty(): + return await self._queue.get() + + self._wakeup_event.clear() + queue_task = asyncio.create_task(self._queue.get()) + wake_task = asyncio.create_task(self._wakeup_event.wait()) + done, _ = await asyncio.wait([queue_task, wake_task], return_when=asyncio.FIRST_COMPLETED) + + if queue_task in done: + wake_task.cancel() + return queue_task.result() + + # Wake event triggered — retry get + queue_task.cancel() + await asyncio.sleep(0.001) # Yield control + return await self._wait_for_message() + + async def shutdown(self): + logger.debug("Shutting down MessageQueue...") + self._active.clear() + self._loop_task.cancel() + with suppress(asyncio.CancelledError): + await self._loop_task + logger.debug("MessageQueue shutdown complete.") + + def is_empty(self): + return self._queue.empty() + + def size(self): + return self._queue.qsize() + + def clear(self): + logger.debug("Clearing message queue...") + while not self._queue.empty(): + self._queue.get_nowait() + self._queue.task_done() + logger.debug("Message queue cleared.") diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py new file mode 100644 index 0000000..00f228e --- /dev/null +++ b/tb_mqtt_client/service/message_splitter.py @@ -0,0 +1,136 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from typing import List +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.common.logging_utils import get_logger + +logger = get_logger(__name__) + + +class MessageSplitter: + def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): + if max_payload_size is None or max_payload_size <= 0: + logger.debug("Invalid max_payload_size: %s, using default 65535", max_payload_size) + max_payload_size = 65535 + if max_datapoints is None or max_datapoints < 0: + logger.debug("Invalid max_datapoints: %s, using default 0", max_datapoints) + max_datapoints = 0 + + self._max_payload_size = max_payload_size + self._max_datapoints = max_datapoints + logger.trace("MessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", + self._max_payload_size, self._max_datapoints) + + def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: + logger.trace("Splitting timeseries for %d messages", len(messages)) + result: List[DeviceUplinkMessage] = [] + + for message in messages: + if not message.has_timeseries(): + logger.trace("Message from device '%s' has no timeseries. Skipping.", message.device_name) + continue + + logger.trace("Processing timeseries from device: %s", message.device_name) + builder = None + size = 0 + point_count = 0 + + for ts in message.timeseries: + exceeds_size = builder and size + ts.size > self._max_payload_size + exceeds_points = self._max_datapoints > 0 and point_count >= self._max_datapoints + + if not builder or exceeds_size or exceeds_points: + if builder: + built = builder.build() + result.append(built) + logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) + builder = DeviceUplinkMessageBuilder() \ + .set_device_name(message.device_name) \ + .set_device_profile(message.device_profile) + size = 0 + point_count = 0 + + builder.add_telemetry(ts) + size += ts.size + point_count += 1 + logger.trace("Added timeseries entry to batch (size=%d, points=%d)", size, point_count) + + if builder and builder._timeseries: + built = builder.build() + result.append(built) + logger.trace("Flushed final batch with %d points (size=%d)", len(built.timeseries), size) + + logger.trace("Total timeseries batches created: %d", len(result)) + return result + + def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: + logger.trace("Splitting attributes for %d messages", len(messages)) + result: List[DeviceUplinkMessage] = [] + + for message in messages: + if not message.has_attributes(): + logger.trace("Message from device '%s' has no attributes. Skipping.", message.device_name) + continue + + logger.trace("Processing attributes from device: %s", message.device_name) + builder = None + size = 0 + + for attr in message.attributes: + if builder and size + attr.size > self._max_payload_size: + built = builder.build() + result.append(built) + logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) + builder = None + size = 0 + + if not builder: + builder = DeviceUplinkMessageBuilder() \ + .set_device_name(message.device_name) \ + .set_device_profile(message.device_profile) + + builder.add_attributes(attr) + size += attr.size + logger.trace("Added attribute to batch (size=%d)", size) + + if builder and builder._attributes: + built = builder.build() + result.append(built) + logger.trace("Flushed final attribute batch (count=%d, size=%d)", len(built.attributes), size) + + logger.trace("Total attribute batches created: %d", len(result)) + return result + + @property + def max_payload_size(self) -> int: + return self._max_payload_size + + @max_payload_size.setter + def max_payload_size(self, value: int): + old = self._max_payload_size + self._max_payload_size = value if value > 0 else 65535 + logger.debug("Updated max_payload_size: %d -> %d", old, self._max_payload_size) + + @property + def max_datapoints(self) -> int: + return self._max_datapoints + + @max_datapoints.setter + def max_datapoints(self, value: int): + old = self._max_datapoints + self._max_datapoints = value if value > 0 else 0 + logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py new file mode 100644 index 0000000..41c880c --- /dev/null +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -0,0 +1,302 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +import ssl +from asyncio import sleep +from typing import Optional, Callable, Awaitable, Dict, Union + +from gmqtt import Client as GMQTTClient, Message, Subscription, MQTTConnectError + +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController +from tb_mqtt_client.service.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT, TELEMETRY_MESSAGE_RATE_LIMIT, \ + TELEMETRY_DATAPOINTS_RATE_LIMIT +from tb_mqtt_client.constants.service_messages import SESSION_LIMITS_REQUEST_MESSAGE + +logger = get_logger(__name__) + + +class MQTTManager: + def __init__( + self, + client_id: str, + on_connect: Optional[Callable[[], Awaitable[None]]] = None, + on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, + rate_limits_handler: Optional[Callable[[str, bytes], Awaitable[None]]] = None, + rpc_response_handler: Optional[RPCResponseHandler] = None, + ): + self._client = GMQTTClient(client_id) + patch_gmqtt_puback(self._client, self._handle_puback_reason_code) + self._client.on_connect = self._on_connect_internal + self._client.on_disconnect = self._on_disconnect_internal + self._client.on_message = self._on_message_internal + self._client.on_publish = self._on_publish_internal + self._client.on_subscribe = self._on_subscribe_internal + self._client.on_unsubscribe = self._on_unsubscribe_internal + + self._on_connect_callback = on_connect + self._on_disconnect_callback = on_disconnect + + self._connected_event = asyncio.Event() + self._handlers: Dict[str, Callable[[str, bytes], Awaitable[None]]] = {} + + self._pending_publishes: Dict[int, asyncio.Future] = {} + self._pending_subscriptions: Dict[int, asyncio.Future] = {} + self._pending_unsubscriptions: Dict[int, asyncio.Future] = {} + self._rpc_response_handler = rpc_response_handler or RPCResponseHandler() + + self._backpressure = BackpressureController() + self.__rate_limits_handler = rate_limits_handler + self.__rate_limits_retrieved = False + self.__rate_limiter: Optional[Dict[str, RateLimit]] = None + self.__is_gateway = False # TODO: determine if this is a gateway or not + self.__is_waiting_for_rate_limits_publish = False + self._rate_limits_ready_event = asyncio.Event() + + async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, + password: Optional[str] = None, tls: bool = False, + keepalive: int = 60, ssl_context: Optional[ssl.SSLContext] = None): + try: + if username: + self._client.set_auth_credentials(username, password) + + if tls: + if ssl_context is None: + ssl_context = ssl.create_default_context() + await self._client.connect(host, port, ssl=ssl_context, keepalive=keepalive) + else: + await self._client.connect(host, port, keepalive=keepalive) + try: + await asyncio.wait_for(self._connected_event.wait(), timeout=10) + except asyncio.TimeoutError: + logger.warning("Timeout waiting for MQTT connection.") + raise + + except MQTTConnectError as e: + logger.warning("MQTT connection failed: %s", str(e)) + self._connected_event.clear() + except Exception as e: + logger.exception("Unhandled exception during MQTT connect: %s", e) + raise + + def is_connected(self) -> bool: + return self._client.is_connected and self._connected_event.is_set() and self.__rate_limits_retrieved + + async def disconnect(self): + await self._client.disconnect() + await asyncio.sleep(0.2) + self._connected_event.clear() + self.__rate_limits_retrieved = False + self.__is_waiting_for_rate_limits_publish = True + self._rate_limits_ready_event.clear() + + async def publish(self, message_or_topic: Union[str, Message], + payload: Optional[bytes] = None, + qos: int = 1, + retain: bool = False, + force=False) -> asyncio.Future: + + if not force: + if not self.__rate_limits_retrieved and not self.__is_waiting_for_rate_limits_publish: + raise RuntimeError("Cannot publish before rate limits are retrieved.") + try: + await asyncio.wait_for(self._rate_limits_ready_event.wait(), timeout=10) + except asyncio.TimeoutError: + raise RuntimeError("Timeout waiting for rate limits.") + + if not force and self._backpressure.should_pause(): + logger.warning("Backpressure active. Publishing suppressed.") + raise RuntimeError("Publishing temporarily paused due to backpressure.") + + if isinstance(message_or_topic, Message): + message = message_or_topic + else: + message = Message(message_or_topic, payload, qos=qos, retain=retain) + + mid, package = self._client._connection.publish(message) + + future = asyncio.get_event_loop().create_future() + if qos > 0: + self._pending_publishes[mid] = future + self._client._persistent_storage.push_message_nowait(mid, package) + else: + future.set_result(True) + + return future + + async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyncio.Future: + sub_future = asyncio.get_event_loop().create_future() + subscription = Subscription(topic, qos=qos) if isinstance(topic, str) else topic + + if self.__rate_limiter: + self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() + mid = self._client._connection.subscribe([subscription]) + self._pending_subscriptions[mid] = sub_future + return sub_future + + async def unsubscribe(self, topic: str) -> asyncio.Future: + unsubscribe_future = asyncio.get_event_loop().create_future() + if self.__rate_limiter: + self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() + mid = self._client._connection.unsubscribe(topic) + self._pending_unsubscriptions[mid] = unsubscribe_future + return unsubscribe_future + + def register_handler(self, topic_filter: str, handler: Callable[[str, bytes], Awaitable[None]]): + self._handlers[topic_filter] = handler + + def unregister_handler(self, topic_filter: str): + self._handlers.pop(topic_filter, None) + + def _on_connect_internal(self, client, flags, rc, properties): + logger.info("Connected to platform") + self._connected_event.set() + asyncio.create_task(self.__handle_connect_and_limits()) + + async def __handle_connect_and_limits(self): + logger.debug("Subscribing to RPC response topics") + sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, qos=1) + while not sub_future.done(): + await sleep(0.01) + sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, qos=1) + while not sub_future.done(): + await sleep(0.01) + logger.debug("Subscribing completed, sending rate limits request") + + await self.__request_rate_limits() + + if self._on_connect_callback: + await self._on_connect_callback() + + def _on_disconnect_internal(self, client, packet, exc=None): + logger.warning("Disconnected from platform") + RPCRequestIdProducer.reset() + self._rpc_response_handler.clear() + self._connected_event.clear() + self._backpressure.notify_disconnect(delay_seconds=15) + if self._on_disconnect_callback: + asyncio.create_task(self._on_disconnect_callback()) + + def _on_message_internal(self, client, topic: str, payload: bytes, qos, properties): + for topic_filter, handler in self._handlers.items(): + if self._match_topic(topic_filter, topic): + asyncio.create_task(handler(topic, payload)) + return + if topic.startswith(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC): + asyncio.create_task(self._rpc_response_handler.handle(topic, payload)) + + def _on_publish_internal(self, client, mid): + future = self._pending_publishes.pop(mid, None) + if future and not future.done(): + future.set_result(True) + + def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dict): + QUOTA_EXCEEDED = 0x97 # MQTT 5 reason code for quota exceeded + if reason_code == QUOTA_EXCEEDED: + logger.warning("PUBACK received with QUOTA_EXCEEDED for mid=%s", mid) + self._backpressure.notify_quota_exceeded(delay_seconds=10) + elif reason_code != 0: + logger.warning("PUBACK received with error code %s for mid=%s", reason_code, mid) + + def _on_subscribe_internal(self, client, mid, qos, properties): + future = self._pending_subscriptions.pop(mid, None) + if future and not future.done(): + future.set_result(True) + + def _on_unsubscribe_internal(self, client, mid): + future = self._pending_unsubscriptions.pop(mid, None) + if future and not future.done(): + future.set_result(True) + + async def await_ready(self, timeout: float = 10.0): + try: + await asyncio.wait_for(self._rate_limits_ready_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + logger.debug("Waiting for rate limits timed out.") + + def set_rate_limits( + self, + message_rate_limit: Union[RateLimit, Dict[str, RateLimit]], + telemetry_message_rate_limit: Optional[RateLimit], + telemetry_dp_rate_limit: Optional[RateLimit] + ): + self.__rate_limiter = { + MESSAGES_RATE_LIMIT: message_rate_limit, + TELEMETRY_MESSAGE_RATE_LIMIT: telemetry_message_rate_limit, + TELEMETRY_DATAPOINTS_RATE_LIMIT: telemetry_dp_rate_limit + } + self.__rate_limits_retrieved = True + self.__is_waiting_for_rate_limits_publish = False + self._rate_limits_ready_event.set() + + async def __request_rate_limits(self): + request_id = await RPCRequestIdProducer.get_next() + request_topic = f"v1/devices/me/rpc/request/{request_id}" + response_topic = f"v1/devices/me/rpc/response/{request_id}" + + logger.debug("Publishing rate limits request to: %s", request_topic) + response_future = self._rpc_response_handler.register_request(request_id) + + async def _handler(topic: str, payload: bytes): + try: + if self.__rate_limits_handler: + await self.__rate_limits_handler(topic, payload) + response_future.set_result(payload) + except Exception as e: + logger.exception("Error handling rate limits response: %s", e) + response_future.set_exception(e) + + self.register_handler(response_topic, _handler) + + try: + self.__is_waiting_for_rate_limits_publish = True + logger.debug("Requesting rate limits via RPC...") + await self.publish(request_topic, SESSION_LIMITS_REQUEST_MESSAGE, qos=1, force=True) + await asyncio.wait_for(response_future, timeout=10) + logger.info("Successfully processed rate limits.") + self.__rate_limits_retrieved = True + self.__is_waiting_for_rate_limits_publish = False + self._rate_limits_ready_event.set() + except asyncio.TimeoutError: + logger.warning("Timeout while waiting for rate limits.") + finally: + self.unregister_handler(response_topic) + self.__is_waiting_for_rate_limits_publish = False + + @property + def backpressure(self) -> BackpressureController: + return self._backpressure + + @staticmethod + def _match_topic(filter: str, topic: str) -> bool: + filter_parts = filter.split('/') + topic_parts = topic.split('/') + + for i, filter_part in enumerate(filter_parts): + if filter_part == '#': + return True + if i >= len(topic_parts): + return False + if filter_part != '+' and filter_part != topic_parts[i]: + return False + + return len(filter_parts) == len(topic_parts) diff --git a/tb_mqtt_client/service/rpc_response_handler.py b/tb_mqtt_client/service/rpc_response_handler.py new file mode 100644 index 0000000..762c2ae --- /dev/null +++ b/tb_mqtt_client/service/rpc_response_handler.py @@ -0,0 +1,72 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Dict, Union + +from tb_mqtt_client.common.logging_utils import get_logger +from orjson import loads + +logger = get_logger(__name__) + + +class RPCResponseHandler: + """ + Handles RPC responses coming from the platform to the client (client-side RPCs responses). + Maintains an internal map of request_id -> asyncio.Future for awaiting RPC results. + """ + + def __init__(self): + self._pending_requests: Dict[Union[str, int], asyncio.Future] = {} + + def register_request(self, request_id: Union[str, int]) -> asyncio.Future: + """ + Called when a request is sent to the platform and a response is awaited. + """ + if request_id in self._pending_requests: + raise RuntimeError(f"Request ID {request_id} is already registered.") + future = asyncio.get_event_loop().create_future() + self._pending_requests[request_id] = future + return future + + async def handle(self, topic: str, payload: bytes): + """ + Handles the incoming RPC response from the platform and fulfills the corresponding future. + The topic is expected to be: v1/devices/me/rpc/response/{request_id} + """ + try: + request_id = topic.split("/")[-1] + response_data = loads(payload) + + future = self._pending_requests.pop(request_id, None) + if not future: + logger.warning("No future awaiting request ID %s. Ignoring.", request_id) + return + + if isinstance(response_data, dict) and "error" in response_data: + future.set_exception(Exception(response_data["error"])) + else: + future.set_result(response_data) + + except Exception as e: + logger.exception("Failed to handle RPC response: %s", e) + + def clear(self): + """ + Clears all pending futures (e.g. on disconnect). + """ + for fut in self._pending_requests.values(): + if not fut.done(): + fut.cancel() + self._pending_requests.clear() diff --git a/tb_mqtt_client/tb_device_mqtt.py b/tb_mqtt_client/tb_device_mqtt.py new file mode 100644 index 0000000..6d6dc0d --- /dev/null +++ b/tb_mqtt_client/tb_device_mqtt.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/constants/__init__.py b/tests/constants/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tests/constants/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/constants/test_mqtt_topics.py b/tests/constants/test_mqtt_topics.py new file mode 100644 index 0000000..293d4e5 --- /dev/null +++ b/tests/constants/test_mqtt_topics.py @@ -0,0 +1,23 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from tb_mqtt_client.constants import mqtt_topics + + +def test_device_topic_builders(): + assert mqtt_topics.build_device_attributes_request_topic(42) == "v1/devices/me/attributes/request/42" + assert mqtt_topics.build_device_rpc_request_topic(99) == "v1/devices/me/rpc/request/99" + assert mqtt_topics.build_device_rpc_response_topic(99) == "v1/devices/me/rpc/response/99" \ No newline at end of file diff --git a/tests/service/__init__.py b/tests/service/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tests/service/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/service/device/__init__.py b/tests/service/device/__init__.py new file mode 100644 index 0000000..59e3bfa --- /dev/null +++ b/tests/service/device/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/service/device/test_device_client_rate_limits.py b/tests/service/device/test_device_client_rate_limits.py new file mode 100644 index 0000000..cd0a00d --- /dev/null +++ b/tests/service/device/test_device_client_rate_limits.py @@ -0,0 +1,77 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from orjson import dumps + +from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit + + +@pytest.fixture +def device_client(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "test_token" + config.client_id = "" + return DeviceClient(config) + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response_valid(device_client): + payload = { + "rateLimits": { + "messages": "10:1,300:60", + "telemetryMessages": "20:1,600:60", + "telemetryDataPoints": "100:1,1000:60" + }, + "maxInflightMessages": 100, + "maxPayloadSize": 2048 + } + + topic = "v1/devices/me/rpc/response/1" + await device_client._handle_rate_limit_response(topic, dumps(payload)) + + assert isinstance(device_client._messages_rate_limit, RateLimit) + assert isinstance(device_client._telemetry_rate_limit, RateLimit) + assert isinstance(device_client._telemetry_dp_rate_limit, RateLimit) + + assert device_client._messages_rate_limit.has_limit() + assert device_client._telemetry_rate_limit.has_limit() + assert device_client._telemetry_dp_rate_limit.has_limit() + + assert device_client._max_inflight_messages > 0 + assert device_client._max_queued_messages == device_client._max_inflight_messages + assert device_client.max_payload_size == int(2048 * device_client._telemetry_rate_limit.percentage / 100) + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response_invalid_payload(device_client, caplog): + topic = "v1/devices/me/rpc/response/1" + await device_client._handle_rate_limit_response(topic, b'invalid-json') + + assert not device_client._messages_rate_limit.has_limit() + assert "Failed to parse rate limits" in caplog.text + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response_missing_rate_limits(device_client, caplog): + payload = {"maxInflightMessages": 100} + topic = "v1/devices/me/rpc/response/1" + await device_client._handle_rate_limit_response(topic, dumps(payload)) + + assert "Invalid rate limit response" in caplog.text + assert device_client._max_inflight_messages == 100 # default fallback still applies \ No newline at end of file diff --git a/tests/service/test_json_message_dispatcher.py b/tests/service/test_json_message_dispatcher.py new file mode 100644 index 0000000..9592616 --- /dev/null +++ b/tests/service/test_json_message_dispatcher.py @@ -0,0 +1,92 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher +from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@pytest.fixture +def dispatcher(): + return JsonMessageDispatcher(max_payload_size=512, max_datapoints=10) + + +def test_single_telemetry_dispatch(dispatcher): + builder = DeviceUplinkMessageBuilder().set_device_name("dev1") + builder.add_telemetry(TimeseriesEntry("temp", 25)) + msg = builder.build() + + payloads = dispatcher.build_topic_payloads([msg]) + assert len(payloads) == 1 + topic, payload, count = payloads[0] + assert topic == DEVICE_TELEMETRY_TOPIC + assert b"dev1" in payload + assert count == 1 + + +def test_single_attribute_dispatch(dispatcher): + builder = DeviceUplinkMessageBuilder().set_device_name("dev2") + builder.add_attributes(AttributeEntry("mode", "auto")) + msg = builder.build() + + payloads = dispatcher.build_topic_payloads([msg]) + assert len(payloads) == 1 + topic, payload, count = payloads[0] + assert topic == DEVICE_ATTRIBUTES_TOPIC + assert b"dev2" in payload + assert count == 1 + + +def test_multiple_devices_grouping(dispatcher): + b1 = DeviceUplinkMessageBuilder().set_device_name("dev1") + b1.add_telemetry(TimeseriesEntry("t1", 1)) + b2 = DeviceUplinkMessageBuilder().set_device_name("dev2") + b2.add_telemetry(TimeseriesEntry("t2", 2)) + + payloads = dispatcher.build_topic_payloads([b1.build(), b2.build()]) + assert len(payloads) == 2 + for topic, payload, count in payloads: + assert topic == DEVICE_TELEMETRY_TOPIC + assert count == 1 + + +def test_large_telemetry_split(dispatcher): + builder = DeviceUplinkMessageBuilder().set_device_name("splittest") + for i in range(15): + builder.add_telemetry(TimeseriesEntry(f"key{i}", i)) + + payloads = dispatcher.build_topic_payloads([builder.build()]) + assert len(payloads) > 1 + for topic, payload, count in payloads: + assert topic == DEVICE_TELEMETRY_TOPIC + assert count <= dispatcher.splitter.max_datapoints + + +def test_large_attributes_split(): + dispatcher = JsonMessageDispatcher(max_payload_size=200) + + builder = DeviceUplinkMessageBuilder().set_device_name("splitattr") + for i in range(20): + builder.add_attributes(AttributeEntry(f"k{i}", "x" * 50)) # Increase size + + payloads = dispatcher.build_topic_payloads([builder.build()]) + assert len(payloads) > 1 # Now expect splitting + for topic, payload, count in payloads: + assert topic == DEVICE_ATTRIBUTES_TOPIC + assert count > 0 \ No newline at end of file diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py new file mode 100644 index 0000000..535b8c6 --- /dev/null +++ b/tests/service/test_message_splitter.py @@ -0,0 +1,95 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import pytest +from tb_mqtt_client.service.message_splitter import MessageSplitter +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + + +@pytest.mark.parametrize("max_payload_size,max_datapoints", [(100, 3)]) +def test_split_large_telemetry(max_payload_size, max_datapoints): + splitter = MessageSplitter(max_payload_size=max_payload_size, max_datapoints=max_datapoints) + + builder = DeviceUplinkMessageBuilder().set_device_name("device1") + for i in range(10): + builder.add_telemetry(TimeseriesEntry(f"k{i}", i)) + + message = builder.build() + split = splitter.split_timeseries([message]) + + assert len(split) > 1 + total_ts = sum(len(m.timeseries) for m in split) + assert total_ts == 10 + for m in split: + assert m.device_name == "device1" + + +def test_split_large_attributes(): + splitter = MessageSplitter(max_payload_size=100) + + builder = DeviceUplinkMessageBuilder().set_device_name("deviceA") + for i in range(20): + builder.add_attributes(AttributeEntry(f"attr{i}", "val" * 10)) + + message = builder.build() + split = splitter.split_attributes([message]) + + assert len(split) > 1 + total_attrs = sum(len(m.attributes) for m in split) + assert total_attrs == 20 + for m in split: + assert m.device_name == "deviceA" + + +def test_no_split_needed(): + splitter = MessageSplitter(max_payload_size=10000, max_datapoints=100) + + builder = DeviceUplinkMessageBuilder().set_device_name("simpleDevice") + builder.add_telemetry(TimeseriesEntry("temp", 23)) + builder.add_attributes(AttributeEntry("fw", "1.0.0")) + + message = builder.build() + result_attr = splitter.split_attributes([message]) + result_ts = splitter.split_timeseries([message]) + + assert len(result_attr) == 1 + assert len(result_ts) == 1 + assert result_attr[0].device_name == "simpleDevice" + assert result_ts[0].device_name == "simpleDevice" + assert len(result_attr[0].attributes) == 1 + assert len(result_ts[0].timeseries) == 1 + + +def test_mixed_split(): + splitter = MessageSplitter(max_payload_size=120, max_datapoints=2) + builder = DeviceUplinkMessageBuilder().set_device_name("mixed") + + # Increase attribute value size to ensure payload size > 120 + for i in range(5): + builder.add_attributes(AttributeEntry(f"a{i}", "x" * 50)) + builder.add_telemetry(TimeseriesEntry(f"t{i}", i)) + + msg = builder.build() + result_attr = splitter.split_attributes([msg]) + result_ts = splitter.split_timeseries([msg]) + + # Assert that the attributes and telemetry are split into multiple messages + assert len(result_attr) > 1, f"Expected split, got {len(result_attr)}" + assert len(result_ts) > 1, f"Expected split, got {len(result_ts)}" + assert sum(len(r.attributes) for r in result_attr) == 5 + assert sum(len(r.timeseries) for r in result_ts) == 5 diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py new file mode 100644 index 0000000..c2364ab --- /dev/null +++ b/tests/service/test_mqtt_manager.py @@ -0,0 +1,152 @@ +# Copyright 2025. ThingsBoard +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import asyncio +import pytest +from unittest.mock import MagicMock + +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.service.mqtt_manager import MQTTManager +from gmqtt import Message + + +@pytest.fixture +def mqtt_manager(): + manager = MQTTManager("test-client") + manager._client._connection = MagicMock() + manager._client._persistent_storage = MagicMock() + return manager + + +@pytest.mark.asyncio +async def test_publish_acknowledgment(mqtt_manager): + fake_mid = 101 + message = Message("v1/devices/me/telemetry", b'{"temp":25}', qos=1) + + mqtt_manager._client._connection.publish = MagicMock(return_value=(fake_mid, b"fake_package")) + mqtt_manager.set_rate_limits(message_rate_limit=RateLimit("0:0"), + telemetry_message_rate_limit=None, + telemetry_dp_rate_limit=None) + + future = await mqtt_manager.publish(message) + + assert fake_mid in mqtt_manager._pending_publishes + mqtt_manager._on_publish_internal(mqtt_manager._client, fake_mid) + + result = await future + assert result is True + assert fake_mid not in mqtt_manager._pending_publishes + + +@pytest.mark.asyncio +async def test_subscribe_acknowledgment(mqtt_manager): + fake_mid = 202 + mqtt_manager._client._connection.subscribe = MagicMock(return_value=fake_mid) + + future = await mqtt_manager.subscribe("v1/devices/me/rpc/request/+") + + assert fake_mid in mqtt_manager._pending_subscriptions + mqtt_manager._on_subscribe_internal(mqtt_manager._client, fake_mid, 1, None) + + result = await future + assert result is True + assert fake_mid not in mqtt_manager._pending_subscriptions + + +@pytest.mark.asyncio +async def test_unsubscribe_acknowledgment(mqtt_manager): + fake_mid = 303 + mqtt_manager._client._connection.unsubscribe = MagicMock(return_value=fake_mid) + + future = await mqtt_manager.unsubscribe("v1/devices/me/rpc/request/+") + + assert fake_mid in mqtt_manager._pending_unsubscriptions + mqtt_manager._on_unsubscribe_internal(mqtt_manager._client, fake_mid) + + result = await future + assert result is True + assert fake_mid not in mqtt_manager._pending_unsubscriptions + + +@pytest.mark.asyncio +async def test_topic_handler_matching(mqtt_manager): + called = asyncio.Event() + + async def handler(topic, payload): + assert topic == "v1/devices/me/rpc/request/42" + assert payload == b"payload" + called.set() + + mqtt_manager.register_handler("v1/devices/me/rpc/request/+", handler) + + mqtt_manager._on_message_internal( + mqtt_manager._client, + topic="v1/devices/me/rpc/request/42", + payload=b"payload", + qos=1, + properties=None + ) + + await asyncio.wait_for(called.wait(), timeout=1.0) + + +@pytest.mark.asyncio +async def test_global_rate_limit_allows_publish(mqtt_manager): + rate_limiter = RateLimit("2:10") + mqtt_manager.set_rate_limits(rate_limiter, + telemetry_message_rate_limit=None, + telemetry_dp_rate_limit=None) + + mqtt_manager._client._connection.publish = MagicMock(return_value=(1002, b"pkg")) + future = await mqtt_manager.publish("topic", b'data') + + assert isinstance(future, asyncio.Future) + + +@pytest.mark.asyncio +async def test_gateway_rate_limit_per_device_allows(mqtt_manager): + rl_a = RateLimit("2:60") + rl_b = RateLimit("5:60") + + mqtt_manager.set_rate_limits({"deviceA": rl_a, "deviceB": rl_b}, + telemetry_message_rate_limit=None, + telemetry_dp_rate_limit=None) + mqtt_manager._client._connection.publish = MagicMock(return_value=(1004, b"pkg")) + + future = await mqtt_manager.publish("topic", b'data') + assert isinstance(future, asyncio.Future) + + +@pytest.mark.asyncio +async def test_publish_before_rate_limit_not_retrieved(mqtt_manager): + mqtt_manager._client._connection.publish = MagicMock(return_value=(1005, b"pkg")) + mqtt_manager._client._persistent_storage = MagicMock() + + # Should raise unless it's the special rate-limit request + with pytest.raises(RuntimeError, match="Cannot publish before rate limits"): + await mqtt_manager.publish("v1/devices/me/telemetry", b'data') + + +@pytest.mark.asyncio +async def test_publish_during_rate_limit_request_allowed(mqtt_manager): + # Simulate internal state for initial rate limit request + mqtt_manager._client._connection.publish = MagicMock(return_value=(1006, b"pkg")) + mqtt_manager._client._persistent_storage = MagicMock() + mqtt_manager._MQTTManager__is_waiting_for_rate_limits_publish = True + mqtt_manager._rate_limits_ready_event.set() + + future = await mqtt_manager.publish("v1/devices/me/rpc/request/1", b'data') + assert isinstance(future, asyncio.Future) \ No newline at end of file From fa614d0bb09f2312588a7d03f643c58917b9130f Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 28 May 2025 08:50:02 +0300 Subject: [PATCH 02/74] Updated processing rate limitations, data sending and messages delivery confirmation --- tb_mqtt_client/common/config_loader.py | 28 +- tb_mqtt_client/common/gmqtt_patch.py | 205 ++++++++- .../rate_limit/backpressure_controller.py | 49 ++- .../common/rate_limit/rate_limit.py | 30 +- tb_mqtt_client/constants/mqtt_topics.py | 11 +- .../entities/data/device_uplink_message.py | 53 ++- tb_mqtt_client/service/device/client.py | 36 +- tb_mqtt_client/service/gateway/client.py | 399 ++++++++++++++++++ tb_mqtt_client/service/message_dispatcher.py | 94 +++-- tb_mqtt_client/service/message_queue.py | 299 +++++++++---- tb_mqtt_client/service/message_splitter.py | 73 +++- tb_mqtt_client/service/mqtt_manager.py | 123 ++++-- 12 files changed, 1204 insertions(+), 196 deletions(-) diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index 3c4383d..9be28f9 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -14,7 +14,7 @@ # import os -from typing import Optional +from typing import Optional, Dict, Any class DeviceConfig: @@ -46,3 +46,29 @@ def __repr__(self): f"auth={'token' if self.access_token else 'user/pass'} " f"tls_auth={self.use_tls_auth()} " f"tls={self.use_tls()}>") + + +class GatewayConfig(DeviceConfig): + def __init__(self): + super().__init__() + + # Gateway-specific options + self.gateway_name: Optional[str] = os.getenv("TB_GATEWAY_NAME") + + # Rate limits for devices connected through the gateway + self.device_messages_rate_limit: Optional[str] = os.getenv("TB_DEVICE_MESSAGES_RATE_LIMIT") + self.device_telemetry_rate_limit: Optional[str] = os.getenv("TB_DEVICE_TELEMETRY_RATE_LIMIT") + self.device_telemetry_dp_rate_limit: Optional[str] = os.getenv("TB_DEVICE_TELEMETRY_DP_RATE_LIMIT") + + # Default device type for auto-registered devices + self.default_device_type: Optional[str] = os.getenv("TB_DEFAULT_DEVICE_TYPE", "default") + + # Whether to automatically register new devices + self.auto_register_devices: bool = os.getenv("TB_AUTO_REGISTER_DEVICES", "true").lower() == "true" + + def __repr__(self): + return (f"") diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index 5c8dc01..a2ea769 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -15,6 +15,7 @@ import struct +import asyncio from types import MethodType from typing import Callable from collections import defaultdict @@ -22,9 +23,210 @@ from tb_mqtt_client.common.logging_utils import get_logger from gmqtt.mqtt.property import Property from gmqtt.mqtt.utils import unpack_variable_byte_integer +from gmqtt.mqtt.protocol import BaseMQTTProtocol, MQTTProtocol +from gmqtt.mqtt.handler import MqttPackageHandler +from gmqtt.mqtt.constants import MQTTCommands logger = get_logger(__name__) +# MQTT 5.0 Disconnect Reason Codes +DISCONNECT_REASON_CODES = { + 0: "Normal disconnection", + 4: "Disconnect with Will Message", + 128: "Unspecified error", + 129: "Malformed Packet", + 130: "Protocol Error", + 131: "Implementation specific error", + 132: "Not authorized", + 133: "Server busy", + 134: "Server shutting down", + 135: "Keep Alive timeout", + 136: "Session taken over", + 137: "Topic Filter invalid", + 138: "Topic Name invalid", + 139: "Receive Maximum exceeded", + 140: "Topic Alias invalid", + 141: "Packet too large", + 142: "Session taken over", + 143: "Quota exceeded", + 144: "Administrative action", + 145: "Payload format invalid", + 146: "Retain not supported", + 147: "QoS not supported", + 148: "Use another server", + 149: "Server moved", + 150: "Shared Subscriptions not supported", + 151: "Connection rate exceeded", + 152: "Maximum connect time", + 153: "Subscription Identifiers not supported", + 154: "Wildcard Subscriptions not supported" +} + + +def extract_reason_code(packet): + """ + Extract the reason code from a disconnect packet. + + :param packet: The disconnect packet, which can be an object with a reason_code attribute or raw bytes + :return: The reason code if found, None otherwise + """ + reason_code = None + if packet: + if hasattr(packet, 'reason_code'): + reason_code = packet.reason_code + elif isinstance(packet, bytes) and len(packet) >= 2: + reason_code = packet[1] + + return reason_code + +def patch_mqtt_handler_disconnect(): + """ + Monkey-patch gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet to properly + handle server-initiated disconnect messages. + """ + try: + # Store the original method + original_handle_disconnect = MqttPackageHandler._handle_disconnect_packet + + # Define the patched method + def patched_handle_disconnect_packet(self, cmd, packet): + # Extract reason code + reason_code = 0 + if packet and len(packet) >= 1: + reason_code = packet[0] + + # Parse properties if available + properties = {} + if packet and len(packet) > 1: + try: + properties, _ = self._parse_properties(packet[1:]) + except Exception as e: + logger.warning("Failed to parse properties from disconnect packet: %s", e) + + reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + logger.debug("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) + + # Call the original method to handle reconnection + # But don't call the on_disconnect callback, as we'll do that ourselves + # with the extracted reason_code and properties + self._clear_topics_aliases() + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + + # Call the on_disconnect callback with the client, reason_code, properties, and None for exc + # since this is a server-initiated disconnect + self.on_disconnect(self, reason_code, properties, None) + + # Set a flag on the connection object to indicate that on_disconnect has been called + self._connection._on_disconnect_called = True + + # Apply the patch + MqttPackageHandler._handle_disconnect_packet = patched_handle_disconnect_packet + logger.debug("Successfully patched gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt handler: %s", e) + return False + +def patch_gmqtt_protocol_connection_lost(): + """ + Monkey-patch gmqtt.mqtt.protocol.BaseMQTTProtocol.connection_lost to suppress the + default "[CONN CLOSE NORMALLY]" log message, as we handle disconnect logging in our code. + + Also patch MQTTProtocol.connection_lost to include the reason code in the DISCONNECT package + and pass the exception to the handler. + """ + try: + original_base_connection_lost = BaseMQTTProtocol.connection_lost + def patched_base_connection_lost(self, exc): + self._connected.clear() + super(BaseMQTTProtocol, self).connection_lost(exc) + BaseMQTTProtocol.connection_lost = patched_base_connection_lost + + original_mqtt_connection_lost = MQTTProtocol.connection_lost + def patched_mqtt_connection_lost(self, exc): + super(MQTTProtocol, self).connection_lost(exc) + reason_code = 0 + properties = {} + + if exc: + # Determine reason code based on exception type + if isinstance(exc, ConnectionRefusedError): + reason_code = 135 # Keep Alive timeout + elif isinstance(exc, TimeoutError): + reason_code = 135 # Keep Alive timeout + elif isinstance(exc, ConnectionResetError): + reason_code = 139 # Receive Maximum exceeded + elif isinstance(exc, ConnectionAbortedError): + reason_code = 136 # Session taken over + elif isinstance(exc, PermissionError): + reason_code = 132 # Not authorized + elif isinstance(exc, OSError): + reason_code = 130 # Protocol Error + else: + reason_code = 131 # Implementation specific error + + # Add exception message to properties if available + if hasattr(exc, 'args') and exc.args: + properties['reason_string'] = [str(exc.args[0])] + + # Pack the reason code into a payload + payload = struct.pack('!B', reason_code) + + # Store the exception and properties in the connection object + # so they can be accessed by the handler + self._connection._disconnect_exc = exc + self._connection._disconnect_properties = properties + + # Put the DISCONNECT package into the connection's package queue + self._connection.put_package((MQTTCommands.DISCONNECT, payload)) + + if self._read_loop_future is not None: + self._read_loop_future.cancel() + self._read_loop_future = None + + self._queue = asyncio.Queue() + + MQTTProtocol.connection_lost = patched_mqtt_connection_lost + + # Also patch MqttPackageHandler.__call__ to pass the exception and properties to on_disconnect + original_call = MqttPackageHandler.__call__ + def patched_call(self, cmd, packet): + try: + if cmd == MQTTCommands.DISCONNECT and hasattr(self._connection, '_disconnect_exc'): + # This is a disconnect packet from connection_lost + # Extract reason code + reason_code = 0 + if packet and len(packet) >= 1: + reason_code = packet[0] + + # Get properties and exception from connection + properties = getattr(self._connection, '_disconnect_properties', {}) + exc = getattr(self._connection, '_disconnect_exc', None) + + # Check if on_disconnect has already been called + if not hasattr(self._connection, '_on_disconnect_called') or not self._connection._on_disconnect_called: + # Call on_disconnect with the extracted values + self._clear_topics_aliases() + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + self.on_disconnect(self, reason_code, properties, exc) + return + + # For other commands, call the original method + return original_call(self, cmd, packet) + except Exception as e: + logger.error('[ERROR HANDLE PKG]', exc_info=e) + return None + + MqttPackageHandler.__call__ = patched_call + + logger.debug("Successfully patched gmqtt.mqtt.protocol connection_lost methods") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt protocol: %s", e) + return False + def patch_gmqtt_puback(client, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): """ @@ -85,4 +287,5 @@ def wrapped_handle_puback(self, cmd, packet): return base_method(self, cmd, packet) - client._handle_puback_packet = MethodType(wrapped_handle_puback, client) + MqttPackageHandler._handle_puback_packet = wrapped_handle_puback + # client._handle_puback_packet = MethodType(wrapped_handle_puback, client) diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py index 5e33fc3..64ba2b8 100644 --- a/tb_mqtt_client/common/rate_limit/backpressure_controller.py +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -26,24 +26,65 @@ class BackpressureController: def __init__(self): self._pause_until: Optional[datetime] = None self._default_pause_duration = timedelta(seconds=10) + self._consecutive_quota_exceeded = 0 + self._last_quota_exceeded = datetime.now(UTC) + self._max_backoff_seconds = 3600 # 1 hour maximum backoff def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): - duration = timedelta(seconds=delay_seconds) if delay_seconds else self._default_pause_duration - self._pause_until = datetime.now(UTC) + duration + now = datetime.now(UTC) + # If we've had a quota exceeded event in the last 60 seconds, increment the counter + if (now - self._last_quota_exceeded).total_seconds() < 60: + self._consecutive_quota_exceeded += 1 + else: + # Reset counter if it's been more than 60 seconds since the last quota exceeded event + self._consecutive_quota_exceeded = 1 + + self._last_quota_exceeded = now + + # Apply exponential backoff based on consecutive quota exceeded events + if delay_seconds is None: + # Start with default duration and apply exponential backoff + backoff_factor = min(2 ** (self._consecutive_quota_exceeded - 1), 10) + delay_seconds = int(self._default_pause_duration.total_seconds() * backoff_factor) + # Cap at max backoff + delay_seconds = min(delay_seconds, self._max_backoff_seconds) + + logger.warning("Applying backpressure for %d seconds (consecutive quota exceeded: %d)", + delay_seconds, self._consecutive_quota_exceeded) + + duration = timedelta(seconds=delay_seconds) + self._pause_until = now + duration def notify_disconnect(self, delay_seconds: Optional[int] = None): - self.notify_quota_exceeded(delay_seconds) + if delay_seconds is None: + delay_seconds = int(self._default_pause_duration.total_seconds()) + + duration = timedelta(seconds=delay_seconds) + self._pause_until = datetime.now(UTC) + duration + logger.debug("Pausing publishing for %d seconds due to disconnect", delay_seconds) def should_pause(self) -> bool: if self._pause_until is None: return False - if datetime.now(UTC) < self._pause_until: + + now = datetime.now(UTC) + if now < self._pause_until: + remaining = (self._pause_until - now).total_seconds() + if remaining > 10: # Only log if more than 10 seconds remaining + logger.debug("Backpressure active: pausing publishing for %.1f more seconds", remaining) return True + + # Reset pause state self._pause_until = None + logger.info("Backpressure released, resuming publishing") return False def pause_for(self, seconds: int): self._pause_until = datetime.now(UTC) + timedelta(seconds=seconds) + logger.info("Manually pausing publishing for %d seconds", seconds) def clear(self): + if self._pause_until is not None: + logger.info("Clearing backpressure pause") self._pause_until = None + self._consecutive_quota_exceeded = 0 diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py index a0b8edf..a8cc4b1 100644 --- a/tb_mqtt_client/common/rate_limit/rate_limit.py +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -107,6 +107,34 @@ def check_limit_reached(self, amount=1): result = (bucket.capacity, dur) return result + def refill(self): + """Force refill of all token buckets without consuming any tokens.""" + if self._no_limit: + return + with self._lock: + for bucket in self._rate_buckets.values(): + bucket.refill() + + def try_consume(self, amount=1): + """ + Try to consume tokens from all buckets. + Returns True if all buckets had enough tokens and they were consumed. + Returns False if any bucket didn't have enough tokens. + """ + if self._no_limit: + return None + + with self._lock: + for bucket in self._rate_buckets.values(): + bucket.refill() + if bucket.tokens < amount: + return bucket.capacity, bucket.duration + + for bucket in self._rate_buckets.values(): + bucket.tokens -= amount + + return None + def consume(self, amount=1): if self._no_limit: return @@ -146,7 +174,7 @@ def reach_limit(self): self.__reached_index += 1 logger.info("Rate limit reached for \"%s\". Cooldown for %s seconds", self.name, dur) - return self.__reached_index, self.__reached_index_time + return self.__reached_index, self.__reached_index_time, dur def to_dict(self): return { diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index dc703cc..f757263 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -32,12 +32,15 @@ # V1 Topics for Gateway API BASE_GATEWAY_TOPIC = "v1/gateway" GATEWAY_CONNECT_TOPIC = BASE_GATEWAY_TOPIC + "/connect" -GATEWAY_DISCONNECT_TOPIC = BASE_GATEWAY_TOPIC + "disconnect" -GATEWAY_TELEMETRY_TOPIC = BASE_GATEWAY_TOPIC + "telemetry" -GATEWAY_ATTRIBUTES_TOPIC = BASE_GATEWAY_TOPIC + "attributes" +GATEWAY_DISCONNECT_TOPIC = BASE_GATEWAY_TOPIC + "/disconnect" +GATEWAY_TELEMETRY_TOPIC = BASE_GATEWAY_TOPIC + "/telemetry" +GATEWAY_ATTRIBUTES_TOPIC = BASE_GATEWAY_TOPIC + "/attributes" GATEWAY_ATTRIBUTES_REQUEST_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX -GATEWAY_RPC_TOPIC = BASE_GATEWAY_TOPIC + "rpc" +GATEWAY_RPC_TOPIC = BASE_GATEWAY_TOPIC + "/rpc" +GATEWAY_RPC_REQUEST_TOPIC = GATEWAY_RPC_TOPIC + REQUEST_TOPIC_SUFFIX +GATEWAY_RPC_RESPONSE_TOPIC = GATEWAY_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX +GATEWAY_CLAIM_TOPIC = BASE_GATEWAY_TOPIC + "/claim" # Topic Builders diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index 656d73f..7b9472b 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -14,10 +14,15 @@ # -from typing import List, Optional, Union +import asyncio +from typing import List, Optional, Union, OrderedDict, Dict + +from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +logger = get_logger(__name__) + DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) @@ -26,8 +31,9 @@ def __init__(self, device_name: Optional[str] = None, device_profile: Optional[str] = None, attributes: Optional[List[AttributeEntry]] = None, - timeseries: Optional[List[TimeseriesEntry]] = None, - _size: Optional[int] = None): + timeseries: Optional[OrderedDict[int, List[TimeseriesEntry]]] = None, + _size: Optional[int] = None, + delivery_future: List[Optional[asyncio.Future[bool]]] = None): if _size is None: raise ValueError("DeviceUplinkMessage must be created using DeviceUplinkMessageBuilder") @@ -36,16 +42,24 @@ def __init__(self, self.attributes = attributes or [] self.timeseries = timeseries or [] self.__size = _size + self.delivery_futures = delivery_future or [] + def timeseries_datapoint_count(self) -> int: return len(self.timeseries) + def attributes_datapoint_count(self) -> int: + return len(self.attributes) + def has_attributes(self) -> bool: return bool(self.attributes) def has_timeseries(self) -> bool: return bool(self.timeseries) + def get_delivery_futures(self): + return self.delivery_futures + @property def size(self) -> int: return self.__size @@ -56,7 +70,8 @@ def __init__(self): self._device_name: Optional[str] = None self._device_profile: Optional[str] = None self._attributes: List[AttributeEntry] = [] - self._timeseries: List[TimeseriesEntry] = [] + self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() + self._delivery_futures: List[Optional[asyncio.Future[bool]]] = [] self.__size = DEFAULT_FIELDS_SIZE def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': @@ -79,19 +94,43 @@ def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]] self.__size += attribute.size return self - def add_telemetry(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry]]) -> 'DeviceUplinkMessageBuilder': + def add_telemetry(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry], OrderedDict[int, List[TimeseriesEntry]]]) -> 'DeviceUplinkMessageBuilder': + if isinstance(timeseries, OrderedDict): + self._timeseries = timeseries + return self if not isinstance(timeseries, list): timeseries = [timeseries] - self._timeseries.extend(timeseries) + for entry in timeseries: + if entry.ts is not None: + if entry.ts in self._timeseries: + self._timeseries[entry.ts].append(entry) + else: + self._timeseries[entry.ts] = [entry] + else: + if 0 in self._timeseries: + self._timeseries[0].append(entry) + else: + self._timeseries[0] = [entry] for timeseries_entry in timeseries: self.__size += timeseries_entry.size return self + def add_delivery_futures(self, future: Union[asyncio.Future[bool], List[asyncio.Future[bool]]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(future, list): + future = [future] + if future: + logger.exception("Created delivery future: %s", id(future[0])) + self._delivery_futures.extend(future) + return self + def build(self) -> DeviceUplinkMessage: + if not self._delivery_futures: + self._delivery_futures = [asyncio.get_event_loop().create_future()] return DeviceUplinkMessage( device_name=self._device_name, device_profile=self._device_profile, attributes=self._attributes, timeseries=self._timeseries, - _size=self.__size + _size=self.__size, + delivery_future=self._delivery_futures ) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 8c82a16..8eeea7b 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -103,6 +103,11 @@ async def connect(self): await self._on_connect() + # Initialize with default max_payload_size if not set + if self.max_payload_size is None: + self.max_payload_size = 65535 + logger.debug("Using default max_payload_size: %d", self.max_payload_size) + self._dispatcher = JsonMessageDispatcher(self.max_payload_size, self._telemetry_dp_rate_limit.minimal_limit) self._message_queue = MessageQueue( mqtt_manager=self._mqtt_manager, @@ -123,15 +128,17 @@ async def send_telemetry(self, telemetry_data: Union[Dict[str, Any], List[TimeseriesEntry], List[Dict[str, Any]]]): message = self._build_uplink_message_for_telemetry(telemetry_data) - await self._message_queue.publish(topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, - payload=message, - datapoints_count=message.timeseries_datapoint_count()) + futures = await self._message_queue.publish(topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + payload=message, + datapoints_count=message.timeseries_datapoint_count()) + return futures[0] if futures else None async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]]): message = self._build_uplink_message_for_attributes(attributes) - await self._message_queue.publish(topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, - payload=message, - datapoints_count=len(message.attributes)) + futures = await self._message_queue.publish(topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, + payload=message, + datapoints_count=message.attributes_datapoint_count()) + return futures[0] if futures else None async def send_rpc_response(self, response: RPCResponse): topic = mqtt_topics.build_device_rpc_response_topic(request_id=response.request_id) @@ -157,7 +164,7 @@ async def _on_connect(self): self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_response) # noqa async def _on_disconnect(self): - logger.warning("Device client disconnected") + logger.info("Device client disconnected.") async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Any: """ @@ -224,6 +231,21 @@ async def _handle_rate_limit_response(self, topic: str, payload: bytes): if "maxPayloadSize" in response: self.max_payload_size = int(response["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + # Update the dispatcher's max_payload_size if it's already initialized + if hasattr(self, '_dispatcher') and self._dispatcher is not None: + self._dispatcher.splitter.max_payload_size = self.max_payload_size + logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) + else: + # If maxPayloadSize is not provided, keep the default value + logger.debug("No maxPayloadSize in service config, using default: %d", self.max_payload_size) + # Initialize with default max_payload_size if not set + if self.max_payload_size is None: + self.max_payload_size = 65535 + logger.debug("Using default max_payload_size: %d", self.max_payload_size) + # Update the dispatcher's max_payload_size if it's already initialized + if hasattr(self, '_dispatcher') and self._dispatcher is not None: + self._dispatcher.splitter.max_payload_size = self.max_payload_size + logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) if (not self._messages_rate_limit.has_limit() and not self._telemetry_rate_limit.has_limit() diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 59e3bfa..b6f03de 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -12,3 +12,402 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from asyncio import sleep +from typing import Callable, Awaitable, Optional, Dict, Any, Union, List, Set + +from orjson import dumps, loads +from random import choices +from string import ascii_uppercase, digits + +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +logger = get_logger(__name__) + + +class GatewayClient(DeviceClient): + """ + ThingsBoard Gateway MQTT client implementation. + This class extends DeviceClient and adds gateway-specific functionality. + """ + + def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): + """ + Initialize a new GatewayClient instance. + + :param config: Gateway configuration object or dictionary + """ + self._config = None + if isinstance(config, GatewayConfig): + self._config = config + else: + self._config = GatewayConfig() + if isinstance(config, dict): + for key, value in config.items(): + if hasattr(self._config, key) and value is not None: + setattr(self._config, key, value) + + client_id = self._config.client_id or "tb-gateway-" + ''.join(choices(ascii_uppercase + digits, k=6)) + + # Initialize the DeviceClient with the gateway configuration + super().__init__(self._config) + + # Gateway-specific rate limits + self._device_messages_rate_limit = RateLimit("0:0,", name="device_messages") + self._device_telemetry_rate_limit = RateLimit("0:0,", name="device_telemetry") + self._device_telemetry_dp_rate_limit = RateLimit("0:0,", name="device_telemetry_datapoints") + + # Set of connected devices + self._connected_devices: Set[str] = set() + + # Callbacks + self._device_attribute_update_callback = None + self._device_rpc_request_callback = None + self._device_disconnect_callback = None + + async def connect(self): + """ + Connect to the ThingsBoard platform. + """ + logger.info("Connecting gateway to platform at %s:%s", self._host, self._port) + await super().connect() + + # Subscribe to gateway-specific topics + await self._subscribe_to_gateway_topics() + + logger.info("Gateway connected to ThingsBoard.") + + async def _subscribe_to_gateway_topics(self): + """ + Subscribe to gateway-specific MQTT topics. + """ + logger.info("Subscribing to gateway topics") + + # Subscribe to gateway attributes topic + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + # Subscribe to gateway attributes response topic + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + # Subscribe to gateway RPC topic + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_RPC_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + # Register handlers for gateway topics + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, self._handle_gateway_attribute_update) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_RPC_TOPIC, self._handle_gateway_rpc_request) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, self._handle_gateway_attribute_response) + + async def _handle_gateway_attribute_update(self, topic: str, payload: bytes): + """ + Handle attribute updates for gateway devices. + + :param topic: MQTT topic + :param payload: Message payload + """ + try: + data = loads(payload) + logger.debug("Received gateway attribute update: %s", data) + + if self._device_attribute_update_callback: + for device_name, attributes in data.items(): + update = AttributeUpdate(device=device_name, attributes=attributes) + await self._device_attribute_update_callback(update) + except Exception as e: + logger.exception("Error handling gateway attribute update: %s", e) + + async def _handle_gateway_rpc_request(self, topic: str, payload: bytes): + """ + Handle RPC requests for gateway devices. + + :param topic: MQTT topic + :param payload: Message payload + """ + try: + data = loads(payload) + logger.debug("Received gateway RPC request: %s", data) + + if self._device_rpc_request_callback and 'device' in data and 'data' in data: + device_name = data['device'] + rpc_data = data['data'] + + if 'id' in rpc_data and 'method' in rpc_data: + request_id = rpc_data['id'] + method = rpc_data['method'] + params = rpc_data.get('params', {}) + + result = await self._device_rpc_request_callback(device_name, method, params) + + # Send RPC response + await self.gw_send_rpc_reply(device_name, request_id, result) + except Exception as e: + logger.exception("Error handling gateway RPC request: %s", e) + + async def _handle_gateway_attribute_response(self, topic: str, payload: bytes): + """ + Handle attribute responses for gateway devices. + + :param topic: MQTT topic + :param payload: Message payload + """ + try: + data = loads(payload) + logger.debug("Received gateway attribute response: %s", data) + + # Process attribute response if needed + # This is typically used for handling responses to attribute requests + except Exception as e: + logger.exception("Error handling gateway attribute response: %s", e) + + async def gw_connect_device(self, device_name: str): + """ + Connect a device to the gateway. + + :param device_name: Name of the device to connect + """ + if device_name in self._connected_devices: + logger.warning("Device %s is already connected", device_name) + return + + self._connected_devices.add(device_name) + logger.info("Device %s connected to gateway", device_name) + + async def gw_disconnect_device(self, device_name: str): + """ + Disconnect a device from the gateway. + + :param device_name: Name of the device to disconnect + """ + if device_name not in self._connected_devices: + logger.warning("Device %s is not connected", device_name) + return + + self._connected_devices.remove(device_name) + + # Publish device disconnect message + await self._mqtt_manager.publish( + mqtt_topics.GATEWAY_DISCONNECT_TOPIC, + dumps({"device": device_name}), + qos=1 + ) + + logger.info("Device %s disconnected from gateway", device_name) + + # Call disconnect callback if registered + if self._device_disconnect_callback: + await self._device_disconnect_callback(device_name) + + async def gw_send_telemetry(self, device_name: str, telemetry: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry]]): + """ + Send telemetry on behalf of a connected device. + + :param device_name: Name of the device + :param telemetry: Telemetry data to send + """ + if device_name not in self._connected_devices: + logger.warning("Cannot send telemetry for disconnected device %s", device_name) + return + + # Convert telemetry to the appropriate format + payload = self._prepare_telemetry_payload(device_name, telemetry) + + # Publish telemetry + await self._mqtt_manager.publish( + mqtt_topics.GATEWAY_TELEMETRY_TOPIC, + dumps(payload), + qos=1 + ) + + logger.debug("Sent telemetry for device %s: %s", device_name, payload) + + async def gw_send_attributes(self, device_name: str, attributes: Union[Dict[str, Any], AttributeEntry, List[AttributeEntry]]): + """ + Send attributes on behalf of a connected device. + + :param device_name: Name of the device + :param attributes: Attributes to send + """ + if device_name not in self._connected_devices: + logger.warning("Cannot send attributes for disconnected device %s", device_name) + return + + # Convert attributes to the appropriate format + payload = self._prepare_attributes_payload(device_name, attributes) + + # Publish attributes + await self._mqtt_manager.publish( + mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, + dumps(payload), + qos=1 + ) + + logger.debug("Sent attributes for device %s: %s", device_name, payload) + + async def gw_send_rpc_reply(self, device_name: str, request_id: int, response: Dict[str, Any]): + """ + Send an RPC response on behalf of a connected device. + + :param device_name: Name of the device + :param request_id: ID of the RPC request + :param response: Response data + """ + if device_name not in self._connected_devices: + logger.warning("Cannot send RPC reply for disconnected device %s", device_name) + return + + # Prepare RPC response payload + payload = { + "device": device_name, + "id": request_id, + "data": response + } + + # Publish RPC response + await self._mqtt_manager.publish( + mqtt_topics.GATEWAY_RPC_RESPONSE_TOPIC, + dumps(payload), + qos=1 + ) + + logger.debug("Sent RPC response for device %s, request %s: %s", device_name, request_id, response) + + async def gw_request_shared_attributes(self, device_name: str, keys: List[str], callback: Callable[[Dict[str, Any]], Awaitable[None]]): + """ + Request shared attributes for a connected device. + + :param device_name: Name of the device + :param keys: List of attribute keys to request + :param callback: Callback function to handle the response + """ + if device_name not in self._connected_devices: + logger.warning("Cannot request attributes for disconnected device %s", device_name) + return + + # TODO: Implement attribute request handling with callbacks + + # Prepare attribute request payload + request_id = 1 # TODO: Generate unique request ID + payload = { + "device": device_name, + "keys": keys, + "id": request_id + } + + # Publish attribute request + await self._mqtt_manager.publish( + mqtt_topics.GATEWAY_ATTRIBUTES_REQUEST_TOPIC, + dumps(payload), + qos=1 + ) + + logger.debug("Requested shared attributes for device %s: %s", device_name, keys) + + def set_device_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + """ + Set callback for device attribute updates. + + :param callback: Callback function + """ + self._device_attribute_update_callback = callback + + def set_device_rpc_request_callback(self, callback: Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + """ + Set callback for device RPC requests. + + :param callback: Callback function that takes device name, method, and params + """ + self._device_rpc_request_callback = callback + + def set_device_disconnect_callback(self, callback: Callable[[str], Awaitable[None]]): + """ + Set callback for device disconnections. + + :param callback: Callback function + """ + self._device_disconnect_callback = callback + + def _prepare_telemetry_payload(self, device_name: str, telemetry: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry]]) -> Dict[str, Any]: + """ + Prepare telemetry payload for gateway API. + + :param device_name: Name of the device + :param telemetry: Telemetry data + :return: Formatted payload + """ + if isinstance(telemetry, dict): + # Simple key-value telemetry + return {device_name: telemetry} + + elif isinstance(telemetry, TimeseriesEntry): + # Single TimeseriesEntry + if telemetry.ts: + return {device_name: {"ts": telemetry.ts, "values": {telemetry.key: telemetry.value}}} + else: + return {device_name: {telemetry.key: telemetry.value}} + + elif isinstance(telemetry, list): + # List of TimeseriesEntry objects + # Group by timestamp + ts_groups = {} + for entry in telemetry: + ts = entry.ts or 0 + if ts not in ts_groups: + ts_groups[ts] = {} + ts_groups[ts][entry.key] = entry.value + + if len(ts_groups) == 1 and 0 in ts_groups: + # No timestamps, just values + return {device_name: ts_groups[0]} + else: + # With timestamps + result = [] + for ts, values in ts_groups.items(): + if ts > 0: + result.append({"ts": ts, "values": values}) + else: + result.append({"values": values}) + return {device_name: result} + + # Fallback + logger.warning("Unsupported telemetry format: %s", type(telemetry)) + return {device_name: {}} + + def _prepare_attributes_payload(self, device_name: str, attributes: Union[Dict[str, Any], AttributeEntry, List[AttributeEntry]]) -> Dict[str, Any]: + """ + Prepare attributes payload for gateway API. + + :param device_name: Name of the device + :param attributes: Attributes data + :return: Formatted payload + """ + if isinstance(attributes, dict): + # Simple key-value attributes + return {device_name: attributes} + + elif isinstance(attributes, AttributeEntry): + # Single AttributeEntry + return {device_name: {attributes.key: attributes.value}} + + elif isinstance(attributes, list): + # List of AttributeEntry objects + attrs = {} + for entry in attributes: + attrs[entry.key] = entry.value + return {device_name: attrs} + + # Fallback + logger.warning("Unsupported attributes format: %s", type(attributes)) + return {device_name: {}} diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index d6e8636..edea245 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -14,7 +14,10 @@ # +from collections import defaultdict +import asyncio from abc import ABC, abstractmethod +from datetime import UTC, datetime from typing import Any, Dict, Union, List, Tuple, Optional from orjson import dumps @@ -36,7 +39,7 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio def build_topic_payloads( self, messages: List[DeviceUplinkMessage] - ) -> List[Tuple[str, bytes, int]]: + ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[bool]]]]]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string and a payload byte array. @@ -67,47 +70,51 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio def splitter(self) -> MessageSplitter: return self._splitter - def build_topic_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int]]: - if not messages: - logger.trace("No messages to process in build_topic_payloads.") - return [] + def build_topic_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[bool]]]]]: + try: + if not messages: + logger.trace("No messages to process in build_topic_payloads.") + return [] - from collections import defaultdict + result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[bool]]]]] = [] + device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) - result: List[Tuple[str, bytes, int]] = [] - device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) + for msg in messages: + device_name = msg.device_name + device_groups[device_name].append(msg) + logger.trace("Queued message for device='%s'", device_name) - for msg in messages: - device_name = msg.device_name or "" - device_groups[device_name].append(msg) - logger.trace("Queued message for device='%s'", device_name) + logger.trace("Processing %d device group(s).", len(device_groups)) - logger.trace("Processing %d device group(s).", len(device_groups)) + for device, device_msgs in device_groups.items(): + telemetry_msgs = [m for m in device_msgs if m.has_timeseries()] + attr_msgs = [m for m in device_msgs if m.has_attributes()] + logger.trace("Device '%s' - telemetry: %d, attributes: %d", + device, len(telemetry_msgs), len(attr_msgs)) - for device, device_msgs in device_groups.items(): - telemetry_msgs = [m for m in device_msgs if m.has_timeseries()] - attr_msgs = [m for m in device_msgs if m.has_attributes()] - logger.trace("Device '%s' - telemetry: %d, attributes: %d", - device, len(telemetry_msgs), len(attr_msgs)) + for ts_batch in self._splitter.split_timeseries(telemetry_msgs): + payload = self.build_payload(ts_batch) + count = ts_batch.timeseries_datapoint_count() + result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) + logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) - for ts_batch in self._splitter.split_timeseries(telemetry_msgs): - payload = self.build_payload(ts_batch) - count = ts_batch.timeseries_datapoint_count() - result.append((DEVICE_TELEMETRY_TOPIC, payload, count)) - logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) + for attr_batch in self._splitter.split_attributes(attr_msgs): + payload = self.build_payload(attr_batch) + count = len(attr_batch.attributes) + result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) + logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) - for attr_batch in self._splitter.split_attributes(attr_msgs): - payload = self.build_payload(attr_batch) - count = len(attr_batch.attributes) - result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count)) - logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) - logger.trace("Generated %d topic-payload entries.", len(result)) - return result + logger.trace("Generated %d topic-payload entries.", len(result)) + + return result + except Exception as e: + logger.error("Error building topic-payloads: %s", str(e)) + raise def build_payload(self, msg: DeviceUplinkMessage) -> bytes: result: Dict[str, Any] = {} - device_name = msg.device_name or "" + device_name = msg.device_name logger.trace("Building payload for device='%s'", device_name) if msg.device_name: @@ -119,10 +126,10 @@ def build_payload(self, msg: DeviceUplinkMessage) -> bytes: result[msg.device_name] = self._pack_timeseries(msg) else: if msg.attributes: - logger.trace("Packing anonymous attributes") + logger.trace("Packing attributes") result = self._pack_attributes(msg) if msg.timeseries: - logger.trace("Packing anonymous timeseries") + logger.trace("Packing timeseries") result = self._pack_timeseries(msg) payload = dumps(result) @@ -135,12 +142,15 @@ def _pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def _pack_timeseries(msg: DeviceUplinkMessage) -> Union[Dict[str, Any], List[Dict[str, Any]]]: - logger.trace("Packing %d timeseries entry(ies)", len(msg.timeseries)) - grouped = {} - for entry in msg.timeseries: - grouped.setdefault(entry.ts or 0, {})[entry.key] = entry.value - - if all(ts == 0 for ts in grouped): - return grouped[0] - return [{"ts": ts, "values": values} for ts, values in grouped.items()] + def _pack_timeseries(msg: DeviceUplinkMessage) -> List[Dict[str, Any]]: + logger.trace("Packing %d timeseries timestamp bucket(s)", len(msg.timeseries)) + + now_ts = int(datetime.now(UTC).timestamp() * 1000) + packed: List[Dict[str, Any]] = [] + + for ts_key, entries in msg.timeseries.items(): + resolved_ts = ts_key or now_ts + values = {entry.key: entry.value for entry in entries} + packed.append({"ts": resolved_ts, "values": values}) + + return packed diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 0bbae9f..eceef52 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -15,15 +15,15 @@ import asyncio -from typing import List, Optional, Union, Tuple from contextlib import suppress +from typing import List, Optional, Union, Tuple, Dict, Callable from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage -from tb_mqtt_client.service.mqtt_manager import MQTTManager from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.mqtt_manager import MQTTManager logger = get_logger(__name__) @@ -37,9 +37,10 @@ def __init__(self, telemetry_rate_limit: Optional[RateLimit], telemetry_dp_rate_limit: Optional[RateLimit], message_dispatcher: MessageDispatcher, - max_queue_size: int = 10000, + max_queue_size: int = 1000000, batch_collect_max_time_ms: int = 100, batch_collect_max_count: int = 500): + self.__qos = 1 self._batch_max_time = batch_collect_max_time_ms / 1000 # convert to seconds self._batch_max_count = batch_collect_max_count self._mqtt_manager = mqtt_manager @@ -47,39 +48,74 @@ def __init__(self, self._telemetry_rate_limit = telemetry_rate_limit self._telemetry_dp_rate_limit = telemetry_dp_rate_limit self._backpressure = self._mqtt_manager.backpressure + self._pending_ack_futures: Dict[int, asyncio.Future[bool]] = {} + self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} self._queue = asyncio.Queue(maxsize=max_queue_size) self._active = asyncio.Event() self._wakeup_event = asyncio.Event() + self._retry_tasks: set[asyncio.Task] = set() self._active.set() self._dispatcher = message_dispatcher self._loop_task = asyncio.create_task(self._dequeue_loop()) + self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", max_queue_size, self._batch_max_time, batch_collect_max_count) async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int): + delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) else [] try: - self._queue.put_nowait((topic, payload, datapoints_count)) - logger.trace("Enqueued message: topic=%s, datapoints=%d, type=%s", + logger.debug("publish() received delivery future id: %r for topic=%s", + id(delivery_futures[0]), topic) + self._queue.put_nowait((topic, payload, delivery_futures, datapoints_count)) + logger.debug("Enqueued message: topic=%s, datapoints=%d, type=%s", topic, datapoints_count, type(payload).__name__) except asyncio.QueueFull: - logger.warning("Message queue full. Dropping message for topic %s", topic) + logger.error("Message queue full. Dropping message for topic %s", topic) + for future in payload.get_delivery_futures(): + if future: + future.set_result(False) + return delivery_futures or None async def _dequeue_loop(self): + logger.debug("MessageQueue dequeue loop started.") while self._active.is_set(): try: - topic, payload, count = await self._wait_for_message() + # topic, payload, count = await self._wait_for_message() + topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) + # Unpack payload and delivery futures if it's a retry tuple + logger.debug("MessageQueue dequeue: topic=%s, payload=%r, count=%d", + topic, payload, count) + # if isinstance(payload, tuple): + # payload, delivery_futures_or_none = payload + # else: + # delivery_futures_or_none = None + + if isinstance(payload, bytes): + await self._try_publish(topic, payload, count, delivery_futures_or_none) + continue + logger.debug("Dequeued message: delivery_future id: %r topic=%s, type=%s, datapoints=%d", + id(delivery_futures_or_none[0]) if delivery_futures_or_none else None, + topic, type(payload).__name__, count) + await asyncio.sleep(0) # cooperative yield except asyncio.TimeoutError: + logger.trace("Dequeue wait timed out. Yielding...") + await asyncio.sleep(0.001) + continue + except asyncio.CancelledError: + break + except Exception as e: + logger.debug("Unexpected error in dequeue loop: %s", e) continue if isinstance(payload, bytes): logger.trace("Dequeued immediate publish: topic=%s (raw bytes)", topic) - await self._try_publish(topic, payload, count) + await self._try_publish(topic, payload, count, delivery_futures_or_none) continue logger.trace("Dequeued message for batching: topic=%s, device=%s", topic, getattr(payload, 'device_name', 'N/A')) - batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], int]] = [(topic, payload, count)] + batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], asyncio.Future[bool], int]] = [(topic, payload, delivery_futures_or_none, count)] start = asyncio.get_event_loop().time() batch_size = payload.size @@ -92,109 +128,180 @@ async def _dequeue_loop(self): logger.trace("Batch count threshold reached: %d messages", len(batch)) break - next_topic, next_payload, next_count = self._queue.get_nowait() - if isinstance(next_payload, DeviceUplinkMessage): - msg_size = next_payload.size - if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: - logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) - await self._queue.put((next_topic, next_payload, next_count)) - break - batch.append((next_topic, next_payload, next_count)) - batch_size += msg_size - else: - logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_topic) - await self._try_publish(next_topic, next_payload, next_count) + try: + next_topic, next_payload, delivery_futures_or_none, next_count = self._queue.get_nowait() + if isinstance(next_payload, DeviceUplinkMessage): + msg_size = next_payload.size + if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: + logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) + self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, next_count)) + break + batch.append((next_topic, next_payload, delivery_futures_or_none, next_count)) + batch_size += msg_size + else: + logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_topic) + await self._try_publish(next_topic, next_payload, next_count) + except asyncio.QueueEmpty: + break if batch: - messages = [p for _, p, _ in batch if isinstance(p, DeviceUplinkMessage)] - logger.trace("Formed batch with %d DeviceUplinkMessages", len(messages)) + logger.debug("Batching completed: %d messages, total size=%d", len(batch), batch_size) + messages = [device_uplink_message for _, device_uplink_message, _, _ in batch] + topic_payloads = self._dispatcher.build_topic_payloads(messages) - for topic, payload, datapoints in topic_payloads: - logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d", - topic, len(payload), datapoints) - await self._try_publish(topic, payload, datapoints) - async def _try_publish(self, topic: str, payload: bytes, points: int): - telemetry = topic == mqtt_topics.DEVICE_TELEMETRY_TOPIC - logger.trace("Attempting publish: topic=%s, datapoints=%d", topic, points) + for topic, payload, datapoints, delivery_futures in topic_payloads: + logger.debug("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", + topic, len(payload), datapoints, [id(f) for f in delivery_futures]) + await self._try_publish(topic, payload, datapoints, delivery_futures) + + async def _try_publish(self, + topic: str, + payload: bytes, + datapoints: int, + delivery_futures_or_none: List[Optional[asyncio.Future[bool]]] = None): + if delivery_futures_or_none is None: + logger.trace("No delivery futures associated! This publish result will not be tracked.") + delivery_futures_or_none = [] + is_message_with_telemetry_or_attributes = topic in (mqtt_topics.DEVICE_TELEMETRY_TOPIC, + mqtt_topics.DEVICE_ATTRIBUTES_TOPIC) + + logger.trace("Attempting publish: topic=%s, datapoints=%d", topic, datapoints) + # Check backpressure first - if active, don't even try to check rate limits if self._backpressure.should_pause(): - self._schedule_delayed_retry(topic, payload, points, delay=1.0) + logger.debug("Backpressure active, delaying publish of topic=%s for %.1f seconds", topic, 1.0) + self._schedule_delayed_retry(topic, payload, datapoints, delay=1.0, delivery_futures=delivery_futures_or_none) return - if telemetry: - if self._telemetry_rate_limit and self._telemetry_rate_limit.check_limit_reached(1): - logger.debug("Telemetry message rate limit hit: topic=%s", topic) - retry_delay = self._telemetry_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, payload, points, delay=retry_delay) - return - if self._telemetry_dp_rate_limit and self._telemetry_dp_rate_limit.check_limit_reached(points): - logger.debug("Telemetry datapoint rate limit hit: topic=%s", topic) - retry_delay = self._telemetry_dp_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, payload, points, delay=retry_delay) - return - else: - if self._message_rate_limit and self._message_rate_limit.check_limit_reached(1): - logger.debug("Generic message rate limit hit: topic=%s", topic) - logger.debug("Rate limit state: %s", self._message_rate_limit.to_dict()) - retry_delay = self._message_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, payload, points, delay=retry_delay) - return + # Check and consume rate limits atomically before publishing + if is_message_with_telemetry_or_attributes: + # For telemetry messages, we need to check both message and datapoint rate limits + telemetry_msg_success = True + telemetry_dp_success = True + + if self._telemetry_rate_limit: + triggered_rate_limit = self._telemetry_rate_limit.try_consume(1) + if triggered_rate_limit: + logger.debug("Telemetry message rate limit hit for topic %s: %r per %r seconds", + topic, triggered_rate_limit[0], triggered_rate_limit[1]) + retry_delay = self._telemetry_rate_limit.minimal_timeout + self._schedule_delayed_retry(topic, payload, datapoints, delay=retry_delay, delivery_futures=delivery_futures_or_none) + return + if self._telemetry_dp_rate_limit: + triggered_rate_limit = self._telemetry_dp_rate_limit.try_consume(datapoints) + if triggered_rate_limit: + logger.debug("Telemetry datapoint rate limit hit for topic %s: %r per %r seconds", + topic, triggered_rate_limit[0], triggered_rate_limit[1]) + retry_delay = self._telemetry_dp_rate_limit.minimal_timeout + self._schedule_delayed_retry(topic, payload, datapoints, delay=retry_delay, delivery_futures=delivery_futures_or_none) + return + else: + # For non-telemetry messages, we only need to check the message rate limit + if self._message_rate_limit: + triggered_rate_limit = self._message_rate_limit.try_consume(1) + if triggered_rate_limit: + logger.debug("Generic message rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) + retry_delay = self._message_rate_limit.minimal_timeout + self._schedule_delayed_retry(topic, payload, datapoints, delay=retry_delay, delivery_futures=delivery_futures_or_none) + return try: - logger.debug("Rate limit state before publish: %s", self._message_rate_limit.to_dict()) - await self._mqtt_manager.publish(topic, payload, qos=1) - logger.trace("Publish successful: topic=%s", topic) - if telemetry: - if self._telemetry_rate_limit: - self._telemetry_rate_limit.consume(1) - if self._telemetry_dp_rate_limit: - self._telemetry_dp_rate_limit.consume(points) - else: - if self._message_rate_limit: - self._message_rate_limit.consume(1) + logger.debug("Trying to publish topic=%s, payload size=%d, attached future id=%r", + topic, len(payload), id(delivery_futures_or_none[0]) if delivery_futures_or_none else 0) + + mqtt_future = await self._mqtt_manager.publish(topic, payload, qos=self.__qos) + + if delivery_futures_or_none is not None: + def resolve_attached(mqtt_future: asyncio.Future): + try: + success = mqtt_future.result() is True + except Exception as e: + success = False + logger.warning("mqtt_future failed with exception: %s", e) + + for i, f in enumerate(delivery_futures_or_none): + if f is not None and not f.done(): + f.set_result(success) + logger.debug("Resolved delivery future #%d id=%r with %s, main publish future id: %r, %r", + i, id(f), success, id(mqtt_future), mqtt_future) + + logger.debug("Adding done callback to main publish future: %r, main publish future state: %r", id(mqtt_future), mqtt_future.done()) + mqtt_future.add_done_callback(resolve_attached) except Exception as e: logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) - self._schedule_delayed_retry(topic, payload, points, delay=1.0) + self._schedule_delayed_retry(topic, payload, datapoints, delay=.1) - def _schedule_delayed_retry(self, topic: str, payload: bytes, points: int, delay: float): + def _schedule_delayed_retry(self, topic: str, payload: bytes, points: int, delay: float, + delivery_futures: Optional[List[Optional[asyncio.Future[bool]]]] = None): logger.trace("Scheduling retry: topic=%s, delay=%.2f", topic, delay) async def retry(): - await asyncio.sleep(delay) try: - self._queue.put_nowait((topic, payload, points)) + logger.debug("Retrying publish: topic=%s", topic) + await asyncio.sleep(delay) + self._queue.put_nowait((topic, payload, delivery_futures, points)) self._wakeup_event.set() - logger.trace("Re-enqueued message after delay: topic=%s", topic) + logger.debug("Re-enqueued message after delay: topic=%s", topic) except asyncio.QueueFull: logger.warning("Retry queue full. Dropping retried message: topic=%s", topic) + except Exception as e: + logger.debug("Unexpected error during delayed retry: %s", e) - asyncio.create_task(retry()) + task = asyncio.create_task(retry()) + self._retry_tasks.add(task) + task.add_done_callback(self._retry_tasks.discard) - async def _wait_for_message(self): - if not self._queue.empty(): - return await self._queue.get() + async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage], int]: + while self._active.is_set(): + try: + if not self._queue.empty(): + try: + return await self._queue.get() + except asyncio.QueueEmpty: + await asyncio.sleep(.01) + + self._wakeup_event.clear() + queue_task = asyncio.create_task(self._queue.get()) + wake_task = asyncio.create_task(self._wakeup_event.wait()) + + done, pending = await asyncio.wait( + [queue_task, wake_task], return_when=asyncio.FIRST_COMPLETED + ) - self._wakeup_event.clear() - queue_task = asyncio.create_task(self._queue.get()) - wake_task = asyncio.create_task(self._wakeup_event.wait()) - done, _ = await asyncio.wait([queue_task, wake_task], return_when=asyncio.FIRST_COMPLETED) + for task in pending: + logger.debug("Cancelling pending task: %r, it is queue_task = %r", task, queue_task==task) + task.cancel() + with suppress(asyncio.CancelledError): + await task - if queue_task in done: - wake_task.cancel() - return queue_task.result() + if queue_task in done: + logger.debug("Retrieved message from queue: %r", queue_task.result()) + return queue_task.result() - # Wake event triggered — retry get - queue_task.cancel() - await asyncio.sleep(0.001) # Yield control - return await self._wait_for_message() + await asyncio.sleep(0) + + except asyncio.CancelledError: + break + + raise asyncio.CancelledError("MessageQueue is shutting down or stopped.") async def shutdown(self): logger.debug("Shutting down MessageQueue...") self._active.clear() + self._wakeup_event.set() # Wake up the _wait_for_message if it's blocked + + for task in self._retry_tasks: + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*self._retry_tasks, return_exceptions=True) + self._loop_task.cancel() + self._rate_limit_refill_task.cancel() with suppress(asyncio.CancelledError): await self._loop_task + await self._rate_limit_refill_task + logger.debug("MessageQueue shutdown complete.") def is_empty(self): @@ -206,6 +313,36 @@ def size(self): def clear(self): logger.debug("Clearing message queue...") while not self._queue.empty(): - self._queue.get_nowait() + _, message, _ = self._queue.get_nowait() + if isinstance(message, DeviceUplinkMessage) and message.get_delivery_futures(): + for future in message.get_delivery_futures(): + future.set_result(False) self._queue.task_done() logger.debug("Message queue cleared.") + + @property + def qos(self) -> int: + return self.__qos + + @qos.setter + def qos(self, qos: int): + self.__qos = qos + + async def _rate_limit_refill_loop(self): + try: + while self._active.is_set(): + await asyncio.sleep(1.0) + self._refill_rate_limits() + logger.debug("Rate limits refilled, state: %s", + { + "message_rate_limit": self._message_rate_limit.to_dict() if self._message_rate_limit else None, + "telemetry_rate_limit": self._telemetry_rate_limit.to_dict() if self._telemetry_rate_limit else None, + "telemetry_dp_rate_limit": self._telemetry_dp_rate_limit.to_dict() if self._telemetry_dp_rate_limit else None + }) + except asyncio.CancelledError: + pass + + def _refill_rate_limits(self): + for rl in (self._message_rate_limit, self._telemetry_rate_limit, self._telemetry_dp_rate_limit): + if rl: + rl.refill() diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index 00f228e..e76d838 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - +import asyncio from typing import List -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder logger = get_logger(__name__) @@ -37,6 +36,11 @@ def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: logger.trace("Splitting timeseries for %d messages", len(messages)) + if (len(messages) == 1 and + messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() <= self._max_datapoints and + messages[0].size <= self._max_payload_size) or self._max_datapoints == 0: + return messages + result: List[DeviceUplinkMessage] = [] for message in messages: @@ -48,8 +52,9 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp builder = None size = 0 point_count = 0 + batch_futures = [] - for ts in message.timeseries: + for ts in message.timeseries.values(): exceeds_size = builder and size + ts.size > self._max_payload_size exceeds_points = self._max_datapoints > 0 and point_count >= self._max_datapoints @@ -57,23 +62,35 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp if builder: built = builder.build() result.append(built) + batch_futures.extend(built.get_delivery_futures()) logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) - builder = DeviceUplinkMessageBuilder() \ - .set_device_name(message.device_name) \ - .set_device_profile(message.device_profile) + builder = DeviceUplinkMessageBuilder().set_device_name(message.device_name).set_device_profile( + message.device_profile) size = 0 point_count = 0 builder.add_telemetry(ts) size += ts.size point_count += 1 - logger.trace("Added timeseries entry to batch (size=%d, points=%d)", size, point_count) if builder and builder._timeseries: built = builder.build() result.append(built) + batch_futures.extend(built.get_delivery_futures()) logger.trace("Flushed final batch with %d points (size=%d)", len(built.timeseries), size) + if message.get_delivery_futures(): + original_future = message.get_delivery_futures()[0] + logger.exception("Adding futures to original future: %s, futures ids: %r", id(original_future), + [id(batch_future) for batch_future in batch_futures]) + + async def resolve_original(): + logger.exception("Resolving original future with batch futures: %s", [id(f) for f in batch_futures]) + results = await asyncio.gather(*batch_futures, return_exceptions=False) + original_future.set_result(all(results)) + + asyncio.create_task(resolve_original()) + logger.trace("Total timeseries batches created: %d", len(result)) return result @@ -81,6 +98,11 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp logger.trace("Splitting attributes for %d messages", len(messages)) result: List[DeviceUplinkMessage] = [] + if (len(messages) == 1 and + messages[0].attributes_datapoint_count() <= self._max_datapoints and + messages[0].size <= self._max_payload_size): + return messages + for message in messages: if not message.has_attributes(): logger.trace("Message from device '%s' has no attributes. Skipping.", message.device_name) @@ -89,29 +111,48 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp logger.trace("Processing attributes from device: %s", message.device_name) builder = None size = 0 + point_count = 0 + batch_futures = [] for attr in message.attributes: - if builder and size + attr.size > self._max_payload_size: - built = builder.build() - result.append(built) - logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) + exceeds_size = builder and size + attr.size > self._max_payload_size + exceeds_points = self._max_datapoints > 0 and point_count >= self._max_datapoints + + if not builder or exceeds_size or exceeds_points: + if builder: + built = builder.build() + result.append(built) + batch_futures.extend(built.get_delivery_futures()) + logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) builder = None size = 0 + point_count = 0 if not builder: - builder = DeviceUplinkMessageBuilder() \ - .set_device_name(message.device_name) \ - .set_device_profile(message.device_profile) + builder = DeviceUplinkMessageBuilder().set_device_name(message.device_name).set_device_profile( + message.device_profile) builder.add_attributes(attr) size += attr.size - logger.trace("Added attribute to batch (size=%d)", size) + point_count += 1 if builder and builder._attributes: built = builder.build() result.append(built) + batch_futures.extend(built.get_delivery_futures()) logger.trace("Flushed final attribute batch (count=%d, size=%d)", len(built.attributes), size) + if message.get_delivery_futures(): + original_future = message.get_delivery_futures()[0] + logger.exception("Adding futures to original future: %s, futures ids: %r", id(original_future), + [id(batch_future) for batch_future in batch_futures]) + + async def resolve_original(): + results = await asyncio.gather(*batch_futures, return_exceptions=False) + original_future.set_result(all(results)) + + asyncio.create_task(resolve_original()) + logger.trace("Total attribute batches created: %d", len(result)) return result diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 41c880c..ae34955 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -22,7 +22,7 @@ from gmqtt import Client as GMQTTClient, Message, Subscription, MQTTConnectError from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer -from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback +from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback, patch_gmqtt_protocol_connection_lost, patch_mqtt_handler_disconnect, DISCONNECT_REASON_CODES from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController @@ -44,6 +44,9 @@ def __init__( rate_limits_handler: Optional[Callable[[str, bytes], Awaitable[None]]] = None, rpc_response_handler: Optional[RPCResponseHandler] = None, ): + patch_gmqtt_protocol_connection_lost() + patch_mqtt_handler_disconnect() + self._client = GMQTTClient(client_id) patch_gmqtt_puback(self._client, self._handle_puback_reason_code) self._client.on_connect = self._on_connect_internal @@ -57,6 +60,7 @@ def __init__( self._on_disconnect_callback = on_disconnect self._connected_event = asyncio.Event() + self._connect_params = None # Will be set in connect method self._handlers: Dict[str, Callable[[str, bytes], Awaitable[None]]] = {} self._pending_publishes: Dict[int, asyncio.Future] = {} @@ -69,34 +73,39 @@ def __init__( self.__rate_limits_retrieved = False self.__rate_limiter: Optional[Dict[str, RateLimit]] = None self.__is_gateway = False # TODO: determine if this is a gateway or not - self.__is_waiting_for_rate_limits_publish = False + self.__is_waiting_for_rate_limits_publish = True # Start with True to prevent publishing before rate limits are retrieved self._rate_limits_ready_event = asyncio.Event() async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, password: Optional[str] = None, tls: bool = False, keepalive: int = 60, ssl_context: Optional[ssl.SSLContext] = None): - try: - if username: - self._client.set_auth_credentials(username, password) - - if tls: - if ssl_context is None: - ssl_context = ssl.create_default_context() - await self._client.connect(host, port, ssl=ssl_context, keepalive=keepalive) - else: - await self._client.connect(host, port, keepalive=keepalive) + self._connect_params = (host, port, username, password, tls, keepalive, ssl_context) + asyncio.create_task(self._connect_loop()) + + async def _connect_loop(self): + host, port, username, password, tls, keepalive, ssl_context = self._connect_params + retry_delay = 3 + + while not self._client.is_connected: try: - await asyncio.wait_for(self._connected_event.wait(), timeout=10) - except asyncio.TimeoutError: - logger.warning("Timeout waiting for MQTT connection.") - raise + if username: + self._client.set_auth_credentials(username, password) - except MQTTConnectError as e: - logger.warning("MQTT connection failed: %s", str(e)) - self._connected_event.clear() - except Exception as e: - logger.exception("Unhandled exception during MQTT connect: %s", e) - raise + if tls: + if ssl_context is None: + ssl_context = ssl.create_default_context() + await self._client.connect(host, port, ssl=ssl_context, keepalive=keepalive) + else: + await self._client.connect(host, port, keepalive=keepalive) + + logger.info("MQTT connection initiated, waiting for on_connect...") + await self._connected_event.wait() + logger.info("MQTT connected.") + break # Exit loop if connected + + except Exception as e: + logger.warning("Initial MQTT connection failed: %s. Retrying in %s seconds...", str(e), retry_delay) + await asyncio.sleep(retry_delay) def is_connected(self) -> bool: return self._client.is_connected and self._connected_event.is_set() and self.__rate_limits_retrieved @@ -124,7 +133,7 @@ async def publish(self, message_or_topic: Union[str, Message], raise RuntimeError("Timeout waiting for rate limits.") if not force and self._backpressure.should_pause(): - logger.warning("Backpressure active. Publishing suppressed.") + logger.trace("Backpressure active. Publishing suppressed.") raise RuntimeError("Publishing temporarily paused due to backpressure.") if isinstance(message_or_topic, Message): @@ -136,6 +145,7 @@ async def publish(self, message_or_topic: Union[str, Message], future = asyncio.get_event_loop().create_future() if qos > 0: + logger.debug("Publishing mid=%s, storing publish main future with id: %r", mid, id(future)) self._pending_publishes[mid] = future self._client._persistent_storage.push_message_nowait(mid, package) else: @@ -169,6 +179,8 @@ def unregister_handler(self, topic_filter: str): def _on_connect_internal(self, client, flags, rc, properties): logger.info("Connected to platform") + if hasattr(client, '_connection'): + client._connection._on_disconnect_called = False self._connected_event.set() asyncio.create_task(self.__handle_connect_and_limits()) @@ -187,12 +199,39 @@ async def __handle_connect_and_limits(self): if self._on_connect_callback: await self._on_connect_callback() - def _on_disconnect_internal(self, client, packet, exc=None): - logger.warning("Disconnected from platform") + def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): + if reason_code is not None: + reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + logger.info("Disconnected from platform with reason code: %s (%s)", reason_code, reason_desc) + + if properties and 'reason_string' in properties: + logger.info("Disconnect reason: %s", properties['reason_string'][0]) + else: + logger.info("Disconnected from platform") + + if exc: + logger.warning("Disconnect exception: %s", exc) + RPCRequestIdProducer.reset() self._rpc_response_handler.clear() self._connected_event.clear() - self._backpressure.notify_disconnect(delay_seconds=15) + self.__rate_limits_retrieved = False + self.__is_waiting_for_rate_limits_publish = True + self._rate_limits_ready_event.clear() + if reason_code == 142: + logger.error("Session was taken over, looks like another client connected with the same credentials.") + self._backpressure.notify_disconnect(delay_seconds=10) + if reason_code in (131, 142, 143, 151): + reached_time = 1 + for rate_limit in self.__rate_limiter.values(): + if isinstance(rate_limit, RateLimit): + reached_limit = rate_limit.reach_limit() + reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) + self._backpressure.notify_disconnect(delay_seconds=reached_time) + else: + # Default disconnect handling + self._backpressure.notify_disconnect(delay_seconds=15) + if self._on_disconnect_callback: asyncio.create_task(self._on_disconnect_callback()) @@ -205,18 +244,35 @@ def _on_message_internal(self, client, topic: str, payload: bytes, qos, properti asyncio.create_task(self._rpc_response_handler.handle(topic, payload)) def _on_publish_internal(self, client, mid): - future = self._pending_publishes.pop(mid, None) - if future and not future.done(): - future.set_result(True) + pass + # future = self._pending_publishes.pop(mid, None) + # if future and not future.done(): + # future.set_result(True) def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dict): QUOTA_EXCEEDED = 0x97 # MQTT 5 reason code for quota exceeded + IMPLEMENTATION_SPECIFIC_ERROR = 0x83 # MQTT 5 reason code for implementation specific error (131) + + logger.debug("Handling PUBACK mid=%s with rc %r", mid, reason_code) + future = self._pending_publishes.pop(mid, None) + if future is None: + logger.error("Missing future for mid=%s", mid) + elif future.done(): + logger.error("Future for mid=%s already resolved", mid) + else: + logger.debug("Resolved future for mid=%s, object id: %r", mid, id(future)) + future.set_result(True) + if reason_code == QUOTA_EXCEEDED: logger.warning("PUBACK received with QUOTA_EXCEEDED for mid=%s", mid) self._backpressure.notify_quota_exceeded(delay_seconds=10) + elif reason_code == IMPLEMENTATION_SPECIFIC_ERROR: + logger.warning("PUBACK received with IMPLEMENTATION_SPECIFIC_ERROR for mid=%s, treating as rate limit reached", mid) + self._backpressure.notify_quota_exceeded(delay_seconds=15) # Treat implementation specific error as quota exceeded elif reason_code != 0: logger.warning("PUBACK received with error code %s for mid=%s", reason_code, mid) + def _on_subscribe_internal(self, client, mid, qos, properties): future = self._pending_subscriptions.pop(mid, None) if future and not future.done(): @@ -249,6 +305,9 @@ def set_rate_limits( self._rate_limits_ready_event.set() async def __request_rate_limits(self): + # Set this flag at the beginning to prevent publishing before rate limits are retrieved + self.__is_waiting_for_rate_limits_publish = True + request_id = await RPCRequestIdProducer.get_next() request_topic = f"v1/devices/me/rpc/request/{request_id}" response_topic = f"v1/devices/me/rpc/response/{request_id}" @@ -262,13 +321,12 @@ async def _handler(topic: str, payload: bytes): await self.__rate_limits_handler(topic, payload) response_future.set_result(payload) except Exception as e: - logger.exception("Error handling rate limits response: %s", e) + logger.debug("Error handling rate limits response: %s", e) response_future.set_exception(e) self.register_handler(response_topic, _handler) try: - self.__is_waiting_for_rate_limits_publish = True logger.debug("Requesting rate limits via RPC...") await self.publish(request_topic, SESSION_LIMITS_REQUEST_MESSAGE, qos=1, force=True) await asyncio.wait_for(response_future, timeout=10) @@ -278,9 +336,10 @@ async def _handler(topic: str, payload: bytes): self._rate_limits_ready_event.set() except asyncio.TimeoutError: logger.warning("Timeout while waiting for rate limits.") + # Keep __is_waiting_for_rate_limits_publish as True to prevent publishing + # until rate limits are retrieved finally: self.unregister_handler(response_topic) - self.__is_waiting_for_rate_limits_publish = False @property def backpressure(self) -> BackpressureController: From 7e9f19166a4a2a52a3807b2f1aac25ef9d897777 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 29 May 2025 19:58:07 +0300 Subject: [PATCH 03/74] Added basic handlers for device, updated copyright --- examples/device/claiming_device_pe_only.py | 20 +- examples/device/client_provisioning.py | 22 +- examples/device/client_rpc_request.py | 20 +- examples/device/firmware_update.py | 20 +- examples/device/hardware_specs_sender.py | 20 +- examples/device/load.py | 176 ++++++++++++++++ examples/device/operational_example.py | 193 ++++++++++++----- examples/device/request_attributes.py | 20 +- examples/device/send_telemetry_and_attr.py | 20 +- examples/device/send_telemetry_pack.py | 20 +- examples/device/subscription_to_attrs.py | 20 +- examples/device/tls_connect.py | 20 +- examples/gateway/claiming_device_pe_only.py | 20 +- examples/gateway/connect_disconnect_device.py | 20 +- examples/gateway/request_attributes.py | 21 +- examples/gateway/respond_to_rpc.py | 20 +- .../gateway/send_telemetry_and_attributes.py | 20 +- examples/gateway/subscribe_to_attributes.py | 20 +- examples/gateway/tls_connect.py | 20 +- sdk_utils.py | 20 +- setup.py | 23 +- tb_device_http.py | 22 +- tb_device_mqtt.py | 32 ++- tb_gateway_mqtt.py | 33 ++- tb_mqtt_client/__init__.py | 22 +- tb_mqtt_client/common/__init__.py | 26 +-- tb_mqtt_client/common/config_loader.py | 27 ++- tb_mqtt_client/common/exceptions.py | 25 ++- tb_mqtt_client/common/gmqtt_patch.py | 68 +++--- tb_mqtt_client/common/logging_utils.py | 43 ++-- tb_mqtt_client/common/provision_client.py | 26 +-- tb_mqtt_client/common/rate_limit/__init__.py | 26 +-- .../rate_limit/backpressure_controller.py | 49 +++-- .../common/rate_limit/rate_limit.py | 56 +++-- tb_mqtt_client/common/request_id_generator.py | 53 +++-- tb_mqtt_client/constants/__init__.py | 26 +-- tb_mqtt_client/constants/mqtt_topics.py | 31 ++- tb_mqtt_client/constants/service_keys.py | 26 ++- tb_mqtt_client/constants/service_messages.py | 25 ++- tb_mqtt_client/entities/__init__.py | 22 +- tb_mqtt_client/entities/data/__init__.py | 26 +-- .../entities/data/attribute_entry.py | 25 ++- .../entities/data/attribute_request.py | 76 +++++-- .../entities/data/attribute_response.py | 14 -- .../entities/data/attribute_update.py | 27 ++- tb_mqtt_client/entities/data/data_entry.py | 26 ++- .../entities/data/device_uplink_message.py | 45 ++-- .../data/requested_attribute_response.py | 93 ++++++++ tb_mqtt_client/entities/data/rpc_request.py | 30 ++- tb_mqtt_client/entities/data/rpc_response.py | 28 ++- .../entities/data/timeseries_entry.py | 39 +--- tb_mqtt_client/entities/gateway/__init__.py | 26 +-- .../entities/gateway/device_session_state.py | 26 +-- .../entities/gateway/rpc_context.py | 26 +-- .../entities/gateway/virtual_device.py | 26 +-- tb_mqtt_client/entities/publish_result.py | 33 +++ tb_mqtt_client/service/__init__.py | 22 +- tb_mqtt_client/service/base_client.py | 35 ++-- tb_mqtt_client/service/device/__init__.py | 26 +-- .../device/attribute_updates_handler.py | 52 ----- tb_mqtt_client/service/device/client.py | 137 +++++++----- .../service/device/handlers/__init__.py | 14 ++ .../handlers/attribute_updates_handler.py | 63 ++++++ .../requested_attributes_response_handler.py | 97 +++++++++ .../device/handlers/rpc_requests_handler.py | 80 +++++++ .../device/handlers/rpc_response_handler.py | 87 ++++++++ .../service/device/rpc_requests_handler.py | 69 ------ tb_mqtt_client/service/event_dispatcher.py | 26 +-- tb_mqtt_client/service/gateway/__init__.py | 26 +-- tb_mqtt_client/service/gateway/client.py | 32 ++- .../service/gateway/device_sesion.py | 26 +-- .../gateway_attribute_updates_handler.py | 26 +-- .../gateway/gateway_rpc_requests_handler.py | 26 +-- .../service/gateway/multiplex_publisher.py | 25 ++- .../service/gateway/subdevice_manager.py | 26 +-- tb_mqtt_client/service/message_dispatcher.py | 198 ++++++++++++++---- tb_mqtt_client/service/message_queue.py | 129 ++++++------ tb_mqtt_client/service/message_splitter.py | 39 ++-- tb_mqtt_client/service/mqtt_manager.py | 162 +++++++++----- .../service/rpc_response_handler.py | 72 ------- tb_mqtt_client/tb_device_mqtt.py | 22 +- tests/__init__.py | 21 +- tests/constants/__init__.py | 26 +-- tests/constants/test_mqtt_topics.py | 26 ++- tests/service/__init__.py | 26 +-- tests/service/device/__init__.py | 26 +-- .../device/test_device_client_rate_limits.py | 25 ++- tests/service/test_json_message_dispatcher.py | 36 ++-- tests/service/test_message_splitter.py | 26 ++- tests/service/test_mqtt_manager.py | 26 ++- tests/tb_device_mqtt_client_tests.py | 20 +- tests/tb_gateway_mqtt_client_tests.py | 20 +- utils.py | 20 +- 93 files changed, 2290 insertions(+), 1484 deletions(-) create mode 100644 examples/device/load.py delete mode 100644 tb_mqtt_client/entities/data/attribute_response.py create mode 100644 tb_mqtt_client/entities/data/requested_attribute_response.py create mode 100644 tb_mqtt_client/entities/publish_result.py delete mode 100644 tb_mqtt_client/service/device/attribute_updates_handler.py create mode 100644 tb_mqtt_client/service/device/handlers/__init__.py create mode 100644 tb_mqtt_client/service/device/handlers/attribute_updates_handler.py create mode 100644 tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py create mode 100644 tb_mqtt_client/service/device/handlers/rpc_requests_handler.py create mode 100644 tb_mqtt_client/service/device/handlers/rpc_response_handler.py delete mode 100644 tb_mqtt_client/service/device/rpc_requests_handler.py delete mode 100644 tb_mqtt_client/service/rpc_response_handler.py diff --git a/examples/device/claiming_device_pe_only.py b/examples/device/claiming_device_pe_only.py index f7bf8e6..3687aba 100644 --- a/examples/device/claiming_device_pe_only.py +++ b/examples/device/claiming_device_pe_only.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index 20cc3ab..cf1886b 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -1,17 +1,17 @@ -# Copyright 2025. ThingsBoard + +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo diff --git a/examples/device/client_rpc_request.py b/examples/device/client_rpc_request.py index fb7f204..a54b589 100644 --- a/examples/device/client_rpc_request.py +++ b/examples/device/client_rpc_request.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time import logging diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py index 0c5b445..4a84ce3 100644 --- a/examples/device/firmware_update.py +++ b/examples/device/firmware_update.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time import logging diff --git a/examples/device/hardware_specs_sender.py b/examples/device/hardware_specs_sender.py index 48f0291..2a753a6 100644 --- a/examples/device/hardware_specs_sender.py +++ b/examples/device/hardware_specs_sender.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time import logging diff --git a/examples/device/load.py b/examples/device/load.py new file mode 100644 index 0000000..3e1cf51 --- /dev/null +++ b/examples/device/load.py @@ -0,0 +1,176 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import signal +import time +from datetime import datetime, UTC +from random import randint + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + +# --- Logging setup --- +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +# --- Constants --- +BATCH_SIZE = 1000 +YIELD_DELAY = 0.001 +MAX_PENDING_BATCHES = 100 +FUTURE_TIMEOUT = 1.0 + + +async def attribute_update_callback(update: AttributeUpdate): + logger.info("Received attribute update: %s", update.as_dict()) + + +async def rpc_request_callback(request: RPCRequest): + logger.info("Received RPC request: %s", request.to_dict()) + return RPCResponse(request_id=request.request_id, result={"status": "ok"}) + + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) + except NotImplementedError: + signal.signal(sig, lambda *_: _shutdown_handler()) + + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + client.set_attribute_update_callback(attribute_update_callback) + client.set_rpc_request_callback(rpc_request_callback) + + await client.connect() + logger.info("Connected to ThingsBoard.") + + sent_batches = 0 + delivered_batches = 0 + delivered_datapoints = 0 + pending_futures = [] + + delivery_start_ts = None # Start time of first successful delivery + delivery_end_ts = None # End time of last successful delivery + + try: + while not stop_event.is_set(): + ts_now = int(datetime.now(UTC).timestamp() * 1000) + entries = [ + TimeseriesEntry("temperature", randint(20, 40), ts=ts_now - i) + for i in range(BATCH_SIZE) + ] + + try: + future = await client.send_telemetry(entries) + if future: + pending_futures.append((future, BATCH_SIZE)) + sent_batches += 1 + else: + logger.warning("Telemetry batch dropped or not acknowledged.") + except Exception as e: + logger.warning("Failed to publish telemetry batch: %s", e) + + if len(pending_futures) >= MAX_PENDING_BATCHES: + done, _ = await asyncio.wait( + [f for f, _ in pending_futures], timeout=FUTURE_TIMEOUT + ) + + remaining = [] + for fut, batch_size in pending_futures: + if fut in done: + try: + result = fut.result() + if result is True: + delivered_batches += 1 + delivered_datapoints += batch_size + now = time.perf_counter() + delivery_start_ts = delivery_start_ts or now + delivery_end_ts = now + except asyncio.CancelledError: + logger.exception("Future was cancelled: %r, id: %r", fut, id(fut)) + logger.warning("Delivery future was cancelled: %r", fut) + except Exception as e: + logger.warning("Delivery future raised: %s", e) + else: + fut.cancel() + logger.warning("Cancelled delivery future after timeout: %r, future id: %r", fut, id(fut)) + # remaining.append((fut, batch_size)) + + pending_futures = [] + + if sent_batches % 10 == 0: + logger.info("Sent %d batches so far...", sent_batches) + + await asyncio.sleep(YIELD_DELAY) + + finally: + logger.info("Waiting for remaining telemetry batches to be acknowledged...") + done, _ = await asyncio.wait( + [f for f, _ in pending_futures], timeout=.01 + ) + + for fut, batch_size in pending_futures: + if fut in done: + try: + result = fut.result() + if result is True: + delivered_batches += 1 + delivered_datapoints += batch_size + now = time.perf_counter() + delivery_start_ts = delivery_start_ts or now + delivery_end_ts = now + except asyncio.CancelledError: + # logger.warning("Final delivery future was cancelled: %r", fut) + pass + except Exception as e: + logger.warning("Final delivery failed: %s", e) + else: + fut.cancel() + # logger.warning("Final delivery future timed out and was cancelled: %r", fut) + + await client.disconnect() + logger.info("Disconnected cleanly.") + + if delivery_start_ts is not None and delivery_end_ts is not None: + delivery_duration = delivery_end_ts - delivery_start_ts + logger.info("Delivered %d batches / %d datapoints in %.6f seconds (%.0f datapoints/sec)", + delivered_batches, delivered_datapoints, delivery_duration, + delivered_datapoints / delivery_duration if delivery_duration > 0 else 0) + else: + logger.warning("No successful delivery occurred.") + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("Interrupted by user.") diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index d9d8330..c2f08c5 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -1,17 +1,31 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio import logging import signal from datetime import datetime, UTC -from random import randint, uniform from tb_mqtt_client.common.config_loader import DeviceConfig -from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger configure_logging() logger = get_logger(__name__) @@ -43,21 +57,32 @@ async def rpc_request_callback(request: RPCRequest): return response +async def attribute_request_callback(requested_attributes_response: RequestedAttributeResponse): + """ + Callback function to handle requested attributes. + :param requested_attributes_response: The requested attribute response object. + """ + logger.info("Received requested attributes response: %s", requested_attributes_response.as_dict()) + + async def main(): stop_event = asyncio.Event() def _shutdown_handler(): stop_event.set() + client.stop() loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): try: - loop.add_signal_handler(sig, _shutdown_handler) + loop.add_signal_handler(sig, _shutdown_handler) # noqa except NotImplementedError: # Windows compatibility fallback - signal.signal(sig, lambda *_: _shutdown_handler()) + signal.signal(sig, lambda *_: _shutdown_handler()) # noqa config = DeviceConfig() + # config.host = "192.168.1.202" + # config.access_token = "ypbn08v8f4klg6oah3r6" config.host = "localhost" config.access_token = "YOUR_ACCESS_TOKEN" @@ -68,68 +93,126 @@ def _shutdown_handler(): logger.info("Connected to ThingsBoard.") - while not stop_event.is_set(): - # --- Attributes --- - - # 1. Raw dict - raw_dict = { - "firmwareVersion": "1.0.4", - "hardwareModel": "TB-SDK-Device" - } - await client.send_attributes(raw_dict) - - logger.info(f"Raw attributes sent: {raw_dict}") + attribute_request = AttributeRequest(["uno"], ["client"]) - # 2. Single AttributeEntry - single_entry = AttributeEntry("mode", "normal") - await client.send_attributes(single_entry) - - logger.info("Single attribute sent: %s", single_entry) - - # 3. List of AttributeEntry - attr_entries = [ - AttributeEntry("maxTemperature", 85), - AttributeEntry("calibrated", True) - ] - await client.send_attributes(attr_entries) - - # --- Telemetry --- - - # 1. Raw dict - raw_dict = { - "temperature": round(uniform(20.0, 30.0), 2), - "humidity": 60 - } - await client.send_telemetry(raw_dict) - - logger.info(f"Raw telemetry sent: {raw_dict}") - - # 2. Single TelemetryEntry (with ts) - single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) - await client.send_telemetry(single_entry) - - logger.info("Single telemetry sent: %s", single_entry) + while not stop_event.is_set(): + # # --- Attributes --- + # + # # 1. Raw dict + # raw_dict = { + # "firmwareVersion": "1.0.4", + # "hardwareModel": "TB-SDK-Device" + # } + # logger.info("Sending attributes...") + # delivery_future = await client.send_attributes(raw_dict) + # if delivery_future: + # logger.info("Awaiting delivery future for raw attributes...") + # result = await delivery_future + # # logger.info("Raw attributes sent: %s, delivery result: %s", raw_dict, result) + # else: + # logger.warning("Delivery future is None, raw attributes may not be sent.") + # + # # logger.info(f"Raw attributes sent: {raw_dict}") + # + # # 2. Single AttributeEntry + # single_entry = AttributeEntry("mode", "normal") + # logger.info("Sending single attribute: %s", single_entry) + # delivery_future = await client.send_attributes(single_entry) + # if delivery_future: + # logger.info("Awaiting delivery future for single attribute...") + # result = await delivery_future + # # logger.info("Single attribute sent: %s, delivery result: %s", single_entry, result) + # else: + # logger.warning("Delivery future is None, single attribute may not be sent.") + # + # # logger.info("Single attribute sent: %s", single_entry) + # + # # 3. List of AttributeEntry + # attr_entries = [ + # AttributeEntry("maxTemperature", 85), + # AttributeEntry("calibrated", True) + # ] + # logger.info("Sending list of attributes: %s", attr_entries) + # delivery_future = await client.send_attributes(attr_entries) + # if delivery_future: + # logger.info("Awaiting delivery future for list of attributes...") + # result = await delivery_future + # # logger.info("List of attributes sent: %s, delivery result: %s", attr_entries, result) + # else: + # logger.warning("Delivery future is None, list of attributes may not be sent.") + # + # # --- Telemetry --- + # + # # 1. Raw dict + # raw_dict = { + # "temperature": round(uniform(20.0, 30.0), 2), + # "humidity": 60 + # } + # logger.info("Sending raw telemetry...") + # delivery_future = await client.send_telemetry(raw_dict) + # if delivery_future: + # logger.info("Awaiting delivery future for raw telemetry...") + # result = await delivery_future + # # logger.info("Raw telemetry sent: %s, delivery result: %s", raw_dict, result) + # else: + # logger.warning("Delivery future is None, raw telemetry may not be sent.") + # + # # logger.info(f"Raw telemetry sent: {raw_dict}") + # + # # 2. Single TelemetryEntry (with ts) + # single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) + # logger.info("Sending single telemetry: %s", single_entry) + # delivery_future = await client.send_telemetry(single_entry) + # if delivery_future: + # logger.info("Awaiting delivery future for single telemetry...") + # result = await delivery_future + # # logger.info("Single telemetry sent: %s, delivery result: %s", single_entry, result) + # else: + # logger.warning("Delivery future is None, single telemetry may not be sent.") + # + # # logger.info("Single telemetry sent: %s", single_entry) # 3. List of TelemetryEntry with mixed timestamps - ts_now = int(datetime.now(UTC).timestamp() * 1000) + telemetry_entries = [] for i in range(100): - telemetry_entries.append(TimeseriesEntry("temperature", i, ts=ts_now-i)) - await client.send_telemetry(telemetry_entries) - logger.info("List of telemetry sent: %s, it took %r milliseconds", len(telemetry_entries), - int(datetime.now(UTC).timestamp() * 1000) - ts_now) + telemetry_entries.append(TimeseriesEntry("temperature", i, ts=int(datetime.now(UTC).timestamp() * 1000)-i)) + ts_now = int(datetime.now(UTC).timestamp() * 1000) + logger.info("Sending list of telemetry entries with mixed timestamps...") + delivery_future = await client.send_telemetry(telemetry_entries) + if delivery_future: + logger.info("Awaiting delivery future for list of telemetry...") + try: + result = await asyncio.wait_for(delivery_future, timeout=5) + except asyncio.TimeoutError: + logger.warning("Delivery future timed out after 5 seconds.") + result = False + except Exception as e: + logger.error("Error while awaiting delivery future: %s", e) + result = False + logger.info("List of telemetry sent: %s, it took %r milliseconds", len(telemetry_entries), + int(datetime.now(UTC).timestamp() * 1000) - ts_now) + logger.info("Delivery result: %s", result) + else: + logger.warning("Delivery future is None, list of telemetry may not be sent.") + + # logger.info("Requesting attributes...") + + # await client.send_attribute_request(attribute_request, attribute_request_callback) try: - await asyncio.wait_for(stop_event.wait(), timeout=2) + logger.info("Waiting for 1 seconds before next iteration...") + await asyncio.wait_for(stop_event.wait(), timeout=1) except asyncio.TimeoutError: - pass + logger.info("Going to next iteration...") - await client.disconnect() logger.info("Disconnected cleanly.") if __name__ == "__main__": try: - asyncio.run(main()) + loop = asyncio.get_event_loop() + loop.set_debug(True) # Enable debug mode for asyncio + loop.run_until_complete(main()) except KeyboardInterrupt: print("Interrupted by user.") diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index e4529d5..00c53bf 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import time diff --git a/examples/device/send_telemetry_and_attr.py b/examples/device/send_telemetry_and_attr.py index 1f734db..aa4a320 100644 --- a/examples/device/send_telemetry_and_attr.py +++ b/examples/device/send_telemetry_and_attr.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo diff --git a/examples/device/send_telemetry_pack.py b/examples/device/send_telemetry_pack.py index 0a426ee..571e5ec 100644 --- a/examples/device/send_telemetry_pack.py +++ b/examples/device/send_telemetry_pack.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo diff --git a/examples/device/subscription_to_attrs.py b/examples/device/subscription_to_attrs.py index 5e5c93f..604e6c4 100644 --- a/examples/device/subscription_to_attrs.py +++ b/examples/device/subscription_to_attrs.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import time diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index 606f9aa..5dbd36d 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from tb_device_mqtt import TBDeviceMqttClient diff --git a/examples/gateway/claiming_device_pe_only.py b/examples/gateway/claiming_device_pe_only.py index 4fdd8a5..3a6bfb0 100644 --- a/examples/gateway/claiming_device_pe_only.py +++ b/examples/gateway/claiming_device_pe_only.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging diff --git a/examples/gateway/connect_disconnect_device.py b/examples/gateway/connect_disconnect_device.py index c6a45c9..c32068c 100644 --- a/examples/gateway/connect_disconnect_device.py +++ b/examples/gateway/connect_disconnect_device.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from tb_gateway_mqtt import TBGatewayMqttClient diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index 765fa54..addfb7c 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -1,16 +1,17 @@ -# Copyright 2025. ThingsBoard + +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import time diff --git a/examples/gateway/respond_to_rpc.py b/examples/gateway/respond_to_rpc.py index 2d15096..deca40d 100644 --- a/examples/gateway/respond_to_rpc.py +++ b/examples/gateway/respond_to_rpc.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging.handlers import time diff --git a/examples/gateway/send_telemetry_and_attributes.py b/examples/gateway/send_telemetry_and_attributes.py index 279255c..d322c32 100644 --- a/examples/gateway/send_telemetry_and_attributes.py +++ b/examples/gateway/send_telemetry_and_attributes.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time import logging diff --git a/examples/gateway/subscribe_to_attributes.py b/examples/gateway/subscribe_to_attributes.py index 463a361..7b00576 100644 --- a/examples/gateway/subscribe_to_attributes.py +++ b/examples/gateway/subscribe_to_attributes.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging.handlers import time diff --git a/examples/gateway/tls_connect.py b/examples/gateway/tls_connect.py index 617472f..4315508 100644 --- a/examples/gateway/tls_connect.py +++ b/examples/gateway/tls_connect.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging from tb_gateway_mqtt import TBGatewayMqttClient diff --git a/sdk_utils.py b/sdk_utils.py index 07e9a16..528104c 100644 --- a/sdk_utils.py +++ b/sdk_utils.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from random import randint from zlib import crc32 diff --git a/setup.py b/setup.py index 043e216..1118092 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from os import path from setuptools import setup @@ -35,4 +34,4 @@ long_description_content_type="text/markdown", python_requires=">=3.12", packages=["."], - install_requires=['gmqtt', 'requests>=2.31.0', 'orjson']) + install_requires=['gmqtt', 'orjson']) diff --git a/tb_device_http.py b/tb_device_http.py index c4a1e5d..140ca53 100644 --- a/tb_device_http.py +++ b/tb_device_http.py @@ -1,18 +1,18 @@ -# Copyright 2025. ThingsBoard +"""ThingsBoard HTTP API device module.""" +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -"""ThingsBoard HTTP API device module.""" import threading import logging import queue diff --git a/tb_device_mqtt.py b/tb_device_mqtt.py index cb4ff28..29990db 100644 --- a/tb_device_mqtt.py +++ b/tb_device_mqtt.py @@ -1,18 +1,22 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is maintained for backward compatibility with version 1 of the SDK. +# It is recommended to use the new SDK structure in tb_mqtt_client for new projects. import logging +import warnings from copy import deepcopy from inspect import signature from time import sleep @@ -23,6 +27,14 @@ from utils import install_package from os import environ +# Show deprecation warning +warnings.warn( + "The tb_device_mqtt module is deprecated and will be removed in a future version. " + "Please use tb_mqtt_client.service.device.client.DeviceClient instead.", + DeprecationWarning, + stacklevel=2 +) + def check_tb_paho_mqtt_installed(): try: dists = metadata.distributions() diff --git a/tb_gateway_mqtt.py b/tb_gateway_mqtt.py index d4ddf6e..03a729c 100644 --- a/tb_gateway_mqtt.py +++ b/tb_gateway_mqtt.py @@ -1,19 +1,22 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is maintained for backward compatibility with version 1 of the SDK. +# It is recommended to use the new SDK structure in tb_mqtt_client for new projects. import logging +import warnings try: from time import monotonic as time @@ -22,6 +25,14 @@ from tb_device_mqtt import TBDeviceMqttClient, RateLimit, TBSendMethod +# Show deprecation warning +warnings.warn( + "The tb_gateway_mqtt module is deprecated and will be removed in a future version. " + "Please use tb_mqtt_client.service.gateway.client.GatewayClient instead.", + DeprecationWarning, + stacklevel=2 +) + GATEWAY_ATTRIBUTES_TOPIC = "v1/gateway/attributes" GATEWAY_TELEMETRY_TOPIC = "v1/gateway/telemetry" GATEWAY_DISCONNECT_TOPIC = "v1/gateway/disconnect" diff --git a/tb_mqtt_client/__init__.py b/tb_mqtt_client/__init__.py index 6d6dc0d..fa669aa 100644 --- a/tb_mqtt_client/__init__.py +++ b/tb_mqtt_client/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/common/__init__.py b/tb_mqtt_client/common/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/common/__init__.py +++ b/tb_mqtt_client/common/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index 9be28f9..ca3b1a2 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -1,20 +1,19 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os -from typing import Optional, Dict, Any +from typing import Optional class DeviceConfig: diff --git a/tb_mqtt_client/common/exceptions.py b/tb_mqtt_client/common/exceptions.py index 96bbea1..a4a3517 100644 --- a/tb_mqtt_client/common/exceptions.py +++ b/tb_mqtt_client/common/exceptions.py @@ -1,17 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio import logging diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index a2ea769..4ccd951 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -1,31 +1,29 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - -import struct import asyncio -from types import MethodType -from typing import Callable +import struct from collections import defaultdict +from typing import Callable -from tb_mqtt_client.common.logging_utils import get_logger +from gmqtt.mqtt.constants import MQTTCommands +from gmqtt.mqtt.handler import MqttPackageHandler from gmqtt.mqtt.property import Property -from gmqtt.mqtt.utils import unpack_variable_byte_integer from gmqtt.mqtt.protocol import BaseMQTTProtocol, MQTTProtocol -from gmqtt.mqtt.handler import MqttPackageHandler -from gmqtt.mqtt.constants import MQTTCommands +from gmqtt.mqtt.utils import unpack_variable_byte_integer + +from tb_mqtt_client.common.logging_utils import get_logger logger = get_logger(__name__) @@ -86,7 +84,7 @@ def patch_mqtt_handler_disconnect(): """ try: # Store the original method - original_handle_disconnect = MqttPackageHandler._handle_disconnect_packet + original_handle_disconnect = MqttPackageHandler._handle_disconnect_packet # noqa # Define the patched method def patched_handle_disconnect_packet(self, cmd, packet): @@ -100,8 +98,8 @@ def patched_handle_disconnect_packet(self, cmd, packet): if packet and len(packet) > 1: try: properties, _ = self._parse_properties(packet[1:]) - except Exception as e: - logger.warning("Failed to parse properties from disconnect packet: %s", e) + except Exception as exc: + logger.warning("Failed to parse properties from disconnect packet: %s", exc) reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") logger.debug("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) @@ -119,6 +117,7 @@ def patched_handle_disconnect_packet(self, cmd, packet): # Set a flag on the connection object to indicate that on_disconnect has been called self._connection._on_disconnect_called = True + original_handle_disconnect(self, cmd, packet) # Apply the patch MqttPackageHandler._handle_disconnect_packet = patched_handle_disconnect_packet @@ -133,7 +132,7 @@ def patch_gmqtt_protocol_connection_lost(): Monkey-patch gmqtt.mqtt.protocol.BaseMQTTProtocol.connection_lost to suppress the default "[CONN CLOSE NORMALLY]" log message, as we handle disconnect logging in our code. - Also patch MQTTProtocol.connection_lost to include the reason code in the DISCONNECT package + Also, patch MQTTProtocol.connection_lost to include the reason code in the DISCONNECT package and pass the exception to the handler. """ try: @@ -150,7 +149,7 @@ def patched_mqtt_connection_lost(self, exc): properties = {} if exc: - # Determine reason code based on exception type + # Determine reason code based on an exception type if isinstance(exc, ConnectionRefusedError): reason_code = 135 # Keep Alive timeout elif isinstance(exc, TimeoutError): @@ -166,7 +165,7 @@ def patched_mqtt_connection_lost(self, exc): else: reason_code = 131 # Implementation specific error - # Add exception message to properties if available + # Add an exception message to properties if available if hasattr(exc, 'args') and exc.args: properties['reason_string'] = [str(exc.args[0])] @@ -205,18 +204,19 @@ def patched_call(self, cmd, packet): exc = getattr(self._connection, '_disconnect_exc', None) # Check if on_disconnect has already been called - if not hasattr(self._connection, '_on_disconnect_called') or not self._connection._on_disconnect_called: + if (not hasattr(self._connection, '_on_disconnect_called') + or not self._connection._on_disconnect_called): # noqa # Call on_disconnect with the extracted values self._clear_topics_aliases() future = asyncio.ensure_future(self.reconnect(delay=True)) future.add_done_callback(self._handle_exception_in_future) self.on_disconnect(self, reason_code, properties, exc) - return + return None # For other commands, call the original method return original_call(self, cmd, packet) - except Exception as e: - logger.error('[ERROR HANDLE PKG]', exc_info=e) + except Exception as exception: + logger.error('[ERROR HANDLE PKG]', exc_info=exception) return None MqttPackageHandler.__call__ = patched_call @@ -244,14 +244,13 @@ def patch_gmqtt_puback(client, on_puback_with_reason_and_properties: Callable[[i def _parse_properties(packet: bytes) -> dict: """ - Parse MQTT 5.0 properties from packet. + Parse MQTT 5.0 properties from a packet. """ properties_dict = defaultdict(list) try: - properties_len, packet = unpack_variable_byte_integer(packet) + properties_len, _ = unpack_variable_byte_integer(packet) props = packet[:properties_len] - packet = packet[properties_len:] while props: property_identifier = props[0] @@ -288,4 +287,3 @@ def wrapped_handle_puback(self, cmd, packet): return base_method(self, cmd, packet) MqttPackageHandler._handle_puback_packet = wrapped_handle_puback - # client._handle_puback_packet = MethodType(wrapped_handle_puback, client) diff --git a/tb_mqtt_client/common/logging_utils.py b/tb_mqtt_client/common/logging_utils.py index c5175df..5ff651a 100644 --- a/tb_mqtt_client/common/logging_utils.py +++ b/tb_mqtt_client/common/logging_utils.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import sys @@ -26,12 +24,15 @@ logging.addLevelName(TRACE_LEVEL, "TRACE") -def trace(self, message, *args, **kwargs): - if self.isEnabledFor(TRACE_LEVEL): - self._log(TRACE_LEVEL, message, args, **kwargs) - +class ExtendedLogger(logging.Logger): + """ + Custom logger class that supports TRACE level logging. + """ + def trace(self, message, *args, **kwargs): + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, message, args, **kwargs) -logging.Logger.trace = trace +logging.setLoggerClass(ExtendedLogger) def configure_logging(level: int = logging.INFO, @@ -52,8 +53,8 @@ def configure_logging(level: int = logging.INFO, ) -def get_logger(name: Optional[str] = None) -> logging.Logger: +def get_logger(name: Optional[str] = None) -> ExtendedLogger: """ Returns a logger instance with the given name. """ - return logging.getLogger(name or __name__) + return logging.getLogger(name or __name__) # noqa diff --git a/tb_mqtt_client/common/provision_client.py b/tb_mqtt_client/common/provision_client.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/common/provision_client.py +++ b/tb_mqtt_client/common/provision_client.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/common/rate_limit/__init__.py b/tb_mqtt_client/common/rate_limit/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/common/rate_limit/__init__.py +++ b/tb_mqtt_client/common/rate_limit/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py index 64ba2b8..8436a12 100644 --- a/tb_mqtt_client/common/rate_limit/backpressure_controller.py +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -1,19 +1,18 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - +from asyncio import Event from datetime import datetime, timedelta, UTC from typing import Optional @@ -23,27 +22,31 @@ class BackpressureController: - def __init__(self): + def __init__(self, main_stop_event: Event): + self.__main_stop_event = main_stop_event self._pause_until: Optional[datetime] = None self._default_pause_duration = timedelta(seconds=10) self._consecutive_quota_exceeded = 0 self._last_quota_exceeded = datetime.now(UTC) - self._max_backoff_seconds = 3600 # 1 hour maximum backoff + self._max_backoff_seconds = 3600 # 1 hour def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): + if self.__main_stop_event.is_set(): + logger.trace("Main stop event is set, not applying backpressure") + return now = datetime.now(UTC) - # If we've had a quota exceeded event in the last 60 seconds, increment the counter + # If we've had a quota-exceeded event in the last 60 seconds, increment the counter if (now - self._last_quota_exceeded).total_seconds() < 60: self._consecutive_quota_exceeded += 1 else: - # Reset counter if it's been more than 60 seconds since the last quota exceeded event + # Reset counter if it's been more than 60 seconds since the last quota exceeded the event self._consecutive_quota_exceeded = 1 self._last_quota_exceeded = now - # Apply exponential backoff based on consecutive quota exceeded events + # Apply exponential backoff based on consecutive quota-exceeded events if delay_seconds is None: - # Start with default duration and apply exponential backoff + # Start with the default duration and apply exponential backoff backoff_factor = min(2 ** (self._consecutive_quota_exceeded - 1), 10) delay_seconds = int(self._default_pause_duration.total_seconds() * backoff_factor) # Cap at max backoff @@ -56,6 +59,9 @@ def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): self._pause_until = now + duration def notify_disconnect(self, delay_seconds: Optional[int] = None): + if self.__main_stop_event.is_set(): + logger.trace("Main stop event is set, not pausing publishing") + return if delay_seconds is None: delay_seconds = int(self._default_pause_duration.total_seconds()) @@ -64,6 +70,9 @@ def notify_disconnect(self, delay_seconds: Optional[int] = None): logger.debug("Pausing publishing for %d seconds due to disconnect", delay_seconds) def should_pause(self) -> bool: + if self.__main_stop_event.is_set(): + logger.trace("Main stop event is set, not checking pause state") + return False if self._pause_until is None: return False diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py index a8cc4b1..ddddb1e 100644 --- a/tb_mqtt_client/common/rate_limit/rate_limit.py +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -1,22 +1,20 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import logging -from threading import RLock +from asyncio import Lock from time import monotonic logger = logging.getLogger(__name__) @@ -66,7 +64,7 @@ def __init__(self, rate_limit: str, name: str = None, percentage: int = DEFAULT_ self.percentage = percentage self._no_limit = False self._rate_buckets = {} - self._lock = RLock() + self._lock = Lock() self._minimal_timeout = DEFAULT_TIMEOUT self._minimal_limit = float('inf') self.__reached_index = 0 @@ -95,11 +93,11 @@ def _parse_string(self, rate_limit: str): self._no_limit = not bool(self._rate_buckets) - def check_limit_reached(self, amount=1): + async def check_limit_reached(self, amount=1): if self._no_limit: return False - with self._lock: + async with self._lock: result = False for dur, bucket in self._rate_buckets.items(): bucket.refill() @@ -107,15 +105,15 @@ def check_limit_reached(self, amount=1): result = (bucket.capacity, dur) return result - def refill(self): + async def refill(self): """Force refill of all token buckets without consuming any tokens.""" if self._no_limit: return - with self._lock: + async with self._lock: for bucket in self._rate_buckets.values(): bucket.refill() - def try_consume(self, amount=1): + async def try_consume(self, amount=1): """ Try to consume tokens from all buckets. Returns True if all buckets had enough tokens and they were consumed. @@ -124,7 +122,7 @@ def try_consume(self, amount=1): if self._no_limit: return None - with self._lock: + async with self._lock: for bucket in self._rate_buckets.values(): bucket.refill() if bucket.tokens < amount: @@ -135,10 +133,10 @@ def try_consume(self, amount=1): return None - def consume(self, amount=1): + async def consume(self, amount=1): if self._no_limit: return - with self._lock: + async with self._lock: for bucket in self._rate_buckets.values(): bucket.consume(amount) @@ -153,11 +151,11 @@ def minimal_timeout(self): def has_limit(self): return not self._no_limit - def reach_limit(self): + async def reach_limit(self): if self._no_limit: - return + return None - with self._lock: + async with self._lock: durations = sorted(self._rate_buckets.keys()) now = monotonic() @@ -190,8 +188,8 @@ def to_dict(self): } } - def set_limit(self, rate_limit: str, percentage: int = DEFAULT_RATE_LIMIT_PERCENTAGE): - with self._lock: + async def set_limit(self, rate_limit: str, percentage: int = DEFAULT_RATE_LIMIT_PERCENTAGE): + async with self._lock: self._rate_buckets.clear() self._minimal_timeout = DEFAULT_TIMEOUT self._minimal_limit = float('inf') diff --git a/tb_mqtt_client/common/request_id_generator.py b/tb_mqtt_client/common/request_id_generator.py index 9e961f1..a5df480 100644 --- a/tb_mqtt_client/common/request_id_generator.py +++ b/tb_mqtt_client/common/request_id_generator.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from asyncio import Lock @@ -42,3 +40,30 @@ def reset(cls): Reset the global request ID counter (usually on disconnect). """ cls._counter = 1 + + +class AttributeRequestIdProducer: + """ + Singleton-style producer of unique attribute request IDs, + safe for async environments and shared across all SDK services. + """ + + _counter: int = 1 + _lock: Lock = Lock() + + @classmethod + async def get_next(cls) -> int: + """ + Atomically increment and return the next attribute request ID. + """ + async with cls._lock: + current = cls._counter + cls._counter += 1 + return current + + @classmethod + def reset(cls): + """ + Reset the global attribute request ID counter (usually on disconnect). + """ + cls._counter = 1 diff --git a/tb_mqtt_client/constants/__init__.py b/tb_mqtt_client/constants/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/constants/__init__.py +++ b/tb_mqtt_client/constants/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index f757263..378d2c0 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -1,22 +1,20 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - WILDCARD = "+" REQUEST_TOPIC_SUFFIX = "/request" RESPONSE_TOPIC_SUFFIX = "/response" +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # V1 Topics for Device API DEVICE_TELEMETRY_TOPIC = "v1/devices/me/telemetry" DEVICE_ATTRIBUTES_TOPIC = "v1/devices/me/attributes" @@ -26,6 +24,7 @@ # Device RPC topics DEVICE_RPC_REQUEST_TOPIC = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" DEVICE_RPC_RESPONSE_TOPIC = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" +# Device RPC topics for subscription DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + WILDCARD DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD diff --git a/tb_mqtt_client/constants/service_keys.py b/tb_mqtt_client/constants/service_keys.py index 624bd8e..befebaa 100644 --- a/tb_mqtt_client/constants/service_keys.py +++ b/tb_mqtt_client/constants/service_keys.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. MESSAGES_RATE_LIMIT = "MESSAGES_RATE_LIMIT" TELEMETRY_MESSAGE_RATE_LIMIT = "TELEMETRY_MESSAGE_RATE_LIMIT" diff --git a/tb_mqtt_client/constants/service_messages.py b/tb_mqtt_client/constants/service_messages.py index faaa3ff..daaa25d 100644 --- a/tb_mqtt_client/constants/service_messages.py +++ b/tb_mqtt_client/constants/service_messages.py @@ -1,17 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from orjson import dumps diff --git a/tb_mqtt_client/entities/__init__.py b/tb_mqtt_client/entities/__init__.py index 6d6dc0d..fa669aa 100644 --- a/tb_mqtt_client/entities/__init__.py +++ b/tb_mqtt_client/entities/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/entities/data/__init__.py b/tb_mqtt_client/entities/data/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/entities/data/__init__.py +++ b/tb_mqtt_client/entities/data/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/entities/data/attribute_entry.py b/tb_mqtt_client/entities/data/attribute_entry.py index cd5eef7..fce1c3b 100644 --- a/tb_mqtt_client/entities/data/attribute_entry.py +++ b/tb_mqtt_client/entities/data/attribute_entry.py @@ -1,17 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py index 59e3bfa..54254ef 100644 --- a/tb_mqtt_client/entities/data/attribute_request.py +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -1,14 +1,64 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + + +class AttributeRequest: + """ + Represents a request for attributes, including shared and client attributes. + This class is used to encapsulate the details of an attribute request. + If some scope is not needed, it can be set to None. + shared: list + A list of shared attribute keys to the request. If empty - all shared attributes will be requested. + client: list + A list of client attribute keys to the request. If empty - all client attributes will be requested. + """ + + def __init__(self, shared: list, client: list): + self._id: Union[int, None] = None + self.shared_keys = shared + self.client_keys = client + + @property + def id(self) -> Union[int, None]: + """ + Get the unique ID for this attribute request. + :return: Unique identifier for the request or None if not set. + """ + return self._id + + @id.setter + def id(self, value: int): + """ + Set the unique ID for this attribute request. + :param value: Unique identifier to set for the request. + """ + if not isinstance(value, int): + raise ValueError("ID must be an integer.") + self._id = value + + def __repr__(self): + return f"" + + def to_payload_format(self) -> dict: + """ + Convert the attribute request to a payload format suitable for sending over MQTT to the platform. + """ + formatted_request = {} + if self.shared_keys is not None: + formatted_request["sharedKeys"] = ','.join(self.shared_keys) + if self.client_keys is not None: + formatted_request["clientKeys"] = ','.join(self.client_keys) + return formatted_request diff --git a/tb_mqtt_client/entities/data/attribute_response.py b/tb_mqtt_client/entities/data/attribute_response.py deleted file mode 100644 index 59e3bfa..0000000 --- a/tb_mqtt_client/entities/data/attribute_response.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/tb_mqtt_client/entities/data/attribute_update.py b/tb_mqtt_client/entities/data/attribute_update.py index 731c1bc..2ac3deb 100644 --- a/tb_mqtt_client/entities/data/attribute_update.py +++ b/tb_mqtt_client/entities/data/attribute_update.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass from typing import Dict, Any, List @@ -50,7 +48,6 @@ def from_dict(cls, data: Dict[str, Any]) -> 'AttributeUpdate': """ Deserialize dictionary into AttributeUpdate object. :param data: Dictionary of attribute key-value pairs. - :param device: Optional device name (for gateway context). :return: AttributeUpdate instance. """ entries = [AttributeEntry(k, v) for k, v in data.items()] diff --git a/tb_mqtt_client/entities/data/data_entry.py b/tb_mqtt_client/entities/data/data_entry.py index a6920c9..5d0f0ac 100644 --- a/tb_mqtt_client/entities/data/data_entry.py +++ b/tb_mqtt_client/entities/data/data_entry.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Optional from orjson import dumps diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index 7b9472b..e13e596 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -1,25 +1,24 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio -from typing import List, Optional, Union, OrderedDict, Dict +from typing import List, Optional, Union, OrderedDict from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.publish_result import PublishResult logger = get_logger(__name__) @@ -33,7 +32,7 @@ def __init__(self, attributes: Optional[List[AttributeEntry]] = None, timeseries: Optional[OrderedDict[int, List[TimeseriesEntry]]] = None, _size: Optional[int] = None, - delivery_future: List[Optional[asyncio.Future[bool]]] = None): + delivery_future: List[Optional[asyncio.Future[PublishResult]]] = None): if _size is None: raise ValueError("DeviceUplinkMessage must be created using DeviceUplinkMessageBuilder") @@ -71,7 +70,7 @@ def __init__(self): self._device_profile: Optional[str] = None self._attributes: List[AttributeEntry] = [] self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() - self._delivery_futures: List[Optional[asyncio.Future[bool]]] = [] + self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] self.__size = DEFAULT_FIELDS_SIZE def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': @@ -115,12 +114,12 @@ def add_telemetry(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry] self.__size += timeseries_entry.size return self - def add_delivery_futures(self, future: Union[asyncio.Future[bool], List[asyncio.Future[bool]]]) -> 'DeviceUplinkMessageBuilder': - if not isinstance(future, list): - future = [future] - if future: - logger.exception("Created delivery future: %s", id(future[0])) - self._delivery_futures.extend(future) + def add_delivery_futures(self, futures: Union[asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(futures, list): + futures = [futures] + if futures: + logger.debug("Created delivery futures: %r", [id(future) for future in futures]) + self._delivery_futures.extend(futures) return self def build(self) -> DeviceUplinkMessage: diff --git a/tb_mqtt_client/entities/data/requested_attribute_response.py b/tb_mqtt_client/entities/data/requested_attribute_response.py new file mode 100644 index 0000000..1ba8857 --- /dev/null +++ b/tb_mqtt_client/entities/data/requested_attribute_response.py @@ -0,0 +1,93 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Any, List + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@dataclass(slots=True, frozen=True) +class RequestedAttributeResponse: + + request_id: int + shared: List[AttributeEntry] + client: List[AttributeEntry] + + def __repr__(self): + return f"" + + def __getitem__(self, item): + """ + Allows access to values using dictionary-like syntax. + """ + for entry in self.shared: + if entry.key == item: + return entry.value + for entry in self.client: + if entry.key == item: + return entry.value + raise KeyError(f"Key '{item}' not found in shared or client attributes.") + + def shared_keys(self): + return [entry.key for entry in self.shared] + + def client_keys(self): + return [entry.key for entry in self.client] + + def get_shared(self, key: str, default=None): + """ + Get the value of a shared attribute by key. + :param key: The key of the shared attribute. + :param default: Default value if the key is not found. + :return: Value of the shared attribute or default. + """ + for entry in self.shared: + if entry.key == key: + return entry.value + return default + + def get_client(self, key: str, default=None): + """ + Get the value of a client attribute by key. + :param key: The key of the client attribute. + :param default: Default value if the key is not found. + :return: Value of the client attribute or default. + """ + for entry in self.client: + if entry.key == key: + return entry.value + return default + + def as_dict(self) -> Dict[str, Any]: + """ + Convert the AttributeResponse to a dictionary format. + :return: Dictionary representation of the response. + """ + return { + 'shared': [entry.as_dict() for entry in self.shared], + 'client': [entry.as_dict() for entry in self.client] + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'RequestedAttributeResponse': + """ + Deserialize dictionary into AttributeResponse object. + :param data: Dictionary containing 'shared' and 'client' attributes. + :return: AttributeResponse instance. + """ + shared = [AttributeEntry(k, v) for k, v in data.get('shared', {}).items()] + client = [AttributeEntry(k, v) for k, v in data.get('client', {}).items()] + request_id = data.get('request_id', -1) # Default to -1 if not provided + return cls(shared=shared, client=client, request_id=request_id) diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py index d09ad1b..18ce603 100644 --- a/tb_mqtt_client/entities/data/rpc_request.py +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -1,28 +1,26 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass -from typing import Union, Optional, Dict, Any, List +from typing import Union, Optional, Dict, Any @dataclass(slots=True, frozen=True) class RPCRequest: request_id: Union[int, str] method: str - params: Optional[Union[Dict[str, Any], List[Any]]] = None + params: Optional[Any] = None def to_dict(self) -> Dict[str, Any]: result = { diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py index 12b0faf..c4b0c20 100644 --- a/tb_mqtt_client/entities/data/rpc_response.py +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass from typing import Union, Optional, Dict, Any @@ -32,7 +30,7 @@ class RPCResponse: result: Optional[Any] = None error: Optional[Union[str, Dict[str, Any]]] = None - def to_dict(self) -> Dict[str, Any]: + def to_payload_format(self) -> Dict[str, Any]: """Serializes the RPC response for publishing.""" data = {} if self.result is not None: diff --git a/tb_mqtt_client/entities/data/timeseries_entry.py b/tb_mqtt_client/entities/data/timeseries_entry.py index 4f84a7b..137ea43 100644 --- a/tb_mqtt_client/entities/data/timeseries_entry.py +++ b/tb_mqtt_client/entities/data/timeseries_entry.py @@ -1,33 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # - +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Optional diff --git a/tb_mqtt_client/entities/gateway/__init__.py b/tb_mqtt_client/entities/gateway/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/entities/gateway/__init__.py +++ b/tb_mqtt_client/entities/gateway/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/entities/gateway/device_session_state.py b/tb_mqtt_client/entities/gateway/device_session_state.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/entities/gateway/device_session_state.py +++ b/tb_mqtt_client/entities/gateway/device_session_state.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/entities/gateway/rpc_context.py b/tb_mqtt_client/entities/gateway/rpc_context.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/entities/gateway/rpc_context.py +++ b/tb_mqtt_client/entities/gateway/rpc_context.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/entities/gateway/virtual_device.py b/tb_mqtt_client/entities/gateway/virtual_device.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/entities/gateway/virtual_device.py +++ b/tb_mqtt_client/entities/gateway/virtual_device.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/entities/publish_result.py b/tb_mqtt_client/entities/publish_result.py new file mode 100644 index 0000000..2f4497c --- /dev/null +++ b/tb_mqtt_client/entities/publish_result.py @@ -0,0 +1,33 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class PublishResult: + def __init__(self, topic: str, qos: int, message_id: int, payload_size: int, reason_code: int): + self.topic = topic + self.qos = qos + self.message_id = message_id + self.payload_size = payload_size + self.reason_code = reason_code + + def __repr__(self): + return f"" + + def as_dict(self) -> dict: + return { + "topic": self.topic, + "qos": self.qos, + "message_id": self.message_id, + "payload_size": self.payload_size, + "reason_code": self.reason_code + } diff --git a/tb_mqtt_client/service/__init__.py b/tb_mqtt_client/service/__init__.py index 6d6dc0d..fa669aa 100644 --- a/tb_mqtt_client/service/__init__.py +++ b/tb_mqtt_client/service/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 2594634..70e2746 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -1,25 +1,24 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - -from abc import ABC, abstractmethod import asyncio -import uvloop +from abc import ABC, abstractmethod from typing import Callable, Awaitable, Dict, Any, Union -from tb_mqtt_client.common.exceptions import exception_handler +import uvloop + +from tb_mqtt_client.common.exceptions import exception_handler from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -76,7 +75,7 @@ async def send_rpc_response(self, response: RPCResponse): """ Send a response to a server-initiated RPC request. - :param RPCResponse response: The RPC response to send. + :param RPCResponse response: The RPC response for sending to the platform. """ pass diff --git a/tb_mqtt_client/service/device/__init__.py b/tb_mqtt_client/service/device/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/device/__init__.py +++ b/tb_mqtt_client/service/device/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/device/attribute_updates_handler.py b/tb_mqtt_client/service/device/attribute_updates_handler.py deleted file mode 100644 index e0b593e..0000000 --- a/tb_mqtt_client/service/device/attribute_updates_handler.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from typing import Awaitable, Callable, Optional -from orjson import loads - -from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.common.logging_utils import get_logger - -logger = get_logger(__name__) - - -class AttributeUpdatesHandler: - """ - Handles shared attribute update messages from the platform. - """ - - def __init__(self): - self._callback: Optional[Callable[[AttributeUpdate], Awaitable[None]]] = None - - def set_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): - """ - Sets the async callback that will be triggered on shared attribute update. - - :param callback: A coroutine that takes an AttributeUpdate object. - """ - self._callback = callback - - async def handle(self, topic: str, payload: bytes): - if not self._callback: - logger.debug("No attribute update callback set. Skipping payload.") - return - - try: - data = loads(payload) - update = AttributeUpdate.from_dict(data) - await self._callback(update) - except Exception as e: - logger.exception("Failed to handle attribute update: %s", e) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 8eeea7b..09dc7bd 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -1,47 +1,53 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # -from asyncio import sleep, wait_for, TimeoutError +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from orjson import loads, dumps +from asyncio import sleep, wait_for, TimeoutError, Event from random import choices from string import ascii_uppercase, digits from typing import Callable, Awaitable, Optional, Dict, Any, Union, List -from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from orjson import loads, dumps + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler +from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ + RequestedAttributeResponseHandler +from tb_mqtt_client.service.device.handlers.rpc_requests_handler import RPCRequestsHandler +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher, MessageDispatcher from tb_mqtt_client.service.message_queue import MessageQueue from tb_mqtt_client.service.mqtt_manager import MQTTManager -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher -from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.common.config_loader import DeviceConfig -from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.service.device.attribute_updates_handler import AttributeUpdatesHandler -from tb_mqtt_client.service.device.rpc_requests_handler import RPCRequestsHandler -from tb_mqtt_client.service.rpc_response_handler import RPCResponseHandler logger = get_logger(__name__) class DeviceClient(BaseClient): def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): + self._stop_event = Event() self._config = None if isinstance(config, DeviceConfig): self._config = config @@ -56,6 +62,9 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): super().__init__(self._config.host, self._config.port, client_id) + self._message_queue: Optional[MessageQueue] = None + self._message_dispatcher: Optional[MessageDispatcher] = None + self._messages_rate_limit = RateLimit("0:0,", name="messages") self._telemetry_rate_limit = RateLimit("0:0,", name="telemetry") self._telemetry_dp_rate_limit = RateLimit("0:0,", name="telemetryDataPoints") @@ -66,15 +75,14 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._rpc_response_handler = RPCResponseHandler() - self._mqtt_manager = MQTTManager(self._client_id, - self._on_connect, - self._on_disconnect, - self._handle_rate_limit_response, + self._mqtt_manager = MQTTManager(client_id=self._client_id, + main_stop_event=self._stop_event, + on_connect=self._on_connect, + on_disconnect=self._on_disconnect, + rate_limits_handler=self._handle_rate_limit_response, rpc_response_handler=self._rpc_response_handler,) - self._message_queue: Optional[MessageQueue] = None - self._dispatcher: Optional[JsonMessageDispatcher] = None - + self._requested_attribute_response_handler = RequestedAttributeResponseHandler() self._attribute_updates_handler = AttributeUpdatesHandler() self._rpc_requests_handler = RPCRequestsHandler() @@ -108,20 +116,37 @@ async def connect(self): self.max_payload_size = 65535 logger.debug("Using default max_payload_size: %d", self.max_payload_size) - self._dispatcher = JsonMessageDispatcher(self.max_payload_size, self._telemetry_dp_rate_limit.minimal_limit) + self._message_dispatcher = JsonMessageDispatcher(self.max_payload_size, self._telemetry_dp_rate_limit.minimal_limit) self._message_queue = MessageQueue( mqtt_manager=self._mqtt_manager, + main_stop_event=self._stop_event, message_rate_limit=self._messages_rate_limit, telemetry_rate_limit=self._telemetry_rate_limit, telemetry_dp_rate_limit=self._telemetry_dp_rate_limit, - message_dispatcher=self._dispatcher, + message_dispatcher=self._message_dispatcher, max_queue_size=self._max_uplink_message_queue_size, ) + self._requested_attribute_response_handler.set_message_dispatcher(self._message_dispatcher) + self._attribute_updates_handler.set_message_dispatcher(self._message_dispatcher) + self._rpc_requests_handler.set_message_dispatcher(self._message_dispatcher) + self._rpc_response_handler.set_message_dispatcher(self._message_dispatcher) + + def stop(self): + """ + Stops the client and disconnects from the MQTT broker. + """ + logger.info("Stopping DeviceClient...") + self._stop_event.set() + if self._mqtt_manager.is_connected(): + self._mqtt_manager.disconnect() + logger.info("DeviceClient stopped.") + async def disconnect(self): await self._mqtt_manager.disconnect() - if self._message_queue: - await self._message_queue.shutdown() + # if self._message_queue: + # await self._message_queue.shutdown() + # TODO: Not sure if we need to shutdown the message queue here, as it might be handled by MQTTManager async def send_telemetry(self, telemetry_data: Union[Dict[str, Any], TimeseriesEntry, @@ -142,14 +167,26 @@ async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry async def send_rpc_response(self, response: RPCResponse): topic = mqtt_topics.build_device_rpc_response_topic(request_id=response.request_id) + payload = self._message_dispatcher.build_rpc_response_payload(response) await self._message_queue.publish(topic=topic, - payload=dumps(response.to_dict()), + payload=payload, datapoints_count=0) + async def send_attribute_request(self, attribute_request: AttributeRequest, callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + await self._requested_attribute_response_handler.register_request(attribute_request, callback) + + topic = mqtt_topics.build_device_attributes_request_topic(attribute_request.id) + payload = self._message_dispatcher.build_attribute_request_payload(attribute_request) + + await self._message_queue.publish(topic=topic, + payload=payload, + datapoints_count=0) + + def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): self._attribute_updates_handler.set_callback(callback) - def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + def set_rpc_request_callback(self, callback: Callable[[RPCRequest], Awaitable[RPCResponse]]): self._rpc_requests_handler.set_callback(callback) async def _on_connect(self): @@ -162,9 +199,12 @@ async def _on_connect(self): self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_response) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, self._handle_requested_attribute_response) # noqa async def _on_disconnect(self): logger.info("Device client disconnected.") + self._requested_attribute_response_handler.clear() + self._rpc_response_handler.clear() async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Any: """ @@ -172,7 +212,7 @@ async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = No :param method: The RPC method to call. :param params: The parameters to send. :param timeout: Timeout for the response in seconds. - :return: The response result (dict, list, str, etc.), or raises on error. + :return: The response result (dict, list, str, etc.). """ request_id = await RPCRequestIdProducer.get_next() topic = mqtt_topics.build_device_rpc_request_topic(request_id) @@ -200,20 +240,23 @@ async def _handle_rpc_request(self, topic: str, payload: bytes): async def _handle_rpc_response(self, topic: str, payload: bytes): await self._rpc_response_handler.handle(topic, payload) - async def _handle_rate_limit_response(self, topic: str, payload: bytes): + async def _handle_requested_attribute_response(self, topic: str, payload: bytes): + await self._requested_attribute_response_handler.handle(topic, payload) + + async def _handle_rate_limit_response(self, topic: str, payload: bytes): # noqa try: response = loads(payload.decode("utf-8")) logger.debug("Received rate limit response payload: %s", response) if not isinstance(response, dict) or 'rateLimits' not in response: logger.warning("Invalid rate limit response: %r", response) - return + return None rate_limits = response.get('rateLimits', {}) - self._messages_rate_limit.set_limit(rate_limits.get("messages", "0:0,")) - self._telemetry_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,")) - self._telemetry_dp_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,")) + await self._messages_rate_limit.set_limit(rate_limits.get("messages", "0:0,")) + await self._telemetry_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,")) + await self._telemetry_dp_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,")) server_inflight = int(response.get("maxInflightMessages", 100)) limits = [rl.minimal_limit for rl in [ @@ -232,8 +275,8 @@ async def _handle_rate_limit_response(self, topic: str, payload: bytes): if "maxPayloadSize" in response: self.max_payload_size = int(response["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) # Update the dispatcher's max_payload_size if it's already initialized - if hasattr(self, '_dispatcher') and self._dispatcher is not None: - self._dispatcher.splitter.max_payload_size = self.max_payload_size + if hasattr(self, '_dispatcher') and self._message_dispatcher is not None: + self._message_dispatcher.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) else: # If maxPayloadSize is not provided, keep the default value @@ -243,8 +286,8 @@ async def _handle_rate_limit_response(self, topic: str, payload: bytes): self.max_payload_size = 65535 logger.debug("Using default max_payload_size: %d", self.max_payload_size) # Update the dispatcher's max_payload_size if it's already initialized - if hasattr(self, '_dispatcher') and self._dispatcher is not None: - self._dispatcher.splitter.max_payload_size = self.max_payload_size + if hasattr(self, '_dispatcher') and self._message_dispatcher is not None: + self._message_dispatcher.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) if (not self._messages_rate_limit.has_limit() diff --git a/tb_mqtt_client/service/device/handlers/__init__.py b/tb_mqtt_client/service/device/handlers/__init__.py new file mode 100644 index 0000000..fa669aa --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py b/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py new file mode 100644 index 0000000..9787fdd --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py @@ -0,0 +1,63 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher + +logger = get_logger(__name__) + + +class AttributeUpdatesHandler: + """ + Handles shared attribute update messages from the platform. + """ + + def __init__(self): + self._message_dispatcher = None + self._callback: Optional[Callable[[AttributeUpdate], Awaitable[None]]] = None + + def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + """ + Sets the message dispatcher for handling incoming messages. + This should be called before any callbacks are set. + + :param message_dispatcher: An instance of MessageDispatcher. + """ + if not isinstance(message_dispatcher, MessageDispatcher): + raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") + self._message_dispatcher = message_dispatcher + logger.debug("Message dispatcher set for AttributeUpdatesHandler.") + + def set_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + """ + Sets the async callback that will be triggered on shared attribute update. + + :param callback: A coroutine that takes an AttributeUpdate object. + """ + self._callback = callback + + async def handle(self, topic: str, payload: bytes): # noqa + if not self._callback: + logger.debug("No attribute update callback set. Skipping payload.") + return + + try: + data = self._message_dispatcher.parse_attribute_update(payload) + logger.debug("Handling attribute update: %r", data) + await self._callback(data) + except Exception as e: + logger.exception("Failed to handle attribute update: %s", e) diff --git a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py new file mode 100644 index 0000000..7c0d0c3 --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py @@ -0,0 +1,97 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple, Awaitable, Callable + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher + +logger = get_logger(__name__) + + +class RequestedAttributeResponseHandler: + """ + Handles responses to attribute requests sent to the platform. + """ + + def __init__(self): + self._message_dispatcher = None + self._pending_attribute_requests: Dict[int, Tuple[AttributeRequest, Callable[[RequestedAttributeResponse], Awaitable[None]]]] = {} + + def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + """ + Sets the message dispatcher for handling incoming messages. + This should be called before any requests are registered. + """ + if not isinstance(message_dispatcher, MessageDispatcher): + raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") + self._message_dispatcher = message_dispatcher + logger.debug("Message dispatcher set for RequestedAttributeResponseHandler.") + + async def register_request(self, request: AttributeRequest, callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + """ + Called when a request is sent to the platform and a response is awaited. + """ + request.id = await AttributeRequestIdProducer.get_next() + if request.id in self._pending_attribute_requests: + raise RuntimeError(f"Request ID {request.id} is already registered.") + self._pending_attribute_requests[request.id] = (request, callback) + + def unregister_request(self, request_id: int): + """ + Unregisters a request by its ID. + This is useful if the request is no longer needed or has timed out. + """ + if request_id in self._pending_attribute_requests: + self._pending_attribute_requests.pop(request_id) + logger.debug("Unregistered attribute request with ID %s", request_id) + else: + logger.debug("Attempted to unregister non-existent request ID %s", request_id) + + async def handle(self, topic: str, payload: bytes): + """ + Handles the incoming attribute request response. + """ + try: + if not self._message_dispatcher: + logger.error("Message dispatcher is not initialized. Cannot handle attribute response.") + request_id = topic.split('/')[-1] # Assuming request ID is in the topic + self._pending_attribute_requests.pop(int(request_id), (None, None, None)) + return + + requested_attribute_response = self._message_dispatcher.parse_attribute_request_response(topic, payload) + pending_request_details = self._pending_attribute_requests.pop(requested_attribute_response.request_id, None) + if not pending_request_details: + logger.warning("No future awaiting request ID %s. Ignoring.", requested_attribute_response.request_id) + return + + request, callback = pending_request_details + + if callback: + logger.debug("Invoking callback for requested attribute response with ID %s", requested_attribute_response.request_id) + await callback(requested_attribute_response) + else: + logger.error("No callback registered for requested attribute response with ID %s", requested_attribute_response.request_id) + + except Exception as e: + logger.exception("Failed to handle requested attribute response: %s", e) + + def clear(self): + """ + Clears all pending requests (e.g., on disconnect). + """ + self._pending_attribute_requests.clear() diff --git a/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py b/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py new file mode 100644 index 0000000..d80f6b6 --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py @@ -0,0 +1,80 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher + +logger = get_logger(__name__) + + +class RPCRequestsHandler: + """ + Handles incoming RPC request messages for a device. + """ + + def __init__(self): + self._message_dispatcher = None + self._callback: Optional[Callable[[RPCRequest], Awaitable[RPCResponse]]] = None + + def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + """ + Sets the message dispatcher for handling incoming messages. + This should be called before any callbacks are set. + :param message_dispatcher: An instance of MessageDispatcher. + """ + if not isinstance(message_dispatcher, MessageDispatcher): + raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") + self._message_dispatcher = message_dispatcher + logger.debug("Message dispatcher set for RPCRequestsHandler.") + + def set_callback(self, callback: Callable[[RPCRequest], Awaitable[RPCResponse]]): + """ + Set the async callback to handle incoming RPC requests. + :param callback: A coroutine that takes an RPCRequest and returns an RPCResponse. + """ + self._callback = callback + + async def handle(self, topic: str, payload: bytes) -> Optional[RPCResponse]: + """ + Process the RPC request and return the response payload and request ID (if possible). + :returns: (request_id, response_dict) or None if failed + """ + if not self._callback: + logger.debug("No RPC request callback set. Skipping RPC handling. " + "You can add set callback using client.set_rpc_request_callback(your_method)") + return None + + if not self._message_dispatcher: + logger.error("Message dispatcher is not initialized. Cannot handle RPC request.") + return None + + try: + rpc_request = self._message_dispatcher.parse_rpc_request(topic, payload) + logger.debug("Handling RPC method id: %i - %s with params: %s", + rpc_request.request_id, rpc_request.method, rpc_request.params) + result = await self._callback(rpc_request) + if not isinstance(result, RPCResponse): + logger.error("RPC callback must return an instance of RPCResponse, got: %s", type(result)) + return None + logger.debug("RPC response for method id: %i - %s with result: %s", + result.request_id, rpc_request.method, result.result) + return result + + except Exception as e: + logger.exception("Failed to process RPC request: %s", e) + return None diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py new file mode 100644 index 0000000..d99d0f3 --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -0,0 +1,87 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Dict, Union + +from orjson import loads + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher + +logger = get_logger(__name__) + + +class RPCResponseHandler: + """ + Handles RPC responses coming from the platform to the client (client-side RPCs responses). + Maintains an internal map of request_id -> asyncio.Future for awaiting RPC results. + """ + + def __init__(self): + self._message_dispatcher = None + self._pending_rpc_requests: Dict[Union[str, int], asyncio.Future] = {} + + def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + """ + Sets the message dispatcher for handling incoming messages. + This should be called before any requests are registered. + :param message_dispatcher: An instance of MessageDispatcher. + """ + if not isinstance(message_dispatcher, MessageDispatcher): + raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") + self._message_dispatcher = message_dispatcher + logger.debug("Message dispatcher set for RPCResponseHandler.") + + def register_request(self, request_id: Union[str, int]) -> asyncio.Future: + """ + Called when a request is sent to the platform and a response is awaited. + """ + if request_id in self._pending_rpc_requests: + raise RuntimeError(f"Request ID {request_id} is already registered.") + future = asyncio.get_event_loop().create_future() + self._pending_rpc_requests[request_id] = future + return future + + async def handle(self, topic: str, payload: bytes): + """ + Handles the incoming RPC response from the platform and fulfills the corresponding future. + The topic is expected to be: v1/devices/me/rpc/response/{request_id} + """ + try: + # TODO: Use MessageDispatcher to parse the topic and payload + request_id = topic.split("/")[-1] + response_data = loads(payload) + + future = self._pending_rpc_requests.pop(request_id, None) + if not future: + logger.warning("No future awaiting request ID %s. Ignoring.", request_id) + return + + if isinstance(response_data, dict) and "error" in response_data: + future.set_exception(Exception(response_data["error"])) + else: + future.set_result(response_data) + + except Exception as e: + logger.exception("Failed to handle RPC response: %s", e) + + def clear(self): + """ + Clears all pending futures (e.g., on disconnect). + """ + for fut in self._pending_rpc_requests.values(): + if not fut.done(): + fut.cancel() + self._pending_rpc_requests.clear() diff --git a/tb_mqtt_client/service/device/rpc_requests_handler.py b/tb_mqtt_client/service/device/rpc_requests_handler.py deleted file mode 100644 index 2f9d165..0000000 --- a/tb_mqtt_client/service/device/rpc_requests_handler.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from orjson import loads -from typing import Awaitable, Callable, Dict, Optional - -from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.entities.data.rpc_request import RPCRequest -from tb_mqtt_client.entities.data.rpc_response import RPCResponse - -logger = get_logger(__name__) - - -class RPCRequestsHandler: - """ - Handles incoming RPC request messages for a device. - """ - - def __init__(self): - self._callback: Optional[Callable[[str, Dict], Awaitable[Dict]]] = None - - def set_callback(self, callback: Callable[[str, Dict], Awaitable[Dict]]): - """ - Set the async callback to handle incoming RPC requests. - :param callback: A coroutine accepting (method_name, params) and returning a result dict. - """ - self._callback = callback - - async def handle(self, topic: str, payload: bytes) -> Optional[RPCResponse]: - """ - Process the RPC request and return response payload and request ID (if possible). - :returns: (request_id, response_dict) or None if failed - """ - if not self._callback: - logger.debug("No RPC request callback set. Skipping RPC handling. " - "You can add set callback using client.set_rpc_request_callback(your_method)") - return None - - try: - request_id = int(topic.split("/")[-1]) - parsed = loads(payload) - parsed["id"] = request_id - rpc_request = RPCRequest.from_dict(parsed) - - logger.debug("Handling RPC method id: %i - %s with params: %s", - rpc_request.request_id, rpc_request.method, rpc_request.params) - - result = await self._callback(rpc_request) - if not isinstance(result, RPCResponse): - logger.error("RPC callback must return an instance of RPCResponse, got: %s", type(result)) - return None - return result - - except Exception as e: - logger.exception("Failed to process RPC request: %s", e) - return None diff --git a/tb_mqtt_client/service/event_dispatcher.py b/tb_mqtt_client/service/event_dispatcher.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/event_dispatcher.py +++ b/tb_mqtt_client/service/event_dispatcher.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/gateway/__init__.py b/tb_mqtt_client/service/gateway/__init__.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/gateway/__init__.py +++ b/tb_mqtt_client/service/gateway/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index b6f03de..cc6bd43 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -1,32 +1,30 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from asyncio import sleep +from random import choices +from string import ascii_uppercase, digits from typing import Callable, Awaitable, Optional, Dict, Any, Union, List, Set from orjson import dumps, loads -from random import choices -from string import ascii_uppercase, digits -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient diff --git a/tb_mqtt_client/service/gateway/device_sesion.py b/tb_mqtt_client/service/gateway/device_sesion.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/gateway/device_sesion.py +++ b/tb_mqtt_client/service/gateway/device_sesion.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py +++ b/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py b/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py +++ b/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/gateway/multiplex_publisher.py b/tb_mqtt_client/service/gateway/multiplex_publisher.py index 52dbfcf..fa669aa 100644 --- a/tb_mqtt_client/service/gateway/multiplex_publisher.py +++ b/tb_mqtt_client/service/gateway/multiplex_publisher.py @@ -1,15 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/service/gateway/subdevice_manager.py b/tb_mqtt_client/service/gateway/subdevice_manager.py index 59e3bfa..fa669aa 100644 --- a/tb_mqtt_client/service/gateway/subdevice_manager.py +++ b/tb_mqtt_client/service/gateway/subdevice_manager.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index edea245..9e242b7 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -1,30 +1,36 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - -from collections import defaultdict import asyncio from abc import ABC, abstractmethod +from collections import defaultdict from datetime import UTC, datetime -from typing import Any, Dict, Union, List, Tuple, Optional -from orjson import dumps +from typing import Any, Dict, List, Tuple, Optional, Union + +from orjson import dumps, loads +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.message_splitter import MessageSplitter -from tb_mqtt_client.common.logging_utils import get_logger logger = get_logger(__name__) @@ -36,10 +42,10 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio max_payload_size, max_datapoints) @abstractmethod - def build_topic_payloads( + def build_uplink_payloads( self, messages: List[DeviceUplinkMessage] - ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[bool]]]]]: + ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string and a payload byte array. @@ -47,9 +53,18 @@ def build_topic_payloads( pass @abstractmethod - def build_payload(self, msg: DeviceUplinkMessage) -> bytes: + def build_attribute_request_payload(self, request: AttributeRequest) -> bytes: """ - Build a JSON payload for a single DeviceUplinkMessage. + Build the payload for an attribute request response. + This method should return a tuple of topic and payload bytes. + """ + pass + + @abstractmethod + def build_rpc_response_payload(self, rpc_response: RPCResponse): + """ + Build the payload for an RPC response. + This method should return a tuple of topic and payload bytes. """ pass @@ -60,23 +75,98 @@ def splitter(self) -> MessageSplitter: """ pass + @abstractmethod + def parse_attribute_request_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + """ + Parse the attribute request response payload into an AttributeRequestResponse. + This method should be implemented to handle the specific format of the topic and payload. + """ + pass + + @abstractmethod + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + """ + Parse the attribute update payload into an AttributeUpdate. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + """ + Parse the RPC request from the given topic and payload. + This method should be implemented to handle the specific format of the RPC request. + """ + pass + class JsonMessageDispatcher(MessageDispatcher): def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): super().__init__(max_payload_size, max_datapoints) logger.trace("JsonMessageDispatcher created.") + def parse_attribute_request_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + """ + Parse the attribute request response payload into a RequestedAttributeResponse. + :param topic: The MQTT topic of the requested attribute response. + :param payload: The raw bytes of the payload. + :return: An instance of RequestedAttributeResponse. + """ + try: + request_id = int(topic.split("/")[-1]) + data = loads(payload) + logger.trace("Parsing attribute request response from payload: %s", data) + if not isinstance(data, dict): + logger.error("Invalid requested attribute response format: expected dict, got %s", type(data).__name__) + raise ValueError("Invalid requested attribute response format") + data["request_id"] = request_id # Add request_id to the data dictionary + return RequestedAttributeResponse.from_dict(data) + except Exception as e: + logger.error("Failed to parse attribute request response: %s", str(e)) + raise ValueError("Invalid attribute request response format") from e + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + """ + Parse the attribute update payload into an AttributeUpdate. + :param payload: The raw bytes of the payload. + :return: An instance of AttributeUpdate. + """ + try: + data = loads(payload) + logger.trace("Parsing attribute update from payload: %s", data) + return AttributeUpdate.from_dict(data) + except Exception as e: + logger.error("Failed to parse attribute update: %s", str(e)) + raise ValueError("Invalid attribute update format") from e + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + """ + Parse the RPC request from the given topic and payload. + :param topic: The MQTT topic of the RPC request. + :param payload: The raw bytes of the payload. + :return: An instance of RPCRequest. + """ + try: + request_id = int(topic.split("/")[-1]) + parsed = loads(payload) + parsed["id"] = request_id + data = RPCRequest.from_dict(parsed) + return data + except Exception as e: + logger.error("Failed to parse RPC request: %s", str(e)) + raise ValueError("Invalid RPC request format") from e + @property def splitter(self) -> MessageSplitter: return self._splitter - def build_topic_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[bool]]]]]: + def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: try: if not messages: logger.trace("No messages to process in build_topic_payloads.") return [] - result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[bool]]]]] = [] + result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) for msg in messages: @@ -93,13 +183,13 @@ def build_topic_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tupl device, len(telemetry_msgs), len(attr_msgs)) for ts_batch in self._splitter.split_timeseries(telemetry_msgs): - payload = self.build_payload(ts_batch) + payload = JsonMessageDispatcher.build_payload(ts_batch, True) count = ts_batch.timeseries_datapoint_count() result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) for attr_batch in self._splitter.split_attributes(attr_msgs): - payload = self.build_payload(attr_batch) + payload = JsonMessageDispatcher.build_payload(attr_batch, False) count = len(attr_batch.attributes) result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) @@ -112,37 +202,65 @@ def build_topic_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tupl logger.error("Error building topic-payloads: %s", str(e)) raise - def build_payload(self, msg: DeviceUplinkMessage) -> bytes: - result: Dict[str, Any] = {} + def build_attribute_request_payload(self, request: AttributeRequest) -> bytes: + """ + Build the payload for an attribute request response. + :param request: The AttributeRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not request.id: + raise ValueError("AttributeRequest must have a valid ID.") + + payload = dumps(request.to_payload_format()) + logger.trace("Built attribute request payload for request: %r", request) + return payload + + def build_rpc_response_payload(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: + """ + Build the payload for an RPC response. + :param rpc_response: The RPC response to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not rpc_response.request_id: + raise ValueError("RPCResponse must have a valid request ID.") + + payload = dumps(rpc_response.to_payload_format()) + topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) + logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) + return topic, payload + + @staticmethod + def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: + result: Union[Dict[str, Any], List[Dict[str, Any]]] = {} device_name = msg.device_name logger.trace("Building payload for device='%s'", device_name) if msg.device_name: - if msg.attributes: - logger.trace("Packing attributes for device='%s'", device_name) - result[msg.device_name] = self._pack_attributes(msg) - if msg.timeseries: + if build_timeseries_payload: logger.trace("Packing timeseries for device='%s'", device_name) - result[msg.device_name] = self._pack_timeseries(msg) + result[msg.device_name] = JsonMessageDispatcher.pack_timeseries(msg) + else: + logger.trace("Packing attributes for device='%s'", device_name) + result[msg.device_name] = JsonMessageDispatcher.pack_attributes(msg) else: - if msg.attributes: - logger.trace("Packing attributes") - result = self._pack_attributes(msg) - if msg.timeseries: + if build_timeseries_payload: logger.trace("Packing timeseries") - result = self._pack_timeseries(msg) + result = JsonMessageDispatcher.pack_timeseries(msg) + else: + logger.trace("Packing attributes") + result = JsonMessageDispatcher.pack_attributes(msg) payload = dumps(result) logger.trace("Built payload size: %d bytes", len(payload)) return payload @staticmethod - def _pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: + def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: logger.trace("Packing %d attribute(s)", len(msg.attributes)) return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def _pack_timeseries(msg: DeviceUplinkMessage) -> List[Dict[str, Any]]: + def pack_timeseries(msg: DeviceUplinkMessage) -> List[Dict[str, Any]]: logger.trace("Packing %d timeseries timestamp bucket(s)", len(msg.timeseries)) now_ts = int(datetime.now(UTC).timestamp() * 1000) diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index eceef52..49f0aa2 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio from contextlib import suppress @@ -22,6 +20,7 @@ from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.message_dispatcher import MessageDispatcher from tb_mqtt_client.service.mqtt_manager import MQTTManager @@ -33,6 +32,7 @@ class MessageQueue: def __init__(self, mqtt_manager: MQTTManager, + main_stop_event: asyncio.Event, message_rate_limit: Optional[RateLimit], telemetry_rate_limit: Optional[RateLimit], telemetry_dp_rate_limit: Optional[RateLimit], @@ -41,6 +41,7 @@ def __init__(self, batch_collect_max_time_ms: int = 100, batch_collect_max_count: int = 500): self.__qos = 1 + self._main_stop_event = main_stop_event self._batch_max_time = batch_collect_max_time_ms / 1000 # convert to seconds self._batch_max_count = batch_collect_max_count self._mqtt_manager = mqtt_manager @@ -48,7 +49,7 @@ def __init__(self, self._telemetry_rate_limit = telemetry_rate_limit self._telemetry_dp_rate_limit = telemetry_dp_rate_limit self._backpressure = self._mqtt_manager.backpressure - self._pending_ack_futures: Dict[int, asyncio.Future[bool]] = {} + self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} self._queue = asyncio.Queue(maxsize=max_queue_size) self._active = asyncio.Event() @@ -64,10 +65,10 @@ def __init__(self, async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int): delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) else [] try: - logger.debug("publish() received delivery future id: %r for topic=%s", - id(delivery_futures[0]), topic) + logger.trace("publish() received delivery future id: %r for topic=%s", + id(delivery_futures[0]) if delivery_futures else -1, topic) self._queue.put_nowait((topic, payload, delivery_futures, datapoints_count)) - logger.debug("Enqueued message: topic=%s, datapoints=%d, type=%s", + logger.trace("Enqueued message: topic=%s, datapoints=%d, type=%s", topic, datapoints_count, type(payload).__name__) except asyncio.QueueFull: logger.error("Message queue full. Dropping message for topic %s", topic) @@ -80,21 +81,14 @@ async def _dequeue_loop(self): logger.debug("MessageQueue dequeue loop started.") while self._active.is_set(): try: - # topic, payload, count = await self._wait_for_message() topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) - # Unpack payload and delivery futures if it's a retry tuple - logger.debug("MessageQueue dequeue: topic=%s, payload=%r, count=%d", + logger.trace("MessageQueue dequeue: topic=%s, payload=%r, count=%d", topic, payload, count) - # if isinstance(payload, tuple): - # payload, delivery_futures_or_none = payload - # else: - # delivery_futures_or_none = None - if isinstance(payload, bytes): await self._try_publish(topic, payload, count, delivery_futures_or_none) continue - logger.debug("Dequeued message: delivery_future id: %r topic=%s, type=%s, datapoints=%d", - id(delivery_futures_or_none[0]) if delivery_futures_or_none else None, + logger.trace("Dequeued message: delivery_future id: %r topic=%s, type=%s, datapoints=%d", + id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1, topic, type(payload).__name__, count) await asyncio.sleep(0) # cooperative yield except asyncio.TimeoutError: @@ -104,7 +98,7 @@ async def _dequeue_loop(self): except asyncio.CancelledError: break except Exception as e: - logger.debug("Unexpected error in dequeue loop: %s", e) + logger.warning("Unexpected error in dequeue loop: %s", e) continue if isinstance(payload, bytes): @@ -115,7 +109,7 @@ async def _dequeue_loop(self): logger.trace("Dequeued message for batching: topic=%s, device=%s", topic, getattr(payload, 'device_name', 'N/A')) - batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], asyncio.Future[bool], int]] = [(topic, payload, delivery_futures_or_none, count)] + batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], asyncio.Future[PublishResult], int]] = [(topic, payload, delivery_futures_or_none, count)] start = asyncio.get_event_loop().time() batch_size = payload.size @@ -132,7 +126,7 @@ async def _dequeue_loop(self): next_topic, next_payload, delivery_futures_or_none, next_count = self._queue.get_nowait() if isinstance(next_payload, DeviceUplinkMessage): msg_size = next_payload.size - if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: + if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: # noqa logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, next_count)) break @@ -145,13 +139,13 @@ async def _dequeue_loop(self): break if batch: - logger.debug("Batching completed: %d messages, total size=%d", len(batch), batch_size) + logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) messages = [device_uplink_message for _, device_uplink_message, _, _ in batch] - topic_payloads = self._dispatcher.build_topic_payloads(messages) + topic_payloads = self._dispatcher.build_uplink_payloads(messages) for topic, payload, datapoints, delivery_futures in topic_payloads: - logger.debug("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", + logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", topic, len(payload), datapoints, [id(f) for f in delivery_futures]) await self._try_publish(topic, payload, datapoints, delivery_futures) @@ -159,7 +153,7 @@ async def _try_publish(self, topic: str, payload: bytes, datapoints: int, - delivery_futures_or_none: List[Optional[asyncio.Future[bool]]] = None): + delivery_futures_or_none: List[Optional[asyncio.Future[PublishResult]]] = None): if delivery_futures_or_none is None: logger.trace("No delivery futures associated! This publish result will not be tracked.") delivery_futures_or_none = [] @@ -181,7 +175,7 @@ async def _try_publish(self, telemetry_dp_success = True if self._telemetry_rate_limit: - triggered_rate_limit = self._telemetry_rate_limit.try_consume(1) + triggered_rate_limit = await self._telemetry_rate_limit.try_consume(1) if triggered_rate_limit: logger.debug("Telemetry message rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) @@ -190,7 +184,7 @@ async def _try_publish(self, return if self._telemetry_dp_rate_limit: - triggered_rate_limit = self._telemetry_dp_rate_limit.try_consume(datapoints) + triggered_rate_limit = await self._telemetry_dp_rate_limit.try_consume(datapoints) if triggered_rate_limit: logger.debug("Telemetry datapoint rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) @@ -200,46 +194,58 @@ async def _try_publish(self, else: # For non-telemetry messages, we only need to check the message rate limit if self._message_rate_limit: - triggered_rate_limit = self._message_rate_limit.try_consume(1) + triggered_rate_limit = await self._message_rate_limit.try_consume(1) if triggered_rate_limit: - logger.debug("Generic message rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) + logger.debug("Generic message rate limit hit for topic %s: %r per %r seconds", + topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._message_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, payload, datapoints, delay=retry_delay, delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(topic, + payload, + datapoints, + delay=retry_delay, + delivery_futures=delivery_futures_or_none) return try: - logger.debug("Trying to publish topic=%s, payload size=%d, attached future id=%r", - topic, len(payload), id(delivery_futures_or_none[0]) if delivery_futures_or_none else 0) + logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", + topic, len(payload), id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1) mqtt_future = await self._mqtt_manager.publish(topic, payload, qos=self.__qos) if delivery_futures_or_none is not None: - def resolve_attached(mqtt_future: asyncio.Future): + def resolve_attached(publish_future: asyncio.Future): try: - success = mqtt_future.result() is True - except Exception as e: - success = False - logger.warning("mqtt_future failed with exception: %s", e) + publish_result = publish_future.result() + except Exception as exc: + logger.warning("Publish failed with exception: %s", exc) + logger.debug("Resolving delivery futures with failure:", exc_info=exc) + publish_result = PublishResult(topic, self.__qos, -1, len(payload), -1) for i, f in enumerate(delivery_futures_or_none): if f is not None and not f.done(): - f.set_result(success) - logger.debug("Resolved delivery future #%d id=%r with %s, main publish future id: %r, %r", - i, id(f), success, id(mqtt_future), mqtt_future) + f.set_result(publish_result) + logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r, %r", + i, id(f), publish_result, id(publish_future), publish_future) - logger.debug("Adding done callback to main publish future: %r, main publish future state: %r", id(mqtt_future), mqtt_future.done()) + logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", id(mqtt_future), mqtt_future.done()) mqtt_future.add_done_callback(resolve_attached) except Exception as e: logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) self._schedule_delayed_retry(topic, payload, datapoints, delay=.1) def _schedule_delayed_retry(self, topic: str, payload: bytes, points: int, delay: float, - delivery_futures: Optional[List[Optional[asyncio.Future[bool]]]] = None): + delivery_futures: Optional[List[Optional[asyncio.Future[PublishResult]]]] = None): + if not self._active.is_set() or self._main_stop_event.is_set(): + logger.debug("MessageQueue is not active or main stop event is set. Not scheduling retry for topic=%s", topic) + return logger.trace("Scheduling retry: topic=%s, delay=%.2f", topic, delay) async def retry(): try: logger.debug("Retrying publish: topic=%s", topic) await asyncio.sleep(delay) + if not self._active.is_set() or self._main_stop_event.is_set(): + logger.debug("MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", topic) + return self._queue.put_nowait((topic, payload, delivery_futures, points)) self._wakeup_event.set() logger.debug("Re-enqueued message after delay: topic=%s", topic) @@ -291,10 +297,13 @@ async def shutdown(self): self._active.clear() self._wakeup_event.set() # Wake up the _wait_for_message if it's blocked - for task in self._retry_tasks: - task.cancel() - with suppress(asyncio.CancelledError): - await asyncio.gather(*self._retry_tasks, return_exceptions=True) + for task in list(self._retry_tasks): + try: + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*self._retry_tasks, return_exceptions=True) + except Exception as e: + logger.warning("Error while cancelling retry task: %s", e) self._loop_task.cancel() self._rate_limit_refill_task.cancel() @@ -332,8 +341,8 @@ async def _rate_limit_refill_loop(self): try: while self._active.is_set(): await asyncio.sleep(1.0) - self._refill_rate_limits() - logger.debug("Rate limits refilled, state: %s", + await self._refill_rate_limits() + logger.trace("Rate limits refilled, state: %s", { "message_rate_limit": self._message_rate_limit.to_dict() if self._message_rate_limit else None, "telemetry_rate_limit": self._telemetry_rate_limit.to_dict() if self._telemetry_rate_limit else None, @@ -342,7 +351,7 @@ async def _rate_limit_refill_loop(self): except asyncio.CancelledError: pass - def _refill_rate_limits(self): + async def _refill_rate_limits(self): for rl in (self._message_rate_limit, self._telemetry_rate_limit, self._telemetry_dp_rate_limit): if rl: - rl.refill() + await rl.refill() diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index e76d838..a484c48 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -1,19 +1,20 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio from typing import List + from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder @@ -56,7 +57,7 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp for ts in message.timeseries.values(): exceeds_size = builder and size + ts.size > self._max_payload_size - exceeds_points = self._max_datapoints > 0 and point_count >= self._max_datapoints + exceeds_points = 0 < self._max_datapoints <= point_count if not builder or exceeds_size or exceeds_points: if builder: @@ -73,7 +74,7 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp size += ts.size point_count += 1 - if builder and builder._timeseries: + if builder and builder._timeseries: # noqa built = builder.build() result.append(built) batch_futures.extend(built.get_delivery_futures()) @@ -81,7 +82,7 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp if message.get_delivery_futures(): original_future = message.get_delivery_futures()[0] - logger.exception("Adding futures to original future: %s, futures ids: %r", id(original_future), + logger.trace("Adding futures to original future: %s, futures ids: %r", id(original_future), [id(batch_future) for batch_future in batch_futures]) async def resolve_original(): @@ -116,7 +117,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp for attr in message.attributes: exceeds_size = builder and size + attr.size > self._max_payload_size - exceeds_points = self._max_datapoints > 0 and point_count >= self._max_datapoints + exceeds_points = 0 < self._max_datapoints <= point_count if not builder or exceeds_size or exceeds_points: if builder: @@ -136,7 +137,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp size += attr.size point_count += 1 - if builder and builder._attributes: + if builder and builder._attributes: # noqa built = builder.build() result.append(built) batch_futures.extend(built.get_delivery_futures()) @@ -144,7 +145,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp if message.get_delivery_futures(): original_future = message.get_delivery_futures()[0] - logger.exception("Adding futures to original future: %s, futures ids: %r", id(original_future), + logger.trace("Adding futures to original future: %s, futures ids: %r", id(original_future), [id(batch_future) for batch_future in batch_futures]) async def resolve_original(): diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index ae34955..ba6ac07 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -1,49 +1,58 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio import ssl from asyncio import sleep -from typing import Optional, Callable, Awaitable, Dict, Union +from time import monotonic +from typing import Optional, Callable, Awaitable, Dict, Union, Tuple -from gmqtt import Client as GMQTTClient, Message, Subscription, MQTTConnectError +from gmqtt import Client as GMQTTClient, Message, Subscription -from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer -from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback, patch_gmqtt_protocol_connection_lost, patch_mqtt_handler_disconnect, DISCONNECT_REASON_CODES +from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback, patch_gmqtt_protocol_connection_lost, \ + patch_mqtt_handler_disconnect, DISCONNECT_REASON_CODES from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController -from tb_mqtt_client.service.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer, AttributeRequestIdProducer from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT, TELEMETRY_MESSAGE_RATE_LIMIT, \ TELEMETRY_DATAPOINTS_RATE_LIMIT from tb_mqtt_client.constants.service_messages import SESSION_LIMITS_REQUEST_MESSAGE +from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler logger = get_logger(__name__) +QUOTA_EXCEEDED = 0x97 # MQTT 5 reason code (151) +IMPLEMENTATION_SPECIFIC_ERROR = 0x83 # MQTT 5 reason code (131) + class MQTTManager: + + _PUBLISH_TIMEOUT = 10.0 # Default timeout for publish operations + def __init__( self, client_id: str, + main_stop_event: asyncio.Event, on_connect: Optional[Callable[[], Awaitable[None]]] = None, on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, rate_limits_handler: Optional[Callable[[str, bytes], Awaitable[None]]] = None, rpc_response_handler: Optional[RPCResponseHandler] = None, ): + self._main_stop_event = main_stop_event patch_gmqtt_protocol_connection_lost() patch_mqtt_handler_disconnect() @@ -63,12 +72,14 @@ def __init__( self._connect_params = None # Will be set in connect method self._handlers: Dict[str, Callable[[str, bytes], Awaitable[None]]] = {} - self._pending_publishes: Dict[int, asyncio.Future] = {} + self._pending_publishes: Dict[int, Tuple[asyncio.Future[PublishResult], str, int, int, float]] = {} + self._publish_monitor_task = asyncio.create_task(self._monitor_ack_timeouts()) + self._pending_subscriptions: Dict[int, asyncio.Future] = {} self._pending_unsubscriptions: Dict[int, asyncio.Future] = {} self._rpc_response_handler = rpc_response_handler or RPCResponseHandler() - self._backpressure = BackpressureController() + self._backpressure = BackpressureController(self._main_stop_event) self.__rate_limits_handler = rate_limits_handler self.__rate_limits_retrieved = False self.__rate_limiter: Optional[Dict[str, RateLimit]] = None @@ -111,7 +122,12 @@ def is_connected(self) -> bool: return self._client.is_connected and self._connected_event.is_set() and self.__rate_limits_retrieved async def disconnect(self): - await self._client.disconnect() + try: + await self._client.disconnect() + except ConnectionResetError: + logger.debug("Connection reset error during disconnect, ignoring.") + except Exception as e: + logger.error("Error during MQTT disconnect: %s", str(e)) await asyncio.sleep(0.2) self._connected_event.clear() self.__rate_limits_retrieved = False @@ -128,7 +144,8 @@ async def publish(self, message_or_topic: Union[str, Message], if not self.__rate_limits_retrieved and not self.__is_waiting_for_rate_limits_publish: raise RuntimeError("Cannot publish before rate limits are retrieved.") try: - await asyncio.wait_for(self._rate_limits_ready_event.wait(), timeout=10) + if not self._rate_limits_ready_event.is_set(): + await asyncio.wait_for(self._rate_limits_ready_event.wait(), timeout=10) except asyncio.TimeoutError: raise RuntimeError("Timeout waiting for rate limits.") @@ -141,13 +158,13 @@ async def publish(self, message_or_topic: Union[str, Message], else: message = Message(message_or_topic, payload, qos=qos, retain=retain) - mid, package = self._client._connection.publish(message) + mid, package = self._client._connection.publish(message) # noqa future = asyncio.get_event_loop().create_future() if qos > 0: - logger.debug("Publishing mid=%s, storing publish main future with id: %r", mid, id(future)) - self._pending_publishes[mid] = future - self._client._persistent_storage.push_message_nowait(mid, package) + logger.trace("Publishing mid=%s, storing publish main future with id: %r", mid, id(future)) + self._pending_publishes[mid] = (future, message.topic, message.qos, message.payload_size, monotonic()) + self._client._persistent_storage.push_message_nowait(mid, package) # noqa else: future.set_result(True) @@ -158,16 +175,16 @@ async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyn subscription = Subscription(topic, qos=qos) if isinstance(topic, str) else topic if self.__rate_limiter: - self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() - mid = self._client._connection.subscribe([subscription]) + await self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() + mid = self._client._connection.subscribe([subscription]) # noqa self._pending_subscriptions[mid] = sub_future return sub_future async def unsubscribe(self, topic: str) -> asyncio.Future: unsubscribe_future = asyncio.get_event_loop().create_future() if self.__rate_limiter: - self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() - mid = self._client._connection.unsubscribe(topic) + await self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() + mid = self._client._connection.unsubscribe(topic) # noqa self._pending_unsubscriptions[mid] = unsubscribe_future return unsubscribe_future @@ -179,8 +196,9 @@ def unregister_handler(self, topic_filter: str): def _on_connect_internal(self, client, flags, rc, properties): logger.info("Connected to platform") + logger.debug("Connection flags: %s, reason code: %s, properties: %s", flags, rc, properties) if hasattr(client, '_connection'): - client._connection._on_disconnect_called = False + client._connection._on_disconnect_called = False # noqa self._connected_event.set() asyncio.create_task(self.__handle_connect_and_limits()) @@ -199,7 +217,7 @@ async def __handle_connect_and_limits(self): if self._on_connect_callback: await self._on_connect_callback() - def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): + def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): # noqa if reason_code is not None: reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") logger.info("Disconnected from platform with reason code: %s (%s)", reason_code, reason_desc) @@ -212,8 +230,23 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc if exc: logger.warning("Disconnect exception: %s", exc) + for mid, (future, topic, qos, payload_size, publishing_time) in list(self._pending_publishes.items()): + if not future.done(): + publish_result = PublishResult( + topic=topic, + qos=qos, + payload_size=payload_size, + message_id=-1, + reason_code=reason_code or 0 + ) + future.set_result(publish_result) + logger.warning("Setting publish result for mid=%s: %r", mid, publish_result) + self._pending_publishes.clear() + RPCRequestIdProducer.reset() + AttributeRequestIdProducer.reset() self._rpc_response_handler.clear() + self._handlers.clear() self._connected_event.clear() self.__rate_limits_retrieved = False self.__is_waiting_for_rate_limits_publish = True @@ -236,6 +269,8 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc asyncio.create_task(self._on_disconnect_callback()) def _on_message_internal(self, client, topic: str, payload: bytes, qos, properties): + logger.trace("Received message by client %r on topic %s with payload %r, qos %s, properties %s", + client, topic, payload, qos, properties) for topic_filter, handler in self._handlers.items(): if self._match_topic(topic_filter, topic): asyncio.create_task(handler(topic, payload)) @@ -245,23 +280,28 @@ def _on_message_internal(self, client, topic: str, payload: bytes, qos, properti def _on_publish_internal(self, client, mid): pass - # future = self._pending_publishes.pop(mid, None) - # if future and not future.done(): - # future.set_result(True) def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dict): - QUOTA_EXCEEDED = 0x97 # MQTT 5 reason code for quota exceeded - IMPLEMENTATION_SPECIFIC_ERROR = 0x83 # MQTT 5 reason code for implementation specific error (131) - - logger.debug("Handling PUBACK mid=%s with rc %r", mid, reason_code) - future = self._pending_publishes.pop(mid, None) - if future is None: + logger.trace("Handling PUBACK mid=%s with rc %r and properties: %r", + mid, reason_code, properties) + pending_future_data = self._pending_publishes.pop(mid, None) + if pending_future_data is None: logger.error("Missing future for mid=%s", mid) - elif future.done(): - logger.error("Future for mid=%s already resolved", mid) + return + future, topic, qos, payload_size, publishing_time = pending_future_data + publish_result = PublishResult( + topic=topic, + qos=qos, + payload_size=payload_size, + message_id=mid, + reason_code=reason_code + ) + logger.trace("Received result for publish future (id: %r): %r", id(future), publish_result) + if not future.done(): + future.set_result(publish_result) else: - logger.debug("Resolved future for mid=%s, object id: %r", mid, id(future)) - future.set_result(True) + logger.warning("Future (id: %r) for mid=%s was already done, skipping setting result", + id(future), mid) if reason_code == QUOTA_EXCEEDED: logger.warning("PUBACK received with QUOTA_EXCEEDED for mid=%s", mid) @@ -274,14 +314,17 @@ def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dic def _on_subscribe_internal(self, client, mid, qos, properties): + logger.trace("Received SUBACK by client %r for mid=%s with qos %s, properties %s", + client, mid, qos, properties) future = self._pending_subscriptions.pop(mid, None) if future and not future.done(): - future.set_result(True) + future.set_result(mid) def _on_unsubscribe_internal(self, client, mid): + logger.trace("Received UNSUBACK by client %r for mid=%s", client, mid) future = self._pending_unsubscriptions.pop(mid, None) if future and not future.done(): - future.set_result(True) + future.set_result(mid) async def await_ready(self, timeout: float = 10.0): try: @@ -305,7 +348,6 @@ def set_rate_limits( self._rate_limits_ready_event.set() async def __request_rate_limits(self): - # Set this flag at the beginning to prevent publishing before rate limits are retrieved self.__is_waiting_for_rate_limits_publish = True request_id = await RPCRequestIdProducer.get_next() @@ -346,8 +388,8 @@ def backpressure(self) -> BackpressureController: return self._backpressure @staticmethod - def _match_topic(filter: str, topic: str) -> bool: - filter_parts = filter.split('/') + def _match_topic(filter_expression: str, topic: str) -> bool: + filter_parts = filter_expression.split('/') topic_parts = topic.split('/') for i, filter_part in enumerate(filter_parts): @@ -359,3 +401,19 @@ def _match_topic(filter: str, topic: str) -> bool: return False return len(filter_parts) == len(topic_parts) + + async def _monitor_ack_timeouts(self): + while not self._main_stop_event.is_set(): + now = monotonic() + expired = [] + for mid, (future, topic, qos, payload_size, timestamp) in list(self._pending_publishes.items()): + if now - timestamp > self._PUBLISH_TIMEOUT: + if not future.done(): + logger.warning("Publish timeout: mid=%s, topic=%s", mid, topic) + result = PublishResult(topic, qos, payload_size, mid, reason_code=408) + future.set_result(result) + expired.append(mid) + for mid in expired: + self._pending_publishes.pop(mid, None) + # TODO: Add logic to handle expired futures, for subscriptions, rpc responses, etc. + await asyncio.sleep(1) diff --git a/tb_mqtt_client/service/rpc_response_handler.py b/tb_mqtt_client/service/rpc_response_handler.py deleted file mode 100644 index 762c2ae..0000000 --- a/tb_mqtt_client/service/rpc_response_handler.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from typing import Dict, Union - -from tb_mqtt_client.common.logging_utils import get_logger -from orjson import loads - -logger = get_logger(__name__) - - -class RPCResponseHandler: - """ - Handles RPC responses coming from the platform to the client (client-side RPCs responses). - Maintains an internal map of request_id -> asyncio.Future for awaiting RPC results. - """ - - def __init__(self): - self._pending_requests: Dict[Union[str, int], asyncio.Future] = {} - - def register_request(self, request_id: Union[str, int]) -> asyncio.Future: - """ - Called when a request is sent to the platform and a response is awaited. - """ - if request_id in self._pending_requests: - raise RuntimeError(f"Request ID {request_id} is already registered.") - future = asyncio.get_event_loop().create_future() - self._pending_requests[request_id] = future - return future - - async def handle(self, topic: str, payload: bytes): - """ - Handles the incoming RPC response from the platform and fulfills the corresponding future. - The topic is expected to be: v1/devices/me/rpc/response/{request_id} - """ - try: - request_id = topic.split("/")[-1] - response_data = loads(payload) - - future = self._pending_requests.pop(request_id, None) - if not future: - logger.warning("No future awaiting request ID %s. Ignoring.", request_id) - return - - if isinstance(response_data, dict) and "error" in response_data: - future.set_exception(Exception(response_data["error"])) - else: - future.set_result(response_data) - - except Exception as e: - logger.exception("Failed to handle RPC response: %s", e) - - def clear(self): - """ - Clears all pending futures (e.g. on disconnect). - """ - for fut in self._pending_requests.values(): - if not fut.done(): - fut.cancel() - self._pending_requests.clear() diff --git a/tb_mqtt_client/tb_device_mqtt.py b/tb_mqtt_client/tb_device_mqtt.py index 6d6dc0d..fa669aa 100644 --- a/tb_mqtt_client/tb_device_mqtt.py +++ b/tb_mqtt_client/tb_device_mqtt.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/__init__.py b/tests/__init__.py index 8d89b47..fa669aa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,13 +1,14 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/constants/__init__.py b/tests/constants/__init__.py index 59e3bfa..fa669aa 100644 --- a/tests/constants/__init__.py +++ b/tests/constants/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/constants/test_mqtt_topics.py b/tests/constants/test_mqtt_topics.py index 293d4e5..6fbe7b0 100644 --- a/tests/constants/test_mqtt_topics.py +++ b/tests/constants/test_mqtt_topics.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from tb_mqtt_client.constants import mqtt_topics diff --git a/tests/service/__init__.py b/tests/service/__init__.py index 59e3bfa..fa669aa 100644 --- a/tests/service/__init__.py +++ b/tests/service/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/service/device/__init__.py b/tests/service/device/__init__.py index 59e3bfa..fa669aa 100644 --- a/tests/service/device/__init__.py +++ b/tests/service/device/__init__.py @@ -1,14 +1,14 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/service/device/test_device_client_rate_limits.py b/tests/service/device/test_device_client_rate_limits.py index cd0a00d..bee445f 100644 --- a/tests/service/device/test_device_client_rate_limits.py +++ b/tests/service/device/test_device_client_rate_limits.py @@ -1,17 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from orjson import dumps diff --git a/tests/service/test_json_message_dispatcher.py b/tests/service/test_json_message_dispatcher.py index 9592616..e2575cc 100644 --- a/tests/service/test_json_message_dispatcher.py +++ b/tests/service/test_json_message_dispatcher.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher @@ -32,7 +30,7 @@ def test_single_telemetry_dispatch(dispatcher): builder.add_telemetry(TimeseriesEntry("temp", 25)) msg = builder.build() - payloads = dispatcher.build_topic_payloads([msg]) + payloads = dispatcher.build_uplink_payloads([msg]) assert len(payloads) == 1 topic, payload, count = payloads[0] assert topic == DEVICE_TELEMETRY_TOPIC @@ -45,7 +43,7 @@ def test_single_attribute_dispatch(dispatcher): builder.add_attributes(AttributeEntry("mode", "auto")) msg = builder.build() - payloads = dispatcher.build_topic_payloads([msg]) + payloads = dispatcher.build_uplink_payloads([msg]) assert len(payloads) == 1 topic, payload, count = payloads[0] assert topic == DEVICE_ATTRIBUTES_TOPIC @@ -59,7 +57,7 @@ def test_multiple_devices_grouping(dispatcher): b2 = DeviceUplinkMessageBuilder().set_device_name("dev2") b2.add_telemetry(TimeseriesEntry("t2", 2)) - payloads = dispatcher.build_topic_payloads([b1.build(), b2.build()]) + payloads = dispatcher.build_uplink_payloads([b1.build(), b2.build()]) assert len(payloads) == 2 for topic, payload, count in payloads: assert topic == DEVICE_TELEMETRY_TOPIC @@ -71,7 +69,7 @@ def test_large_telemetry_split(dispatcher): for i in range(15): builder.add_telemetry(TimeseriesEntry(f"key{i}", i)) - payloads = dispatcher.build_topic_payloads([builder.build()]) + payloads = dispatcher.build_uplink_payloads([builder.build()]) assert len(payloads) > 1 for topic, payload, count in payloads: assert topic == DEVICE_TELEMETRY_TOPIC @@ -85,7 +83,7 @@ def test_large_attributes_split(): for i in range(20): builder.add_attributes(AttributeEntry(f"k{i}", "x" * 50)) # Increase size - payloads = dispatcher.build_topic_payloads([builder.build()]) + payloads = dispatcher.build_uplink_payloads([builder.build()]) assert len(payloads) > 1 # Now expect splitting for topic, payload, count in payloads: assert topic == DEVICE_ATTRIBUTES_TOPIC diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index 535b8c6..1a8ea3b 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from tb_mqtt_client.service.message_splitter import MessageSplitter diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index c2364ab..4deb85e 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -1,18 +1,16 @@ -# Copyright 2025. ThingsBoard -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# http://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2025 ThingsBoard # - +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import asyncio import pytest diff --git a/tests/tb_device_mqtt_client_tests.py b/tests/tb_device_mqtt_client_tests.py index b8bef22..131403d 100644 --- a/tests/tb_device_mqtt_client_tests.py +++ b/tests/tb_device_mqtt_client_tests.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest from time import sleep diff --git a/tests/tb_gateway_mqtt_client_tests.py b/tests/tb_gateway_mqtt_client_tests.py index 6893183..761fca3 100644 --- a/tests/tb_gateway_mqtt_client_tests.py +++ b/tests/tb_gateway_mqtt_client_tests.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest from time import sleep, time diff --git a/utils.py b/utils.py index 8d6b1ec..1227418 100644 --- a/utils.py +++ b/utils.py @@ -1,16 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sys import executable from subprocess import check_call, CalledProcessError, DEVNULL From 54421993a1c0a85255cc64ae4f85fb0d90d278cd Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 30 May 2025 11:40:44 +0300 Subject: [PATCH 04/74] Added ability to use client-side RPC and rate limits retrieval refactoring --- examples/device/load.py | 8 +- examples/device/operational_example.py | 240 +++++++++++------- tb_mqtt_client/constants/mqtt_topics.py | 7 +- tb_mqtt_client/constants/service_messages.py | 23 -- .../entities/data/attribute_request.py | 65 +++-- .../entities/data/attribute_update.py | 2 +- tb_mqtt_client/entities/data/rpc_request.py | 60 +++-- tb_mqtt_client/entities/data/rpc_response.py | 19 +- tb_mqtt_client/service/device/client.py | 47 ++-- .../requested_attributes_response_handler.py | 8 +- .../device/handlers/rpc_response_handler.py | 50 ++-- tb_mqtt_client/service/message_dispatcher.py | 77 +++++- tb_mqtt_client/service/message_queue.py | 27 +- tb_mqtt_client/service/message_splitter.py | 52 ++-- tb_mqtt_client/service/mqtt_manager.py | 72 +++--- 15 files changed, 468 insertions(+), 289 deletions(-) delete mode 100644 tb_mqtt_client/constants/service_messages.py diff --git a/examples/device/load.py b/examples/device/load.py index 3e1cf51..593225c 100644 --- a/examples/device/load.py +++ b/examples/device/load.py @@ -41,20 +41,20 @@ async def attribute_update_callback(update: AttributeUpdate): - logger.info("Received attribute update: %s", update.as_dict()) + logger.info("Received attribute update: %r", update) async def rpc_request_callback(request: RPCRequest): - logger.info("Received RPC request: %s", request.to_dict()) + logger.info("Received RPC request: %r", request) return RPCResponse(request_id=request.request_id, result={"status": "ok"}) - async def main(): stop_event = asyncio.Event() def _shutdown_handler(): stop_event.set() + asyncio.get_event_loop().run_until_complete(client.stop()) loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): @@ -79,7 +79,7 @@ def _shutdown_handler(): delivered_datapoints = 0 pending_futures = [] - delivery_start_ts = None # Start time of first successful delivery + delivery_start_ts = None # Start time of the first successful delivery delivery_end_ts = None # End time of last successful delivery try: diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index c2f08c5..a5d2a1d 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -16,9 +16,11 @@ import logging import signal from datetime import datetime, UTC +from random import uniform, randint from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse @@ -38,7 +40,7 @@ async def attribute_update_callback(update: AttributeUpdate): Callback function to handle attribute updates. :param update: The attribute update object. """ - logger.info("Received attribute update: %s", update.as_dict()) + logger.info("Received attribute update: %r", update) async def rpc_request_callback(request: RPCRequest): @@ -47,7 +49,7 @@ async def rpc_request_callback(request: RPCRequest): :param request: The RPC request object. :return: A RPC response object. """ - logger.info("Received RPC request: %s", request.to_dict()) + logger.info("Received RPC request: %r", request) response_data = { "status": "success", } @@ -56,6 +58,13 @@ async def rpc_request_callback(request: RPCRequest): error=None) return response +async def rpc_response_callback(response: RPCResponse): + """ + Callback function to handle RPC responses for client side RPC requests. + :param response: The RPC response object. + """ + logger.info("Received RPC response in callback: %r", response) + async def attribute_request_callback(requested_attributes_response: RequestedAttributeResponse): """ @@ -70,7 +79,7 @@ async def main(): def _shutdown_handler(): stop_event.set() - client.stop() + asyncio.get_event_loop().run_until_complete(client.stop()) loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): @@ -81,8 +90,6 @@ def _shutdown_handler(): signal.signal(sig, lambda *_: _shutdown_handler()) # noqa config = DeviceConfig() - # config.host = "192.168.1.202" - # config.access_token = "ypbn08v8f4klg6oah3r6" config.host = "localhost" config.access_token = "YOUR_ACCESS_TOKEN" @@ -93,89 +100,122 @@ def _shutdown_handler(): logger.info("Connected to ThingsBoard.") - attribute_request = AttributeRequest(["uno"], ["client"]) - while not stop_event.is_set(): - # # --- Attributes --- - # - # # 1. Raw dict - # raw_dict = { - # "firmwareVersion": "1.0.4", - # "hardwareModel": "TB-SDK-Device" - # } - # logger.info("Sending attributes...") - # delivery_future = await client.send_attributes(raw_dict) - # if delivery_future: - # logger.info("Awaiting delivery future for raw attributes...") - # result = await delivery_future - # # logger.info("Raw attributes sent: %s, delivery result: %s", raw_dict, result) - # else: - # logger.warning("Delivery future is None, raw attributes may not be sent.") - # - # # logger.info(f"Raw attributes sent: {raw_dict}") - # - # # 2. Single AttributeEntry - # single_entry = AttributeEntry("mode", "normal") - # logger.info("Sending single attribute: %s", single_entry) - # delivery_future = await client.send_attributes(single_entry) - # if delivery_future: - # logger.info("Awaiting delivery future for single attribute...") - # result = await delivery_future - # # logger.info("Single attribute sent: %s, delivery result: %s", single_entry, result) - # else: - # logger.warning("Delivery future is None, single attribute may not be sent.") - # - # # logger.info("Single attribute sent: %s", single_entry) - # - # # 3. List of AttributeEntry - # attr_entries = [ - # AttributeEntry("maxTemperature", 85), - # AttributeEntry("calibrated", True) - # ] - # logger.info("Sending list of attributes: %s", attr_entries) - # delivery_future = await client.send_attributes(attr_entries) - # if delivery_future: - # logger.info("Awaiting delivery future for list of attributes...") - # result = await delivery_future - # # logger.info("List of attributes sent: %s, delivery result: %s", attr_entries, result) - # else: - # logger.warning("Delivery future is None, list of attributes may not be sent.") - # - # # --- Telemetry --- - # - # # 1. Raw dict - # raw_dict = { - # "temperature": round(uniform(20.0, 30.0), 2), - # "humidity": 60 - # } - # logger.info("Sending raw telemetry...") - # delivery_future = await client.send_telemetry(raw_dict) - # if delivery_future: - # logger.info("Awaiting delivery future for raw telemetry...") - # result = await delivery_future - # # logger.info("Raw telemetry sent: %s, delivery result: %s", raw_dict, result) - # else: - # logger.warning("Delivery future is None, raw telemetry may not be sent.") - # - # # logger.info(f"Raw telemetry sent: {raw_dict}") - # - # # 2. Single TelemetryEntry (with ts) - # single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) - # logger.info("Sending single telemetry: %s", single_entry) - # delivery_future = await client.send_telemetry(single_entry) - # if delivery_future: - # logger.info("Awaiting delivery future for single telemetry...") - # result = await delivery_future - # # logger.info("Single telemetry sent: %s, delivery result: %s", single_entry, result) - # else: - # logger.warning("Delivery future is None, single telemetry may not be sent.") - # - # # logger.info("Single telemetry sent: %s", single_entry) + # --- Attributes --- + + # 1. Raw dict + raw_dict = { + "firmwareVersion": "1.0.4", + "hardwareModel": "TB-SDK-Device" + } + logger.info("Sending attributes...") + delivery_future = await client.send_attributes(raw_dict) + if delivery_future: + logger.info("Awaiting delivery future for raw attributes...") + try: + result = await asyncio.wait_for(delivery_future, timeout=5) + except asyncio.TimeoutError: + logger.warning("Delivery future timed out after 5 seconds.") + result = False + except Exception as e: + logger.error("Error while awaiting delivery future: %s", e) + result = False + logger.info("Raw attributes sent: %s, delivery result: %s", raw_dict, result) + else: + logger.warning("Delivery future is None, raw attributes may not be sent.") + + # logger.info(f"Raw attributes sent: {raw_dict}") + + # 2. Single AttributeEntry + single_entry = AttributeEntry("mode", "normal") + logger.info("Sending single attribute: %s", single_entry) + delivery_future = await client.send_attributes(single_entry) + if delivery_future: + logger.info("Awaiting delivery future for single attribute...") + try: + result = await asyncio.wait_for(delivery_future, timeout=5) + except asyncio.TimeoutError: + logger.warning("Delivery future timed out after 5 seconds.") + result = False + except Exception as e: + logger.error("Error while awaiting delivery future: %s", e) + result = False + logger.info("Single attribute sent: %s, delivery result: %s", single_entry, result) + else: + logger.warning("Delivery future is None, single attribute may not be sent.") + + logger.info("Single attribute sent: %s", single_entry) + + # 3. List of AttributeEntry + attr_entries = [ + AttributeEntry("maxTemperature", 85), + AttributeEntry("calibrated", True) + ] + logger.info("Sending list of attributes: %s", attr_entries) + delivery_future = await client.send_attributes(attr_entries) + if delivery_future: + logger.info("Awaiting delivery future for list of attributes...") + try: + result = await asyncio.wait_for(delivery_future, timeout=5) + except asyncio.TimeoutError: + logger.warning("Delivery future timed out after 5 seconds.") + result = False + except Exception as e: + logger.error("Error while awaiting delivery future: %s", e) + result = False + logger.info("List of attributes sent: %s, delivery result: %s", attr_entries, result) + else: + logger.warning("Delivery future is None, list of attributes may not be sent.") + + # --- Telemetry --- + + # 1. Raw dict + raw_dict = { + "temperature": round(uniform(20.0, 30.0), 2), + "humidity": 60 + } + logger.info("Sending raw telemetry...") + delivery_future = await client.send_telemetry(raw_dict) + if delivery_future: + logger.info("Awaiting delivery future for raw telemetry...") + try: + result = await asyncio.wait_for(delivery_future, timeout=5) + except asyncio.TimeoutError: + logger.warning("Delivery future timed out after 5 seconds.") + result = False + except Exception as e: + logger.error("Error while awaiting delivery future: %s", e) + result = False + logger.info("Raw telemetry sent: %s, delivery result: %s", raw_dict, result) + else: + logger.warning("Delivery future is None, raw telemetry may not be sent.") + + logger.info(f"Raw telemetry sent: {raw_dict}") + + # 2. Single TelemetryEntry (with ts) + single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) + logger.info("Sending single telemetry: %s", single_entry) + delivery_future = await client.send_telemetry(single_entry) + if delivery_future: + logger.info("Awaiting delivery future for single telemetry...") + try: + result = await asyncio.wait_for(delivery_future, timeout=5) + except asyncio.TimeoutError: + logger.warning("Delivery future timed out after 5 seconds.") + result = False + except Exception as e: + logger.error("Error while awaiting delivery future: %s", e) + result = False + logger.info("Single telemetry sent: %s, delivery result: %s", single_entry, result) + else: + logger.warning("Delivery future is None, single telemetry may not be sent.") + + logger.info("Single telemetry sent: %s", single_entry) # 3. List of TelemetryEntry with mixed timestamps telemetry_entries = [] - for i in range(100): + for i in range(1): telemetry_entries.append(TimeseriesEntry("temperature", i, ts=int(datetime.now(UTC).timestamp() * 1000)-i)) ts_now = int(datetime.now(UTC).timestamp() * 1000) logger.info("Sending list of telemetry entries with mixed timestamps...") @@ -196,9 +236,37 @@ def _shutdown_handler(): else: logger.warning("Delivery future is None, list of telemetry may not be sent.") - # logger.info("Requesting attributes...") + logger.info("Requesting attributes...") + + attribute_request = await AttributeRequest.build(["uno"], ["client"]) + + logger.info("Sending attribute request: %r", attribute_request) + + await client.send_attribute_request(attribute_request, attribute_request_callback) + + logger.info("Sending client side RPC request...") + + rpc_request = await RPCRequest.build("getSomeInformation", {"key1": "value1"}) + + logger.info("Sending RPC request: %r", rpc_request) + + response_future = await client.send_rpc_request(rpc_request) + + if response_future: + logger.info("Awaiting RPC response future...") + try: + response = await asyncio.wait_for(response_future, timeout=5) + logger.info("RPC response received: %s", response) + except asyncio.TimeoutError: + logger.warning("RPC response future timed out after 5 seconds.") + except Exception as e: + logger.error("Error while awaiting RPC response future: %s", e) + + rpc_request_2 = await RPCRequest.build("getAnotherInformation", {"param": "value"}) + + logger.info("Sending another RPC request: %r", rpc_request_2) - # await client.send_attribute_request(attribute_request, attribute_request_callback) + await client.send_rpc_request(rpc_request_2, rpc_response_callback) try: logger.info("Waiting for 1 seconds before next iteration...") @@ -212,7 +280,7 @@ def _shutdown_handler(): if __name__ == "__main__": try: loop = asyncio.get_event_loop() - loop.set_debug(True) # Enable debug mode for asyncio + loop.set_debug(False) # Enable debug mode for asyncio loop.run_until_complete(main()) except KeyboardInterrupt: print("Interrupted by user.") diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 378d2c0..5955816 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -1,6 +1,3 @@ -WILDCARD = "+" -REQUEST_TOPIC_SUFFIX = "/request" -RESPONSE_TOPIC_SUFFIX = "/response" # Copyright 2025 ThingsBoard # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +WILDCARD = "+" +REQUEST_TOPIC_SUFFIX = "/request" +RESPONSE_TOPIC_SUFFIX = "/response" + # V1 Topics for Device API DEVICE_TELEMETRY_TOPIC = "v1/devices/me/telemetry" DEVICE_ATTRIBUTES_TOPIC = "v1/devices/me/attributes" diff --git a/tb_mqtt_client/constants/service_messages.py b/tb_mqtt_client/constants/service_messages.py deleted file mode 100644 index daaa25d..0000000 --- a/tb_mqtt_client/constants/service_messages.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from orjson import dumps - - -SESSION_LIMITS_REQUEST_MESSAGE = dumps( - { - "method": "getSessionLimits", - "params": {}, - } -) diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py index 54254ef..fdcc77e 100644 --- a/tb_mqtt_client/entities/data/attribute_request.py +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -12,53 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from dataclasses import dataclass +from typing import Optional, List, Dict +from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer +@dataclass(slots=True, frozen=True) class AttributeRequest: """ - Represents a request for attributes, including shared and client attributes. - This class is used to encapsulate the details of an attribute request. - If some scope is not needed, it can be set to None. - shared: list - A list of shared attribute keys to the request. If empty - all shared attributes will be requested. - client: list - A list of client attribute keys to the request. If empty - all client attributes will be requested. + Represents a request for device attributes, with optional client and shared attribute keys. + Automatically assigns a unique request ID via the build() method. """ + request_id: int + shared_keys: Optional[List[str]] = None + client_keys: Optional[List[str]] = None - def __init__(self, shared: list, client: list): - self._id: Union[int, None] = None - self.shared_keys = shared - self.client_keys = client + def __new__(self, *args, **kwargs): + raise TypeError("Direct instantiation of AttributeRequest is not allowed. Use 'await AttributeRequest.build(...)'.") - @property - def id(self) -> Union[int, None]: - """ - Get the unique ID for this attribute request. - :return: Unique identifier for the request or None if not set. - """ - return self._id + def __repr__(self) -> str: + return f"" - @id.setter - def id(self, value: int): + @classmethod + async def build(cls, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'AttributeRequest': """ - Set the unique ID for this attribute request. - :param value: Unique identifier to set for the request. + Build a new AttributeRequest with a unique request ID, using the global ID generator. """ - if not isinstance(value, int): - raise ValueError("ID must be an integer.") - self._id = value - - def __repr__(self): - return f"" - - def to_payload_format(self) -> dict: + request_id = await AttributeRequestIdProducer.get_next() + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'shared_keys', shared_keys) + object.__setattr__(self, 'client_keys', client_keys) + return self + + def to_payload_format(self) -> Dict[str, str]: """ - Convert the attribute request to a payload format suitable for sending over MQTT to the platform. + Convert the attribute request into the expected MQTT payload format. """ - formatted_request = {} + payload = {} if self.shared_keys is not None: - formatted_request["sharedKeys"] = ','.join(self.shared_keys) + payload["sharedKeys"] = ','.join(self.shared_keys) if self.client_keys is not None: - formatted_request["clientKeys"] = ','.join(self.client_keys) - return formatted_request + payload["clientKeys"] = ','.join(self.client_keys) + return payload diff --git a/tb_mqtt_client/entities/data/attribute_update.py b/tb_mqtt_client/entities/data/attribute_update.py index 2ac3deb..0f5fd18 100644 --- a/tb_mqtt_client/entities/data/attribute_update.py +++ b/tb_mqtt_client/entities/data/attribute_update.py @@ -44,7 +44,7 @@ def as_dict(self) -> Dict[str, Any]: return {entry.key: entry.value for entry in self.entries} @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'AttributeUpdate': + def _deserialize_from_dict(cls, data: Dict[str, Any]) -> 'AttributeUpdate': """ Deserialize dictionary into AttributeUpdate object. :param data: Dictionary of attribute key-value pairs. diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py index 18ce603..61841fb 100644 --- a/tb_mqtt_client/entities/data/rpc_request.py +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import Union, Optional, Dict, Any +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer + @dataclass(slots=True, frozen=True) class RPCRequest: @@ -22,24 +24,50 @@ class RPCRequest: method: str params: Optional[Any] = None - def to_dict(self) -> Dict[str, Any]: - result = { - "id": self.request_id, - "method": self.method - } - if self.params is not None: - result["params"] = self.params - return result + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of RPCRequest is not allowed. Use 'await RPCRequest.build(...)'.") + + def __repr__(self): + return f"" @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'RPCRequest': - if "id" not in data: - raise ValueError("Missing 'id' in RPC request") + def _deserialize_from_dict(cls, request_id: int, data: Dict[str, Any]) -> 'RPCRequest': + """ + Constructs an RPCRequest, should be used only for deserialization request from the platform. + """ + if not isinstance(request_id, (int, str)): + raise ValueError("Missing request id in RPC request") if "method" not in data: raise ValueError("Missing 'method' in RPC request") - return cls( - request_id=data["id"], - method=data["method"], - params=data.get("params") - ) + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'method', data["method"]) + object.__setattr__(self, 'params', data.get("params")) + return self + + @classmethod + async def build(cls, method: str, params: Optional[Any] = None) -> 'RPCRequest': + """ + Constructs an RPCRequest with a unique request ID, + using the RPCRequestIdProducer to ensure thread-safe ID generation. + """ + request_id = await RPCRequestIdProducer.get_next() + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'method', method) + object.__setattr__(self, 'params', params) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """ + Serializes the RPC request for publishing. + Converts the request to a dictionary format suitable for MQTT payload. + """ + data = { + "id": self.request_id, + "method": self.method + } + if self.params is not None: + data["params"] = self.params + return data diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py index c4b0c20..5a31a14 100644 --- a/tb_mqtt_client/entities/data/rpc_response.py +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -19,7 +19,7 @@ @dataclass(slots=True, frozen=True) class RPCResponse: """ - Represents a response to a device-side RPC call. + Represents a response to the RPC call. Attributes: request_id: Unique identifier of the request being responded to. @@ -30,6 +30,23 @@ class RPCResponse: result: Optional[Any] = None error: Optional[Union[str, Dict[str, Any]]] = None + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of RPCResponse is not allowed. Use RPCResponse.build(request_id, result, error).") + + def __repr__(self) -> str: + return f"RPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" + + @classmethod + def build(cls, request_id: Union[int, str], result: Optional[Any] = None, error: Optional[Union[str, Dict[str, Any]]] = None) -> 'RPCResponse': + """ + Constructs an RPCResponse explicitly. + """ + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'result', result) + object.__setattr__(self, 'error', error) + return self + def to_payload_format(self) -> Dict[str, Any]: """Serializes the RPC response for publishing.""" data = {} diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 09dc7bd..a73c4a2 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -63,7 +63,7 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): super().__init__(self._config.host, self._config.port, client_id) self._message_queue: Optional[MessageQueue] = None - self._message_dispatcher: Optional[MessageDispatcher] = None + self._message_dispatcher: MessageDispatcher = JsonMessageDispatcher(1000, 1) # Will be updated after connection established self._messages_rate_limit = RateLimit("0:0,", name="messages") self._telemetry_rate_limit = RateLimit("0:0,", name="telemetry") @@ -77,6 +77,7 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._mqtt_manager = MQTTManager(client_id=self._client_id, main_stop_event=self._stop_event, + message_dispatcher=self._message_dispatcher, on_connect=self._on_connect, on_disconnect=self._on_disconnect, rate_limits_handler=self._handle_rate_limit_response, @@ -116,7 +117,8 @@ async def connect(self): self.max_payload_size = 65535 logger.debug("Using default max_payload_size: %d", self.max_payload_size) - self._message_dispatcher = JsonMessageDispatcher(self.max_payload_size, self._telemetry_dp_rate_limit.minimal_limit) + self._message_dispatcher = JsonMessageDispatcher(self.max_payload_size, + self._telemetry_dp_rate_limit.minimal_limit) self._message_queue = MessageQueue( mqtt_manager=self._mqtt_manager, main_stop_event=self._stop_event, @@ -132,14 +134,14 @@ async def connect(self): self._rpc_requests_handler.set_message_dispatcher(self._message_dispatcher) self._rpc_response_handler.set_message_dispatcher(self._message_dispatcher) - def stop(self): + async def stop(self): """ Stops the client and disconnects from the MQTT broker. """ logger.info("Stopping DeviceClient...") self._stop_event.set() if self._mqtt_manager.is_connected(): - self._mqtt_manager.disconnect() + await self._mqtt_manager.disconnect() logger.info("DeviceClient stopped.") async def disconnect(self): @@ -165,18 +167,30 @@ async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry datapoints_count=message.attributes_datapoint_count()) return futures[0] if futures else None + async def send_rpc_request(self, rpc_request: RPCRequest, + callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> Awaitable[RPCResponse]: + request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() + topic, payload = self._message_dispatcher.build_rpc_request_payload(rpc_request) + + response_future = self._rpc_response_handler.register_request(request_id, callback) + + await self._message_queue.publish(topic=topic, + payload=payload, + datapoints_count=0) + return response_future + async def send_rpc_response(self, response: RPCResponse): - topic = mqtt_topics.build_device_rpc_response_topic(request_id=response.request_id) - payload = self._message_dispatcher.build_rpc_response_payload(response) + topic, payload = self._message_dispatcher.build_rpc_response_payload(response) await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0) - async def send_attribute_request(self, attribute_request: AttributeRequest, callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + async def send_attribute_request(self, + attribute_request: AttributeRequest, + callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): await self._requested_attribute_response_handler.register_request(attribute_request, callback) - topic = mqtt_topics.build_device_attributes_request_topic(attribute_request.id) - payload = self._message_dispatcher.build_attribute_request_payload(attribute_request) + topic, payload = self._message_dispatcher.build_attribute_request_payload(attribute_request) await self._message_queue.publish(topic=topic, payload=payload, @@ -198,8 +212,8 @@ async def _on_connect(self): self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_response) # noqa self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, self._handle_requested_attribute_response) # noqa + # RPC responses are handled by the RPCResponseHandler, which is already registered async def _on_disconnect(self): logger.info("Device client disconnected.") @@ -243,22 +257,21 @@ async def _handle_rpc_response(self, topic: str, payload: bytes): async def _handle_requested_attribute_response(self, topic: str, payload: bytes): await self._requested_attribute_response_handler.handle(topic, payload) - async def _handle_rate_limit_response(self, topic: str, payload: bytes): # noqa + async def _handle_rate_limit_response(self, response: RPCResponse): # noqa try: - response = loads(payload.decode("utf-8")) logger.debug("Received rate limit response payload: %s", response) - if not isinstance(response, dict) or 'rateLimits' not in response: + if not isinstance(response.result, dict) or 'rateLimits' not in response.result: logger.warning("Invalid rate limit response: %r", response) return None - rate_limits = response.get('rateLimits', {}) + rate_limits = response.result.get('rateLimits', {}) await self._messages_rate_limit.set_limit(rate_limits.get("messages", "0:0,")) await self._telemetry_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,")) await self._telemetry_dp_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,")) - server_inflight = int(response.get("maxInflightMessages", 100)) + server_inflight = int(response.result.get("maxInflightMessages", 100)) limits = [rl.minimal_limit for rl in [ self._messages_rate_limit, self._telemetry_rate_limit @@ -272,8 +285,8 @@ async def _handle_rate_limit_response(self, topic: str, payload: bytes): # noqa if self._max_inflight_messages == 0: self._max_inflight_messages = 10000 - if "maxPayloadSize" in response: - self.max_payload_size = int(response["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + if "maxPayloadSize" in response.result: + self.max_payload_size = int(response.result["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) # Update the dispatcher's max_payload_size if it's already initialized if hasattr(self, '_dispatcher') and self._message_dispatcher is not None: self._message_dispatcher.splitter.max_payload_size = self.max_payload_size diff --git a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py index 7c0d0c3..249078f 100644 --- a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py @@ -46,10 +46,10 @@ async def register_request(self, request: AttributeRequest, callback: Callable[[ """ Called when a request is sent to the platform and a response is awaited. """ - request.id = await AttributeRequestIdProducer.get_next() - if request.id in self._pending_attribute_requests: - raise RuntimeError(f"Request ID {request.id} is already registered.") - self._pending_attribute_requests[request.id] = (request, callback) + request_id = request.request_id + if request_id in self._pending_attribute_requests: + raise RuntimeError(f"Request ID {request.request_id} is already registered.") + self._pending_attribute_requests[request.request_id] = (request, callback) def unregister_request(self, request_id: int): """ diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py index d99d0f3..9008b67 100644 --- a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -13,12 +13,11 @@ # limitations under the License. import asyncio -from typing import Dict, Union - -from orjson import loads +from typing import Dict, Union, Awaitable, Callable, Optional, Tuple from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher, JsonMessageDispatcher logger = get_logger(__name__) @@ -30,8 +29,10 @@ class RPCResponseHandler: """ def __init__(self): - self._message_dispatcher = None - self._pending_rpc_requests: Dict[Union[str, int], asyncio.Future] = {} + self._message_dispatcher: Optional[MessageDispatcher] = None + self._pending_rpc_requests: Dict[Union[str, int], + Tuple[asyncio.Future[RPCResponse], + Optional[Callable[[RPCResponse], Awaitable[None]]]]] = {} def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): """ @@ -44,14 +45,15 @@ def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): self._message_dispatcher = message_dispatcher logger.debug("Message dispatcher set for RPCResponseHandler.") - def register_request(self, request_id: Union[str, int]) -> asyncio.Future: + def register_request(self, request_id: Union[str, int], + callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> asyncio.Future[RPCResponse]: """ Called when a request is sent to the platform and a response is awaited. """ if request_id in self._pending_rpc_requests: raise RuntimeError(f"Request ID {request_id} is already registered.") future = asyncio.get_event_loop().create_future() - self._pending_rpc_requests[request_id] = future + self._pending_rpc_requests[request_id] = future, callback return future async def handle(self, topic: str, payload: bytes): @@ -60,19 +62,33 @@ async def handle(self, topic: str, payload: bytes): The topic is expected to be: v1/devices/me/rpc/response/{request_id} """ try: - # TODO: Use MessageDispatcher to parse the topic and payload - request_id = topic.split("/")[-1] - response_data = loads(payload) + if not self._message_dispatcher: + dummy_dispatcher = JsonMessageDispatcher() + rpc_response = dummy_dispatcher.parse_rpc_response(topic, payload) + else: + rpc_response = self._message_dispatcher.parse_rpc_response(topic, payload) - future = self._pending_rpc_requests.pop(request_id, None) + request_details = self._pending_rpc_requests.pop(rpc_response.request_id, None) + if not request_details: + logger.warning("No pending request found for request ID %s. Ignoring response.", + rpc_response.request_id) + return + future, callback = request_details if not future: - logger.warning("No future awaiting request ID %s. Ignoring.", request_id) + logger.warning("No future awaiting request ID %s. Ignoring.", rpc_response.request_id) return + if callback: + try: + await callback(rpc_response) + except Exception as e: + logger.exception("Error in callback for request ID %s: %s", rpc_response.request_id, e) + future.set_exception(e) + return - if isinstance(response_data, dict) and "error" in response_data: - future.set_exception(Exception(response_data["error"])) + if rpc_response.error: + future.set_exception(Exception(rpc_response.error)) else: - future.set_result(response_data) + future.set_result(rpc_response) except Exception as e: logger.exception("Failed to handle RPC response: %s", e) @@ -81,7 +97,7 @@ def clear(self): """ Clears all pending futures (e.g., on disconnect). """ - for fut in self._pending_rpc_requests.values(): + for fut, _ in self._pending_rpc_requests.values(): if not fut.done(): fut.cancel() self._pending_rpc_requests.clear() diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 9e242b7..6786c73 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -48,12 +48,13 @@ def build_uplink_payloads( ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. - Each pair consists of a topic string and a payload byte array. + Each pair consists of a topic string, payload bytes, the number of datapoints, + and a list of futures for delivery confirmation. """ pass @abstractmethod - def build_attribute_request_payload(self, request: AttributeRequest) -> bytes: + def build_attribute_request_payload(self, request: AttributeRequest) -> Tuple[str, bytes]: """ Build the payload for an attribute request response. This method should return a tuple of topic and payload bytes. @@ -61,7 +62,15 @@ def build_attribute_request_payload(self, request: AttributeRequest) -> bytes: pass @abstractmethod - def build_rpc_response_payload(self, rpc_response: RPCResponse): + def build_rpc_request_payload(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: + """ + Build the payload for an RPC request. + This method should return a tuple of topic and payload bytes. + """ + pass + + @abstractmethod + def build_rpc_response_payload(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: """ Build the payload for an RPC response. This method should return a tuple of topic and payload bytes. @@ -99,8 +108,19 @@ def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: """ pass + @abstractmethod + def parse_rpc_response(self, topic: str, payload: bytes) -> RPCResponse: + """ + Parse the RPC response from the given topic and payload. + This method should be implemented to handle the specific format of the RPC response. + """ + pass + class JsonMessageDispatcher(MessageDispatcher): + """ + A concrete implementation of MessageDispatcher that operates with JSON payloads. + """ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): super().__init__(max_payload_size, max_datapoints) logger.trace("JsonMessageDispatcher created.") @@ -134,7 +154,7 @@ def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: try: data = loads(payload) logger.trace("Parsing attribute update from payload: %s", data) - return AttributeUpdate.from_dict(data) + return AttributeUpdate._deserialize_from_dict(data) except Exception as e: logger.error("Failed to parse attribute update: %s", str(e)) raise ValueError("Invalid attribute update format") from e @@ -149,18 +169,38 @@ def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: try: request_id = int(topic.split("/")[-1]) parsed = loads(payload) - parsed["id"] = request_id - data = RPCRequest.from_dict(parsed) + data = RPCRequest._deserialize_from_dict(request_id, parsed) # noqa return data except Exception as e: logger.error("Failed to parse RPC request: %s", str(e)) raise ValueError("Invalid RPC request format") from e + def parse_rpc_response(self, topic: str, payload: bytes) -> RPCResponse: + """ + Parse the RPC response from the given topic and payload. + :param topic: The MQTT topic of the RPC response. + :param payload: The raw bytes of the payload. + :return: An instance of RPCResponse. + """ + try: + request_id = int(topic.split("/")[-1]) + parsed = loads(payload) + data = RPCResponse.build(request_id, parsed) # noqa + return data + except Exception as e: + logger.error("Failed to parse RPC response: %s", str(e)) + raise ValueError("Invalid RPC response format") from e + @property def splitter(self) -> MessageSplitter: return self._splitter def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + """ + Build a list of topic-payload pairs from the given messages. + Each pair consists of a topic string, payload bytes, the number of datapoints, + and a list of futures for delivery confirmation. + """ try: if not messages: logger.trace("No messages to process in build_topic_payloads.") @@ -194,26 +234,43 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) - logger.trace("Generated %d topic-payload entries.", len(result)) return result except Exception as e: logger.error("Error building topic-payloads: %s", str(e)) + logger.debug("Exception details: %s", e, exc_info=True) raise - def build_attribute_request_payload(self, request: AttributeRequest) -> bytes: + def build_attribute_request_payload(self, request: AttributeRequest) -> Tuple[str, bytes]: """ Build the payload for an attribute request response. :param request: The AttributeRequest to build the payload for. :return: A tuple of topic and payload bytes. """ - if not request.id: + if not request.request_id: raise ValueError("AttributeRequest must have a valid ID.") + topic = mqtt_topics.build_device_attributes_request_topic(request.request_id) payload = dumps(request.to_payload_format()) logger.trace("Built attribute request payload for request: %r", request) - return payload + return topic, payload + + def build_rpc_request_payload(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: + """ + Build the payload for an RPC request. + :param rpc_request: The RPC request to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not rpc_request.request_id: + raise ValueError("RPCRequest must have a valid ID.") + + payload = dumps(rpc_request.to_payload_format()) + topic = mqtt_topics.DEVICE_RPC_REQUEST_TOPIC + str(rpc_request.request_id) + logger.trace("Built RPC request payload for request ID=%d with payload: %r", + rpc_request.request_id, payload) + return topic, payload + def build_rpc_response_payload(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: """ diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 49f0aa2..c185d62 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -74,14 +74,15 @@ async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], logger.error("Message queue full. Dropping message for topic %s", topic) for future in payload.get_delivery_futures(): if future: - future.set_result(False) + future.set_result(PublishResult(topic, self.__qos, -1, len(payload), -1)) return delivery_futures or None async def _dequeue_loop(self): logger.debug("MessageQueue dequeue loop started.") while self._active.is_set(): try: - topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) + # topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) + topic, payload, delivery_futures_or_none, count = await self._wait_for_message() logger.trace("MessageQueue dequeue: topic=%s, payload=%r, count=%d", topic, payload, count) if isinstance(payload, bytes): @@ -109,7 +110,7 @@ async def _dequeue_loop(self): logger.trace("Dequeued message for batching: topic=%s, device=%s", topic, getattr(payload, 'device_name', 'N/A')) - batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], asyncio.Future[PublishResult], int]] = [(topic, payload, delivery_futures_or_none, count)] + batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int]] = [(topic, payload, delivery_futures_or_none, count)] start = asyncio.get_event_loop().time() batch_size = payload.size @@ -189,7 +190,11 @@ async def _try_publish(self, logger.debug("Telemetry datapoint rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._telemetry_dp_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, payload, datapoints, delay=retry_delay, delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(topic=topic, + payload=payload, + points=datapoints, + delay=retry_delay, + delivery_futures=delivery_futures_or_none) return else: # For non-telemetry messages, we only need to check the message rate limit @@ -227,7 +232,11 @@ def resolve_attached(publish_future: asyncio.Future): i, id(f), publish_result, id(publish_future), publish_future) logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", id(mqtt_future), mqtt_future.done()) - mqtt_future.add_done_callback(resolve_attached) + if mqtt_future.done(): + logger.debug("Main publish future is already done, resolving immediately.") + resolve_attached(mqtt_future) + else: + mqtt_future.add_done_callback(resolve_attached) except Exception as e: logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) self._schedule_delayed_retry(topic, payload, datapoints, delay=.1) @@ -258,7 +267,7 @@ async def retry(): self._retry_tasks.add(task) task.add_done_callback(self._retry_tasks.discard) - async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage], int]: + async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int]: while self._active.is_set(): try: if not self._queue.empty(): @@ -276,16 +285,16 @@ async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage ) for task in pending: - logger.debug("Cancelling pending task: %r, it is queue_task = %r", task, queue_task==task) + logger.trace("Cancelling pending task: %r, it is queue_task = %r", task, queue_task==task) task.cancel() with suppress(asyncio.CancelledError): await task if queue_task in done: - logger.debug("Retrieved message from queue: %r", queue_task.result()) + logger.trace("Retrieved message from queue: %r", queue_task.result()) return queue_task.result() - await asyncio.sleep(0) + await asyncio.sleep(0) # Yield control to the event loop except asyncio.CancelledError: break diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index a484c48..61267ec 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -37,9 +37,9 @@ def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: logger.trace("Splitting timeseries for %d messages", len(messages)) - if (len(messages) == 1 and - messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() <= self._max_datapoints and - messages[0].size <= self._max_payload_size) or self._max_datapoints == 0: + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa + and messages[0].size <= self._max_payload_size): return messages result: List[DeviceUplinkMessage] = [] @@ -55,24 +55,25 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp point_count = 0 batch_futures = [] - for ts in message.timeseries.values(): - exceeds_size = builder and size + ts.size > self._max_payload_size - exceeds_points = 0 < self._max_datapoints <= point_count - - if not builder or exceeds_size or exceeds_points: - if builder: - built = builder.build() - result.append(built) - batch_futures.extend(built.get_delivery_futures()) - logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) - builder = DeviceUplinkMessageBuilder().set_device_name(message.device_name).set_device_profile( - message.device_profile) - size = 0 - point_count = 0 - - builder.add_telemetry(ts) - size += ts.size - point_count += 1 + for grouped_ts in message.timeseries.values(): + for ts in grouped_ts: + exceeds_size = builder and size + ts.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder: + built = builder.build() + result.append(built) + batch_futures.extend(built.get_delivery_futures()) + logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) + builder = DeviceUplinkMessageBuilder().set_device_name(message.device_name).set_device_profile( + message.device_profile) + size = 0 + point_count = 0 + + builder.add_telemetry(ts) + size += ts.size + point_count += 1 if builder and builder._timeseries: # noqa built = builder.build() @@ -86,7 +87,8 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp [id(batch_future) for batch_future in batch_futures]) async def resolve_original(): - logger.exception("Resolving original future with batch futures: %s", [id(f) for f in batch_futures]) + logger.trace("Resolving original future with batch futures: %r, %s", + batch_futures, [id(f) for f in batch_futures]) results = await asyncio.gather(*batch_futures, return_exceptions=False) original_future.set_result(all(results)) @@ -99,9 +101,9 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp logger.trace("Splitting attributes for %d messages", len(messages)) result: List[DeviceUplinkMessage] = [] - if (len(messages) == 1 and - messages[0].attributes_datapoint_count() <= self._max_datapoints and - messages[0].size <= self._max_payload_size): + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa + and messages[0].size <= self._max_payload_size): return messages for message in messages: diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index ba6ac07..9b0e5f7 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -15,6 +15,7 @@ import asyncio import ssl from asyncio import sleep +from contextlib import suppress from time import monotonic from typing import Optional, Callable, Awaitable, Dict, Union, Tuple @@ -29,9 +30,11 @@ from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT, TELEMETRY_MESSAGE_RATE_LIMIT, \ TELEMETRY_DATAPOINTS_RATE_LIMIT -from tb_mqtt_client.constants.service_messages import SESSION_LIMITS_REQUEST_MESSAGE +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher logger = get_logger(__name__) @@ -47,12 +50,14 @@ def __init__( self, client_id: str, main_stop_event: asyncio.Event, + message_dispatcher: MessageDispatcher, on_connect: Optional[Callable[[], Awaitable[None]]] = None, on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, - rate_limits_handler: Optional[Callable[[str, bytes], Awaitable[None]]] = None, + rate_limits_handler: Optional[Callable[[RPCResponse], Awaitable[None]]] = None, rpc_response_handler: Optional[RPCResponseHandler] = None, ): self._main_stop_event = main_stop_event + self._message_dispatcher = message_dispatcher patch_gmqtt_protocol_connection_lost() patch_mqtt_handler_disconnect() @@ -78,6 +83,7 @@ def __init__( self._pending_subscriptions: Dict[int, asyncio.Future] = {} self._pending_unsubscriptions: Dict[int, asyncio.Future] = {} self._rpc_response_handler = rpc_response_handler or RPCResponseHandler() + self.register_handler(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, self._rpc_response_handler.handle) self._backpressure = BackpressureController(self._main_stop_event) self.__rate_limits_handler = rate_limits_handler @@ -218,6 +224,7 @@ async def __handle_connect_and_limits(self): await self._on_connect_callback() def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): # noqa + self._connected_event.clear() if reason_code is not None: reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") logger.info("Disconnected from platform with reason code: %s (%s)", reason_code, reason_desc) @@ -247,7 +254,6 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc AttributeRequestIdProducer.reset() self._rpc_response_handler.clear() self._handlers.clear() - self._connected_event.clear() self.__rate_limits_retrieved = False self.__is_waiting_for_rate_limits_publish = True self._rate_limits_ready_event.clear() @@ -269,14 +275,12 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc asyncio.create_task(self._on_disconnect_callback()) def _on_message_internal(self, client, topic: str, payload: bytes, qos, properties): - logger.trace("Received message by client %r on topic %s with payload %r, qos %s, properties %s", + logger.trace("Received message by client %r on topic %s with payload %r, qos %r, properties %r", client, topic, payload, qos, properties) for topic_filter, handler in self._handlers.items(): if self._match_topic(topic_filter, topic): asyncio.create_task(handler(topic, payload)) return - if topic.startswith(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC): - asyncio.create_task(self._rpc_response_handler.handle(topic, payload)) def _on_publish_internal(self, client, mid): pass @@ -350,27 +354,14 @@ def set_rate_limits( async def __request_rate_limits(self): self.__is_waiting_for_rate_limits_publish = True - request_id = await RPCRequestIdProducer.get_next() - request_topic = f"v1/devices/me/rpc/request/{request_id}" - response_topic = f"v1/devices/me/rpc/response/{request_id}" + logger.debug("Publishing rate limits request to server...") - logger.debug("Publishing rate limits request to: %s", request_topic) - response_future = self._rpc_response_handler.register_request(request_id) - - async def _handler(topic: str, payload: bytes): - try: - if self.__rate_limits_handler: - await self.__rate_limits_handler(topic, payload) - response_future.set_result(payload) - except Exception as e: - logger.debug("Error handling rate limits response: %s", e) - response_future.set_exception(e) - - self.register_handler(response_topic, _handler) + request = await RPCRequest.build("getSessionLimits") + topic, payload = self._message_dispatcher.build_rpc_request_payload(request) + response_future = self._rpc_response_handler.register_request(request.request_id, self.__rate_limits_handler) try: - logger.debug("Requesting rate limits via RPC...") - await self.publish(request_topic, SESSION_LIMITS_REQUEST_MESSAGE, qos=1, force=True) + await self.publish(topic, payload, qos=1, force=True) await asyncio.wait_for(response_future, timeout=10) logger.info("Successfully processed rate limits.") self.__rate_limits_retrieved = True @@ -380,8 +371,6 @@ async def _handler(topic: str, payload: bytes): logger.warning("Timeout while waiting for rate limits.") # Keep __is_waiting_for_rate_limits_publish as True to prevent publishing # until rate limits are retrieved - finally: - self.unregister_handler(response_topic) @property def backpressure(self) -> BackpressureController: @@ -405,15 +394,24 @@ def _match_topic(filter_expression: str, topic: str) -> bool: async def _monitor_ack_timeouts(self): while not self._main_stop_event.is_set(): now = monotonic() - expired = [] - for mid, (future, topic, qos, payload_size, timestamp) in list(self._pending_publishes.items()): - if now - timestamp > self._PUBLISH_TIMEOUT: - if not future.done(): - logger.warning("Publish timeout: mid=%s, topic=%s", mid, topic) - result = PublishResult(topic, qos, payload_size, mid, reason_code=408) - future.set_result(result) - expired.append(mid) - for mid in expired: - self._pending_publishes.pop(mid, None) + await self.check_pending_publishes(now) # TODO: Add logic to handle expired futures, for subscriptions, rpc responses, etc. - await asyncio.sleep(1) + await asyncio.sleep(0.1) + await self.check_pending_publishes(monotonic()) + + async def check_pending_publishes(self, time_to_check): + expired = [] + for mid, (future, topic, qos, payload_size, timestamp) in list(self._pending_publishes.items()): + if self._main_stop_event.is_set(): + with suppress(asyncio.CancelledError): + future.cancel() + continue + if time_to_check - timestamp > self._PUBLISH_TIMEOUT: + if not future.done(): + logger.warning("Publish timeout: mid=%s, topic=%s", mid, topic) + result = PublishResult(topic, qos, payload_size, mid, reason_code=408) + future.set_result(result) + expired.append(mid) + for mid in expired: + self._pending_publishes.pop(mid, None) + From e25c785d4700de8b9c120a3c25a215c667bf302e Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 30 May 2025 12:52:08 +0300 Subject: [PATCH 05/74] Added values validation, improved entities representation --- examples/device/operational_example.py | 8 ++- tb_mqtt_client/common/config_loader.py | 6 +- tb_mqtt_client/constants/json_typing.py | 43 +++++++++++ tb_mqtt_client/constants/service_keys.py | 4 ++ .../entities/data/attribute_entry.py | 5 +- .../entities/data/attribute_request.py | 5 +- .../entities/data/attribute_update.py | 2 +- tb_mqtt_client/entities/data/data_entry.py | 8 ++- .../entities/data/device_uplink_message.py | 71 ++++++++++++------- .../data/requested_attribute_response.py | 2 +- tb_mqtt_client/entities/data/rpc_request.py | 5 +- tb_mqtt_client/entities/data/rpc_response.py | 4 ++ .../entities/data/timeseries_entry.py | 5 +- tb_mqtt_client/entities/publish_result.py | 2 +- tb_mqtt_client/service/device/client.py | 23 ++++++ .../requested_attributes_response_handler.py | 5 +- 16 files changed, 154 insertions(+), 44 deletions(-) create mode 100644 tb_mqtt_client/constants/json_typing.py diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index a5d2a1d..c754977 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -71,7 +71,7 @@ async def attribute_request_callback(requested_attributes_response: RequestedAtt Callback function to handle requested attributes. :param requested_attributes_response: The requested attribute response object. """ - logger.info("Received requested attributes response: %s", requested_attributes_response.as_dict()) + logger.info("Received requested attributes response: %r", requested_attributes_response) async def main(): @@ -124,7 +124,7 @@ def _shutdown_handler(): else: logger.warning("Delivery future is None, raw attributes may not be sent.") - # logger.info(f"Raw attributes sent: {raw_dict}") + logger.info(f"Raw attributes sent: {raw_dict}") # 2. Single AttributeEntry single_entry = AttributeEntry("mode", "normal") @@ -236,6 +236,8 @@ def _shutdown_handler(): else: logger.warning("Delivery future is None, list of telemetry may not be sent.") + # --- Attribute Request --- + logger.info("Requesting attributes...") attribute_request = await AttributeRequest.build(["uno"], ["client"]) @@ -244,6 +246,8 @@ def _shutdown_handler(): await client.send_attribute_request(attribute_request, attribute_request_callback) + # --- Client-side RPC --- + logger.info("Sending client side RPC request...") rpc_request = await RPCRequest.build("getSomeInformation", {"key1": "value1"}) diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index ca3b1a2..f034403 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -41,10 +41,10 @@ def use_tls(self) -> bool: return self.ca_cert is not None def __repr__(self): - return (f"") + f"client_id={self.client_id} " + f"tls={self.use_tls()})") class GatewayConfig(DeviceConfig): diff --git a/tb_mqtt_client/constants/json_typing.py b/tb_mqtt_client/constants/json_typing.py new file mode 100644 index 0000000..3e98ff7 --- /dev/null +++ b/tb_mqtt_client/constants/json_typing.py @@ -0,0 +1,43 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List, Dict + +JSONPrimitive = Union[str, int, float, bool, None] +JSONCompatibleType = Union[JSONPrimitive, List["JSONType"], Dict[str, "JSONType"]] + +def validate_json_compatibility(value: object) -> None: + """ + Validates that the input value is fully JSON-compatible in structure and type. + Raises a ValueError on the first incompatible type encountered. + """ + stack = [(value, "$")] + basic_types = (str, int, float, bool, type(None)) + + while stack: + current, path = stack.pop() + t = type(current) + + if t in basic_types: + continue + if t is list: + for idx, item in enumerate(current): # type: ignore[arg-type] + stack.append((item, f"{path}[{idx}]")) + elif t is dict: + for k, v in current.items(): # type: ignore[union-attr] + if type(k) is not str: + raise ValueError(f"Invalid JSON key at {path}: expected str, got {type(k).__name__} ({k!r})") + stack.append((v, f"{path}.{k}")) + else: + raise ValueError(f"Invalid JSON value at {path}: unsupported - {type(current).__name__} ({current!r})") \ No newline at end of file diff --git a/tb_mqtt_client/constants/service_keys.py b/tb_mqtt_client/constants/service_keys.py index befebaa..ac0568e 100644 --- a/tb_mqtt_client/constants/service_keys.py +++ b/tb_mqtt_client/constants/service_keys.py @@ -15,3 +15,7 @@ MESSAGES_RATE_LIMIT = "MESSAGES_RATE_LIMIT" TELEMETRY_MESSAGE_RATE_LIMIT = "TELEMETRY_MESSAGE_RATE_LIMIT" TELEMETRY_DATAPOINTS_RATE_LIMIT = "TELEMETRY_DATAPOINTS_RATE_LIMIT" + + +TELEMETRY_TIMESTAMP_PARAMETER = "ts" +TELEMETRY_VALUES_PARAMETER = "values" diff --git a/tb_mqtt_client/entities/data/attribute_entry.py b/tb_mqtt_client/entities/data/attribute_entry.py index fce1c3b..b145f62 100644 --- a/tb_mqtt_client/entities/data/attribute_entry.py +++ b/tb_mqtt_client/entities/data/attribute_entry.py @@ -14,15 +14,16 @@ from typing import Any +from tb_mqtt_client.constants.json_typing import JSONCompatibleType from tb_mqtt_client.entities.data.data_entry import DataEntry class AttributeEntry(DataEntry): - def __init__(self, key: str, value: Any): + def __init__(self, key: str, value: JSONCompatibleType): super().__init__(key, value) def __repr__(self): - return f"" + return f"AttributeEntry(key={self.key}, value={self.value})" def as_dict(self) -> dict: return { diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py index fdcc77e..24c9c8c 100644 --- a/tb_mqtt_client/entities/data/attribute_request.py +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from typing import Optional, List, Dict from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer +from tb_mqtt_client.constants.json_typing import validate_json_compatibility @dataclass(slots=True, frozen=True) @@ -31,13 +32,15 @@ def __new__(self, *args, **kwargs): raise TypeError("Direct instantiation of AttributeRequest is not allowed. Use 'await AttributeRequest.build(...)'.") def __repr__(self) -> str: - return f"" + return f"AttributeRequest(id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" @classmethod async def build(cls, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'AttributeRequest': """ Build a new AttributeRequest with a unique request ID, using the global ID generator. """ + validate_json_compatibility(shared_keys) + validate_json_compatibility(client_keys) request_id = await AttributeRequestIdProducer.get_next() self = object.__new__(cls) object.__setattr__(self, 'request_id', request_id) diff --git a/tb_mqtt_client/entities/data/attribute_update.py b/tb_mqtt_client/entities/data/attribute_update.py index 0f5fd18..416d5f1 100644 --- a/tb_mqtt_client/entities/data/attribute_update.py +++ b/tb_mqtt_client/entities/data/attribute_update.py @@ -23,7 +23,7 @@ class AttributeUpdate: entries: List[AttributeEntry] def __repr__(self): - return f"" + return f"AttributeUpdate(entries={self.entries})" def get(self, key: str, default=None): for entry in self.entries: diff --git a/tb_mqtt_client/entities/data/data_entry.py b/tb_mqtt_client/entities/data/data_entry.py index 5d0f0ac..641cbcf 100644 --- a/tb_mqtt_client/entities/data/data_entry.py +++ b/tb_mqtt_client/entities/data/data_entry.py @@ -15,14 +15,20 @@ from typing import Any, Optional from orjson import dumps +from tb_mqtt_client.constants.json_typing import JSONCompatibleType, validate_json_compatibility + class DataEntry: - def __init__(self, key: str, value: Any, ts: Optional[int] = None): + def __init__(self, key: str, value: JSONCompatibleType, ts: Optional[int] = None): + validate_json_compatibility(value) self.__key = key self.__value = value self.__ts = ts self.__size = self.__estimate_size() + def __repr__(self): + return f"DataEntry(key={self.key}, value={self.value}, ts={self.ts})" + def __estimate_size(self) -> int: if self.ts is not None: return len(dumps({"ts": self.ts, "values": {self.key: self.value}})) diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index e13e596..d4e33f6 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -13,7 +13,9 @@ # limitations under the License. import asyncio -from typing import List, Optional, Union, OrderedDict +from dataclasses import dataclass +from types import MappingProxyType +from typing import List, Optional, Union, OrderedDict, Tuple, Mapping from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry @@ -25,27 +27,48 @@ DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) +@dataclass(slots=True, frozen=True) class DeviceUplinkMessage: - def __init__(self, - device_name: Optional[str] = None, - device_profile: Optional[str] = None, - attributes: Optional[List[AttributeEntry]] = None, - timeseries: Optional[OrderedDict[int, List[TimeseriesEntry]]] = None, - _size: Optional[int] = None, - delivery_future: List[Optional[asyncio.Future[PublishResult]]] = None): - if _size is None: - raise ValueError("DeviceUplinkMessage must be created using DeviceUplinkMessageBuilder") - - self.device_name = device_name - self.device_profile = device_profile - self.attributes = attributes or [] - self.timeseries = timeseries or [] - self.__size = _size - self.delivery_futures = delivery_future or [] + device_name: Optional[str] + device_profile: Optional[str] + attributes: Tuple[AttributeEntry, ...] + timeseries: Mapping[int, Tuple[TimeseriesEntry, ...]] + delivery_futures: Tuple[Optional[asyncio.Future], ...] + _size: int + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") + + def __repr__(self): + return (f"DeviceUplinkMessage(device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"attributes={self.attributes}, " + f"timeseries={self.timeseries}, " + f"delivery_futures={self.delivery_futures})") + + @classmethod + def build(cls, + device_name: Optional[str], + device_profile: Optional[str], + attributes: List[AttributeEntry], + timeseries: Mapping[int, List[TimeseriesEntry]], + delivery_futures: List[Optional[asyncio.Future]], + size: int) -> 'DeviceUplinkMessage': + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) + object.__setattr__(self, 'attributes', tuple(attributes)) + object.__setattr__(self, 'timeseries', MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) + object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) + object.__setattr__(self, '_size', size) + return self + @property + def size(self) -> int: + return self._size def timeseries_datapoint_count(self) -> int: - return len(self.timeseries) + return sum(len(entries) for entries in self.timeseries.values()) def attributes_datapoint_count(self) -> int: return len(self.attributes) @@ -56,13 +79,9 @@ def has_attributes(self) -> bool: def has_timeseries(self) -> bool: return bool(self.timeseries) - def get_delivery_futures(self): + def get_delivery_futures(self) -> Tuple[Optional[asyncio.Future], ...]: return self.delivery_futures - @property - def size(self) -> int: - return self.__size - class DeviceUplinkMessageBuilder: def __init__(self): @@ -125,11 +144,11 @@ def add_delivery_futures(self, futures: Union[asyncio.Future[PublishResult], Lis def build(self) -> DeviceUplinkMessage: if not self._delivery_futures: self._delivery_futures = [asyncio.get_event_loop().create_future()] - return DeviceUplinkMessage( + return DeviceUplinkMessage.build( device_name=self._device_name, device_profile=self._device_profile, attributes=self._attributes, timeseries=self._timeseries, - _size=self.__size, - delivery_future=self._delivery_futures + delivery_futures=self._delivery_futures, + size=self.__size ) diff --git a/tb_mqtt_client/entities/data/requested_attribute_response.py b/tb_mqtt_client/entities/data/requested_attribute_response.py index 1ba8857..5ee97b6 100644 --- a/tb_mqtt_client/entities/data/requested_attribute_response.py +++ b/tb_mqtt_client/entities/data/requested_attribute_response.py @@ -26,7 +26,7 @@ class RequestedAttributeResponse: client: List[AttributeEntry] def __repr__(self): - return f"" + return f"RequestedAttributeResponse(request_id={self.request_id}, shared={self.shared}, client={self.client})" def __getitem__(self, item): """ diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py index 61841fb..cd2bff6 100644 --- a/tb_mqtt_client/entities/data/rpc_request.py +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -16,6 +16,7 @@ from typing import Union, Optional, Dict, Any from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.constants.json_typing import validate_json_compatibility @dataclass(slots=True, frozen=True) @@ -28,7 +29,7 @@ def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of RPCRequest is not allowed. Use 'await RPCRequest.build(...)'.") def __repr__(self): - return f"" + return f"RPCRequest(id={self.request_id}, method={self.method}, params={self.params})" @classmethod def _deserialize_from_dict(cls, request_id: int, data: Dict[str, Any]) -> 'RPCRequest': @@ -56,6 +57,8 @@ async def build(cls, method: str, params: Optional[Any] = None) -> 'RPCRequest': self = object.__new__(cls) object.__setattr__(self, 'request_id', request_id) object.__setattr__(self, 'method', method) + if params is not None: + validate_json_compatibility(params) object.__setattr__(self, 'params', params) return self diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py index 5a31a14..ab54e18 100644 --- a/tb_mqtt_client/entities/data/rpc_response.py +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import Union, Optional, Dict, Any +from tb_mqtt_client.constants.json_typing import validate_json_compatibility + @dataclass(slots=True, frozen=True) class RPCResponse: @@ -43,7 +45,9 @@ def build(cls, request_id: Union[int, str], result: Optional[Any] = None, error: """ self = object.__new__(cls) object.__setattr__(self, 'request_id', request_id) + validate_json_compatibility(result) object.__setattr__(self, 'result', result) + validate_json_compatibility(error) object.__setattr__(self, 'error', error) return self diff --git a/tb_mqtt_client/entities/data/timeseries_entry.py b/tb_mqtt_client/entities/data/timeseries_entry.py index 137ea43..3970a70 100644 --- a/tb_mqtt_client/entities/data/timeseries_entry.py +++ b/tb_mqtt_client/entities/data/timeseries_entry.py @@ -14,15 +14,16 @@ from typing import Any, Optional +from tb_mqtt_client.constants.json_typing import JSONCompatibleType from tb_mqtt_client.entities.data.data_entry import DataEntry class TimeseriesEntry(DataEntry): - def __init__(self, key: str, value: Any, ts: Optional[int] = None): + def __init__(self, key: str, value: JSONCompatibleType, ts: Optional[int] = None): super().__init__(key, value, ts) def __repr__(self): - return f"" + return f"TelemetryEntry(key={self.key}, value={self.value}, ts={self.ts})" def as_dict(self) -> dict: result = { diff --git a/tb_mqtt_client/entities/publish_result.py b/tb_mqtt_client/entities/publish_result.py index 2f4497c..61a66cd 100644 --- a/tb_mqtt_client/entities/publish_result.py +++ b/tb_mqtt_client/entities/publish_result.py @@ -21,7 +21,7 @@ def __init__(self, topic: str, qos: int, message_id: int, payload_size: int, rea self.reason_code = reason_code def __repr__(self): - return f"" + return f"PublishResult(topic={self.topic}, qos={self.qos}, message_id={self.message_id}, payload_size={self.payload_size}, reason_code={self.reason_code})" def as_dict(self) -> dict: return { diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index a73c4a2..14c8e98 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -24,6 +24,8 @@ from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate @@ -332,12 +334,33 @@ def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], List[TimeseriesEntry], List[Dict[str, Any]]]) -> DeviceUplinkMessage: if isinstance(payload, dict): + if TELEMETRY_TIMESTAMP_PARAMETER in payload: + ts = payload.pop(TELEMETRY_TIMESTAMP_PARAMETER) + values = payload.pop(TELEMETRY_VALUES_PARAMETER, {}) + else: + ts = None payload = [TimeseriesEntry(k, v) for k, v in payload.items()] builder = DeviceUplinkMessageBuilder() builder.add_telemetry(payload) return builder.build() + @staticmethod + def __build_timeseries_entry_from_dict(data: Dict[str, Any]) -> TimeseriesEntry: + if TELEMETRY_TIMESTAMP_PARAMETER in data: + ts = data.pop(TELEMETRY_TIMESTAMP_PARAMETER) + values = data.pop(TELEMETRY_VALUES_PARAMETER, {}) + if not isinstance(values, dict): + raise ValueError(f"Expected {TELEMETRY_VALUES_PARAMETER} to be a dict, got {type(values).__name__}") + values + else: + ts = None + values = data + for key, value in values.items(): + if not isinstance(key, str): + raise ValueError(f"Expected keys in {TELEMETRY_VALUES_PARAMETER} to be strings, got {type(key).__name__}") + return TimeseriesEntry(values, ts=ts) + @staticmethod def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], AttributeEntry, diff --git a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py index 249078f..38ca4ca 100644 --- a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py @@ -15,7 +15,6 @@ from typing import Dict, Tuple, Awaitable, Callable from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.service.message_dispatcher import MessageDispatcher @@ -70,7 +69,7 @@ async def handle(self, topic: str, payload: bytes): if not self._message_dispatcher: logger.error("Message dispatcher is not initialized. Cannot handle attribute response.") request_id = topic.split('/')[-1] # Assuming request ID is in the topic - self._pending_attribute_requests.pop(int(request_id), (None, None, None)) + self._pending_attribute_requests.pop(int(request_id), None) return requested_attribute_response = self._message_dispatcher.parse_attribute_request_response(topic, payload) @@ -82,7 +81,7 @@ async def handle(self, topic: str, payload: bytes): request, callback = pending_request_details if callback: - logger.debug("Invoking callback for requested attribute response with ID %s", requested_attribute_response.request_id) + logger.trace("Invoking callback for requested attribute response with ID %s", requested_attribute_response.request_id) await callback(requested_attribute_response) else: logger.error("No callback registered for requested attribute response with ID %s", requested_attribute_response.request_id) From 5b98ce461d4255f4fb933104e9087d1317e92d10 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 4 Jun 2025 08:31:50 +0300 Subject: [PATCH 06/74] Added wait for publish option to send_attributes and timeseries and candidate for claiming processing for device client --- examples/device/operational_example.py | 2 +- tb_mqtt_client/common/config_loader.py | 3 + tb_mqtt_client/constants/mqtt_topics.py | 4 +- tb_mqtt_client/entities/data/claim_request.py | 58 +++++++ .../entities/data/device_uplink_message.py | 8 +- tb_mqtt_client/entities/data/rpc_request.py | 9 +- tb_mqtt_client/entities/publish_result.py | 6 + tb_mqtt_client/service/base_client.py | 50 ++++-- tb_mqtt_client/service/device/client.py | 144 +++++++++++++----- tb_mqtt_client/service/message_dispatcher.py | 33 +++- tb_mqtt_client/service/message_queue.py | 74 +++++---- tb_mqtt_client/service/mqtt_manager.py | 20 ++- 12 files changed, 317 insertions(+), 94 deletions(-) create mode 100644 tb_mqtt_client/entities/data/claim_request.py diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index c754977..29a0b87 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -129,7 +129,7 @@ def _shutdown_handler(): # 2. Single AttributeEntry single_entry = AttributeEntry("mode", "normal") logger.info("Sending single attribute: %s", single_entry) - delivery_future = await client.send_attributes(single_entry) + delivery_future = await client.send_attributes(single_entry, wait_for_publish=True, timeout=5) if delivery_future: logger.info("Awaiting delivery future for single attribute...") try: diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index f034403..15b22b2 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -34,6 +34,9 @@ def __init__(self): self.client_cert: Optional[str] = os.getenv("TB_CLIENT_CERT") self.private_key: Optional[str] = os.getenv("TB_PRIVATE_KEY") + # Default values + self.qos: int = int(os.getenv("TB_QOS", 1)) + def use_tls_auth(self) -> bool: return all([self.ca_cert, self.client_cert, self.private_key]) diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 5955816..3986551 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -21,13 +21,15 @@ DEVICE_ATTRIBUTES_TOPIC = "v1/devices/me/attributes" DEVICE_ATTRIBUTES_REQUEST_TOPIC = DEVICE_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + "{request_id}" DEVICE_ATTRIBUTES_RESPONSE_TOPIC = DEVICE_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD -DEVICE_RPC_TOPIC = "v1/devices/me/rpc" # Device RPC topics +DEVICE_RPC_TOPIC = "v1/devices/me/rpc" DEVICE_RPC_REQUEST_TOPIC = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" DEVICE_RPC_RESPONSE_TOPIC = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" # Device RPC topics for subscription DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + WILDCARD DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD +# Device Claim topic +DEVICE_CLAIM_TOPIC = "v1/devices/me/claim" # V1 Topics for Gateway API BASE_GATEWAY_TOPIC = "v1/gateway" diff --git a/tb_mqtt_client/entities/data/claim_request.py b/tb_mqtt_client/entities/data/claim_request.py new file mode 100644 index 0000000..ec42057 --- /dev/null +++ b/tb_mqtt_client/entities/data/claim_request.py @@ -0,0 +1,58 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Dict, Any + + +@dataclass(slots=True, frozen=True) +class ClaimRequest: + """ + Represents a device claim request, as per ThingsBoard MQTT API. + Optionally includes a secret key and a duration (in seconds) for the claim. + Does not include request_id, as it's not required for this operation. + """ + secret_key: Optional[str] = None + duration: Optional[int] = None # seconds + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of ClaimRequest is not allowed. Use 'ClaimRequest.build(...)'.") + + def __repr__(self) -> str: + return f"ClaimRequest(secret_key={self.secret_key}, duration={self.duration})" + + @classmethod + def build(cls, secret_key: str, duration: int = 60000) -> 'ClaimRequest': + """ + Safely construct a new ClaimRequest instance. + """ + if not isinstance(secret_key, str): + raise ValueError("Secret key must be a string") + if not isinstance(duration, int) or duration < 0: + raise ValueError("Duration must be a non-negative integer representing seconds") + self = object.__new__(cls) + object.__setattr__(self, 'secret_key', secret_key) + object.__setattr__(self, 'duration', duration) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """ + Convert the claim request to the expected MQTT JSON payload format. + """ + payload = {} + if self.secret_key is not None: + payload["secretKey"] = self.secret_key + if self.duration is not None: + payload["durationMs"] = int(self.duration * 1000) + return payload \ No newline at end of file diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index d4e33f6..2848fdb 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -31,9 +31,9 @@ class DeviceUplinkMessage: device_name: Optional[str] device_profile: Optional[str] - attributes: Tuple[AttributeEntry, ...] - timeseries: Mapping[int, Tuple[TimeseriesEntry, ...]] - delivery_futures: Tuple[Optional[asyncio.Future], ...] + attributes: Tuple[AttributeEntry] + timeseries: Mapping[int, Tuple[TimeseriesEntry]] + delivery_futures: List[Optional[asyncio.Future[PublishResult]]] _size: int def __new__(cls, *args, **kwargs): @@ -79,7 +79,7 @@ def has_attributes(self) -> bool: def has_timeseries(self) -> bool: return bool(self.timeseries) - def get_delivery_futures(self) -> Tuple[Optional[asyncio.Future], ...]: + def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: return self.delivery_futures diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py index cd2bff6..e84c302 100644 --- a/tb_mqtt_client/entities/data/rpc_request.py +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -16,7 +16,7 @@ from typing import Union, Optional, Dict, Any from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer -from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType @dataclass(slots=True, frozen=True) @@ -48,17 +48,18 @@ def _deserialize_from_dict(cls, request_id: int, data: Dict[str, Any]) -> 'RPCRe return self @classmethod - async def build(cls, method: str, params: Optional[Any] = None) -> 'RPCRequest': + async def build(cls, method: str, params: Optional[JSONCompatibleType] = None) -> 'RPCRequest': """ Constructs an RPCRequest with a unique request ID, using the RPCRequestIdProducer to ensure thread-safe ID generation. """ + if not isinstance(method, str): + raise ValueError("Method must be a string") + validate_json_compatibility(params) request_id = await RPCRequestIdProducer.get_next() self = object.__new__(cls) object.__setattr__(self, 'request_id', request_id) object.__setattr__(self, 'method', method) - if params is not None: - validate_json_compatibility(params) object.__setattr__(self, 'params', params) return self diff --git a/tb_mqtt_client/entities/publish_result.py b/tb_mqtt_client/entities/publish_result.py index 61a66cd..6da17b1 100644 --- a/tb_mqtt_client/entities/publish_result.py +++ b/tb_mqtt_client/entities/publish_result.py @@ -31,3 +31,9 @@ def as_dict(self) -> dict: "payload_size": self.payload_size, "reason_code": self.reason_code } + + def is_successful(self) -> bool: + """ + Check if the publish operation was successful based on the reason code. + """ + return self.reason_code == 0 diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 70e2746..1f1bebe 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -14,14 +14,17 @@ import asyncio from abc import ABC, abstractmethod -from typing import Callable, Awaitable, Dict, Any, Union +from typing import Callable, Awaitable, Dict, Any, Union, List import uvloop from tb_mqtt_client.common.exceptions import exception_handler from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.publish_result import PublishResult asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) exception_handler.install_asyncio_handler() @@ -29,9 +32,11 @@ class BaseClient(ABC): """ - Abstract base class for ThingsBoard clients. + Abstract base class for clients. """ + DEFAULT_TIMEOUT = 3 + def __init__(self, host: str, port: int, client_id: str): self._host = host self._port = port @@ -41,7 +46,7 @@ def __init__(self, host: str, port: int, client_id: str): @abstractmethod async def connect(self): """ - Connect to the ThingsBoard platform over MQTT. + Connect to the platform over MQTT. """ pass @@ -53,20 +58,47 @@ async def disconnect(self): pass @abstractmethod - async def send_telemetry(self, telemetry_data: Dict[str, Any]): + async def send_telemetry(self, telemetry_data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool = True, + timeout: int = DEFAULT_TIMEOUT) -> Union[asyncio.Future[PublishResult], PublishResult]: """ Send telemetry data. - :param telemetry_data: Dictionary of telemetry key-values. + :param telemetry_data: Dictionary of telemetry data, a single TimeseriesEntry, + or a list of TimeseriesEntry or dictionaries. + :param wait_for_publish: If True, wait for the publishing result. Default is True. + :param timeout: Timeout for the publish operation if `wait_for_publish` is True. + In seconds, defaults to 3 seconds. + :return: Future or PublishResult depending on `wait_for_publish`. + """ + pass + + @abstractmethod + async def send_attributes(self, + attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + wait_for_publish: bool = True, + timeout: int = DEFAULT_TIMEOUT) -> Union[asyncio.Future[PublishResult], PublishResult]: + """ + Send client attributes. + + :param attributes: Dictionary of attributes or a single AttributeEntry or a list of AttributeEntries. + :param wait_for_publish: If True, wait for the publishing result. Default is True. + :param timeout: Timeout for the publish operation if `wait_for_publish` is True. + In seconds, defaults to 3 seconds. + :return: Future or PublishResult depending on `wait_for_publish`. """ pass @abstractmethod - async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]]): + async def claim_device(self, claim_request: ClaimRequest) -> Union[asyncio.Future[PublishResult], PublishResult]: """ - Send client-side attributes. + Claim a device using the provided ClaimRequest. - :param attributes: Dictionary of attributes. + :param claim_request: The ClaimRequest instance contains secret key and duration. + :return: Future or PublishResult depending on the implementation. """ pass @@ -95,4 +127,4 @@ def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awa :param callback: Coroutine accepting (method, params) and returning result. """ - pass + pass \ No newline at end of file diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 14c8e98..f481acd 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -12,28 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asyncio import sleep, wait_for, TimeoutError, Event +from asyncio import sleep, wait_for, TimeoutError, Event, Future from random import choices from string import ascii_uppercase, digits +from time import time from typing import Callable, Awaitable, Optional, Dict, Any, Union, List -from orjson import loads, dumps +from orjson import dumps from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ @@ -82,12 +85,14 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): message_dispatcher=self._message_dispatcher, on_connect=self._on_connect, on_disconnect=self._on_disconnect, + on_publish_result=self.__on_publish_result, rate_limits_handler=self._handle_rate_limit_response, - rpc_response_handler=self._rpc_response_handler,) + rpc_response_handler=self._rpc_response_handler) self._requested_attribute_response_handler = RequestedAttributeResponseHandler() self._attribute_updates_handler = AttributeUpdatesHandler() self._rpc_requests_handler = RPCRequestsHandler() + self.__claiming_response_future: Union[Future[bool], None] = None async def connect(self): logger.info("Connecting to platform at %s:%s", self._host, self._port) @@ -152,52 +157,85 @@ async def disconnect(self): # await self._message_queue.shutdown() # TODO: Not sure if we need to shutdown the message queue here, as it might be handled by MQTTManager - async def send_telemetry(self, telemetry_data: Union[Dict[str, Any], - TimeseriesEntry, - List[TimeseriesEntry], - List[Dict[str, Any]]]): - message = self._build_uplink_message_for_telemetry(telemetry_data) - futures = await self._message_queue.publish(topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + async def send_telemetry(self, + data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], + qos: int = 1, + wait_for_publish: bool = True, + timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: + message = self._build_uplink_message_for_telemetry(data) + topic = mqtt_topics.DEVICE_TELEMETRY_TOPIC + futures = await self._message_queue.publish(topic=topic, payload=message, - datapoints_count=message.timeseries_datapoint_count()) - return futures[0] if futures else None + datapoints_count=message.timeseries_datapoint_count(), + qos=qos or self._config.qos) + if wait_for_publish: + try: + return await wait_for(futures[0], timeout=timeout) if futures else None + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + return PublishResult(topic, qos, -1, message.size, -1) + else: + return futures[0] if futures else None - async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]]): + async def send_attributes(self, + attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + qos: int = None, + wait_for_publish: bool = True, + timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: message = self._build_uplink_message_for_attributes(attributes) futures = await self._message_queue.publish(topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, payload=message, - datapoints_count=message.attributes_datapoint_count()) + datapoints_count=message.attributes_datapoint_count(), + qos=qos or self._config.qos) return futures[0] if futures else None async def send_rpc_request(self, rpc_request: RPCRequest, callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> Awaitable[RPCResponse]: request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() - topic, payload = self._message_dispatcher.build_rpc_request_payload(rpc_request) + topic, payload = self._message_dispatcher.build_rpc_request(rpc_request) response_future = self._rpc_response_handler.register_request(request_id, callback) await self._message_queue.publish(topic=topic, payload=payload, - datapoints_count=0) + datapoints_count=0, + qos=self._config.qos) return response_future async def send_rpc_response(self, response: RPCResponse): - topic, payload = self._message_dispatcher.build_rpc_response_payload(response) + topic, payload = self._message_dispatcher.build_rpc_response(response) await self._message_queue.publish(topic=topic, payload=payload, - datapoints_count=0) + datapoints_count=0, + qos=self._config.qos) async def send_attribute_request(self, attribute_request: AttributeRequest, - callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + callback: Callable[[RequestedAttributeResponse], Awaitable[None]],): await self._requested_attribute_response_handler.register_request(attribute_request, callback) - topic, payload = self._message_dispatcher.build_attribute_request_payload(attribute_request) + topic, payload = self._message_dispatcher.build_attribute_request(attribute_request) await self._message_queue.publish(topic=topic, payload=payload, - datapoints_count=0) - + datapoints_count=0, + qos=self._config.qos) + + async def claim_device(self, + claim_request: ClaimRequest, + wait_for_publish: bool = True, + timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: + topic, payload = self._message_dispatcher.build_claim_request(claim_request) + self.__claiming_response_future = Future() + await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) + if wait_for_publish: + try: + return await wait_for(self.__claiming_response_future, timeout=timeout) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + return PublishResult(topic, 1, -1, len(payload), -1) + else: + return self.__claiming_response_future def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): self._attribute_updates_handler.set_callback(callback) @@ -328,38 +366,68 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa logger.exception("Failed to parse rate limits from server response: %s", e) return False + async def __on_publish_result(self, publish_result: PublishResult): + """ + Callback for handling publish results. + This can be used to handle the result of a publish operation, such as logging or updating state. + """ + if mqtt_topics.DEVICE_CLAIM_TOPIC == publish_result.topic: + if self.__claiming_response_future and not self.__claiming_response_future.done(): + if publish_result.is_successful(): + self.__claiming_response_future.set_result(True) + logger.debug("Device claimed successfully.") + else: + self.__claiming_response_future.set_exception( + Exception(f"Failed to claim device: {publish_result}")) + logger.error("Failed to claim device: %r", publish_result) + return + if publish_result.is_successful(): + logger.trace("Publish successful: %r", publish_result) + else: + logger.error("Publish failed: %r", publish_result) + @staticmethod def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry], List[Dict[str, Any]]]) -> DeviceUplinkMessage: - if isinstance(payload, dict): - if TELEMETRY_TIMESTAMP_PARAMETER in payload: - ts = payload.pop(TELEMETRY_TIMESTAMP_PARAMETER) - values = payload.pop(TELEMETRY_VALUES_PARAMETER, {}) - else: - ts = None - payload = [TimeseriesEntry(k, v) for k, v in payload.items()] + timeseries_entries = [] + if isinstance(payload, TimeseriesEntry): + timeseries_entries.append(payload) + elif isinstance(payload, dict): + timeseries_entries.extend(DeviceClient.__build_timeseries_entry_from_dict(payload)) + elif isinstance(payload, list) and len(payload) > 0: + for item in payload: + if isinstance(item, dict): + timeseries_entries.extend(DeviceClient.__build_timeseries_entry_from_dict(item)) + elif isinstance(item, TimeseriesEntry): + timeseries_entries.append(item) + else: + raise ValueError(f"Unsupported item type in telemetry list: {type(item).__name__}") + else: + raise ValueError(f"Unsupported payload type for telemetry: {type(payload).__name__}") builder = DeviceUplinkMessageBuilder() - builder.add_telemetry(payload) + builder.add_telemetry(timeseries_entries) return builder.build() @staticmethod - def __build_timeseries_entry_from_dict(data: Dict[str, Any]) -> TimeseriesEntry: + def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> List[TimeseriesEntry]: + result = [] if TELEMETRY_TIMESTAMP_PARAMETER in data: ts = data.pop(TELEMETRY_TIMESTAMP_PARAMETER) values = data.pop(TELEMETRY_VALUES_PARAMETER, {}) - if not isinstance(values, dict): - raise ValueError(f"Expected {TELEMETRY_VALUES_PARAMETER} to be a dict, got {type(values).__name__}") - values else: - ts = None + ts = time() * 1000 values = data + + if not isinstance(values, dict): + raise ValueError(f"Expected {TELEMETRY_VALUES_PARAMETER} to be a dict, got {type(values).__name__}") + for key, value in values.items(): - if not isinstance(key, str): - raise ValueError(f"Expected keys in {TELEMETRY_VALUES_PARAMETER} to be strings, got {type(key).__name__}") - return TimeseriesEntry(values, ts=ts) + result.append(TimeseriesEntry(key, value, ts=ts)) + + return result @staticmethod def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 6786c73..52d1790 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -54,7 +54,7 @@ def build_uplink_payloads( pass @abstractmethod - def build_attribute_request_payload(self, request: AttributeRequest) -> Tuple[str, bytes]: + def build_attribute_request(self, request: AttributeRequest) -> Tuple[str, bytes]: """ Build the payload for an attribute request response. This method should return a tuple of topic and payload bytes. @@ -62,7 +62,14 @@ def build_attribute_request_payload(self, request: AttributeRequest) -> Tuple[st pass @abstractmethod - def build_rpc_request_payload(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: + def build_claim_request(self, claim_request) -> Tuple[str, bytes]: + """ + Build the payload for a claim request. + This method should return a tuple of topic and payload bytes. + """ + + @abstractmethod + def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: """ Build the payload for an RPC request. This method should return a tuple of topic and payload bytes. @@ -70,7 +77,7 @@ def build_rpc_request_payload(self, rpc_request: RPCRequest) -> Tuple[str, bytes pass @abstractmethod - def build_rpc_response_payload(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: + def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: """ Build the payload for an RPC response. This method should return a tuple of topic and payload bytes. @@ -242,7 +249,7 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup logger.debug("Exception details: %s", e, exc_info=True) raise - def build_attribute_request_payload(self, request: AttributeRequest) -> Tuple[str, bytes]: + def build_attribute_request(self, request: AttributeRequest) -> Tuple[str, bytes]: """ Build the payload for an attribute request response. :param request: The AttributeRequest to build the payload for. @@ -256,7 +263,21 @@ def build_attribute_request_payload(self, request: AttributeRequest) -> Tuple[st logger.trace("Built attribute request payload for request: %r", request) return topic, payload - def build_rpc_request_payload(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: + def build_claim_request(self, claim_request) -> Tuple[str, bytes]: + """ + Build the payload for a claim request. + :param claim_request: The ClaimRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not claim_request.secret_key: + raise ValueError("ClaimRequest must have a valid secret key.") + + topic = mqtt_topics.DEVICE_CLAIM_TOPIC + payload = dumps(claim_request.to_payload_format()) + logger.trace("Built claim request payload: %r", claim_request) + return topic, payload + + def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: """ Build the payload for an RPC request. :param rpc_request: The RPC request to build the payload for. @@ -272,7 +293,7 @@ def build_rpc_request_payload(self, rpc_request: RPCRequest) -> Tuple[str, bytes return topic, payload - def build_rpc_response_payload(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: + def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: """ Build the payload for an RPC response. :param rpc_response: The RPC response to build the payload for. diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index c185d62..3dcd023 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -40,7 +40,6 @@ def __init__(self, max_queue_size: int = 1000000, batch_collect_max_time_ms: int = 100, batch_collect_max_count: int = 500): - self.__qos = 1 self._main_stop_event = main_stop_event self._batch_max_time = batch_collect_max_time_ms / 1000 # convert to seconds self._batch_max_count = batch_collect_max_count @@ -51,7 +50,7 @@ def __init__(self, self._backpressure = self._mqtt_manager.backpressure self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} - self._queue = asyncio.Queue(maxsize=max_queue_size) + self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) self._active = asyncio.Event() self._wakeup_event = asyncio.Event() self._retry_tasks: set[asyncio.Task] = set() @@ -62,19 +61,19 @@ def __init__(self, logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", max_queue_size, self._batch_max_time, batch_collect_max_count) - async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int): + async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) else [] try: logger.trace("publish() received delivery future id: %r for topic=%s", id(delivery_futures[0]) if delivery_futures else -1, topic) - self._queue.put_nowait((topic, payload, delivery_futures, datapoints_count)) + self._queue.put_nowait((topic, payload, delivery_futures, datapoints_count, qos)) logger.trace("Enqueued message: topic=%s, datapoints=%d, type=%s", topic, datapoints_count, type(payload).__name__) except asyncio.QueueFull: logger.error("Message queue full. Dropping message for topic %s", topic) for future in payload.get_delivery_futures(): if future: - future.set_result(PublishResult(topic, self.__qos, -1, len(payload), -1)) + future.set_result(PublishResult(topic, qos, -1, len(payload), -1)) return delivery_futures or None async def _dequeue_loop(self): @@ -82,15 +81,15 @@ async def _dequeue_loop(self): while self._active.is_set(): try: # topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) - topic, payload, delivery_futures_or_none, count = await self._wait_for_message() + topic, payload, delivery_futures_or_none, datapoints, qos = await self._wait_for_message() logger.trace("MessageQueue dequeue: topic=%s, payload=%r, count=%d", - topic, payload, count) + topic, payload, datapoints) if isinstance(payload, bytes): - await self._try_publish(topic, payload, count, delivery_futures_or_none) + await self._try_publish(topic, payload, datapoints, delivery_futures_or_none) continue logger.trace("Dequeued message: delivery_future id: %r topic=%s, type=%s, datapoints=%d", id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1, - topic, type(payload).__name__, count) + topic, type(payload).__name__, datapoints) await asyncio.sleep(0) # cooperative yield except asyncio.TimeoutError: logger.trace("Dequeue wait timed out. Yielding...") @@ -104,13 +103,13 @@ async def _dequeue_loop(self): if isinstance(payload, bytes): logger.trace("Dequeued immediate publish: topic=%s (raw bytes)", topic) - await self._try_publish(topic, payload, count, delivery_futures_or_none) + await self._try_publish(topic, payload, datapoints, delivery_futures_or_none) continue logger.trace("Dequeued message for batching: topic=%s, device=%s", topic, getattr(payload, 'device_name', 'N/A')) - batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int]] = [(topic, payload, delivery_futures_or_none, count)] + batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = [(topic, payload, delivery_futures_or_none, datapoints, qos)] start = asyncio.get_event_loop().time() batch_size = payload.size @@ -124,18 +123,18 @@ async def _dequeue_loop(self): break try: - next_topic, next_payload, delivery_futures_or_none, next_count = self._queue.get_nowait() + next_topic, next_payload, delivery_futures_or_none, datapoints, qos = self._queue.get_nowait() if isinstance(next_payload, DeviceUplinkMessage): msg_size = next_payload.size if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: # noqa logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) - self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, next_count)) + self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) break - batch.append((next_topic, next_payload, delivery_futures_or_none, next_count)) + batch.append((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) batch_size += msg_size else: logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_topic) - await self._try_publish(next_topic, next_payload, next_count) + await self._try_publish(next_topic, next_payload, datapoints) except asyncio.QueueEmpty: break @@ -148,13 +147,18 @@ async def _dequeue_loop(self): for topic, payload, datapoints, delivery_futures in topic_payloads: logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", topic, len(payload), datapoints, [id(f) for f in delivery_futures]) - await self._try_publish(topic, payload, datapoints, delivery_futures) + await self._try_publish(topic=topic, + payload=payload, + datapoints=datapoints, + delivery_futures_or_none=delivery_futures, + qos=qos) async def _try_publish(self, topic: str, payload: bytes, datapoints: int, - delivery_futures_or_none: List[Optional[asyncio.Future[PublishResult]]] = None): + delivery_futures_or_none: List[Optional[asyncio.Future[PublishResult]]] = None, + qos: int = 1): if delivery_futures_or_none is None: logger.trace("No delivery futures associated! This publish result will not be tracked.") delivery_futures_or_none = [] @@ -166,7 +170,12 @@ async def _try_publish(self, # Check backpressure first - if active, don't even try to check rate limits if self._backpressure.should_pause(): logger.debug("Backpressure active, delaying publish of topic=%s for %.1f seconds", topic, 1.0) - self._schedule_delayed_retry(topic, payload, datapoints, delay=1.0, delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(topic=topic, + payload=payload, + datapoints=datapoints, + qos=qos, + delay=1.0, + delivery_futures=delivery_futures_or_none) return # Check and consume rate limits atomically before publishing @@ -181,7 +190,12 @@ async def _try_publish(self, logger.debug("Telemetry message rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._telemetry_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, payload, datapoints, delay=retry_delay, delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(topic=topic, + payload=payload, + datapoints=datapoints, + qos=qos, + delay=retry_delay, + delivery_futures=delivery_futures_or_none) return if self._telemetry_dp_rate_limit: @@ -192,7 +206,8 @@ async def _try_publish(self, retry_delay = self._telemetry_dp_rate_limit.minimal_timeout self._schedule_delayed_retry(topic=topic, payload=payload, - points=datapoints, + datapoints=datapoints, + qos=qos, delay=retry_delay, delivery_futures=delivery_futures_or_none) return @@ -204,9 +219,10 @@ async def _try_publish(self, logger.debug("Generic message rate limit hit for topic %s: %r per %r seconds", topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._message_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic, - payload, - datapoints, + self._schedule_delayed_retry(topic=topic, + payload=payload, + datapoints=datapoints, + qos=qos, delay=retry_delay, delivery_futures=delivery_futures_or_none) return @@ -214,7 +230,7 @@ async def _try_publish(self, logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", topic, len(payload), id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1) - mqtt_future = await self._mqtt_manager.publish(topic, payload, qos=self.__qos) + mqtt_future = await self._mqtt_manager.publish(message_or_topic=topic, payload=payload, qos=self.__qos) if delivery_futures_or_none is not None: def resolve_attached(publish_future: asyncio.Future): @@ -239,9 +255,9 @@ def resolve_attached(publish_future: asyncio.Future): mqtt_future.add_done_callback(resolve_attached) except Exception as e: logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) - self._schedule_delayed_retry(topic, payload, datapoints, delay=.1) + self._schedule_delayed_retry(topic, payload, datapoints, qos, delay=.1) - def _schedule_delayed_retry(self, topic: str, payload: bytes, points: int, delay: float, + def _schedule_delayed_retry(self, topic: str, payload: bytes, datapoints: int, qos: int, delay: float, delivery_futures: Optional[List[Optional[asyncio.Future[PublishResult]]]] = None): if not self._active.is_set() or self._main_stop_event.is_set(): logger.debug("MessageQueue is not active or main stop event is set. Not scheduling retry for topic=%s", topic) @@ -255,7 +271,7 @@ async def retry(): if not self._active.is_set() or self._main_stop_event.is_set(): logger.debug("MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", topic) return - self._queue.put_nowait((topic, payload, delivery_futures, points)) + self._queue.put_nowait((topic, payload, delivery_futures, datapoints, qos)) self._wakeup_event.set() logger.debug("Re-enqueued message after delay: topic=%s", topic) except asyncio.QueueFull: @@ -267,7 +283,7 @@ async def retry(): self._retry_tasks.add(task) task.add_done_callback(self._retry_tasks.discard) - async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int]: + async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]: while self._active.is_set(): try: if not self._queue.empty(): @@ -331,7 +347,7 @@ def size(self): def clear(self): logger.debug("Clearing message queue...") while not self._queue.empty(): - _, message, _ = self._queue.get_nowait() + _, message, _, _, _ = self._queue.get_nowait() if isinstance(message, DeviceUplinkMessage) and message.get_delivery_futures(): for future in message.get_delivery_futures(): future.set_result(False) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 9b0e5f7..84afdf6 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -53,6 +53,7 @@ def __init__( message_dispatcher: MessageDispatcher, on_connect: Optional[Callable[[], Awaitable[None]]] = None, on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, + on_publish_result: Optional[Callable[[PublishResult], Awaitable[None]]] = None, rate_limits_handler: Optional[Callable[[RPCResponse], Awaitable[None]]] = None, rpc_response_handler: Optional[RPCResponseHandler] = None, ): @@ -72,6 +73,7 @@ def __init__( self._on_connect_callback = on_connect self._on_disconnect_callback = on_disconnect + self._on_publish_result_callback = on_publish_result self._connected_event = asyncio.Event() self._connect_params = None # Will be set in connect method @@ -92,6 +94,15 @@ def __init__( self.__is_gateway = False # TODO: determine if this is a gateway or not self.__is_waiting_for_rate_limits_publish = True # Start with True to prevent publishing before rate limits are retrieved self._rate_limits_ready_event = asyncio.Event() + self._claiming_future = None + + # TODO: In case of implementing for gateway may be better to use a handler, to discuss + def register_claiming_future(self, future: asyncio.Future): + """ + Register a future that will be set when the claiming process is complete. + This is used to ensure that the MQTT client does not publish messages before the claiming process is done. + """ + self._claiming_future = future async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, password: Optional[str] = None, tls: bool = False, @@ -221,7 +232,7 @@ async def __handle_connect_and_limits(self): await self.__request_rate_limits() if self._on_connect_callback: - await self._on_connect_callback() + asyncio.create_task(self._on_connect_callback()) def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): # noqa self._connected_event.clear() @@ -279,6 +290,9 @@ def _on_message_internal(self, client, topic: str, payload: bytes, qos, properti client, topic, payload, qos, properties) for topic_filter, handler in self._handlers.items(): if self._match_topic(topic_filter, topic): + # TODO + # TODO: Add awaiting for created task to ensure it is handled properly + # TODO asyncio.create_task(handler(topic, payload)) return @@ -316,6 +330,8 @@ def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dic elif reason_code != 0: logger.warning("PUBACK received with error code %s for mid=%s", reason_code, mid) + if self._on_publish_result_callback: + self._on_publish_result_callback(publish_result) def _on_subscribe_internal(self, client, mid, qos, properties): logger.trace("Received SUBACK by client %r for mid=%s with qos %s, properties %s", @@ -357,7 +373,7 @@ async def __request_rate_limits(self): logger.debug("Publishing rate limits request to server...") request = await RPCRequest.build("getSessionLimits") - topic, payload = self._message_dispatcher.build_rpc_request_payload(request) + topic, payload = self._message_dispatcher.build_rpc_request(request) response_future = self._rpc_response_handler.register_request(request.request_id, self.__rate_limits_handler) try: From efe94d21f3369b4c46f4a926d3cdac656c47cab5 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 4 Jun 2025 09:43:21 +0300 Subject: [PATCH 07/74] Fix for qos in message queue and added ability to send errors while RPC processing fails with details --- examples/device/operational_example.py | 132 ++++--------------- tb_mqtt_client/entities/data/rpc_response.py | 47 ++++++- tb_mqtt_client/service/device/client.py | 12 +- tb_mqtt_client/service/message_queue.py | 17 +-- tb_mqtt_client/service/mqtt_manager.py | 2 +- 5 files changed, 85 insertions(+), 125 deletions(-) diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index 29a0b87..c1289fc 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -25,7 +25,7 @@ from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest -from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.rpc_response import RPCResponse, RPCStatus from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient @@ -43,19 +43,28 @@ async def attribute_update_callback(update: AttributeUpdate): logger.info("Received attribute update: %r", update) -async def rpc_request_callback(request: RPCRequest): +async def rpc_request_callback(request: RPCRequest) -> RPCResponse: """ Callback function to handle RPC requests. :param request: The RPC request object. - :return: A RPC response object. + :return: An RPCResponse object. """ logger.info("Received RPC request: %r", request) - response_data = { - "status": "success", - } - response = RPCResponse(request_id=request.request_id, - result=response_data, - error=None) + + if request.method == "getError": + # Simulate an error response for demonstration purposes + logger.error("Simulated error for method: %s", request.method) + try: + # Simulate some processing that raises an error + raise RuntimeError("Simulated processing error") + except RuntimeError as e: + return RPCResponse.build(request_id=request.request_id, error=e) + else: + response_data = { + "message": f"Response for method {request.method}", + "params": request.params or {} + } + response = RPCResponse.build(request_id=request.request_id, result=response_data) return response async def rpc_response_callback(response: RPCResponse): @@ -109,42 +118,14 @@ def _shutdown_handler(): "hardwareModel": "TB-SDK-Device" } logger.info("Sending attributes...") - delivery_future = await client.send_attributes(raw_dict) - if delivery_future: - logger.info("Awaiting delivery future for raw attributes...") - try: - result = await asyncio.wait_for(delivery_future, timeout=5) - except asyncio.TimeoutError: - logger.warning("Delivery future timed out after 5 seconds.") - result = False - except Exception as e: - logger.error("Error while awaiting delivery future: %s", e) - result = False - logger.info("Raw attributes sent: %s, delivery result: %s", raw_dict, result) - else: - logger.warning("Delivery future is None, raw attributes may not be sent.") - - logger.info(f"Raw attributes sent: {raw_dict}") + raw_publish_result = await client.send_attributes(raw_dict) + logger.info(f"Raw attributes sent: {raw_dict} with result: {raw_publish_result}") # 2. Single AttributeEntry single_entry = AttributeEntry("mode", "normal") logger.info("Sending single attribute: %s", single_entry) - delivery_future = await client.send_attributes(single_entry, wait_for_publish=True, timeout=5) - if delivery_future: - logger.info("Awaiting delivery future for single attribute...") - try: - result = await asyncio.wait_for(delivery_future, timeout=5) - except asyncio.TimeoutError: - logger.warning("Delivery future timed out after 5 seconds.") - result = False - except Exception as e: - logger.error("Error while awaiting delivery future: %s", e) - result = False - logger.info("Single attribute sent: %s, delivery result: %s", single_entry, result) - else: - logger.warning("Delivery future is None, single attribute may not be sent.") - - logger.info("Single attribute sent: %s", single_entry) + single_attribute_publish_result = await client.send_attributes(single_entry) + logger.info(f"Single attribute sent: {single_entry} with result: {single_attribute_publish_result}") # 3. List of AttributeEntry attr_entries = [ @@ -152,20 +133,8 @@ def _shutdown_handler(): AttributeEntry("calibrated", True) ] logger.info("Sending list of attributes: %s", attr_entries) - delivery_future = await client.send_attributes(attr_entries) - if delivery_future: - logger.info("Awaiting delivery future for list of attributes...") - try: - result = await asyncio.wait_for(delivery_future, timeout=5) - except asyncio.TimeoutError: - logger.warning("Delivery future timed out after 5 seconds.") - result = False - except Exception as e: - logger.error("Error while awaiting delivery future: %s", e) - result = False - logger.info("List of attributes sent: %s, delivery result: %s", attr_entries, result) - else: - logger.warning("Delivery future is None, list of attributes may not be sent.") + attributes_list_publish_result = await client.send_attributes(attr_entries) + logger.info("List of attributes sent: %s with result: %s", attr_entries, attributes_list_publish_result) # --- Telemetry --- @@ -175,66 +144,23 @@ def _shutdown_handler(): "humidity": 60 } logger.info("Sending raw telemetry...") - delivery_future = await client.send_telemetry(raw_dict) - if delivery_future: - logger.info("Awaiting delivery future for raw telemetry...") - try: - result = await asyncio.wait_for(delivery_future, timeout=5) - except asyncio.TimeoutError: - logger.warning("Delivery future timed out after 5 seconds.") - result = False - except Exception as e: - logger.error("Error while awaiting delivery future: %s", e) - result = False - logger.info("Raw telemetry sent: %s, delivery result: %s", raw_dict, result) - else: - logger.warning("Delivery future is None, raw telemetry may not be sent.") - - logger.info(f"Raw telemetry sent: {raw_dict}") + raw_telemetry_publish_result = await client.send_telemetry(raw_dict) + logger.info(f"Raw telemetry sent: {raw_dict} with result: {raw_telemetry_publish_result}") # 2. Single TelemetryEntry (with ts) single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) logger.info("Sending single telemetry: %s", single_entry) delivery_future = await client.send_telemetry(single_entry) - if delivery_future: - logger.info("Awaiting delivery future for single telemetry...") - try: - result = await asyncio.wait_for(delivery_future, timeout=5) - except asyncio.TimeoutError: - logger.warning("Delivery future timed out after 5 seconds.") - result = False - except Exception as e: - logger.error("Error while awaiting delivery future: %s", e) - result = False - logger.info("Single telemetry sent: %s, delivery result: %s", single_entry, result) - else: - logger.warning("Delivery future is None, single telemetry may not be sent.") - - logger.info("Single telemetry sent: %s", single_entry) + logger.info(f"Single telemetry sent: {single_entry} with delivery future: {delivery_future}") # 3. List of TelemetryEntry with mixed timestamps telemetry_entries = [] for i in range(1): telemetry_entries.append(TimeseriesEntry("temperature", i, ts=int(datetime.now(UTC).timestamp() * 1000)-i)) - ts_now = int(datetime.now(UTC).timestamp() * 1000) logger.info("Sending list of telemetry entries with mixed timestamps...") - delivery_future = await client.send_telemetry(telemetry_entries) - if delivery_future: - logger.info("Awaiting delivery future for list of telemetry...") - try: - result = await asyncio.wait_for(delivery_future, timeout=5) - except asyncio.TimeoutError: - logger.warning("Delivery future timed out after 5 seconds.") - result = False - except Exception as e: - logger.error("Error while awaiting delivery future: %s", e) - result = False - logger.info("List of telemetry sent: %s, it took %r milliseconds", len(telemetry_entries), - int(datetime.now(UTC).timestamp() * 1000) - ts_now) - logger.info("Delivery result: %s", result) - else: - logger.warning("Delivery future is None, list of telemetry may not be sent.") + telemetry_list_publish_result = await client.send_telemetry(telemetry_entries) + logger.info("List of telemetry entries sent: %s with result: %s", telemetry_entries, telemetry_list_publish_result) # --- Attribute Request --- diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py index ab54e18..7eafa84 100644 --- a/tb_mqtt_client/entities/data/rpc_response.py +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -12,12 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from dataclasses import dataclass +from traceback import format_exception from typing import Union, Optional, Dict, Any -from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType +class RPCStatus(Enum): + """ + Enum representing the status of an RPC call. + """ + SUCCESS = "SUCCESS" + ERROR = "ERROR" + TIMEOUT = "TIMEOUT" + NOT_FOUND = "NOT_FOUND" + + def __str__(self): + return self.value + @dataclass(slots=True, frozen=True) class RPCResponse: """ @@ -29,6 +43,7 @@ class RPCResponse: error: Optional error information if the RPC failed. """ request_id: Union[int, str] + status: RPCStatus = None result: Optional[Any] = None error: Optional[Union[str, Dict[str, Any]]] = None @@ -39,16 +54,38 @@ def __repr__(self) -> str: return f"RPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" @classmethod - def build(cls, request_id: Union[int, str], result: Optional[Any] = None, error: Optional[Union[str, Dict[str, Any]]] = None) -> 'RPCResponse': + def build(cls, request_id: Union[int, str], result: Optional[Any] = None, error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'RPCResponse': """ Constructs an RPCResponse explicitly. """ self = object.__new__(cls) object.__setattr__(self, 'request_id', request_id) - validate_json_compatibility(result) + + if error is not None: + if not isinstance(error, (str, dict, BaseException)): + raise ValueError("Error must be a string, dictionary, or an exception instance") + + object.__setattr__(self, 'status', RPCStatus.ERROR) + + if isinstance(error, BaseException): + try: + raise error + except BaseException as e: + error = { + "message": str(e), + "type": type(e).__name__, + "details": ''.join(format_exception(type(e), e, e.__traceback__)) + } + + validate_json_compatibility(error) + object.__setattr__(self, 'error', error) + + else: + object.__setattr__(self, 'status', RPCStatus.SUCCESS) + object.__setattr__(self, 'error', None) + validate_json_compatibility(result) + object.__setattr__(self, 'result', result) - validate_json_compatibility(error) - object.__setattr__(self, 'error', error) return self def to_payload_format(self) -> Dict[str, Any]: diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index f481acd..d8e8988 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -183,11 +183,19 @@ async def send_attributes(self, wait_for_publish: bool = True, timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: message = self._build_uplink_message_for_attributes(attributes) - futures = await self._message_queue.publish(topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, + topic = mqtt_topics.DEVICE_ATTRIBUTES_TOPIC + futures = await self._message_queue.publish(topic=topic, payload=message, datapoints_count=message.attributes_datapoint_count(), qos=qos or self._config.qos) - return futures[0] if futures else None + if wait_for_publish: + try: + return await wait_for(futures[0], timeout=timeout) if futures else None + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + return PublishResult(topic, qos, -1, message.size, -1) + else: + return futures[0] if futures else None async def send_rpc_request(self, rpc_request: RPCRequest, callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> Awaitable[RPCResponse]: diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 3dcd023..5e21767 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -84,9 +84,6 @@ async def _dequeue_loop(self): topic, payload, delivery_futures_or_none, datapoints, qos = await self._wait_for_message() logger.trace("MessageQueue dequeue: topic=%s, payload=%r, count=%d", topic, payload, datapoints) - if isinstance(payload, bytes): - await self._try_publish(topic, payload, datapoints, delivery_futures_or_none) - continue logger.trace("Dequeued message: delivery_future id: %r topic=%s, type=%s, datapoints=%d", id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1, topic, type(payload).__name__, datapoints) @@ -140,7 +137,7 @@ async def _dequeue_loop(self): if batch: logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) - messages = [device_uplink_message for _, device_uplink_message, _, _ in batch] + messages = [device_uplink_message for _, device_uplink_message, _, _, _ in batch] topic_payloads = self._dispatcher.build_uplink_payloads(messages) @@ -230,7 +227,7 @@ async def _try_publish(self, logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", topic, len(payload), id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1) - mqtt_future = await self._mqtt_manager.publish(message_or_topic=topic, payload=payload, qos=self.__qos) + mqtt_future = await self._mqtt_manager.publish(message_or_topic=topic, payload=payload, qos=qos) if delivery_futures_or_none is not None: def resolve_attached(publish_future: asyncio.Future): @@ -239,7 +236,7 @@ def resolve_attached(publish_future: asyncio.Future): except Exception as exc: logger.warning("Publish failed with exception: %s", exc) logger.debug("Resolving delivery futures with failure:", exc_info=exc) - publish_result = PublishResult(topic, self.__qos, -1, len(payload), -1) + publish_result = PublishResult(topic, qos, -1, len(payload), -1) for i, f in enumerate(delivery_futures_or_none): if f is not None and not f.done(): @@ -354,14 +351,6 @@ def clear(self): self._queue.task_done() logger.debug("Message queue cleared.") - @property - def qos(self) -> int: - return self.__qos - - @qos.setter - def qos(self, qos: int): - self.__qos = qos - async def _rate_limit_refill_loop(self): try: while self._active.is_set(): diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 84afdf6..6fe4791 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -331,7 +331,7 @@ def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dic logger.warning("PUBACK received with error code %s for mid=%s", reason_code, mid) if self._on_publish_result_callback: - self._on_publish_result_callback(publish_result) + asyncio.create_task(self._on_publish_result_callback(publish_result)) def _on_subscribe_internal(self, client, mid, qos, properties): logger.trace("Received SUBACK by client %r for mid=%s with qos %s, properties %s", From cd23b2f22abf8bb4436842a0ff925278faf9897b Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 4 Jun 2025 10:17:00 +0300 Subject: [PATCH 08/74] Changed typing for callbacks --- tb_mqtt_client/service/message_queue.py | 3 +++ tb_mqtt_client/service/mqtt_manager.py | 17 +++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 5e21767..fb8a74b 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -252,6 +252,9 @@ def resolve_attached(publish_future: asyncio.Future): mqtt_future.add_done_callback(resolve_attached) except Exception as e: logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) + logger.debug("Scheduling retry for topic=%s, payload size=%d, qos=%d", + topic, len(payload), qos) + logger.debug("error details: %s", e, exc_info=True) self._schedule_delayed_retry(topic, payload, datapoints, qos, delay=.1) def _schedule_delayed_retry(self, topic: str, payload: bytes, datapoints: int, qos: int, delay: float, diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 6fe4791..9cdc696 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -17,7 +17,7 @@ from asyncio import sleep from contextlib import suppress from time import monotonic -from typing import Optional, Callable, Awaitable, Dict, Union, Tuple +from typing import Optional, Callable, Dict, Union, Tuple, Coroutine, Any from gmqtt import Client as GMQTTClient, Message, Subscription @@ -51,10 +51,10 @@ def __init__( client_id: str, main_stop_event: asyncio.Event, message_dispatcher: MessageDispatcher, - on_connect: Optional[Callable[[], Awaitable[None]]] = None, - on_disconnect: Optional[Callable[[], Awaitable[None]]] = None, - on_publish_result: Optional[Callable[[PublishResult], Awaitable[None]]] = None, - rate_limits_handler: Optional[Callable[[RPCResponse], Awaitable[None]]] = None, + on_connect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, + on_disconnect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, + on_publish_result: Optional[Callable[[PublishResult], Coroutine[Any, Any, None]]] = None, + rate_limits_handler: Optional[Callable[[RPCResponse], Coroutine[Any, Any, None]]] = None, rpc_response_handler: Optional[RPCResponseHandler] = None, ): self._main_stop_event = main_stop_event @@ -77,7 +77,7 @@ def __init__( self._connected_event = asyncio.Event() self._connect_params = None # Will be set in connect method - self._handlers: Dict[str, Callable[[str, bytes], Awaitable[None]]] = {} + self._handlers: Dict[str, Callable[[str, bytes], Coroutine[Any, Any, None]]] = {} self._pending_publishes: Dict[int, Tuple[asyncio.Future[PublishResult], str, int, int, float]] = {} self._publish_monitor_task = asyncio.create_task(self._monitor_ack_timeouts()) @@ -205,7 +205,7 @@ async def unsubscribe(self, topic: str) -> asyncio.Future: self._pending_unsubscriptions[mid] = unsubscribe_future return unsubscribe_future - def register_handler(self, topic_filter: str, handler: Callable[[str, bytes], Awaitable[None]]): + def register_handler(self, topic_filter: str, handler: Callable[[str, bytes], Coroutine[Any, Any, None]]): self._handlers[topic_filter] = handler def unregister_handler(self, topic_filter: str): @@ -290,9 +290,6 @@ def _on_message_internal(self, client, topic: str, payload: bytes, qos, properti client, topic, payload, qos, properties) for topic_filter, handler in self._handlers.items(): if self._match_topic(topic_filter, topic): - # TODO - # TODO: Add awaiting for created task to ensure it is handled properly - # TODO asyncio.create_task(handler(topic, payload)) return From cb853f37907f0b04f6663dcafe143a8fb4416e83 Mon Sep 17 00:00:00 2001 From: samson0v Date: Wed, 11 Jun 2025 13:37:04 +0300 Subject: [PATCH 09/74] Added device provisioning --- .../entities/data/provision_request.py | 76 +++++++++++++++++++ tb_mqtt_client/entities/provision_client.py | 74 ++++++++++++++++++ tb_mqtt_client/service/device/client.py | 18 +++++ 3 files changed, 168 insertions(+) create mode 100644 tb_mqtt_client/entities/data/provision_request.py create mode 100644 tb_mqtt_client/entities/provision_client.py diff --git a/tb_mqtt_client/entities/data/provision_request.py b/tb_mqtt_client/entities/data/provision_request.py new file mode 100644 index 0000000..de5b163 --- /dev/null +++ b/tb_mqtt_client/entities/data/provision_request.py @@ -0,0 +1,76 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ProvisionRequest: + def __init__(self, provision_device_key, provision_device_secret, device_name=None, gateway=None): + self.provision_device_key = provision_device_key + self.provision_device_secret = provision_device_secret + self.device_name = device_name + self.gateway = gateway + + def to_dict(self): + provision_request = { + "provisionDeviceKey": self.provision_device_key, + "provisionDeviceSecret": self.provision_device_secret, + } + if self.device_name is not None: + provision_request["deviceName"] = self.device_name + if self.gateway is not None: + provision_request["gateway"] = self.gateway + + return provision_request + + +class ProvisionRequestAccessToken(ProvisionRequest): + def __init__(self, provision_device_key, provision_device_secret, access_token, device_name=None, gateway=None): + super().__init__(provision_device_key, provision_device_secret, device_name, gateway) + self.credentials_type = "ACCESS_TOKEN" + self.access_token = access_token + + def to_dict(self): + provision_request = super().to_dict() + provision_request["token"] = self.access_token + provision_request["credentialsType"] = "ACCESS_TOKEN" + return provision_request + + +class ProvisionRequestBasic(ProvisionRequest): + def __init__(self, provision_device_key, provision_device_secret, + client_id=None, username=None, password=None, device_name=None, gateway=None): + super().__init__(provision_device_key, provision_device_secret, device_name, gateway) + self.credentials_type = "MQTT_BASIC" + self.client_id = client_id + self.username = username + self.password = password + + def to_dict(self): + provision_request = super().to_dict() + provision_request["credentialsType"] = "MQTT_BASIC" + provision_request["username"] = self.username + provision_request["password"] = self.password + provision_request["clientId"] = self.client_id + return provision_request + + +class ProvisionRequestX509(ProvisionRequest): + def __init__(self, provision_device_key, provision_device_secret, hash, device_name=None, gateway=None): + super().__init__(provision_device_key, provision_device_secret, device_name, gateway) + self.credentials_type = "X509_CERTIFICATE" + self.hash = hash + + def to_dict(self): + provision_request = super().to_dict() + provision_request["credentialsType"] = "X509_CERTIFICATE" + provision_request["hash"] = self.hash + return provision_request diff --git a/tb_mqtt_client/entities/provision_client.py b/tb_mqtt_client/entities/provision_client.py new file mode 100644 index 0000000..df665bc --- /dev/null +++ b/tb_mqtt_client/entities/provision_client.py @@ -0,0 +1,74 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Event + +from gmqtt import Client as GMQTTClient +from orjson import OPT_NON_STR_KEYS, dumps, loads + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.provision_request import ProvisionRequest + +PROVISION_REQUEST_TOPIC = "/provision/request" +PROVISION_RESPONSE_TOPIC = "/provision/response" + +logger = get_logger(__name__) + + +class ProvisionClient: + def __init__(self, host, port, provision_request: 'ProvisionRequest'): + self._log = logger + self._stop_event = Event() + self._host = host + self._port = port + self._provision_request = provision_request.to_dict() + self._client_id = "provision" + self._client = GMQTTClient(self._client_id) + self._client.on_connect = self._on_connect + self._client.on_message = self._on_message + self._provisioned = Event() + self._device_credentials = None + + def _on_connect(self, client, _, rc, __): + if rc == 0: + self._log.debug("[Provisioning client] Connected to ThingsBoard ") + client.subscribe(PROVISION_RESPONSE_TOPIC) + provision_request = dumps(self._provision_request, option=OPT_NON_STR_KEYS) + self._log.debug("[Provisioning client] Sending provisioning request %s" % provision_request) + client.publish(PROVISION_REQUEST_TOPIC, provision_request) + else: + self._device_credentials = None + self._provisioned.set() + self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) + + async def _on_message(self, _, __, payload, ___, ____): + decoded_payload = payload.decode("UTF-8") + self._log.debug("[Provisioning client] Received data from ThingsBoard: %s" % decoded_payload) + decoded_message = loads(decoded_payload) + provision_device_status = decoded_message.get("status") + if provision_device_status == "SUCCESS": + self._device_credentials = decoded_message + else: + self._log.error("[Provisioning client] Provisioning was unsuccessful with status %s and message: %s" % ( + provision_device_status, decoded_message["errorMsg"])) + + await self._client.disconnect() + self._provisioned.set() + + async def provision(self): + await self._client.connect(self._host, self._port) + await self._provisioned.wait() + + if self._device_credentials: + return self._device_credentials diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index d8e8988..4f1f001 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -36,6 +36,8 @@ from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.provision_client import ProvisionClient +from tb_mqtt_client.entities.data.provision_request import ProvisionRequest from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler @@ -447,3 +449,19 @@ def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], builder = DeviceUplinkMessageBuilder() builder.add_attributes(payload) return builder.build() + + @staticmethod + async def provision(host, provision_request: 'ProvisionRequest', port=1883, timeout=3.0): + provision_client = ProvisionClient( + host=host, + port=port, + provision_request=provision_request + ) + + device_credentials = None + try: + device_credentials = await wait_for(provision_client.provision(), timeout=timeout) + except TimeoutError: + logger.error("Provisioning timed out") + + return device_credentials From 0e914915c9a4273ee55a773f8690b33165e55577 Mon Sep 17 00:00:00 2001 From: samson0v Date: Thu, 12 Jun 2025 14:06:03 +0300 Subject: [PATCH 10/74] Added firmware update --- tb_mqtt_client/constants/mqtt_topics.py | 7 + tb_mqtt_client/service/device/client.py | 6 + .../service/device/firmware_updater.py | 193 ++++++++++++++++++ 3 files changed, 206 insertions(+) create mode 100644 tb_mqtt_client/service/device/firmware_updater.py diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 3986551..553acaf 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -30,6 +30,9 @@ DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD # Device Claim topic DEVICE_CLAIM_TOPIC = "v1/devices/me/claim" +# Device Firmware Update topics +DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC = "v2/fw/response/+/chunk/+" +DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC = "v2/fw/request/{request_id}/chunk/{current_chunk}" # V1 Topics for Gateway API BASE_GATEWAY_TOPIC = "v1/gateway" @@ -69,3 +72,7 @@ def build_gateway_device_attributes_topic() -> str: def build_gateway_rpc_topic() -> str: return GATEWAY_RPC_TOPIC + + +def build_firmware_update_request_topic(request_id: int, current_chunk: int) -> str: + return DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC.format(request_id=request_id, current_chunk=current_chunk) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 4f1f001..4e13964 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -40,6 +40,7 @@ from tb_mqtt_client.entities.data.provision_request import ProvisionRequest from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ RequestedAttributeResponseHandler @@ -96,6 +97,11 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._rpc_requests_handler = RPCRequestsHandler() self.__claiming_response_future: Union[Future[bool], None] = None + self._firmware_updater = FirmwareUpdater(self) + + async def update_firmware(self): + await self._firmware_updater.update() + async def connect(self): logger.info("Connecting to platform at %s:%s", self._host, self._port) diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py new file mode 100644 index 0000000..eae4ab7 --- /dev/null +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -0,0 +1,193 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import sleep +from enum import Enum +from sdk_utils import verify_checksum +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest + +logger = get_logger(__name__) + + +class FirmwareStates(Enum): + IDLE = 'IDLE' + DOWNLOADING = 'DOWNLOADING' + DOWNLOADED = 'DOWNLOADED' + VERIFIED = 'VERIFIED' + FAILED = 'FAILED' + UPDATING = 'UPDATING' + UPDATED = 'UPDATED' + + +class FirmwareUpdater: + FW_TITLE_ATTR = "fw_title" + FW_VERSION_ATTR = "fw_version" + FW_CHECKSUM_ATTR = "fw_checksum" + FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" + FW_SIZE_ATTR = "fw_size" + FW_STATE_ATTR = "fw_state" + + REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, + FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] + + def __init__(self, client): + self._log = logger + self._client = client + self._client._mqtt_manager.register_handler(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, + self._handle_firmware_update) + self._firmware_request_id = 0 + self._chunk_size = 0 + self._current_chunk = 0 + self._firmware_data = b'' + self._target_firmware_length = 0 + self._target_checksum = 0 + self._target_checksum_alg = None + self._target_version = None + self._target_title = None + self.current_firmware_info = { + 'current_' + FirmwareUpdater.FW_TITLE_ATTR: 'Initial', + 'current_' + FirmwareUpdater.FW_VERSION_ATTR: 'v0', + FirmwareUpdater.FW_STATE_ATTR: FirmwareStates.IDLE.value + } + + async def _handle_firmware_update(self, _, payload: bytes): + self._firmware_data = self._firmware_data + payload + self._current_chunk = self._current_chunk + 1 + + self._log.debug('Getting chunk with number: %s. Chunk size is : %r byte(s).' % ( + self._current_chunk, self._chunk_size)) + + if len(self._firmware_data) == self._target_firmware_length: + self._log.info('Firmware download completed. ' + 'Total firmware size: %s byte(s).' % self._target_firmware_length) + await self._verify_downloaded_firmware() + else: + await self._get_next_chunk() + + async def _get_next_chunk(self): + if not self._chunk_size or self._chunk_size > self._target_firmware_length: + payload = b'' + else: + payload = str(self._chunk_size).encode() + + topic = mqtt_topics.build_firmware_update_request_topic(self._firmware_request_id, self._current_chunk) + await self._client._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) + + async def _verify_downloaded_firmware(self): + self._log.info('Verifying downloaded firmware...') + + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.DOWNLOADED.value + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + + verified = verify_checksum(self._firmware_data, + self._target_checksum, + self._target_checksum_alg) + + if verified: + self._log.debug('Checksum verified.') + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.VERIFIED.value + else: + self._log.error('Checksum verification failed.') + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value + + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + + if self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] == FirmwareStates.VERIFIED.value: + await self._apply_downloaded_firmware() + + async def _apply_downloaded_firmware(self): + self._log.info('Applying downloaded firmware...') + + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.UPDATING.value + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + + try: + self._save_firmware() + except Exception as e: + self._log.error('Failed to save firmware: %s', e) + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + return + + self.current_firmware_info = { + "current_" + FirmwareUpdater.FW_TITLE_ATTR: self._target_title, + "current_" + FirmwareUpdater.FW_VERSION_ATTR: self._target_version, + FirmwareUpdater.FW_STATE_ATTR: FirmwareStates.UPDATED.value + } + + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + + self._log.info('Firmware is updated.') + self._log.info('Current firmware version is: %s' % self._target_version) + + def _save_firmware(self): + with open(self._target_title, "wb") as firmware_file: + firmware_file.write(self._firmware_data) + + async def update(self): + self._log.info("Starting firmware update process...") + + sub_future = await self._client._mqtt_manager.subscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, + qos=1) + while not sub_future.done(): + await sleep(0.01) + + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + + attribute_request = await AttributeRequest.build(FirmwareUpdater.REQUIRED_SHARED_KEYS) + await self._client.send_attribute_request(attribute_request, callback=self._firmware_info_callback) + + async def _firmware_info_callback(self, response, *args, **kwargs): + # TODO: add logs + if len(response.shared_keys()) == len(FirmwareUpdater.REQUIRED_SHARED_KEYS): + fetched_firmware_info = response.as_dict()['shared'] + fetched_firmware_info = {item['key']: item['value'] + for item in fetched_firmware_info} + + if self._is_different_firmware_versions(fetched_firmware_info): + self._log.info("Firmware update available: %s. Downloading...", + fetched_firmware_info) + + self._firmware_data = b'' + self._current_chunk = 0 + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value + + self._firmware_request_id += 1 + self._target_firmware_length = fetched_firmware_info[FirmwareUpdater.FW_SIZE_ATTR] + self._target_checksum = fetched_firmware_info[FirmwareUpdater.FW_CHECKSUM_ALG_ATTR] + self._target_checksum_alg = fetched_firmware_info[FirmwareUpdater.FW_CHECKSUM_ATTR] + self._target_title = fetched_firmware_info[FirmwareUpdater.FW_TITLE_ATTR] + self._target_version = fetched_firmware_info[FirmwareUpdater.FW_VERSION_ATTR] + + await self._get_next_chunk() + else: + self._log.info("Firmware is up to date.") + else: + self._log.error("Failed to fetch firmware info. " + "Received firmware info does not match required keys. " + "Expected: %s, Received: %s", + FirmwareUpdater.REQUIRED_SHARED_KEYS, + response.shared_keys()) + + self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + + def _send_current_firmware_info(self): + pass + + def _is_different_firmware_versions(self, new_firmware_info): + return (self.current_firmware_info['current_' + FirmwareUpdater.FW_TITLE_ATTR] != new_firmware_info[FirmwareUpdater.FW_TITLE_ATTR] or # noqa + self.current_firmware_info['current_' + FirmwareUpdater.FW_VERSION_ATTR] != new_firmware_info[FirmwareUpdater.FW_VERSION_ATTR]) # noqa From 29144841607b722bdc2dd6b2f7053239dab3a9d2 Mon Sep 17 00:00:00 2001 From: samson0v Date: Tue, 17 Jun 2025 13:33:03 +0300 Subject: [PATCH 11/74] Updated device provisioning --- tb_mqtt_client/constants/mqtt_topics.py | 3 + .../entities/data/provision_request.py | 98 ++++++++++--------- .../entities/data/provisioning_response.py | 67 +++++++++++++ tb_mqtt_client/entities/provision_client.py | 37 ++++--- tb_mqtt_client/service/device/client.py | 6 +- tb_mqtt_client/service/message_dispatcher.py | 73 +++++++++++--- 6 files changed, 202 insertions(+), 82 deletions(-) create mode 100644 tb_mqtt_client/entities/data/provisioning_response.py diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 553acaf..5f98343 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -30,6 +30,9 @@ DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD # Device Claim topic DEVICE_CLAIM_TOPIC = "v1/devices/me/claim" +# Device Provisioning topics +PROVISION_REQUEST_TOPIC = "/provision/request" +PROVISION_RESPONSE_TOPIC = "/provision/response" # Device Firmware Update topics DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC = "v2/fw/response/+/chunk/+" DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC = "v2/fw/request/{request_id}/chunk/{current_chunk}" diff --git a/tb_mqtt_client/entities/data/provision_request.py b/tb_mqtt_client/entities/data/provision_request.py index de5b163..0ad7026 100644 --- a/tb_mqtt_client/entities/data/provision_request.py +++ b/tb_mqtt_client/entities/data/provision_request.py @@ -12,65 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum +from typing import Optional + + class ProvisionRequest: - def __init__(self, provision_device_key, provision_device_secret, device_name=None, gateway=None): - self.provision_device_key = provision_device_key - self.provision_device_secret = provision_device_secret + def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "1883", + device_name: Optional[str] = None, gateway: Optional[bool] = False): + self.host = host + self.port = port + self.credentials = credentials self.device_name = device_name self.gateway = gateway - def to_dict(self): - provision_request = { - "provisionDeviceKey": self.provision_device_key, - "provisionDeviceSecret": self.provision_device_secret, - } - if self.device_name is not None: - provision_request["deviceName"] = self.device_name - if self.gateway is not None: - provision_request["gateway"] = self.gateway - return provision_request +class ProvisioningCredentialsType(Enum): + ACCESS_TOKEN = "ACCESS_TOKEN" + MQTT_BASIC = "MQTT_BASIC" + X509_CERTIFICATE = "X509_CERTIFICATE" -class ProvisionRequestAccessToken(ProvisionRequest): - def __init__(self, provision_device_key, provision_device_secret, access_token, device_name=None, gateway=None): - super().__init__(provision_device_key, provision_device_secret, device_name, gateway) - self.credentials_type = "ACCESS_TOKEN" - self.access_token = access_token +class ProvisioningCredentials: + def __init__(self, provision_device_key: str, provision_device_secret: str): + self.provision_device_key = provision_device_key + self.provision_device_secret = provision_device_secret + self.credentials_type = None - def to_dict(self): - provision_request = super().to_dict() - provision_request["token"] = self.access_token - provision_request["credentialsType"] = "ACCESS_TOKEN" - return provision_request + +class AccessTokenProvisionCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: str): + super().__init__(provision_device_key, provision_device_secret) + self.access_token = access_token + self.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN -class ProvisionRequestBasic(ProvisionRequest): +class BasicProvisionCredentials(ProvisioningCredentials): def __init__(self, provision_device_key, provision_device_secret, - client_id=None, username=None, password=None, device_name=None, gateway=None): - super().__init__(provision_device_key, provision_device_secret, device_name, gateway) - self.credentials_type = "MQTT_BASIC" + client_id: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None): + super().__init__(provision_device_key, provision_device_secret) self.client_id = client_id self.username = username self.password = password + self.credentials_type = ProvisioningCredentialsType.MQTT_BASIC + + +class X509ProvisionCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key, provision_device_secret, + private_key_path: str, public_cert_path: str, ca_cert_path: str): + super().__init__(provision_device_key, provision_device_secret) + self.private_key_path = private_key_path + self.ca_cert_path = ca_cert_path + self.public_cert_path = public_cert_path + self.public_cert = self._load_public_cert_path(public_cert_path) + self.credentials_type = ProvisioningCredentialsType.X509_CERTIFICATE + + def _load_public_cert_path(public_cert_path): + content = '' + + try: + with open(public_cert_path, 'r') as file: + content = file.read() + except FileNotFoundError: + raise FileNotFoundError(f"Public certificate file not found: {public_cert_path}") + except IOError as e: + raise IOError(f"Error reading public certificate file {public_cert_path}: {e}") - def to_dict(self): - provision_request = super().to_dict() - provision_request["credentialsType"] = "MQTT_BASIC" - provision_request["username"] = self.username - provision_request["password"] = self.password - provision_request["clientId"] = self.client_id - return provision_request - - -class ProvisionRequestX509(ProvisionRequest): - def __init__(self, provision_device_key, provision_device_secret, hash, device_name=None, gateway=None): - super().__init__(provision_device_key, provision_device_secret, device_name, gateway) - self.credentials_type = "X509_CERTIFICATE" - self.hash = hash - - def to_dict(self): - provision_request = super().to_dict() - provision_request["credentialsType"] = "X509_CERTIFICATE" - provision_request["hash"] = self.hash - return provision_request + return content.strip() if content else None diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py new file mode 100644 index 0000000..eb6dd58 --- /dev/null +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.provision_request import ProvisionRequest, ProvisioningCredentialsType + + +class ProvisioningResponseStatus(Enum): + SUCCESS = "SUCCESS" + ERROR = "FAILURE" + + def __str__(self): + return self.value + + +@dataclass(frozen=True) +class ProvisioningResponse: + status: ProvisioningResponseStatus + result: Optional[dict] = None + error: Optional[str] = None + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of ProvisioningResponse is not allowed. Use ProvisioningResponse.build(result, error).") # noqa + + def __repr__(self) -> str: + return f"ProvisioningResponse(status={self.status}, result={self.result}, error={self.error})" + + @classmethod + def build(cls, provision_request: 'ProvisionRequest', payload: dict) -> 'ProvisioningResponse': + """ + Constructs a ProvisioningResponse explicitly. + """ + self = object.__new__(cls) + + if payload.get('status') == ProvisioningResponseStatus.ERROR.value: + object.__setattr__(self, 'error', payload.get('errorMsg')) + object.__setattr__(self, 'status', ProvisioningResponseStatus.ERROR) + object.__setattr__(self, 'result', None) + else: + device_config = ProvisioningResponse._build_device_config(provision_request, payload) + + object.__setattr__(self, 'result', device_config) + object.__setattr__(self, 'status', ProvisioningResponseStatus.SUCCESS) + object.__setattr__(self, 'error', None) + + return self + + @staticmethod + def _build_device_config(provision_request: 'ProvisionRequest', payload: dict): + device_config = DeviceConfig() + device_config.host = provision_request.host + device_config.port = provision_request.port + + if provision_request.credentials.credentials_type is None or \ + provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: + device_config.access_token = payload['credentialsValue'] + elif provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: + device_config.client_id = payload['credentialsValue']['clientId'] + device_config.username = payload['credentialsValue']['userName'] + device_config.password = payload['credentialsValue']['password'] + elif provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: + device_config.ca_cert = provision_request.credentials.ca_cert_path + device_config.client_cert = provision_request.credentials.public_cert_path + device_config.private_key = provision_request.credentials.private_key_path + + return device_config diff --git a/tb_mqtt_client/entities/provision_client.py b/tb_mqtt_client/entities/provision_client.py index df665bc..2de3de7 100644 --- a/tb_mqtt_client/entities/provision_client.py +++ b/tb_mqtt_client/entities/provision_client.py @@ -15,13 +15,14 @@ from asyncio import Event from gmqtt import Client as GMQTTClient -from orjson import OPT_NON_STR_KEYS, dumps, loads +from orjson import loads +from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC from tb_mqtt_client.entities.data.provision_request import ProvisionRequest - -PROVISION_REQUEST_TOPIC = "/provision/request" -PROVISION_RESPONSE_TOPIC = "/provision/response" +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher logger = get_logger(__name__) @@ -32,23 +33,26 @@ def __init__(self, host, port, provision_request: 'ProvisionRequest'): self._stop_event = Event() self._host = host self._port = port - self._provision_request = provision_request.to_dict() + self._provision_request = provision_request self._client_id = "provision" self._client = GMQTTClient(self._client_id) self._client.on_connect = self._on_connect self._client.on_message = self._on_message self._provisioned = Event() - self._device_credentials = None + self._device_config: 'DeviceConfig' = None + self._json_message_dispatcher = JsonMessageDispatcher() def _on_connect(self, client, _, rc, __): if rc == 0: - self._log.debug("[Provisioning client] Connected to ThingsBoard ") + self._log.debug("[Provisioning client] Connected to ThingsBoard") client.subscribe(PROVISION_RESPONSE_TOPIC) - provision_request = dumps(self._provision_request, option=OPT_NON_STR_KEYS) - self._log.debug("[Provisioning client] Sending provisioning request %s" % provision_request) - client.publish(PROVISION_REQUEST_TOPIC, provision_request) + topic, payload = self._json_message_dispatcher.build_provision_request(self._provision_request) + self._log.debug("[Provisioning client] Sending provisioning request %s" % payload) + client.publish(topic, payload) else: - self._device_credentials = None + self._device_config = ProvisioningResponse.build(self._provision_request, + {'status': 'FAILURE', + 'errorMsg': 'Cannot connect to ThingsBoard!'}) self._provisioned.set() self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) @@ -56,12 +60,8 @@ async def _on_message(self, _, __, payload, ___, ____): decoded_payload = payload.decode("UTF-8") self._log.debug("[Provisioning client] Received data from ThingsBoard: %s" % decoded_payload) decoded_message = loads(decoded_payload) - provision_device_status = decoded_message.get("status") - if provision_device_status == "SUCCESS": - self._device_credentials = decoded_message - else: - self._log.error("[Provisioning client] Provisioning was unsuccessful with status %s and message: %s" % ( - provision_device_status, decoded_message["errorMsg"])) + + self._device_config = ProvisioningResponse.build(self._provision_request, decoded_message) await self._client.disconnect() self._provisioned.set() @@ -70,5 +70,4 @@ async def provision(self): await self._client.connect(self._host, self._port) await self._provisioned.wait() - if self._device_credentials: - return self._device_credentials + return self._device_config diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 4e13964..3d104f4 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -457,10 +457,10 @@ def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], return builder.build() @staticmethod - async def provision(host, provision_request: 'ProvisionRequest', port=1883, timeout=3.0): + async def provision(provision_request: 'ProvisionRequest', timeout=3.0): provision_client = ProvisionClient( - host=host, - port=port, + host=provision_request.host, + port=provision_request.port, provision_request=provision_request ) diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 52d1790..bd1e17e 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -26,6 +26,7 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.provision_request import ProvisionRequest, ProvisioningCredentialsType from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -84,6 +85,14 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: """ pass + @abstractmethod + def build_provision_request(self, provision_request) -> Tuple[str, bytes]: + """ + Build the payload for a device provisioning request. + This method should return a tuple of topic and payload bytes. + """ + pass + @abstractmethod def splitter(self) -> MessageSplitter: """ @@ -292,20 +301,58 @@ def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: rpc_request.request_id, payload) return topic, payload - def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: - """ - Build the payload for an RPC response. - :param rpc_response: The RPC response to build the payload for. - :return: A tuple of topic and payload bytes. - """ - if not rpc_response.request_id: - raise ValueError("RPCResponse must have a valid request ID.") - - payload = dumps(rpc_response.to_payload_format()) - topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) - logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) - return topic, payload + """ + Build the payload for an RPC response. + :param rpc_response: The RPC response to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not rpc_response.request_id: + raise ValueError("RPCResponse must have a valid request ID.") + + payload = dumps(rpc_response.to_payload_format()) + topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) + logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) + return topic, payload + + def build_provision_request(self, provision_request: 'ProvisionRequest') -> Tuple[str, bytes]: + """ + Build the payload for a device provisioning request. + :param provision_request: The ProvisionRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not provision_request.credentials.provision_device_key or not provision_request.credentials.provision_device_secret: + raise ValueError("ProvisionRequest must have valid device key and secret.") + + topic = mqtt_topics.PROVISION_REQUEST_TOPIC + request = {} + request["provisionDeviceKey"] = provision_request.credentials.provision_device_key + request["provisionDeviceSecret"] = provision_request.credentials.provision_device_secret + + if provision_request.device_name: + request["deviceName"] = provision_request.device_name + + if provision_request.gateway: + request["gateway"] = provision_request.gateway + + if provision_request.credentials.credentials_type and \ + provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: + request["token"] = provision_request.credentials.access_token + request["credentialsType"] = provision_request.credentials.credentials_type.value + + if provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: + request["username"] = provision_request.credentials.username + request["password"] = provision_request.credentials.password + request["clientId"] = provision_request.credentials.client_id + request["credentialsType"] = provision_request.credentials.credentials_type.value + + if provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: + request["hash"] = provision_request.credentials.public_cert + request["credentialsType"] = provision_request.credentials.credentials_type.value + + payload = dumps(request) + logger.trace("Built provision request payload: %r", provision_request) + return topic, payload @staticmethod def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: From 92c9236df70db88330ffbdf5ee2e7b04b0057731 Mon Sep 17 00:00:00 2001 From: samson0v Date: Tue, 17 Jun 2025 13:42:27 +0300 Subject: [PATCH 12/74] Refactored device provisioning --- ...ion_request.py => provisioning_request.py} | 8 ++++---- .../entities/data/provisioning_response.py | 20 ++++++++++++++++--- ...ision_client.py => provisioning_client.py} | 6 +++--- tb_mqtt_client/service/device/client.py | 8 ++++---- tb_mqtt_client/service/message_dispatcher.py | 8 ++++---- 5 files changed, 32 insertions(+), 18 deletions(-) rename tb_mqtt_client/entities/data/{provision_request.py => provisioning_request.py} (93%) rename tb_mqtt_client/entities/{provision_client.py => provisioning_client.py} (96%) diff --git a/tb_mqtt_client/entities/data/provision_request.py b/tb_mqtt_client/entities/data/provisioning_request.py similarity index 93% rename from tb_mqtt_client/entities/data/provision_request.py rename to tb_mqtt_client/entities/data/provisioning_request.py index 0ad7026..8101237 100644 --- a/tb_mqtt_client/entities/data/provision_request.py +++ b/tb_mqtt_client/entities/data/provisioning_request.py @@ -16,7 +16,7 @@ from typing import Optional -class ProvisionRequest: +class ProvisioningRequest: def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "1883", device_name: Optional[str] = None, gateway: Optional[bool] = False): self.host = host @@ -39,14 +39,14 @@ def __init__(self, provision_device_key: str, provision_device_secret: str): self.credentials_type = None -class AccessTokenProvisionCredentials(ProvisioningCredentials): +class AccessTokenProvisioningCredentials(ProvisioningCredentials): def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: str): super().__init__(provision_device_key, provision_device_secret) self.access_token = access_token self.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN -class BasicProvisionCredentials(ProvisioningCredentials): +class BasicProvisioningCredentials(ProvisioningCredentials): def __init__(self, provision_device_key, provision_device_secret, client_id: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None): super().__init__(provision_device_key, provision_device_secret) @@ -56,7 +56,7 @@ def __init__(self, provision_device_key, provision_device_secret, self.credentials_type = ProvisioningCredentialsType.MQTT_BASIC -class X509ProvisionCredentials(ProvisioningCredentials): +class X509ProvisioningCredentials(ProvisioningCredentials): def __init__(self, provision_device_key, provision_device_secret, private_key_path: str, public_cert_path: str, ca_cert_path: str): super().__init__(provision_device_key, provision_device_secret) diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py index eb6dd58..fcea0bd 100644 --- a/tb_mqtt_client/entities/data/provisioning_response.py +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -1,9 +1,23 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from enum import Enum from typing import Optional from tb_mqtt_client.common.config_loader import DeviceConfig -from tb_mqtt_client.entities.data.provision_request import ProvisionRequest, ProvisioningCredentialsType +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType class ProvisioningResponseStatus(Enum): @@ -27,7 +41,7 @@ def __repr__(self) -> str: return f"ProvisioningResponse(status={self.status}, result={self.result}, error={self.error})" @classmethod - def build(cls, provision_request: 'ProvisionRequest', payload: dict) -> 'ProvisioningResponse': + def build(cls, provision_request: 'ProvisioningRequest', payload: dict) -> 'ProvisioningResponse': """ Constructs a ProvisioningResponse explicitly. """ @@ -47,7 +61,7 @@ def build(cls, provision_request: 'ProvisionRequest', payload: dict) -> 'Provisi return self @staticmethod - def _build_device_config(provision_request: 'ProvisionRequest', payload: dict): + def _build_device_config(provision_request: 'ProvisioningRequest', payload: dict): device_config = DeviceConfig() device_config.host = provision_request.host device_config.port = provision_request.port diff --git a/tb_mqtt_client/entities/provision_client.py b/tb_mqtt_client/entities/provisioning_client.py similarity index 96% rename from tb_mqtt_client/entities/provision_client.py rename to tb_mqtt_client/entities/provisioning_client.py index 2de3de7..a2718c3 100644 --- a/tb_mqtt_client/entities/provision_client.py +++ b/tb_mqtt_client/entities/provisioning_client.py @@ -20,15 +20,15 @@ from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC -from tb_mqtt_client.entities.data.provision_request import ProvisionRequest +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher logger = get_logger(__name__) -class ProvisionClient: - def __init__(self, host, port, provision_request: 'ProvisionRequest'): +class ProvisioningClient: + def __init__(self, host, port, provision_request: 'ProvisioningRequest'): self._log = logger self._stop_event = Event() self._host = host diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 3d104f4..87114ee 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -36,8 +36,8 @@ from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.provision_client import ProvisionClient -from tb_mqtt_client.entities.data.provision_request import ProvisionRequest +from tb_mqtt_client.entities.provisioning_client import ProvisioningClient +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater @@ -457,8 +457,8 @@ def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], return builder.build() @staticmethod - async def provision(provision_request: 'ProvisionRequest', timeout=3.0): - provision_client = ProvisionClient( + async def provision(provision_request: 'ProvisioningRequest', timeout=3.0): + provision_client = ProvisioningClient( host=provision_request.host, port=provision_request.port, provision_request=provision_request diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index bd1e17e..5dbd9fa 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -26,7 +26,7 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage -from tb_mqtt_client.entities.data.provision_request import ProvisionRequest, ProvisioningCredentialsType +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -315,14 +315,14 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) return topic, payload - def build_provision_request(self, provision_request: 'ProvisionRequest') -> Tuple[str, bytes]: + def build_provision_request(self, provision_request: 'ProvisioningRequest') -> Tuple[str, bytes]: """ Build the payload for a device provisioning request. - :param provision_request: The ProvisionRequest to build the payload for. + :param provision_request: The ProvisioningRequest to build the payload for. :return: A tuple of topic and payload bytes. """ if not provision_request.credentials.provision_device_key or not provision_request.credentials.provision_device_secret: - raise ValueError("ProvisionRequest must have valid device key and secret.") + raise ValueError("ProvisioningRequest must have valid device key and secret.") topic = mqtt_topics.PROVISION_REQUEST_TOPIC request = {} From 4490374cda4227c4e20dba4763d0e79d6c11d333 Mon Sep 17 00:00:00 2001 From: samson0v Date: Wed, 18 Jun 2025 08:34:46 +0300 Subject: [PATCH 13/74] Updated firmware updater --- tb_mqtt_client/service/device/client.py | 5 ++-- .../service/device/firmware_updater.py | 29 ++++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 87114ee..7b10448 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -99,8 +99,9 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._firmware_updater = FirmwareUpdater(self) - async def update_firmware(self): - await self._firmware_updater.update() + async def update_firmware(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, + save_firmware: bool = True, firmware_save_path: Optional[str] = None): + await self._firmware_updater.update(on_received_callback, save_firmware, firmware_save_path) async def connect(self): logger.info("Connecting to platform at %s:%s", self._host, self._port) diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py index eae4ab7..87453b2 100644 --- a/tb_mqtt_client/service/device/firmware_updater.py +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -14,6 +14,8 @@ from asyncio import sleep from enum import Enum +from os.path import sep +from typing import Awaitable, Callable, Optional from sdk_utils import verify_checksum from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.constants import mqtt_topics @@ -48,6 +50,9 @@ def __init__(self, client): self._client = client self._client._mqtt_manager.register_handler(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, self._handle_firmware_update) + self._on_received_callback = None + self._save_firmware = True + self._save_path = './' self._firmware_request_id = 0 self._chunk_size = 0 self._current_chunk = 0 @@ -115,7 +120,8 @@ async def _apply_downloaded_firmware(self): await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) try: - self._save_firmware() + if self._save_firmware: + self._save() except Exception as e: self._log.error('Failed to save firmware: %s', e) self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value @@ -130,16 +136,27 @@ async def _apply_downloaded_firmware(self): await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + if self._on_received_callback: + await self._on_received_callback(self._firmware_data, self.current_firmware_info) + self._log.info('Firmware is updated.') self._log.info('Current firmware version is: %s' % self._target_version) - def _save_firmware(self): - with open(self._target_title, "wb") as firmware_file: + def _save(self): + firmware_path = self._save_path + sep + self._target_title + with open(firmware_path, "wb") as firmware_file: firmware_file.write(self._firmware_data) - async def update(self): + async def update(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, + save_firmware: bool = True, firmware_save_path: Optional[str] = None): self._log.info("Starting firmware update process...") + self._on_received_callback = on_received_callback + self._save_firmware = save_firmware + if firmware_save_path: + self._save_path = firmware_save_path + self._log.info("Firmware will be saved to: %s", self._save_path) + sub_future = await self._client._mqtt_manager.subscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, qos=1) while not sub_future.done(): @@ -151,7 +168,6 @@ async def update(self): await self._client.send_attribute_request(attribute_request, callback=self._firmware_info_callback) async def _firmware_info_callback(self, response, *args, **kwargs): - # TODO: add logs if len(response.shared_keys()) == len(FirmwareUpdater.REQUIRED_SHARED_KEYS): fetched_firmware_info = response.as_dict()['shared'] fetched_firmware_info = {item['key']: item['value'] @@ -185,9 +201,6 @@ async def _firmware_info_callback(self, response, *args, **kwargs): self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) - def _send_current_firmware_info(self): - pass - def _is_different_firmware_versions(self, new_firmware_info): return (self.current_firmware_info['current_' + FirmwareUpdater.FW_TITLE_ATTR] != new_firmware_info[FirmwareUpdater.FW_TITLE_ATTR] or # noqa self.current_firmware_info['current_' + FirmwareUpdater.FW_VERSION_ATTR] != new_firmware_info[FirmwareUpdater.FW_VERSION_ATTR]) # noqa From f93f78cb2a966c551a5bc4cf3ff0722062088ed0 Mon Sep 17 00:00:00 2001 From: samson0v Date: Wed, 18 Jun 2025 14:37:50 +0300 Subject: [PATCH 14/74] Updated firmware updater due to comments --- .../common/install_package_utils.py | 48 +++++ tb_mqtt_client/constants/firmware.py | 36 ++++ .../service/device/firmware_updater.py | 182 ++++++++++++------ 3 files changed, 210 insertions(+), 56 deletions(-) create mode 100644 tb_mqtt_client/common/install_package_utils.py create mode 100644 tb_mqtt_client/constants/firmware.py diff --git a/tb_mqtt_client/common/install_package_utils.py b/tb_mqtt_client/common/install_package_utils.py new file mode 100644 index 0000000..1227418 --- /dev/null +++ b/tb_mqtt_client/common/install_package_utils.py @@ -0,0 +1,48 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sys import executable +from subprocess import check_call, CalledProcessError, DEVNULL +from pkg_resources import get_distribution, DistributionNotFound + + +def install_package(package, version="upgrade"): + result = False + + def try_install(args, suppress_stderr=False): + try: + stderr = DEVNULL if suppress_stderr else None + check_call([executable, "-m", "pip", *args], stderr=stderr) + return True + except CalledProcessError: + return False + + if version.lower() == "upgrade": + args = ["install", package, "--upgrade"] + result = try_install(args + ["--user"], suppress_stderr=True) + if not result: + result = try_install(args) + else: + try: + installed_version = get_distribution(package).version + if installed_version == version: + return True + except DistributionNotFound: + pass + install_version = f"{package}=={version}" if ">=" not in version else f"{package}{version}" + args = ["install", install_version] + if not try_install(args + ["--user"], suppress_stderr=True): + result = try_install(args) + + return result diff --git a/tb_mqtt_client/constants/firmware.py b/tb_mqtt_client/constants/firmware.py new file mode 100644 index 0000000..cdc03b4 --- /dev/null +++ b/tb_mqtt_client/constants/firmware.py @@ -0,0 +1,36 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class FirmwareStates(Enum): + IDLE = 'IDLE' + DOWNLOADING = 'DOWNLOADING' + DOWNLOADED = 'DOWNLOADED' + VERIFIED = 'VERIFIED' + FAILED = 'FAILED' + UPDATING = 'UPDATING' + UPDATED = 'UPDATED' + + +FW_TITLE_ATTR = "fw_title" +FW_VERSION_ATTR = "fw_version" +FW_CHECKSUM_ATTR = "fw_checksum" +FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" +FW_SIZE_ATTR = "fw_size" +FW_STATE_ATTR = "fw_state" + +REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, + FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py index 87453b2..4422b52 100644 --- a/tb_mqtt_client/service/device/firmware_updater.py +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -12,39 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. +from random import randint +from zlib import crc32 +from hashlib import sha256, sha384, sha512, md5 +from subprocess import CalledProcessError from asyncio import sleep -from enum import Enum from os.path import sep from typing import Awaitable, Callable, Optional -from sdk_utils import verify_checksum +from tb_mqtt_client.common.install_package_utils import install_package from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.constants.firmware import ( + FW_CHECKSUM_ALG_ATTR, + FW_CHECKSUM_ATTR, + FW_SIZE_ATTR, + FW_STATE_ATTR, + FW_TITLE_ATTR, + FW_VERSION_ATTR, + REQUIRED_SHARED_KEYS, + FirmwareStates +) -logger = get_logger(__name__) +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + +try: + from mmh3 import hash, hash128 +except ImportError: + try: + from pymmh3 import hash, hash128 + except ImportError: + try: + install_package('mmh3') + except CalledProcessError: + install_package('pymmh3') +try: + from mmh3 import hash, hash128 # noqa +except ImportError: + from pymmh3 import hash, hash128 -class FirmwareStates(Enum): - IDLE = 'IDLE' - DOWNLOADING = 'DOWNLOADING' - DOWNLOADED = 'DOWNLOADED' - VERIFIED = 'VERIFIED' - FAILED = 'FAILED' - UPDATING = 'UPDATING' - UPDATED = 'UPDATED' +logger = get_logger(__name__) class FirmwareUpdater: - FW_TITLE_ATTR = "fw_title" - FW_VERSION_ATTR = "fw_version" - FW_CHECKSUM_ATTR = "fw_checksum" - FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" - FW_SIZE_ATTR = "fw_size" - FW_STATE_ATTR = "fw_state" - - REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, - FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] - def __init__(self, client): self._log = logger self._client = client @@ -63,9 +74,9 @@ def __init__(self, client): self._target_version = None self._target_title = None self.current_firmware_info = { - 'current_' + FirmwareUpdater.FW_TITLE_ATTR: 'Initial', - 'current_' + FirmwareUpdater.FW_VERSION_ATTR: 'v0', - FirmwareUpdater.FW_STATE_ATTR: FirmwareStates.IDLE.value + 'current_' + FW_TITLE_ATTR: 'Initial', + 'current_' + FW_VERSION_ATTR: 'v0', + FW_STATE_ATTR: FirmwareStates.IDLE.value } async def _handle_firmware_update(self, _, payload: bytes): @@ -94,50 +105,51 @@ async def _get_next_chunk(self): async def _verify_downloaded_firmware(self): self._log.info('Verifying downloaded firmware...') - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.DOWNLOADED.value - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADED.value + await self._send_current_firmware_info() - verified = verify_checksum(self._firmware_data, - self._target_checksum, - self._target_checksum_alg) + verified = self.verify_checksum(self._firmware_data, + self._target_checksum, + self._target_checksum_alg) if verified: self._log.debug('Checksum verified.') - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.VERIFIED.value + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.VERIFIED.value else: self._log.error('Checksum verification failed.') - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + await self._send_current_firmware_info() - if self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] == FirmwareStates.VERIFIED.value: + if self.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.VERIFIED.value: await self._apply_downloaded_firmware() async def _apply_downloaded_firmware(self): self._log.info('Applying downloaded firmware...') - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.UPDATING.value - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.UPDATING.value + await self._send_current_firmware_info() try: if self._save_firmware: self._save() except Exception as e: self._log.error('Failed to save firmware: %s', e) - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._send_current_firmware_info() return self.current_firmware_info = { - "current_" + FirmwareUpdater.FW_TITLE_ATTR: self._target_title, - "current_" + FirmwareUpdater.FW_VERSION_ATTR: self._target_version, - FirmwareUpdater.FW_STATE_ATTR: FirmwareStates.UPDATED.value + "current_" + FW_TITLE_ATTR: self._target_title, + "current_" + FW_VERSION_ATTR: self._target_version, + FW_STATE_ATTR: FirmwareStates.UPDATED.value } - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + await self._send_current_firmware_info() if self._on_received_callback: await self._on_received_callback(self._firmware_data, self.current_firmware_info) + await self._client._mqtt_manager.unsubscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC) self._log.info('Firmware is updated.') self._log.info('Current firmware version is: %s' % self._target_version) @@ -149,6 +161,10 @@ def _save(self): async def update(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, save_firmware: bool = True, firmware_save_path: Optional[str] = None): + if not self._client._mqtt_manager.is_connected(): + self._log.error("Client is not connected. Cannot start firmware update.") + return + self._log.info("Starting firmware update process...") self._on_received_callback = on_received_callback @@ -162,13 +178,13 @@ async def update(self, on_received_callback: Optional[Callable[[str], Awaitable[ while not sub_future.done(): await sleep(0.01) - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + await self._send_current_firmware_info() - attribute_request = await AttributeRequest.build(FirmwareUpdater.REQUIRED_SHARED_KEYS) + attribute_request = await AttributeRequest.build(REQUIRED_SHARED_KEYS) await self._client.send_attribute_request(attribute_request, callback=self._firmware_info_callback) async def _firmware_info_callback(self, response, *args, **kwargs): - if len(response.shared_keys()) == len(FirmwareUpdater.REQUIRED_SHARED_KEYS): + if len(response.shared_keys()) == len(REQUIRED_SHARED_KEYS): fetched_firmware_info = response.as_dict()['shared'] fetched_firmware_info = {item['key']: item['value'] for item in fetched_firmware_info} @@ -179,14 +195,14 @@ async def _firmware_info_callback(self, response, *args, **kwargs): self._firmware_data = b'' self._current_chunk = 0 - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value self._firmware_request_id += 1 - self._target_firmware_length = fetched_firmware_info[FirmwareUpdater.FW_SIZE_ATTR] - self._target_checksum = fetched_firmware_info[FirmwareUpdater.FW_CHECKSUM_ALG_ATTR] - self._target_checksum_alg = fetched_firmware_info[FirmwareUpdater.FW_CHECKSUM_ATTR] - self._target_title = fetched_firmware_info[FirmwareUpdater.FW_TITLE_ATTR] - self._target_version = fetched_firmware_info[FirmwareUpdater.FW_VERSION_ATTR] + self._target_firmware_length = fetched_firmware_info[FW_SIZE_ATTR] + self._target_checksum = fetched_firmware_info[FW_CHECKSUM_ALG_ATTR] + self._target_checksum_alg = fetched_firmware_info[FW_CHECKSUM_ATTR] + self._target_title = fetched_firmware_info[FW_TITLE_ATTR] + self._target_version = fetched_firmware_info[FW_VERSION_ATTR] await self._get_next_chunk() else: @@ -195,12 +211,66 @@ async def _firmware_info_callback(self, response, *args, **kwargs): self._log.error("Failed to fetch firmware info. " "Received firmware info does not match required keys. " "Expected: %s, Received: %s", - FirmwareUpdater.REQUIRED_SHARED_KEYS, + REQUIRED_SHARED_KEYS, response.shared_keys()) - self.current_firmware_info[FirmwareUpdater.FW_STATE_ATTR] = FirmwareStates.FAILED.value - await self._client.send_telemetry(self.current_firmware_info, wait_for_publish=True) + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._send_current_firmware_info() def _is_different_firmware_versions(self, new_firmware_info): - return (self.current_firmware_info['current_' + FirmwareUpdater.FW_TITLE_ATTR] != new_firmware_info[FirmwareUpdater.FW_TITLE_ATTR] or # noqa - self.current_firmware_info['current_' + FirmwareUpdater.FW_VERSION_ATTR] != new_firmware_info[FirmwareUpdater.FW_VERSION_ATTR]) # noqa + return (self.current_firmware_info['current_' + FW_TITLE_ATTR] != new_firmware_info[FW_TITLE_ATTR] or # noqa + self.current_firmware_info['current_' + FW_VERSION_ATTR] != new_firmware_info[FW_VERSION_ATTR]) # noqa + + async def _send_current_firmware_info(self): + current_info = [TimeseriesEntry(key, value) for key, value in self.current_firmware_info.items()] + await self._client.send_telemetry(current_info, wait_for_publish=True) + + def verify_checksum(self, firmware_data, checksum_alg, checksum): + if firmware_data is None: + self._log.debug('Firmware wasn\'t received!') + return False + + if checksum is None: + self._log.debug('Checksum was\'t provided!') + return False + + checksum_of_received_firmware = None + + self._log.debug('Checksum algorithm is: %s' % checksum_alg) + if checksum_alg.lower() == "sha256": + checksum_of_received_firmware = sha256(firmware_data).digest().hex() + elif checksum_alg.lower() == "sha384": + checksum_of_received_firmware = sha384(firmware_data).digest().hex() + elif checksum_alg.lower() == "sha512": + checksum_of_received_firmware = sha512(firmware_data).digest().hex() + elif checksum_alg.lower() == "md5": + checksum_of_received_firmware = md5(firmware_data).digest().hex() + elif checksum_alg.lower() == "murmur3_32": + reversed_checksum = f'{hash(firmware_data, signed=False):0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + elif checksum_alg.lower() == "murmur3_128": + reversed_checksum = f'{hash128(firmware_data, signed=False):0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + elif checksum_alg.lower() == "crc32": + reversed_checksum = f'{crc32(firmware_data) & 0xffffffff:0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + else: + self._log.error('Client error. Unsupported checksum algorithm.') + + self._log.debug(checksum_of_received_firmware) + + random_value = randint(0, 5) + if random_value > 3: + self._log.debug('Dummy fail! Do not panic, just restart and try again the chance of this fail is ~20%') + return False + + return checksum_of_received_firmware == checksum From 1b6f1175498f0759235d017967c679095c360635 Mon Sep 17 00:00:00 2001 From: samson0v Date: Thu, 19 Jun 2025 09:51:58 +0300 Subject: [PATCH 15/74] Device provisioning fixes --- tb_mqtt_client/constants/provisioning.py | 29 +++++++++++++++++++ .../entities/data/provisioning_request.py | 17 +++++------ .../entities/data/provisioning_response.py | 12 ++------ tb_mqtt_client/service/message_dispatcher.py | 15 +++++++--- 4 files changed, 49 insertions(+), 24 deletions(-) create mode 100644 tb_mqtt_client/constants/provisioning.py diff --git a/tb_mqtt_client/constants/provisioning.py b/tb_mqtt_client/constants/provisioning.py new file mode 100644 index 0000000..a8cf8c8 --- /dev/null +++ b/tb_mqtt_client/constants/provisioning.py @@ -0,0 +1,29 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ProvisioningResponseStatus(Enum): + SUCCESS = "SUCCESS" + ERROR = "FAILURE" + + def __str__(self): + return self.value + + +class ProvisioningCredentialsType(Enum): + ACCESS_TOKEN = "ACCESS_TOKEN" + MQTT_BASIC = "MQTT_BASIC" + X509_CERTIFICATE = "X509_CERTIFICATE" diff --git a/tb_mqtt_client/entities/data/provisioning_request.py b/tb_mqtt_client/entities/data/provisioning_request.py index 8101237..e54db43 100644 --- a/tb_mqtt_client/entities/data/provisioning_request.py +++ b/tb_mqtt_client/entities/data/provisioning_request.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum +from abc import ABC, abstractmethod from typing import Optional +from tb_mqtt_client.constants.provisioning import ProvisioningCredentialsType + class ProvisioningRequest: def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "1883", @@ -26,21 +28,16 @@ def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "18 self.gateway = gateway -class ProvisioningCredentialsType(Enum): - ACCESS_TOKEN = "ACCESS_TOKEN" - MQTT_BASIC = "MQTT_BASIC" - X509_CERTIFICATE = "X509_CERTIFICATE" - - -class ProvisioningCredentials: +class ProvisioningCredentials(ABC): + @abstractmethod def __init__(self, provision_device_key: str, provision_device_secret: str): self.provision_device_key = provision_device_key self.provision_device_secret = provision_device_secret - self.credentials_type = None + self.credentials_type: ProvisioningCredentialsType class AccessTokenProvisioningCredentials(ProvisioningCredentials): - def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: str): + def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: Optional[str] = None): super().__init__(provision_device_key, provision_device_secret) self.access_token = access_token self.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py index fcea0bd..82e7a73 100644 --- a/tb_mqtt_client/entities/data/provisioning_response.py +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -13,25 +13,17 @@ # limitations under the License. from dataclasses import dataclass -from enum import Enum from typing import Optional from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType -class ProvisioningResponseStatus(Enum): - SUCCESS = "SUCCESS" - ERROR = "FAILURE" - - def __str__(self): - return self.value - - @dataclass(frozen=True) class ProvisioningResponse: status: ProvisioningResponseStatus - result: Optional[dict] = None + result: Optional[DeviceConfig] = None error: Optional[str] = None def __new__(cls, *args, **kwargs): diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 5dbd9fa..fe85482 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -337,13 +337,20 @@ def build_provision_request(self, provision_request: 'ProvisioningRequest') -> T if provision_request.credentials.credentials_type and \ provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: - request["token"] = provision_request.credentials.access_token + if provision_request.credentials.access_token is not None: + request["token"] = provision_request.credentials.access_token request["credentialsType"] = provision_request.credentials.credentials_type.value if provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: - request["username"] = provision_request.credentials.username - request["password"] = provision_request.credentials.password - request["clientId"] = provision_request.credentials.client_id + if provision_request.credentials.username is not None: + request["username"] = provision_request.credentials.username + + if provision_request.credentials.password is not None: + request["password"] = provision_request.credentials.password + + if provision_request.credentials.client_id is not None: + request["clientId"] = provision_request.credentials.client_id + request["credentialsType"] = provision_request.credentials.credentials_type.value if provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: From 35d89c1800649d4b39fba4ab802d50b4399a5775 Mon Sep 17 00:00:00 2001 From: samson0v Date: Thu, 19 Jun 2025 10:38:12 +0300 Subject: [PATCH 16/74] Fixed on_unsubscribe callback --- tb_mqtt_client/service/mqtt_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 9cdc696..16bebbb 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -337,7 +337,7 @@ def _on_subscribe_internal(self, client, mid, qos, properties): if future and not future.done(): future.set_result(mid) - def _on_unsubscribe_internal(self, client, mid): + def _on_unsubscribe_internal(self, client, mid, _): logger.trace("Received UNSUBACK by client %r for mid=%s", client, mid) future = self._pending_unsubscriptions.pop(mid, None) if future and not future.done(): From 23322cc348bcfd882a8690a04bca917b8a074fb0 Mon Sep 17 00:00:00 2001 From: samson0v Date: Thu, 19 Jun 2025 10:40:52 +0300 Subject: [PATCH 17/74] Added properies argument to on_unsubscribe callback --- tb_mqtt_client/service/mqtt_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 16bebbb..af40488 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -337,8 +337,8 @@ def _on_subscribe_internal(self, client, mid, qos, properties): if future and not future.done(): future.set_result(mid) - def _on_unsubscribe_internal(self, client, mid, _): - logger.trace("Received UNSUBACK by client %r for mid=%s", client, mid) + def _on_unsubscribe_internal(self, client, mid, properties): + logger.trace("Received UNSUBACK by client %r for mid=%s with properties %s", client, mid, properties) future = self._pending_unsubscriptions.pop(mid, None) if future and not future.done(): future.set_result(mid) From 603fc32ef059a1ab66f37e528dfd746d43354eed Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 23 Jun 2025 15:04:49 +0300 Subject: [PATCH 18/74] Added CONACK handler --- tb_mqtt_client/common/async_utils.py | 0 tb_mqtt_client/common/gmqtt_patch.py | 168 ++++++++++++++++----------- 2 files changed, 102 insertions(+), 66 deletions(-) create mode 100644 tb_mqtt_client/common/async_utils.py diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index 4ccd951..233d4e1 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -15,6 +15,7 @@ import asyncio import struct from collections import defaultdict +from types import MethodType from typing import Callable from gmqtt.mqtt.constants import MQTTCommands @@ -27,38 +28,66 @@ logger = get_logger(__name__) -# MQTT 5.0 Disconnect Reason Codes -DISCONNECT_REASON_CODES = { - 0: "Normal disconnection", - 4: "Disconnect with Will Message", - 128: "Unspecified error", - 129: "Malformed Packet", - 130: "Protocol Error", - 131: "Implementation specific error", - 132: "Not authorized", - 133: "Server busy", - 134: "Server shutting down", - 135: "Keep Alive timeout", - 136: "Session taken over", - 137: "Topic Filter invalid", - 138: "Topic Name invalid", - 139: "Receive Maximum exceeded", - 140: "Topic Alias invalid", - 141: "Packet too large", - 142: "Session taken over", - 143: "Quota exceeded", - 144: "Administrative action", - 145: "Payload format invalid", - 146: "Retain not supported", - 147: "QoS not supported", - 148: "Use another server", - 149: "Server moved", - 150: "Shared Subscriptions not supported", - 151: "Connection rate exceeded", - 152: "Maximum connect time", - 153: "Subscription Identifiers not supported", - 154: "Wildcard Subscriptions not supported" -} + +class PatchUtils: + DISCONNECT_REASON_CODES = { + 0: "Normal disconnection", + 4: "Disconnect with Will Message", + 128: "Unspecified error", + 129: "Malformed Packet", + 130: "Protocol Error", + 131: "Implementation specific error", + 132: "Not authorized", + 133: "Server busy", + 134: "Server shutting down", + 135: "Keep Alive timeout", + 136: "Session taken over", + 137: "Topic Filter invalid", + 138: "Topic Name invalid", + 139: "Receive Maximum exceeded", + 140: "Topic Alias invalid", + 141: "Packet too large", + 142: "Session taken over", + 143: "Quota exceeded", + 144: "Administrative action", + 145: "Payload format invalid", + 146: "Retain not supported", + 147: "QoS not supported", + 148: "Use another server", + 149: "Server moved", + 150: "Shared Subscriptions not supported", + 151: "Connection rate exceeded", + 152: "Maximum connect time", + 153: "Subscription Identifiers not supported", + 154: "Wildcard Subscriptions not supported" + } + + @staticmethod + def parse_mqtt_properties(packet: bytes) -> dict: + """ + Parse MQTT 5.0 properties from a packet. + """ + properties_dict = defaultdict(list) + + try: + properties_len, _ = unpack_variable_byte_integer(packet) + props = packet[:properties_len] + + while props: + property_identifier = props[0] + property_obj = Property.factory(id_=property_identifier) + if property_obj is None: + logger.warning(f"Unknown property id={property_identifier}") + break + + result, props = property_obj.loads(props[1:]) + for k, v in result.items(): + properties_dict[k].append(v) + + except Exception as e: + logger.warning("Failed to parse properties: %s", e) + + return dict(properties_dict) def extract_reason_code(packet): @@ -97,12 +126,12 @@ def patched_handle_disconnect_packet(self, cmd, packet): properties = {} if packet and len(packet) > 1: try: - properties, _ = self._parse_properties(packet[1:]) + properties = PatchUtils.parse_mqtt_properties(packet[1:]) except Exception as exc: logger.warning("Failed to parse properties from disconnect packet: %s", exc) - reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") - logger.debug("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) + reason_desc = PatchUtils.DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + logger.trace("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) # Call the original method to handle reconnection # But don't call the on_disconnect callback, as we'll do that ourselves @@ -127,6 +156,40 @@ def patched_handle_disconnect_packet(self, cmd, packet): logger.warning("Failed to patch gmqtt handler: %s", e) return False +def patch_handle_connack(client, on_connack_with_session_present_and_result_code: Callable[[object, int, int, dict], None]): + """ + Monkey-patch gmqtt.mqtt.handler.MqttPackageHandler._handle_connack_packet to add custom handling for + CONNACK packets, allowing for custom callbacks with session_present, result_code, and properties. + """ + try: + original_handler = MqttPackageHandler._handle_connack_packet + + def new_handle_connack_packet(self, cmd, packet): + try: + original_handler(self, cmd, packet) + + session_present, reason_code = struct.unpack("!BB", packet[:2]) + + if len(packet) > 2: + props_payload = packet[2:] + properties = PatchUtils.parse_mqtt_properties(props_payload) + else: + properties = {} + + logger.debug("CONNACK patched handler: session_present=%r, reason_code=%r, properties=%r", + session_present, reason_code, properties) + + on_connack_with_session_present_and_result_code(client, session_present, reason_code, properties) + except Exception as e: + logger.error("Error while handling CONNACK packet: %s", e, exc_info=True) + + MqttPackageHandler._handle_connack_packet = new_handle_connack_packet + logger.debug("Successfully patched gmqtt.mqtt.handler._handle_connack_packet") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt handler: %s", e) + return False + def patch_gmqtt_protocol_connection_lost(): """ Monkey-patch gmqtt.mqtt.protocol.BaseMQTTProtocol.connection_lost to suppress the @@ -235,39 +298,12 @@ def patch_gmqtt_puback(client, on_puback_with_reason_and_properties: Callable[[i :param client: GMQTTClient instance :param on_puback_with_reason_and_properties: Callback with (mid, reason_code, properties_dict) """ - # Backup original method from MqttPackageHandler - base_method = client.__class__.__bases__[0].__dict__.get('_handle_puback_packet') + original_handler = MqttPackageHandler._handle_puback_packet - if base_method is None: + if original_handler is None: logger.error("Could not find _handle_puback_packet in base class.") return - def _parse_properties(packet: bytes) -> dict: - """ - Parse MQTT 5.0 properties from a packet. - """ - properties_dict = defaultdict(list) - - try: - properties_len, _ = unpack_variable_byte_integer(packet) - props = packet[:properties_len] - - while props: - property_identifier = props[0] - property_obj = Property.factory(id_=property_identifier) - if property_obj is None: - logger.warning(f"Unknown PUBACK property id={property_identifier}") - break - - result, props = property_obj.loads(props[1:]) - for k, v in result.items(): - properties_dict[k].append(v) - - except Exception as e: - logger.warning("Failed to parse PUBACK properties: %s", e) - - return dict(properties_dict) - def wrapped_handle_puback(self, cmd, packet): try: mid = struct.unpack("!H", packet[:2])[0] @@ -278,12 +314,12 @@ def wrapped_handle_puback(self, cmd, packet): reason_code = packet[2] if len(packet) > 3: props_payload = packet[3:] - properties = _parse_properties(props_payload) + properties = PatchUtils.parse_mqtt_properties(props_payload) on_puback_with_reason_and_properties(mid, reason_code, properties) except Exception as e: logger.exception("Error while handling PUBACK with properties: %s", e) - return base_method(self, cmd, packet) + return original_handler(self, cmd, packet) MqttPackageHandler._handle_puback_packet = wrapped_handle_puback From bd47806dd9c14a1d788974b088d2fd3432b91ba1 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 23 Jun 2025 15:06:15 +0300 Subject: [PATCH 19/74] Updated async functions to wait for responses by default --- tb_mqtt_client/common/async_utils.py | 59 +++++++ tb_mqtt_client/service/base_client.py | 12 +- tb_mqtt_client/service/device/client.py | 149 ++++++++++++------ .../device/handlers/rpc_response_handler.py | 2 +- tb_mqtt_client/service/message_dispatcher.py | 25 +-- tb_mqtt_client/service/message_queue.py | 11 +- tb_mqtt_client/service/message_splitter.py | 8 +- tb_mqtt_client/service/mqtt_manager.py | 43 +++-- 8 files changed, 224 insertions(+), 85 deletions(-) diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py index e69de29..49962ba 100644 --- a/tb_mqtt_client/common/async_utils.py +++ b/tb_mqtt_client/common/async_utils.py @@ -0,0 +1,59 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Optional, Any +import asyncio + + +async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, Any], + stop_event: asyncio.Event, + timeout: Optional[float]) -> Optional[Any]: + """ + Await the given future/coroutine until it completes, timeout expires, or stop_event is set. + + :param future_or_coroutine: An awaitable coroutine, asyncio.Future, or asyncio.Task. + :param stop_event: asyncio.Event that signals shutdown. + :param timeout: Optional timeout in seconds, -1 for no timeout, or None to wait indefinitely. + :return: The result if completed successfully, or None on timeout/stop. + """ + if asyncio.iscoroutine(future_or_coroutine): + main_task = asyncio.create_task(future_or_coroutine) + elif asyncio.isfuture(future_or_coroutine): + main_task = future_or_coroutine + else: + raise TypeError("Expected coroutine or Future/Task") + + stop_task = asyncio.create_task(stop_event.wait()) + + if timeout is not None and timeout < 0: + timeout = None + + try: + done, _ = await asyncio.wait( + [main_task, stop_task], + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED + ) + if main_task in done: + return await main_task + if stop_task in done: + if stop_event.is_set(): + return None + if timeout is not None and not done: + raise asyncio.TimeoutError("Operation timed out") + except asyncio.CancelledError: + return None + finally: + if not stop_task.done(): + stop_task.cancel() diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 1f1bebe..bd17588 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -14,7 +14,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import Callable, Awaitable, Dict, Any, Union, List +from typing import Callable, Awaitable, Dict, Any, Union, List, Optional import uvloop @@ -35,7 +35,7 @@ class BaseClient(ABC): Abstract base class for clients. """ - DEFAULT_TIMEOUT = 3 + DEFAULT_TIMEOUT = 3.0 def __init__(self, host: str, port: int, client_id: str): self._host = host @@ -63,7 +63,7 @@ async def send_telemetry(self, telemetry_data: Union[TimeseriesEntry, Dict[str, Any], List[Dict[str, Any]]], wait_for_publish: bool = True, - timeout: int = DEFAULT_TIMEOUT) -> Union[asyncio.Future[PublishResult], PublishResult]: + timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], PublishResult]: """ Send telemetry data. @@ -71,7 +71,7 @@ async def send_telemetry(self, telemetry_data: Union[TimeseriesEntry, or a list of TimeseriesEntry or dictionaries. :param wait_for_publish: If True, wait for the publishing result. Default is True. :param timeout: Timeout for the publish operation if `wait_for_publish` is True. - In seconds, defaults to 3 seconds. + In seconds. If less than 0 or None, wait indefinitely. :return: Future or PublishResult depending on `wait_for_publish`. """ pass @@ -80,14 +80,14 @@ async def send_telemetry(self, telemetry_data: Union[TimeseriesEntry, async def send_attributes(self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], wait_for_publish: bool = True, - timeout: int = DEFAULT_TIMEOUT) -> Union[asyncio.Future[PublishResult], PublishResult]: + timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], PublishResult]: """ Send client attributes. :param attributes: Dictionary of attributes or a single AttributeEntry or a list of AttributeEntries. :param wait_for_publish: If True, wait for the publishing result. Default is True. :param timeout: Timeout for the publish operation if `wait_for_publish` is True. - In seconds, defaults to 3 seconds. + In seconds. If less than 0 or None, wait indefinitely. :return: Future or PublishResult depending on `wait_for_publish`. """ pass diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 7b10448..cd60789 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -20,12 +20,13 @@ from orjson import dumps +from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType +from tb_mqtt_client.constants.json_typing import JSONCompatibleType from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest @@ -125,6 +126,8 @@ async def connect(self): while not self._mqtt_manager.is_connected(): await self._mqtt_manager.await_ready() + if self._stop_event.is_set(): + return await self._on_connect() @@ -156,8 +159,17 @@ async def stop(self): """ logger.info("Stopping DeviceClient...") self._stop_event.set() + + for fut, _ in self._rpc_response_handler._pending_rpc_requests.values(): + if not fut.done(): + fut.cancel() + + if self._message_queue: + await self._message_queue.shutdown() + if self._mqtt_manager.is_connected(): await self._mqtt_manager.disconnect() + logger.info("DeviceClient stopped.") async def disconnect(self): @@ -166,58 +178,105 @@ async def disconnect(self): # await self._message_queue.shutdown() # TODO: Not sure if we need to shutdown the message queue here, as it might be handled by MQTTManager - async def send_telemetry(self, - data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], - qos: int = 1, - wait_for_publish: bool = True, - timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: + async def send_telemetry( + self, + data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], + qos: int = 1, + wait_for_publish: bool = True, + timeout: Optional[float] = None + ) -> Union[PublishResult, List[PublishResult], None]: message = self._build_uplink_message_for_telemetry(data) topic = mqtt_topics.DEVICE_TELEMETRY_TOPIC - futures = await self._message_queue.publish(topic=topic, - payload=message, - datapoints_count=message.timeseries_datapoint_count(), - qos=qos or self._config.qos) - if wait_for_publish: + futures = await self._message_queue.publish( + topic=topic, + payload=message, + datapoints_count=message.timeseries_datapoint_count(), + qos=qos or self._config.qos + ) + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return None + + results = [] + for fut in futures: try: - return await wait_for(futures[0], timeout=timeout) if futures else None + result = await await_or_stop(fut, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for telemetry publish result") - return PublishResult(topic, qos, -1, message.size, -1) - else: - return futures[0] if futures else None - - async def send_attributes(self, - attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], - qos: int = None, - wait_for_publish: bool = True, - timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: + result = PublishResult(topic, qos, -1, message.size, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + async def send_attributes( + self, + attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + qos: int = None, + wait_for_publish: bool = True, + timeout: int = BaseClient.DEFAULT_TIMEOUT + ) -> Union[PublishResult, List[PublishResult], None]: message = self._build_uplink_message_for_attributes(attributes) topic = mqtt_topics.DEVICE_ATTRIBUTES_TOPIC - futures = await self._message_queue.publish(topic=topic, - payload=message, - datapoints_count=message.attributes_datapoint_count(), - qos=qos or self._config.qos) - if wait_for_publish: + futures = await self._message_queue.publish( + topic=topic, + payload=message, + datapoints_count=message.attributes_datapoint_count(), + qos=qos or self._config.qos + ) + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return None + + results = [] + for fut in futures: try: - return await wait_for(futures[0], timeout=timeout) if futures else None + result = await await_or_stop(fut, timeout=timeout, stop_event=self._stop_event) except TimeoutError: - logger.warning("Timeout while waiting for telemetry publish result") - return PublishResult(topic, qos, -1, message.size, -1) - else: - return futures[0] if futures else None - - async def send_rpc_request(self, rpc_request: RPCRequest, - callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> Awaitable[RPCResponse]: + logger.warning("Timeout while waiting for attribute publish result") + result = PublishResult(topic, qos, -1, message.size, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + async def send_rpc_request( + self, + rpc_request: RPCRequest, + callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None, + wait_for_publish: bool = True, + timeout: Optional[float] = BaseClient.DEFAULT_TIMEOUT + ) -> Union[RPCResponse, Awaitable[RPCResponse], None]: request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() topic, payload = self._message_dispatcher.build_rpc_request(rpc_request) response_future = self._rpc_response_handler.register_request(request_id, callback) - await self._message_queue.publish(topic=topic, - payload=payload, - datapoints_count=0, - qos=self._config.qos) - return response_future + await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=0, + qos=self._config.qos + ) + + if not wait_for_publish: + return response_future + + try: + return await await_or_stop(response_future, timeout=timeout, stop_event=self._stop_event) + except TimeoutError as e: + if not callback: + raise TimeoutError(f"Timed out waiting for RPC response (requestId={request_id})") + else: + logger.warning("Timed out waiting for RPC response, but callback is set. " + "Callback will be called with None response.") + await self._rpc_response_handler.handle(mqtt_topics.build_device_rpc_response_topic(rpc_request.request_id), e) async def send_rpc_response(self, response: RPCResponse): topic, payload = self._message_dispatcher.build_rpc_response(response) @@ -247,7 +306,7 @@ async def claim_device(self, await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) if wait_for_publish: try: - return await wait_for(self.__claiming_response_future, timeout=timeout) + return await await_or_stop(self.__claiming_response_future, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for telemetry publish result") return PublishResult(topic, 1, -1, len(payload), -1) @@ -266,6 +325,8 @@ async def _on_connect(self): sub_future = await self._mqtt_manager.subscribe(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, qos=1) while not sub_future.done(): await sleep(0.01) + if self._stop_event.is_set(): + return self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa @@ -277,13 +338,13 @@ async def _on_disconnect(self): self._requested_attribute_response_handler.clear() self._rpc_response_handler.clear() - async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Any: + async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Union[RPCResponse, None]: """ Initiates a client-side RPC to ThingsBoard and awaits the result. :param method: The RPC method to call. :param params: The parameters to send. :param timeout: Timeout for the response in seconds. - :return: The response result (dict, list, str, etc.). + :return: RPCResponse object containing the result or error. """ request_id = await RPCRequestIdProducer.get_next() topic = mqtt_topics.build_device_rpc_request_topic(request_id) @@ -296,7 +357,7 @@ async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = No await self._mqtt_manager.publish(topic, payload, qos=1) try: - return await wait_for(future, timeout=timeout) + return await await_or_stop(future, timeout=timeout, stop_event=self._stop_event) except TimeoutError: raise TimeoutError(f"Timed out waiting for RPC response (method={method}, id={request_id})") @@ -458,7 +519,7 @@ def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], return builder.build() @staticmethod - async def provision(provision_request: 'ProvisioningRequest', timeout=3.0): + async def provision(provision_request: 'ProvisioningRequest', timeout=BaseClient.DEFAULT_TIMEOUT): provision_client = ProvisioningClient( host=provision_request.host, port=provision_request.port, diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py index 9008b67..f241f3d 100644 --- a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -56,7 +56,7 @@ def register_request(self, request_id: Union[str, int], self._pending_rpc_requests[request_id] = future, callback return future - async def handle(self, topic: str, payload: bytes): + async def handle(self, topic: str, payload: Union[bytes, TimeoutError]): """ Handles the incoming RPC response from the platform and fulfills the corresponding future. The topic is expected to be: v1/devices/me/rpc/response/{request_id} diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index fe85482..09151d5 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -14,6 +14,7 @@ import asyncio from abc import ABC, abstractmethod +from itertools import chain from collections import defaultdict from datetime import UTC, datetime from typing import Any, Dict, List, Tuple, Optional, Union @@ -125,7 +126,7 @@ def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: pass @abstractmethod - def parse_rpc_response(self, topic: str, payload: bytes) -> RPCResponse: + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: """ Parse the RPC response from the given topic and payload. This method should be implemented to handle the specific format of the RPC response. @@ -191,7 +192,7 @@ def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: logger.error("Failed to parse RPC request: %s", str(e)) raise ValueError("Invalid RPC request format") from e - def parse_rpc_response(self, topic: str, payload: bytes) -> RPCResponse: + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: """ Parse the RPC response from the given topic and payload. :param topic: The MQTT topic of the RPC response. @@ -200,8 +201,11 @@ def parse_rpc_response(self, topic: str, payload: bytes) -> RPCResponse: """ try: request_id = int(topic.split("/")[-1]) - parsed = loads(payload) - data = RPCResponse.build(request_id, parsed) # noqa + if isinstance(payload, Exception): + data = RPCResponse.build(request_id, error=payload) + else: + parsed = loads(payload) + data = RPCResponse.build(request_id, parsed) # noqa return data except Exception as e: logger.error("Failed to parse RPC response: %s", str(e)) @@ -392,15 +396,12 @@ def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def pack_timeseries(msg: DeviceUplinkMessage) -> List[Dict[str, Any]]: - logger.trace("Packing %d timeseries timestamp bucket(s)", len(msg.timeseries)) - + def pack_timeseries(msg: 'DeviceUplinkMessage') -> List[Dict[str, Any]]: now_ts = int(datetime.now(UTC).timestamp() * 1000) - packed: List[Dict[str, Any]] = [] - for ts_key, entries in msg.timeseries.items(): - resolved_ts = ts_key or now_ts - values = {entry.key: entry.value for entry in entries} - packed.append({"ts": resolved_ts, "values": values}) + packed = [ + {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} + for entry in chain.from_iterable(msg.timeseries.values()) + ] return packed diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index fb8a74b..e651074 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -41,7 +41,7 @@ def __init__(self, batch_collect_max_time_ms: int = 100, batch_collect_max_count: int = 500): self._main_stop_event = main_stop_event - self._batch_max_time = batch_collect_max_time_ms / 1000 # convert to seconds + self._batch_max_time = batch_collect_max_time_ms / 1000 self._batch_max_count = batch_collect_max_count self._mqtt_manager = mqtt_manager self._message_rate_limit = message_rate_limit @@ -78,7 +78,7 @@ async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], async def _dequeue_loop(self): logger.debug("MessageQueue dequeue loop started.") - while self._active.is_set(): + while self._active.is_set() and not self._main_stop_event.is_set(): try: # topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) topic, payload, delivery_futures_or_none, datapoints, qos = await self._wait_for_message() @@ -331,12 +331,15 @@ async def shutdown(self): logger.warning("Error while cancelling retry task: %s", e) self._loop_task.cancel() - self._rate_limit_refill_task.cancel() + if self._rate_limit_refill_task: + self._rate_limit_refill_task.cancel() with suppress(asyncio.CancelledError): await self._loop_task await self._rate_limit_refill_task - logger.debug("MessageQueue shutdown complete.") + logger.debug("MessageQueue shutdown complete, message queue size: %d", + self._queue.qsize()) + self.clear() def is_empty(self): return self._queue.empty() diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index 61267ec..8f187d3 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -56,8 +56,8 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp batch_futures = [] for grouped_ts in message.timeseries.values(): - for ts in grouped_ts: - exceeds_size = builder and size + ts.size > self._max_payload_size + for ts_kv in grouped_ts: + exceeds_size = builder and size + ts_kv.size > self._max_payload_size exceeds_points = 0 < self._max_datapoints <= point_count if not builder or exceeds_size or exceeds_points: @@ -71,8 +71,8 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp size = 0 point_count = 0 - builder.add_telemetry(ts) - size += ts.size + builder.add_telemetry(ts_kv) + size += ts_kv.size point_count += 1 if builder and builder._timeseries: # noqa diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index af40488..ab67f6c 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -21,8 +21,9 @@ from gmqtt import Client as GMQTTClient, Message, Subscription +from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback, patch_gmqtt_protocol_connection_lost, \ - patch_mqtt_handler_disconnect, DISCONNECT_REASON_CODES + patch_mqtt_handler_disconnect, patch_handle_connack, PatchUtils from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit @@ -64,6 +65,7 @@ def __init__( self._client = GMQTTClient(client_id) patch_gmqtt_puback(self._client, self._handle_puback_reason_code) + patch_handle_connack(self._client, self._on_connect_internal) self._client.on_connect = self._on_connect_internal self._client.on_disconnect = self._on_disconnect_internal self._client.on_message = self._on_message_internal @@ -114,7 +116,7 @@ async def _connect_loop(self): host, port, username, password, tls, keepalive, ssl_context = self._connect_params retry_delay = 3 - while not self._client.is_connected: + while not self._client.is_connected and not self._main_stop_event.is_set(): try: if username: self._client.set_auth_credentials(username, password) @@ -129,7 +131,7 @@ async def _connect_loop(self): logger.info("MQTT connection initiated, waiting for on_connect...") await self._connected_event.wait() logger.info("MQTT connected.") - break # Exit loop if connected + break except Exception as e: logger.warning("Initial MQTT connection failed: %s. Retrying in %s seconds...", str(e), retry_delay) @@ -162,7 +164,7 @@ async def publish(self, message_or_topic: Union[str, Message], raise RuntimeError("Cannot publish before rate limits are retrieved.") try: if not self._rate_limits_ready_event.is_set(): - await asyncio.wait_for(self._rate_limits_ready_event.wait(), timeout=10) + await await_or_stop(self._rate_limits_ready_event.wait(), self._main_stop_event, timeout=10) except asyncio.TimeoutError: raise RuntimeError("Timeout waiting for rate limits.") @@ -211,9 +213,15 @@ def register_handler(self, topic_filter: str, handler: Callable[[str, bytes], Co def unregister_handler(self, topic_filter: str): self._handlers.pop(topic_filter, None) - def _on_connect_internal(self, client, flags, rc, properties): - logger.info("Connected to platform") - logger.debug("Connection flags: %s, reason code: %s, properties: %s", flags, rc, properties) + def _on_connect_internal(self, client, session_present, reason_code, properties): + if reason_code != 0: + logger.error("Failed to connect to platform with reason code: %s", reason_code) + if properties and 'reason_string' in properties: + logger.error("Connection reason: %s", properties['reason_string'][0]) + self._connected_event.clear() + return + logger.info("Connected to the platform.") + logger.debug("Connection session_present: %s, reason code: %s, properties: %s", session_present, reason_code, properties) if hasattr(client, '_connection'): client._connection._on_disconnect_called = False # noqa self._connected_event.set() @@ -224,9 +232,13 @@ async def __handle_connect_and_limits(self): sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, qos=1) while not sub_future.done(): await sleep(0.01) + if self._main_stop_event.is_set(): + return sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, qos=1) while not sub_future.done(): await sleep(0.01) + if self._main_stop_event.is_set(): + return logger.debug("Subscribing completed, sending rate limits request") await self.__request_rate_limits() @@ -235,15 +247,19 @@ async def __handle_connect_and_limits(self): asyncio.create_task(self._on_connect_callback()) def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): # noqa + if isinstance(reason_code, bytes): + # Skipping handling due to duplication, because gmqtt triggers this cb again. + logger.trace("Received bytes reason code: %r", reason_code) + return self._connected_event.clear() if reason_code is not None: - reason_desc = DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + reason_desc = PatchUtils.DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") logger.info("Disconnected from platform with reason code: %s (%s)", reason_code, reason_desc) if properties and 'reason_string' in properties: logger.info("Disconnect reason: %s", properties['reason_string'][0]) else: - logger.info("Disconnected from platform") + logger.info("Disconnected from the platform.") if exc: logger.warning("Disconnect exception: %s", exc) @@ -275,10 +291,10 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc reached_time = 1 for rate_limit in self.__rate_limiter.values(): if isinstance(rate_limit, RateLimit): - reached_limit = rate_limit.reach_limit() + reached_limit = asyncio.get_event_loop().run_until_complete(rate_limit.reach_limit()) reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) self._backpressure.notify_disconnect(delay_seconds=reached_time) - else: + elif reason_code != 0: # Default disconnect handling self._backpressure.notify_disconnect(delay_seconds=15) @@ -345,7 +361,7 @@ def _on_unsubscribe_internal(self, client, mid, properties): async def await_ready(self, timeout: float = 10.0): try: - await asyncio.wait_for(self._rate_limits_ready_event.wait(), timeout=timeout) + await await_or_stop(self._rate_limits_ready_event.wait(), self._main_stop_event, timeout=timeout) except asyncio.TimeoutError: logger.debug("Waiting for rate limits timed out.") @@ -375,7 +391,7 @@ async def __request_rate_limits(self): try: await self.publish(topic, payload, qos=1, force=True) - await asyncio.wait_for(response_future, timeout=10) + await await_or_stop(response_future, self._main_stop_event, timeout=10) logger.info("Successfully processed rate limits.") self.__rate_limits_retrieved = True self.__is_waiting_for_rate_limits_publish = False @@ -427,4 +443,3 @@ async def check_pending_publishes(self, time_to_check): expired.append(mid) for mid in expired: self._pending_publishes.pop(mid, None) - From db545f25fbbfaa65b63a7e8e0a3a14467a616d79 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 23 Jun 2025 15:06:45 +0300 Subject: [PATCH 20/74] Updated example due to changes in device client --- examples/device/operational_example.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index c1289fc..95ad46a 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -156,11 +156,13 @@ def _shutdown_handler(): # 3. List of TelemetryEntry with mixed timestamps telemetry_entries = [] - for i in range(1): + for i in range(100): telemetry_entries.append(TimeseriesEntry("temperature", i, ts=int(datetime.now(UTC).timestamp() * 1000)-i)) logger.info("Sending list of telemetry entries with mixed timestamps...") telemetry_list_publish_result = await client.send_telemetry(telemetry_entries) - logger.info("List of telemetry entries sent: %s with result: %s", telemetry_entries, telemetry_list_publish_result) + logger.info("List of telemetry entries sent: %s with result: %s", + len(telemetry_entries) if len(telemetry_entries) > 10 else telemetry_entries, + telemetry_list_publish_result) # --- Attribute Request --- @@ -180,23 +182,20 @@ def _shutdown_handler(): logger.info("Sending RPC request: %r", rpc_request) - response_future = await client.send_rpc_request(rpc_request) + try: + rpc_response = await client.send_rpc_request(rpc_request) + except TimeoutError as e: + logger.error("Timeout while sending RPC request: %s", e) + rpc_response = None - if response_future: - logger.info("Awaiting RPC response future...") - try: - response = await asyncio.wait_for(response_future, timeout=5) - logger.info("RPC response received: %s", response) - except asyncio.TimeoutError: - logger.warning("RPC response future timed out after 5 seconds.") - except Exception as e: - logger.error("Error while awaiting RPC response future: %s", e) + if rpc_response: + logger.info("RPC response received: %s", rpc_response) rpc_request_2 = await RPCRequest.build("getAnotherInformation", {"param": "value"}) logger.info("Sending another RPC request: %r", rpc_request_2) - await client.send_rpc_request(rpc_request_2, rpc_response_callback) + await client.send_rpc_request(rpc_request_2, rpc_response_callback, wait_for_publish=False) try: logger.info("Waiting for 1 seconds before next iteration...") From c45148f104f96e886998ab5420f60602433cfa5f Mon Sep 17 00:00:00 2001 From: imbeacon Date: Tue, 24 Jun 2025 12:45:23 +0300 Subject: [PATCH 21/74] Updated examples for device --- ...y => DEPRECATEDclaiming_device_pe_only.py} | 0 ...ng.py => DEPRECATEDclient_provisioning.py} | 0 ...update.py => DEPRECATEDfirmware_update.py} | 0 ...ls_connect.py => DEPRECATEDtls_connect.py} | 0 examples/device/client_rpc_request.py | 44 ----- examples/device/handle_attribute_updates.py | 45 +++++ examples/device/handle_rpc_requests.py | 50 ++++++ examples/device/hardware_specs_sender.py | 70 -------- examples/device/load.py | 2 + examples/device/request_attributes.py | 63 ++++--- examples/device/send_attributes.py | 48 ++++++ examples/device/send_client_side_rpc.py | 50 ++++++ examples/device/send_telemetry.py | 50 ++++++ examples/device/send_telemetry_and_attr.py | 61 ------- examples/device/send_telemetry_pack.py | 44 ----- examples/device/subscription_to_attrs.py | 41 ----- examples/gateway/operational_example.py | 162 ++++++++++++++++++ 17 files changed, 437 insertions(+), 293 deletions(-) rename examples/device/{claiming_device_pe_only.py => DEPRECATEDclaiming_device_pe_only.py} (100%) rename examples/device/{client_provisioning.py => DEPRECATEDclient_provisioning.py} (100%) rename examples/device/{firmware_update.py => DEPRECATEDfirmware_update.py} (100%) rename examples/device/{tls_connect.py => DEPRECATEDtls_connect.py} (100%) delete mode 100644 examples/device/client_rpc_request.py create mode 100644 examples/device/handle_attribute_updates.py create mode 100644 examples/device/handle_rpc_requests.py delete mode 100644 examples/device/hardware_specs_sender.py create mode 100644 examples/device/send_attributes.py create mode 100644 examples/device/send_client_side_rpc.py create mode 100644 examples/device/send_telemetry.py delete mode 100644 examples/device/send_telemetry_and_attr.py delete mode 100644 examples/device/send_telemetry_pack.py delete mode 100644 examples/device/subscription_to_attrs.py create mode 100644 examples/gateway/operational_example.py diff --git a/examples/device/claiming_device_pe_only.py b/examples/device/DEPRECATEDclaiming_device_pe_only.py similarity index 100% rename from examples/device/claiming_device_pe_only.py rename to examples/device/DEPRECATEDclaiming_device_pe_only.py diff --git a/examples/device/client_provisioning.py b/examples/device/DEPRECATEDclient_provisioning.py similarity index 100% rename from examples/device/client_provisioning.py rename to examples/device/DEPRECATEDclient_provisioning.py diff --git a/examples/device/firmware_update.py b/examples/device/DEPRECATEDfirmware_update.py similarity index 100% rename from examples/device/firmware_update.py rename to examples/device/DEPRECATEDfirmware_update.py diff --git a/examples/device/tls_connect.py b/examples/device/DEPRECATEDtls_connect.py similarity index 100% rename from examples/device/tls_connect.py rename to examples/device/DEPRECATEDtls_connect.py diff --git a/examples/device/client_rpc_request.py b/examples/device/client_rpc_request.py deleted file mode 100644 index a54b589..0000000 --- a/examples/device/client_rpc_request.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.INFO) - - -def callback(request_id, resp_body, exception=None): - if exception is not None: - logging.error("Exception: " + str(exception)) - else: - logging.info("request id: {request_id}, response body: {resp_body}".format(request_id=request_id, - resp_body=resp_body)) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - - client.connect() - # call "getTime" on server and receive result, then process it with callback - client.send_rpc_call("getTime", {}, callback) - try: - while not client.stopped: - time.sleep(1) - except KeyboardInterrupt: - client.disconnect() - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/device/handle_attribute_updates.py b/examples/device/handle_attribute_updates.py new file mode 100644 index 0000000..9ed3595 --- /dev/null +++ b/examples/device/handle_attribute_updates.py @@ -0,0 +1,45 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to handle attribute updates from ThingsBoard using the DeviceClient + +import asyncio +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.device.client import DeviceClient + +async def attribute_update_callback(update: AttributeUpdate): + print("Received attribute update:", update) + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + client.set_attribute_update_callback(attribute_update_callback) + + await client.connect() + print("Waiting for attribute updates... Press Ctrl+C to stop.") + + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print("Shutting down...") + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/handle_rpc_requests.py b/examples/device/handle_rpc_requests.py new file mode 100644 index 0000000..587057e --- /dev/null +++ b/examples/device/handle_rpc_requests.py @@ -0,0 +1,50 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to handle RPC requests from ThingsBoard using the DeviceClient + +import asyncio +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient + +async def rpc_request_callback(request: RPCRequest) -> RPCResponse: + print("Received RPC:", request) + + if request.method == "ping": + return RPCResponse.build(request_id=request.request_id, result={"pong": True}) + else: + return RPCResponse.build(request_id=request.request_id, result={"message": "Unknown method"}) + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + client.set_rpc_request_callback(rpc_request_callback) + + await client.connect() + print("Waiting for RPCs... Press Ctrl+C to stop.") + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print("Shutting down...") + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/hardware_specs_sender.py b/examples/device/hardware_specs_sender.py deleted file mode 100644 index 2a753a6..0000000 --- a/examples/device/hardware_specs_sender.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) -from tb_device_mqtt import TBDeviceMqttClient -# this example illustrates situation, where client send cpu and memory usage every 5 seconds. -# If client receives an update of uploadFrequency attribute, it changes frequency of other attributes publishing. -# Also client is listening to rpc and responds immediately to corresponding rpc methods from server -# ('getCPULoad' or 'getMemoryUsage') - -logging.basicConfig(level=logging.DEBUG) -uploadFrequency = 5 - - -# this callback changes global variable defining how often telemetry is sent -def on_upload_frequency_change(value, error): - global uploadFrequency - if "uploadFrequency" in value: - uploadFrequency = int(value["uploadFrequency"]) - elif "shared" in value and "uploadFrequency" in value["shared"]: - uploadFrequency = int(value["shared"]["uploadFrequency"]) - - -# dependently of request method we send different data back -def on_server_side_rpc_request(request_id, request_body): - print(client, request_id, request_body) - if request_body["method"] == "getCPULoad": - client.send_rpc_reply(request_id, {"CPU percent": psutil.cpu_percent(interval=0.1)}) - elif request_body["method"] == "getMemoryUsage": - client.send_rpc_reply(request_id, {"Memory": psutil.virtual_memory().percent}) - else: - print("Unknown method: " + request_body["method"]) - client.send_rpc_reply(request_id, "Unknown method: " + request_body["method"]) - - -client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") -client.set_server_side_rpc_request_handler(on_server_side_rpc_request) -client.connect() -# to fetch the latest setting for upload frequency configured on the server -client.request_attributes(shared_keys=["uploadFrequency"], callback=on_upload_frequency_change) -# to subscribe to future changes of upload frequency -client.subscribe_to_attribute(key="uploadFrequency", callback=on_upload_frequency_change) - - -def main(): - while True: - client.send_telemetry({"cpu": psutil.cpu_percent(), "memory": psutil.virtual_memory().percent}) - print("Sleeping for " + str(uploadFrequency)) - time.sleep(uploadFrequency) - - -if __name__ == '__main__': - main() diff --git a/examples/device/load.py b/examples/device/load.py index 593225c..3cac8d9 100644 --- a/examples/device/load.py +++ b/examples/device/load.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Example script to send a high load of telemetry data to ThingsBoard using the DeviceClient + import asyncio import logging import signal diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index 00c53bf..55f1b06 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -12,36 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import time - -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.INFO) - - -def on_attributes_change(result, exception=None): - # This is a callback function that will be called when we receive the response from the server - if exception is not None: - logging.error("Exception: " + str(exception)) - else: - logging.info(result) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - # Sending data to retrieve it later - client.send_attributes({"atr1": "value1", "atr2": "value2"}) - # Requesting attributes - client.request_attributes(["atr1", "atr2"], callback=on_attributes_change) - try: - # Waiting for the callback - while not client.stopped: - time.sleep(1) - except KeyboardInterrupt: - client.disconnect() - client.stop() - - -if __name__ == '__main__': - main() +# Example script to request attributes from ThingsBoard using the DeviceClient + +import asyncio +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.service.device.client import DeviceClient + +async def attribute_request_callback(response: RequestedAttributeResponse): + print("Received attribute response:", response) + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + await client.connect() + + # Request specific attributes + request = await AttributeRequest.build(["targetTemperature"], ["currentTemperature"]) + await client.send_attribute_request(request, attribute_request_callback) + + print("Attribute request sent. Waiting for response...") + await asyncio.sleep(5) + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/send_attributes.py b/examples/device/send_attributes.py new file mode 100644 index 0000000..de5cb3c --- /dev/null +++ b/examples/device/send_attributes.py @@ -0,0 +1,48 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to send attributes to ThingsBoard using the DeviceClient + +import asyncio +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.service.device.client import DeviceClient + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + await client.connect() + + # Send attribute as raw dictionary + await client.send_attributes({ + "firmwareVersion": "1.0.4", + "hardwareModel": "TB-SDK-Device" + }) + + # Send single attribute entry + await client.send_attributes(AttributeEntry("mode", "normal")) + + # Send list of attributes + await client.send_attributes([ + AttributeEntry("maxTemperature", 85), + AttributeEntry("calibrated", True) + ]) + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py new file mode 100644 index 0000000..4862769 --- /dev/null +++ b/examples/device/send_client_side_rpc.py @@ -0,0 +1,50 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to send a client-side RPC request to ThingsBoard using the DeviceClient + +import asyncio +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient + +async def rpc_response_callback(response: RPCResponse): + print("Received RPC response:", response) + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + await client.connect() + + # Send client-side RPC and wait for response + rpc_request = await RPCRequest.build("getTime", {}) + try: + response = await client.send_rpc_request(rpc_request) + print("Received response:", response) + except TimeoutError: + print("RPC request timed out") + + # Send client-side RPC with callback + rpc_request_2 = await RPCRequest.build("getStatus", {}) + await client.send_rpc_request(rpc_request_2, rpc_response_callback, wait_for_publish=False) + + await asyncio.sleep(5) + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/send_telemetry.py b/examples/device/send_telemetry.py new file mode 100644 index 0000000..320ac18 --- /dev/null +++ b/examples/device/send_telemetry.py @@ -0,0 +1,50 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example demonstrates how to send telemetry data from a device to ThingsBoard using the DeviceClient. + +import asyncio +from random import uniform, randint +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + await client.connect() + + # Send telemetry as raw dictionary + await client.send_telemetry({ + "temperature": round(uniform(20.0, 30.0), 2), + "humidity": randint(30, 70) + }) + + # Send single telemetry entry + await client.send_telemetry(TimeseriesEntry("batteryLevel", randint(0, 100))) + + # Send list of telemetry entries + entries = [ + TimeseriesEntry("vibration", 0.05), + TimeseriesEntry("speed", 123) + ] + await client.send_telemetry(entries) + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/send_telemetry_and_attr.py b/examples/device/send_telemetry_and_attr.py deleted file mode 100644 index aa4a320..0000000 --- a/examples/device/send_telemetry_and_attr.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -from time import sleep, time -logging.basicConfig(level=logging.DEBUG) - -telemetry = {"temperature": 41.9, "humidity": 69, "enabled": False, "currentFirmwareVersion": "v1.2.2"} -telemetry_as_array = [{"temperature": 42.0}, {"humidity": 70}, {"enabled": True}, {"currentFirmwareVersion": "v1.2.3"}] -telemetry_with_ts = {"ts": int(round(time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}} -telemetry_with_ts_as_array = [{"ts": 1451649600000, "values": {"temperature": 42.2, "humidity": 71}}, - {"ts": 1451649601000, "values": {"temperature": 42.3, "humidity": 72}}] -attributes = {"sensorModel": "DHT-22", "attribute_2": "value"} - -log = logging.getLogger(__name__) - - -def on_connect(client, userdata, flags, result_code, *extra_params, tb_client): - if result_code == 0: - log.info("Connected to ThingsBoard!") - else: - log.error("Failed to connect to ThingsBoard with result code: %d", result_code) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect(callback=on_connect) - - while not client.is_connected(): - sleep(1) - - # Sending data in async way - client.send_attributes(attributes) - client.send_telemetry(telemetry) - client.send_telemetry(telemetry_as_array, quality_of_service=1) - client.send_telemetry(telemetry_with_ts) - client.send_telemetry(telemetry_with_ts_as_array) - - # Waiting for data to be delivered - result = client.send_attributes(attributes) - log.info("Attribute update sent: " + str(result.rc() == TBPublishInfo.TB_ERR_SUCCESS)) - result = client.send_attributes(attributes) - log.info("Telemetry update sent: " + str(result.rc() == TBPublishInfo.TB_ERR_SUCCESS)) - - client.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/device/send_telemetry_pack.py b/examples/device/send_telemetry_pack.py deleted file mode 100644 index 571e5ec..0000000 --- a/examples/device/send_telemetry_pack.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -import time - -logging.basicConfig(level=logging.DEBUG) - -telemetry_with_ts = {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}} - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - - results = [] - result = True - - for i in range(0, 100): - results.append(client.send_telemetry( - {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}})) - - for tmp_result in results: - result &= tmp_result.get() == TBPublishInfo.TB_ERR_SUCCESS - - print("Result " + str(result)) - - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/device/subscription_to_attrs.py b/examples/device/subscription_to_attrs.py deleted file mode 100644 index 604e6c4..0000000 --- a/examples/device/subscription_to_attrs.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.DEBUG) - - -def callback(result, *args): - logging.info("Received data: %r", result) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - sub_id_1 = client.subscribe_to_attribute("frequency", callback) - sub_id_2 = client.subscribe_to_all_attributes(callback) - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - client.unsubscribe_from_attribute(sub_id_1) - client.unsubscribe_from_attribute(sub_id_2) - client.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/operational_example.py b/examples/gateway/operational_example.py new file mode 100644 index 0000000..db083f1 --- /dev/null +++ b/examples/gateway/operational_example.py @@ -0,0 +1,162 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import signal +from datetime import datetime, UTC +from random import randint, uniform + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.DEBUG) +logging.getLogger("tb_mqtt_client").setLevel(logging.DEBUG) + + +async def device_attribute_update_callback(update: AttributeUpdate): + """ + Callback function to handle device attribute updates. + :param update: The attribute update object. + """ + logger.info("Received attribute update for device %s: %s", update.device, update.attributes) + + +async def device_rpc_request_callback(device_name: str, method: str, params: dict): + """ + Callback function to handle device RPC requests. + :param device_name: Name of the device + :param method: RPC method + :param params: RPC parameters + :return: RPC response + """ + logger.info("Received RPC request for device %s: method=%s, params=%s", device_name, method, params) + + # Example response based on method + if method == "getTemperature": + return {"temperature": round(uniform(20.0, 30.0), 2)} + elif method == "setLedState": + state = params.get("state", False) + return {"success": True, "state": state} + else: + return {"error": f"Unsupported method: {method}"} + + +async def device_disconnect_callback(device_name: str): + """ + Callback function to handle device disconnections. + :param device_name: Name of the disconnected device + """ + logger.info("Device %s disconnected", device_name) + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) + except NotImplementedError: + # Windows compatibility fallback + signal.signal(sig, lambda *_: _shutdown_handler()) + + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_GATEWAY_ACCESS_TOKEN" + + client = GatewayClient(config) + client.set_device_attribute_update_callback(device_attribute_update_callback) + client.set_device_rpc_request_callback(device_rpc_request_callback) + client.set_device_disconnect_callback(device_disconnect_callback) + await client.connect() + + logger.info("Connected to ThingsBoard as gateway.") + + # Connect devices to the gateway + device_names = ["sensor-1", "sensor-2", "actuator-1"] + for device_name in device_names: + await client.gw_connect_device(device_name) + logger.info("Connected device: %s", device_name) + + while not stop_event.is_set(): + # Send device attributes + for device_name in device_names: + # Send device attributes + attributes = { + "firmwareVersion": "1.0.4", + "serialNumber": f"SN-{randint(1000, 9999)}", + "deviceType": "sensor" if "sensor" in device_name else "actuator" + } + await client.gw_send_attributes(device_name, attributes) + logger.info("Sent attributes for device %s: %s", device_name, attributes) + + # Send single attribute + single_attr = AttributeEntry("lastUpdateTime", datetime.now(UTC).isoformat()) + await client.gw_send_attributes(device_name, single_attr) + logger.info("Sent single attribute for device %s: %s", device_name, single_attr) + + # Send device telemetry + if "sensor" in device_name: + # For sensor devices + telemetry = { + "temperature": round(uniform(20.0, 30.0), 2), + "humidity": round(uniform(40.0, 80.0), 2), + "batteryLevel": randint(1, 100) + } + await client.gw_send_telemetry(device_name, telemetry) + logger.info("Sent telemetry for device %s: %s", device_name, telemetry) + + # Send single telemetry entry + single_entry = TimeseriesEntry("signalStrength", randint(-90, -30)) + await client.gw_send_telemetry(device_name, single_entry) + logger.info("Sent single telemetry entry for device %s: %s", device_name, single_entry) + else: + # For actuator devices + telemetry = { + "state": "ON" if randint(0, 1) == 1 else "OFF", + "powerConsumption": round(uniform(0.1, 5.0), 2), + "uptime": randint(1, 1000) + } + await client.gw_send_telemetry(device_name, telemetry) + logger.info("Sent telemetry for device %s: %s", device_name, telemetry) + + try: + await asyncio.wait_for(stop_event.wait(), timeout=5) + except asyncio.TimeoutError: + pass + + # Disconnect devices before shutting down + for device_name in device_names: + await client.gw_disconnect_device(device_name) + logger.info("Disconnected device: %s", device_name) + + await client.disconnect() + logger.info("Disconnected from ThingsBoard.") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("Interrupted by user.") \ No newline at end of file From 7fda91eb8eea68488fe2e2cf6309ab594a76b616 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Tue, 24 Jun 2025 12:45:53 +0300 Subject: [PATCH 22/74] Added description for configuration objects --- tb_mqtt_client/common/config_loader.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index 15b22b2..55af666 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -17,6 +17,11 @@ class DeviceConfig: + """ + Configuration class for ThingsBoard device clients. + This class loads configuration options from environment variables, allowing for flexible deployment + and easy customization of device connection settings. + """ def __init__(self): self.host: str = os.getenv("TB_HOST") self.port: int = int(os.getenv("TB_PORT", 1883)) @@ -51,6 +56,10 @@ def __repr__(self): class GatewayConfig(DeviceConfig): + """ + Configuration class for ThingsBoard gateway clients. + This class extends DeviceConfig to include additional options specific to gateways. + """ def __init__(self): super().__init__() From 14af70b5cd34757344cc6b63e1c1a906344d803b Mon Sep 17 00:00:00 2001 From: samson0v Date: Tue, 24 Jun 2025 15:39:45 +0300 Subject: [PATCH 23/74] Added provisioning and firmware update examples --- examples/device/client_provisioning.py | 45 +++++++++++++++++++++++++ examples/device/firmware_update.py | 46 ++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 examples/device/client_provisioning.py create mode 100644 examples/device/firmware_update.py diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py new file mode 100644 index 0000000..ccb1bff --- /dev/null +++ b/examples/device/client_provisioning.py @@ -0,0 +1,45 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to device provisioning using the DeviceClient + +import asyncio +from random import randint + +from tb_mqtt_client.entities.data.provisioning_request import AccessTokenProvisioningCredentials, ProvisioningRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + + +async def main(): + provisioning_credentials = AccessTokenProvisioningCredentials( + provision_device_key='YOUR_PROVISION_DEVICE_KEY', + provision_device_secret='YOUR_PROVISION_DEVICE_SECRET', + ) + provisioning_request = ProvisioningRequest('localhost', credentials=provisioning_credentials) + provisioning_response = await DeviceClient.provision(provisioning_request) + print('Provisined device config: ', provisioning_response) + + # Create a DeviceClient instance with the provisioned device config + client = DeviceClient(provisioning_response.result) + await client.connect() + + # Send single telemetry entry to provisioned device + await client.send_telemetry(TimeseriesEntry("batteryLevel", randint(0, 100))) + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py new file mode 100644 index 0000000..d0b9dfc --- /dev/null +++ b/examples/device/firmware_update.py @@ -0,0 +1,46 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to update firmware using the DeviceClient + +import asyncio + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.service.device.client import DeviceClient + + +firmware_received = asyncio.Event() + + +async def firmware_update_callback(_, payload): + print(f"Firmware update payload received: {payload}") + firmware_received.set() + + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + await client.connect() + + await client.update_firmware(on_received_callback=firmware_update_callback) + await firmware_received.wait() + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) From bde61b0d73bda3c69ad9dbeaa709c64203d7f236 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 25 Jun 2025 07:48:41 +0300 Subject: [PATCH 24/74] Fix for provisioning request with x509 credentials and added tests for json message dispatcher --- .../entities/data/provisioning_request.py | 1 + tests/service/test_json_message_dispatcher.py | 234 ++++++++++++++---- 2 files changed, 181 insertions(+), 54 deletions(-) diff --git a/tb_mqtt_client/entities/data/provisioning_request.py b/tb_mqtt_client/entities/data/provisioning_request.py index e54db43..18120b5 100644 --- a/tb_mqtt_client/entities/data/provisioning_request.py +++ b/tb_mqtt_client/entities/data/provisioning_request.py @@ -63,6 +63,7 @@ def __init__(self, provision_device_key, provision_device_secret, self.public_cert = self._load_public_cert_path(public_cert_path) self.credentials_type = ProvisioningCredentialsType.X509_CERTIFICATE + @staticmethod def _load_public_cert_path(public_cert_path): content = '' diff --git a/tests/service/test_json_message_dispatcher.py b/tests/service/test_json_message_dispatcher.py index e2575cc..9116f4d 100644 --- a/tests/service/test_json_message_dispatcher.py +++ b/tests/service/test_json_message_dispatcher.py @@ -13,78 +13,204 @@ # limitations under the License. import pytest +from unittest.mock import MagicMock, patch, mock_open +from orjson import dumps + from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher -from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType, \ + BasicProvisioningCredentials, X509ProvisioningCredentials, AccessTokenProvisioningCredentials + + +class DummyClaimRequest: + def __init__(self, secret_key="key"): + self.secret_key = secret_key + + def to_payload_format(self): + return {"secretKey": self.secret_key} + + +class DummyProvisioningRequest: + def __init__(self): + self.device_name = "dev" + self.gateway = True + self.credentials = MagicMock() + self.credentials.provision_device_key = "key" + self.credentials.provision_device_secret = "secret" + self.credentials.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN + self.credentials.access_token = "token" @pytest.fixture def dispatcher(): - return JsonMessageDispatcher(max_payload_size=512, max_datapoints=10) + return JsonMessageDispatcher() + + +def test_build_attribute_request(dispatcher): + request = MagicMock(spec=AttributeRequest) + request.request_id = 1 + request.to_payload_format.return_value = {"clientKeys": "temp", "sharedKeys": "shared"} + topic, payload = dispatcher.build_attribute_request(request) + assert topic.endswith("/1") + assert b"clientKeys" in payload + + +def test_build_attribute_request_invalid(dispatcher): + request = MagicMock(spec=AttributeRequest) + request.request_id = None + with pytest.raises(ValueError): + dispatcher.build_attribute_request(request) + + +def test_build_claim_request(dispatcher): + req = DummyClaimRequest() + topic, payload = dispatcher.build_claim_request(req) + assert topic == mqtt_topics.DEVICE_CLAIM_TOPIC + assert b"secretKey" in payload + + +def test_build_claim_request_invalid(dispatcher): + req = DummyClaimRequest(secret_key=None) # Simulating an invalid request # noqa + with pytest.raises(ValueError): + dispatcher.build_claim_request(req) + + +def test_build_rpc_request(dispatcher): + request = MagicMock(spec=RPCRequest) + request.request_id = 42 + request.to_payload_format.return_value = {"method": "reboot"} + topic, payload = dispatcher.build_rpc_request(request) + assert topic.endswith("42") + assert b"reboot" in payload + + +def test_build_rpc_request_invalid(dispatcher): + request = MagicMock(spec=RPCRequest) + request.request_id = None + with pytest.raises(ValueError): + dispatcher.build_rpc_request(request) + + +def test_build_rpc_response(dispatcher): + response = MagicMock(spec=RPCResponse) + response.request_id = 123 + response.to_payload_format.return_value = {"result": "ok"} + topic, payload = dispatcher.build_rpc_response(response) + assert topic.endswith("123") + assert b"ok" in payload + + +def test_build_rpc_response_invalid(dispatcher): + response = MagicMock(spec=RPCResponse) + response.request_id = None + with pytest.raises(ValueError): + dispatcher.build_rpc_response(response) + + +def test_build_provision_request_access_token(dispatcher): + credentials = AccessTokenProvisioningCredentials("key1", "secret1", access_token="tokenABC") + req = ProvisioningRequest("localhost", credentials, device_name="dev1", gateway=True) + topic, payload = dispatcher.build_provision_request(req) + assert topic == mqtt_topics.PROVISION_REQUEST_TOPIC + assert b"provisionDeviceKey" in payload + assert b"tokenABC" in payload + assert b"credentialsType" in payload + assert b"deviceName" in payload + assert b"gateway" in payload + + +def test_build_provision_request_mqtt_basic(dispatcher): + credentials = BasicProvisioningCredentials("key2", "secret2", client_id="cid", username="user", password="pass") + req = ProvisioningRequest("127.0.0.1", credentials, device_name="dev2", gateway=False) + topic, payload = dispatcher.build_provision_request(req) + assert b"clientId" in payload + assert b"username" in payload + assert b"password" in payload + assert b"credentialsType" in payload + + +def test_build_provision_request_x509(dispatcher): + cert_path = "/fake/path/cert.pem" + cert_content = "-----BEGIN CERTIFICATE-----\nFAKECERT\n-----END CERTIFICATE-----" + with patch("builtins.open", mock_open(read_data=cert_content)): + credentials = X509ProvisioningCredentials("key3", "secret3", "key.pem", cert_path, "ca.pem") + req = ProvisioningRequest("iot.server", credentials, device_name="dev3") + topic, payload = dispatcher.build_provision_request(req) + assert b"hash" in payload + assert b"credentialsType" in payload + assert b"FAKECERT" in payload + + +def test_build_provision_request_x509_file_not_found(dispatcher): + with patch("builtins.open", side_effect=FileNotFoundError): + with pytest.raises(FileNotFoundError): + X509ProvisioningCredentials("key", "secret", "k.pem", "nonexistent.pem", "ca.pem") + + +def test_parse_attribute_request_response(dispatcher): + topic = "v1/devices/me/attributes/response/42" + payload = dumps({"shared": {"temp": 22}}) + with patch.object(RequestedAttributeResponse, "from_dict", return_value="ok") as mock: + result = dispatcher.parse_attribute_request_response(topic, payload) + assert result == "ok" + mock.assert_called_once() + + +def test_parse_attribute_request_response_invalid(dispatcher): + topic = "v1/devices/me/attributes/response/bad" + with pytest.raises(ValueError): + dispatcher.parse_attribute_request_response(topic, b"invalid") -def test_single_telemetry_dispatch(dispatcher): - builder = DeviceUplinkMessageBuilder().set_device_name("dev1") - builder.add_telemetry(TimeseriesEntry("temp", 25)) - msg = builder.build() +def test_parse_attribute_update(dispatcher): + payload = dumps({"shared": {"humidity": 60}}) + with patch.object(AttributeUpdate, "_deserialize_from_dict", return_value="AU"): + result = dispatcher.parse_attribute_update(payload) + assert result == "AU" - payloads = dispatcher.build_uplink_payloads([msg]) - assert len(payloads) == 1 - topic, payload, count = payloads[0] - assert topic == DEVICE_TELEMETRY_TOPIC - assert b"dev1" in payload - assert count == 1 +def test_parse_attribute_update_invalid(dispatcher): + with pytest.raises(ValueError): + dispatcher.parse_attribute_update(b"{bad}") -def test_single_attribute_dispatch(dispatcher): - builder = DeviceUplinkMessageBuilder().set_device_name("dev2") - builder.add_attributes(AttributeEntry("mode", "auto")) - msg = builder.build() - payloads = dispatcher.build_uplink_payloads([msg]) - assert len(payloads) == 1 - topic, payload, count = payloads[0] - assert topic == DEVICE_ATTRIBUTES_TOPIC - assert b"dev2" in payload - assert count == 1 +def test_parse_rpc_request(dispatcher): + topic = "v1/devices/me/rpc/request/123" + payload = dumps({"params": {"a": 1}}) + with patch.object(RPCRequest, "_deserialize_from_dict", return_value="REQ"): + assert dispatcher.parse_rpc_request(topic, payload) == "REQ" -def test_multiple_devices_grouping(dispatcher): - b1 = DeviceUplinkMessageBuilder().set_device_name("dev1") - b1.add_telemetry(TimeseriesEntry("t1", 1)) - b2 = DeviceUplinkMessageBuilder().set_device_name("dev2") - b2.add_telemetry(TimeseriesEntry("t2", 2)) +def test_parse_rpc_request_invalid(dispatcher): + topic = "v1/devices/me/rpc/request/NaN" + with pytest.raises(ValueError): + dispatcher.parse_rpc_request(topic, b"{}") - payloads = dispatcher.build_uplink_payloads([b1.build(), b2.build()]) - assert len(payloads) == 2 - for topic, payload, count in payloads: - assert topic == DEVICE_TELEMETRY_TOPIC - assert count == 1 +def test_parse_rpc_response(dispatcher): + topic = "v1/devices/me/rpc/response/999" + payload = dumps({"value": "done"}) + with patch.object(RPCResponse, "build", return_value="RSP"): + assert dispatcher.parse_rpc_response(topic, payload) == "RSP" -def test_large_telemetry_split(dispatcher): - builder = DeviceUplinkMessageBuilder().set_device_name("splittest") - for i in range(15): - builder.add_telemetry(TimeseriesEntry(f"key{i}", i)) - payloads = dispatcher.build_uplink_payloads([builder.build()]) - assert len(payloads) > 1 - for topic, payload, count in payloads: - assert topic == DEVICE_TELEMETRY_TOPIC - assert count <= dispatcher.splitter.max_datapoints +def test_parse_rpc_response_with_error(dispatcher): + topic = "v1/devices/me/rpc/response/888" + error = ValueError("fail") + with patch.object(RPCResponse, "build", return_value="ERR"): + assert dispatcher.parse_rpc_response(topic, error) == "ERR" -def test_large_attributes_split(): - dispatcher = JsonMessageDispatcher(max_payload_size=200) +def test_parse_rpc_response_invalid(dispatcher): + topic = "v1/devices/me/rpc/response/NaN" + with pytest.raises(ValueError): + dispatcher.parse_rpc_response(topic, b"bad") - builder = DeviceUplinkMessageBuilder().set_device_name("splitattr") - for i in range(20): - builder.add_attributes(AttributeEntry(f"k{i}", "x" * 50)) # Increase size - payloads = dispatcher.build_uplink_payloads([builder.build()]) - assert len(payloads) > 1 # Now expect splitting - for topic, payload, count in payloads: - assert topic == DEVICE_ATTRIBUTES_TOPIC - assert count > 0 \ No newline at end of file +if __name__ == '__main__': + pytest.main([__file__]) From f2d977e1655cfc275b13efcb73e825c2ae8354b4 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 25 Jun 2025 08:16:23 +0300 Subject: [PATCH 25/74] Added tests for message splitter --- tests/service/test_message_splitter.py | 246 +++++++++++++++++-------- 1 file changed, 172 insertions(+), 74 deletions(-) diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index 1a8ea3b..bffc00e 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -12,82 +12,180 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import pytest -from tb_mqtt_client.service.message_splitter import MessageSplitter -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder -from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry - - -@pytest.mark.parametrize("max_payload_size,max_datapoints", [(100, 3)]) -def test_split_large_telemetry(max_payload_size, max_datapoints): - splitter = MessageSplitter(max_payload_size=max_payload_size, max_datapoints=max_datapoints) - - builder = DeviceUplinkMessageBuilder().set_device_name("device1") - for i in range(10): - builder.add_telemetry(TimeseriesEntry(f"k{i}", i)) - - message = builder.build() - split = splitter.split_timeseries([message]) - - assert len(split) > 1 - total_ts = sum(len(m.timeseries) for m in split) - assert total_ts == 10 - for m in split: - assert m.device_name == "device1" - - -def test_split_large_attributes(): - splitter = MessageSplitter(max_payload_size=100) - - builder = DeviceUplinkMessageBuilder().set_device_name("deviceA") - for i in range(20): - builder.add_attributes(AttributeEntry(f"attr{i}", "val" * 10)) +from unittest.mock import MagicMock, patch - message = builder.build() - split = splitter.split_attributes([message]) - - assert len(split) > 1 - total_attrs = sum(len(m.attributes) for m in split) - assert total_attrs == 20 - for m in split: - assert m.device_name == "deviceA" - - -def test_no_split_needed(): - splitter = MessageSplitter(max_payload_size=10000, max_datapoints=100) - - builder = DeviceUplinkMessageBuilder().set_device_name("simpleDevice") - builder.add_telemetry(TimeseriesEntry("temp", 23)) - builder.add_attributes(AttributeEntry("fw", "1.0.0")) - - message = builder.build() - result_attr = splitter.split_attributes([message]) - result_ts = splitter.split_timeseries([message]) - - assert len(result_attr) == 1 - assert len(result_ts) == 1 - assert result_attr[0].device_name == "simpleDevice" - assert result_ts[0].device_name == "simpleDevice" - assert len(result_attr[0].attributes) == 1 - assert len(result_ts[0].timeseries) == 1 - - -def test_mixed_split(): - splitter = MessageSplitter(max_payload_size=120, max_datapoints=2) - builder = DeviceUplinkMessageBuilder().set_device_name("mixed") +from tb_mqtt_client.service.message_splitter import MessageSplitter - # Increase attribute value size to ensure payload size > 120 - for i in range(5): - builder.add_attributes(AttributeEntry(f"a{i}", "x" * 50)) - builder.add_telemetry(TimeseriesEntry(f"t{i}", i)) - msg = builder.build() - result_attr = splitter.split_attributes([msg]) - result_ts = splitter.split_timeseries([msg]) +@pytest.fixture +def splitter(): + return MessageSplitter(max_payload_size=100, max_datapoints=3) + + +def mock_ts_entry(size=20): + entry = MagicMock() + entry.size = size + entry.ts = 123456789 + entry.key = "k" + entry.value = 42 + return entry + + +def mock_attr_entry(size=20): + attr = MagicMock() + attr.size = size + attr.key = "k" + attr.value = 42 + return attr + + +# Positive cases +def test_single_small_timeseries_pass_through(splitter): + msg = MagicMock() + msg.has_timeseries.return_value = True + msg.size = 50 + msg.attributes_datapoint_count.return_value = 0 + msg.timeseries_datapoint_count.return_value = 2 + result = splitter.split_timeseries([msg]) + assert result == [msg] + + +def test_single_small_attributes_pass_through(splitter): + msg = MagicMock() + msg.has_attributes.return_value = True + msg.size = 50 + msg.attributes_datapoint_count.return_value = 2 + msg.timeseries_datapoint_count.return_value = 0 + result = splitter.split_attributes([msg]) + assert result == [msg] + + +# Negative test: invalid payload size and datapoints +def test_invalid_config_defaults(): + splitter = MessageSplitter(max_payload_size=-10, max_datapoints=-100) + assert splitter.max_payload_size == 65535 + assert splitter.max_datapoints == 0 - # Assert that the attributes and telemetry are split into multiple messages - assert len(result_attr) > 1, f"Expected split, got {len(result_attr)}" - assert len(result_ts) > 1, f"Expected split, got {len(result_ts)}" - assert sum(len(r.attributes) for r in result_attr) == 5 - assert sum(len(r.timeseries) for r in result_ts) == 5 + +# Negative test: empty message list +def test_empty_list_returns_empty(splitter): + assert splitter.split_timeseries([]) == [] + assert splitter.split_attributes([]) == [] + + +# Negative test: message without required fields +def test_malformed_message_handling(splitter): + msg_ts = MagicMock() + msg_ts.has_timeseries.side_effect = Exception("Malformed TS field") + msg_ts.attributes_datapoint_count.return_value = 0 + msg_ts.timeseries_datapoint_count.return_value = 0 + msg_ts.size = 200 + + with pytest.raises(Exception, match="Malformed TS field"): + splitter.split_timeseries([msg_ts]) + + msg_attr = MagicMock() + msg_attr.has_attributes.side_effect = Exception("Malformed Attr field") + msg_attr.attributes_datapoint_count.return_value = 0 + msg_attr.timeseries_datapoint_count.return_value = 0 + msg_attr.size = 200 + + with pytest.raises(Exception, match="Malformed Attr field"): + splitter.split_attributes([msg_attr]) + + +# Negative test: builder fails on build() +@patch("tb_mqtt_client.service.message_splitter.DeviceUplinkMessageBuilder") +def test_builder_failure_during_split_raises(mock_builder_class): + entry = MagicMock() + entry.size = 10 + + message = MagicMock() + message.device_name = "dev" + message.device_profile = "prof" + message.has_timeseries.return_value = True + message.timeseries = {"temp": [entry] * 4} + message.get_delivery_futures.return_value = [] + message.attributes_datapoint_count.return_value = 0 + message.timeseries_datapoint_count.return_value = 4 + message.size = 50 + + builder_instance = MagicMock() + builder_instance.set_device_name.return_value = builder_instance + builder_instance.set_device_profile.return_value = builder_instance + builder_instance.add_telemetry.return_value = None + builder_instance._timeseries = [entry] + builder_instance.build.side_effect = RuntimeError("build failed") + mock_builder_class.return_value = builder_instance + + splitter = MessageSplitter(max_payload_size=20, max_datapoints=2) + + with pytest.raises(RuntimeError, match="build failed"): + splitter.split_timeseries([message]) + + +# Negative test: one of delivery futures fails +@pytest.mark.asyncio +@patch("tb_mqtt_client.service.message_splitter.DeviceUplinkMessageBuilder") +async def test_delivery_future_failure_propagation(mock_builder_class, splitter): + entry = mock_ts_entry() + message = MagicMock() + message.device_name = "deviceX" + message.device_profile = "profileX" + message.has_timeseries.return_value = True + message.timeseries = {"data": [entry] * 4} + message.attributes_datapoint_count.return_value = 0 + message.timeseries_datapoint_count.return_value = 4 + message.size = 50 + + main_future = asyncio.Future() + message.get_delivery_futures.return_value = [main_future] + + fail_future = asyncio.Future() + ok_future = asyncio.Future() + + built_msg1 = MagicMock() + built_msg1.get_delivery_futures.return_value = [fail_future] + + built_msg2 = MagicMock() + built_msg2.get_delivery_futures.return_value = [ok_future] + + builder = MagicMock() + builder.build.side_effect = [built_msg1, built_msg2] + mock_builder_class.return_value = builder + + result = splitter.split_timeseries([message]) + assert len(result) == 2 + + await asyncio.sleep(0) + fail_future.set_result(False) + ok_future.set_result(True) + + await asyncio.sleep(0.1) + assert main_future.done() + assert main_future.result() is True + + +# Property validation +def test_payload_setter_validation(): + s = MessageSplitter() + s.max_payload_size = 12345 + assert s.max_payload_size == 12345 + s.max_payload_size = 0 + assert s.max_payload_size == 65535 + + +def test_datapoint_setter_validation(): + s = MessageSplitter() + s.max_datapoints = 99 + assert s.max_datapoints == 99 + s.max_datapoints = 0 + assert s.max_datapoints == 0 + s.max_datapoints = -5 + assert s.max_datapoints == 0 + + +if __name__ == '__main__': + pytest.main([__file__]) From f1209fed070ef837e9e9bac01596387261bcc10b Mon Sep 17 00:00:00 2001 From: samson0v Date: Wed, 25 Jun 2025 08:23:51 +0300 Subject: [PATCH 26/74] Added error handling for provisioning, added timeout for firmware update --- examples/device/client_provisioning.py | 5 +++++ examples/device/firmware_update.py | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index ccb1bff..e0d1b38 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -29,6 +29,11 @@ async def main(): ) provisioning_request = ProvisioningRequest('localhost', credentials=provisioning_credentials) provisioning_response = await DeviceClient.provision(provisioning_request) + + if provisioning_response.error is not None: + print(f"Provisioning failed: {provisioning_response.error}") + return + print('Provisined device config: ', provisioning_response) # Create a DeviceClient instance with the provisioned device config diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py index d0b9dfc..8ec7380 100644 --- a/examples/device/firmware_update.py +++ b/examples/device/firmware_update.py @@ -15,12 +15,14 @@ # Example script to update firmware using the DeviceClient import asyncio +from time import monotonic from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.service.device.client import DeviceClient firmware_received = asyncio.Event() +firmware_update_timeout = 30 async def firmware_update_callback(_, payload): @@ -37,7 +39,10 @@ async def main(): await client.connect() await client.update_firmware(on_received_callback=firmware_update_callback) - await firmware_received.wait() + + update_started = monotonic() + while not firmware_received.is_set() and monotonic() - update_started < firmware_update_timeout: + await asyncio.sleep(1) await client.stop() From 59951311603419c3174c8a9378f108d0639fa66c Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 25 Jun 2025 09:11:42 +0300 Subject: [PATCH 27/74] Removed deprecated examples --- .../device/DEPRECATEDclient_provisioning.py | 70 ------------------- examples/device/DEPRECATEDfirmware_update.py | 37 ---------- 2 files changed, 107 deletions(-) delete mode 100644 examples/device/DEPRECATEDclient_provisioning.py delete mode 100644 examples/device/DEPRECATEDfirmware_update.py diff --git a/examples/device/DEPRECATEDclient_provisioning.py b/examples/device/DEPRECATEDclient_provisioning.py deleted file mode 100644 index cf1886b..0000000 --- a/examples/device/DEPRECATEDclient_provisioning.py +++ /dev/null @@ -1,70 +0,0 @@ - -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -logging.basicConfig(level=logging.DEBUG) - - -def main(): - """ - We can provide the following parameters to provisioning function: - host - required - Host of ThingsBoard - provision_device_key - required - device provision key from device profile - provision_device_secret - required - device provision secret from device profile - port=1883 - not required - MQTT port of ThingsBoard instance - device_name=None - may be generated on ThingsBoard - You may pass here name for device, if this parameter is not assigned, the name will be generated - - ### Credentials type = ACCESS_TOKEN - - access_token=None - may be generated on ThingsBoard - You may pass here some access token and it will be saved as accessToken for device on ThingsBoard. - - ### Credentials type = MQTT_BASIC - - client_id=None - not required (if username is not None) - You may pass here client Id for your device and use it later for connecting - username=None - not required (if client id is not None) - You may pass here username for your client and use it later for connecting - password=None - not required - You may pass here password and use it later for connecting - - ### Credentials type = X509_CERTIFICATE - hash=None - required (If you want to use this credentials type) - You should pass here public key of the device, generated from mqttserver.jks - - """ - - # Call device provisioning, to do this we don't need an instance of the TBDeviceMqttClient to provision device - - THINGSBOARD_HOST = "mqtt.thingsboard.cloud" - - credentials = TBDeviceMqttClient.provision(THINGSBOARD_HOST, "PROVISION_DEVICE_KEY", "PROVISION_DEVICE_SECRET") - - if credentials is not None and credentials.get("status") == "SUCCESS": - username = None - password = None - client_id = None - if credentials["credentialsType"] == "ACCESS_TOKEN": - username = credentials["credentialsValue"] - elif credentials["credentialsType"] == "MQTT_BASIC": - username = credentials["credentialsValue"]["userName"] - password = credentials["credentialsValue"]["password"] - client_id = credentials["credentialsValue"]["clientId"] - - client = TBDeviceMqttClient(THINGSBOARD_HOST, username=username, password=password, client_id=client_id) - client.connect() - # Other code - - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/device/DEPRECATEDfirmware_update.py b/examples/device/DEPRECATEDfirmware_update.py deleted file mode 100644 index 4a84ce3..0000000 --- a/examples/device/DEPRECATEDfirmware_update.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -from tb_device_mqtt import TBDeviceMqttClient, FW_STATE_ATTR - -logging.basicConfig(level=logging.INFO) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - - client.get_firmware_update() - - # Waiting for firmware to be delivered - while not client.current_firmware_info[FW_STATE_ATTR] == 'UPDATED': - time.sleep(1) - - client.disconnect() - client.stop() - - -if __name__ == '__main__': - main() From 10f22ba6f0d546300955df3c4dfbc2a7c0116889 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 25 Jun 2025 09:12:01 +0300 Subject: [PATCH 28/74] Adjusted configuration for tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cf309c5..b5c44ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ python_files = ["test_*.py"] python_classes = ["Test*", "*Tests"] python_functions = ["test_*"] asyncio_mode = "auto" -asyncio_fixture_scope = "function" +asyncio_default_fixture_loop_scope = "function" [tool.black] line-length = 100 From a157f0a6596c4682bea214785565acc4136f4191 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 25 Jun 2025 09:12:19 +0300 Subject: [PATCH 29/74] Added tests for MQTTManager --- tests/service/test_mqtt_manager.py | 264 ++++++++++++++++++++--------- 1 file changed, 180 insertions(+), 84 deletions(-) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 4deb85e..625ccac 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -14,137 +14,233 @@ import asyncio import pytest -from unittest.mock import MagicMock +import pytest_asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from time import monotonic -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT from tb_mqtt_client.service.mqtt_manager import MQTTManager -from gmqtt import Message +from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler + + +@pytest_asyncio.fixture +async def setup_manager(): + stop_event = asyncio.Event() + message_dispatcher = MagicMock(spec=MessageDispatcher) + on_connect = AsyncMock() + on_disconnect = AsyncMock() + on_publish_result = AsyncMock() + rate_limits_handler = AsyncMock() + rpc_response_handler = MagicMock(spec=RPCResponseHandler) + + manager = MQTTManager( + client_id="test-client", + main_stop_event=stop_event, + message_dispatcher=message_dispatcher, + on_connect=on_connect, + on_disconnect=on_disconnect, + on_publish_result=on_publish_result, + rate_limits_handler=rate_limits_handler, + rpc_response_handler=rpc_response_handler + ) + return manager, stop_event, message_dispatcher, on_connect, on_disconnect, on_publish_result, rate_limits_handler, rpc_response_handler -@pytest.fixture -def mqtt_manager(): - manager = MQTTManager("test-client") - manager._client._connection = MagicMock() - manager._client._persistent_storage = MagicMock() - return manager +@pytest.mark.asyncio +async def test_connect_sets_connect_params(setup_manager): + manager, *_ = setup_manager + await manager.connect("localhost", 1883, "user", "pass", tls=False) + assert manager._connect_params[:4] == ("localhost", 1883, "user", "pass") + + +@pytest.mark.asyncio +async def test_is_connected_returns_false_if_not_ready(setup_manager): + manager, *_ = setup_manager + assert not manager.is_connected() + + +@pytest.mark.asyncio +async def test_register_and_unregister_handler(setup_manager): + manager, *_ = setup_manager + async def dummy(topic, payload): pass + manager.register_handler("topic/+", dummy) + assert "topic/+" in manager._handlers + manager.unregister_handler("topic/+") + assert "topic/+" not in manager._handlers @pytest.mark.asyncio -async def test_publish_acknowledgment(mqtt_manager): - fake_mid = 101 - message = Message("v1/devices/me/telemetry", b'{"temp":25}', qos=1) +async def test_on_disconnect_internal_abnormal_disconnect(setup_manager): + manager, *_ = setup_manager - mqtt_manager._client._connection.publish = MagicMock(return_value=(fake_mid, b"fake_package")) - mqtt_manager.set_rate_limits(message_rate_limit=RateLimit("0:0"), - telemetry_message_rate_limit=None, - telemetry_dp_rate_limit=None) + fut1 = asyncio.Future() + fut2 = asyncio.Future() + manager._pending_publishes[101] = (fut1, "topic1", 1, 100, monotonic()) + manager._pending_publishes[102] = (fut2, "topic2", 1, 100, monotonic()) - future = await mqtt_manager.publish(message) + manager._backpressure = MagicMock() + manager._backpressure.notify_disconnect = MagicMock() - assert fake_mid in mqtt_manager._pending_publishes - mqtt_manager._on_publish_internal(mqtt_manager._client, fake_mid) + manager._on_disconnect_internal(manager._client, reason_code=1) - result = await future - assert result is True - assert fake_mid not in mqtt_manager._pending_publishes + assert fut1.done() and fut2.done() + + manager._backpressure.notify_disconnect.assert_called() @pytest.mark.asyncio -async def test_subscribe_acknowledgment(mqtt_manager): - fake_mid = 202 - mqtt_manager._client._connection.subscribe = MagicMock(return_value=fake_mid) +async def test_handle_puback_reason_code_unknown_id(setup_manager): + manager, *_ = setup_manager + # Should not raise or fail if ID not tracked + manager._handle_puback_reason_code(999, 0, {}) + - future = await mqtt_manager.subscribe("v1/devices/me/rpc/request/+") +@pytest.mark.asyncio +async def test_on_message_internal_handler_exception(setup_manager): + manager, *_ = setup_manager - assert fake_mid in mqtt_manager._pending_subscriptions - mqtt_manager._on_subscribe_internal(mqtt_manager._client, fake_mid, 1, None) + async def bad_handler(topic, payload): + raise ValueError("oops") - result = await future - assert result is True - assert fake_mid not in mqtt_manager._pending_subscriptions + manager.register_handler("test/topic", bad_handler) + manager._on_message_internal(manager._client, "test/topic", b"{}", 0, {}) + await asyncio.sleep(0.05) # Let async task run +def test_match_topic_full_wildcard(): + assert MQTTManager._match_topic("#", "any/depth/of/topic") + @pytest.mark.asyncio -async def test_unsubscribe_acknowledgment(mqtt_manager): - fake_mid = 303 - mqtt_manager._client._connection.unsubscribe = MagicMock(return_value=fake_mid) +async def test_publish_fails_without_rate_limits(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = False + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + with pytest.raises(RuntimeError, match="Cannot publish before rate limits are retrieved."): + await manager.publish("topic", b"payload") - future = await mqtt_manager.unsubscribe("v1/devices/me/rpc/request/+") - assert fake_mid in mqtt_manager._pending_unsubscriptions - mqtt_manager._on_unsubscribe_internal(mqtt_manager._client, fake_mid) +@pytest.mark.asyncio +async def test_publish_force_bypasses_limits(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._rate_limits_ready_event.set() - result = await future - assert result is True - assert fake_mid not in mqtt_manager._pending_unsubscriptions + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (10, b"packet") + manager._client._persistent_storage = MagicMock() + + result = await manager.publish("topic", b"payload", qos=1, force=True) + assert isinstance(result, asyncio.Future) @pytest.mark.asyncio -async def test_topic_handler_matching(mqtt_manager): +async def test_on_disconnect_internal_clears_futures(setup_manager): + manager, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[42] = (fut, "topic", 1, 100, monotonic()) + manager._on_disconnect_internal(manager._client, reason_code=0) + assert not manager._pending_publishes + assert fut.done() + assert isinstance(fut.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_on_message_internal_triggers_handler(setup_manager): + manager, *_ = setup_manager called = asyncio.Event() - async def handler(topic, payload): - assert topic == "v1/devices/me/rpc/request/42" - assert payload == b"payload" + async def dummy_handler(topic, payload): called.set() - mqtt_manager.register_handler("v1/devices/me/rpc/request/+", handler) + manager.register_handler("foo/bar", dummy_handler) + manager._on_message_internal(manager._client, "foo/bar", b"123", 1, {}) + await asyncio.wait_for(called.wait(), timeout=1) - mqtt_manager._on_message_internal( - mqtt_manager._client, - topic="v1/devices/me/rpc/request/42", - payload=b"payload", - qos=1, - properties=None - ) - await asyncio.wait_for(called.wait(), timeout=1.0) +@pytest.mark.asyncio +async def test_handle_puback_reason_code(setup_manager): + manager, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[123] = (fut, "topic", 1, 100, monotonic()) + manager._handle_puback_reason_code(123, 0, {}) + assert fut.done() + assert fut.result().message_id == 123 @pytest.mark.asyncio -async def test_global_rate_limit_allows_publish(mqtt_manager): - rate_limiter = RateLimit("2:10") - mqtt_manager.set_rate_limits(rate_limiter, - telemetry_message_rate_limit=None, - telemetry_dp_rate_limit=None) +async def test_await_ready_timeout(setup_manager): + manager, stop_event, *_ = setup_manager + with patch("tb_mqtt_client.service.mqtt_manager.await_or_stop", side_effect=asyncio.TimeoutError): + await manager.await_ready(timeout=0.01) - mqtt_manager._client._connection.publish = MagicMock(return_value=(1002, b"pkg")) - future = await mqtt_manager.publish("topic", b'data') - assert isinstance(future, asyncio.Future) +@pytest.mark.asyncio +async def test_set_rate_limits_allows_ready(setup_manager): + manager, *_ = setup_manager + mock_limit = MagicMock() + manager.set_rate_limits(mock_limit, None, None) + assert manager._rate_limits_ready_event.is_set() @pytest.mark.asyncio -async def test_gateway_rate_limit_per_device_allows(mqtt_manager): - rl_a = RateLimit("2:60") - rl_b = RateLimit("5:60") +async def test_match_topic_logic(): + assert MQTTManager._match_topic("foo/+", "foo/bar") + assert not MQTTManager._match_topic("foo/bar", "foo/bar/baz") + assert MQTTManager._match_topic("foo/#", "foo/bar/baz") - mqtt_manager.set_rate_limits({"deviceA": rl_a, "deviceB": rl_b}, - telemetry_message_rate_limit=None, - telemetry_dp_rate_limit=None) - mqtt_manager._client._connection.publish = MagicMock(return_value=(1004, b"pkg")) - future = await mqtt_manager.publish("topic", b'data') - assert isinstance(future, asyncio.Future) +@pytest.mark.asyncio +async def test_check_pending_publishes_timeout(setup_manager): + manager, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[1] = (fut, "topic", 1, 100, monotonic() - 20) + await manager.check_pending_publishes(monotonic()) + assert fut.done() + assert fut.result().reason_code == 408 + + +@pytest.mark.asyncio +async def test_disconnect_swallows_reset_error(setup_manager): + manager, *_ = setup_manager + with patch.object(manager._client, "disconnect", side_effect=ConnectionResetError): + await manager.disconnect() @pytest.mark.asyncio -async def test_publish_before_rate_limit_not_retrieved(mqtt_manager): - mqtt_manager._client._connection.publish = MagicMock(return_value=(1005, b"pkg")) - mqtt_manager._client._persistent_storage = MagicMock() +async def test_subscribe_adds_future(setup_manager): + manager, *_ = setup_manager + manager._client._connection = MagicMock() + manager._client._connection.subscribe.return_value = 42 + + mock_rate_limit = AsyncMock() + setattr(manager, "_MQTTManager__rate_limiter", {MESSAGES_RATE_LIMIT: mock_rate_limit}) + + fut = await manager.subscribe("topic", qos=1) + + assert 42 in manager._pending_subscriptions + assert isinstance(fut, asyncio.Future) + mock_rate_limit.consume.assert_awaited_once() - # Should raise unless it's the special rate-limit request - with pytest.raises(RuntimeError, match="Cannot publish before rate limits"): - await mqtt_manager.publish("v1/devices/me/telemetry", b'data') @pytest.mark.asyncio -async def test_publish_during_rate_limit_request_allowed(mqtt_manager): - # Simulate internal state for initial rate limit request - mqtt_manager._client._connection.publish = MagicMock(return_value=(1006, b"pkg")) - mqtt_manager._client._persistent_storage = MagicMock() - mqtt_manager._MQTTManager__is_waiting_for_rate_limits_publish = True - mqtt_manager._rate_limits_ready_event.set() +async def test_unsubscribe_adds_future(setup_manager): + manager, *_ = setup_manager + manager._client._connection = MagicMock() + manager._client._connection.unsubscribe.return_value = 77 + + mock_rate_limit = AsyncMock() + setattr(manager, "_MQTTManager__rate_limiter", {MESSAGES_RATE_LIMIT: mock_rate_limit}) + + fut = await manager.unsubscribe("topic") + + assert 77 in manager._pending_unsubscriptions + assert isinstance(fut, asyncio.Future) + mock_rate_limit.consume.assert_awaited_once() + - future = await mqtt_manager.publish("v1/devices/me/rpc/request/1", b'data') - assert isinstance(future, asyncio.Future) \ No newline at end of file +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file From 4088ff8d1a6e9d72df6b87ae5f182cddf66ac483 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 25 Jun 2025 12:09:40 +0300 Subject: [PATCH 30/74] Added rate limits tests --- tests/common/__init__.py | 13 ++ tests/common/test_rate_limit.py | 124 ++++++++++++++++++ .../device/test_device_client_rate_limits.py | 76 ----------- 3 files changed, 137 insertions(+), 76 deletions(-) create mode 100644 tests/common/__init__.py create mode 100644 tests/common/test_rate_limit.py delete mode 100644 tests/service/device/test_device_client_rate_limits.py diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/common/test_rate_limit.py b/tests/common/test_rate_limit.py new file mode 100644 index 0000000..8088329 --- /dev/null +++ b/tests/common/test_rate_limit.py @@ -0,0 +1,124 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest +from time import sleep + +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, GreedyTokenBucket + + +def test_greedy_token_bucket_can_consume_and_consume(): + bucket = GreedyTokenBucket(10, 1) + assert bucket.can_consume(5) is True + assert bucket.consume(5) is True + assert bucket.tokens <= 5 + + assert bucket.can_consume(6) is False + assert bucket.consume(6) is False + + sleep(1.1) + bucket.refill() + assert round(bucket.tokens) == 10 + + +def test_greedy_token_bucket_get_remaining_tokens_and_refill(): + bucket = GreedyTokenBucket(100, 2) + bucket.consume(50) + before = bucket.get_remaining_tokens() + sleep(0.5) + after = bucket.get_remaining_tokens() + assert after > before + + +@pytest.mark.asyncio +async def test_rate_limit_no_limit_behavior(): + rl = RateLimit("", "test") + assert await rl.check_limit_reached() is False + assert await rl.try_consume() is None + assert await rl.consume() is None + assert rl.minimal_limit == 0 + assert rl.minimal_timeout == 0 + assert rl.has_limit() is False + assert await rl.reach_limit() is None + assert rl.to_dict()["no_limit"] is True + + +@pytest.mark.asyncio +async def test_rate_limit_basic_limit_behavior(): + rl = RateLimit("10:1,20:2", "test") + assert rl.has_limit() is True + assert rl.minimal_limit == 10 * 0.8 + assert rl.minimal_timeout >= 2 + + for _ in range(8): + assert await rl.try_consume() is None + + warning = await rl.try_consume() + assert isinstance(warning, tuple) + assert warning[0] == 8.0 + assert warning[1] == 1.0 + + exceeded = await rl.try_consume() + assert isinstance(exceeded, tuple) + + await rl.consume() + + +@pytest.mark.asyncio +async def test_rate_limit_reach_limit_rotation(): + rl = RateLimit("1:1,1:2", "test") + await rl.reach_limit() + await rl.reach_limit() + result = await rl.reach_limit() + + assert isinstance(result, tuple) + index, timestamp, dur = result + assert index >= 0 + assert dur in (1, 2) + + +@pytest.mark.asyncio +async def test_rate_limit_set_limit_resets_state(): + rl = RateLimit("100:10", "test") + original = rl.to_dict() + await rl.set_limit("10:1", percentage=50) + new_state = rl.to_dict() + + assert new_state["rateLimits"] != original["rateLimits"] + assert rl.percentage == 50 + + +@pytest.mark.asyncio +async def test_rate_limit_invalid_format(caplog): + rl = RateLimit("invalid_format,10:xyz", "bad_config") + assert rl.has_limit() is False or isinstance(rl.to_dict(), dict) + assert any("Invalid rate limit format" in rec.message for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_rate_limit_refill_behavior(): + rl = RateLimit("3:1", "refill-test") + await rl.try_consume() + await rl.try_consume() + await rl.try_consume() + assert (await rl.try_consume()) is not None + + await asyncio.sleep(1.1) + await rl.refill() + assert (await rl.try_consume()) is None + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/device/test_device_client_rate_limits.py b/tests/service/device/test_device_client_rate_limits.py deleted file mode 100644 index bee445f..0000000 --- a/tests/service/device/test_device_client_rate_limits.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from orjson import dumps - -from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.config_loader import DeviceConfig -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit - - -@pytest.fixture -def device_client(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "test_token" - config.client_id = "" - return DeviceClient(config) - - -@pytest.mark.asyncio -async def test_handle_rate_limit_response_valid(device_client): - payload = { - "rateLimits": { - "messages": "10:1,300:60", - "telemetryMessages": "20:1,600:60", - "telemetryDataPoints": "100:1,1000:60" - }, - "maxInflightMessages": 100, - "maxPayloadSize": 2048 - } - - topic = "v1/devices/me/rpc/response/1" - await device_client._handle_rate_limit_response(topic, dumps(payload)) - - assert isinstance(device_client._messages_rate_limit, RateLimit) - assert isinstance(device_client._telemetry_rate_limit, RateLimit) - assert isinstance(device_client._telemetry_dp_rate_limit, RateLimit) - - assert device_client._messages_rate_limit.has_limit() - assert device_client._telemetry_rate_limit.has_limit() - assert device_client._telemetry_dp_rate_limit.has_limit() - - assert device_client._max_inflight_messages > 0 - assert device_client._max_queued_messages == device_client._max_inflight_messages - assert device_client.max_payload_size == int(2048 * device_client._telemetry_rate_limit.percentage / 100) - - -@pytest.mark.asyncio -async def test_handle_rate_limit_response_invalid_payload(device_client, caplog): - topic = "v1/devices/me/rpc/response/1" - await device_client._handle_rate_limit_response(topic, b'invalid-json') - - assert not device_client._messages_rate_limit.has_limit() - assert "Failed to parse rate limits" in caplog.text - - -@pytest.mark.asyncio -async def test_handle_rate_limit_response_missing_rate_limits(device_client, caplog): - payload = {"maxInflightMessages": 100} - topic = "v1/devices/me/rpc/response/1" - await device_client._handle_rate_limit_response(topic, dumps(payload)) - - assert "Invalid rate limit response" in caplog.text - assert device_client._max_inflight_messages == 100 # default fallback still applies \ No newline at end of file From c098239af4978384bfd36a168938c819ce2e04a5 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 26 Jun 2025 08:31:08 +0300 Subject: [PATCH 31/74] Moved provisioning client to common --- tb_mqtt_client/common/provision_client.py | 14 -------------- .../{entities => common}/provisioning_client.py | 0 .../{entities => common}/publish_result.py | 0 .../entities/data/device_uplink_message.py | 2 +- tb_mqtt_client/service/base_client.py | 2 +- tb_mqtt_client/service/device/client.py | 4 ++-- tb_mqtt_client/service/message_dispatcher.py | 2 +- tb_mqtt_client/service/message_queue.py | 2 +- tb_mqtt_client/service/mqtt_manager.py | 2 +- tests/service/test_mqtt_manager.py | 2 +- 10 files changed, 8 insertions(+), 22 deletions(-) delete mode 100644 tb_mqtt_client/common/provision_client.py rename tb_mqtt_client/{entities => common}/provisioning_client.py (100%) rename tb_mqtt_client/{entities => common}/publish_result.py (100%) diff --git a/tb_mqtt_client/common/provision_client.py b/tb_mqtt_client/common/provision_client.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/common/provision_client.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/entities/provisioning_client.py b/tb_mqtt_client/common/provisioning_client.py similarity index 100% rename from tb_mqtt_client/entities/provisioning_client.py rename to tb_mqtt_client/common/provisioning_client.py diff --git a/tb_mqtt_client/entities/publish_result.py b/tb_mqtt_client/common/publish_result.py similarity index 100% rename from tb_mqtt_client/entities/publish_result.py rename to tb_mqtt_client/common/publish_result.py diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index 2848fdb..ec9cfaf 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -20,7 +20,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult logger = get_logger(__name__) diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index bd17588..7235ec0 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -24,7 +24,7 @@ from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) exception_handler.install_asyncio_handler() diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index cd60789..45ba153 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -37,9 +37,9 @@ from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.provisioning_client import ProvisioningClient +from tb_mqtt_client.common.provisioning_client import ProvisioningClient from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 09151d5..110ca99 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -31,7 +31,7 @@ from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.message_splitter import MessageSplitter logger = get_logger(__name__) diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index e651074..b63feff 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -20,7 +20,7 @@ from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.message_dispatcher import MessageDispatcher from tb_mqtt_client.service.mqtt_manager import MQTTManager diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index ab67f6c..deda326 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -33,7 +33,7 @@ TELEMETRY_DATAPOINTS_RATE_LIMIT from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler from tb_mqtt_client.service.message_dispatcher import MessageDispatcher diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 625ccac..a4066bb 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -20,7 +20,7 @@ from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT from tb_mqtt_client.service.mqtt_manager import MQTTManager -from tb_mqtt_client.entities.publish_result import PublishResult +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.message_dispatcher import MessageDispatcher from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler From 6a53f0a715b8afc9726e615ef26677d84bf6902c Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 26 Jun 2025 10:37:12 +0300 Subject: [PATCH 32/74] Fix for message queue and added tests --- tb_mqtt_client/service/message_queue.py | 31 +- tests/service/device/test_firmware_updater.py | 153 +++++++ tests/service/test_message_queue.py | 382 ++++++++++++++++++ 3 files changed, 557 insertions(+), 9 deletions(-) create mode 100644 tests/service/device/test_firmware_updater.py create mode 100644 tests/service/test_message_queue.py diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index b63feff..c1b8e19 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -51,6 +51,7 @@ def __init__(self, self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) + self._pending_queue_tasks: set[asyncio.Task] = set() self._active = asyncio.Event() self._wakeup_event = asyncio.Event() self._retry_tasks: set[asyncio.Task] = set() @@ -255,7 +256,7 @@ def resolve_attached(publish_future: asyncio.Future): logger.debug("Scheduling retry for topic=%s, payload size=%d, qos=%d", topic, len(payload), qos) logger.debug("error details: %s", e, exc_info=True) - self._schedule_delayed_retry(topic, payload, datapoints, qos, delay=.1) + self._schedule_delayed_retry(topic, payload, datapoints, qos, delay=.1, delivery_futures=delivery_futures_or_none) def _schedule_delayed_retry(self, topic: str, payload: bytes, datapoints: int, qos: int, delay: float, delivery_futures: Optional[List[Optional[asyncio.Future[PublishResult]]]] = None): @@ -294,6 +295,8 @@ async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage self._wakeup_event.clear() queue_task = asyncio.create_task(self._queue.get()) + self._pending_queue_tasks.add(queue_task) + wake_task = asyncio.create_task(self._wakeup_event.wait()) done, pending = await asyncio.wait( @@ -301,11 +304,12 @@ async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage ) for task in pending: - logger.trace("Cancelling pending task: %r, it is queue_task = %r", task, queue_task==task) task.cancel() with suppress(asyncio.CancelledError): await task + self._pending_queue_tasks.discard(queue_task) + if queue_task in done: logger.trace("Retrieved message from queue: %r", queue_task.result()) return queue_task.result() @@ -322,13 +326,8 @@ async def shutdown(self): self._active.clear() self._wakeup_event.set() # Wake up the _wait_for_message if it's blocked - for task in list(self._retry_tasks): - try: - task.cancel() - with suppress(asyncio.CancelledError): - await asyncio.gather(*self._retry_tasks, return_exceptions=True) - except Exception as e: - logger.warning("Error while cancelling retry task: %s", e) + await self._cancel_tasks(self._retry_tasks) + await self._cancel_tasks(self._pending_queue_tasks) self._loop_task.cancel() if self._rate_limit_refill_task: @@ -337,10 +336,24 @@ async def shutdown(self): await self._loop_task await self._rate_limit_refill_task + while not self._queue.empty(): + try: + self._queue.get_nowait() + self._queue.task_done() + except asyncio.QueueEmpty: + break + logger.debug("MessageQueue shutdown complete, message queue size: %d", self._queue.qsize()) self.clear() + async def _cancel_tasks(self, tasks: set[asyncio.Task]): + for task in list(tasks): + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*tasks, return_exceptions=True) + tasks.clear() + def is_empty(self): return self._queue.empty() diff --git a/tests/service/device/test_firmware_updater.py b/tests/service/device/test_firmware_updater.py new file mode 100644 index 0000000..d742ef1 --- /dev/null +++ b/tests/service/device/test_firmware_updater.py @@ -0,0 +1,153 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch, ANY +from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater +from tb_mqtt_client.constants.firmware import FW_TITLE_ATTR, FW_VERSION_ATTR, FW_STATE_ATTR, FirmwareStates + + +@pytest.fixture +def mock_client(): + client = MagicMock() + client._mqtt_manager.register_handler = MagicMock() + client._mqtt_manager.subscribe = AsyncMock(return_value=asyncio.Future()) + client._mqtt_manager.subscribe.return_value.set_result(True) + client._mqtt_manager.unsubscribe = AsyncMock() + client._mqtt_manager.is_connected.return_value = True + client._message_queue.publish = AsyncMock() + client.send_telemetry = AsyncMock() + client.send_attribute_request = AsyncMock() + return client + +@pytest.fixture +def updater(mock_client): + return FirmwareUpdater(mock_client) + +@pytest.mark.asyncio +async def test_update_success(updater, mock_client): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequest.build", new_callable=AsyncMock) as mock_build, \ + patch.object(updater, "_firmware_info_callback", new=AsyncMock()): + await updater.update() + mock_client._mqtt_manager.subscribe.assert_called_once() + mock_client.send_telemetry.assert_called() + mock_client.send_attribute_request.assert_called_once() + +@pytest.mark.asyncio +async def test_update_not_connected(updater, mock_client, caplog): + mock_client._mqtt_manager.is_connected.return_value = False + await updater.update() + assert "Client is not connected" in caplog.text + +@pytest.mark.asyncio +async def test_handle_firmware_update_full(updater): + updater._target_firmware_length = 4 + updater._chunk_size = 4 + updater._firmware_data = b'ab' + payload = b'cd' + with patch.object(updater, '_verify_downloaded_firmware', new=AsyncMock()) as verify: + await updater._handle_firmware_update(None, payload) + verify.assert_awaited_once() + assert updater._firmware_data == b'abcd' + +@pytest.mark.asyncio +async def test_handle_firmware_update_partial(updater): + updater._target_firmware_length = 10 + updater._chunk_size = 5 + updater._firmware_data = b'123' + payload = b'456' + with patch.object(updater, '_get_next_chunk', new=AsyncMock()) as next_chunk: + await updater._handle_firmware_update(None, payload) + next_chunk.assert_awaited_once() + assert updater._firmware_data.endswith(payload) + +@pytest.mark.asyncio +async def test_get_next_chunk_valid(updater, mock_client): + updater._chunk_size = 5 + updater._target_firmware_length = 10 + updater._firmware_request_id = 1 + updater._current_chunk = 2 + await updater._get_next_chunk() + mock_client._message_queue.publish.assert_awaited() + +@pytest.mark.asyncio +async def test_get_next_chunk_empty_payload(updater, mock_client): + updater._chunk_size = 15 + updater._target_firmware_length = 10 + await updater._get_next_chunk() + mock_client._message_queue.publish.assert_awaited_with( + topic=ANY, + payload=b'', + datapoints_count=0, + qos=1 + ) + +@pytest.mark.asyncio +async def test_verify_downloaded_firmware_success(updater): + updater._firmware_data = b'data' + updater._target_checksum = "8d777f385d3dfec8815d20f7496026dc" + updater._target_checksum_alg = "md5" + updater.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value + with patch.object(updater, 'verify_checksum', return_value=True), \ + patch.object(updater, '_apply_downloaded_firmware', new=AsyncMock()), \ + patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): + await updater._verify_downloaded_firmware() + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.VERIFIED.value + +@pytest.mark.asyncio +async def test_verify_downloaded_firmware_fail(updater): + updater._firmware_data = b'data' + updater._target_checksum = "wrong" + updater._target_checksum_alg = "md5" + with patch.object(updater, 'verify_checksum', return_value=False), \ + patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): + await updater._verify_downloaded_firmware() + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + +@pytest.mark.asyncio +async def test_apply_downloaded_firmware_saves_file(tmp_path, updater): + updater._firmware_data = b'binary-firmware' + updater._target_title = 'fw.bin' + updater._target_version = 'v3' + updater._save_path = str(tmp_path) + updater._save_firmware = True + updater._on_received_callback = AsyncMock() + with patch.object(updater, '_send_current_firmware_info', new=AsyncMock()), \ + patch.object(updater._client._mqtt_manager, 'unsubscribe', new=AsyncMock()): + await updater._apply_downloaded_firmware() + assert (tmp_path / 'fw.bin').exists() + +def test_verify_checksum_md5_valid(updater): + result = updater.verify_checksum(b'data', 'md5', "8d777f385d3dfec8815d20f7496026dc") + assert isinstance(result, bool) + +def test_verify_checksum_invalid_algorithm(updater, caplog): + result = updater.verify_checksum(b'data', 'invalid_alg', "deadbeef") + assert result is False + assert 'Unsupported checksum algorithm' in caplog.text + +def test_is_different_versions_true(updater): + new_info = {FW_TITLE_ATTR: 'fw', FW_VERSION_ATTR: 'v2'} + assert updater._is_different_firmware_versions(new_info) is True + +def test_is_different_versions_false(updater): + updater.current_firmware_info['current_' + FW_TITLE_ATTR] = 'fw' + updater.current_firmware_info['current_' + FW_VERSION_ATTR] = 'v2' + new_info = {FW_TITLE_ATTR: 'fw', FW_VERSION_ATTR: 'v2'} + assert updater._is_different_firmware_versions(new_info) is False + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py new file mode 100644 index 0000000..63fffd0 --- /dev/null +++ b/tests/service/test_message_queue.py @@ -0,0 +1,382 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock + +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.message_queue import MessageQueue + + +@pytest.mark.asyncio +async def test_publish_raw_bytes_success(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + main_stop_event = asyncio.Event() + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + queue = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher, + ) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'test_payload', 1, qos=1) + await asyncio.sleep(0.05) + await queue.shutdown() + + mqtt_manager.publish.assert_called_once() + + +@pytest.mark.asyncio +async def test_batching_device_uplink_message(): + mqtt_manager = MagicMock() + future = asyncio.Future() + future.set_result(PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 7, 0)) + mqtt_manager.publish = AsyncMock(return_value=future) + mqtt_manager.backpressure.should_pause.return_value = False + main_stop_event = asyncio.Event() + + delivery_future = asyncio.Future() + dummy_message = MagicMock() + dummy_message.size = 10 + dummy_message.device_name = "device" + dummy_message.get_delivery_futures.return_value = [delivery_future] + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100 + dispatcher.build_uplink_payloads.return_value = [ + (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'batch_payload', 1, [delivery_future]) + ] + + queue = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher, + batch_collect_max_time_ms=50, + batch_collect_max_count=10 + ) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_message, 1, qos=1) + await asyncio.sleep(0.1) + await queue.shutdown() + + assert delivery_future.done() + assert isinstance(delivery_future.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_telemetry_rate_limit_retry_triggered(): + telemetry_limit = MagicMock() + telemetry_limit.try_consume = AsyncMock(return_value=(10, 1)) + telemetry_limit.minimal_timeout = 0.01 + telemetry_limit.to_dict.return_value = {"limit": "mock"} + + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + main_stop_event = asyncio.Event() + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + + delivery_future = asyncio.Future() + delivery_future.set_result(PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 5, 0)) + dispatcher.build_uplink_payloads.return_value = [ + (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'dummy_payload', 1, [delivery_future]) + ] + + msg = DeviceUplinkMessageBuilder() + msg.set_device_name("device") + msg.add_telemetry(TimeseriesEntry("temp", 1)) + msg = msg.build() + + queue = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + message_rate_limit=None, + telemetry_rate_limit=telemetry_limit, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher + ) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, qos=1) + await asyncio.sleep(0.1) + await queue.shutdown() + + telemetry_limit.try_consume.assert_awaited() + + +@pytest.mark.asyncio +async def test_shutdown_clears_queue(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + main_stop_event = asyncio.Event() + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + dummy_message = MagicMock() + dummy_message.device_name = "device" + dummy_message.size = 1 + dummy_message.get_delivery_futures.return_value = [] + + queue = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher + ) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_message, 1, qos=1) + await queue.shutdown() + + assert queue.is_empty() + + +@pytest.mark.asyncio +async def test_publish_raw_bytes_success(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + main_stop_event = asyncio.Event() + + queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, dispatcher) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"payload", 1, qos=1) + await asyncio.sleep(0.05) + await queue.shutdown() + mqtt_manager.publish.assert_called() + + +@pytest.mark.asyncio +async def test_publish_device_uplink_message_batched(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + main_stop_event = asyncio.Event() + + future = asyncio.Future() + dummy_msg = MagicMock() + dummy_msg.size = 10 + dummy_msg.device_name = "dev" + dummy_msg.get_delivery_futures.return_value = [future] + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100 + dispatcher.build_uplink_payloads.return_value = [ + (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batch", 1, [future]) + ] + + queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, dispatcher) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, qos=1) + await asyncio.sleep(0.1) + await queue.shutdown() + assert future.done() + + +@pytest.mark.asyncio +async def test_rate_limit_telemetry_triggers_retry(): + limit = MagicMock() + limit.try_consume = AsyncMock(return_value=(1, 1)) + limit.minimal_timeout = 0.01 + limit.to_dict.return_value = {} + + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + main_stop_event = asyncio.Event() + + msg = MagicMock() + msg.device_name = "d" + msg.size = 1 + msg.get_delivery_futures.return_value = [] + + queue = MessageQueue(mqtt_manager, main_stop_event, None, limit, None, dispatcher) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, 1) + await asyncio.sleep(0.2) + await queue.shutdown() + mqtt_manager.publish.assert_not_called() + + +@pytest.mark.asyncio +async def test_shutdown_clears_queue(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + stop_event = asyncio.Event() + + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + dummy = MagicMock() + dummy.size = 1 + dummy.get_delivery_futures.return_value = [] + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy, 1, 1) + await queue.shutdown() + assert queue.is_empty() + + +@pytest.mark.asyncio +async def test_backpressure_triggers_retry(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = True + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + stop_event = asyncio.Event() + + msg = MagicMock() + msg.size = 1 + msg.device_name = "dev" + msg.get_delivery_futures.return_value = [] + + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, 1) + await asyncio.sleep(0.1) + await queue.shutdown() + mqtt_manager.publish.assert_not_called() + + +@pytest.mark.asyncio +async def test_retry_on_exception(): + mqtt_manager = MagicMock() + publish_mock = AsyncMock() + first_attempt = RuntimeError("fail") + + second_attempt_future = asyncio.Future() + + async def complete_publish_result_later(): + await asyncio.sleep(0.1) + publish_result = PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 1, 0) + second_attempt_future.set_result(publish_result) + + asyncio.create_task(complete_publish_result_later()) + + publish_mock.side_effect = [first_attempt, second_attempt_future] + + mqtt_manager.publish = publish_mock + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + + future = asyncio.Future() + dummy_msg = MagicMock() + dummy_msg.device_name = "dev" + dummy_msg.size = 10 + dummy_msg.get_delivery_futures.return_value = [future] + + dispatcher.build_uplink_payloads.return_value = [ + (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"payload", 1, [future]) + ] + + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, + stop_event, + None, None, None, + dispatcher, + batch_collect_max_time_ms=10 + ) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, qos=1) + + await asyncio.sleep(0.5) + await queue.shutdown() + + assert mqtt_manager.publish.call_count == 2 + assert future.done() + assert isinstance(future.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_mixed_raw_and_structured_queue(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + stop_event = asyncio.Event() + + future = asyncio.Future() + uplink_msg = MagicMock() + uplink_msg.device_name = "x" + uplink_msg.size = 10 + uplink_msg.get_delivery_futures.return_value = [future] + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100 + dispatcher.build_uplink_payloads.return_value = [ + (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batched", 1, [future]) + ] + + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher, batch_collect_max_time_ms=20) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"raw", 1, 1) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, uplink_msg, 1, 1) + await asyncio.sleep(0.1) + await queue.shutdown() + assert future.done() + + +@pytest.mark.asyncio +async def test_rate_limit_refill_executes(): + r1, r2, r3 = MagicMock(), MagicMock(), MagicMock() + for r in (r1, r2, r3): + r.refill = AsyncMock() + r.to_dict.return_value = {} + + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + stop_event = asyncio.Event() + + queue = MessageQueue(mqtt_manager, stop_event, r1, r2, r3, dispatcher) + await asyncio.sleep(1.2) + await queue.shutdown() + + r1.refill.assert_awaited() + r2.refill.assert_awaited() + r3.refill.assert_awaited() + + +if __name__ == '__main__': + pytest.main([__file__]) From f78fa2eacc8e408cd171635b0f2669f9b626a44c Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 27 Jun 2025 14:47:10 +0300 Subject: [PATCH 33/74] Added tests --- tests/common/test_backpressure_controller.py | 111 +++ tests/common/test_install_package_utils.py | 90 ++ tests/common/test_provisioning_client.py | 108 ++ tests/common/test_publish_result.py | 85 ++ tests/common/test_rate_limit.py | 7 +- tests/entities/__init__.py | 13 + tests/entities/data/__init__.py | 13 + tests/entities/data/test_attribute_entry.py | 63 ++ tests/entities/data/test_attribute_request.py | 85 ++ tests/entities/data/test_attribute_update.py | 84 ++ tests/entities/data/test_data_entry.py | 95 ++ .../data/test_device_uplink_message.py | 156 +++ tests/entities/data/test_provisioning_data.py | 169 ++++ .../data/test_requested_attribute_response.py | 127 +++ tests/entities/data/test_rpc_request.py | 91 ++ tests/entities/data/test_rpc_response.py | 87 ++ tests/entities/data/test_timeseries_entry.py | 74 ++ tests/service/device/test_device_client.py | 515 ++++++++++ tests/service/device/test_firmware_updater.py | 145 ++- tests/service/test_json_message_dispatcher.py | 170 +++- tests/service/test_message_queue.py | 927 ++++++++++++++++-- tests/service/test_message_splitter.py | 57 +- tests/service/test_mqtt_manager.py | 162 ++- 23 files changed, 3285 insertions(+), 149 deletions(-) create mode 100644 tests/common/test_backpressure_controller.py create mode 100644 tests/common/test_install_package_utils.py create mode 100644 tests/common/test_provisioning_client.py create mode 100644 tests/common/test_publish_result.py create mode 100644 tests/entities/__init__.py create mode 100644 tests/entities/data/__init__.py create mode 100644 tests/entities/data/test_attribute_entry.py create mode 100644 tests/entities/data/test_attribute_request.py create mode 100644 tests/entities/data/test_attribute_update.py create mode 100644 tests/entities/data/test_data_entry.py create mode 100644 tests/entities/data/test_device_uplink_message.py create mode 100644 tests/entities/data/test_provisioning_data.py create mode 100644 tests/entities/data/test_requested_attribute_response.py create mode 100644 tests/entities/data/test_rpc_request.py create mode 100644 tests/entities/data/test_rpc_response.py create mode 100644 tests/entities/data/test_timeseries_entry.py create mode 100644 tests/service/device/test_device_client.py diff --git a/tests/common/test_backpressure_controller.py b/tests/common/test_backpressure_controller.py new file mode 100644 index 0000000..15c3d43 --- /dev/null +++ b/tests/common/test_backpressure_controller.py @@ -0,0 +1,111 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest +from datetime import datetime, timedelta, UTC + +from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController + +@pytest.fixture +def stop_event(): + return asyncio.Event() + + +@pytest.fixture +def controller(stop_event): + return BackpressureController(stop_event) + + +def test_notify_quota_exceeded_initial(controller): + controller.notify_quota_exceeded() + assert controller._pause_until is not None + assert controller._consecutive_quota_exceeded == 1 + + +def test_notify_quota_exceeded_consecutive(controller): + controller._last_quota_exceeded = datetime.now(UTC) + controller._consecutive_quota_exceeded = 2 + controller.notify_quota_exceeded() + assert controller._consecutive_quota_exceeded == 3 + assert controller._pause_until is not None + + +def test_notify_quota_exceeded_reset_after_60s(controller): + controller._last_quota_exceeded = datetime.now(UTC) - timedelta(seconds=61) + controller._consecutive_quota_exceeded = 5 + controller.notify_quota_exceeded() + assert controller._consecutive_quota_exceeded == 1 + + +def test_notify_quota_exceeded_custom_delay(controller): + controller.notify_quota_exceeded(delay_seconds=30) + assert controller._pause_until is not None + + +def test_notify_quota_exceeded_stop_event_set(stop_event, controller): + stop_event.set() + controller.notify_quota_exceeded() + assert controller._pause_until is None + + +def test_notify_disconnect_default(controller): + controller.notify_disconnect() + assert controller._pause_until is not None + + +def test_notify_disconnect_custom_delay(controller): + controller.notify_disconnect(delay_seconds=25) + assert controller._pause_until is not None + + +def test_notify_disconnect_stop_event_set(stop_event, controller): + stop_event.set() + controller.notify_disconnect() + assert controller._pause_until is None + + +def test_should_pause_active(controller): + controller._pause_until = datetime.now(UTC) + timedelta(seconds=15) + assert controller.should_pause() + + +def test_should_pause_expired(controller): + controller._pause_until = datetime.now(UTC) - timedelta(seconds=1) + assert controller.should_pause() is False + assert controller._pause_until is None + + +def test_should_pause_not_set(controller): + controller._pause_until = None + assert controller.should_pause() is False + + +def test_should_pause_stop_event_set(stop_event, controller): + stop_event.set() + controller._pause_until = datetime.now(UTC) + timedelta(seconds=30) + assert controller.should_pause() is False + + +def test_pause_for(controller): + controller.pause_for(12) + assert controller._pause_until is not None + + +def test_clear(controller): + controller._pause_until = datetime.now(UTC) + timedelta(seconds=30) + controller._consecutive_quota_exceeded = 5 + controller.clear() + assert controller._pause_until is None + assert controller._consecutive_quota_exceeded == 0 diff --git a/tests/common/test_install_package_utils.py b/tests/common/test_install_package_utils.py new file mode 100644 index 0000000..0e2a0e8 --- /dev/null +++ b/tests/common/test_install_package_utils.py @@ -0,0 +1,90 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from subprocess import CalledProcessError +from unittest import mock + +import pytest +from pkg_resources import DistributionNotFound + +from tb_mqtt_client.common.install_package_utils import install_package + + +@pytest.mark.parametrize("version", ["upgrade", "UPGRADE"]) +def test_install_package_upgrade_success(version): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", return_value=0) as mock_call: + result = install_package("somepkg", version) + assert result is True + mock_call.assert_called_once() + + +def test_install_package_upgrade_retry_success(): + # First fails with --user, then succeeds without + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=[CalledProcessError(1, ""), 0]) as mock_call: + result = install_package("somepkg", "upgrade") + assert result is True + assert mock_call.call_count == 2 + + +def test_install_package_upgrade_all_fail(): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=CalledProcessError(1, "")) as mock_call: + result = install_package("somepkg", "upgrade") + assert result is False + assert mock_call.call_count == 2 + + +def test_install_package_specific_version_already_installed(): + dist = mock.Mock() + dist.version = "1.2.3" + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", return_value=dist): + result = install_package("somepkg", "1.2.3") + assert result is True + + +@pytest.mark.parametrize("version_input,expected_arg", [ + ("1.2.4", "somepkg==1.2.4"), + (">=1.2.4", "somepkg>=1.2.4"), +]) +def test_install_package_specific_version_install_success(version_input, expected_arg): + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", side_effect=DistributionNotFound): + side_effects = [ + CalledProcessError(1, ["pip"]), + 0 + ] + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=side_effects) as mock_call: + result = install_package("somepkg", version_input) + assert result is True + calls = [str(call.args) for call in mock_call.call_args_list] + assert any(expected_arg in args for args in calls) + + +def test_install_package_specific_version_retry_success(): + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", side_effect=DistributionNotFound): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=[CalledProcessError(1, ""), 0]) as mock_call: + result = install_package("somepkg", "1.0.0") + assert result is True + assert mock_call.call_count == 2 + + +def test_install_package_specific_version_all_fail(): + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", side_effect=DistributionNotFound): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=CalledProcessError(1, "")) as mock_call: + result = install_package("somepkg", "1.0.0") + assert result is False + assert mock_call.call_count == 2 diff --git a/tests/common/test_provisioning_client.py b/tests/common/test_provisioning_client.py new file mode 100644 index 0000000..9b95431 --- /dev/null +++ b/tests/common/test_provisioning_client.py @@ -0,0 +1,108 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from tb_mqtt_client.common.provisioning_client import ProvisioningClient +from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials + + +@pytest.fixture +def real_request(): + return ProvisioningRequest( + host="test_host", + credentials=AccessTokenProvisioningCredentials("key", "secret"), + port=1883, + device_name="test_device" + ) + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageDispatcher") +async def test_successful_provisioning_flow(mock_dispatcher_cls, mock_gmqtt_cls, real_request): + mock_client = AsyncMock() + mock_gmqtt_cls.return_value = mock_client + + mock_dispatcher = MagicMock() + topic = "provision/topic" + payload = b'{"provision": "data"}' + mock_dispatcher.build_provision_request.return_value = (topic, payload) + + mock_device_config = MagicMock() + mock_dispatcher.parse_provisioning_response.return_value.result = mock_device_config + mock_dispatcher_cls.return_value = mock_dispatcher + + client = ProvisioningClient("test_host", 1883, real_request) + + client._on_connect(mock_client, None, 0, None) + + mock_client.subscribe.assert_called_once_with(PROVISION_RESPONSE_TOPIC) + mock_client.publish.assert_called_once_with(topic, payload) + + await client._on_message(None, None, b"payload-data", None, None) + + assert client._device_config == mock_device_config + assert client._provisioned.is_set() + mock_client.disconnect.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageDispatcher") +async def test_failed_connection(mock_dispatcher_cls, mock_gmqtt_cls, real_request, caplog): + mock_client = AsyncMock() + mock_gmqtt_cls.return_value = mock_client + + client = ProvisioningClient("localhost", 1883, real_request) + + with caplog.at_level("ERROR"): + client._on_connect(mock_client, None, 1, None) + + assert client._device_config is not None + assert client._device_config.status == ProvisioningResponseStatus.ERROR + assert client._provisioned.is_set() + assert "Cannot connect to ThingsBoard!" in caplog.text + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageDispatcher") +async def test_provision_method_awaits_provisioned(mock_dispatcher_cls, mock_gmqtt_cls, real_request): + mock_client = AsyncMock() + mock_gmqtt_cls.return_value = mock_client + + client = ProvisioningClient("localhost", 1883, real_request) + expected_config = MagicMock() + client._device_config = expected_config + client._provisioned.set() + + result = await client.provision() + + mock_client.connect.assert_awaited_once_with("localhost", 1883) + assert result == expected_config + + +def test_initial_state(real_request): + client = ProvisioningClient("host", 1234, real_request) + + assert client._host == "host" + assert client._port == 1234 + assert client._provision_request == real_request + assert client._client_id == "provision" + assert not client._provisioned.is_set() + assert client._device_config is None diff --git a/tests/common/test_publish_result.py b/tests/common/test_publish_result.py new file mode 100644 index 0000000..54a7ce6 --- /dev/null +++ b/tests/common/test_publish_result.py @@ -0,0 +1,85 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from tb_mqtt_client.common.publish_result import PublishResult + + +@pytest.fixture +def default_publish_result(): + return PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + + +def test_publish_result_attributes(default_publish_result): + assert default_publish_result.topic == "v1/devices/me/telemetry" + assert default_publish_result.qos == 1 + assert default_publish_result.message_id == 123 + assert default_publish_result.payload_size == 256 + assert default_publish_result.reason_code == 0 + + +def test_publish_result_repr(default_publish_result): + result = repr(default_publish_result) + assert isinstance(result, str) + assert "PublishResult" in result + assert "v1/devices/me/telemetry" in result + assert "qos=1" in result + assert "message_id=123" in result + assert "payload_size=256" in result + assert "reason_code=0" in result + + +def test_publish_result_as_dict(default_publish_result): + d = default_publish_result.as_dict() + assert isinstance(d, dict) + assert d == { + "topic": "v1/devices/me/telemetry", + "qos": 1, + "message_id": 123, + "payload_size": 256, + "reason_code": 0 + } + + +def test_publish_result_is_successful_true(default_publish_result): + assert default_publish_result.is_successful() is True + + +def test_publish_result_is_successful_false(): + result = PublishResult( + topic="v1/devices/me/attributes", + qos=0, + message_id=999, + payload_size=0, + reason_code=128 # Simulated failure + ) + assert result.is_successful() is False + + +@pytest.mark.parametrize("reason_code", [1, 2, 3, 16, 255]) +def test_publish_result_various_failure_codes(reason_code): + result = PublishResult( + topic="v1/devices/me/rpc", + qos=2, + message_id=42, + payload_size=100, + reason_code=reason_code + ) + assert result.is_successful() is False diff --git a/tests/common/test_rate_limit.py b/tests/common/test_rate_limit.py index 8088329..c487097 100644 --- a/tests/common/test_rate_limit.py +++ b/tests/common/test_rate_limit.py @@ -13,9 +13,10 @@ # limitations under the License. import asyncio -import pytest from time import sleep +import pytest + from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, GreedyTokenBucket @@ -118,7 +119,3 @@ async def test_rate_limit_refill_behavior(): await asyncio.sleep(1.1) await rl.refill() assert (await rl.try_consume()) is None - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/entities/__init__.py b/tests/entities/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/entities/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entities/data/__init__.py b/tests/entities/data/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/entities/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entities/data/test_attribute_entry.py b/tests/entities/data/test_attribute_entry.py new file mode 100644 index 0000000..a003fd6 --- /dev/null +++ b/tests/entities/data/test_attribute_entry.py @@ -0,0 +1,63 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@pytest.mark.parametrize("key, value", [ + ("temperature", 25), + ("status", True), + ("label", "sensor-A"), + ("list_val", [1, 2, 3]), + ("dict_val", {"k": "v"}) +]) +def test_attribute_entry_as_dict(key, value): + entry = AttributeEntry(key, value) + expected = {"key": key, "value": value} + assert entry.as_dict() == expected + + +@pytest.mark.parametrize("key, value", [ + ("a", 1), + ("b", "val"), + ("c", {"k": "v"}) +]) +def test_attribute_entry_repr(key, value): + entry = AttributeEntry(key, value) + assert repr(entry) == f"AttributeEntry(key={key}, value={value})" + + +def test_attribute_entry_eq_same(): + e1 = AttributeEntry("k1", 123) + e2 = AttributeEntry("k1", 123) + assert e1 == e2 + + +def test_attribute_entry_eq_diff_key(): + e1 = AttributeEntry("k1", 123) + e2 = AttributeEntry("k2", 123) + assert e1 != e2 + + +def test_attribute_entry_eq_diff_value(): + e1 = AttributeEntry("k1", 123) + e2 = AttributeEntry("k1", 456) + assert e1 != e2 + + +def test_attribute_entry_eq_wrong_type(): + e1 = AttributeEntry("k1", 123) + assert e1 != {"key": "k1", "value": 123} diff --git a/tests/entities/data/test_attribute_request.py b/tests/entities/data/test_attribute_request.py new file mode 100644 index 0000000..700ec15 --- /dev/null +++ b/tests/entities/data/test_attribute_request.py @@ -0,0 +1,85 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, patch + +import pytest + +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", new_callable=AsyncMock) +async def test_attribute_request_build_with_keys(mock_get_next): + mock_get_next.return_value = 101 + + request = await AttributeRequest.build( + shared_keys=["shared1", "shared2"], + client_keys=["client1"] + ) + + assert request.request_id == 101 + assert request.shared_keys == ["shared1", "shared2"] + assert request.client_keys == ["client1"] + assert request.to_payload_format() == { + "sharedKeys": "shared1,shared2", + "clientKeys": "client1" + } + assert repr(request) == ( + "AttributeRequest(id=101, shared_keys=['shared1', 'shared2'], client_keys=['client1'])" + ) + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", new_callable=AsyncMock) +async def test_attribute_request_build_with_none_keys(mock_get_next): + mock_get_next.return_value = 202 + + request = await AttributeRequest.build() + assert request.request_id == 202 + assert request.shared_keys is None + assert request.client_keys is None + assert request.to_payload_format() == {} + assert repr(request) == "AttributeRequest(id=202, shared_keys=None, client_keys=None)" + + +@pytest.mark.asyncio +async def test_attribute_request_direct_instantiation_fails(): + with pytest.raises(TypeError, match="Direct instantiation of AttributeRequest is not allowed"): + AttributeRequest(1, ["s1"], ["c1"]) + + +@pytest.mark.asyncio +async def test_attribute_request_invalid_json_keys(): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", + new_callable=AsyncMock) as mock_get_next: + mock_get_next.return_value = 999 + + with pytest.raises(ValueError, match="unsupported - set"): + await AttributeRequest.build(shared_keys={1, 2, 3}) + + with pytest.raises(ValueError, match="expected str, got int"): + await AttributeRequest.build(client_keys=[{"bad_key": {1: "value"}}]) + + +@pytest.mark.asyncio +async def test_attribute_request_invalid_nested_value(): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", + new_callable=AsyncMock) as mock_get_next: + mock_get_next.return_value = 1001 + + def dummy(): pass + + with pytest.raises(ValueError, match="unsupported - function"): + await AttributeRequest.build(client_keys=[{"nested": dummy}]) diff --git a/tests/entities/data/test_attribute_update.py b/tests/entities/data/test_attribute_update.py new file mode 100644 index 0000000..87bf108 --- /dev/null +++ b/tests/entities/data/test_attribute_update.py @@ -0,0 +1,84 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate + + +@pytest.fixture +def example_entries(): + return [ + AttributeEntry("temperature", 23), + AttributeEntry("humidity", 50), + AttributeEntry("status", "ok") + ] + + +def test_repr(example_entries): + update = AttributeUpdate(example_entries) + assert repr(update) == f"AttributeUpdate(entries={example_entries})" + + +def test_get_existing_key(example_entries): + update = AttributeUpdate(example_entries) + assert update.get("temperature") == 23 + assert update.get("status") == "ok" + + +def test_get_missing_key_with_default(example_entries): + update = AttributeUpdate(example_entries) + assert update.get("nonexistent") is None + assert update.get("nonexistent", "default") == "default" + + +def test_keys(example_entries): + update = AttributeUpdate(example_entries) + assert update.keys() == ["temperature", "humidity", "status"] + + +def test_values(example_entries): + update = AttributeUpdate(example_entries) + assert update.values() == [23, 50, "ok"] + + +def test_items(example_entries): + update = AttributeUpdate(example_entries) + assert update.items() == [ + ("temperature", 23), + ("humidity", 50), + ("status", "ok") + ] + + +def test_as_dict(example_entries): + update = AttributeUpdate(example_entries) + assert update.as_dict() == { + "temperature": 23, + "humidity": 50, + "status": "ok" + } + + +def test_deserialize_from_dict(): + raw = { + "speed": 100, + "enabled": True + } + update = AttributeUpdate._deserialize_from_dict(raw) + assert isinstance(update, AttributeUpdate) + assert update.as_dict() == raw + assert update.get("speed") == 100 + assert update.get("enabled") is True diff --git a/tests/entities/data/test_data_entry.py b/tests/entities/data/test_data_entry.py new file mode 100644 index 0000000..4a5ac7e --- /dev/null +++ b/tests/entities/data/test_data_entry.py @@ -0,0 +1,95 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from tb_mqtt_client.entities.data.data_entry import DataEntry + + +def test_data_entry_initialization_without_timestamp(): + entry = DataEntry("temperature", 22.5) + + assert entry.key == "temperature" + assert entry.value == 22.5 + assert entry.ts is None + assert isinstance(entry.size, int) + assert "DataEntry(key=temperature, value=22.5, ts=None)" == repr(entry) + + +def test_data_entry_initialization_with_timestamp(): + entry = DataEntry("humidity", 60, ts=1717171717) + + assert entry.key == "humidity" + assert entry.value == 60 + assert entry.ts == 1717171717 + assert isinstance(entry.size, int) + assert "DataEntry(key=humidity, value=60, ts=1717171717)" == repr(entry) + + +def test_data_entry_key_setter_recalculates_size(): + entry = DataEntry("initial", True) + original_size = entry.size + entry.key = "updated_key" + + assert entry.key == "updated_key" + assert isinstance(entry.size, int) + assert entry.size != original_size + + +def test_data_entry_value_setter_recalculates_size(): + entry = DataEntry("key", "initial") + original_size = entry.size + entry.value = "updated_value" + + assert entry.value == "updated_value" + assert isinstance(entry.size, int) + assert entry.size != original_size + + +def test_data_entry_ts_setter_recalculates_size(): + entry = DataEntry("voltage", 3.3) + original_size = entry.size + entry.ts = 1710000000 + + assert entry.ts == 1710000000 + assert isinstance(entry.size, int) + assert entry.size != original_size + + +def test_data_entry_raises_on_invalid_json(): + class NonSerializable: + pass + + with pytest.raises(ValueError, match="unsupported - NonSerializable"): + DataEntry("bad", NonSerializable()) + + +@patch("tb_mqtt_client.entities.data.data_entry.dumps") +def test_estimate_size_called_with_ts(mock_dumps): + mock_dumps.return_value = b'{"ts":123,"values":{"x":42}}' + entry = DataEntry("x", 42, ts=123) + + mock_dumps.assert_called_once_with({"ts": 123, "values": {"x": 42}}) + assert entry.size == len(mock_dumps.return_value) + + +@patch("tb_mqtt_client.entities.data.data_entry.dumps") +def test_estimate_size_called_without_ts(mock_dumps): + mock_dumps.return_value = b'{"x":42}' + entry = DataEntry("x", 42) + + mock_dumps.assert_called_once_with({"x": 42}) + assert entry.size == len(mock_dumps.return_value) diff --git a/tests/entities/data/test_device_uplink_message.py b/tests/entities/data/test_device_uplink_message.py new file mode 100644 index 0000000..bd90cc7 --- /dev/null +++ b/tests/entities/data/test_device_uplink_message.py @@ -0,0 +1,156 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import OrderedDict +from types import MappingProxyType + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + + +@pytest.fixture +def attribute_entry(): + return AttributeEntry(key="temp", value=42) + + +@pytest.fixture +def timeseries_entry(): + return TimeseriesEntry(key="speed", value=88, ts=1234567890) + + +def test_direct_instantiation_forbidden(): + with pytest.raises(TypeError, match="Direct instantiation of DeviceUplinkMessage is not allowed"): + DeviceUplinkMessage(device_name="test", device_profile="default", attributes=(), timeseries={}, + delivery_futures=[], _size=0) + + +def test_build_empty_message(): + builder = DeviceUplinkMessageBuilder() + msg = builder.build() + + assert msg.device_name is None + assert msg.device_profile is None + assert msg.attributes == () + assert isinstance(msg.timeseries, MappingProxyType) + assert dict(msg.timeseries) == {} + assert len(msg.delivery_futures) == 1 + assert isinstance(msg.delivery_futures[0], asyncio.Future) + assert msg.size == DEFAULT_FIELDS_SIZE + assert not msg.has_attributes() + assert not msg.has_timeseries() + assert msg.attributes_datapoint_count() == 0 + assert msg.timeseries_datapoint_count() == 0 + + +def test_set_device_name_and_profile(): + builder = DeviceUplinkMessageBuilder() + msg = builder.set_device_name("device-1").set_device_profile("profile-x").build() + + assert msg.device_name == "device-1" + assert msg.device_profile == "profile-x" + assert msg.size == DEFAULT_FIELDS_SIZE + len("device-1") + len("profile-x") + + +def test_add_single_attribute(attribute_entry): + builder = DeviceUplinkMessageBuilder() + msg = builder.add_attributes(attribute_entry).build() + + assert len(msg.attributes) == 1 + assert msg.attributes[0] == attribute_entry + assert msg.attributes_datapoint_count() == 1 + assert msg.has_attributes() + assert msg.size == DEFAULT_FIELDS_SIZE + attribute_entry.size + + +def test_add_multiple_attributes(attribute_entry): + other = AttributeEntry(key="humidity", value=60) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_attributes([attribute_entry, other]).build() + + assert len(msg.attributes) == 2 + assert msg.attributes == (attribute_entry, other) + expected_size = DEFAULT_FIELDS_SIZE + attribute_entry.size + other.size + assert msg.size == expected_size + + +def test_add_single_timeseries_entry(timeseries_entry): + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries(timeseries_entry).build() + + assert len(msg.timeseries) == 1 + assert timeseries_entry.ts in msg.timeseries + assert msg.timeseries[timeseries_entry.ts] == (timeseries_entry,) + assert msg.has_timeseries() + assert msg.timeseries_datapoint_count() == 1 + assert msg.size == DEFAULT_FIELDS_SIZE + timeseries_entry.size + + +def test_add_multiple_timeseries_entries(timeseries_entry): + entry2 = TimeseriesEntry(key="vibration", value=1.5, ts=timeseries_entry.ts) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries([timeseries_entry, entry2]).build() + + assert timeseries_entry.ts in msg.timeseries + assert msg.timeseries[timeseries_entry.ts] == (timeseries_entry, entry2) + assert msg.timeseries_datapoint_count() == 2 + assert msg.size == DEFAULT_FIELDS_SIZE + timeseries_entry.size + entry2.size + + +def test_add_timeseries_with_none_ts(): + entry = TimeseriesEntry(key="x", value=1, ts=None) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries(entry).build() + + assert 0 in msg.timeseries + assert msg.timeseries[0] == (entry,) + assert msg.timeseries_datapoint_count() == 1 + assert msg.size == DEFAULT_FIELDS_SIZE + entry.size + + +def test_add_timeseries_from_ordered_dict(timeseries_entry): + ordered = OrderedDict({timeseries_entry.ts: [timeseries_entry]}) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries(ordered).build() + + assert msg.timeseries[timeseries_entry.ts] == (timeseries_entry,) + assert msg.timeseries_datapoint_count() == 1 + + +def test_add_delivery_futures(): + future1 = asyncio.Future() + future2 = asyncio.Future() + builder = DeviceUplinkMessageBuilder() + msg = builder.add_delivery_futures([future1, future2]).build() + + assert msg.delivery_futures == (future1, future2) + + +def test_repr_contains_info(attribute_entry, timeseries_entry): + builder = DeviceUplinkMessageBuilder() + msg = builder.set_device_name("test-device") \ + .set_device_profile("dp") \ + .add_attributes(attribute_entry) \ + .add_timeseries(timeseries_entry) \ + .build() + + repr_str = repr(msg) + assert "test-device" in repr_str + assert "dp" in repr_str + assert "AttributeEntry" in repr_str + assert "TimeseriesEntry" in repr_str diff --git a/tests/entities/data/test_provisioning_data.py b/tests/entities/data/test_provisioning_data.py new file mode 100644 index 0000000..d2eb4af --- /dev/null +++ b/tests/entities/data/test_provisioning_data.py @@ -0,0 +1,169 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch, mock_open + +import pytest + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials, \ + BasicProvisioningCredentials, X509ProvisioningCredentials +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse + + +@pytest.fixture +def access_token_request(): + return ProvisioningRequest( + host="my-host", + port=1883, + credentials=AccessTokenProvisioningCredentials( + provision_device_key="provision_device_key", + provision_device_secret="provision_device_secret", + access_token="access-token123" + ) + ) + + +@pytest.fixture +def mqtt_basic_request(): + return ProvisioningRequest( + host="test-host", + port=8883, + credentials=BasicProvisioningCredentials( + provision_device_key="provision_device_key", + provision_device_secret="provision_device_secret", + client_id="my-client-id", + username="user1", + password="pass123" + ) + ) + + +@pytest.fixture +def x509_request(): + mocked_cert_content = "-----BEGIN CERTIFICATE-----\nABCDEF\n-----END CERTIFICATE-----" + + with patch("builtins.open", mock_open(read_data=mocked_cert_content)): + return ProvisioningRequest( + host="secure-host", + port=8883, + credentials=X509ProvisioningCredentials( + provision_device_key="provision_device_key", + provision_device_secret="provision_device_secret", + private_key_path="/path/to/private.key", + public_cert_path="/path/to/client.crt", + ca_cert_path="/path/to/ca.crt" + ) + ) + + +def test_direct_instantiation_fails(): + with pytest.raises(TypeError, match="Direct instantiation of ProvisioningResponse is not allowed"): + ProvisioningResponse(status=ProvisioningResponseStatus.SUCCESS) + + +def test_error_response(): + access_token_credentials = AccessTokenProvisioningCredentials("provision_device_key", "provision_device_secret") + request = ProvisioningRequest(host="any", port=1883, credentials=access_token_credentials) + payload = {"errorMsg": "Provisioning failed", "status": ProvisioningResponseStatus.ERROR.value} + + response = ProvisioningResponse.build(request, payload) + + assert response.status == ProvisioningResponseStatus.ERROR + assert response.result is None + assert response.error == "Provisioning failed" + + +def test_success_access_token(access_token_request): + payload = {"credentialsValue": "ACCESS-TOKEN-123"} + + response = ProvisioningResponse.build(access_token_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + assert response.error is None + assert isinstance(response.result, DeviceConfig) + assert response.result.access_token == "ACCESS-TOKEN-123" + assert response.result.host == "my-host" + assert response.result.port == 1883 + + +def test_success_mqtt_basic(mqtt_basic_request): + payload = { + "credentialsValue": { + "clientId": "my-client-id", + "userName": "user1", + "password": "pass123" + } + } + + response = ProvisioningResponse.build(mqtt_basic_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + assert isinstance(response.result, DeviceConfig) + config = response.result + assert config.client_id == "my-client-id" + assert config.username == "user1" + assert config.password == "pass123" + assert config.host == "test-host" + assert config.port == 8883 + assert response.error is None + + +def test_success_x509(x509_request): + payload = {"credentialsValue": None} # Should be ignored for X509 + + response = ProvisioningResponse.build(x509_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + config = response.result + assert config.ca_cert == "/path/to/ca.crt" + assert config.client_cert == "/path/to/client.crt" + assert config.private_key == "/path/to/private.key" + assert config.host == "secure-host" + assert config.port == 8883 + assert response.error is None + + +def test_repr_output(access_token_request): + payload = {"credentialsValue": "access-token"} + response = ProvisioningResponse.build(access_token_request, payload) + + r = repr(response) + assert r.startswith("ProvisioningResponse(status=SUCCESS") + assert "DeviceConfig(" in r + assert "host=my-host" in r + assert "port=1883" in r + assert "error=None" in r + + +def test_missing_credentials_value_for_access_token(access_token_request): + payload = {} # Missing 'credentialsValue' + + with pytest.raises(KeyError): + ProvisioningResponse.build(access_token_request, payload) + + +def test_mqtt_basic_missing_fields(mqtt_basic_request): + payload = {"credentialsValue": {}} # All fields missing + + with pytest.raises(KeyError): + ProvisioningResponse.build(mqtt_basic_request, payload) + + +def test_access_token_none_is_accepted(access_token_request): + payload = {"credentialsValue": None} + response = ProvisioningResponse.build(access_token_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + assert response.result.access_token is None diff --git a/tests/entities/data/test_requested_attribute_response.py b/tests/entities/data/test_requested_attribute_response.py new file mode 100644 index 0000000..ff6e6bf --- /dev/null +++ b/tests/entities/data/test_requested_attribute_response.py @@ -0,0 +1,127 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse + + +@pytest.fixture +def shared_attrs(): + return [AttributeEntry("shared_key_1", "shared_value_1"), AttributeEntry("shared_key_2", 42)] + + +@pytest.fixture +def client_attrs(): + return [AttributeEntry("client_key_1", "client_value_1"), AttributeEntry("client_key_2", True)] + + +@pytest.fixture +def response(shared_attrs, client_attrs): + return RequestedAttributeResponse(request_id=1001, shared=shared_attrs, client=client_attrs) + + +def test_repr(response): + output = repr(response) + assert "RequestedAttributeResponse" in output + assert "request_id=1001" in output + assert "shared_key_1" in output + assert "client_key_1" in output + + +def test_getitem_shared_key(response): + assert response["shared_key_1"] == "shared_value_1" + assert response["shared_key_2"] == 42 + + +def test_getitem_client_key(response): + assert response["client_key_1"] == "client_value_1" + assert response["client_key_2"] is True + + +def test_getitem_key_error(response): + with pytest.raises(KeyError, match="Key 'nonexistent' not found"): + _ = response["nonexistent"] + + +def test_get_shared_success(shared_attrs, client_attrs): + response = RequestedAttributeResponse(5, shared_attrs, client_attrs) + assert response.get_shared("shared_key_1") == "shared_value_1" + assert response.get_shared("shared_key_2") == 42 + + +def test_get_shared_default(response): + assert response.get_shared("unknown", default="fallback") == "fallback" + assert response.get_shared("unknown") is None + + +def test_get_client_success(response): + assert response.get_client("client_key_1") == "client_value_1" + assert response.get_client("client_key_2") is True + + +def test_get_client_default(response): + assert response.get_client("unknown", default=0) == 0 + assert response.get_client("unknown") is None + + +def test_shared_keys(response): + keys = response.shared_keys() + assert keys == ["shared_key_1", "shared_key_2"] + + +def test_client_keys(response): + keys = response.client_keys() + assert keys == ["client_key_1", "client_key_2"] + + +def test_as_dict(response): + result = response.as_dict() + assert isinstance(result, dict) + assert "shared" in result + assert "client" in result + assert result["shared"][0]["key"] == "shared_key_1" + assert result["client"][1]["value"] is True + + +def test_from_dict_full(): + data = { + "request_id": 77, + "shared": {"temp": 25, "mode": "cool"}, + "client": {"state": "on", "power": 100} + } + resp = RequestedAttributeResponse.from_dict(data) + assert isinstance(resp, RequestedAttributeResponse) + assert resp.request_id == 77 + assert resp.get_shared("temp") == 25 + assert resp.get_client("power") == 100 + + +def test_from_dict_missing_request_id(): + data = { + "shared": {"x": 1}, + "client": {"y": 2} + } + resp = RequestedAttributeResponse.from_dict(data) + assert resp.request_id == -1 + assert resp.get_shared("x") == 1 + assert resp.get_client("y") == 2 + + +def test_from_dict_empty(): + resp = RequestedAttributeResponse.from_dict({}) + assert resp.request_id == -1 + assert resp.shared == [] + assert resp.client == [] diff --git a/tests/entities/data/test_rpc_request.py b/tests/entities/data/test_rpc_request.py new file mode 100644 index 0000000..b3ca496 --- /dev/null +++ b/tests/entities/data/test_rpc_request.py @@ -0,0 +1,91 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.entities.data.rpc_request import RPCRequest + + +@pytest.mark.asyncio +async def test_rpc_request_build_valid(): + method = "reboot" + params = {"delay": 5} + with patch.object(RPCRequestIdProducer, "get_next", return_value=42): + request = await RPCRequest.build(method, params) + + assert request.request_id == 42 + assert request.method == method + assert request.params == params + assert request.to_payload_format() == { + "id": 42, + "method": "reboot", + "params": {"delay": 5} + } + + +def test_rpc_request_direct_instantiation_raises(): + with pytest.raises(TypeError, match="Direct instantiation of RPCRequest is not allowed"): + RPCRequest(1, "reboot") + + +@pytest.mark.asyncio +async def test_rpc_request_build_with_none_params(): + with patch.object(RPCRequestIdProducer, "get_next", return_value="abc123"): + request = await RPCRequest.build("ping") + + assert request.params is None + assert request.request_id == "abc123" + assert request.to_payload_format() == { + "id": "abc123", + "method": "ping" + } + + +@pytest.mark.asyncio +async def test_rpc_request_build_invalid_method_type(): + with pytest.raises(ValueError, match="Method must be a string"): + await RPCRequest.build(123, {"key": "value"}) + + +def test_rpc_request_deserialize_valid(): + data = {"method": "getStatus", "params": {"verbose": True}} + req = RPCRequest._deserialize_from_dict(1001, data) + + assert req.request_id == 1001 + assert req.method == "getStatus" + assert req.params == {"verbose": True} + assert req.to_payload_format() == { + "id": 1001, + "method": "getStatus", + "params": {"verbose": True} + } + + +def test_rpc_request_deserialize_missing_id(): + with pytest.raises(ValueError, match="Missing request id"): + RPCRequest._deserialize_from_dict(None, {"method": "test"}) + + +def test_rpc_request_deserialize_missing_method(): + with pytest.raises(ValueError, match="Missing 'method' in RPC request"): + RPCRequest._deserialize_from_dict(1, {}) + + +def test_rpc_request_repr(): + data = {"method": "info", "params": {"a": 1}} + req = RPCRequest._deserialize_from_dict(42, data) + assert repr(req) == "RPCRequest(id=42, method=info, params={'a': 1})" diff --git a/tests/entities/data/test_rpc_response.py b/tests/entities/data/test_rpc_response.py new file mode 100644 index 0000000..b7fa703 --- /dev/null +++ b/tests/entities/data/test_rpc_response.py @@ -0,0 +1,87 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.rpc_response import RPCResponse, RPCStatus + + +class NonSerializable: + pass + + +def test_rpc_status_str(): + assert str(RPCStatus.SUCCESS) == "SUCCESS" + assert str(RPCStatus.ERROR) == "ERROR" + assert str(RPCStatus.TIMEOUT) == "TIMEOUT" + assert str(RPCStatus.NOT_FOUND) == "NOT_FOUND" + + +def test_rpc_response_direct_instantiation_raises(): + with pytest.raises(TypeError): + RPCResponse(1, result="test") + + +def test_rpc_response_success(): + response = RPCResponse.build(request_id=1, result={"key": "value"}) + assert response.request_id == 1 + assert response.status == RPCStatus.SUCCESS + assert response.result == {"key": "value"} + assert response.error is None + assert response.to_payload_format() == {"result": {"key": "value"}} + + +def test_rpc_response_error_with_string(): + response = RPCResponse.build(request_id="req-1", error="Something went wrong") + assert response.request_id == "req-1" + assert response.status == RPCStatus.ERROR + assert response.error == "Something went wrong" + assert response.result is None + assert response.to_payload_format() == {"error": "Something went wrong"} + + +def test_rpc_response_error_with_dict(): + error_dict = {"code": 500, "message": "Internal Error"} + response = RPCResponse.build(request_id="req-2", error=error_dict) + assert response.request_id == "req-2" + assert response.status == RPCStatus.ERROR + assert response.error == error_dict + assert response.result is None + assert response.to_payload_format() == {"error": error_dict} + + +def test_rpc_response_error_with_exception(): + try: + raise ValueError("Something failed") + except ValueError as ex: + response = RPCResponse.build(request_id=42, error=ex) + + assert response.status == RPCStatus.ERROR + assert isinstance(response.error, dict) + assert "message" in response.error + assert "type" in response.error + assert "details" in response.error + assert response.error["type"] == "ValueError" + assert response.error["message"] == "Something failed" + assert "raise ValueError(\"Something failed\")" in response.error["details"] + + +def test_rpc_response_invalid_error_type(): + with pytest.raises(ValueError): + RPCResponse.build(request_id=1, error=NonSerializable()) + + +def test_rpc_response_invalid_result_type(): + with pytest.raises(ValueError): + RPCResponse.build(request_id=1, result=NonSerializable()) diff --git a/tests/entities/data/test_timeseries_entry.py b/tests/entities/data/test_timeseries_entry.py new file mode 100644 index 0000000..9bcec4c --- /dev/null +++ b/tests/entities/data/test_timeseries_entry.py @@ -0,0 +1,74 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + + +@pytest.mark.parametrize( + "key,value,ts,expected_dict", + [ + ("temperature", 23.5, 1650000000000, {"key": "temperature", "value": 23.5, "ts": 1650000000000}), + ("status", "online", None, {"key": "status", "value": "online"}), + ("flag", True, 0, {"key": "flag", "value": True, "ts": 0}), + ] +) +def test_as_dict(key, value, ts, expected_dict): + entry = TimeseriesEntry(key, value, ts) + assert entry.as_dict() == expected_dict + + +def test_repr_output(): + entry = TimeseriesEntry("humidity", 60, 1650001234567) + assert repr(entry) == "TimeseriesEntry(key=humidity, value=60, ts=1650001234567)" + + +def test_repr_without_ts(): + entry = TimeseriesEntry("humidity", 60) + assert repr(entry) == "TimeseriesEntry(key=humidity, value=60, ts=None)" + + +def test_equality_same_values(): + e1 = TimeseriesEntry("pressure", 101.3, 1650000000000) + e2 = TimeseriesEntry("pressure", 101.3, 1650000000000) + assert e1 == e2 + + +def test_equality_different_key(): + e1 = TimeseriesEntry("a", 1, 123) + e2 = TimeseriesEntry("b", 1, 123) + assert e1 != e2 + + +def test_equality_different_value(): + e1 = TimeseriesEntry("key", 1, 123) + e2 = TimeseriesEntry("key", 2, 123) + assert e1 != e2 + + +def test_equality_different_ts(): + e1 = TimeseriesEntry("key", "value", 123) + e2 = TimeseriesEntry("key", "value", 456) + assert e1 != e2 + + +def test_equality_different_type(): + e1 = TimeseriesEntry("key", "value", 123) + assert e1 != {"key": "key", "value": "value", "ts": 123} + + +def test_default_ts_none(): + entry = TimeseriesEntry("key", "value") + assert entry.ts is None + assert entry.as_dict() == {"key": "key", "value": "value"} diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py new file mode 100644 index 0000000..395a78e --- /dev/null +++ b/tests/service/device/test_device_client.py @@ -0,0 +1,515 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants.mqtt_topics import DEVICE_CLAIM_TOPIC +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.service.message_queue import MessageQueue + + +@pytest.mark.asyncio +async def test_send_telemetry_with_dict(): + client = DeviceClient() + client._message_queue = AsyncMock(spec=MessageQueue) + future = asyncio.Future() + future.set_result(PublishResult("topic", 1, 1, 100, 1)) + client._message_queue.publish.return_value = [future] + result = await client.send_telemetry({"temp": 22}) + assert isinstance(result, PublishResult) + assert result.message_id == 1 + + +@pytest.mark.asyncio +async def test_send_telemetry_timeout(): + client = DeviceClient() + client._message_queue = AsyncMock() + future = asyncio.Future() + client._message_queue.publish.return_value = [future] + result = await client.send_telemetry({"temp": 22}, timeout=0.01) + assert isinstance(result, PublishResult) + assert result.message_id == -1 + + +@pytest.mark.asyncio +async def test_send_attributes_dict(): + client = DeviceClient() + client._message_queue = AsyncMock() + fut = asyncio.Future() + fut.set_result(PublishResult("attr", 1, 2, 50, 1)) + client._message_queue.publish.return_value = [fut] + result = await client.send_attributes({"key": "val"}) + assert isinstance(result, PublishResult) + assert result.message_id == 2 + + +@pytest.mark.asyncio +async def test_send_rpc_request_timeout(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_request import RPCRequest + request = await RPCRequest.build(method="get", params={}) + client._rpc_response_handler.register_request = MagicMock(return_value=asyncio.Future()) + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [] + with pytest.raises(TimeoutError): + await client.send_rpc_request(request, wait_for_publish=True, timeout=0.01) + + +@pytest.mark.asyncio +async def test_send_rpc_response(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_response import RPCResponse + response = RPCResponse.build(123, result={"ok": True}) + client._message_queue = AsyncMock() + await client.send_rpc_response(response) + client._message_queue.publish.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_claim_device_success(): + client = DeviceClient() + from tb_mqtt_client.entities.data.claim_request import ClaimRequest + claim = ClaimRequest.build(secret_key="abc") + client._message_queue = AsyncMock() + fut = asyncio.Future() + fut.set_result(None) + client._message_queue.publish.return_value = fut + client._DeviceClient__claiming_response_future = asyncio.Future() + client._DeviceClient__claiming_response_future.set_result( + PublishResult(DEVICE_CLAIM_TOPIC, 1, 3, 10, 1) + ) + result = await client.claim_device(claim, timeout=0.01) + assert isinstance(result, PublishResult) + assert result.topic == DEVICE_CLAIM_TOPIC + + +@pytest.mark.asyncio +async def test_send_attribute_request(): + client = DeviceClient() + from tb_mqtt_client.entities.data.attribute_request import AttributeRequest + request = await AttributeRequest.build(client_keys=["key1"]) + client._requested_attribute_response_handler.register_request = AsyncMock() + client._message_queue = AsyncMock() + await client.send_attribute_request(request, AsyncMock()) + client._message_queue.publish.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_rpc_call_timeout(): + client = DeviceClient() + client._mqtt_manager.publish = AsyncMock() + client._rpc_response_handler.register_request = MagicMock(return_value=asyncio.Future()) + with pytest.raises(TimeoutError): + await client.send_rpc_call("reboot", timeout=0.01) + + +@pytest.mark.asyncio +async def test_disconnect(): + client = DeviceClient() + client._mqtt_manager.disconnect = AsyncMock() + await client.disconnect() + client._mqtt_manager.disconnect.assert_awaited() + + +@pytest.mark.asyncio +async def test_stop_disconnects_and_shuts_down_queue(): + client = DeviceClient() + client._mqtt_manager.is_connected = lambda: True + client._mqtt_manager.disconnect = AsyncMock() + client._message_queue = AsyncMock() + await client.stop() + client._message_queue.shutdown.assert_awaited() + client._mqtt_manager.disconnect.assert_awaited() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_invalid_response(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_response import RPCResponse + bad_resp = RPCResponse.build(1, result="invalid") + result = await client._handle_rate_limit_response(bad_resp) + assert result is None + + +@pytest.mark.asyncio +async def test_handle_rate_limit_valid_response(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_response import RPCResponse + client._mqtt_manager.set_rate_limits = MagicMock() + payload = { + "rateLimits": { + "messages": "10:1,", + "telemetryMessages": "100:60,", + "telemetryDataPoints": "500:60," + }, + "maxInflightMessages": 200, + "maxPayloadSize": 512 + } + resp = RPCResponse.build(1, result=payload) + result = await client._handle_rate_limit_response(resp) + assert result is True + assert client.max_payload_size <= 512 + + +@pytest.mark.asyncio +async def test_provision_success(monkeypatch): + from tb_mqtt_client.service.device import client as client_module + mock_prov = AsyncMock() + mock_prov.provision.return_value = "creds" + monkeypatch.setattr(client_module, "ProvisioningClient", lambda **kwargs: mock_prov) + credentials = AccessTokenProvisioningCredentials("provision_device_key", "provision_device_secret") + req = ProvisioningRequest(host="localhost", credentials=credentials, port=1883, device_name="dev") + res = await DeviceClient.provision(req) + assert res == "creds" + + +@pytest.mark.asyncio +async def test_connects_to_platform_with_tls_using_device_config(): + config = DeviceConfig() + config.use_tls = MagicMock(return_value=True) + config.ca_cert = "path/to/ca_cert" + config.client_cert = "path/to/client_cert" + config.private_key = "path/to/private_key" + + mqtt_manager = AsyncMock() + client = DeviceClient(config=config) + client._mqtt_manager = mqtt_manager + client._mqtt_manager.connect = AsyncMock() + client._mqtt_manager.is_connected = MagicMock(return_value=True) + client._mqtt_manager.await_ready = AsyncMock() + + with patch("tb_mqtt_client.service.device.client.ssl.create_default_context") as mock_ssl_context: + ssl_context = mock_ssl_context.return_value + ssl_context.load_verify_locations = MagicMock() + ssl_context.load_cert_chain = MagicMock() + await client.connect() + + mock_ssl_context.assert_called_once() + ssl_context.load_verify_locations.assert_called_once_with(config.ca_cert) + ssl_context.load_cert_chain.assert_called_once_with(certfile=config.client_cert, keyfile=config.private_key) + client._mqtt_manager.connect.assert_awaited_once_with( + host=config.host, + port=config.port, + username=config.access_token or config.username, + password=None if config.access_token else config.password, + tls=True, + ssl_context=ssl_context) + + +@pytest.mark.asyncio +async def test_connects_to_platform_with_tls(): + config = DeviceConfig() + config.use_tls = MagicMock(return_value=True) + config.ca_cert = "path/to/ca_cert" + config.client_cert = "path/to/client_cert" + config.private_key = "path/to/private_key" + config.host = "localhost" + config.port = 8883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + + with patch("tb_mqtt_client.service.device.client.ssl.create_default_context") as mock_ssl_context: + ssl_context = mock_ssl_context.return_value + await client.connect() + + mock_ssl_context.assert_called_once() + ssl_context.load_verify_locations.assert_called_once_with(config.ca_cert) + ssl_context.load_cert_chain.assert_called_once_with( + certfile=config.client_cert, + keyfile=config.private_key + ) + mqtt_manager.connect.assert_awaited_once_with( + host=config.host, + port=config.port, + username=config.access_token, + password=None, + tls=True, + ssl_context=ssl_context + ) + + +@pytest.mark.asyncio +async def test_connects_to_platform_without_tls(): + config = DeviceConfig() + config.use_tls = MagicMock(return_value=False) + config.host = "localhost" + config.port = 1883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + + await client.connect() + + mqtt_manager.connect.assert_awaited_once_with( + host=config.host, + port=config.port, + username=config.access_token, + password=None, + tls=False, + ssl_context=None + ) + + +@pytest.mark.asyncio +async def test_stops_if_event_is_set_during_connection(): + config = DeviceConfig() + config.host = "localhost" + config.port = 1883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + mqtt_manager.is_connected.return_value = False + mqtt_manager.await_ready = AsyncMock() + + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + client._stop_event.set() + + await client.connect() + + mqtt_manager.await_ready.assert_not_awaited() + mqtt_manager.is_connected.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_initializes_dispatcher_and_queue_after_connection(): + config = DeviceConfig() + config.host = "localhost" + config.port = 1883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + mqtt_manager.is_connected.return_value = True + + with patch("tb_mqtt_client.service.device.client.MessageQueue") as mock_queue: + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + + await client.connect() + + assert client.max_payload_size == 65535 + assert client._message_dispatcher is not None + assert client._message_queue is not None + + mock_queue.assert_called_once() + kwargs = mock_queue.call_args.kwargs + assert kwargs["max_queue_size"] == client._max_uplink_message_queue_size + + +class FakeSplitter: + def __init__(self): + self.max_payload_size = None + + +@pytest.mark.asyncio +async def test_uses_default_max_payload_size_when_not_provided(): + client = DeviceClient() + client.max_payload_size = None + + splitter = FakeSplitter() + client._message_dispatcher = MagicMock() + client._message_dispatcher.splitter = splitter + + resp = RPCResponse.build(1, result={"rateLimits": {}}) + await client._handle_rate_limit_response(resp) + + assert client.max_payload_size == 65535 + assert splitter.max_payload_size == 65535 + + +@pytest.mark.asyncio +async def test_does_not_update_dispatcher_when_not_initialized(): + client = DeviceClient() + client.max_payload_size = None + client._message_dispatcher = None + + resp = RPCResponse.build(1, result={"rateLimits": {}}) + await client._handle_rate_limit_response(resp) + + assert client.max_payload_size == 65535 + + +@pytest.mark.asyncio +async def test_publish_result_device_claim_successfully_sets_future(): + publish_result = PublishResult(topic=DEVICE_CLAIM_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0) + client = DeviceClient() + client._DeviceClient__claiming_response_future = MagicMock() + client._DeviceClient__claiming_response_future.done.return_value = False + + await client._DeviceClient__on_publish_result(publish_result) + + client._DeviceClient__claiming_response_future.set_result.assert_called_once_with(True) + + +@pytest.mark.asyncio +async def test_send_telemetry_without_connect_raises_error(): + client = DeviceClient() + with pytest.raises(AttributeError): + await client.send_telemetry({"temp": 22}) + + +@pytest.mark.asyncio +async def test_handle_attribute_update_calls_handler(): + client = DeviceClient() + mock_handler = AsyncMock() + client._attribute_updates_handler.handle = mock_handler + await client._handle_attribute_update("topic", b'{"key":"value"}') + mock_handler.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_rpc_request_triggers_rpc_response(): + client = DeviceClient() + mock_handler = AsyncMock(return_value=RPCResponse.build(1, result={"res": "ok"})) + client._rpc_requests_handler.handle = mock_handler + client.send_rpc_response = AsyncMock() + await client._handle_rpc_request("topic", b'{"method": "test"}') + client.send_rpc_response.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response_with_partial_data(): + client = DeviceClient() + response = RPCResponse.build(1, result={"rateLimits": {"messages": "5:1"}}) + result = await client._handle_rate_limit_response(response) + assert result is True + assert client._messages_rate_limit.has_limit() + assert client.max_payload_size == 65535 + + +@pytest.mark.asyncio +async def test_update_firmware_triggers_firmware_updater(): + client = DeviceClient() + updater = AsyncMock() + client._firmware_updater = updater + await client.update_firmware() + updater.update.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_rpc_request_with_callback_executes(): + client = DeviceClient() + rpc = await RPCRequest.build(method="get", params={}) + client._rpc_response_handler.register_request = MagicMock(return_value=AsyncMock()) + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [] + + callback = AsyncMock() + await client.send_rpc_request(rpc, callback=callback, wait_for_publish=False) + client._rpc_response_handler.register_request.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_attribute_update_callback_sets_handler(): + client = DeviceClient() + cb = AsyncMock() + client.set_attribute_update_callback(cb) + assert client._attribute_updates_handler._callback == cb + + +@pytest.mark.asyncio +async def test_send_telemetry_with_invalid_type_raises(): + client = DeviceClient() + with pytest.raises(ValueError): + await client.send_telemetry("invalid") + + +@pytest.mark.asyncio +async def test_send_attributes_no_wait(): + client = DeviceClient() + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [asyncio.Future()] + result = await client.send_attributes({"attr": "val"}, wait_for_publish=False) + assert result is None + + +@pytest.mark.asyncio +async def test_claim_device_returns_future_when_not_waiting(): + client = DeviceClient() + claim = ClaimRequest.build("secret_key") + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [asyncio.Future()] + result = await client.claim_device(claim, wait_for_publish=False) + assert isinstance(result, asyncio.Future) + + +@pytest.mark.asyncio +async def test_handle_rpc_response_calls_handler(): + client = DeviceClient() + mock = AsyncMock() + client._rpc_response_handler.handle = mock + await client._handle_rpc_response("topic", b"payload") + mock.assert_awaited_once_with("topic", b"payload") + + +@pytest.mark.asyncio +async def test_handle_requested_attribute_response_calls_handler(): + client = DeviceClient() + mock = AsyncMock() + client._requested_attribute_response_handler.handle = mock + await client._handle_requested_attribute_response("topic", b"payload") + mock.assert_awaited_once_with("topic", b"payload") + + +@pytest.mark.asyncio +async def test_on_disconnect_clears_handlers(): + client = DeviceClient() + client._requested_attribute_response_handler.clear = MagicMock() + client._rpc_response_handler.clear = MagicMock() + await client._on_disconnect() + client._requested_attribute_response_handler.clear.assert_called_once() + client._rpc_response_handler.clear.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_rpc_request_callback_sets_handler(): + client = DeviceClient() + cb = AsyncMock() + client.set_rpc_request_callback(cb) + assert client._rpc_requests_handler._callback == cb + + +@pytest.mark.asyncio +async def test_provision_timeout(monkeypatch): + from tb_mqtt_client.service.device import client as client_module + mock_prov = AsyncMock() + mock_prov.provision.side_effect = asyncio.TimeoutError() + monkeypatch.setattr(client_module, "ProvisioningClient", lambda **kwargs: mock_prov) + + credentials = AccessTokenProvisioningCredentials("key", "secret") + req = ProvisioningRequest(host="localhost", credentials=credentials, port=1883, device_name="dev") + result = await DeviceClient.provision(req, timeout=0.01) + assert result is None diff --git a/tests/service/device/test_firmware_updater.py b/tests/service/device/test_firmware_updater.py index d742ef1..c86d91c 100644 --- a/tests/service/device/test_firmware_updater.py +++ b/tests/service/device/test_firmware_updater.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import asyncio +from hashlib import sha256, sha384, sha512, md5 from unittest.mock import AsyncMock, MagicMock, patch, ANY +from zlib import crc32 + +import pytest + +from tb_mqtt_client.constants.firmware import * +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater -from tb_mqtt_client.constants.firmware import FW_TITLE_ATTR, FW_VERSION_ATTR, FW_STATE_ATTR, FirmwareStates @pytest.fixture @@ -32,24 +37,28 @@ def mock_client(): client.send_attribute_request = AsyncMock() return client + @pytest.fixture def updater(mock_client): return FirmwareUpdater(mock_client) + @pytest.mark.asyncio async def test_update_success(updater, mock_client): - with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequest.build", new_callable=AsyncMock) as mock_build, \ - patch.object(updater, "_firmware_info_callback", new=AsyncMock()): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequest.build", + new_callable=AsyncMock) as mock_build, \ + patch.object(updater, "_firmware_info_callback", new=AsyncMock()): await updater.update() mock_client._mqtt_manager.subscribe.assert_called_once() mock_client.send_telemetry.assert_called() mock_client.send_attribute_request.assert_called_once() + @pytest.mark.asyncio -async def test_update_not_connected(updater, mock_client, caplog): +async def test_update_not_connected(updater, mock_client): mock_client._mqtt_manager.is_connected.return_value = False await updater.update() - assert "Client is not connected" in caplog.text + @pytest.mark.asyncio async def test_handle_firmware_update_full(updater): @@ -62,6 +71,7 @@ async def test_handle_firmware_update_full(updater): verify.assert_awaited_once() assert updater._firmware_data == b'abcd' + @pytest.mark.asyncio async def test_handle_firmware_update_partial(updater): updater._target_firmware_length = 10 @@ -73,6 +83,7 @@ async def test_handle_firmware_update_partial(updater): next_chunk.assert_awaited_once() assert updater._firmware_data.endswith(payload) + @pytest.mark.asyncio async def test_get_next_chunk_valid(updater, mock_client): updater._chunk_size = 5 @@ -82,6 +93,7 @@ async def test_get_next_chunk_valid(updater, mock_client): await updater._get_next_chunk() mock_client._message_queue.publish.assert_awaited() + @pytest.mark.asyncio async def test_get_next_chunk_empty_payload(updater, mock_client): updater._chunk_size = 15 @@ -94,6 +106,7 @@ async def test_get_next_chunk_empty_payload(updater, mock_client): qos=1 ) + @pytest.mark.asyncio async def test_verify_downloaded_firmware_success(updater): updater._firmware_data = b'data' @@ -101,21 +114,23 @@ async def test_verify_downloaded_firmware_success(updater): updater._target_checksum_alg = "md5" updater.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value with patch.object(updater, 'verify_checksum', return_value=True), \ - patch.object(updater, '_apply_downloaded_firmware', new=AsyncMock()), \ - patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): + patch.object(updater, '_apply_downloaded_firmware', new=AsyncMock()), \ + patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): await updater._verify_downloaded_firmware() assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.VERIFIED.value + @pytest.mark.asyncio async def test_verify_downloaded_firmware_fail(updater): updater._firmware_data = b'data' updater._target_checksum = "wrong" updater._target_checksum_alg = "md5" with patch.object(updater, 'verify_checksum', return_value=False), \ - patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): + patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): await updater._verify_downloaded_firmware() assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + @pytest.mark.asyncio async def test_apply_downloaded_firmware_saves_file(tmp_path, updater): updater._firmware_data = b'binary-firmware' @@ -125,23 +140,26 @@ async def test_apply_downloaded_firmware_saves_file(tmp_path, updater): updater._save_firmware = True updater._on_received_callback = AsyncMock() with patch.object(updater, '_send_current_firmware_info', new=AsyncMock()), \ - patch.object(updater._client._mqtt_manager, 'unsubscribe', new=AsyncMock()): + patch.object(updater._client._mqtt_manager, 'unsubscribe', new=AsyncMock()): await updater._apply_downloaded_firmware() assert (tmp_path / 'fw.bin').exists() + def test_verify_checksum_md5_valid(updater): result = updater.verify_checksum(b'data', 'md5', "8d777f385d3dfec8815d20f7496026dc") assert isinstance(result, bool) -def test_verify_checksum_invalid_algorithm(updater, caplog): + +def test_verify_checksum_invalid_algorithm(updater): result = updater.verify_checksum(b'data', 'invalid_alg', "deadbeef") assert result is False - assert 'Unsupported checksum algorithm' in caplog.text + def test_is_different_versions_true(updater): new_info = {FW_TITLE_ATTR: 'fw', FW_VERSION_ATTR: 'v2'} assert updater._is_different_firmware_versions(new_info) is True + def test_is_different_versions_false(updater): updater.current_firmware_info['current_' + FW_TITLE_ATTR] = 'fw' updater.current_firmware_info['current_' + FW_VERSION_ATTR] = 'v2' @@ -149,5 +167,104 @@ def test_is_different_versions_false(updater): assert updater._is_different_firmware_versions(new_info) is False -if __name__ == '__main__': - pytest.main([__file__]) +@pytest.mark.asyncio +async def test_save_firmware_failure_logs_error(updater, caplog): + updater._firmware_data = b'data' + updater._target_title = "fw.bin" + updater._target_version = "v1" + updater._save_firmware = True + with patch.object(updater, '_send_current_firmware_info', new=AsyncMock()), \ + patch.object(updater._client._mqtt_manager, 'unsubscribe', new=AsyncMock()), \ + patch.object(updater, '_save', side_effect=IOError("disk error")): + await updater._apply_downloaded_firmware() + assert "Failed to save firmware" in caplog.text + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + + +@pytest.mark.asyncio +async def test_send_current_firmware_info_calls_send_telemetry(updater, mock_client): + updater.current_firmware_info = { + f"current_{FW_TITLE_ATTR}": "test", + f"current_{FW_VERSION_ATTR}": "1.0", + FW_STATE_ATTR: FirmwareStates.DOWNLOADING.value + } + await updater._send_current_firmware_info() + mock_client.send_telemetry.assert_awaited_once() + args = mock_client.send_telemetry.call_args[0][0] + assert all(isinstance(entry, TimeseriesEntry) for entry in args) + + +@pytest.mark.asyncio +async def test_firmware_info_callback_keys_mismatch(updater, caplog): + response = MagicMock() + response.shared_keys.return_value = ["unexpected_key"] + await updater._firmware_info_callback(response) + assert "does not match required keys" in caplog.text + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + + +@pytest.mark.asyncio +async def test_firmware_info_callback_same_version(updater): + updater.current_firmware_info[f"current_{FW_TITLE_ATTR}"] = "fw" + updater.current_firmware_info[f"current_{FW_VERSION_ATTR}"] = "v1" + + response = MagicMock() + response.shared_keys.return_value = REQUIRED_SHARED_KEYS + response.as_dict.return_value = { + "shared": [{"key": FW_TITLE_ATTR, "value": "fw"}, + {"key": FW_VERSION_ATTR, "value": "v1"}, + {"key": FW_SIZE_ATTR, "value": 123}, + {"key": FW_CHECKSUM_ALG_ATTR, "value": "dummy"}, + {"key": FW_CHECKSUM_ATTR, "value": "dummy"}] + } + + await updater._firmware_info_callback(response) + + +@pytest.mark.asyncio +async def test_firmware_info_callback_triggers_download(updater): + response = MagicMock() + response.shared_keys.return_value = REQUIRED_SHARED_KEYS + response.as_dict.return_value = { + "shared": [{"key": FW_TITLE_ATTR, "value": "new_fw"}, + {"key": FW_VERSION_ATTR, "value": "v2"}, + {"key": FW_SIZE_ATTR, "value": 123}, + {"key": FW_CHECKSUM_ALG_ATTR, "value": "alg"}, + {"key": FW_CHECKSUM_ATTR, "value": "chk"}] + } + + with patch.object(updater, '_get_next_chunk', new=AsyncMock()) as mocked: + await updater._firmware_info_callback(response) + mocked.assert_awaited_once() + assert updater._target_title == "new_fw" + assert updater._target_version == "v2" + + +def test_verify_checksum_null_data(updater): + result = updater.verify_checksum(None, "md5", "abc") + assert not result + + +def test_verify_checksum_null_checksum(updater): + result = updater.verify_checksum(b"data", "md5", None) + assert not result + + +@pytest.mark.parametrize("alg", [ + ("sha256", sha256(b"data").digest().hex()), + ("sha384", sha384(b"data").digest().hex()), + ("sha512", sha512(b"data").digest().hex()), + ("md5", md5(b"data").digest().hex()), + ("crc32", "".join(reversed([f'{crc32(b"data") & 0xffffffff:0>2X}'[i:i + 2] + for i in range(0, len(f'{crc32(b"data") & 0xffffffff:0>2X}'), 2)])).lower()) +]) +def test_verify_checksum_known_algorithms(updater, alg): + name, checksum = alg + with patch("tb_mqtt_client.service.device.firmware_updater.randint", return_value=0): + assert updater.verify_checksum(b"data", name, checksum) is True + + +def test_verify_checksum_random_failure(updater): + with patch("tb_mqtt_client.service.device.firmware_updater.randint", return_value=5): + result = updater.verify_checksum(b"data", "md5", md5(b"data").digest().hex()) + assert not result diff --git a/tests/service/test_json_message_dispatcher.py b/tests/service/test_json_message_dispatcher.py index 9116f4d..d5217ff 100644 --- a/tests/service/test_json_message_dispatcher.py +++ b/tests/service/test_json_message_dispatcher.py @@ -12,39 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import MagicMock, patch, mock_open + +import pytest from orjson import dumps -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest -from tb_mqtt_client.entities.data.rpc_request import RPCRequest -from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, BasicProvisioningCredentials, \ + X509ProvisioningCredentials, AccessTokenProvisioningCredentials +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse -from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType, \ - BasicProvisioningCredentials, X509ProvisioningCredentials, AccessTokenProvisioningCredentials - +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher -class DummyClaimRequest: - def __init__(self, secret_key="key"): - self.secret_key = secret_key - def to_payload_format(self): - return {"secretKey": self.secret_key} +@pytest.fixture +def dummy_provisioning_request(): + credentials = AccessTokenProvisioningCredentials( + provision_device_key="key", + provision_device_secret="secret", + access_token="token" + ) + return ProvisioningRequest("some-device", credentials, device_name="dev", gateway=False) -class DummyProvisioningRequest: - def __init__(self): - self.device_name = "dev" - self.gateway = True - self.credentials = MagicMock() - self.credentials.provision_device_key = "key" - self.credentials.provision_device_secret = "secret" - self.credentials.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN - self.credentials.access_token = "token" +def build_msg(device="devX", with_attr=False, with_ts=False): + builder = DeviceUplinkMessageBuilder().set_device_name(device) + if with_attr: + builder.add_attributes(AttributeEntry("a", 1)) + if with_ts: + builder.add_timeseries(TimeseriesEntry("t", 2, ts=1234567890)) + return builder.build() @pytest.fixture def dispatcher(): @@ -68,16 +75,15 @@ def test_build_attribute_request_invalid(dispatcher): def test_build_claim_request(dispatcher): - req = DummyClaimRequest() + req = ClaimRequest.build("secretKey") topic, payload = dispatcher.build_claim_request(req) assert topic == mqtt_topics.DEVICE_CLAIM_TOPIC assert b"secretKey" in payload def test_build_claim_request_invalid(dispatcher): - req = DummyClaimRequest(secret_key=None) # Simulating an invalid request # noqa with pytest.raises(ValueError): - dispatcher.build_claim_request(req) + req = ClaimRequest.build(secret_key=None) # Simulating an invalid request # noqa def test_build_rpc_request(dispatcher): @@ -212,5 +218,117 @@ def test_parse_rpc_response_invalid(dispatcher): dispatcher.parse_rpc_response(topic, b"bad") -if __name__ == '__main__': - pytest.main([__file__]) +@pytest.mark.asyncio +async def test_build_uplink_payloads_empty(dispatcher: JsonMessageDispatcher): + assert dispatcher.build_uplink_payloads([]) == [] + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_only_attributes(dispatcher: JsonMessageDispatcher): + msg = build_msg(with_attr=True) + with patch.object(dispatcher._splitter, "split_attributes", return_value=[msg]): + result = dispatcher.build_uplink_payloads([msg]) + assert len(result) == 1 + topic, payload, count, futures = result[0] + assert topic == DEVICE_ATTRIBUTES_TOPIC + assert count == 1 + assert b"a" in payload + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_only_timeseries(dispatcher: JsonMessageDispatcher): + msg = build_msg(with_ts=True) + with patch.object(dispatcher._splitter, "split_timeseries", return_value=[msg]): + result = dispatcher.build_uplink_payloads([msg]) + assert len(result) == 1 + topic, payload, count, futures = result[0] + assert topic == DEVICE_TELEMETRY_TOPIC + assert count == 1 + assert b"ts" in payload + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_both(dispatcher: JsonMessageDispatcher): + msg = build_msg(with_attr=True, with_ts=True) + with patch.object(dispatcher._splitter, "split_attributes", return_value=[msg]), \ + patch.object(dispatcher._splitter, "split_timeseries", return_value=[msg]): + result = dispatcher.build_uplink_payloads([msg]) + assert len(result) == 2 + topics = {r[0] for r in result} + assert DEVICE_ATTRIBUTES_TOPIC in topics + assert DEVICE_TELEMETRY_TOPIC in topics + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_multiple_devices(dispatcher: JsonMessageDispatcher): + msg1 = build_msg(device="dev1", with_attr=True) + msg2 = build_msg(device="dev2", with_ts=True) + with patch.object(dispatcher._splitter, "split_attributes", side_effect=lambda x: x), \ + patch.object(dispatcher._splitter, "split_timeseries", side_effect=lambda x: x): + result = dispatcher.build_uplink_payloads([msg1, msg2]) + topics = {r[0] for r in result} + assert DEVICE_ATTRIBUTES_TOPIC in topics or DEVICE_TELEMETRY_TOPIC in topics + + +def test_build_payload_with_device_name(dispatcher: JsonMessageDispatcher): + msg = build_msg(with_ts=True) + payload = dispatcher.build_payload(msg, True) + assert isinstance(payload, bytes) + assert msg.device_name.encode() in payload + + +def test_build_payload_without_device_name(dispatcher: JsonMessageDispatcher): + builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 9)) + msg = builder.build() + payload = dispatcher.build_payload(msg, False) + assert isinstance(payload, bytes) + assert b"x" in payload + + +def test_pack_attributes(): + builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 10)) + msg = builder.build() + result = JsonMessageDispatcher.pack_attributes(msg) + assert isinstance(result, dict) + assert "x" in result + + +def test_pack_timeseries_uses_now(monkeypatch): + monkeypatch.setattr("tb_mqtt_client.service.message_dispatcher.datetime", MagicMock()) + ts_entry = TimeseriesEntry("temp", 23, ts=None) + builder = DeviceUplinkMessageBuilder().add_timeseries(ts_entry) + msg = builder.build() + packed = JsonMessageDispatcher.pack_timeseries(msg) + assert isinstance(packed, list) + assert "ts" in packed[0] + assert "values" in packed[0] + + +def test_build_uplink_payloads_error_handling(dispatcher: JsonMessageDispatcher): + with patch("tb_mqtt_client.service.message_dispatcher.DeviceUplinkMessage.has_attributes", side_effect=Exception("boom")): + msg = build_msg(with_attr=True) + with pytest.raises(Exception, match="boom"): + dispatcher.build_uplink_payloads([msg]) + + +def test_parse_provisioning_response_success(dispatcher, dummy_provisioning_request): + payload_dict = {"status": "SUCCESS", "credentialsType": "ACCESS_TOKEN"} + payload_bytes = dumps(payload_dict) + + with patch.object(ProvisioningResponse, "build", return_value="SUCCESS_RESPONSE") as mock_build: + result = dispatcher.parse_provisioning_response(dummy_provisioning_request, payload_bytes) + assert result == "SUCCESS_RESPONSE" + mock_build.assert_called_once_with(dummy_provisioning_request, payload_dict) + + +def test_parse_provisioning_response_failure(dispatcher, dummy_provisioning_request): + broken_bytes = b"{not_json" + + with patch.object(ProvisioningResponse, "build", return_value="FAILURE_RESPONSE") as mock_build: + result = dispatcher.parse_provisioning_response(dummy_provisioning_request, broken_bytes) + assert result == "FAILURE_RESPONSE" + mock_build.assert_called_once() + args = mock_build.call_args[0] + assert args[0] == dummy_provisioning_request + assert args[1]["status"] == "FAILURE" + assert "errorMsg" in args[1] \ No newline at end of file diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py index 63fffd0..5637f39 100644 --- a/tests/service/test_message_queue.py +++ b/tests/service/test_message_queue.py @@ -13,43 +13,22 @@ # limitations under the License. import asyncio +from contextlib import AsyncExitStack +from time import time +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import AsyncMock, MagicMock -from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.service_keys import TELEMETRY_DATAPOINTS_RATE_LIMIT from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher from tb_mqtt_client.service.message_queue import MessageQueue -@pytest.mark.asyncio -async def test_publish_raw_bytes_success(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - main_stop_event = asyncio.Event() - - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] - - queue = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=main_stop_event, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher, - ) - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'test_payload', 1, qos=1) - await asyncio.sleep(0.05) - await queue.shutdown() - - mqtt_manager.publish.assert_called_once() - - @pytest.mark.asyncio async def test_batching_device_uplink_message(): mqtt_manager = MagicMock() @@ -114,7 +93,7 @@ async def test_telemetry_rate_limit_retry_triggered(): msg = DeviceUplinkMessageBuilder() msg.set_device_name("device") - msg.add_telemetry(TimeseriesEntry("temp", 1)) + msg.add_timeseries(TimeseriesEntry("temp", 1)) msg = msg.build() queue = MessageQueue( @@ -136,31 +115,19 @@ async def test_telemetry_rate_limit_retry_triggered(): @pytest.mark.asyncio async def test_shutdown_clears_queue(): mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - main_stop_event = asyncio.Event() - + mqtt_manager.publish = AsyncMock() dispatcher = MagicMock() dispatcher.splitter.max_payload_size = 100000 dispatcher.build_uplink_payloads.return_value = [] + stop_event = asyncio.Event() - dummy_message = MagicMock() - dummy_message.device_name = "device" - dummy_message.size = 1 - dummy_message.get_delivery_futures.return_value = [] - - queue = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=main_stop_event, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher - ) - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_message, 1, qos=1) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + dummy = MagicMock() + dummy.size = 1 + dummy.get_delivery_futures.return_value = [] + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy, 1, 1) await queue.shutdown() - assert queue.is_empty() @@ -234,47 +201,6 @@ async def test_rate_limit_telemetry_triggers_retry(): mqtt_manager.publish.assert_not_called() -@pytest.mark.asyncio -async def test_shutdown_clears_queue(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] - stop_event = asyncio.Event() - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) - dummy = MagicMock() - dummy.size = 1 - dummy.get_delivery_futures.return_value = [] - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy, 1, 1) - await queue.shutdown() - assert queue.is_empty() - - -@pytest.mark.asyncio -async def test_backpressure_triggers_retry(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = True - mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] - stop_event = asyncio.Event() - - msg = MagicMock() - msg.size = 1 - msg.device_name = "dev" - msg.get_delivery_futures.return_value = [] - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, 1) - await asyncio.sleep(0.1) - await queue.shutdown() - mqtt_manager.publish.assert_not_called() - - @pytest.mark.asyncio async def test_retry_on_exception(): mqtt_manager = MagicMock() @@ -378,5 +304,822 @@ async def test_rate_limit_refill_executes(): r3.refill.assert_awaited() -if __name__ == '__main__': - pytest.main([__file__]) +@pytest.mark.asyncio +async def test_try_publish_without_delivery_futures(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock(return_value=asyncio.Future()) + mqtt_manager.publish.return_value.set_result(PublishResult("t", 1, 1, 1, 1)) + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + await queue._try_publish("custom/topic", b"payload", datapoints=1, delivery_futures_or_none=None, qos=1) + await queue.shutdown() + + mqtt_manager.publish.assert_called_once() + + +@pytest.mark.asyncio +async def test_schedule_delayed_retry_skipped_if_inactive_or_stopped(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + stop_event.set() + + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue._active.clear() + + queue._schedule_delayed_retry("topic", b"data", datapoints=1, qos=1, delay=0.01) + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_clear_queue_sets_futures_to_publish_result(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + + with patch.object(MessageQueue, "_dequeue_loop", new=AsyncMock()): + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + dummy_msg = DeviceUplinkMessageBuilder() \ + .add_delivery_futures(asyncio.Future()) \ + .add_timeseries(TimeseriesEntry("temp", 1)) \ + .build() + + future = (await queue.publish("some/topic", dummy_msg, 1, 1))[0] + + queue.clear() + + assert future.done() + result = future.result() + assert isinstance(result, PublishResult) + assert result.topic == "some/topic" + assert result.payload_size == dummy_msg.size + assert result.reason_code == -1 + assert result.qos == 1 + assert result.message_id == -1 + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_wait_for_message_exit_on_inactive(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue._active.clear() + + with pytest.raises(asyncio.CancelledError): + await queue._wait_for_message() + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_schedule_delayed_retry_requeues_message(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + + with patch("tb_mqtt_client.service.message_queue.MessageQueue._dequeue_loop", new=AsyncMock()): + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + future = asyncio.Future() + dummy_msg = MagicMock() + dummy_msg.device_name = "dev" + dummy_msg.size = 10 + dummy_msg.get_delivery_futures.return_value = [future] + + queue._schedule_delayed_retry( + topic="retry/topic", + payload=b"retry-payload", + datapoints=1, + qos=1, + delay=0.05, + delivery_futures=[future] + ) + + await asyncio.sleep(0.1) + assert not queue.is_empty() + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_cancel_tasks_clears_all(): + mqtt_manager = MagicMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + async def dummy(): + await asyncio.sleep(1) + + task = asyncio.create_task(dummy()) + queue._retry_tasks.add(task) + + await queue._cancel_tasks(queue._retry_tasks) + assert len(queue._retry_tasks) == 0 + + +@pytest.mark.asyncio +async def test_clear_queue_with_bytes_message(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + + stop_event = asyncio.Event() + with patch.object(MessageQueue, "_dequeue_loop", new=AsyncMock()): + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + future = asyncio.Future() + await queue.publish("raw/topic", b"abc", 1, 0) + queue._queue._queue[0] = ("raw/topic", b"abc", [future], 1, 0) + + queue.clear() + assert future.done() + result = future.result() + assert result.topic == "raw/topic" + assert result.payload_size == 3 + assert result.qos == 0 + assert result.reason_code == -1 + assert result.message_id == -1 + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_resolve_attached_handles_publish_exception(): + future = asyncio.Future() + future.set_exception(RuntimeError("fail")) + + f1 = asyncio.Future() + f2 = asyncio.Future() + + dummy_payload = b"abc" + topic = "topic" + qos = 1 + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock(return_value=future) + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + await queue._try_publish( + topic=topic, + payload=dummy_payload, + datapoints=1, + delivery_futures_or_none=[f1, f2], + qos=qos + ) + + await asyncio.sleep(0.05) + assert f1.done() and f2.done() + for f in (f1, f2): + res = f.result() + assert isinstance(res, PublishResult) + assert res.reason_code == -1 + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_try_publish_message_type_non_telemetry(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + + rate_limit = MagicMock() + rate_limit.try_consume = AsyncMock(return_value=None) + rate_limit.to_dict.return_value = {} + rate_limit.minimal_timeout = 0.1 + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, message_rate_limit=rate_limit, + telemetry_rate_limit=None, telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher) + + await queue._try_publish( + topic="non/telemetry", + payload=b"x", + datapoints=1, + delivery_futures_or_none=[], + qos=0 + ) + mqtt_manager.publish.assert_called_once() + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_shutdown_rate_limit_task_cancel_only(): + mqtt_manager = MagicMock() + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager.publish = AsyncMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + # Cancel only the rate limit task before shutdown + queue._rate_limit_refill_task.cancel() + + await queue.shutdown() + assert queue._rate_limit_refill_task.cancelled() + + +@pytest.mark.asyncio +async def test_schedule_delayed_retry_when_main_stop_active(): + mqtt_manager = MagicMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + + stop_event = asyncio.Event() + stop_event.set() + + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + queue._active.clear() + + queue._schedule_delayed_retry("x", b"y", 1, 0, 0.01) + await asyncio.sleep(0.05) + assert queue._queue.empty() + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_publish_queue_full_sets_failed_result_for_bytes(): + mqtt_manager = MagicMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + stop_event = asyncio.Event() + + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher, max_queue_size=1) + await queue.publish("t", b"raw", 1, qos=0) + + queue._queue.put_nowait = MagicMock(side_effect=asyncio.QueueFull) + + result = await queue.publish("t", b"raw", 1, qos=0) + assert result is not None + assert isinstance(result[0], asyncio.Future) + await asyncio.sleep(0) + assert result[0].done() + assert result[0].result().reason_code == -1 + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_wait_for_message_raises_cancelled(): + mqtt_manager = MagicMock() + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + queue._active.clear() + + with pytest.raises(asyncio.CancelledError): + await queue._wait_for_message() + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_batch_loop_breaks_on_count_threshold(): + # Setup: fake PublishResult to return + publish_result = PublishResult( + topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + qos=1, + message_id=42, + payload_size=8, + reason_code=0 + ) + + # This is what the MQTT manager's publish will return + publish_future = asyncio.Future() + publish_future.set_result(publish_result) + + # Now mock the MQTT manager + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock(return_value=publish_future) + mqtt_manager.backpressure.should_pause.return_value = False + + # This is the future that the message queue should resolve + delivery_future = asyncio.Future() + + # Mock dispatcher to output the delivery future + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [ + (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batched", 1, [delivery_future]) + ] + + # Create and start the queue + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, stop_event, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher, + batch_collect_max_count=2 + ) + + dummy_msg = MagicMock() + dummy_msg.size = 10 + dummy_msg.device_name = "dev" + dummy_msg.get_delivery_futures.return_value = [delivery_future] + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, 1) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, 1) + + # Allow time for batching and publishing + await asyncio.sleep(0.1) + await queue.shutdown() + + assert mqtt_manager.publish.called + assert delivery_future.done() + result = delivery_future.result() + assert isinstance(result, PublishResult) + assert result.topic == mqtt_topics.DEVICE_TELEMETRY_TOPIC + + +@pytest.mark.asyncio +async def test_batch_loop_skips_message_on_size_exceed(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 15 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + small_msg = MagicMock() + small_msg.size = 10 + small_msg.device_name = "dev" + small_msg.get_delivery_futures.return_value = [] + + large_msg = MagicMock() + large_msg.size = 10 + large_msg.device_name = "dev" + large_msg.get_delivery_futures.return_value = [] + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, small_msg, 1, 1) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, large_msg, 1, 1) + + await asyncio.sleep(0.1) + await queue.shutdown() + + assert mqtt_manager.publish.called + + +@pytest.mark.asyncio +async def test_batch_requeues_on_size_exceed(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 15 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + msg1 = MagicMock() + msg1.size = 10 + msg1.device_name = "dev" + msg1.get_delivery_futures.return_value = [] + + msg2 = MagicMock() + msg2.size = 10 + msg2.device_name = "dev" + msg2.get_delivery_futures.return_value = [] + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg1, 1, 1) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg2, 1, 1) + + await asyncio.sleep(0.1) + await queue.shutdown() + + assert mqtt_manager.publish.call_count >= 1 + + +@pytest.mark.asyncio +async def test_batch_immediate_publish_on_raw_bytes(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"raw_payload", 1, 1) + + await asyncio.sleep(0.05) + await queue.shutdown() + + mqtt_manager.publish.assert_called() + args, kwargs = mqtt_manager.publish.call_args + assert isinstance(kwargs['payload'], bytes) + + +@pytest.mark.asyncio +async def test_batch_queue_empty_breaks_safely(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + + await asyncio.sleep(0.05) + await queue.shutdown() + + assert mqtt_manager.publish.call_count == 0 + + +@pytest.mark.asyncio +async def test_try_publish_telemetry_rate_limited(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 1000 + dispatcher.build_uplink_payloads.return_value = [("topic", b"{}", 3, [])] + telemetry_rate_limit = MagicMock() + telemetry_rate_limit.try_consume = AsyncMock(return_value=(10, 1)) + telemetry_rate_limit.minimal_timeout = 0.5 + + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, + stop_event, + telemetry_rate_limit, + None, + RateLimit("10:1", TELEMETRY_DATAPOINTS_RATE_LIMIT, 100), + dispatcher + ) + + queue._schedule_delayed_retry = MagicMock() + + msg = (DeviceUplinkMessageBuilder().add_timeseries( + [TimeseriesEntry("temp", 1), + TimeseriesEntry("hum", 2), + TimeseriesEntry("pres", 3)]) + .build()) + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, msg.timeseries_datapoint_count(), 1) + await asyncio.sleep(0.1) + await queue.shutdown() + + queue._schedule_delayed_retry.assert_called_once() + mqtt_manager.publish.assert_not_called() + + +@pytest.mark.asyncio +async def test_try_publish_non_telemetry_rate_limited(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 1000 + dispatcher.build_uplink_payloads.return_value = [("topic", b"{}", 1)] + + message_rate_limit = MagicMock() + message_rate_limit.try_consume = AsyncMock(return_value=(5, 60)) + message_rate_limit.minimal_timeout = 1.0 + + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, + stop_event, + telemetry_rate_limit=None, + message_rate_limit=message_rate_limit, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher + ) + + queue._schedule_delayed_retry = MagicMock() + + payload = b'raw-bytes' + topic = "v1/devices/me/rpc/request/1" + + await queue._try_publish(topic, payload, datapoints=0, qos=1, delivery_futures_or_none=None) + await asyncio.sleep(0.05) + await queue.shutdown() + + message_rate_limit.try_consume.assert_awaited_once_with(1) + queue._schedule_delayed_retry.assert_called_once_with( + topic=topic, + payload=payload, + datapoints=0, + qos=1, + delay=1.0, + delivery_futures=[] + ) + mqtt_manager.publish.assert_not_called() + + +@pytest.mark.parametrize("paused", [True, False]) +@pytest.mark.asyncio +async def test_backpressure_delays_publish(paused, monkeypatch): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = paused + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, + stop_event, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_dispatcher=dispatcher, + batch_collect_max_count=1 + ) + + scheduled_retry_mock = MagicMock() + monkeypatch.setattr(queue, "_schedule_delayed_retry", scheduled_retry_mock) + + await queue._try_publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"test_payload", 1, qos=1) + await asyncio.sleep(0.05) + + if paused: + scheduled_retry_mock.assert_called_once_with( + topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + payload=b"test_payload", + datapoints=1, + qos=1, + delay=1.0, + delivery_futures=[] + ) + mqtt_manager.publish.assert_not_called() + else: + scheduled_retry_mock.assert_not_called() + mqtt_manager.publish.assert_called_once_with( + message_or_topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + payload=b"test_payload", + qos=1 + ) + + await queue.shutdown() + + +@pytest.mark.asyncio +async def test_publish_telemetry_rate_limit_triggered(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [] + + stop_event = asyncio.Event() + + telemetry_dp_rate_limit = MagicMock() + telemetry_dp_rate_limit.try_consume = AsyncMock(return_value=(10, 60)) + telemetry_dp_rate_limit.minimal_timeout = 1.23 + + async with AsyncExitStack() as stack: + queue = MessageQueue( + mqtt_manager, + stop_event, + None, + None, + telemetry_dp_rate_limit=telemetry_dp_rate_limit, + message_dispatcher=dispatcher, + ) + stack.push_async_callback(queue.shutdown) + + msg = DeviceUplinkMessageBuilder() \ + .add_timeseries([TimeseriesEntry(f"temp{i}", i) for i in range(10)]) \ + .build() + + with patch.object(queue, "_schedule_delayed_retry", wraps=queue._schedule_delayed_retry) as delayed_retry_mock: + await queue._try_publish( + mqtt_topics.DEVICE_TELEMETRY_TOPIC, + b'', + qos=1, + datapoints=msg.timeseries_datapoint_count() + ) + + mqtt_manager.publish.assert_not_called() + delayed_retry_mock.assert_called_once() + args, kwargs = delayed_retry_mock.call_args + assert kwargs["topic"] == mqtt_topics.DEVICE_TELEMETRY_TOPIC + assert kwargs["delay"] == telemetry_dp_rate_limit.minimal_timeout + + +@pytest.mark.asyncio +async def test_batch_loop_large_messages_are_split_and_published(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = JsonMessageDispatcher(100, 20) + + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, + stop_event, + None, + None, + None, + message_dispatcher=dispatcher, + max_queue_size=100, + batch_collect_max_time_ms=10 + ) + + builder1 = DeviceUplinkMessageBuilder() + for i in range(30): + builder1.add_timeseries(TimeseriesEntry(f"t{i}", i)) + message1 = builder1.build() + + builder2 = DeviceUplinkMessageBuilder() + for i in range(30, 60): + builder2.add_timeseries(TimeseriesEntry(f"t{i}", i)) + message2 = builder2.build() + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, message1, message1.timeseries_datapoint_count(), qos=1) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, message2, message2.timeseries_datapoint_count(), qos=1) + + await asyncio.sleep(0.2) + await queue.shutdown() + + assert mqtt_manager.publish.call_count == 6 + + for call in mqtt_manager.publish.call_args_list: + kwargs = call.kwargs + assert kwargs["message_or_topic"] == "v1/devices/me/telemetry" + assert isinstance(kwargs["payload"], bytes) + assert kwargs["qos"] == 1 + + +@pytest.mark.asyncio +async def test_delivery_futures_resolved_via_real_puback_handler(): + delivery_future = asyncio.Future() + + mqtt_future = asyncio.Future() + mqtt_future.mid = 123 + + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock(return_value=mqtt_future) + mqtt_manager.backpressure.should_pause.return_value = False + mqtt_manager._backpressure = MagicMock() + mqtt_manager._on_publish_result_callback = None + + from tb_mqtt_client.service.mqtt_manager import MQTTManager + mqtt_manager._handle_puback_reason_code = MQTTManager._handle_puback_reason_code.__get__(mqtt_manager) + + topic = mqtt_topics.DEVICE_TELEMETRY_TOPIC + qos = 1 + payload_size = 24 + publish_time = 1.0 + mqtt_manager._pending_publishes = { + mqtt_future.mid: (delivery_future, topic, qos, payload_size, publish_time) + } + + dispatcher = MagicMock() + dispatcher.splitter.max_payload_size = 100000 + dispatcher.build_uplink_payloads.return_value = [ + (topic, b'{"some":"payload"}', qos, [delivery_future]) + ] + + stop_event = asyncio.Event() + queue = MessageQueue( + mqtt_manager, + stop_event, + None, + None, + None, + message_dispatcher=dispatcher, + batch_collect_max_count=1, + batch_collect_max_time_ms=1 + ) + + msg = ( + DeviceUplinkMessageBuilder() + .set_device_name("deviceA") + .add_timeseries([TimeseriesEntry("temperature", 25)]) + .add_delivery_futures([delivery_future]) + .build() + ) + + await queue.publish(topic, msg, msg.timeseries_datapoint_count(), qos=qos) + await asyncio.sleep(0.05) + + mqtt_future.set_result(None) + mqtt_manager._handle_puback_reason_code(mqtt_future.mid, 0, {}) + + await asyncio.sleep(0.05) + + assert delivery_future.done() + result = delivery_future.result() + assert isinstance(result, PublishResult) + assert result.topic == topic + assert result.message_id == mqtt_future.mid + assert result.reason_code == 0 + + +@pytest.mark.asyncio +async def test_batch_append_and_batch_size_accumulate(): + mqtt_manager = MagicMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + dispatcher = JsonMessageDispatcher(100000, 10000) + stop_event = asyncio.Event() + + queue = MessageQueue( + mqtt_manager, + stop_event, + None, + None, + None, + message_dispatcher=dispatcher, + batch_collect_max_count=2, + batch_collect_max_time_ms=1000 + ) + + fixed_ts = int(time() * 1000) + + msg1 = DeviceUplinkMessageBuilder() \ + .add_timeseries([TimeseriesEntry(f"temp{i}", i, ts=fixed_ts) for i in range(10)]) \ + .build() + msg2 = DeviceUplinkMessageBuilder() \ + .add_timeseries([TimeseriesEntry(f"temp{i}", i, ts=fixed_ts) for i in range(10)]) \ + .build() + + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg1, msg1.timeseries_datapoint_count(), 1) + await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg2, msg2.timeseries_datapoint_count(), 1) + + await asyncio.sleep(0.3) + await queue.shutdown() + + mqtt_manager.publish.assert_called_once() + + args, kwargs = mqtt_manager.publish.call_args + assert kwargs["message_or_topic"] == mqtt_topics.DEVICE_TELEMETRY_TOPIC + assert isinstance(kwargs["payload"], bytes) + assert kwargs["qos"] == 1 diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index bffc00e..6bc0e74 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -13,9 +13,14 @@ # limitations under the License. import asyncio -import pytest from unittest.mock import MagicMock, patch +import pytest + +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher from tb_mqtt_client.service.message_splitter import MessageSplitter @@ -115,7 +120,7 @@ def test_builder_failure_during_split_raises(mock_builder_class): builder_instance = MagicMock() builder_instance.set_device_name.return_value = builder_instance builder_instance.set_device_profile.return_value = builder_instance - builder_instance.add_telemetry.return_value = None + builder_instance.add_timeseries.return_value = None builder_instance._timeseries = [entry] builder_instance.build.side_effect = RuntimeError("build failed") mock_builder_class.return_value = builder_instance @@ -187,5 +192,49 @@ def test_datapoint_setter_validation(): assert s.max_datapoints == 0 -if __name__ == '__main__': - pytest.main([__file__]) +@pytest.mark.asyncio +async def test_split_attributes_grouping(): + dispatcher = JsonMessageDispatcher(max_payload_size=200, max_datapoints=5) + + builder1 = DeviceUplinkMessageBuilder().set_device_name("deviceA").set_device_profile("default") + builder2 = DeviceUplinkMessageBuilder().set_device_name("deviceA").set_device_profile("default") + + for i in range(3): + builder1.add_attributes(AttributeEntry(f"key_{i}", i)) + + for i in range(3, 6): + builder2.add_attributes(AttributeEntry(f"key_{i}", i)) + + + messages = [builder1.build(), builder2.build()] + result = dispatcher.splitter.split_attributes(messages) + + assert len(result) == 2 + total_attrs = sum(len(msg.attributes) for msg in result) + assert total_attrs == 6 + assert all(msg.device_name == "deviceA" for msg in result) + for msg in result: + for fut in msg.get_delivery_futures(): + fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) + + +@pytest.mark.asyncio +async def test_split_attributes_different_devices_not_grouped(): + dispatcher = JsonMessageDispatcher(max_payload_size=200, max_datapoints=100) + + builder1 = DeviceUplinkMessageBuilder().set_device_name("deviceA") + builder2 = DeviceUplinkMessageBuilder().set_device_name("deviceB") + + for i in range(3): + builder1.add_attributes(AttributeEntry(f"key_{i}", i)) + builder2.add_attributes(AttributeEntry(f"key_{i+3}", i+3)) + + result = dispatcher.splitter.split_attributes([builder1.build(), builder2.build()]) + + assert len(result) == 2 + assert result[0].device_name != result[1].device_name + for msg in result: + for fut in msg.get_delivery_futures(): + fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) + + diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index a4066bb..d46685f 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -13,16 +13,18 @@ # limitations under the License. import asyncio +from time import monotonic +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock, call + import pytest import pytest_asyncio -from unittest.mock import AsyncMock, MagicMock, patch -from time import monotonic -from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT -from tb_mqtt_client.service.mqtt_manager import MQTTManager from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.mqtt_manager import MQTTManager, IMPLEMENTATION_SPECIFIC_ERROR, QUOTA_EXCEEDED @pytest_asyncio.fixture @@ -64,7 +66,9 @@ async def test_is_connected_returns_false_if_not_ready(setup_manager): @pytest.mark.asyncio async def test_register_and_unregister_handler(setup_manager): manager, *_ = setup_manager + async def dummy(topic, payload): pass + manager.register_handler("topic/+", dummy) assert "topic/+" in manager._handlers manager.unregister_handler("topic/+") @@ -112,6 +116,7 @@ async def bad_handler(topic, payload): def test_match_topic_full_wildcard(): assert MQTTManager._match_topic("#", "any/depth/of/topic") + @pytest.mark.asyncio async def test_publish_fails_without_rate_limits(setup_manager): manager, *_ = setup_manager @@ -225,7 +230,6 @@ async def test_subscribe_adds_future(setup_manager): mock_rate_limit.consume.assert_awaited_once() - @pytest.mark.asyncio async def test_unsubscribe_adds_future(setup_manager): manager, *_ = setup_manager @@ -242,5 +246,147 @@ async def test_unsubscribe_adds_future(setup_manager): mock_rate_limit.consume.assert_awaited_once() -if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file +@pytest.mark.asyncio +async def test_register_claiming_future_triggers_event(setup_manager): + manager, *_ = setup_manager + future = asyncio.Future() + manager.register_claiming_future(future) + assert manager._claiming_future is future + + +@pytest.mark.asyncio +async def test_publish_qos_zero_sets_result_immediately(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._rate_limits_ready_event.set() + + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (99, b"packet") + manager._client._persistent_storage = MagicMock() + + result = await manager.publish("topic", b"payload", qos=0, force=True) + assert result.done() + assert result.result() is True + + +@pytest.mark.asyncio +async def test_on_subscribe_internal_sets_future(setup_manager): + manager, *_ = setup_manager + future = asyncio.Future() + manager._pending_subscriptions[5] = future + manager._on_subscribe_internal(manager._client, 5, 1, {}) + assert future.done() + assert future.result() == 5 + + +@pytest.mark.asyncio +async def test_on_unsubscribe_internal_sets_future(setup_manager): + manager, *_ = setup_manager + future = asyncio.Future() + manager._pending_unsubscriptions[11] = future + manager._on_unsubscribe_internal(manager._client, 11, {}) + assert future.done() + assert future.result() == 11 + + +@pytest.mark.asyncio +async def test_handle_puback_reason_code_errors(setup_manager): + manager, *_ = setup_manager + + f1 = asyncio.Future() + manager._pending_publishes[1] = (f1, "topic", 1, 100, 0) + manager._handle_puback_reason_code(1, IMPLEMENTATION_SPECIFIC_ERROR, {}) + assert f1.result().reason_code == IMPLEMENTATION_SPECIFIC_ERROR + + f2 = asyncio.Future() + manager._pending_publishes[2] = (f2, "topic", 1, 100, 0) + manager._handle_puback_reason_code(2, QUOTA_EXCEEDED, {}) + assert f2.result().reason_code == QUOTA_EXCEEDED + + manager._handle_puback_reason_code(9999, 1, {}) # Should log warning, not crash + + +@pytest.mark.asyncio +async def test_connect_loop_retry_and_success(setup_manager): + manager, stop_event, *_ = setup_manager + manager._connected_event.set() + + manager._client.connect = AsyncMock(side_effect=[Exception("fail1"), AsyncMock()]) + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock) as mock_connected, \ + patch("asyncio.sleep", new_callable=AsyncMock), \ + patch.object(manager, "_connect_params", new=("host", 1883, None, None, False, 60, None)): + + mock_connected.side_effect = [False, False, True] + + await asyncio.wait_for(manager._connect_loop(), timeout=1) + + +@pytest.mark.asyncio +async def test_request_rate_limits_timeout(setup_manager): + manager, stop_event, _, _, _, _, rate_handler, _ = setup_manager + dispatcher = manager._message_dispatcher + + req_mock = MagicMock() + req_mock.request_id = "req-id" + + dispatcher.build_rpc_request.return_value = ("topic", b"payload") + + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (999, b"fake_packet") + manager._client._persistent_storage = MagicMock() + manager._client._persistent_storage.push_message_nowait = MagicMock() + + future = asyncio.Future() + future.set_result(None) + manager._rpc_response_handler.register_request.return_value = future + + with patch("tb_mqtt_client.entities.data.rpc_request.RPCRequest.build", return_value=req_mock): + await manager._MQTTManager__request_rate_limits() + assert manager._rate_limits_ready_event.is_set() + + +@pytest.mark.asyncio +async def test_monitor_ack_timeouts_stops_gracefully(setup_manager): + manager, stop_event, *_ = setup_manager + stop_event.set() + await manager._monitor_ack_timeouts() + + +@pytest.mark.asyncio +async def test_match_topic_exact_match_and_failures(): + assert MQTTManager._match_topic("a/b/c", "a/b/c") + assert not MQTTManager._match_topic("a/b/c", "a/b") + assert not MQTTManager._match_topic("a/+/c", "a/x") + + +@pytest.mark.asyncio +async def test_disconnect_reason_code_142_triggers_special_flow(setup_manager): + manager, *_ = setup_manager + manager._client = MagicMock() + manager._backpressure = MagicMock() + manager._on_disconnect_callback = AsyncMock() + + rate_limit = MagicMock(spec=RateLimit) + manager._MQTTManager__rate_limiter = {"messages": rate_limit} + + fut = asyncio.Future() + manager._pending_publishes[42] = (fut, "topic", 1, 100, 0) + + with patch("asyncio.get_event_loop") as mock_get_loop, \ + patch("asyncio.create_task", side_effect=lambda coro: asyncio.ensure_future(coro)): + + mock_loop = MagicMock() + mock_loop.run_until_complete.return_value = (None, 1, 1) # simulate reached_time = 1 + mock_get_loop.return_value = mock_loop + + manager._on_disconnect_internal(manager._client, reason_code=142) + await asyncio.sleep(0.05) + + assert fut.done() + manager._backpressure.notify_disconnect.assert_has_calls([ + call(delay_seconds=10), + call(delay_seconds=1), + ]) + manager._on_disconnect_callback.assert_awaited_once() From 36eacc67fb443a2c00716983791c80a0fdeb8387 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 27 Jun 2025 14:47:37 +0300 Subject: [PATCH 34/74] Added fixes for found bugs --- tb_mqtt_client/common/provisioning_client.py | 17 ++- .../entities/data/device_uplink_message.py | 14 ++- .../entities/data/provisioning_request.py | 18 +-- .../entities/data/timeseries_entry.py | 2 +- tb_mqtt_client/service/device/client.py | 58 +++++---- tb_mqtt_client/service/message_dispatcher.py | 24 ++++ tb_mqtt_client/service/message_queue.py | 21 ++-- tb_mqtt_client/service/message_splitter.py | 118 +++++++++++------- 8 files changed, 165 insertions(+), 107 deletions(-) diff --git a/tb_mqtt_client/common/provisioning_client.py b/tb_mqtt_client/common/provisioning_client.py index a2718c3..940d800 100644 --- a/tb_mqtt_client/common/provisioning_client.py +++ b/tb_mqtt_client/common/provisioning_client.py @@ -13,9 +13,9 @@ # limitations under the License. from asyncio import Event +from typing import Union, Optional from gmqtt import Client as GMQTTClient -from orjson import loads from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger @@ -28,7 +28,7 @@ class ProvisioningClient: - def __init__(self, host, port, provision_request: 'ProvisioningRequest'): + def __init__(self, host: str, port: int, provision_request: ProvisioningRequest): self._log = logger self._stop_event = Event() self._host = host @@ -39,14 +39,14 @@ def __init__(self, host, port, provision_request: 'ProvisioningRequest'): self._client.on_connect = self._on_connect self._client.on_message = self._on_message self._provisioned = Event() - self._device_config: 'DeviceConfig' = None - self._json_message_dispatcher = JsonMessageDispatcher() + self._device_config: Optional[Union[DeviceConfig, ProvisioningResponse]] = None + self.__message_dispatcher = JsonMessageDispatcher() def _on_connect(self, client, _, rc, __): if rc == 0: self._log.debug("[Provisioning client] Connected to ThingsBoard") client.subscribe(PROVISION_RESPONSE_TOPIC) - topic, payload = self._json_message_dispatcher.build_provision_request(self._provision_request) + topic, payload = self.__message_dispatcher.build_provision_request(self._provision_request) self._log.debug("[Provisioning client] Sending provisioning request %s" % payload) client.publish(topic, payload) else: @@ -57,11 +57,8 @@ def _on_connect(self, client, _, rc, __): self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) async def _on_message(self, _, __, payload, ___, ____): - decoded_payload = payload.decode("UTF-8") - self._log.debug("[Provisioning client] Received data from ThingsBoard: %s" % decoded_payload) - decoded_message = loads(decoded_payload) - - self._device_config = ProvisioningResponse.build(self._provision_request, decoded_message) + provisioning_response = self.__message_dispatcher.parse_provisioning_response(self._provision_request, payload) + self._device_config = provisioning_response.result await self._client.disconnect() self._provisioned.set() diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index ec9cfaf..96927d4 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -18,9 +18,9 @@ from typing import List, Optional, Union, OrderedDict, Tuple, Mapping from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.common.publish_result import PublishResult logger = get_logger(__name__) @@ -37,7 +37,8 @@ class DeviceUplinkMessage: _size: int def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") + raise TypeError( + "Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") def __repr__(self): return (f"DeviceUplinkMessage(device_name={self.device_name}, " @@ -58,7 +59,8 @@ def build(cls, object.__setattr__(self, 'device_name', device_name) object.__setattr__(self, 'device_profile', device_profile) object.__setattr__(self, 'attributes', tuple(attributes)) - object.__setattr__(self, 'timeseries', MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) + object.__setattr__(self, 'timeseries', + MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) object.__setattr__(self, '_size', size) return self @@ -112,7 +114,8 @@ def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]] self.__size += attribute.size return self - def add_telemetry(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry], OrderedDict[int, List[TimeseriesEntry]]]) -> 'DeviceUplinkMessageBuilder': + def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry], OrderedDict[ + int, List[TimeseriesEntry]]]) -> 'DeviceUplinkMessageBuilder': if isinstance(timeseries, OrderedDict): self._timeseries = timeseries return self @@ -133,7 +136,8 @@ def add_telemetry(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry] self.__size += timeseries_entry.size return self - def add_delivery_futures(self, futures: Union[asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': + def add_delivery_futures(self, futures: Union[ + asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': if not isinstance(futures, list): futures = [futures] if futures: diff --git a/tb_mqtt_client/entities/data/provisioning_request.py b/tb_mqtt_client/entities/data/provisioning_request.py index 18120b5..7493e69 100644 --- a/tb_mqtt_client/entities/data/provisioning_request.py +++ b/tb_mqtt_client/entities/data/provisioning_request.py @@ -18,8 +18,16 @@ from tb_mqtt_client.constants.provisioning import ProvisioningCredentialsType +class ProvisioningCredentials(ABC): + @abstractmethod + def __init__(self, provision_device_key: str, provision_device_secret: str): + self.provision_device_key = provision_device_key + self.provision_device_secret = provision_device_secret + self.credentials_type: ProvisioningCredentialsType + + class ProvisioningRequest: - def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "1883", + def __init__(self, host, credentials: ProvisioningCredentials, port: int = 1883, device_name: Optional[str] = None, gateway: Optional[bool] = False): self.host = host self.port = port @@ -28,14 +36,6 @@ def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "18 self.gateway = gateway -class ProvisioningCredentials(ABC): - @abstractmethod - def __init__(self, provision_device_key: str, provision_device_secret: str): - self.provision_device_key = provision_device_key - self.provision_device_secret = provision_device_secret - self.credentials_type: ProvisioningCredentialsType - - class AccessTokenProvisioningCredentials(ProvisioningCredentials): def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: Optional[str] = None): super().__init__(provision_device_key, provision_device_secret) diff --git a/tb_mqtt_client/entities/data/timeseries_entry.py b/tb_mqtt_client/entities/data/timeseries_entry.py index 3970a70..0136403 100644 --- a/tb_mqtt_client/entities/data/timeseries_entry.py +++ b/tb_mqtt_client/entities/data/timeseries_entry.py @@ -23,7 +23,7 @@ def __init__(self, key: str, value: JSONCompatibleType, ts: Optional[int] = None super().__init__(key, value, ts) def __repr__(self): - return f"TelemetryEntry(key={self.key}, value={self.value}, ts={self.ts})" + return f"TimeseriesEntry(key={self.key}, value={self.value}, ts={self.ts})" def as_dict(self) -> dict: result = { diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 45ba153..bd6d926 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ssl from asyncio import sleep, wait_for, TimeoutError, Event, Future from random import choices from string import ascii_uppercase, digits @@ -23,6 +24,8 @@ from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.provisioning_client import ProvisioningClient +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer from tb_mqtt_client.constants import mqtt_topics @@ -33,13 +36,11 @@ from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.common.provisioning_client import ProvisioningClient -from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest -from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler @@ -72,11 +73,13 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): super().__init__(self._config.host, self._config.port, client_id) self._message_queue: Optional[MessageQueue] = None - self._message_dispatcher: MessageDispatcher = JsonMessageDispatcher(1000, 1) # Will be updated after connection established + self._message_dispatcher: MessageDispatcher = JsonMessageDispatcher(1000, + 1) # Will be updated after connection established self._messages_rate_limit = RateLimit("0:0,", name="messages") self._telemetry_rate_limit = RateLimit("0:0,", name="telemetry") self._telemetry_dp_rate_limit = RateLimit("0:0,", name="telemetryDataPoints") + self._ssl_context = None self.max_payload_size = None self._max_inflight_messages = 100 self._max_uplink_message_queue_size = 10000 @@ -107,13 +110,11 @@ async def update_firmware(self, on_received_callback: Optional[Callable[[str], A async def connect(self): logger.info("Connecting to platform at %s:%s", self._host, self._port) - ssl_context = None tls = self._config.use_tls() if tls: - import ssl - ssl_context = ssl.create_default_context() - ssl_context.load_verify_locations(self._config.ca_cert) - ssl_context.load_cert_chain(certfile=self._config.client_cert, keyfile=self._config.private_key) + self._ssl_context = ssl.create_default_context() + self._ssl_context.load_verify_locations(self._config.ca_cert) + self._ssl_context.load_cert_chain(certfile=self._config.client_cert, keyfile=self._config.private_key) await self._mqtt_manager.connect( host=self._host, @@ -121,7 +122,7 @@ async def connect(self): username=self._config.access_token or self._config.username, password=None if self._config.access_token else self._config.password, tls=tls, - ssl_context=ssl_context + ssl_context=self._ssl_context ) while not self._mqtt_manager.is_connected(): @@ -276,7 +277,8 @@ async def send_rpc_request( else: logger.warning("Timed out waiting for RPC response, but callback is set. " "Callback will be called with None response.") - await self._rpc_response_handler.handle(mqtt_topics.build_device_rpc_response_topic(rpc_request.request_id), e) + await self._rpc_response_handler.handle( + mqtt_topics.build_device_rpc_response_topic(rpc_request.request_id), e) async def send_rpc_response(self, response: RPCResponse): topic, payload = self._message_dispatcher.build_rpc_response(response) @@ -287,7 +289,7 @@ async def send_rpc_response(self, response: RPCResponse): async def send_attribute_request(self, attribute_request: AttributeRequest, - callback: Callable[[RequestedAttributeResponse], Awaitable[None]],): + callback: Callable[[RequestedAttributeResponse], Awaitable[None]], ): await self._requested_attribute_response_handler.register_request(attribute_request, callback) topic, payload = self._message_dispatcher.build_attribute_request(attribute_request) @@ -300,13 +302,14 @@ async def send_attribute_request(self, async def claim_device(self, claim_request: ClaimRequest, wait_for_publish: bool = True, - timeout: int = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: + timeout: float = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: topic, payload = self._message_dispatcher.build_claim_request(claim_request) self.__claiming_response_future = Future() await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) if wait_for_publish: try: - return await await_or_stop(self.__claiming_response_future, timeout=timeout, stop_event=self._stop_event) + return await await_or_stop(self.__claiming_response_future, timeout=timeout, + stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for telemetry publish result") return PublishResult(topic, 1, -1, len(payload), -1) @@ -329,8 +332,10 @@ async def _on_connect(self): return self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, self._handle_requested_attribute_response) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, + self._handle_rpc_request) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, + self._handle_requested_attribute_response) # noqa # RPC responses are handled by the RPCResponseHandler, which is already registered async def _on_disconnect(self): @@ -338,7 +343,8 @@ async def _on_disconnect(self): self._requested_attribute_response_handler.clear() self._rpc_response_handler.clear() - async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Union[RPCResponse, None]: + async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Union[ + RPCResponse, None]: """ Initiates a client-side RPC to ThingsBoard and awaits the result. :param method: The RPC method to call. @@ -406,7 +412,7 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa if "maxPayloadSize" in response.result: self.max_payload_size = int(response.result["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) # Update the dispatcher's max_payload_size if it's already initialized - if hasattr(self, '_dispatcher') and self._message_dispatcher is not None: + if self._message_dispatcher is not None and hasattr(self._message_dispatcher, 'splitter'): self._message_dispatcher.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) else: @@ -417,12 +423,12 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa self.max_payload_size = 65535 logger.debug("Using default max_payload_size: %d", self.max_payload_size) # Update the dispatcher's max_payload_size if it's already initialized - if hasattr(self, '_dispatcher') and self._message_dispatcher is not None: + if self._message_dispatcher is not None and hasattr(self._message_dispatcher, 'splitter'): self._message_dispatcher.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) if (not self._messages_rate_limit.has_limit() - and not self._telemetry_rate_limit.has_limit() + and not self._telemetry_rate_limit.has_limit() and not self._telemetry_dp_rate_limit.has_limit()): self._max_queued_messages = 50000 logger.debug("No rate limits, setting max_queued_messages to 50000") @@ -466,9 +472,9 @@ async def __on_publish_result(self, publish_result: PublishResult): @staticmethod def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], - TimeseriesEntry, - List[TimeseriesEntry], - List[Dict[str, Any]]]) -> DeviceUplinkMessage: + TimeseriesEntry, + List[TimeseriesEntry], + List[Dict[str, Any]]]) -> DeviceUplinkMessage: timeseries_entries = [] if isinstance(payload, TimeseriesEntry): timeseries_entries.append(payload) @@ -486,7 +492,7 @@ def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], raise ValueError(f"Unsupported payload type for telemetry: {type(payload).__name__}") builder = DeviceUplinkMessageBuilder() - builder.add_telemetry(timeseries_entries) + builder.add_timeseries(timeseries_entries) return builder.build() @staticmethod @@ -509,8 +515,8 @@ def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> L @staticmethod def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], - AttributeEntry, - List[AttributeEntry]]) -> DeviceUplinkMessage: + AttributeEntry, + List[AttributeEntry]]) -> DeviceUplinkMessage: if isinstance(payload, dict): payload = [AttributeEntry(k, v) for k, v in payload.items()] diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 110ca99..2ab326c 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -28,6 +28,7 @@ from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -133,6 +134,14 @@ def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RP """ pass + @abstractmethod + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, payload: bytes) -> 'ProvisioningResponse': + """ + Parse the provisioning response from the given payload. + This method should be implemented to handle the specific format of the provisioning response. + """ + pass + class JsonMessageDispatcher(MessageDispatcher): """ @@ -211,6 +220,21 @@ def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RP logger.error("Failed to parse RPC response: %s", str(e)) raise ValueError("Invalid RPC response format") from e + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, payload: bytes) -> 'ProvisioningResponse': + """ + Parse the provisioning response from the given payload. + :param provisioning_request: The ProvisioningRequest that initiated the provisioning. + :param payload: The raw bytes of the payload. + :return: An instance of ProvisioningResponse. + """ + try: + data = loads(payload) + logger.trace("Parsing provisioning response from payload: %s", data) + return ProvisioningResponse.build(provisioning_request, data) + except Exception as e: + logger.error("Failed to parse provisioning response: %s", str(e)) + return ProvisioningResponse.build(provisioning_request, {"status": "FAILURE", "errorMsg": str(e)}) + @property def splitter(self) -> MessageSplitter: return self._splitter diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index c1b8e19..3a0d014 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -50,6 +50,7 @@ def __init__(self, self._backpressure = self._mqtt_manager.backpressure self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} + # Queue expects tuples of (topic, payload, delivery_futures, datapoints_count, qos) self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) self._pending_queue_tasks: set[asyncio.Task] = set() self._active = asyncio.Event() @@ -63,7 +64,7 @@ def __init__(self, max_queue_size, self._batch_max_time, batch_collect_max_count) async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: - delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) else [] + delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) else [asyncio.Future()] try: logger.trace("publish() received delivery future id: %r for topic=%s", id(delivery_futures[0]) if delivery_futures else -1, topic) @@ -72,7 +73,7 @@ async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], topic, datapoints_count, type(payload).__name__) except asyncio.QueueFull: logger.error("Message queue full. Dropping message for topic %s", topic) - for future in payload.get_delivery_futures(): + for future in delivery_futures: if future: future.set_result(PublishResult(topic, qos, -1, len(payload), -1)) return delivery_futures or None @@ -357,16 +358,18 @@ async def _cancel_tasks(self, tasks: set[asyncio.Task]): def is_empty(self): return self._queue.empty() - def size(self): - return self._queue.qsize() - def clear(self): logger.debug("Clearing message queue...") while not self._queue.empty(): - _, message, _, _, _ = self._queue.get_nowait() - if isinstance(message, DeviceUplinkMessage) and message.get_delivery_futures(): - for future in message.get_delivery_futures(): - future.set_result(False) + topic, message, delivery_futures, _, qos = self._queue.get_nowait() + for future in delivery_futures: + future.set_result(PublishResult( + topic=topic, + qos=qos, + message_id=-1, + payload_size=message.size if isinstance(message, DeviceUplinkMessage) else len(message), + reason_code=-1 + )) self._queue.task_done() logger.debug("Message queue cleared.") diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index 8f187d3..b17062b 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -13,10 +13,13 @@ # limitations under the License. import asyncio -from typing import List +from collections import defaultdict +from typing import List, Optional, Dict, Tuple from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry logger = get_logger(__name__) @@ -37,43 +40,57 @@ def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: logger.trace("Splitting timeseries for %d messages", len(messages)) + if (len(messages) == 1 - and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa + and ((messages[0].attributes_datapoint_count() + messages[ + 0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa and messages[0].size <= self._max_payload_size): return messages result: List[DeviceUplinkMessage] = [] - for message in messages: - if not message.has_timeseries(): - logger.trace("Message from device '%s' has no timeseries. Skipping.", message.device_name) - continue + grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) + for msg in messages: + key = (msg.device_name, msg.device_profile) + grouped[key].append(msg) - logger.trace("Processing timeseries from device: %s", message.device_name) - builder = None + for (device_name, device_profile), group_msgs in grouped.items(): + logger.trace("Processing group: device='%s', profile='%s', messages=%d", device_name, device_profile, + len(group_msgs)) + + all_ts: List[TimeseriesEntry] = [] + delivery_futures: List[asyncio.Future] = [] + for msg in group_msgs: + if msg.has_timeseries(): + for ts_group in msg.timeseries.values(): + all_ts.extend(ts_group) + delivery_futures.extend(msg.get_delivery_futures()) + + builder: Optional[DeviceUplinkMessageBuilder] = None size = 0 point_count = 0 batch_futures = [] - for grouped_ts in message.timeseries.values(): - for ts_kv in grouped_ts: - exceeds_size = builder and size + ts_kv.size > self._max_payload_size - exceeds_points = 0 < self._max_datapoints <= point_count - - if not builder or exceeds_size or exceeds_points: - if builder: - built = builder.build() - result.append(built) - batch_futures.extend(built.get_delivery_futures()) - logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) - builder = DeviceUplinkMessageBuilder().set_device_name(message.device_name).set_device_profile( - message.device_profile) - size = 0 - point_count = 0 - - builder.add_telemetry(ts_kv) - size += ts_kv.size - point_count += 1 + for ts_kv in all_ts: + exceeds_size = builder and size + ts_kv.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder: + built = builder.build() + result.append(built) + batch_futures.extend(built.get_delivery_futures()) + logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) + + builder = DeviceUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) + size = 0 + point_count = 0 + + builder.add_timeseries(ts_kv) + size += ts_kv.size + point_count += 1 if builder and builder._timeseries: # noqa built = builder.build() @@ -81,14 +98,13 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp batch_futures.extend(built.get_delivery_futures()) logger.trace("Flushed final batch with %d points (size=%d)", len(built.timeseries), size) - if message.get_delivery_futures(): - original_future = message.get_delivery_futures()[0] + if delivery_futures: + original_future = delivery_futures[0] logger.trace("Adding futures to original future: %s, futures ids: %r", id(original_future), - [id(batch_future) for batch_future in batch_futures]) + [id(f) for f in batch_futures]) async def resolve_original(): - logger.trace("Resolving original future with batch futures: %r, %s", - batch_futures, [id(f) for f in batch_futures]) + logger.trace("Resolving original future with batch futures: %r", [id(f) for f in batch_futures]) results = await asyncio.gather(*batch_futures, return_exceptions=False) original_future.set_result(all(results)) @@ -102,22 +118,33 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp result: List[DeviceUplinkMessage] = [] if (len(messages) == 1 - and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa + and ((messages[0].attributes_datapoint_count() + messages[ + 0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa and messages[0].size <= self._max_payload_size): return messages - for message in messages: - if not message.has_attributes(): - logger.trace("Message from device '%s' has no attributes. Skipping.", message.device_name) - continue + grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + logger.trace("Processing attribute group: device='%s', profile='%s', messages=%d", device_name, + device_profile, len(group_msgs)) + + all_attrs: List[AttributeEntry] = [] + delivery_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_attributes(): + all_attrs.extend(msg.attributes) + delivery_futures.extend(msg.get_delivery_futures()) - logger.trace("Processing attributes from device: %s", message.device_name) builder = None size = 0 point_count = 0 batch_futures = [] - for attr in message.attributes: + for attr in all_attrs: exceeds_size = builder and size + attr.size > self._max_payload_size exceeds_points = 0 < self._max_datapoints <= point_count @@ -127,14 +154,11 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp result.append(built) batch_futures.extend(built.get_delivery_futures()) logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) - builder = None + builder = DeviceUplinkMessageBuilder().set_device_name(device_name).set_device_profile( + device_profile) size = 0 point_count = 0 - if not builder: - builder = DeviceUplinkMessageBuilder().set_device_name(message.device_name).set_device_profile( - message.device_profile) - builder.add_attributes(attr) size += attr.size point_count += 1 @@ -145,10 +169,10 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp batch_futures.extend(built.get_delivery_futures()) logger.trace("Flushed final attribute batch (count=%d, size=%d)", len(built.attributes), size) - if message.get_delivery_futures(): - original_future = message.get_delivery_futures()[0] + if delivery_futures: + original_future = delivery_futures[0] logger.trace("Adding futures to original future: %s, futures ids: %r", id(original_future), - [id(batch_future) for batch_future in batch_futures]) + [id(batch_future) for batch_future in batch_futures]) async def resolve_original(): results = await asyncio.gather(*batch_futures, return_exceptions=False) From 7a6c43a093940d8f8ee1023606d1d6f83759a708 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 27 Jun 2025 14:48:11 +0300 Subject: [PATCH 35/74] Refactoring --- examples/device/operational_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index 95ad46a..c49891b 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -147,13 +147,13 @@ def _shutdown_handler(): raw_telemetry_publish_result = await client.send_telemetry(raw_dict) logger.info(f"Raw telemetry sent: {raw_dict} with result: {raw_telemetry_publish_result}") - # 2. Single TelemetryEntry (with ts) + # 2. Single TimeseriesEntry (with ts) single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) logger.info("Sending single telemetry: %s", single_entry) delivery_future = await client.send_telemetry(single_entry) logger.info(f"Single telemetry sent: {single_entry} with delivery future: {delivery_future}") - # 3. List of TelemetryEntry with mixed timestamps + # 3. List of TimeseriesEntry with mixed timestamps telemetry_entries = [] for i in range(100): From ccc20329e5d5ce168f282bc22a33a3fb1e441dee Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 30 Jun 2025 09:00:26 +0300 Subject: [PATCH 36/74] Adjusted claiming for current implementation on the platform --- examples/device/claim_device.py | 55 +++++++++++++++++++++++++ tb_mqtt_client/service/device/client.py | 21 ++-------- tb_mqtt_client/service/mqtt_manager.py | 2 +- 3 files changed, 60 insertions(+), 18 deletions(-) create mode 100644 examples/device/claim_device.py diff --git a/examples/device/claim_device.py b/examples/device/claim_device.py new file mode 100644 index 0000000..7b690ef --- /dev/null +++ b/examples/device/claim_device.py @@ -0,0 +1,55 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to claim a device using ThingsBoard DeviceClient + +import asyncio +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.service.device.client import DeviceClient + + +# Constants for connection +PLATFORM_HOST = "localhost" # Replace with your ThingsBoard host +DEVICE_ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your device's access token + +# Constants for claiming +CLAIMING_DURATION = 120 # Default claiming duration in seconds +CLAIMING_SECRET_KEY = "YOUR_SECRET_KEY" # Replace with your actual secret key + +async def main(): + # Create device config + config = DeviceConfig() + config.host = PLATFORM_HOST + config.access_token = DEVICE_ACCESS_TOKEN + + # Create device client + client = DeviceClient(config) + await client.connect() + + # Build claim request with secret key and optional duration (in seconds) + claim_request = ClaimRequest.build(secret_key=CLAIMING_SECRET_KEY, duration=CLAIMING_DURATION) + + # Send claim request + result: PublishResult = await client.claim_device(claim_request, wait_for_publish=True, timeout=CLAIMING_DURATION + 10) + if result.is_successful(): + print(f"Claiming request was sent successfully. Please use the secret key '{CLAIMING_SECRET_KEY}' to claim the device from the dashboard.") + else: + print(f"Failed to send claiming request. Result: {result}") + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index bd6d926..6616241 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -99,7 +99,6 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._requested_attribute_response_handler = RequestedAttributeResponseHandler() self._attribute_updates_handler = AttributeUpdatesHandler() self._rpc_requests_handler = RPCRequestsHandler() - self.__claiming_response_future: Union[Future[bool], None] = None self._firmware_updater = FirmwareUpdater(self) @@ -304,17 +303,15 @@ async def claim_device(self, wait_for_publish: bool = True, timeout: float = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: topic, payload = self._message_dispatcher.build_claim_request(claim_request) - self.__claiming_response_future = Future() - await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) + publish_future = await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) if wait_for_publish: try: - return await await_or_stop(self.__claiming_response_future, timeout=timeout, - stop_event=self._stop_event) + return await await_or_stop(publish_future[0], timeout=timeout, stop_event=self._stop_event) except TimeoutError: - logger.warning("Timeout while waiting for telemetry publish result") + logger.warning("Timeout while waiting for claiming publish result") return PublishResult(topic, 1, -1, len(payload), -1) else: - return self.__claiming_response_future + return publish_future[0] def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): self._attribute_updates_handler.set_callback(callback) @@ -455,16 +452,6 @@ async def __on_publish_result(self, publish_result: PublishResult): Callback for handling publish results. This can be used to handle the result of a publish operation, such as logging or updating state. """ - if mqtt_topics.DEVICE_CLAIM_TOPIC == publish_result.topic: - if self.__claiming_response_future and not self.__claiming_response_future.done(): - if publish_result.is_successful(): - self.__claiming_response_future.set_result(True) - logger.debug("Device claimed successfully.") - else: - self.__claiming_response_future.set_exception( - Exception(f"Failed to claim device: {publish_result}")) - logger.error("Failed to claim device: %r", publish_result) - return if publish_result.is_successful(): logger.trace("Publish successful: %r", publish_result) else: diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index deda326..4abdc0d 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -310,7 +310,7 @@ def _on_message_internal(self, client, topic: str, payload: bytes, qos, properti return def _on_publish_internal(self, client, mid): - pass + logger.trace("Publish was sent by client %r with mid=%s", client, mid) def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dict): logger.trace("Handling PUBACK mid=%s with rc %r and properties: %r", From bf8a6b2a11abb998b32d505fa70bac8388c58d99 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 30 Jun 2025 09:12:16 +0300 Subject: [PATCH 37/74] Removed deprecated example --- .../DEPRECATEDclaiming_device_pe_only.py | 37 ------------------- 1 file changed, 37 deletions(-) delete mode 100644 examples/device/DEPRECATEDclaiming_device_pe_only.py diff --git a/examples/device/DEPRECATEDclaiming_device_pe_only.py b/examples/device/DEPRECATEDclaiming_device_pe_only.py deleted file mode 100644 index 3687aba..0000000 --- a/examples/device/DEPRECATEDclaiming_device_pe_only.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.DEBUG) - -THINGSBOARD_HOST = "127.0.0.1" -DEVICE_ACCESS_TOKEN = "DEVICE_ACCESS_TOKEN" - -SECRET_KEY = "DEVICE_SECRET_KEY" # Customer should write this key in device claiming widget -DURATION = 30000 # In milliseconds (30 seconds) - - -def main(): - client = TBDeviceMqttClient(THINGSBOARD_HOST, username=DEVICE_ACCESS_TOKEN) - client.connect() - info = client.claim(secret_key=SECRET_KEY, duration=DURATION) - if info.rc() == 0: - print("Claiming request was sent.") - client.stop() - - -if __name__ == '__main__': - main() From 4c8b94fffc4b2eba07da80fbc489257be3c36b0f Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 30 Jun 2025 10:33:38 +0300 Subject: [PATCH 38/74] Renamed telemetry to time series --- examples/device/client_provisioning.py | 2 +- examples/device/load.py | 2 +- examples/device/operational_example.py | 6 +- .../{send_telemetry.py => send_timeseries.py} | 14 ++-- examples/device/tls_connect.py | 38 +++++++++ tb_mqtt_client/service/base_client.py | 35 ++++---- tb_mqtt_client/service/device/client.py | 22 ++++-- .../service/device/firmware_updater.py | 2 +- tb_mqtt_client/service/gateway/client.py | 2 +- tests/service/device/test_device_client.py | 79 ++++++++++++------- tests/service/device/test_firmware_updater.py | 10 +-- 11 files changed, 145 insertions(+), 67 deletions(-) rename examples/device/{send_telemetry.py => send_timeseries.py} (77%) create mode 100644 examples/device/tls_connect.py diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index e0d1b38..a2462b9 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -41,7 +41,7 @@ async def main(): await client.connect() # Send single telemetry entry to provisioned device - await client.send_telemetry(TimeseriesEntry("batteryLevel", randint(0, 100))) + await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) await client.stop() diff --git a/examples/device/load.py b/examples/device/load.py index 3cac8d9..2bd87d4 100644 --- a/examples/device/load.py +++ b/examples/device/load.py @@ -93,7 +93,7 @@ def _shutdown_handler(): ] try: - future = await client.send_telemetry(entries) + future = await client.send_timeseries(entries) if future: pending_futures.append((future, BATCH_SIZE)) sent_batches += 1 diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index c49891b..bcc537a 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -144,13 +144,13 @@ def _shutdown_handler(): "humidity": 60 } logger.info("Sending raw telemetry...") - raw_telemetry_publish_result = await client.send_telemetry(raw_dict) + raw_telemetry_publish_result = await client.send_timeseries(raw_dict) logger.info(f"Raw telemetry sent: {raw_dict} with result: {raw_telemetry_publish_result}") # 2. Single TimeseriesEntry (with ts) single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) logger.info("Sending single telemetry: %s", single_entry) - delivery_future = await client.send_telemetry(single_entry) + delivery_future = await client.send_timeseries(single_entry) logger.info(f"Single telemetry sent: {single_entry} with delivery future: {delivery_future}") # 3. List of TimeseriesEntry with mixed timestamps @@ -159,7 +159,7 @@ def _shutdown_handler(): for i in range(100): telemetry_entries.append(TimeseriesEntry("temperature", i, ts=int(datetime.now(UTC).timestamp() * 1000)-i)) logger.info("Sending list of telemetry entries with mixed timestamps...") - telemetry_list_publish_result = await client.send_telemetry(telemetry_entries) + telemetry_list_publish_result = await client.send_timeseries(telemetry_entries) logger.info("List of telemetry entries sent: %s with result: %s", len(telemetry_entries) if len(telemetry_entries) > 10 else telemetry_entries, telemetry_list_publish_result) diff --git a/examples/device/send_telemetry.py b/examples/device/send_timeseries.py similarity index 77% rename from examples/device/send_telemetry.py rename to examples/device/send_timeseries.py index 320ac18..36edb3f 100644 --- a/examples/device/send_telemetry.py +++ b/examples/device/send_timeseries.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This example demonstrates how to send telemetry data from a device to ThingsBoard using the DeviceClient. +# This example demonstrates how to send time series data from a device to ThingsBoard using the DeviceClient. import asyncio from random import uniform, randint @@ -28,21 +28,21 @@ async def main(): client = DeviceClient(config) await client.connect() - # Send telemetry as raw dictionary - await client.send_telemetry({ + # Send time series as raw dictionary + await client.send_timeseries({ "temperature": round(uniform(20.0, 30.0), 2), "humidity": randint(30, 70) }) - # Send single telemetry entry - await client.send_telemetry(TimeseriesEntry("batteryLevel", randint(0, 100))) + # Send single time series entry + await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) - # Send list of telemetry entries + # Send a list of time series entries entries = [ TimeseriesEntry("vibration", 0.05), TimeseriesEntry("speed", 123) ] - await client.send_telemetry(entries) + await client.send_timeseries(entries) await client.stop() diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py new file mode 100644 index 0000000..569b320 --- /dev/null +++ b/examples/device/tls_connect.py @@ -0,0 +1,38 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example demonstrates how to connect to ThingsBoard over SSL using the DeviceClient and send telemetry. + +import asyncio +from random import uniform, randint +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +async def main(): + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + await client.connect() + + + # Send telemetry entry + await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 7235ec0..6f9e271 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -58,21 +58,26 @@ async def disconnect(self): pass @abstractmethod - async def send_telemetry(self, telemetry_data: Union[TimeseriesEntry, - List[TimeseriesEntry], - Dict[str, Any], - List[Dict[str, Any]]], - wait_for_publish: bool = True, - timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], PublishResult]: - """ - Send telemetry data. - - :param telemetry_data: Dictionary of telemetry data, a single TimeseriesEntry, - or a list of TimeseriesEntry or dictionaries. - :param wait_for_publish: If True, wait for the publishing result. Default is True. - :param timeout: Timeout for the publish operation if `wait_for_publish` is True. - In seconds. If less than 0 or None, wait indefinitely. - :return: Future or PublishResult depending on `wait_for_publish`. + async def send_timeseries(self, + telemetry_data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool = True, + timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], + PublishResult, + None, + List[PublishResult], + List[asyncio.Future[PublishResult]]]: + """ + Sends timeseries data to the ThingsBoard server. + :param data: Timeseries data to send, can be a single TimeseriesEntry, a list of TimeseriesEntries, + a dictionary of key-value pairs, or a list of dictionaries. + :param qos: Quality of Service level for the MQTT message. + :param wait_for_publish: If True, waits for the publish result. + :param timeout: Timeout for waiting for the publish result. + :return: PublishResult or list of PublishResults if wait_for_publish is True, Future or list of Futures if not, + None if no data is sent. """ pass diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 6616241..150feab 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -178,13 +178,23 @@ async def disconnect(self): # await self._message_queue.shutdown() # TODO: Not sure if we need to shutdown the message queue here, as it might be handled by MQTTManager - async def send_telemetry( + async def send_timeseries( self, data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], qos: int = 1, wait_for_publish: bool = True, timeout: Optional[float] = None - ) -> Union[PublishResult, List[PublishResult], None]: + ) -> Union[PublishResult, List[PublishResult], None, Future[PublishResult], List[Future[PublishResult]]]: + """ + Sends timeseries data to the ThingsBoard server. + :param data: Timeseries data to send, can be a single TimeseriesEntry, a list of TimeseriesEntries, + a dictionary of key-value pairs, or a list of dictionaries. + :param qos: Quality of Service level for the MQTT message. + :param wait_for_publish: If True, waits for the publish result. + :param timeout: Timeout for waiting for the publish result. + :return: PublishResult or list of PublishResults if wait_for_publish is True, Future or list of Futures if not, + None if no data is sent. + """ message = self._build_uplink_message_for_telemetry(data) topic = mqtt_topics.DEVICE_TELEMETRY_TOPIC futures = await self._message_queue.publish( @@ -199,7 +209,7 @@ async def send_telemetry( return None if not wait_for_publish: - return None + return futures[0] if len(futures) == 1 else futures results = [] for fut in futures: @@ -304,14 +314,16 @@ async def claim_device(self, timeout: float = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: topic, payload = self._message_dispatcher.build_claim_request(claim_request) publish_future = await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) + if isinstance(publish_future, list): + publish_future = publish_future[0] if wait_for_publish: try: - return await await_or_stop(publish_future[0], timeout=timeout, stop_event=self._stop_event) + return await await_or_stop(publish_future, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for claiming publish result") return PublishResult(topic, 1, -1, len(payload), -1) else: - return publish_future[0] + return publish_future def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): self._attribute_updates_handler.set_callback(callback) diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py index 4422b52..358984a 100644 --- a/tb_mqtt_client/service/device/firmware_updater.py +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -223,7 +223,7 @@ def _is_different_firmware_versions(self, new_firmware_info): async def _send_current_firmware_info(self): current_info = [TimeseriesEntry(key, value) for key, value in self.current_firmware_info.items()] - await self._client.send_telemetry(current_info, wait_for_publish=True) + await self._client.send_timeseries(current_info, wait_for_publish=True) def verify_checksum(self, firmware_data, checksum_alg, checksum): if firmware_data is None: diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index cc6bd43..c7ae4ee 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -208,7 +208,7 @@ async def gw_disconnect_device(self, device_name: str): if self._device_disconnect_callback: await self._device_disconnect_callback(device_name) - async def gw_send_telemetry(self, device_name: str, telemetry: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry]]): + async def gw_send_timeseries(self, device_name: str, telemetry: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry]]): """ Send telemetry on behalf of a connected device. diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py index 395a78e..fcfbc51 100644 --- a/tests/service/device/test_device_client.py +++ b/tests/service/device/test_device_client.py @@ -29,24 +29,24 @@ @pytest.mark.asyncio -async def test_send_telemetry_with_dict(): +async def test_send_timeseries_with_dict(): client = DeviceClient() client._message_queue = AsyncMock(spec=MessageQueue) future = asyncio.Future() future.set_result(PublishResult("topic", 1, 1, 100, 1)) client._message_queue.publish.return_value = [future] - result = await client.send_telemetry({"temp": 22}) + result = await client.send_timeseries({"temp": 22}) assert isinstance(result, PublishResult) assert result.message_id == 1 @pytest.mark.asyncio -async def test_send_telemetry_timeout(): +async def test_send_timeseries_timeout(): client = DeviceClient() client._message_queue = AsyncMock() future = asyncio.Future() client._message_queue.publish.return_value = [future] - result = await client.send_telemetry({"temp": 22}, timeout=0.01) + result = await client.send_timeseries({"temp": 22}, timeout=0.01) assert isinstance(result, PublishResult) assert result.message_id == -1 @@ -88,19 +88,50 @@ async def test_send_rpc_response(): @pytest.mark.asyncio async def test_claim_device_success(): client = DeviceClient() - from tb_mqtt_client.entities.data.claim_request import ClaimRequest + client._message_queue = AsyncMock() + claim = ClaimRequest.build(secret_key="abc") + + fut = asyncio.Future() + fut.set_result(PublishResult(topic=DEVICE_CLAIM_TOPIC, qos=1, message_id=1, payload_size=12, reason_code=0)) + client._message_queue.publish.return_value = [fut] + + result = await client.claim_device(claim) + assert isinstance(result, PublishResult) + assert result.topic == DEVICE_CLAIM_TOPIC + + +@pytest.mark.asyncio +async def test_claim_device_timeout(): + client = DeviceClient() client._message_queue = AsyncMock() + + claim = ClaimRequest.build(secret_key="abc") fut = asyncio.Future() - fut.set_result(None) - client._message_queue.publish.return_value = fut - client._DeviceClient__claiming_response_future = asyncio.Future() - client._DeviceClient__claiming_response_future.set_result( - PublishResult(DEVICE_CLAIM_TOPIC, 1, 3, 10, 1) - ) + client._message_queue.publish.return_value = [fut] + result = await client.claim_device(claim, timeout=0.01) assert isinstance(result, PublishResult) - assert result.topic == DEVICE_CLAIM_TOPIC + assert result.message_id == -1 + + + +@pytest.mark.asyncio +async def test_claim_device_payload_contains_secret_key(): + client = DeviceClient() + client._message_queue = AsyncMock() + + claim = ClaimRequest.build(secret_key="my-secret") + fut = asyncio.Future() + fut.set_result(PublishResult(topic=DEVICE_CLAIM_TOPIC, qos=1, message_id=3, payload_size=15, reason_code=0)) + client._message_queue.publish.return_value = [fut] + + await client.claim_device(claim) + + client._message_queue.publish.assert_awaited_once() + args, kwargs = client._message_queue.publish.call_args + assert kwargs['topic'] == DEVICE_CLAIM_TOPIC + assert b"my-secret" in kwargs['payload'] @pytest.mark.asyncio @@ -363,22 +394,10 @@ async def test_does_not_update_dispatcher_when_not_initialized(): @pytest.mark.asyncio -async def test_publish_result_device_claim_successfully_sets_future(): - publish_result = PublishResult(topic=DEVICE_CLAIM_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0) - client = DeviceClient() - client._DeviceClient__claiming_response_future = MagicMock() - client._DeviceClient__claiming_response_future.done.return_value = False - - await client._DeviceClient__on_publish_result(publish_result) - - client._DeviceClient__claiming_response_future.set_result.assert_called_once_with(True) - - -@pytest.mark.asyncio -async def test_send_telemetry_without_connect_raises_error(): +async def test_send_timeseries_without_connect_raises_error(): client = DeviceClient() with pytest.raises(AttributeError): - await client.send_telemetry({"temp": 22}) + await client.send_timeseries({"temp": 22}) @pytest.mark.asyncio @@ -441,10 +460,10 @@ async def test_set_attribute_update_callback_sets_handler(): @pytest.mark.asyncio -async def test_send_telemetry_with_invalid_type_raises(): +async def test_send_timeseries_with_invalid_type_raises(): client = DeviceClient() with pytest.raises(ValueError): - await client.send_telemetry("invalid") + await client.send_timeseries("invalid") @pytest.mark.asyncio @@ -513,3 +532,7 @@ async def test_provision_timeout(monkeypatch): req = ProvisioningRequest(host="localhost", credentials=credentials, port=1883, device_name="dev") result = await DeviceClient.provision(req, timeout=0.01) assert result is None + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/device/test_firmware_updater.py b/tests/service/device/test_firmware_updater.py index c86d91c..46874bc 100644 --- a/tests/service/device/test_firmware_updater.py +++ b/tests/service/device/test_firmware_updater.py @@ -33,7 +33,7 @@ def mock_client(): client._mqtt_manager.unsubscribe = AsyncMock() client._mqtt_manager.is_connected.return_value = True client._message_queue.publish = AsyncMock() - client.send_telemetry = AsyncMock() + client.send_timeseries = AsyncMock() client.send_attribute_request = AsyncMock() return client @@ -50,7 +50,7 @@ async def test_update_success(updater, mock_client): patch.object(updater, "_firmware_info_callback", new=AsyncMock()): await updater.update() mock_client._mqtt_manager.subscribe.assert_called_once() - mock_client.send_telemetry.assert_called() + mock_client.send_timeseries.assert_called() mock_client.send_attribute_request.assert_called_once() @@ -182,15 +182,15 @@ async def test_save_firmware_failure_logs_error(updater, caplog): @pytest.mark.asyncio -async def test_send_current_firmware_info_calls_send_telemetry(updater, mock_client): +async def test_send_current_firmware_info_calls_send_timeseries(updater, mock_client): updater.current_firmware_info = { f"current_{FW_TITLE_ATTR}": "test", f"current_{FW_VERSION_ATTR}": "1.0", FW_STATE_ATTR: FirmwareStates.DOWNLOADING.value } await updater._send_current_firmware_info() - mock_client.send_telemetry.assert_awaited_once() - args = mock_client.send_telemetry.call_args[0][0] + mock_client.send_timeseries.assert_awaited_once() + args = mock_client.send_timeseries.call_args[0][0] assert all(isinstance(entry, TimeseriesEntry) for entry in args) From 14c4a5315e26ad25f5e5585925b49a197573041f Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 30 Jun 2025 10:52:20 +0300 Subject: [PATCH 39/74] Added new example for connection over SSL --- examples/device/DEPRECATEDtls_connect.py | 25 ------------------- examples/device/tls_connect.py | 31 ++++++++++++++++++++---- 2 files changed, 26 insertions(+), 30 deletions(-) delete mode 100644 examples/device/DEPRECATEDtls_connect.py diff --git a/examples/device/DEPRECATEDtls_connect.py b/examples/device/DEPRECATEDtls_connect.py deleted file mode 100644 index 5dbd36d..0000000 --- a/examples/device/DEPRECATEDtls_connect.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_device_mqtt import TBDeviceMqttClient -import socket - -logging.basicConfig(level=logging.DEBUG) -# connecting to localhost -client = TBDeviceMqttClient(socket.gethostname(), username="A2_TEST_TOKEN") -client.connect(tls=True, - ca_certs="mqttserver.pub.pem", - cert_file="mqttclient.nopass.pem") -client.disconnect() diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index 569b320..c1f624c 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -15,22 +15,43 @@ # This example demonstrates how to connect to ThingsBoard over SSL using the DeviceClient and send telemetry. import asyncio -from random import uniform, randint +from random import randint + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient +PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host +PLATFORM_PORT = 8883 # Default port for MQTT over SSL + + +# Update with your CA certificate, client certificate, and client key paths. There are no default files generated. +# You can generate them using the following guides: +# Certificates for server - https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ +# Certificates for client - https://thingsboard.io/docs/user-guide/certificates/?ubuntuThingsboardX509=X509Leaf +CA_CERT_PATH = "mqttserver.pem" # Update with your CA certificate path (Default - mqttserver.pem in the examples directory) +CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) +CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) + async def main(): config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" + + config.host = PLATFORM_HOST + config.port = PLATFORM_PORT + + config.ca_cert = CA_CERT_PATH + config.client_cert = CLIENT_CERT_PATH + config.private_key = CLIENT_KEY_PATH client = DeviceClient(config) await client.connect() - # Send telemetry entry - await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) + result = await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100)), wait_for_publish=True) + if result is not None and result.is_successful(): + print("Telemetry sent successfully") + else: + print(f"Failed to send telemetry: {result}") await client.stop() From 30e61e4661d53662d51768bdb091e5a64efbcaca Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 30 Jun 2025 10:56:26 +0300 Subject: [PATCH 40/74] Added send_telemetry method with warning, to use send_timeseries instead --- examples/device/tls_connect.py | 2 +- tb_mqtt_client/service/device/client.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index c1f624c..5feeb9f 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This example demonstrates how to connect to ThingsBoard over SSL using the DeviceClient and send telemetry. +# This example demonstrates how to connect to ThingsBoard over SSL using the DeviceClient and send time series. import asyncio from random import randint diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 150feab..6f43cde 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -178,6 +178,13 @@ async def disconnect(self): # await self._message_queue.shutdown() # TODO: Not sure if we need to shutdown the message queue here, as it might be handled by MQTTManager + async def send_telemetry(self, *args, **kwargs): + """ + Note: This method is deprecated. Use `send_timeseries` instead. + """ + logger.warning("send_telemetry is deprecated. Use send_timeseries instead.") + return await self.send_timeseries(*args, **kwargs) + async def send_timeseries( self, data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], From d38e0318f01d621b30554354ce6a83749db33172 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 30 Jun 2025 10:58:01 +0300 Subject: [PATCH 41/74] Example refactoring --- examples/device/tls_connect.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index 5feeb9f..b823b4a 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -21,6 +21,7 @@ from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient + PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host PLATFORM_PORT = 8883 # Default port for MQTT over SSL @@ -33,6 +34,7 @@ CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) + async def main(): config = DeviceConfig() @@ -47,7 +49,8 @@ async def main(): await client.connect() # Send telemetry entry - result = await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100)), wait_for_publish=True) + result = await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100)), + wait_for_publish=True) if result is not None and result.is_successful(): print("Telemetry sent successfully") else: From d4e1900f6b673f630d8684c8deadc3ef1e5eda77 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 3 Jul 2025 09:58:39 +0300 Subject: [PATCH 42/74] Examples refactoring --- examples/device/claim_device.py | 14 ++++++-- examples/device/client_provisioning.py | 16 ++++++--- examples/device/firmware_update.py | 10 +++++- examples/device/handle_attribute_updates.py | 17 +++++++-- examples/device/handle_rpc_requests.py | 16 +++++++-- examples/device/request_attributes.py | 14 ++++++-- examples/device/send_attributes.py | 35 ++++++++++++++----- examples/device/send_client_side_rpc.py | 17 +++++++-- examples/device/send_timeseries.py | 23 ++++++++++-- examples/device/tls_connect.py | 12 +++++-- ...y => DEPRECATEDclaiming_device_pe_only.py} | 0 ...=> DEPRECATEDconnect_disconnect_device.py} | 0 ...tes.py => DEPRECATEDrequest_attributes.py} | 0 ..._to_rpc.py => DEPRECATEDrespond_to_rpc.py} | 0 ...EPRECATEDsend_telemetry_and_attributes.py} | 0 ...y => DEPRECATEDsubscribe_to_attributes.py} | 0 ...ls_connect.py => DEPRECATEDtls_connect.py} | 0 .../service/gateway/device_session.py | 14 ++++++++ 18 files changed, 157 insertions(+), 31 deletions(-) rename examples/gateway/{claiming_device_pe_only.py => DEPRECATEDclaiming_device_pe_only.py} (100%) rename examples/gateway/{connect_disconnect_device.py => DEPRECATEDconnect_disconnect_device.py} (100%) rename examples/gateway/{request_attributes.py => DEPRECATEDrequest_attributes.py} (100%) rename examples/gateway/{respond_to_rpc.py => DEPRECATEDrespond_to_rpc.py} (100%) rename examples/gateway/{send_telemetry_and_attributes.py => DEPRECATEDsend_telemetry_and_attributes.py} (100%) rename examples/gateway/{subscribe_to_attributes.py => DEPRECATEDsubscribe_to_attributes.py} (100%) rename examples/gateway/{tls_connect.py => DEPRECATEDtls_connect.py} (100%) create mode 100644 tb_mqtt_client/service/gateway/device_session.py diff --git a/examples/device/claim_device.py b/examples/device/claim_device.py index 7b690ef..cc34c49 100644 --- a/examples/device/claim_device.py +++ b/examples/device/claim_device.py @@ -15,10 +15,19 @@ # Example script to claim a device using ThingsBoard DeviceClient import asyncio +import logging + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) # Constants for connection @@ -29,6 +38,7 @@ CLAIMING_DURATION = 120 # Default claiming duration in seconds CLAIMING_SECRET_KEY = "YOUR_SECRET_KEY" # Replace with your actual secret key + async def main(): # Create device config config = DeviceConfig() @@ -45,9 +55,9 @@ async def main(): # Send claim request result: PublishResult = await client.claim_device(claim_request, wait_for_publish=True, timeout=CLAIMING_DURATION + 10) if result.is_successful(): - print(f"Claiming request was sent successfully. Please use the secret key '{CLAIMING_SECRET_KEY}' to claim the device from the dashboard.") + logger.info(f"Claiming request was sent successfully. Please use the secret key '{CLAIMING_SECRET_KEY}' to claim the device from the dashboard.") else: - print(f"Failed to send claiming request. Result: {result}") + logger.error(f"Failed to send claiming request. Result: {result}") await client.stop() diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index a2462b9..240e0f4 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -15,11 +15,19 @@ # Example script to device provisioning using the DeviceClient import asyncio +import logging from random import randint from tb_mqtt_client.entities.data.provisioning_request import AccessTokenProvisioningCredentials, ProvisioningRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) async def main(): @@ -31,16 +39,16 @@ async def main(): provisioning_response = await DeviceClient.provision(provisioning_request) if provisioning_response.error is not None: - print(f"Provisioning failed: {provisioning_response.error}") + logger.error(f"Provisioning failed: {provisioning_response.error}") return - print('Provisined device config: ', provisioning_response) + logger.info('Provisioned device configuration: ', provisioning_response) - # Create a DeviceClient instance with the provisioned device config + # Create a DeviceClient instance with the provisioned device configuration client = DeviceClient(provisioning_response.result) await client.connect() - # Send single telemetry entry to provisioned device + # Send single telemetry entry to the provisioned device await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) await client.stop() diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py index 8ec7380..34df722 100644 --- a/examples/device/firmware_update.py +++ b/examples/device/firmware_update.py @@ -15,10 +15,18 @@ # Example script to update firmware using the DeviceClient import asyncio +import logging from time import monotonic from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) firmware_received = asyncio.Event() @@ -26,7 +34,7 @@ async def firmware_update_callback(_, payload): - print(f"Firmware update payload received: {payload}") + logger.info(f"Firmware update payload received: {payload}") firmware_received.set() diff --git a/examples/device/handle_attribute_updates.py b/examples/device/handle_attribute_updates.py index 9ed3595..649a2e4 100644 --- a/examples/device/handle_attribute_updates.py +++ b/examples/device/handle_attribute_updates.py @@ -15,12 +15,23 @@ # Example script to handle attribute updates from ThingsBoard using the DeviceClient import asyncio +import logging + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + async def attribute_update_callback(update: AttributeUpdate): - print("Received attribute update:", update) + logger.info("Received attribute update:", update) + async def main(): config = DeviceConfig() @@ -31,13 +42,13 @@ async def main(): client.set_attribute_update_callback(attribute_update_callback) await client.connect() - print("Waiting for attribute updates... Press Ctrl+C to stop.") + logger.info("Waiting for attribute updates... Press Ctrl+C to stop.") try: while True: await asyncio.sleep(1) except KeyboardInterrupt: - print("Shutting down...") + logger.info("Shutting down...") await client.stop() diff --git a/examples/device/handle_rpc_requests.py b/examples/device/handle_rpc_requests.py index 587057e..eb82463 100644 --- a/examples/device/handle_rpc_requests.py +++ b/examples/device/handle_rpc_requests.py @@ -15,13 +15,23 @@ # Example script to handle RPC requests from ThingsBoard using the DeviceClient import asyncio +import logging + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + async def rpc_request_callback(request: RPCRequest) -> RPCResponse: - print("Received RPC:", request) + logger.info("Received RPC:", request) if request.method == "ping": return RPCResponse.build(request_id=request.request_id, result={"pong": True}) @@ -37,12 +47,12 @@ async def main(): client.set_rpc_request_callback(rpc_request_callback) await client.connect() - print("Waiting for RPCs... Press Ctrl+C to stop.") + logger.info("Waiting for RPCs... Press Ctrl+C to stop.") try: while True: await asyncio.sleep(1) except KeyboardInterrupt: - print("Shutting down...") + logger.info("Shutting down...") await client.stop() diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index 55f1b06..88a13ae 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -15,13 +15,23 @@ # Example script to request attributes from ThingsBoard using the DeviceClient import asyncio +import logging + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + async def attribute_request_callback(response: RequestedAttributeResponse): - print("Received attribute response:", response) + logger.info("Received attribute response:", response) async def main(): config = DeviceConfig() @@ -35,7 +45,7 @@ async def main(): request = await AttributeRequest.build(["targetTemperature"], ["currentTemperature"]) await client.send_attribute_request(request, attribute_request_callback) - print("Attribute request sent. Waiting for response...") + logger.info("Attribute request sent. Waiting for response...") await asyncio.sleep(5) await client.stop() diff --git a/examples/device/send_attributes.py b/examples/device/send_attributes.py index de5cb3c..76a779f 100644 --- a/examples/device/send_attributes.py +++ b/examples/device/send_attributes.py @@ -15,9 +15,19 @@ # Example script to send attributes to ThingsBoard using the DeviceClient import asyncio +import logging + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + async def main(): config = DeviceConfig() @@ -28,19 +38,28 @@ async def main(): await client.connect() # Send attribute as raw dictionary - await client.send_attributes({ - "firmwareVersion": "1.0.4", + raw_attributes = { + "firmwareVersion": "1.0.3", "hardwareModel": "TB-SDK-Device" - }) + } + logger.info("Sending raw attributes: %s", raw_attributes) + await client.send_attributes(raw_attributes) + logger.info("Raw attributes sent successfully.") # Send single attribute entry - await client.send_attributes(AttributeEntry("mode", "normal")) + single_attribute = AttributeEntry("mode", "normal") + logger.info("Sending single attribute: %s", single_attribute) + await client.send_attributes(single_attribute) + logger.info("Single attribute sent successfully.") # Send list of attributes - await client.send_attributes([ - AttributeEntry("maxTemperature", 85), - AttributeEntry("calibrated", True) - ]) + attribute_list = [ + AttributeEntry("location", "Building A"), + AttributeEntry("status", "active") + ] + logger.info("Sending list of attributes: %s", attribute_list) + await client.send_attributes(attribute_list) + logger.info("List of attributes sent successfully.") await client.stop() diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py index 4862769..aa297a8 100644 --- a/examples/device/send_client_side_rpc.py +++ b/examples/device/send_client_side_rpc.py @@ -15,13 +15,24 @@ # Example script to send a client-side RPC request to ThingsBoard using the DeviceClient import asyncio +import logging + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + async def rpc_response_callback(response: RPCResponse): - print("Received RPC response:", response) + logger.info("Received RPC response:", response) + async def main(): config = DeviceConfig() @@ -35,9 +46,9 @@ async def main(): rpc_request = await RPCRequest.build("getTime", {}) try: response = await client.send_rpc_request(rpc_request) - print("Received response:", response) + logger.info("Received response:", response) except TimeoutError: - print("RPC request timed out") + logger.info("RPC request timed out") # Send client-side RPC with callback rpc_request_2 = await RPCRequest.build("getStatus", {}) diff --git a/examples/device/send_timeseries.py b/examples/device/send_timeseries.py index 36edb3f..f7e4b70 100644 --- a/examples/device/send_timeseries.py +++ b/examples/device/send_timeseries.py @@ -15,10 +15,19 @@ # This example demonstrates how to send time series data from a device to ThingsBoard using the DeviceClient. import asyncio +import logging from random import uniform, randint from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + async def main(): config = DeviceConfig() @@ -29,20 +38,28 @@ async def main(): await client.connect() # Send time series as raw dictionary - await client.send_timeseries({ + raw_timeseries = { "temperature": round(uniform(20.0, 30.0), 2), "humidity": randint(30, 70) - }) + } + logger.info("Sending raw timeseries: %s", raw_timeseries) + await client.send_timeseries(raw_timeseries) + logger.info("Raw timeseries sent successfully.") # Send single time series entry - await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) + single_entry = TimeseriesEntry("pressure", round(uniform(950.0, 1050.0), 2)) + logger.info("Sending single timeseries entry: %s", single_entry) + await client.send_timeseries(single_entry) + logger.info("Single timeseries entry sent successfully.") # Send a list of time series entries entries = [ TimeseriesEntry("vibration", 0.05), TimeseriesEntry("speed", 123) ] + logger.info("Sending list of timeseries entries: %s", entries) await client.send_timeseries(entries) + logger.info("List of timeseries entries sent successfully.") await client.stop() diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index b823b4a..a0370a0 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -15,11 +15,19 @@ # This example demonstrates how to connect to ThingsBoard over SSL using the DeviceClient and send time series. import asyncio +import logging from random import randint from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger + + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host @@ -52,9 +60,9 @@ async def main(): result = await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100)), wait_for_publish=True) if result is not None and result.is_successful(): - print("Telemetry sent successfully") + logger.info("Telemetry sent successfully") else: - print(f"Failed to send telemetry: {result}") + logger.error(f"Failed to send telemetry: {result}") await client.stop() diff --git a/examples/gateway/claiming_device_pe_only.py b/examples/gateway/DEPRECATEDclaiming_device_pe_only.py similarity index 100% rename from examples/gateway/claiming_device_pe_only.py rename to examples/gateway/DEPRECATEDclaiming_device_pe_only.py diff --git a/examples/gateway/connect_disconnect_device.py b/examples/gateway/DEPRECATEDconnect_disconnect_device.py similarity index 100% rename from examples/gateway/connect_disconnect_device.py rename to examples/gateway/DEPRECATEDconnect_disconnect_device.py diff --git a/examples/gateway/request_attributes.py b/examples/gateway/DEPRECATEDrequest_attributes.py similarity index 100% rename from examples/gateway/request_attributes.py rename to examples/gateway/DEPRECATEDrequest_attributes.py diff --git a/examples/gateway/respond_to_rpc.py b/examples/gateway/DEPRECATEDrespond_to_rpc.py similarity index 100% rename from examples/gateway/respond_to_rpc.py rename to examples/gateway/DEPRECATEDrespond_to_rpc.py diff --git a/examples/gateway/send_telemetry_and_attributes.py b/examples/gateway/DEPRECATEDsend_telemetry_and_attributes.py similarity index 100% rename from examples/gateway/send_telemetry_and_attributes.py rename to examples/gateway/DEPRECATEDsend_telemetry_and_attributes.py diff --git a/examples/gateway/subscribe_to_attributes.py b/examples/gateway/DEPRECATEDsubscribe_to_attributes.py similarity index 100% rename from examples/gateway/subscribe_to_attributes.py rename to examples/gateway/DEPRECATEDsubscribe_to_attributes.py diff --git a/examples/gateway/tls_connect.py b/examples/gateway/DEPRECATEDtls_connect.py similarity index 100% rename from examples/gateway/tls_connect.py rename to examples/gateway/DEPRECATEDtls_connect.py diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py new file mode 100644 index 0000000..fa669aa --- /dev/null +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -0,0 +1,14 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + From 0ceecf274f89ef82c077403468a4f91a20e62cd1 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 11 Jul 2025 12:53:25 +0300 Subject: [PATCH 43/74] Refactored device client and added basic entities for gateway client --- tb_mqtt_client/common/config_loader.py | 64 ++- tb_mqtt_client/common/gmqtt_patch.py | 1 - tb_mqtt_client/common/provisioning_client.py | 8 +- .../gateway/device_connect_message.py | 53 +++ .../gateway/device_disconnect_message.py | 66 +++ .../entities/gateway/device_info.py | 56 +++ .../entities/gateway/device_session_state.py | 17 + tb_mqtt_client/entities/gateway/event_type.py | 36 ++ .../gateway/gateway_attribute_request.py | 66 +++ .../gateway/gateway_attribute_update.py | 32 ++ .../entities/gateway/gateway_event.py | 41 ++ .../gateway_requested_attribute_response.py | 99 +++++ .../entities/gateway/gateway_rpc_request.py | 52 +++ .../entities/gateway/gateway_rpc_response.py | 90 ++++ .../gateway/gateway_uplink_message.py | 172 ++++++++ tb_mqtt_client/service/device/client.py | 38 +- .../handlers/attribute_updates_handler.py | 20 +- .../requested_attributes_response_handler.py | 22 +- .../device/handlers/rpc_requests_handler.py | 24 +- .../device/handlers/rpc_response_handler.py | 26 +- .../message_adapter.py} | 34 +- tb_mqtt_client/service/event_dispatcher.py | 14 - tb_mqtt_client/service/gateway/client.py | 390 +++--------------- .../service/gateway/device_manager.py | 121 ++++++ .../service/gateway/device_sesion.py | 14 - .../service/gateway/device_session.py | 42 ++ .../service/gateway/event_dispatcher.py | 56 +++ .../gateway_attribute_updates_handler.py | 14 - .../gateway/gateway_client_interface.py | 73 ++++ .../gateway/gateway_rpc_requests_handler.py | 14 - .../gateway/handlers/__init__.py} | 1 - .../gateway_attribute_updates_handler.py | 35 ++ .../gateway/handlers/gateway_rpc_handler.py} | 0 .../service/gateway/message_adapter.py | 267 ++++++++++++ .../service/gateway/subdevice_manager.py | 14 - tb_mqtt_client/service/message_queue.py | 10 +- tb_mqtt_client/service/mqtt_manager.py | 6 +- tests/common/test_provisioning_client.py | 6 +- tests/service/device/test_device_client.py | 8 +- ...atcher.py => test_json_message_adapter.py} | 146 +++---- tests/service/test_message_queue.py | 256 ++++++------ tests/service/test_message_splitter.py | 6 +- tests/service/test_mqtt_manager.py | 6 +- 43 files changed, 1792 insertions(+), 724 deletions(-) create mode 100644 tb_mqtt_client/entities/gateway/device_connect_message.py create mode 100644 tb_mqtt_client/entities/gateway/device_disconnect_message.py create mode 100644 tb_mqtt_client/entities/gateway/device_info.py create mode 100644 tb_mqtt_client/entities/gateway/event_type.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_attribute_request.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_attribute_update.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_event.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_rpc_request.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_rpc_response.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_uplink_message.py rename tb_mqtt_client/service/{message_dispatcher.py => device/message_adapter.py} (93%) delete mode 100644 tb_mqtt_client/service/event_dispatcher.py create mode 100644 tb_mqtt_client/service/gateway/device_manager.py delete mode 100644 tb_mqtt_client/service/gateway/device_sesion.py create mode 100644 tb_mqtt_client/service/gateway/event_dispatcher.py delete mode 100644 tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py create mode 100644 tb_mqtt_client/service/gateway/gateway_client_interface.py delete mode 100644 tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py rename tb_mqtt_client/{entities/gateway/virtual_device.py => service/gateway/handlers/__init__.py} (99%) create mode 100644 tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py rename tb_mqtt_client/{entities/gateway/rpc_context.py => service/gateway/handlers/gateway_rpc_handler.py} (100%) create mode 100644 tb_mqtt_client/service/gateway/message_adapter.py delete mode 100644 tb_mqtt_client/service/gateway/subdevice_manager.py rename tests/service/{test_json_message_dispatcher.py => test_json_message_adapter.py} (65%) diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index 55af666..c47b6eb 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -22,7 +22,18 @@ class DeviceConfig: This class loads configuration options from environment variables, allowing for flexible deployment and easy customization of device connection settings. """ - def __init__(self): + def __init__(self, config = None): + if config is not None: + self.host: str = config.get("host", "localhost") + self.port: int = config.get("port", 1883) + self.access_token: Optional[str] = config.get("access_token") + self.username: Optional[str] = config.get("username") + self.password: Optional[str] = config.get("password") + self.client_id: Optional[str] = config.get("client_id") + self.ca_cert: Optional[str] = config.get("ca_cert") + self.client_cert: Optional[str] = config.get("client_cert") + self.private_key: Optional[str] = config.get("private_key") + self.host: str = os.getenv("TB_HOST") self.port: int = int(os.getenv("TB_PORT", 1883)) @@ -60,26 +71,37 @@ class GatewayConfig(DeviceConfig): Configuration class for ThingsBoard gateway clients. This class extends DeviceConfig to include additional options specific to gateways. """ - def __init__(self): - super().__init__() - - # Gateway-specific options - self.gateway_name: Optional[str] = os.getenv("TB_GATEWAY_NAME") - - # Rate limits for devices connected through the gateway - self.device_messages_rate_limit: Optional[str] = os.getenv("TB_DEVICE_MESSAGES_RATE_LIMIT") - self.device_telemetry_rate_limit: Optional[str] = os.getenv("TB_DEVICE_TELEMETRY_RATE_LIMIT") - self.device_telemetry_dp_rate_limit: Optional[str] = os.getenv("TB_DEVICE_TELEMETRY_DP_RATE_LIMIT") - - # Default device type for auto-registered devices - self.default_device_type: Optional[str] = os.getenv("TB_DEFAULT_DEVICE_TYPE", "default") - - # Whether to automatically register new devices - self.auto_register_devices: bool = os.getenv("TB_AUTO_REGISTER_DEVICES", "true").lower() == "true" + def __init__(self, config=None): + # TODO: REFACTOR, temporary solution for development + super().__init__(config) + + if os.getenv("TB_GW_HOST") is not None: + self.host: str = os.getenv("TB_GW_HOST") + if os.getenv("TB_GW_PORT") is not None: + self.port: int = int(os.getenv("TB_GW_PORT", 1883)) + + if os.getenv("TB_GW_ACCESS_TOKEN") is not None: + self.access_token: Optional[str] = os.getenv("TB_GW_ACCESS_TOKEN") + if os.getenv("TB_GW_USERNAME") is not None: + self.username: Optional[str] = os.getenv("TB_GW_USERNAME") + if os.getenv("TB_GW_PASSWORD") is not None: + self.password: Optional[str] = os.getenv("TB_GW_PASSWORD") + + if os.getenv("TB_GW_CLIENT_ID") is not None: + self.client_id: Optional[str] = os.getenv("TB_GW_CLIENT_ID") + + if os.getenv("TB_GW_CA_CERT") is not None: + self.ca_cert: Optional[str] = os.getenv("TB_GW_CA_CERT") + if os.getenv("TB_GW_CLIENT_CERT") is not None: + self.client_cert: Optional[str] = os.getenv("TB_GW_CLIENT_CERT") + if os.getenv("TB_GW_PRIVATE_KEY") is not None: + self.private_key: Optional[str] = os.getenv("TB_GW_PRIVATE_KEY") + + if os.getenv("TB_GW_QOS") is not None: + self.qos: int = int(os.getenv("TB_GW_QOS", 1)) def __repr__(self): - return (f"") + f"client_id={self.client_id} " + f"tls={self.use_tls()})") diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index 233d4e1..f3d715c 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -15,7 +15,6 @@ import asyncio import struct from collections import defaultdict -from types import MethodType from typing import Callable from gmqtt.mqtt.constants import MQTTCommands diff --git a/tb_mqtt_client/common/provisioning_client.py b/tb_mqtt_client/common/provisioning_client.py index 940d800..5ca90f7 100644 --- a/tb_mqtt_client/common/provisioning_client.py +++ b/tb_mqtt_client/common/provisioning_client.py @@ -22,7 +22,7 @@ from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter logger = get_logger(__name__) @@ -40,13 +40,13 @@ def __init__(self, host: str, port: int, provision_request: ProvisioningRequest) self._client.on_message = self._on_message self._provisioned = Event() self._device_config: Optional[Union[DeviceConfig, ProvisioningResponse]] = None - self.__message_dispatcher = JsonMessageDispatcher() + self.__message_adapter = JsonMessageAdapter() def _on_connect(self, client, _, rc, __): if rc == 0: self._log.debug("[Provisioning client] Connected to ThingsBoard") client.subscribe(PROVISION_RESPONSE_TOPIC) - topic, payload = self.__message_dispatcher.build_provision_request(self._provision_request) + topic, payload = self.__message_adapter.build_provision_request(self._provision_request) self._log.debug("[Provisioning client] Sending provisioning request %s" % payload) client.publish(topic, payload) else: @@ -57,7 +57,7 @@ def _on_connect(self, client, _, rc, __): self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) async def _on_message(self, _, __, payload, ___, ____): - provisioning_response = self.__message_dispatcher.parse_provisioning_response(self._provision_request, payload) + provisioning_response = self.__message_adapter.parse_provisioning_response(self._provision_request, payload) self._device_config = provisioning_response.result await self._client.disconnect() diff --git a/tb_mqtt_client/entities/gateway/device_connect_message.py b/tb_mqtt_client/entities/gateway/device_connect_message.py new file mode 100644 index 0000000..a760c64 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_connect_message.py @@ -0,0 +1,53 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict + + +@dataclass(slots=True, frozen=True) +class DeviceConnectMessage: + """ + Represents a device connection message in the ThingsBoard Gateway MQTT client. + This class is used to encapsulate the details of a device connection message. + """ + device_name: str + device_profile: str = 'default' + + def __new__(self, *args, **kwargs): + raise TypeError("Direct instantiation of DeviceConnectMessage is not allowed. Use 'await DeviceConnectMessage.build(...)'.") + + def __repr__(self): + return f"DeviceConnectMessage(device_name={self.device_name}, device_profile={self.device_profile})" + + @classmethod + def build(cls, device_name: str, device_profile: str = 'default') -> 'DeviceConnectMessage': + """ + Build a new DeviceConnectMessage with the specified device name and profile. + """ + if not device_name: + raise ValueError("Device name must not be empty.") + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) + return self + + def to_payload_format(self) -> Dict[str, str]: + """ + Convert the device connection message into the expected MQTT payload format. + """ + return { + "device": self.device_name, + "type": self.device_profile + } diff --git a/tb_mqtt_client/entities/gateway/device_disconnect_message.py b/tb_mqtt_client/entities/gateway/device_disconnect_message.py new file mode 100644 index 0000000..2d9ecf1 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_disconnect_message.py @@ -0,0 +1,66 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict + + +@dataclass(slots=True, frozen=True) +class DeviceDisconnectMessage: + """ + Represents a device disconnection message in the ThingsBoard Gateway MQTT client. + This class is used to encapsulate the details of a device connection message. + """ + device_name: str + + def __new__(self, *args, **kwargs): + raise TypeError( + "Direct instantiation of DeviceDisconnectMessage is not allowed. Use 'DeviceDisconnectMessage.build(...)'.") + + def __repr__(self): + return f"DeviceDisconnectMessage(device_name={self.device_name})" + + @classmethod + def build(cls, device_name: str) -> 'DeviceDisconnectMessage': + """ + Build a new DeviceDisconnectMessage with the specified device name and profile. + """ + if not device_name: + raise ValueError("Device name must not be empty.") + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + return self + + def to_payload_format(self) -> Dict[str, str]: + """ + Convert the device connection message into the expected MQTT payload format. + """ + return { + "device": self.device_name + } diff --git a/tb_mqtt_client/entities/gateway/device_info.py b/tb_mqtt_client/entities/gateway/device_info.py new file mode 100644 index 0000000..eb9d5f0 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_info.py @@ -0,0 +1,56 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +import uuid + + +@dataclass(frozen=True) +class DeviceInfo: + device_name: str + device_profile: str + original_name: str = field(init=False) + device_id: uuid.UUID = field(default_factory=uuid.uuid4, init=False) + + def __post_init__(self): + self.__setattr__("original_name", self.device_name) + + def rename(self, new_name: str): + if new_name != self.device_name: + self.__setattr__("device_name", new_name) + + @classmethod + def from_dict(cls, data: dict) -> 'DeviceInfo': + instance = cls( + device_name=data['device_name'], + device_profile=data.get('device_profile', 'default') + ) + instance.__setattr__("device_id", uuid.UUID(data['device_id'])) + if 'original_name' in data: + instance.__setattr__("original_name", data['original_name']) + return instance + + def to_dict(self) -> dict: + return { + "device_name": self.device_name, + "device_profile": self.device_profile, + "device_id": str(self.device_id), + "original_name": self.original_name + } + + def __str__(self) -> str: + return (f"DeviceInfo(device_id={self.device_id}, " + f"device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"original_name={self.original_name})") diff --git a/tb_mqtt_client/entities/gateway/device_session_state.py b/tb_mqtt_client/entities/gateway/device_session_state.py index fa669aa..5b10891 100644 --- a/tb_mqtt_client/entities/gateway/device_session_state.py +++ b/tb_mqtt_client/entities/gateway/device_session_state.py @@ -12,3 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum + +class DeviceSessionState(Enum): + CONNECTED = "CONNECTED" + DISCONNECTED = "DISCONNECTED" + RECONNECTING = "RECONNECTING" + CONNECTION_LOST = "CONNECTION_LOST" + CONNECTION_FAILED = "CONNECTION_FAILED" + + def is_connected(self) -> bool: + """ + Check if the session state indicates a connected state. + + Returns: + bool: True if the session is connected, False otherwise. + """ + return self == DeviceSessionState.CONNECTED diff --git a/tb_mqtt_client/entities/gateway/event_type.py b/tb_mqtt_client/entities/gateway/event_type.py new file mode 100644 index 0000000..abb7e0b --- /dev/null +++ b/tb_mqtt_client/entities/gateway/event_type.py @@ -0,0 +1,36 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +class GatewayEventType(Enum): + """ + Enum representing different types of gateway events. + Each event type corresponds to a specific action or state change in the gateway. + """ + DEVICE_ADDED = "DEVICE_ADDED" + DEVICE_REMOVED = "DEVICE_REMOVED" + DEVICE_UPDATED = "DEVICE_UPDATED" + DEVICE_SESSION_STATE_CHANGED = "DEVICE_SESSION_STATE_CHANGED" + DEVICE_RPC_REQUEST_RECEIVED = "DEVICE_RPC_REQUEST_RECEIVED" + DEVICE_RPC_RESPONSE_SENT = "DEVICE_RPC_RESPONSE_SENT" + DEVICE_ATTRIBUTE_UPDATE_RECEIVED = "DEVICE_ATTRIBUTE_UPDATE_RECEIVED" + DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVED = "DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVED" + RPC_REQUEST_RECEIVED = "RPC_REQUEST_RECEIVED" + RPC_RESPONSE_SENT = "RPC_RESPONSE_SENT" + GATEWAY_CONNECTED = "GATEWAY_CONNECTED" + GATEWAY_DISCONNECTED = "GATEWAY_DISCONNECTED" + + def __str__(self) -> str: + return self.value diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py new file mode 100644 index 0000000..e853633 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py @@ -0,0 +1,66 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, List, Dict, Union +from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer +from tb_mqtt_client.constants.json_typing import validate_json_compatibility + + +@dataclass(slots=True, frozen=True) +class GatewayAttributeRequest: + """ + Represents a request for device attributes, with optional client and shared attribute keys. + Automatically assigns a unique request ID via the build() method. + """ + request_id: int + device_name: str + shared_keys: Optional[List[str]] = None + client_keys: Optional[List[str]] = None + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of GatewayAttributeRequest is not allowed. Use 'await GatewayAttributeRequest.build(...)'.") + + def __repr__(self) -> str: + return f"GatewayAttributeRequest(device_name={self.device_name}, id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" + + @classmethod + async def build(cls, device_name: str, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'GatewayAttributeRequest': + """ + Build a new GatewayAttributeRequest with a unique request ID, using the global ID generator. + """ + validate_json_compatibility(shared_keys) + validate_json_compatibility(client_keys) + request_id = await AttributeRequestIdProducer.get_next() + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'shared_keys', shared_keys) + object.__setattr__(self, 'client_keys', client_keys) + return self + + def to_payload_format(self) -> Dict[str, Union[str, bool]]: + """ + Convert the attribute request into the expected MQTT payload format. + """ + payload = {"device": self.device_name, "id": str(self.request_id)} + request_key = 'key' if len(self.client_keys) == 1 or len(self.shared_keys) == 1 else 'keys' + if self.client_keys: + payload['client'] = True + payload[request_key] = ','.join(self.client_keys) + elif self.shared_keys: + # TODO: In current realisation on server it is not possible to request values for the both scopes simultaneously, recommended to improve the platform API + payload['client'] = False + payload[request_key] = ','.join(self.shared_keys) + return payload diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py new file mode 100644 index 0000000..1c59814 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -0,0 +1,32 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent + + +class GatewayAttributeUpdate(GatewayEvent): + """ + Represents an attribute update event for a device connected to a gateway. + This event is used to notify about changes in device shared attributes. + """ + + def __init__(self, device_name: str, attribute_update: AttributeUpdate): + super().__init__(event_type=GatewayEventType.DEVICE_ATTRIBUTE_UPDATE_RECEIVED) + self.device_name = device_name + self.attribute_update = attribute_update + + def __str__(self) -> str: + return f"GatewayAttributeUpdate(device_name={self.device_name}, attribute_update={self.attribute_update})" diff --git a/tb_mqtt_client/entities/gateway/gateway_event.py b/tb_mqtt_client/entities/gateway/gateway_event.py new file mode 100644 index 0000000..66d0c9c --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_event.py @@ -0,0 +1,41 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +class GatewayEvent: + """ + Base class for all events in the gateway client. + This class can be extended to create specific event types. + """ + + def __init__(self, event_type: GatewayEventType): + self.event_type = event_type + self.__device_session: Union[DeviceSession, None] = None + + def set_device_session(self, device_session: DeviceSession): + self.__device_session = device_session + + def get_device_session(self) -> Union[DeviceSession, None]: + return self.__device_session + + def __str__(self) -> str: + return f"GatewayEvent(type={self.event_type}, device_session={self.__device_session})" + + def to_dict(self) -> dict: + return {"event_type": self.event_type, "device_session": self.__device_session} diff --git a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py new file mode 100644 index 0000000..651f2d7 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py @@ -0,0 +1,99 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Any, List, Optional + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@dataclass(slots=True, frozen=True) +class GatewayRequestedAttributeResponse: + + device_name: str + request_id: int + shared: Optional[List[AttributeEntry]] = None + client: Optional[List[AttributeEntry]] = None + + def __repr__(self): + return f"GatewayRequestedAttributeResponse(device_name={self.device_name},request_id={self.request_id}, shared={self.shared}, client={self.client})" + + def __getitem__(self, item): + """ + Allows access to values using dictionary-like syntax. + """ + if self.shared is not None: + for entry in self.shared: + if entry.key == item: + return entry.value + if self.client is not None: + for entry in self.client: + if entry.key == item: + return entry.value + raise KeyError(f"Key '{item}' not found in shared or client attributes.") + + def shared_keys(self): + return [entry.key for entry in self.shared] + + def client_keys(self): + return [entry.key for entry in self.client] + + def get_shared(self, key: str, default=None): + """ + Get the value of a shared attribute by key. + :param key: The key of the shared attribute. + :param default: Default value if the key is not found. + :return: Value of the shared attribute or default. + """ + if self.shared is not None: + for entry in self.shared: + if entry.key == key: + return entry.value + return default + + def get_client(self, key: str, default=None): + """ + Get the value of a client attribute by key. + :param key: The key of the client attribute. + :param default: Default value if the key is not found. + :return: Value of the client attribute or default. + """ + if self.client is not None: + for entry in self.client: + if entry.key == key: + return entry.value + return default + + def as_dict(self) -> Dict[str, Any]: + """ + Convert the GatewayRequestedAttributeResponse to a dictionary format. + :return: Dictionary representation of the response. + """ + return { + 'shared': [entry.as_dict() for entry in self.shared if self.shared is not None], + 'client': [entry.as_dict() for entry in self.client if self.client is not None], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'GatewayRequestedAttributeResponse': + """ + Deserialize dictionary into GatewayRequestedAttributeResponse object. + :param data: Dictionary containing 'device' with device name, 'shared' and 'client' attributes. + :return: GatewayRequestedAttributeResponse instance. + """ + request_id = data.get('request_id', -1) + device_name = data.get('device', '') + shared = [AttributeEntry(k, v) for k, v in data.get('shared', {}).items()] + client = [AttributeEntry(k, v) for k, v in data.get('client', {}).items()] + return cls(device_name=device_name, shared=shared, client=client, request_id=request_id) diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py new file mode 100644 index 0000000..5518232 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -0,0 +1,52 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union, Optional, Dict, Any + + +@dataclass(slots=True, frozen=True) +class GatewayRPCRequest: + request_id: Union[int, str] + device_name: str + method: str + params: Optional[Any] = None + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of GatewayRPCRequest is not allowed. Use 'await GatewayRPCRequest.build(...)'.") + + def __repr__(self): + return f"RPCRequest(id={self.request_id}, device_name={self.device_name}, method={self.method}, params={self.params})" + + @classmethod + def _deserialize_from_dict(cls, data: Dict[str, Union[str, Dict[str, Any]]]) -> 'GatewayRPCRequest': + """ + Constructs an GatewayRPCRequest, should be used only for deserialization request from the platform. + """ + if "device" not in data: + raise ValueError("Missing device name in RPC request") + device_name = data["device"] + data = data["data"] + request_id = data["id"] + if not isinstance(request_id, (int, str)): + raise ValueError("Missing request id in RPC request") + if "method" not in data: + raise ValueError("Missing 'method' in RPC request") + + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'method', data["method"]) + object.__setattr__(self, 'params', data.get("params")) + object.__setattr__(self, 'device', device_name) + return self diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py new file mode 100644 index 0000000..37e9eb2 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py @@ -0,0 +1,90 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from traceback import format_exception +from typing import Union, Optional, Dict, Any + +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType +from tb_mqtt_client.entities.data.rpc_response import RPCStatus + + +@dataclass(slots=True, frozen=True) +class GatewayRPCResponse: + """ + Represents a response to the RPC call. + + Attributes: + device_name: Name of the device for which the RPC response is intended. + request_id: Unique identifier of the request being responded to. + result: Optional response payload (Any type allowed). + error: Optional error information if the RPC failed. + """ + device_name: str + request_id: Union[int, str] + status: RPCStatus = None + result: Optional[Any] = None + error: Optional[Union[str, Dict[str, Any]]] = None + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of GatewayRPCResponse is not allowed. Use GatewayRPCResponse.build(request_id, result, error).") + + def __repr__(self) -> str: + return f"GatewayRPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" + + @classmethod + def build(cls, device_name: str, request_id: Union[int, str], result: Optional[Any] = None, error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'GatewayRPCResponse': + """ + Constructs an GatewayRPCResponse explicitly. + """ + if not isinstance(device_name, str) or not device_name: + raise ValueError("Device name must be a non-empty string") + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + + if error is not None: + if not isinstance(error, (str, dict, BaseException)): + raise ValueError("Error must be a string, dictionary, or an exception instance") + + object.__setattr__(self, 'status', RPCStatus.ERROR) + + if isinstance(error, BaseException): + try: + raise error + except BaseException as e: + error = { + "message": str(e), + "type": type(e).__name__, + "details": ''.join(format_exception(type(e), e, e.__traceback__)) + } + + validate_json_compatibility(error) + object.__setattr__(self, 'error', error) + + else: + object.__setattr__(self, 'status', RPCStatus.SUCCESS) + object.__setattr__(self, 'error', None) + validate_json_compatibility(result) + + object.__setattr__(self, 'result', result) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """Serializes the RPC response for publishing.""" + data = {"device": self.device_name, "id": self.request_id, "data": {}} + if self.result is not None: + data["data"]["result"] = self.result + if self.error is not None: + data["data"]["error"] = self.error + return data diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py new file mode 100644 index 0000000..578793e --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -0,0 +1,172 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass +from types import MappingProxyType +from typing import List, Optional, Union, OrderedDict, Tuple, Mapping + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + +logger = get_logger(__name__) + +DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) + + +@dataclass(slots=True, frozen=True) +class GatewayUplinkMessage: + device_name: Optional[str] + device_profile: Optional[str] + attributes: Tuple[AttributeEntry] + timeseries: Mapping[int, Tuple[TimeseriesEntry]] + delivery_futures: List[Optional[asyncio.Future[PublishResult]]] + _size: int + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") + + def __repr__(self): + return (f"DeviceUplinkMessage(device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"attributes={self.attributes}, " + f"timeseries={self.timeseries}, " + f"delivery_futures={self.delivery_futures})") + + @classmethod + def build(cls, + device_name: Optional[str], + device_profile: Optional[str], + attributes: List[AttributeEntry], + timeseries: Mapping[int, List[TimeseriesEntry]], + delivery_futures: List[Optional[asyncio.Future]], + size: int) -> 'DeviceUplinkMessage': + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) + object.__setattr__(self, 'attributes', tuple(attributes)) + object.__setattr__(self, 'timeseries', + MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) + object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) + object.__setattr__(self, '_size', size) + return self + + @property + def size(self) -> int: + return self._size + + def timeseries_datapoint_count(self) -> int: + return sum(len(entries) for entries in self.timeseries.values()) + + def attributes_datapoint_count(self) -> int: + return len(self.attributes) + + def has_attributes(self) -> bool: + return bool(self.attributes) + + def has_timeseries(self) -> bool: + return bool(self.timeseries) + + def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: + return self.delivery_futures + + +class GatewayUplinkMessageBuilder: + def __init__(self): + self._device_name: Optional[str] = None + self._device_profile: Optional[str] = None + self._attributes: List[AttributeEntry] = [] + self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() + self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] + self.__size = DEFAULT_FIELDS_SIZE + + def set_device_name(self, device_name: str) -> 'GatewayUplinkMessageBuilder': + self._device_name = device_name + if device_name is not None: + self.__size += len(device_name) + return self + + def set_device_profile(self, profile: str) -> 'GatewayUplinkMessageBuilder': + self._device_profile = profile + if profile is not None: + self.__size += len(profile) + return self + + def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]]) -> 'GatewayUplinkMessageBuilder': + if not isinstance(attributes, list): + attributes = [attributes] + self._attributes.extend(attributes) + for attribute in attributes: + self.__size += attribute.size + return self + + def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry], OrderedDict[ + int, List[TimeseriesEntry]]]) -> 'GatewayUplinkMessageBuilder': + if isinstance(timeseries, OrderedDict): + self._timeseries = timeseries + return self + if not isinstance(timeseries, list): + timeseries = [timeseries] + for entry in timeseries: + if entry.ts is not None: + if entry.ts in self._timeseries: + self._timeseries[entry.ts].append(entry) + else: + self._timeseries[entry.ts] = [entry] + else: + if 0 in self._timeseries: + self._timeseries[0].append(entry) + else: + self._timeseries[0] = [entry] + for timeseries_entry in timeseries: + self.__size += timeseries_entry.size + return self + + def add_delivery_futures(self, futures: Union[ + asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(futures, list): + futures = [futures] + if futures: + logger.debug("Created delivery futures: %r", [id(future) for future in futures]) + self._delivery_futures.extend(futures) + return self + + def build(self) -> DeviceUplinkMessage: + if not self._delivery_futures: + self._delivery_futures = [asyncio.get_event_loop().create_future()] + return DeviceUplinkMessage.build( + device_name=self._device_name, + device_profile=self._device_profile, + attributes=self._attributes, + timeseries=self._timeseries, + delivery_futures=self._delivery_futures, + size=self.__size + ) diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 6f43cde..a5365c4 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -48,7 +48,7 @@ RequestedAttributeResponseHandler from tb_mqtt_client.service.device.handlers.rpc_requests_handler import RPCRequestsHandler from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher, MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter, MessageAdapter from tb_mqtt_client.service.message_queue import MessageQueue from tb_mqtt_client.service.mqtt_manager import MQTTManager @@ -73,8 +73,8 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): super().__init__(self._config.host, self._config.port, client_id) self._message_queue: Optional[MessageQueue] = None - self._message_dispatcher: MessageDispatcher = JsonMessageDispatcher(1000, - 1) # Will be updated after connection established + self._message_adapter: MessageAdapter = JsonMessageAdapter(1000, + 1) # Will be updated after connection established self._messages_rate_limit = RateLimit("0:0,", name="messages") self._telemetry_rate_limit = RateLimit("0:0,", name="telemetry") @@ -89,7 +89,7 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._mqtt_manager = MQTTManager(client_id=self._client_id, main_stop_event=self._stop_event, - message_dispatcher=self._message_dispatcher, + message_adapter=self._message_adapter, on_connect=self._on_connect, on_disconnect=self._on_disconnect, on_publish_result=self.__on_publish_result, @@ -136,22 +136,22 @@ async def connect(self): self.max_payload_size = 65535 logger.debug("Using default max_payload_size: %d", self.max_payload_size) - self._message_dispatcher = JsonMessageDispatcher(self.max_payload_size, - self._telemetry_dp_rate_limit.minimal_limit) + self._message_adapter = JsonMessageAdapter(self.max_payload_size, + self._telemetry_dp_rate_limit.minimal_limit) self._message_queue = MessageQueue( mqtt_manager=self._mqtt_manager, main_stop_event=self._stop_event, message_rate_limit=self._messages_rate_limit, telemetry_rate_limit=self._telemetry_rate_limit, telemetry_dp_rate_limit=self._telemetry_dp_rate_limit, - message_dispatcher=self._message_dispatcher, + message_adapter=self._message_adapter, max_queue_size=self._max_uplink_message_queue_size, ) - self._requested_attribute_response_handler.set_message_dispatcher(self._message_dispatcher) - self._attribute_updates_handler.set_message_dispatcher(self._message_dispatcher) - self._rpc_requests_handler.set_message_dispatcher(self._message_dispatcher) - self._rpc_response_handler.set_message_dispatcher(self._message_dispatcher) + self._requested_attribute_response_handler.set_message_adapter(self._message_adapter) + self._attribute_updates_handler.set_message_adapter(self._message_adapter) + self._rpc_requests_handler.set_message_adapter(self._message_adapter) + self._rpc_response_handler.set_message_adapter(self._message_adapter) async def stop(self): """ @@ -271,7 +271,7 @@ async def send_rpc_request( timeout: Optional[float] = BaseClient.DEFAULT_TIMEOUT ) -> Union[RPCResponse, Awaitable[RPCResponse], None]: request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() - topic, payload = self._message_dispatcher.build_rpc_request(rpc_request) + topic, payload = self._message_adapter.build_rpc_request(rpc_request) response_future = self._rpc_response_handler.register_request(request_id, callback) @@ -297,7 +297,7 @@ async def send_rpc_request( mqtt_topics.build_device_rpc_response_topic(rpc_request.request_id), e) async def send_rpc_response(self, response: RPCResponse): - topic, payload = self._message_dispatcher.build_rpc_response(response) + topic, payload = self._message_adapter.build_rpc_response(response) await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, @@ -308,7 +308,7 @@ async def send_attribute_request(self, callback: Callable[[RequestedAttributeResponse], Awaitable[None]], ): await self._requested_attribute_response_handler.register_request(attribute_request, callback) - topic, payload = self._message_dispatcher.build_attribute_request(attribute_request) + topic, payload = self._message_adapter.build_attribute_request(attribute_request) await self._message_queue.publish(topic=topic, payload=payload, @@ -319,7 +319,7 @@ async def claim_device(self, claim_request: ClaimRequest, wait_for_publish: bool = True, timeout: float = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: - topic, payload = self._message_dispatcher.build_claim_request(claim_request) + topic, payload = self._message_adapter.build_claim_request(claim_request) publish_future = await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) if isinstance(publish_future, list): publish_future = publish_future[0] @@ -428,8 +428,8 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa if "maxPayloadSize" in response.result: self.max_payload_size = int(response.result["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) # Update the dispatcher's max_payload_size if it's already initialized - if self._message_dispatcher is not None and hasattr(self._message_dispatcher, 'splitter'): - self._message_dispatcher.splitter.max_payload_size = self.max_payload_size + if self._message_adapter is not None and hasattr(self._message_adapter, 'splitter'): + self._message_adapter.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) else: # If maxPayloadSize is not provided, keep the default value @@ -439,8 +439,8 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa self.max_payload_size = 65535 logger.debug("Using default max_payload_size: %d", self.max_payload_size) # Update the dispatcher's max_payload_size if it's already initialized - if self._message_dispatcher is not None and hasattr(self._message_dispatcher, 'splitter'): - self._message_dispatcher.splitter.max_payload_size = self.max_payload_size + if self._message_adapter is not None and hasattr(self._message_adapter, 'splitter'): + self._message_adapter.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) if (not self._messages_rate_limit.has_limit() diff --git a/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py b/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py index 9787fdd..e740d5e 100644 --- a/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py +++ b/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py @@ -16,7 +16,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter logger = get_logger(__name__) @@ -27,20 +27,20 @@ class AttributeUpdatesHandler: """ def __init__(self): - self._message_dispatcher = None + self._message_adapter = None self._callback: Optional[Callable[[AttributeUpdate], Awaitable[None]]] = None - def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + def set_message_adapter(self, message_adapter: MessageAdapter): """ - Sets the message dispatcher for handling incoming messages. + Sets the message adapter for handling incoming messages. This should be called before any callbacks are set. - :param message_dispatcher: An instance of MessageDispatcher. + :param message_adapter: An instance of MessageAdapter. """ - if not isinstance(message_dispatcher, MessageDispatcher): - raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") - self._message_dispatcher = message_dispatcher - logger.debug("Message dispatcher set for AttributeUpdatesHandler.") + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for AttributeUpdatesHandler.") def set_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): """ @@ -56,7 +56,7 @@ async def handle(self, topic: str, payload: bytes): # noqa return try: - data = self._message_dispatcher.parse_attribute_update(payload) + data = self._message_adapter.parse_attribute_update(payload) logger.debug("Handling attribute update: %r", data) await self._callback(data) except Exception as e: diff --git a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py index 38ca4ca..78f7139 100644 --- a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py @@ -17,7 +17,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter logger = get_logger(__name__) @@ -28,18 +28,18 @@ class RequestedAttributeResponseHandler: """ def __init__(self): - self._message_dispatcher = None + self._message_adapter = None self._pending_attribute_requests: Dict[int, Tuple[AttributeRequest, Callable[[RequestedAttributeResponse], Awaitable[None]]]] = {} - def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + def set_message_adapter(self, message_adapter: MessageAdapter): """ - Sets the message dispatcher for handling incoming messages. + Sets the message adapter for handling incoming messages. This should be called before any requests are registered. """ - if not isinstance(message_dispatcher, MessageDispatcher): - raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") - self._message_dispatcher = message_dispatcher - logger.debug("Message dispatcher set for RequestedAttributeResponseHandler.") + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for RequestedAttributeResponseHandler.") async def register_request(self, request: AttributeRequest, callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): """ @@ -66,13 +66,13 @@ async def handle(self, topic: str, payload: bytes): Handles the incoming attribute request response. """ try: - if not self._message_dispatcher: - logger.error("Message dispatcher is not initialized. Cannot handle attribute response.") + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle attribute response.") request_id = topic.split('/')[-1] # Assuming request ID is in the topic self._pending_attribute_requests.pop(int(request_id), None) return - requested_attribute_response = self._message_dispatcher.parse_attribute_request_response(topic, payload) + requested_attribute_response = self._message_adapter.parse_requested_attribute_response(topic, payload) pending_request_details = self._pending_attribute_requests.pop(requested_attribute_response.request_id, None) if not pending_request_details: logger.warning("No future awaiting request ID %s. Ignoring.", requested_attribute_response.request_id) diff --git a/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py b/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py index d80f6b6..e20cc9d 100644 --- a/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py +++ b/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py @@ -17,7 +17,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter logger = get_logger(__name__) @@ -28,19 +28,19 @@ class RPCRequestsHandler: """ def __init__(self): - self._message_dispatcher = None + self._message_adapter = None self._callback: Optional[Callable[[RPCRequest], Awaitable[RPCResponse]]] = None - def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + def set_message_adapter(self, message_adapter: MessageAdapter): """ - Sets the message dispatcher for handling incoming messages. + Sets the message adapter for handling incoming messages. This should be called before any callbacks are set. - :param message_dispatcher: An instance of MessageDispatcher. + :param message_adapter: An instance of MessageAdapter. """ - if not isinstance(message_dispatcher, MessageDispatcher): - raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") - self._message_dispatcher = message_dispatcher - logger.debug("Message dispatcher set for RPCRequestsHandler.") + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for RPCRequestsHandler.") def set_callback(self, callback: Callable[[RPCRequest], Awaitable[RPCResponse]]): """ @@ -59,12 +59,12 @@ async def handle(self, topic: str, payload: bytes) -> Optional[RPCResponse]: "You can add set callback using client.set_rpc_request_callback(your_method)") return None - if not self._message_dispatcher: - logger.error("Message dispatcher is not initialized. Cannot handle RPC request.") + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle RPC request.") return None try: - rpc_request = self._message_dispatcher.parse_rpc_request(topic, payload) + rpc_request = self._message_adapter.parse_rpc_request(topic, payload) logger.debug("Handling RPC method id: %i - %s with params: %s", rpc_request.request_id, rpc_request.method, rpc_request.params) result = await self._callback(rpc_request) diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py index f241f3d..c3d42b0 100644 --- a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -17,7 +17,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher, JsonMessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter, JsonMessageAdapter logger = get_logger(__name__) @@ -29,21 +29,21 @@ class RPCResponseHandler: """ def __init__(self): - self._message_dispatcher: Optional[MessageDispatcher] = None + self._message_adapter: Optional[MessageAdapter] = None self._pending_rpc_requests: Dict[Union[str, int], Tuple[asyncio.Future[RPCResponse], Optional[Callable[[RPCResponse], Awaitable[None]]]]] = {} - def set_message_dispatcher(self, message_dispatcher: MessageDispatcher): + def set_message_adapter(self, message_adapter: MessageAdapter): """ - Sets the message dispatcher for handling incoming messages. + Sets the message adapter for handling incoming messages. This should be called before any requests are registered. - :param message_dispatcher: An instance of MessageDispatcher. + :param message_adapter: An instance of MessageAdapter. """ - if not isinstance(message_dispatcher, MessageDispatcher): - raise ValueError("message_dispatcher must be an instance of MessageDispatcher.") - self._message_dispatcher = message_dispatcher - logger.debug("Message dispatcher set for RPCResponseHandler.") + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for RPCResponseHandler.") def register_request(self, request_id: Union[str, int], callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> asyncio.Future[RPCResponse]: @@ -62,11 +62,11 @@ async def handle(self, topic: str, payload: Union[bytes, TimeoutError]): The topic is expected to be: v1/devices/me/rpc/response/{request_id} """ try: - if not self._message_dispatcher: - dummy_dispatcher = JsonMessageDispatcher() - rpc_response = dummy_dispatcher.parse_rpc_response(topic, payload) + if not self._message_adapter: + dummy_adapter = JsonMessageAdapter() + rpc_response = dummy_adapter.parse_rpc_response(topic, payload) else: - rpc_response = self._message_dispatcher.parse_rpc_response(topic, payload) + rpc_response = self._message_adapter.parse_rpc_response(topic, payload) request_details = self._pending_rpc_requests.pop(rpc_response.request_id, None) if not request_details: diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/device/message_adapter.py similarity index 93% rename from tb_mqtt_client/service/message_dispatcher.py rename to tb_mqtt_client/service/device/message_adapter.py index 2ab326c..2940746 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -12,6 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import asyncio from abc import ABC, abstractmethod from itertools import chain @@ -38,7 +52,7 @@ logger = get_logger(__name__) -class MessageDispatcher(ABC): +class MessageAdapter(ABC): def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): self._splitter = MessageSplitter(max_payload_size, max_datapoints) logger.trace("MessageDispatcher initialized with max_payload_size=%s, max_datapoints=%s", @@ -103,7 +117,7 @@ def splitter(self) -> MessageSplitter: pass @abstractmethod - def parse_attribute_request_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: """ Parse the attribute request response payload into an AttributeRequestResponse. This method should be implemented to handle the specific format of the topic and payload. @@ -143,7 +157,7 @@ def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, pass -class JsonMessageDispatcher(MessageDispatcher): +class JsonMessageAdapter(MessageAdapter): """ A concrete implementation of MessageDispatcher that operates with JSON payloads. """ @@ -151,7 +165,7 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio super().__init__(max_payload_size, max_datapoints) logger.trace("JsonMessageDispatcher created.") - def parse_attribute_request_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: """ Parse the attribute request response payload into a RequestedAttributeResponse. :param topic: The MQTT topic of the requested attribute response. @@ -267,13 +281,13 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup device, len(telemetry_msgs), len(attr_msgs)) for ts_batch in self._splitter.split_timeseries(telemetry_msgs): - payload = JsonMessageDispatcher.build_payload(ts_batch, True) + payload = JsonMessageAdapter.build_payload(ts_batch, True) count = ts_batch.timeseries_datapoint_count() result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) for attr_batch in self._splitter.split_attributes(attr_msgs): - payload = JsonMessageDispatcher.build_payload(attr_batch, False) + payload = JsonMessageAdapter.build_payload(attr_batch, False) count = len(attr_batch.attributes) result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) @@ -398,17 +412,17 @@ def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: if msg.device_name: if build_timeseries_payload: logger.trace("Packing timeseries for device='%s'", device_name) - result[msg.device_name] = JsonMessageDispatcher.pack_timeseries(msg) + result[msg.device_name] = JsonMessageAdapter.pack_timeseries(msg) else: logger.trace("Packing attributes for device='%s'", device_name) - result[msg.device_name] = JsonMessageDispatcher.pack_attributes(msg) + result[msg.device_name] = JsonMessageAdapter.pack_attributes(msg) else: if build_timeseries_payload: logger.trace("Packing timeseries") - result = JsonMessageDispatcher.pack_timeseries(msg) + result = JsonMessageAdapter.pack_timeseries(msg) else: logger.trace("Packing attributes") - result = JsonMessageDispatcher.pack_attributes(msg) + result = JsonMessageAdapter.pack_attributes(msg) payload = dumps(result) logger.trace("Built payload size: %d bytes", len(payload)) diff --git a/tb_mqtt_client/service/event_dispatcher.py b/tb_mqtt_client/service/event_dispatcher.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/service/event_dispatcher.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index c7ae4ee..eaea7b6 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -13,29 +13,29 @@ # limitations under the License. from asyncio import sleep -from random import choices -from string import ascii_uppercase, digits -from typing import Callable, Awaitable, Optional, Dict, Any, Union, List, Set - -from orjson import dumps, loads +from time import monotonic +from typing import Optional, Dict, Union from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher +from tb_mqtt_client.service.gateway.gateway_client_interface import GatewayClientInterface +from tb_mqtt_client.service.gateway.handlers.gateway_attribute_updates_handler import GatewayAttributeUpdatesHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter, JsonGatewayMessageAdapter logger = get_logger(__name__) -class GatewayClient(DeviceClient): +class GatewayClient(DeviceClient, GatewayClientInterface): """ ThingsBoard Gateway MQTT client implementation. This class extends DeviceClient and adds gateway-specific functionality. """ + SUBSCRIPTIONS_TIMEOUT = 1.0 # Timeout for subscribe/unsubscribe operations def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): """ @@ -43,28 +43,24 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): :param config: Gateway configuration object or dictionary """ - self._config = None - if isinstance(config, GatewayConfig): - self._config = config - else: - self._config = GatewayConfig() - if isinstance(config, dict): - for key, value in config.items(): - if hasattr(self._config, key) and value is not None: - setattr(self._config, key, value) + self._config = config if isinstance(config, GatewayConfig) else GatewayConfig(config) + super().__init__(self._config) - client_id = self._config.client_id or "tb-gateway-" + ''.join(choices(ascii_uppercase + digits, k=6)) + self._device_manager = DeviceManager() + self._event_dispatcher: EventDispatcher = EventDispatcher() + self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter() - # Initialize the DeviceClient with the gateway configuration - super().__init__(self._config) + self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed + self._gateway_rpc_handler = None # Placeholder for gateway RPC handler + self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler(self._event_dispatcher, + self._gateway_message_adapter, + self._device_manager) + self._gateway_requested_attribute_response_handler = None # Placeholder for gateway requested attribute response handler # Gateway-specific rate limits - self._device_messages_rate_limit = RateLimit("0:0,", name="device_messages") - self._device_telemetry_rate_limit = RateLimit("0:0,", name="device_telemetry") - self._device_telemetry_dp_rate_limit = RateLimit("0:0,", name="device_telemetry_datapoints") - - # Set of connected devices - self._connected_devices: Set[str] = set() + self._device_messages_rate_limit = RateLimit("10:1,", name="device_messages") + self._device_telemetry_rate_limit = RateLimit("10:1,", name="device_telemetry") + self._device_telemetry_dp_rate_limit = RateLimit("10:1,", name="device_telemetry_datapoints") # Callbacks self._device_attribute_update_callback = None @@ -73,7 +69,7 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): async def connect(self): """ - Connect to the ThingsBoard platform. + Connect to the platform. """ logger.info("Connecting gateway to platform at %s:%s", self._host, self._port) await super().connect() @@ -83,329 +79,63 @@ async def connect(self): logger.info("Gateway connected to ThingsBoard.") + async def disconnect(self): + """ + Disconnect from the platform. + """ + logger.info("Disconnecting gateway from platform at %s:%s", self._host, self._port) + await self._unsubscribe_from_gateway_topics() + await super().disconnect() + logger.info("Gateway disconnected from ThingsBoard.") + async def _subscribe_to_gateway_topics(self): """ Subscribe to gateway-specific MQTT topics. """ logger.info("Subscribing to gateway topics") - # Subscribe to gateway attributes topic sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, qos=1) while not sub_future.done(): await sleep(0.01) - # Subscribe to gateway attributes response topic sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, qos=1) while not sub_future.done(): await sleep(0.01) - # Subscribe to gateway RPC topic sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_RPC_TOPIC, qos=1) while not sub_future.done(): await sleep(0.01) - # Register handlers for gateway topics - self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, self._handle_gateway_attribute_update) - self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_RPC_TOPIC, self._handle_gateway_rpc_request) - self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, self._handle_gateway_attribute_response) - - async def _handle_gateway_attribute_update(self, topic: str, payload: bytes): - """ - Handle attribute updates for gateway devices. - - :param topic: MQTT topic - :param payload: Message payload - """ - try: - data = loads(payload) - logger.debug("Received gateway attribute update: %s", data) - - if self._device_attribute_update_callback: - for device_name, attributes in data.items(): - update = AttributeUpdate(device=device_name, attributes=attributes) - await self._device_attribute_update_callback(update) - except Exception as e: - logger.exception("Error handling gateway attribute update: %s", e) - - async def _handle_gateway_rpc_request(self, topic: str, payload: bytes): - """ - Handle RPC requests for gateway devices. - - :param topic: MQTT topic - :param payload: Message payload - """ - try: - data = loads(payload) - logger.debug("Received gateway RPC request: %s", data) - - if self._device_rpc_request_callback and 'device' in data and 'data' in data: - device_name = data['device'] - rpc_data = data['data'] - - if 'id' in rpc_data and 'method' in rpc_data: - request_id = rpc_data['id'] - method = rpc_data['method'] - params = rpc_data.get('params', {}) - - result = await self._device_rpc_request_callback(device_name, method, params) - - # Send RPC response - await self.gw_send_rpc_reply(device_name, request_id, result) - except Exception as e: - logger.exception("Error handling gateway RPC request: %s", e) - - async def _handle_gateway_attribute_response(self, topic: str, payload: bytes): - """ - Handle attribute responses for gateway devices. - - :param topic: MQTT topic - :param payload: Message payload - """ - try: - data = loads(payload) - logger.debug("Received gateway attribute response: %s", data) - - # Process attribute response if needed - # This is typically used for handling responses to attribute requests - except Exception as e: - logger.exception("Error handling gateway attribute response: %s", e) - - async def gw_connect_device(self, device_name: str): - """ - Connect a device to the gateway. - - :param device_name: Name of the device to connect - """ - if device_name in self._connected_devices: - logger.warning("Device %s is already connected", device_name) - return + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, self._gateway_attribute_updates_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_RPC_TOPIC, self._gateway_rpc_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, self._gateway_requested_attribute_response_handler.handle) - self._connected_devices.add(device_name) - logger.info("Device %s connected to gateway", device_name) - - async def gw_disconnect_device(self, device_name: str): - """ - Disconnect a device from the gateway. - - :param device_name: Name of the device to disconnect + async def _unsubscribe_from_gateway_topics(self): """ - if device_name not in self._connected_devices: - logger.warning("Device %s is not connected", device_name) - return - - self._connected_devices.remove(device_name) - - # Publish device disconnect message - await self._mqtt_manager.publish( - mqtt_topics.GATEWAY_DISCONNECT_TOPIC, - dumps({"device": device_name}), - qos=1 - ) - - logger.info("Device %s disconnected from gateway", device_name) - - # Call disconnect callback if registered - if self._device_disconnect_callback: - await self._device_disconnect_callback(device_name) - - async def gw_send_timeseries(self, device_name: str, telemetry: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry]]): - """ - Send telemetry on behalf of a connected device. - - :param device_name: Name of the device - :param telemetry: Telemetry data to send - """ - if device_name not in self._connected_devices: - logger.warning("Cannot send telemetry for disconnected device %s", device_name) - return - - # Convert telemetry to the appropriate format - payload = self._prepare_telemetry_payload(device_name, telemetry) - - # Publish telemetry - await self._mqtt_manager.publish( - mqtt_topics.GATEWAY_TELEMETRY_TOPIC, - dumps(payload), - qos=1 - ) - - logger.debug("Sent telemetry for device %s: %s", device_name, payload) - - async def gw_send_attributes(self, device_name: str, attributes: Union[Dict[str, Any], AttributeEntry, List[AttributeEntry]]): - """ - Send attributes on behalf of a connected device. - - :param device_name: Name of the device - :param attributes: Attributes to send + Unsubscribe from gateway-specific MQTT topics. """ - if device_name not in self._connected_devices: - logger.warning("Cannot send attributes for disconnected device %s", device_name) - return + logger.info("Unsubscribing from gateway topics") - # Convert attributes to the appropriate format - payload = self._prepare_attributes_payload(device_name, attributes) - - # Publish attributes - await self._mqtt_manager.publish( - mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, - dumps(payload), - qos=1 - ) - - logger.debug("Sent attributes for device %s: %s", device_name, payload) - - async def gw_send_rpc_reply(self, device_name: str, request_id: int, response: Dict[str, Any]): - """ - Send an RPC response on behalf of a connected device. - - :param device_name: Name of the device - :param request_id: ID of the RPC request - :param response: Response data - """ - if device_name not in self._connected_devices: - logger.warning("Cannot send RPC reply for disconnected device %s", device_name) - return - - # Prepare RPC response payload - payload = { - "device": device_name, - "id": request_id, - "data": response - } - - # Publish RPC response - await self._mqtt_manager.publish( - mqtt_topics.GATEWAY_RPC_RESPONSE_TOPIC, - dumps(payload), - qos=1 - ) - - logger.debug("Sent RPC response for device %s, request %s: %s", device_name, request_id, response) - - async def gw_request_shared_attributes(self, device_name: str, keys: List[str], callback: Callable[[Dict[str, Any]], Awaitable[None]]): - """ - Request shared attributes for a connected device. - - :param device_name: Name of the device - :param keys: List of attribute keys to request - :param callback: Callback function to handle the response - """ - if device_name not in self._connected_devices: - logger.warning("Cannot request attributes for disconnected device %s", device_name) - return - - # TODO: Implement attribute request handling with callbacks - - # Prepare attribute request payload - request_id = 1 # TODO: Generate unique request ID - payload = { - "device": device_name, - "keys": keys, - "id": request_id - } - - # Publish attribute request - await self._mqtt_manager.publish( - mqtt_topics.GATEWAY_ATTRIBUTES_REQUEST_TOPIC, - dumps(payload), - qos=1 - ) - - logger.debug("Requested shared attributes for device %s: %s", device_name, keys) - - def set_device_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): - """ - Set callback for device attribute updates. - - :param callback: Callback function - """ - self._device_attribute_update_callback = callback - - def set_device_rpc_request_callback(self, callback: Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): - """ - Set callback for device RPC requests. - - :param callback: Callback function that takes device name, method, and params - """ - self._device_rpc_request_callback = callback - - def set_device_disconnect_callback(self, callback: Callable[[str], Awaitable[None]]): - """ - Set callback for device disconnections. - - :param callback: Callback function - """ - self._device_disconnect_callback = callback - - def _prepare_telemetry_payload(self, device_name: str, telemetry: Union[Dict[str, Any], TimeseriesEntry, List[TimeseriesEntry]]) -> Dict[str, Any]: - """ - Prepare telemetry payload for gateway API. - - :param device_name: Name of the device - :param telemetry: Telemetry data - :return: Formatted payload - """ - if isinstance(telemetry, dict): - # Simple key-value telemetry - return {device_name: telemetry} - - elif isinstance(telemetry, TimeseriesEntry): - # Single TimeseriesEntry - if telemetry.ts: - return {device_name: {"ts": telemetry.ts, "values": {telemetry.key: telemetry.value}}} - else: - return {device_name: {telemetry.key: telemetry.value}} - - elif isinstance(telemetry, list): - # List of TimeseriesEntry objects - # Group by timestamp - ts_groups = {} - for entry in telemetry: - ts = entry.ts or 0 - if ts not in ts_groups: - ts_groups[ts] = {} - ts_groups[ts][entry.key] = entry.value - - if len(ts_groups) == 1 and 0 in ts_groups: - # No timestamps, just values - return {device_name: ts_groups[0]} - else: - # With timestamps - result = [] - for ts, values in ts_groups.items(): - if ts > 0: - result.append({"ts": ts, "values": values}) - else: - result.append({"values": values}) - return {device_name: result} - - # Fallback - logger.warning("Unsupported telemetry format: %s", type(telemetry)) - return {device_name: {}} - - def _prepare_attributes_payload(self, device_name: str, attributes: Union[Dict[str, Any], AttributeEntry, List[AttributeEntry]]) -> Dict[str, Any]: - """ - Prepare attributes payload for gateway API. - - :param device_name: Name of the device - :param attributes: Attributes data - :return: Formatted payload - """ - if isinstance(attributes, dict): - # Simple key-value attributes - return {device_name: attributes} - - elif isinstance(attributes, AttributeEntry): - # Single AttributeEntry - return {device_name: {attributes.key: attributes.value}} + unsub_future = await self._mqtt_manager.unsubscribe(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC) + unsubscribe_start_time = monotonic() + while not unsub_future.done(): + if monotonic() - unsubscribe_start_time > self.SUBSCRIPTIONS_TIMEOUT: + logger.warning("Unsubscribe from gateway attributes topic timed out") + break + await sleep(0.01) - elif isinstance(attributes, list): - # List of AttributeEntry objects - attrs = {} - for entry in attributes: - attrs[entry.key] = entry.value - return {device_name: attrs} + unsub_future = await self._mqtt_manager.unsubscribe(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC) + unsubscribe_start_time = monotonic() + while not unsub_future.done(): + if monotonic() - unsubscribe_start_time > self.SUBSCRIPTIONS_TIMEOUT: + logger.warning("Unsubscribe from gateway attribute responses topic timed out") + break + await sleep(0.01) - # Fallback - logger.warning("Unsupported attributes format: %s", type(attributes)) - return {device_name: {}} + unsub_future = await self._mqtt_manager.unsubscribe(mqtt_topics.GATEWAY_RPC_TOPIC) + unsubscribe_start_time = monotonic() + while not unsub_future.done(): + if monotonic() - unsubscribe_start_time > self.SUBSCRIPTIONS_TIMEOUT: + logger.warning("Unsubscribe from gateway rpc topic timed out") + break + await sleep(0.01) diff --git a/tb_mqtt_client/service/gateway/device_manager.py b/tb_mqtt_client/service/gateway/device_manager.py new file mode 100644 index 0000000..c3e0cb2 --- /dev/null +++ b/tb_mqtt_client/service/gateway/device_manager.py @@ -0,0 +1,121 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Dict, Iterable, Callable, Set +from uuid import UUID + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo + + +logger = get_logger(__name__) + + +class DeviceManager: + def __init__(self): + self._sessions_by_id: Dict[UUID, DeviceSession] = {} + self._ids_by_device_name: Dict[str, UUID] = {} + self._ids_by_original_name: Dict[str, UUID] = {} + self.__connected_devices: Set[DeviceSession] = set() + + + def register(self, device_name: str, device_profile: str = "default") -> DeviceSession: + session = self.get_by_name(device_name) + if session: + return session + + device_info = DeviceInfo( + device_name=device_name, + device_profile=device_profile + ) + session = DeviceSession(device_info, self.__state_change_callback) + self._sessions_by_id[device_info.device_id] = session + self._ids_by_device_name[device_name] = device_info.device_id + self._ids_by_original_name[device_info.original_name] = device_info.device_id + session.update_last_seen() + return session + + def unregister(self, device_id: UUID): + session = self._sessions_by_id.pop(device_id, None) + if session: + self._ids_by_device_name.pop(session.device_info.device_name, None) + self._ids_by_original_name.pop(session.device_info.original_name, None) + + def get_by_id(self, device_id: UUID) -> Optional[DeviceSession]: + return self._sessions_by_id.get(device_id) + + def get_by_name(self, device_name: str) -> Optional[DeviceSession]: + device_id = self._ids_by_device_name.get(device_name) + if device_id: + return self._sessions_by_id.get(device_id) + renamed_id = self._ids_by_original_name.get(device_name) + if renamed_id: + return self._sessions_by_id.get(renamed_id) + return None + + def is_connected(self, device_id: UUID) -> bool: + return device_id in self._sessions_by_id + + def all(self) -> Iterable[DeviceSession]: + return self._sessions_by_id.values() + + def rename_device(self, old_name: str, new_name: str): + device_session = self.get_by_name(old_name) + if not device_session: + logger.warning(f"Device with name '{old_name}' not found for renaming to '{new_name}'.") + return + device_session.device_info.rename(new_name) + self._ids_by_device_name.pop(old_name, None) + device_id = device_session.device_info.device_id + self._ids_by_device_name[new_name] = device_id + self._ids_by_original_name[device_session.device_info.original_name] = device_id + + def set_attribute_update_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_attribute_update_callback(cb) + + def set_attribute_response_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_attribute_response_callback(cb) + + def set_rpc_request_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_rpc_request_callback(cb) + + def set_rpc_response_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_rpc_response_callback(cb) + + def __state_change_callback(self, device_session: DeviceSession) -> None: + if device_session.state.is_connected() and device_session.device_info.device_id not in self.__connected_devices: + self.__connected_devices.add(device_session) + else: + self.__connected_devices.remove(device_session) + logger.debug(f"Device {device_session.device_info.device_name} state changed to {device_session.state}") + + @property + def connected_devices(self) -> Set[DeviceSession]: + return self.__connected_devices + + @property + def all_devices(self) -> Dict[UUID, DeviceSession]: + return self._sessions_by_id + + def __repr__(self): + return f"DeviceManager({str(list(self._ids_by_device_name.keys()))})" diff --git a/tb_mqtt_client/service/gateway/device_sesion.py b/tb_mqtt_client/service/gateway/device_sesion.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/service/gateway/device_sesion.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index fa669aa..152f7bb 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -12,3 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +from time import time +from dataclasses import dataclass, field +from typing import Callable, Awaitable, Optional, Dict, Any + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.device_session_state import DeviceSessionState + + +@dataclass +class DeviceSession: + device_info: DeviceInfo + _state_change_callback: Optional[Callable[['DeviceSession'], None]] = None + connected_at: float = field(default_factory=lambda: int(time() * 1000)) + last_seen_at: float = field(default_factory=lambda: int(time() * 1000)) + claimed: bool = False + provisioned: bool = False + state: DeviceSessionState = DeviceSessionState.CONNECTED + + attribute_update_callback: Optional[Callable[[dict], Awaitable[None]]] = None + attribute_response_callback: Optional[Callable[[dict], Awaitable[None]]] = None + rpc_request_callback: Optional[Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None + rpc_response_callback: Optional[Callable[[dict], Awaitable[None]]] = None + + def update_state(self, new_state: DeviceSessionState): + self.state = new_state + if self._state_change_callback: + self._state_change_callback(self) + + def update_last_seen(self): + self.last_seen_at = int(time() * 1000) + + def set_attribute_update_callback(self, cb: Callable[[dict], Awaitable[None]]): + self.attribute_update_callback = cb + + def set_attribute_response_callback(self, cb: Callable[[dict], Awaitable[None]]): + self.attribute_response_callback = cb + + def set_rpc_request_callback(self, cb: Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + self.rpc_request_callback = cb + + def set_rpc_response_callback(self, cb: Callable[[dict], Awaitable[None]]): + self.rpc_response_callback = cb diff --git a/tb_mqtt_client/service/gateway/event_dispatcher.py b/tb_mqtt_client/service/gateway/event_dispatcher.py new file mode 100644 index 0000000..2108abb --- /dev/null +++ b/tb_mqtt_client/service/gateway/event_dispatcher.py @@ -0,0 +1,56 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import defaultdict +from typing import Callable, Awaitable, Dict, List, Union + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent + +EventCallback = Union[Callable[..., Awaitable[None]], Callable[..., None]] + +logger = get_logger(__name__) + + +class EventDispatcher: + """ + Direct event dispatcher for handling gateway events. + """ + def __init__(self): + self._handlers: Dict[GatewayEventType, List[EventCallback]] = defaultdict(list) + self._lock = asyncio.Lock() + + def register(self, event_type: GatewayEventType, callback: EventCallback): + if callback not in self._handlers[event_type]: + self._handlers[event_type].append(callback) + + def unregister(self, event_type: GatewayEventType, callback: EventCallback): + if callback in self._handlers[event_type]: + self._handlers[event_type].remove(callback) + if not self._handlers[event_type]: + del self._handlers[event_type] + + async def dispatch(self, event: GatewayEvent, *args, **kwargs): + async with self._lock: + callbacks = list(self._handlers.get(event.event_type, [])) + for cb in callbacks: + try: + if asyncio.iscoroutinefunction(cb): + await cb(event, *args, **kwargs) + else: + cb(event, *args, **kwargs) + except Exception as e: + logger.error(f"[EventDispatcher] Exception in handler for '{event.event_type}': {e}") diff --git a/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/service/gateway/gateway_attribute_updates_handler.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/service/gateway/gateway_client_interface.py b/tb_mqtt_client/service/gateway/gateway_client_interface.py new file mode 100644 index 0000000..26cacac --- /dev/null +++ b/tb_mqtt_client/service/gateway/gateway_client_interface.py @@ -0,0 +1,73 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +class GatewayClientInterface(BaseClient, ABC): + + @abstractmethod + async def connect_device(self, device_name: str, device_profile: str) -> DeviceSession: ... + + @abstractmethod + async def disconnect_device(self, device_session: DeviceSession): ... + + @abstractmethod + async def send_device_telemetry(self, device_session: DeviceSession, telemetry: ...): ... + + @abstractmethod + async def send_device_attributes(self, device_session: DeviceSession, attributes: ...): ... + + @abstractmethod + async def send_device_attributes_request(self, device_session: DeviceSession, attributes: AttributeRequest): ... + + @abstractmethod + async def send_device_client_side_rpc_request(self, device_session: DeviceSession, rpc_request: RPCRequest): ... + + @abstractmethod + async def send_device_server_side_rpc_response(self, device_session: DeviceSession, rpc_response: RPCResponse): ... + + + + @abstractmethod + def set_device_server_side_rpc_request_callback(self, device_session: DeviceSession, callback: ...): ... + + @abstractmethod + def set_device_client_side_rpc_response_callback(self, device_session: DeviceSession, callback: ...): ... + + @abstractmethod + def set_device_requested_attributes_callback(self, device_session: DeviceSession, callback: ...): ... + + @abstractmethod + def set_device_attributes_update_callback(self, device_session: DeviceSession, callback: ...): ... diff --git a/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py b/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/service/gateway/gateway_rpc_requests_handler.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/entities/gateway/virtual_device.py b/tb_mqtt_client/service/gateway/handlers/__init__.py similarity index 99% rename from tb_mqtt_client/entities/gateway/virtual_device.py rename to tb_mqtt_client/service/gateway/handlers/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/entities/gateway/virtual_device.py +++ b/tb_mqtt_client/service/gateway/handlers/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py new file mode 100644 index 0000000..bdd8c95 --- /dev/null +++ b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py @@ -0,0 +1,35 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +class GatewayAttributeUpdatesHandler: + """Handles shared attribute updates for devices connected to a gateway.""" + def __init__(self, event_dispatcher: EventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + self.event_dispatcher = event_dispatcher + self.message_adapter = message_adapter + self.device_manager = device_manager + + def handle(self, topic: str, payload: bytes): + """ + Handles the gateway attribute update event by dispatching the attribute update + """ + gateway_attribute_update = self.message_adapter.parse_attribute_update(payload) + device_session = self.device_manager.get_by_name(gateway_attribute_update.device_name) + if device_session: + gateway_attribute_update.set_device_session(device_session) + self.event_dispatcher.dispatch(gateway_attribute_update) diff --git a/tb_mqtt_client/entities/gateway/rpc_context.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py similarity index 100% rename from tb_mqtt_client/entities/gateway/rpc_context.py rename to tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py new file mode 100644 index 0000000..1d1241b --- /dev/null +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -0,0 +1,267 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from abc import abstractmethod, ABC +from collections import defaultdict +from datetime import datetime, UTC +from itertools import chain +from typing import List, Optional, Tuple, Dict, Any, Union + +from orjson import loads, dumps + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse + +logger = get_logger(__name__) + + +class GatewayMessageAdapter(ABC): + """ + Adapter for converting events to uplink messages and received messages to events. + """ + + @abstractmethod + def build_uplink_payloads( + self, + messages: List[DeviceUplinkMessage] + ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + """ + Build a list of topic-payload pairs from the given messages. + Each pair consists of a topic string, payload bytes, the number of datapoints, + and a list of futures for delivery confirmation. + """ + pass + + @abstractmethod + def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage) -> Tuple[str, bytes]: + """ + Build the payload for a device connect message. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage) -> Tuple[str, bytes]: + """ + Build the payload for a device disconnect message. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest) -> Tuple[str, bytes]: + """ + Build the payload for a gateway attribute request. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_attribute_update(self, payload: bytes) -> GatewayAttributeUpdate: + """ + Parse the attribute update payload into an GatewayAttributeUpdate. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_gateway_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, payload: bytes) -> Union[GatewayRequestedAttributeResponse, None]: + """ + Parse the gateway attribute response payload into an GatewayAttributeResponse. + This method should be implemented to handle the specific format of the payload. + """ + pass + + +class JsonGatewayMessageAdapter(GatewayMessageAdapter): + """ + JSON implementation of GatewayMessageAdapter. + Builds uplink payloads from uplink message objects and parses JSON payloads into GatewayEvent objects. + """ + + def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + """ + Build a list of topic-payload pairs from the given messages. + Each pair consists of a topic string, payload bytes, the number of datapoints, + and a list of futures for delivery confirmation. + """ + try: + if not messages: + logger.trace("No messages to process in build_topic_payloads.") + return [] + + result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] + device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) + + for msg in messages: + device_name = msg.device_name + device_groups[device_name].append(msg) + logger.trace("Queued message for device='%s'", device_name) + + logger.trace("Processing %d device group(s).", len(device_groups)) + + gateway_timeseries_message = {} + gateway_attributes_message = {} + gateway_timeseries_device_datapoints_counts: Dict[str, int] = {} + gateway_attributes_device_datapoints_counts: Dict[str, int] = {} + gateway_timeseries_delivery_futures: Dict[str, List[Optional[asyncio.Future[PublishResult]]]] = {} + gateway_attributes_delivery_futures: Dict[str, List[Optional[asyncio.Future[PublishResult]]]] = {} + for device, device_msgs in device_groups.items(): + if device not in gateway_timeseries_message: + gateway_timeseries_message[device] = [] + gateway_timeseries_delivery_futures[device] = [] + if device not in gateway_attributes_message: + gateway_attributes_message[device] = [] + gateway_attributes_delivery_futures[device] = [] + telemetry_msgs: List[DeviceUplinkMessage] = [m for m in device_msgs if m.has_timeseries()] + attr_msgs: List[DeviceUplinkMessage] = [m for m in device_msgs if m.has_attributes()] + logger.trace("Device '%s' - telemetry: %d, attributes: %d", + device, len(telemetry_msgs), len(attr_msgs)) + + # TODO: Recommended to add message splitter to handle large messages and split them into smaller batches + for ts_batch in telemetry_msgs: + packed_ts = JsonGatewayMessageAdapter.pack_timeseries(ts_batch) + gateway_timeseries_message[device].append(packed_ts) + count = ts_batch.timeseries_datapoint_count() + gateway_timeseries_device_datapoints_counts[device] = gateway_timeseries_device_datapoints_counts.get(device, 0) + count + gateway_timeseries_delivery_futures[device] = ts_batch.get_delivery_futures() + logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) + + for attr_batch in attr_msgs: + packed_attrs = JsonGatewayMessageAdapter.pack_attributes(attr_batch) + count = attr_batch.attributes_datapoint_count() + gateway_attributes_message[device].append(packed_attrs) + gateway_attributes_device_datapoints_counts[device] = gateway_attributes_device_datapoints_counts.get(device, 0) + count + logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) + if telemetry_msgs: + result.append((GATEWAY_TELEMETRY_TOPIC, + dumps(gateway_timeseries_message[device]), + gateway_timeseries_device_datapoints_counts[device], + gateway_timeseries_delivery_futures[device])) + if attr_msgs: + result.append((GATEWAY_ATTRIBUTES_TOPIC, + dumps(gateway_attributes_message[device]), + gateway_attributes_device_datapoints_counts[device], + gateway_attributes_delivery_futures[device])) + + logger.trace("Generated %d topic-payload entries.", len(result)) + + return result + except Exception as e: + logger.error("Error building topic-payloads: %s", str(e)) + logger.debug("Exception details: %s", e, exc_info=True) + raise + + def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage) -> Tuple[str, bytes]: + """ + Build the payload for a device connect message. + This method serializes the DeviceConnectMessage to JSON format. + """ + try: + payload = dumps(device_connect_message.to_payload_format()) + logger.trace("Built device connect message payload for device='%s'", device_connect_message.device_name) + return GATEWAY_CONNECT_TOPIC, payload + except Exception as e: + logger.error("Failed to build device connect message payload: %s", str(e)) + raise ValueError("Invalid device connect message format") from e + + def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage) -> Tuple[str, bytes]: + """ + Build the payload for a device disconnect message. + This method serializes the device name to JSON format. + """ + try: + payload = dumps(device_disconnect_message.to_payload_format()) + logger.trace("Built device disconnect message payload for device='%s'", device_disconnect_message.device_name) + return GATEWAY_DISCONNECT_TOPIC, payload + except Exception as e: + logger.error("Failed to build device disconnect message payload: %s", str(e)) + raise ValueError("Invalid device disconnect message format") from e + + def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest) -> Tuple[str, bytes]: + """ + Build the payload for a gateway attribute request. + This method serializes the GatewayAttributeRequest to JSON format. + """ + try: + payload = dumps(attribute_request.to_payload_format()) + logger.trace("Built gateway attribute request payload for device='%s'", attribute_request.device_name) + return GATEWAY_ATTRIBUTES_TOPIC, payload + except Exception as e: + logger.error("Failed to build gateway attribute request payload: %s", str(e)) + raise ValueError("Invalid gateway attribute request format") from e + + def parse_attribute_update(self, payload: bytes) -> GatewayAttributeUpdate: + try: + data = loads(payload.decode('utf-8')) + device_name = data['device_name'] + attribute_update = AttributeUpdate._deserialize_from_dict(data['data']) # noqa + return GatewayAttributeUpdate(device_name=device_name, attribute_update=attribute_update) + except Exception as e: + logger.error("Failed to parse attribute update: %s", str(e)) + raise ValueError("Invalid attribute update format") from e + + def parse_gateway_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, payload: bytes) -> Union[GatewayRequestedAttributeResponse, None]: + try: + data = loads(payload.decode('utf-8')) + device_name = data['device_name'] + client = [] + shared = [] + if 'value' in data and not ((len(gateway_attribute_request.client_keys) == 1 and len(gateway_attribute_request.shared_keys) == 0) + or (len(gateway_attribute_request.client_keys) == 0 and len(gateway_attribute_request.shared_keys) == 1)): + # TODO: Skipping case when requested several attributes, but only one is returned, issue on the platform + logger.warning("Received gateway attribute response with single key, but multiply keys expected. " + "Request keys: %s, Response keys: %s", list(*gateway_attribute_request.client_keys, *gateway_attribute_request.shared_keys), data['value']) + return None + elif 'value' in data: + if len(gateway_attribute_request.client_keys) == 1: + client= [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] + elif len(gateway_attribute_request.shared_keys) == 1: + shared = [AttributeEntry(gateway_attribute_request.shared_keys[0], data['value'])] + elif 'data' in data: + if len(gateway_attribute_request.client_keys) > 0: + client = [AttributeEntry(k, v) for k, v in data['data'].get('client', {}).items() if k in gateway_attribute_request.client_keys] + if len(gateway_attribute_request.shared_keys) > 0: + shared = [AttributeEntry(k, v) for k, v in data['data'].get('shared', {}).items() if k in gateway_attribute_request.shared_keys] + return GatewayRequestedAttributeResponse(device_name=device_name, request_id=gateway_attribute_request.request_id, shared=shared, client=client) + except Exception as e: + logger.error("Failed to parse gateway attribute response: %s", str(e)) + raise ValueError("Invalid gateway attribute response format") from e + + @staticmethod + def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: + logger.trace("Packing %d attribute(s)", len(msg.attributes)) + return {attr.key: attr.value for attr in msg.attributes} + + @staticmethod + def pack_timeseries(msg: DeviceUplinkMessage) -> List[Dict[str, Any]]: + now_ts = int(datetime.now(UTC).timestamp() * 1000) + packed = [ + {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} + for entry in chain.from_iterable(msg.timeseries.values()) + ] + logger.trace("Packed %d timeseries entry(s)", len(packed)) + + return packed diff --git a/tb_mqtt_client/service/gateway/subdevice_manager.py b/tb_mqtt_client/service/gateway/subdevice_manager.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/service/gateway/subdevice_manager.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 3a0d014..5d0539e 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -21,7 +21,7 @@ from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter from tb_mqtt_client.service.mqtt_manager import MQTTManager logger = get_logger(__name__) @@ -36,7 +36,7 @@ def __init__(self, message_rate_limit: Optional[RateLimit], telemetry_rate_limit: Optional[RateLimit], telemetry_dp_rate_limit: Optional[RateLimit], - message_dispatcher: MessageDispatcher, + message_adapter: MessageAdapter, max_queue_size: int = 1000000, batch_collect_max_time_ms: int = 100, batch_collect_max_count: int = 500): @@ -57,7 +57,7 @@ def __init__(self, self._wakeup_event = asyncio.Event() self._retry_tasks: set[asyncio.Task] = set() self._active.set() - self._dispatcher = message_dispatcher + self._adapter = message_adapter self._loop_task = asyncio.create_task(self._dequeue_loop()) self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", @@ -125,7 +125,7 @@ async def _dequeue_loop(self): next_topic, next_payload, delivery_futures_or_none, datapoints, qos = self._queue.get_nowait() if isinstance(next_payload, DeviceUplinkMessage): msg_size = next_payload.size - if batch_size + msg_size > self._dispatcher.splitter.max_payload_size: # noqa + if batch_size + msg_size > self._adapter.splitter.max_payload_size: # noqa logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) break @@ -141,7 +141,7 @@ async def _dequeue_loop(self): logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) messages = [device_uplink_message for _, device_uplink_message, _, _, _ in batch] - topic_payloads = self._dispatcher.build_uplink_payloads(messages) + topic_payloads = self._adapter.build_uplink_payloads(messages) for topic, payload, datapoints, delivery_futures in topic_payloads: logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 4abdc0d..2a1dd06 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -35,7 +35,7 @@ from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter logger = get_logger(__name__) @@ -51,7 +51,7 @@ def __init__( self, client_id: str, main_stop_event: asyncio.Event, - message_dispatcher: MessageDispatcher, + message_adapter: MessageAdapter, on_connect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, on_disconnect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, on_publish_result: Optional[Callable[[PublishResult], Coroutine[Any, Any, None]]] = None, @@ -59,7 +59,7 @@ def __init__( rpc_response_handler: Optional[RPCResponseHandler] = None, ): self._main_stop_event = main_stop_event - self._message_dispatcher = message_dispatcher + self._message_dispatcher = message_adapter patch_gmqtt_protocol_connection_lost() patch_mqtt_handler_disconnect() diff --git a/tests/common/test_provisioning_client.py b/tests/common/test_provisioning_client.py index 9b95431..5f71855 100644 --- a/tests/common/test_provisioning_client.py +++ b/tests/common/test_provisioning_client.py @@ -33,7 +33,7 @@ def real_request(): @pytest.mark.asyncio @patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") -@patch("tb_mqtt_client.common.provisioning_client.JsonMessageDispatcher") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageAdapter") async def test_successful_provisioning_flow(mock_dispatcher_cls, mock_gmqtt_cls, real_request): mock_client = AsyncMock() mock_gmqtt_cls.return_value = mock_client @@ -63,7 +63,7 @@ async def test_successful_provisioning_flow(mock_dispatcher_cls, mock_gmqtt_cls, @pytest.mark.asyncio @patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") -@patch("tb_mqtt_client.common.provisioning_client.JsonMessageDispatcher") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageAdapter") async def test_failed_connection(mock_dispatcher_cls, mock_gmqtt_cls, real_request, caplog): mock_client = AsyncMock() mock_gmqtt_cls.return_value = mock_client @@ -81,7 +81,7 @@ async def test_failed_connection(mock_dispatcher_cls, mock_gmqtt_cls, real_reque @pytest.mark.asyncio @patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") -@patch("tb_mqtt_client.common.provisioning_client.JsonMessageDispatcher") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageAdapter") async def test_provision_method_awaits_provisioned(mock_dispatcher_cls, mock_gmqtt_cls, real_request): mock_client = AsyncMock() mock_gmqtt_cls.return_value = mock_client diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py index fcfbc51..096a31f 100644 --- a/tests/service/device/test_device_client.py +++ b/tests/service/device/test_device_client.py @@ -352,7 +352,7 @@ async def test_initializes_dispatcher_and_queue_after_connection(): await client.connect() assert client.max_payload_size == 65535 - assert client._message_dispatcher is not None + assert client._message_adapter is not None assert client._message_queue is not None mock_queue.assert_called_once() @@ -371,8 +371,8 @@ async def test_uses_default_max_payload_size_when_not_provided(): client.max_payload_size = None splitter = FakeSplitter() - client._message_dispatcher = MagicMock() - client._message_dispatcher.splitter = splitter + client._message_adapter = MagicMock() + client._message_adapter.splitter = splitter resp = RPCResponse.build(1, result={"rateLimits": {}}) await client._handle_rate_limit_response(resp) @@ -385,7 +385,7 @@ async def test_uses_default_max_payload_size_when_not_provided(): async def test_does_not_update_dispatcher_when_not_initialized(): client = DeviceClient() client.max_payload_size = None - client._message_dispatcher = None + client._message_adapter = None resp = RPCResponse.build(1, result={"rateLimits": {}}) await client._handle_rate_limit_response(resp) diff --git a/tests/service/test_json_message_dispatcher.py b/tests/service/test_json_message_adapter.py similarity index 65% rename from tests/service/test_json_message_dispatcher.py rename to tests/service/test_json_message_adapter.py index d5217ff..a958a89 100644 --- a/tests/service/test_json_message_dispatcher.py +++ b/tests/service/test_json_message_adapter.py @@ -31,7 +31,7 @@ from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter @pytest.fixture @@ -54,74 +54,74 @@ def build_msg(device="devX", with_attr=False, with_ts=False): return builder.build() @pytest.fixture -def dispatcher(): - return JsonMessageDispatcher() +def adapter(): + return JsonMessageAdapter() -def test_build_attribute_request(dispatcher): +def test_build_attribute_request(adapter): request = MagicMock(spec=AttributeRequest) request.request_id = 1 request.to_payload_format.return_value = {"clientKeys": "temp", "sharedKeys": "shared"} - topic, payload = dispatcher.build_attribute_request(request) + topic, payload = adapter.build_attribute_request(request) assert topic.endswith("/1") assert b"clientKeys" in payload -def test_build_attribute_request_invalid(dispatcher): +def test_build_attribute_request_invalid(adapter): request = MagicMock(spec=AttributeRequest) request.request_id = None with pytest.raises(ValueError): - dispatcher.build_attribute_request(request) + adapter.build_attribute_request(request) -def test_build_claim_request(dispatcher): +def test_build_claim_request(adapter): req = ClaimRequest.build("secretKey") - topic, payload = dispatcher.build_claim_request(req) + topic, payload = adapter.build_claim_request(req) assert topic == mqtt_topics.DEVICE_CLAIM_TOPIC assert b"secretKey" in payload -def test_build_claim_request_invalid(dispatcher): +def test_build_claim_request_invalid(adapter): with pytest.raises(ValueError): req = ClaimRequest.build(secret_key=None) # Simulating an invalid request # noqa -def test_build_rpc_request(dispatcher): +def test_build_rpc_request(adapter): request = MagicMock(spec=RPCRequest) request.request_id = 42 request.to_payload_format.return_value = {"method": "reboot"} - topic, payload = dispatcher.build_rpc_request(request) + topic, payload = adapter.build_rpc_request(request) assert topic.endswith("42") assert b"reboot" in payload -def test_build_rpc_request_invalid(dispatcher): +def test_build_rpc_request_invalid(adapter): request = MagicMock(spec=RPCRequest) request.request_id = None with pytest.raises(ValueError): - dispatcher.build_rpc_request(request) + adapter.build_rpc_request(request) -def test_build_rpc_response(dispatcher): +def test_build_rpc_response(adapter): response = MagicMock(spec=RPCResponse) response.request_id = 123 response.to_payload_format.return_value = {"result": "ok"} - topic, payload = dispatcher.build_rpc_response(response) + topic, payload = adapter.build_rpc_response(response) assert topic.endswith("123") assert b"ok" in payload -def test_build_rpc_response_invalid(dispatcher): +def test_build_rpc_response_invalid(adapter): response = MagicMock(spec=RPCResponse) response.request_id = None with pytest.raises(ValueError): - dispatcher.build_rpc_response(response) + adapter.build_rpc_response(response) -def test_build_provision_request_access_token(dispatcher): +def test_build_provision_request_access_token(adapter): credentials = AccessTokenProvisioningCredentials("key1", "secret1", access_token="tokenABC") req = ProvisioningRequest("localhost", credentials, device_name="dev1", gateway=True) - topic, payload = dispatcher.build_provision_request(req) + topic, payload = adapter.build_provision_request(req) assert topic == mqtt_topics.PROVISION_REQUEST_TOPIC assert b"provisionDeviceKey" in payload assert b"tokenABC" in payload @@ -130,104 +130,104 @@ def test_build_provision_request_access_token(dispatcher): assert b"gateway" in payload -def test_build_provision_request_mqtt_basic(dispatcher): +def test_build_provision_request_mqtt_basic(adapter): credentials = BasicProvisioningCredentials("key2", "secret2", client_id="cid", username="user", password="pass") req = ProvisioningRequest("127.0.0.1", credentials, device_name="dev2", gateway=False) - topic, payload = dispatcher.build_provision_request(req) + topic, payload = adapter.build_provision_request(req) assert b"clientId" in payload assert b"username" in payload assert b"password" in payload assert b"credentialsType" in payload -def test_build_provision_request_x509(dispatcher): +def test_build_provision_request_x509(adapter): cert_path = "/fake/path/cert.pem" cert_content = "-----BEGIN CERTIFICATE-----\nFAKECERT\n-----END CERTIFICATE-----" with patch("builtins.open", mock_open(read_data=cert_content)): credentials = X509ProvisioningCredentials("key3", "secret3", "key.pem", cert_path, "ca.pem") req = ProvisioningRequest("iot.server", credentials, device_name="dev3") - topic, payload = dispatcher.build_provision_request(req) + topic, payload = adapter.build_provision_request(req) assert b"hash" in payload assert b"credentialsType" in payload assert b"FAKECERT" in payload -def test_build_provision_request_x509_file_not_found(dispatcher): +def test_build_provision_request_x509_file_not_found(adapter): with patch("builtins.open", side_effect=FileNotFoundError): with pytest.raises(FileNotFoundError): X509ProvisioningCredentials("key", "secret", "k.pem", "nonexistent.pem", "ca.pem") -def test_parse_attribute_request_response(dispatcher): +def test_parse_attribute_request_response(adapter): topic = "v1/devices/me/attributes/response/42" payload = dumps({"shared": {"temp": 22}}) with patch.object(RequestedAttributeResponse, "from_dict", return_value="ok") as mock: - result = dispatcher.parse_attribute_request_response(topic, payload) + result = adapter.parse_requested_attribute_response(topic, payload) assert result == "ok" mock.assert_called_once() -def test_parse_attribute_request_response_invalid(dispatcher): +def test_parse_attribute_request_response_invalid(adapter): topic = "v1/devices/me/attributes/response/bad" with pytest.raises(ValueError): - dispatcher.parse_attribute_request_response(topic, b"invalid") + adapter.parse_requested_attribute_response(topic, b"invalid") -def test_parse_attribute_update(dispatcher): +def test_parse_attribute_update(adapter): payload = dumps({"shared": {"humidity": 60}}) with patch.object(AttributeUpdate, "_deserialize_from_dict", return_value="AU"): - result = dispatcher.parse_attribute_update(payload) + result = adapter.parse_attribute_update(payload) assert result == "AU" -def test_parse_attribute_update_invalid(dispatcher): +def test_parse_attribute_update_invalid(adapter): with pytest.raises(ValueError): - dispatcher.parse_attribute_update(b"{bad}") + adapter.parse_attribute_update(b"{bad}") -def test_parse_rpc_request(dispatcher): +def test_parse_rpc_request(adapter): topic = "v1/devices/me/rpc/request/123" payload = dumps({"params": {"a": 1}}) with patch.object(RPCRequest, "_deserialize_from_dict", return_value="REQ"): - assert dispatcher.parse_rpc_request(topic, payload) == "REQ" + assert adapter.parse_rpc_request(topic, payload) == "REQ" -def test_parse_rpc_request_invalid(dispatcher): +def test_parse_rpc_request_invalid(adapter): topic = "v1/devices/me/rpc/request/NaN" with pytest.raises(ValueError): - dispatcher.parse_rpc_request(topic, b"{}") + adapter.parse_rpc_request(topic, b"{}") -def test_parse_rpc_response(dispatcher): +def test_parse_rpc_response(adapter): topic = "v1/devices/me/rpc/response/999" payload = dumps({"value": "done"}) with patch.object(RPCResponse, "build", return_value="RSP"): - assert dispatcher.parse_rpc_response(topic, payload) == "RSP" + assert adapter.parse_rpc_response(topic, payload) == "RSP" -def test_parse_rpc_response_with_error(dispatcher): +def test_parse_rpc_response_with_error(adapter): topic = "v1/devices/me/rpc/response/888" error = ValueError("fail") with patch.object(RPCResponse, "build", return_value="ERR"): - assert dispatcher.parse_rpc_response(topic, error) == "ERR" + assert adapter.parse_rpc_response(topic, error) == "ERR" -def test_parse_rpc_response_invalid(dispatcher): +def test_parse_rpc_response_invalid(adapter): topic = "v1/devices/me/rpc/response/NaN" with pytest.raises(ValueError): - dispatcher.parse_rpc_response(topic, b"bad") + adapter.parse_rpc_response(topic, b"bad") @pytest.mark.asyncio -async def test_build_uplink_payloads_empty(dispatcher: JsonMessageDispatcher): - assert dispatcher.build_uplink_payloads([]) == [] +async def test_build_uplink_payloads_empty(adapter: JsonMessageAdapter): + assert adapter.build_uplink_payloads([]) == [] @pytest.mark.asyncio -async def test_build_uplink_payloads_only_attributes(dispatcher: JsonMessageDispatcher): +async def test_build_uplink_payloads_only_attributes(adapter: JsonMessageAdapter): msg = build_msg(with_attr=True) - with patch.object(dispatcher._splitter, "split_attributes", return_value=[msg]): - result = dispatcher.build_uplink_payloads([msg]) + with patch.object(adapter._splitter, "split_attributes", return_value=[msg]): + result = adapter.build_uplink_payloads([msg]) assert len(result) == 1 topic, payload, count, futures = result[0] assert topic == DEVICE_ATTRIBUTES_TOPIC @@ -236,10 +236,10 @@ async def test_build_uplink_payloads_only_attributes(dispatcher: JsonMessageDisp @pytest.mark.asyncio -async def test_build_uplink_payloads_only_timeseries(dispatcher: JsonMessageDispatcher): +async def test_build_uplink_payloads_only_timeseries(adapter: JsonMessageAdapter): msg = build_msg(with_ts=True) - with patch.object(dispatcher._splitter, "split_timeseries", return_value=[msg]): - result = dispatcher.build_uplink_payloads([msg]) + with patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): + result = adapter.build_uplink_payloads([msg]) assert len(result) == 1 topic, payload, count, futures = result[0] assert topic == DEVICE_TELEMETRY_TOPIC @@ -248,11 +248,11 @@ async def test_build_uplink_payloads_only_timeseries(dispatcher: JsonMessageDisp @pytest.mark.asyncio -async def test_build_uplink_payloads_both(dispatcher: JsonMessageDispatcher): +async def test_build_uplink_payloads_both(adapter: JsonMessageAdapter): msg = build_msg(with_attr=True, with_ts=True) - with patch.object(dispatcher._splitter, "split_attributes", return_value=[msg]), \ - patch.object(dispatcher._splitter, "split_timeseries", return_value=[msg]): - result = dispatcher.build_uplink_payloads([msg]) + with patch.object(adapter._splitter, "split_attributes", return_value=[msg]), \ + patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): + result = adapter.build_uplink_payloads([msg]) assert len(result) == 2 topics = {r[0] for r in result} assert DEVICE_ATTRIBUTES_TOPIC in topics @@ -260,27 +260,27 @@ async def test_build_uplink_payloads_both(dispatcher: JsonMessageDispatcher): @pytest.mark.asyncio -async def test_build_uplink_payloads_multiple_devices(dispatcher: JsonMessageDispatcher): +async def test_build_uplink_payloads_multiple_devices(adapter: JsonMessageAdapter): msg1 = build_msg(device="dev1", with_attr=True) msg2 = build_msg(device="dev2", with_ts=True) - with patch.object(dispatcher._splitter, "split_attributes", side_effect=lambda x: x), \ - patch.object(dispatcher._splitter, "split_timeseries", side_effect=lambda x: x): - result = dispatcher.build_uplink_payloads([msg1, msg2]) + with patch.object(adapter._splitter, "split_attributes", side_effect=lambda x: x), \ + patch.object(adapter._splitter, "split_timeseries", side_effect=lambda x: x): + result = adapter.build_uplink_payloads([msg1, msg2]) topics = {r[0] for r in result} assert DEVICE_ATTRIBUTES_TOPIC in topics or DEVICE_TELEMETRY_TOPIC in topics -def test_build_payload_with_device_name(dispatcher: JsonMessageDispatcher): +def test_build_payload_with_device_name(adapter: JsonMessageAdapter): msg = build_msg(with_ts=True) - payload = dispatcher.build_payload(msg, True) + payload = adapter.build_payload(msg, True) assert isinstance(payload, bytes) assert msg.device_name.encode() in payload -def test_build_payload_without_device_name(dispatcher: JsonMessageDispatcher): +def test_build_payload_without_device_name(adapter: JsonMessageAdapter): builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 9)) msg = builder.build() - payload = dispatcher.build_payload(msg, False) + payload = adapter.build_payload(msg, False) assert isinstance(payload, bytes) assert b"x" in payload @@ -288,44 +288,44 @@ def test_build_payload_without_device_name(dispatcher: JsonMessageDispatcher): def test_pack_attributes(): builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 10)) msg = builder.build() - result = JsonMessageDispatcher.pack_attributes(msg) + result = JsonMessageAdapter.pack_attributes(msg) assert isinstance(result, dict) assert "x" in result def test_pack_timeseries_uses_now(monkeypatch): - monkeypatch.setattr("tb_mqtt_client.service.message_dispatcher.datetime", MagicMock()) + monkeypatch.setattr("tb_mqtt_client.service.device.message_adapter.datetime", MagicMock()) ts_entry = TimeseriesEntry("temp", 23, ts=None) builder = DeviceUplinkMessageBuilder().add_timeseries(ts_entry) msg = builder.build() - packed = JsonMessageDispatcher.pack_timeseries(msg) + packed = JsonMessageAdapter.pack_timeseries(msg) assert isinstance(packed, list) assert "ts" in packed[0] assert "values" in packed[0] -def test_build_uplink_payloads_error_handling(dispatcher: JsonMessageDispatcher): - with patch("tb_mqtt_client.service.message_dispatcher.DeviceUplinkMessage.has_attributes", side_effect=Exception("boom")): +def test_build_uplink_payloads_error_handling(adapter: JsonMessageAdapter): + with patch("tb_mqtt_client.service.device.message_adapter.DeviceUplinkMessage.has_attributes", side_effect=Exception("boom")): msg = build_msg(with_attr=True) with pytest.raises(Exception, match="boom"): - dispatcher.build_uplink_payloads([msg]) + adapter.build_uplink_payloads([msg]) -def test_parse_provisioning_response_success(dispatcher, dummy_provisioning_request): +def test_parse_provisioning_response_success(adapter, dummy_provisioning_request): payload_dict = {"status": "SUCCESS", "credentialsType": "ACCESS_TOKEN"} payload_bytes = dumps(payload_dict) with patch.object(ProvisioningResponse, "build", return_value="SUCCESS_RESPONSE") as mock_build: - result = dispatcher.parse_provisioning_response(dummy_provisioning_request, payload_bytes) + result = adapter.parse_provisioning_response(dummy_provisioning_request, payload_bytes) assert result == "SUCCESS_RESPONSE" mock_build.assert_called_once_with(dummy_provisioning_request, payload_dict) -def test_parse_provisioning_response_failure(dispatcher, dummy_provisioning_request): +def test_parse_provisioning_response_failure(adapter, dummy_provisioning_request): broken_bytes = b"{not_json" with patch.object(ProvisioningResponse, "build", return_value="FAILURE_RESPONSE") as mock_build: - result = dispatcher.parse_provisioning_response(dummy_provisioning_request, broken_bytes) + result = adapter.parse_provisioning_response(dummy_provisioning_request, broken_bytes) assert result == "FAILURE_RESPONSE" mock_build.assert_called_once() args = mock_build.call_args[0] diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py index 5637f39..245a57a 100644 --- a/tests/service/test_message_queue.py +++ b/tests/service/test_message_queue.py @@ -25,7 +25,7 @@ from tb_mqtt_client.constants.service_keys import TELEMETRY_DATAPOINTS_RATE_LIMIT from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter from tb_mqtt_client.service.message_queue import MessageQueue @@ -44,9 +44,9 @@ async def test_batching_device_uplink_message(): dummy_message.device_name = "device" dummy_message.get_delivery_futures.return_value = [delivery_future] - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100 - dispatcher.build_uplink_payloads.return_value = [ + adapter = MagicMock() + adapter.splitter.max_payload_size = 100 + adapter.build_uplink_payloads.return_value = [ (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'batch_payload', 1, [delivery_future]) ] @@ -56,7 +56,7 @@ async def test_batching_device_uplink_message(): message_rate_limit=None, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher, + message_adapter=adapter, batch_collect_max_time_ms=50, batch_collect_max_count=10 ) @@ -82,12 +82,12 @@ async def test_telemetry_rate_limit_retry_triggered(): main_stop_event = asyncio.Event() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 delivery_future = asyncio.Future() delivery_future.set_result(PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 5, 0)) - dispatcher.build_uplink_payloads.return_value = [ + adapter.build_uplink_payloads.return_value = [ (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'dummy_payload', 1, [delivery_future]) ] @@ -102,7 +102,7 @@ async def test_telemetry_rate_limit_retry_triggered(): message_rate_limit=None, telemetry_rate_limit=telemetry_limit, telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher + message_adapter=adapter ) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, qos=1) @@ -117,12 +117,12 @@ async def test_shutdown_clears_queue(): mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) dummy = MagicMock() dummy.size = 1 dummy.get_delivery_futures.return_value = [] @@ -136,12 +136,12 @@ async def test_publish_raw_bytes_success(): mqtt_manager = MagicMock() mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] main_stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, adapter) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"payload", 1, qos=1) await asyncio.sleep(0.05) await queue.shutdown() @@ -161,13 +161,13 @@ async def test_publish_device_uplink_message_batched(): dummy_msg.device_name = "dev" dummy_msg.get_delivery_futures.return_value = [future] - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100 - dispatcher.build_uplink_payloads.return_value = [ + adapter = MagicMock() + adapter.splitter.max_payload_size = 100 + adapter.build_uplink_payloads.return_value = [ (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batch", 1, [future]) ] - queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, adapter) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, qos=1) await asyncio.sleep(0.1) await queue.shutdown() @@ -184,9 +184,9 @@ async def test_rate_limit_telemetry_triggers_retry(): mqtt_manager = MagicMock() mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] main_stop_event = asyncio.Event() msg = MagicMock() @@ -194,7 +194,7 @@ async def test_rate_limit_telemetry_triggers_retry(): msg.size = 1 msg.get_delivery_futures.return_value = [] - queue = MessageQueue(mqtt_manager, main_stop_event, None, limit, None, dispatcher) + queue = MessageQueue(mqtt_manager, main_stop_event, None, limit, None, adapter) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, 1) await asyncio.sleep(0.2) await queue.shutdown() @@ -221,8 +221,8 @@ async def complete_publish_result_later(): mqtt_manager.publish = publish_mock mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 future = asyncio.Future() dummy_msg = MagicMock() @@ -230,7 +230,7 @@ async def complete_publish_result_later(): dummy_msg.size = 10 dummy_msg.get_delivery_futures.return_value = [future] - dispatcher.build_uplink_payloads.return_value = [ + adapter.build_uplink_payloads.return_value = [ (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"payload", 1, [future]) ] @@ -239,7 +239,7 @@ async def complete_publish_result_later(): mqtt_manager, stop_event, None, None, None, - dispatcher, + adapter, batch_collect_max_time_ms=10 ) @@ -266,13 +266,13 @@ async def test_mixed_raw_and_structured_queue(): uplink_msg.size = 10 uplink_msg.get_delivery_futures.return_value = [future] - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100 - dispatcher.build_uplink_payloads.return_value = [ + adapter = MagicMock() + adapter.splitter.max_payload_size = 100 + adapter.build_uplink_payloads.return_value = [ (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batched", 1, [future]) ] - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher, batch_collect_max_time_ms=20) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter, batch_collect_max_time_ms=20) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"raw", 1, 1) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, uplink_msg, 1, 1) await asyncio.sleep(0.1) @@ -290,12 +290,12 @@ async def test_rate_limit_refill_executes(): mqtt_manager = MagicMock() mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, r1, r2, r3, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, r1, r2, r3, adapter) await asyncio.sleep(1.2) await queue.shutdown() @@ -311,12 +311,12 @@ async def test_try_publish_without_delivery_futures(): mqtt_manager.publish.return_value.set_result(PublishResult("t", 1, 1, 1, 1)) mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) await queue._try_publish("custom/topic", b"payload", datapoints=1, delivery_futures_or_none=None, qos=1) await queue.shutdown() @@ -329,14 +329,14 @@ async def test_schedule_delayed_retry_skipped_if_inactive_or_stopped(): mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() stop_event.set() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) queue._active.clear() queue._schedule_delayed_retry("topic", b"data", datapoints=1, qos=1, delay=0.01) @@ -348,14 +348,14 @@ async def test_clear_queue_sets_futures_to_publish_result(): mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() with patch.object(MessageQueue, "_dequeue_loop", new=AsyncMock()): - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) dummy_msg = DeviceUplinkMessageBuilder() \ .add_delivery_futures(asyncio.Future()) \ @@ -383,12 +383,12 @@ async def test_wait_for_message_exit_on_inactive(): mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) queue._active.clear() with pytest.raises(asyncio.CancelledError): @@ -402,14 +402,14 @@ async def test_schedule_delayed_retry_requeues_message(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() with patch("tb_mqtt_client.service.message_queue.MessageQueue._dequeue_loop", new=AsyncMock()): - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) future = asyncio.Future() dummy_msg = MagicMock() @@ -435,12 +435,12 @@ async def test_schedule_delayed_retry_requeues_message(): @pytest.mark.asyncio async def test_cancel_tasks_clears_all(): mqtt_manager = MagicMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) async def dummy(): await asyncio.sleep(1) @@ -457,12 +457,12 @@ async def test_clear_queue_with_bytes_message(): mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 stop_event = asyncio.Event() with patch.object(MessageQueue, "_dequeue_loop", new=AsyncMock()): - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) future = asyncio.Future() await queue.publish("raw/topic", b"abc", 1, 0) @@ -492,16 +492,16 @@ async def test_resolve_attached_handles_publish_exception(): topic = "topic" qos = 1 - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock(return_value=future) stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) await queue._try_publish( topic=topic, @@ -532,13 +532,13 @@ async def test_try_publish_message_type_non_telemetry(): rate_limit.to_dict.return_value = {} rate_limit.minimal_timeout = 0.1 - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 stop_event = asyncio.Event() queue = MessageQueue(mqtt_manager, stop_event, message_rate_limit=rate_limit, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher) + message_adapter=adapter) await queue._try_publish( topic="non/telemetry", @@ -556,11 +556,11 @@ async def test_shutdown_rate_limit_task_cancel_only(): mqtt_manager = MagicMock() mqtt_manager.backpressure.should_pause.return_value = False mqtt_manager.publish = AsyncMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) # Cancel only the rate limit task before shutdown queue._rate_limit_refill_task.cancel() @@ -572,13 +572,13 @@ async def test_shutdown_rate_limit_task_cancel_only(): @pytest.mark.asyncio async def test_schedule_delayed_retry_when_main_stop_active(): mqtt_manager = MagicMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 stop_event = asyncio.Event() stop_event.set() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) queue._active.clear() @@ -591,11 +591,11 @@ async def test_schedule_delayed_retry_when_main_stop_active(): @pytest.mark.asyncio async def test_publish_queue_full_sets_failed_result_for_bytes(): mqtt_manager = MagicMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher, max_queue_size=1) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter, max_queue_size=1) await queue.publish("t", b"raw", 1, qos=0) queue._queue.put_nowait = MagicMock(side_effect=asyncio.QueueFull) @@ -613,10 +613,10 @@ async def test_publish_queue_full_sets_failed_result_for_bytes(): @pytest.mark.asyncio async def test_wait_for_message_raises_cancelled(): mqtt_manager = MagicMock() - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) queue._active.clear() @@ -649,10 +649,10 @@ async def test_batch_loop_breaks_on_count_threshold(): # This is the future that the message queue should resolve delivery_future = asyncio.Future() - # Mock dispatcher to output the delivery future - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [ + # Mock adapter to output the delivery future + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [ (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batched", 1, [delivery_future]) ] @@ -663,7 +663,7 @@ async def test_batch_loop_breaks_on_count_threshold(): message_rate_limit=None, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher, + message_adapter=adapter, batch_collect_max_count=2 ) @@ -692,12 +692,12 @@ async def test_batch_loop_skips_message_on_size_exceed(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 15 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 15 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) small_msg = MagicMock() small_msg.size = 10 @@ -724,12 +724,12 @@ async def test_batch_requeues_on_size_exceed(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 15 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 15 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) msg1 = MagicMock() msg1.size = 10 @@ -756,12 +756,12 @@ async def test_batch_immediate_publish_on_raw_bytes(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"raw_payload", 1, 1) @@ -779,12 +779,12 @@ async def test_batch_queue_empty_breaks_safely(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, dispatcher) + queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) await asyncio.sleep(0.05) await queue.shutdown() @@ -798,9 +798,9 @@ async def test_try_publish_telemetry_rate_limited(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 1000 - dispatcher.build_uplink_payloads.return_value = [("topic", b"{}", 3, [])] + adapter = MagicMock() + adapter.splitter.max_payload_size = 1000 + adapter.build_uplink_payloads.return_value = [("topic", b"{}", 3, [])] telemetry_rate_limit = MagicMock() telemetry_rate_limit.try_consume = AsyncMock(return_value=(10, 1)) telemetry_rate_limit.minimal_timeout = 0.5 @@ -812,7 +812,7 @@ async def test_try_publish_telemetry_rate_limited(): telemetry_rate_limit, None, RateLimit("10:1", TELEMETRY_DATAPOINTS_RATE_LIMIT, 100), - dispatcher + adapter ) queue._schedule_delayed_retry = MagicMock() @@ -837,9 +837,9 @@ async def test_try_publish_non_telemetry_rate_limited(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 1000 - dispatcher.build_uplink_payloads.return_value = [("topic", b"{}", 1)] + adapter = MagicMock() + adapter.splitter.max_payload_size = 1000 + adapter.build_uplink_payloads.return_value = [("topic", b"{}", 1)] message_rate_limit = MagicMock() message_rate_limit.try_consume = AsyncMock(return_value=(5, 60)) @@ -852,7 +852,7 @@ async def test_try_publish_non_telemetry_rate_limited(): telemetry_rate_limit=None, message_rate_limit=message_rate_limit, telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher + message_adapter=adapter ) queue._schedule_delayed_retry = MagicMock() @@ -883,9 +883,9 @@ async def test_backpressure_delays_publish(paused, monkeypatch): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = paused - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() queue = MessageQueue( @@ -894,7 +894,7 @@ async def test_backpressure_delays_publish(paused, monkeypatch): message_rate_limit=None, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_dispatcher=dispatcher, + message_adapter=adapter, batch_collect_max_count=1 ) @@ -931,9 +931,9 @@ async def test_publish_telemetry_rate_limit_triggered(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [] + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [] stop_event = asyncio.Event() @@ -948,7 +948,7 @@ async def test_publish_telemetry_rate_limit_triggered(): None, None, telemetry_dp_rate_limit=telemetry_dp_rate_limit, - message_dispatcher=dispatcher, + message_adapter=adapter, ) stack.push_async_callback(queue.shutdown) @@ -977,7 +977,7 @@ async def test_batch_loop_large_messages_are_split_and_published(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = JsonMessageDispatcher(100, 20) + adapter = JsonMessageAdapter(100, 20) stop_event = asyncio.Event() queue = MessageQueue( @@ -986,7 +986,7 @@ async def test_batch_loop_large_messages_are_split_and_published(): None, None, None, - message_dispatcher=dispatcher, + message_adapter=adapter, max_queue_size=100, batch_collect_max_time_ms=10 ) @@ -1040,9 +1040,9 @@ async def test_delivery_futures_resolved_via_real_puback_handler(): mqtt_future.mid: (delivery_future, topic, qos, payload_size, publish_time) } - dispatcher = MagicMock() - dispatcher.splitter.max_payload_size = 100000 - dispatcher.build_uplink_payloads.return_value = [ + adapter = MagicMock() + adapter.splitter.max_payload_size = 100000 + adapter.build_uplink_payloads.return_value = [ (topic, b'{"some":"payload"}', qos, [delivery_future]) ] @@ -1053,7 +1053,7 @@ async def test_delivery_futures_resolved_via_real_puback_handler(): None, None, None, - message_dispatcher=dispatcher, + message_adapter=adapter, batch_collect_max_count=1, batch_collect_max_time_ms=1 ) @@ -1088,7 +1088,7 @@ async def test_batch_append_and_batch_size_accumulate(): mqtt_manager.publish = AsyncMock() mqtt_manager.backpressure.should_pause.return_value = False - dispatcher = JsonMessageDispatcher(100000, 10000) + adapter = JsonMessageAdapter(100000, 10000) stop_event = asyncio.Event() queue = MessageQueue( @@ -1097,7 +1097,7 @@ async def test_batch_append_and_batch_size_accumulate(): None, None, None, - message_dispatcher=dispatcher, + message_adapter=adapter, batch_collect_max_count=2, batch_collect_max_time_ms=1000 ) diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index 6bc0e74..b15e3ad 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -20,7 +20,7 @@ from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder -from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter from tb_mqtt_client.service.message_splitter import MessageSplitter @@ -194,7 +194,7 @@ def test_datapoint_setter_validation(): @pytest.mark.asyncio async def test_split_attributes_grouping(): - dispatcher = JsonMessageDispatcher(max_payload_size=200, max_datapoints=5) + dispatcher = JsonMessageAdapter(max_payload_size=200, max_datapoints=5) builder1 = DeviceUplinkMessageBuilder().set_device_name("deviceA").set_device_profile("default") builder2 = DeviceUplinkMessageBuilder().set_device_name("deviceA").set_device_profile("default") @@ -220,7 +220,7 @@ async def test_split_attributes_grouping(): @pytest.mark.asyncio async def test_split_attributes_different_devices_not_grouped(): - dispatcher = JsonMessageDispatcher(max_payload_size=200, max_datapoints=100) + dispatcher = JsonMessageAdapter(max_payload_size=200, max_datapoints=100) builder1 = DeviceUplinkMessageBuilder().set_device_name("deviceA") builder2 = DeviceUplinkMessageBuilder().set_device_name("deviceB") diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index d46685f..bd513b3 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -23,14 +23,14 @@ from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler -from tb_mqtt_client.service.message_dispatcher import MessageDispatcher +from tb_mqtt_client.service.device.message_adapter import MessageAdapter from tb_mqtt_client.service.mqtt_manager import MQTTManager, IMPLEMENTATION_SPECIFIC_ERROR, QUOTA_EXCEEDED @pytest_asyncio.fixture async def setup_manager(): stop_event = asyncio.Event() - message_dispatcher = MagicMock(spec=MessageDispatcher) + message_dispatcher = MagicMock(spec=MessageAdapter) on_connect = AsyncMock() on_disconnect = AsyncMock() on_publish_result = AsyncMock() @@ -40,7 +40,7 @@ async def setup_manager(): manager = MQTTManager( client_id="test-client", main_stop_event=stop_event, - message_dispatcher=message_dispatcher, + message_adapter=message_dispatcher, on_connect=on_connect, on_disconnect=on_disconnect, on_publish_result=on_publish_result, From febccf6f13f43783e3be4bad763280d8c140fba7 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Tue, 15 Jul 2025 11:51:54 +0300 Subject: [PATCH 44/74] Removed device name from device uplink message, updated gateway messages to be gateway events --- examples/gateway/send_timeseries.py | 61 ++++++++ .../entities/data/device_uplink_message.py | 39 +---- .../entities/gateway/base_gateway_event.py | 31 ++++ .../entities/gateway/device_info.py | 13 +- tb_mqtt_client/entities/gateway/event_type.py | 24 +-- .../gateway/gateway_attribute_request.py | 5 +- .../gateway/gateway_attribute_update.py | 2 +- .../entities/gateway/gateway_event.py | 5 +- .../gateway_requested_attribute_response.py | 23 +-- .../entities/gateway/gateway_rpc_request.py | 6 +- .../entities/gateway/gateway_rpc_response.py | 22 +-- .../gateway/gateway_uplink_message.py | 23 ++- tb_mqtt_client/service/base_client.py | 84 ++++++++++- tb_mqtt_client/service/device/client.py | 56 +------ .../service/device/message_adapter.py | 73 ++++----- tb_mqtt_client/service/gateway/client.py | 124 ++++++++++++++- .../service/gateway/device_manager.py | 5 - .../service/gateway/device_session.py | 38 +++-- .../service/gateway/event_dispatcher.py | 5 +- .../gateway/gateway_client_interface.py | 59 +++----- .../gateway_attribute_updates_handler.py | 8 +- ...y_requested_attributes_response_handler.py | 130 ++++++++++++++++ .../gateway/handlers/gateway_rpc_handler.py | 60 ++++++++ .../service/gateway/message_adapter.py | 142 ++++++++++++------ tb_mqtt_client/service/message_queue.py | 42 ++++-- tb_mqtt_client/service/message_splitter.py | 14 +- tb_mqtt_client/service/mqtt_manager.py | 46 +++++- .../data/test_device_uplink_message.py | 6 +- 28 files changed, 822 insertions(+), 324 deletions(-) create mode 100644 examples/gateway/send_timeseries.py create mode 100644 tb_mqtt_client/entities/gateway/base_gateway_event.py create mode 100644 tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py new file mode 100644 index 0000000..e89e22e --- /dev/null +++ b/examples/gateway/send_timeseries.py @@ -0,0 +1,61 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.service.gateway.client import GatewayClient + + +configure_logging() +logger = get_logger("tb_mqtt_client") + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Test Device A2" + device_profile = "test_device_profile" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + logger.info("Device connected successfully: %s", device_name) + + # Send time series as raw dictionary + raw_timeseries = { + "temperature": 25.5, + "humidity": 60 + } + logger.info("Sending raw timeseries: %s", raw_timeseries) + await client.send_device_timeseries(device_session=device_session, data=raw_timeseries, wait_for_publish=True) + logger.info("Raw timeseries sent successfully.") + + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index 96927d4..9dae46d 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -24,13 +24,11 @@ logger = get_logger(__name__) -DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) +DEFAULT_FIELDS_SIZE = len('{"attributes":"","timeseries":""}'.encode('utf-8')) @dataclass(slots=True, frozen=True) -class DeviceUplinkMessage: - device_name: Optional[str] - device_profile: Optional[str] +class GatewayUplinkMessage: attributes: Tuple[AttributeEntry] timeseries: Mapping[int, Tuple[TimeseriesEntry]] delivery_futures: List[Optional[asyncio.Future[PublishResult]]] @@ -41,23 +39,16 @@ def __new__(cls, *args, **kwargs): "Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") def __repr__(self): - return (f"DeviceUplinkMessage(device_name={self.device_name}, " - f"device_profile={self.device_profile}, " - f"attributes={self.attributes}, " - f"timeseries={self.timeseries}, " - f"delivery_futures={self.delivery_futures})") + return (f"DeviceUplinkMessage(attributes={self.attributes}, " + f"timeseries={self.timeseries}, delivery_futures={self.delivery_futures})") @classmethod def build(cls, - device_name: Optional[str], - device_profile: Optional[str], attributes: List[AttributeEntry], timeseries: Mapping[int, List[TimeseriesEntry]], delivery_futures: List[Optional[asyncio.Future]], - size: int) -> 'DeviceUplinkMessage': + size: int) -> 'GatewayUplinkMessage': self = object.__new__(cls) - object.__setattr__(self, 'device_name', device_name) - object.__setattr__(self, 'device_profile', device_profile) object.__setattr__(self, 'attributes', tuple(attributes)) object.__setattr__(self, 'timeseries', MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) @@ -87,25 +78,11 @@ def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: class DeviceUplinkMessageBuilder: def __init__(self): - self._device_name: Optional[str] = None - self._device_profile: Optional[str] = None self._attributes: List[AttributeEntry] = [] self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] self.__size = DEFAULT_FIELDS_SIZE - def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': - self._device_name = device_name - if device_name is not None: - self.__size += len(device_name) - return self - - def set_device_profile(self, profile: str) -> 'DeviceUplinkMessageBuilder': - self._device_profile = profile - if profile is not None: - self.__size += len(profile) - return self - def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]]) -> 'DeviceUplinkMessageBuilder': if not isinstance(attributes, list): attributes = [attributes] @@ -145,12 +122,10 @@ def add_delivery_futures(self, futures: Union[ self._delivery_futures.extend(futures) return self - def build(self) -> DeviceUplinkMessage: + def build(self) -> GatewayUplinkMessage: if not self._delivery_futures: self._delivery_futures = [asyncio.get_event_loop().create_future()] - return DeviceUplinkMessage.build( - device_name=self._device_name, - device_profile=self._device_profile, + return GatewayUplinkMessage.build( attributes=self._attributes, timeseries=self._timeseries, delivery_futures=self._delivery_futures, diff --git a/tb_mqtt_client/entities/gateway/base_gateway_event.py b/tb_mqtt_client/entities/gateway/base_gateway_event.py new file mode 100644 index 0000000..590c8ff --- /dev/null +++ b/tb_mqtt_client/entities/gateway/base_gateway_event.py @@ -0,0 +1,31 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + +class BaseGatewayEvent: + def __init__(self, event_type: GatewayEventType): + self.__event_type = event_type + self.__device_session = None + + @property + def event_type(self) -> GatewayEventType: + return self.__event_type + + def set_device_session(self, device_session): + self.__device_session = device_session + + @property + def device_session(self): + return self.__device_session diff --git a/tb_mqtt_client/entities/gateway/device_info.py b/tb_mqtt_client/entities/gateway/device_info.py index eb9d5f0..8a0b61c 100644 --- a/tb_mqtt_client/entities/gateway/device_info.py +++ b/tb_mqtt_client/entities/gateway/device_info.py @@ -16,19 +16,28 @@ import uuid -@dataclass(frozen=True) +@dataclass() class DeviceInfo: device_name: str device_profile: str original_name: str = field(init=False) device_id: uuid.UUID = field(default_factory=uuid.uuid4, init=False) + _initializing: bool = field(default=True, init=False, repr=False) + def __post_init__(self): self.__setattr__("original_name", self.device_name) + self._initializing = False + + def __setattr__(self, key, value): + if not self._initializing: + raise AttributeError(f"Cannot modify attribute '{key}' of frozen DeviceInfo instance. Use rename() method to change device_name.") + else: + super().__setattr__(key, value) def rename(self, new_name: str): if new_name != self.device_name: - self.__setattr__("device_name", new_name) + self.device_name = new_name @classmethod def from_dict(cls, data: dict) -> 'DeviceInfo': diff --git a/tb_mqtt_client/entities/gateway/event_type.py b/tb_mqtt_client/entities/gateway/event_type.py index abb7e0b..c6c4e3e 100644 --- a/tb_mqtt_client/entities/gateway/event_type.py +++ b/tb_mqtt_client/entities/gateway/event_type.py @@ -19,18 +19,18 @@ class GatewayEventType(Enum): Enum representing different types of gateway events. Each event type corresponds to a specific action or state change in the gateway. """ - DEVICE_ADDED = "DEVICE_ADDED" - DEVICE_REMOVED = "DEVICE_REMOVED" - DEVICE_UPDATED = "DEVICE_UPDATED" - DEVICE_SESSION_STATE_CHANGED = "DEVICE_SESSION_STATE_CHANGED" - DEVICE_RPC_REQUEST_RECEIVED = "DEVICE_RPC_REQUEST_RECEIVED" - DEVICE_RPC_RESPONSE_SENT = "DEVICE_RPC_RESPONSE_SENT" - DEVICE_ATTRIBUTE_UPDATE_RECEIVED = "DEVICE_ATTRIBUTE_UPDATE_RECEIVED" - DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVED = "DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVED" - RPC_REQUEST_RECEIVED = "RPC_REQUEST_RECEIVED" - RPC_RESPONSE_SENT = "RPC_RESPONSE_SENT" - GATEWAY_CONNECTED = "GATEWAY_CONNECTED" - GATEWAY_DISCONNECTED = "GATEWAY_DISCONNECTED" + GATEWAY_CONNECT = "gateway.connect" + GATEWAY_DISCONNECT = "gateway.disconnect" + DEVICE_ADD = "gateway.device.add" + DEVICE_REMOVE = "gateway.device.remove" + DEVICE_UPDATE = "gateway.device.update" + DEVICE_SESSION_STATE_CHANGE = "gateway.device.session.state.change" + DEVICE_RPC_REQUEST_RECEIVE = "gateway.device.rpc.request.receive" + DEVICE_RPC_RESPONSE_SEND = "gateway.device.rpc.response.send" + DEVICE_ATTRIBUTE_UPDATE_RECEIVE = "gateway.device.attribute.update.receive" + DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVE = "gateway.device.requested.attribute.response.receive" + RPC_REQUEST_RECEIVE = "device.rpc.request.receive" + RPC_RESPONSE_SEND = "device.rpc.response.send" def __str__(self) -> str: return self.value diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py index e853633..803917f 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py @@ -16,10 +16,12 @@ from typing import Optional, List, Dict, Union from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent @dataclass(slots=True, frozen=True) -class GatewayAttributeRequest: +class GatewayAttributeRequest(BaseGatewayEvent): """ Represents a request for device attributes, with optional client and shared attribute keys. Automatically assigns a unique request ID via the build() method. @@ -28,6 +30,7 @@ class GatewayAttributeRequest: device_name: str shared_keys: Optional[List[str]] = None client_keys: Optional[List[str]] = None + event_type = GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVE def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of GatewayAttributeRequest is not allowed. Use 'await GatewayAttributeRequest.build(...)'.") diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py index 1c59814..0528c8f 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -24,7 +24,7 @@ class GatewayAttributeUpdate(GatewayEvent): """ def __init__(self, device_name: str, attribute_update: AttributeUpdate): - super().__init__(event_type=GatewayEventType.DEVICE_ATTRIBUTE_UPDATE_RECEIVED) + super().__init__(event_type=GatewayEventType.DEVICE_ATTRIBUTE_UPDATE_RECEIVE) self.device_name = device_name self.attribute_update = attribute_update diff --git a/tb_mqtt_client/entities/gateway/gateway_event.py b/tb_mqtt_client/entities/gateway/gateway_event.py index 66d0c9c..e67bfb1 100644 --- a/tb_mqtt_client/entities/gateway/gateway_event.py +++ b/tb_mqtt_client/entities/gateway/gateway_event.py @@ -14,18 +14,19 @@ from typing import Union +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.service.gateway.device_session import DeviceSession -class GatewayEvent: +class GatewayEvent(BaseGatewayEvent): """ Base class for all events in the gateway client. This class can be extended to create specific event types. """ def __init__(self, event_type: GatewayEventType): - self.event_type = event_type + super().__init__(event_type) self.__device_session: Union[DeviceSession, None] = None def set_device_session(self, device_session: DeviceSession): diff --git a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py index 651f2d7..5b66b8a 100644 --- a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py @@ -15,14 +15,18 @@ from dataclasses import dataclass from typing import Dict, Any, List, Optional +from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse + +logger = get_logger(__name__) @dataclass(slots=True, frozen=True) -class GatewayRequestedAttributeResponse: +class GatewayRequestedAttributeResponse(RequestedAttributeResponse): - device_name: str - request_id: int + device_name: str = "" + request_id: int = -1 shared: Optional[List[AttributeEntry]] = None client: Optional[List[AttributeEntry]] = None @@ -84,16 +88,3 @@ def as_dict(self) -> Dict[str, Any]: 'shared': [entry.as_dict() for entry in self.shared if self.shared is not None], 'client': [entry.as_dict() for entry in self.client if self.client is not None], } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'GatewayRequestedAttributeResponse': - """ - Deserialize dictionary into GatewayRequestedAttributeResponse object. - :param data: Dictionary containing 'device' with device name, 'shared' and 'client' attributes. - :return: GatewayRequestedAttributeResponse instance. - """ - request_id = data.get('request_id', -1) - device_name = data.get('device', '') - shared = [AttributeEntry(k, v) for k, v in data.get('shared', {}).items()] - client = [AttributeEntry(k, v) for k, v in data.get('client', {}).items()] - return cls(device_name=device_name, shared=shared, client=client, request_id=request_id) diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py index 5518232..a1af800 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -15,13 +15,17 @@ from dataclasses import dataclass from typing import Union, Optional, Dict, Any +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + @dataclass(slots=True, frozen=True) -class GatewayRPCRequest: +class GatewayRPCRequest(BaseGatewayEvent): request_id: Union[int, str] device_name: str method: str params: Optional[Any] = None + event_type: GatewayEventType = GatewayEventType.RPC_REQUEST_RECEIVE def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of GatewayRPCRequest is not allowed. Use 'await GatewayRPCRequest.build(...)'.") diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py index 37e9eb2..80dcb39 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py @@ -17,11 +17,13 @@ from typing import Union, Optional, Dict, Any from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType -from tb_mqtt_client.entities.data.rpc_response import RPCStatus +from tb_mqtt_client.entities.data.rpc_response import RPCStatus, RPCResponse +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType @dataclass(slots=True, frozen=True) -class GatewayRPCResponse: +class GatewayRPCResponse(RPCResponse, BaseGatewayEvent): """ Represents a response to the RPC call. @@ -31,20 +33,22 @@ class GatewayRPCResponse: result: Optional response payload (Any type allowed). error: Optional error information if the RPC failed. """ - device_name: str - request_id: Union[int, str] - status: RPCStatus = None - result: Optional[Any] = None - error: Optional[Union[str, Dict[str, Any]]] = None + device_name: str = None + event_type: GatewayEventType = GatewayEventType.RPC_RESPONSE_SEND def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of GatewayRPCResponse is not allowed. Use GatewayRPCResponse.build(request_id, result, error).") + raise TypeError("Direct instantiation of GatewayRPCResponse is not allowed. Use GatewayRPCResponse.build(device_name, request_id, result, error).") + def __repr__(self) -> str: return f"GatewayRPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" @classmethod - def build(cls, device_name: str, request_id: Union[int, str], result: Optional[Any] = None, error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'GatewayRPCResponse': + def build(cls, # noqa + device_name: str, + request_id: Union[int, str], + result: Optional[Any] = None, + error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'GatewayRPCResponse': """ Constructs an GatewayRPCResponse explicitly. """ diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index 578793e..a49ce47 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -34,6 +34,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry logger = get_logger(__name__) @@ -42,20 +43,16 @@ @dataclass(slots=True, frozen=True) -class GatewayUplinkMessage: - device_name: Optional[str] - device_profile: Optional[str] - attributes: Tuple[AttributeEntry] - timeseries: Mapping[int, Tuple[TimeseriesEntry]] - delivery_futures: List[Optional[asyncio.Future[PublishResult]]] - _size: int +class GatewayUplinkMessage(GatewayUplinkMessage): + device_name: str + device_profile: str def __new__(cls, *args, **kwargs): raise TypeError( - "Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") + "Direct instantiation of GatewayUplinkMessage is not allowed. Use GatewayUplinkMessageBuilder to construct instances.") def __repr__(self): - return (f"DeviceUplinkMessage(device_name={self.device_name}, " + return (f"GatewayUplinkMessage(device_name={self.device_name}, " f"device_profile={self.device_profile}, " f"attributes={self.attributes}, " f"timeseries={self.timeseries}, " @@ -68,7 +65,7 @@ def build(cls, attributes: List[AttributeEntry], timeseries: Mapping[int, List[TimeseriesEntry]], delivery_futures: List[Optional[asyncio.Future]], - size: int) -> 'DeviceUplinkMessage': + size: int) -> 'GatewayUplinkMessage': self = object.__new__(cls) object.__setattr__(self, 'device_name', device_name) object.__setattr__(self, 'device_profile', device_profile) @@ -151,7 +148,7 @@ def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry return self def add_delivery_futures(self, futures: Union[ - asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': + asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'GatewayUplinkMessageBuilder': if not isinstance(futures, list): futures = [futures] if futures: @@ -159,10 +156,10 @@ def add_delivery_futures(self, futures: Union[ self._delivery_futures.extend(futures) return self - def build(self) -> DeviceUplinkMessage: + def build(self) -> GatewayUplinkMessage: if not self._delivery_futures: self._delivery_futures = [asyncio.get_event_loop().create_future()] - return DeviceUplinkMessage.build( + return GatewayUplinkMessage.build( device_name=self._device_name, device_profile=self._device_profile, attributes=self._attributes, diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 6f9e271..aa860c0 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -14,17 +14,23 @@ import asyncio from abc import ABC, abstractmethod +from time import time from typing import Callable, Awaitable, Dict, Any, Union, List, Optional import uvloop from tb_mqtt_client.common.exceptions import exception_handler +from tb_mqtt_client.constants.json_typing import JSONCompatibleType +from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, GatewayUplinkMessage from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, GatewayUplinkMessage +from tb_mqtt_client.service.gateway.device_session import DeviceSession asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) exception_handler.install_asyncio_handler() @@ -59,10 +65,10 @@ async def disconnect(self): @abstractmethod async def send_timeseries(self, - telemetry_data: Union[TimeseriesEntry, - List[TimeseriesEntry], - Dict[str, Any], - List[Dict[str, Any]]], + data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], wait_for_publish: bool = True, timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], PublishResult, @@ -132,4 +138,72 @@ def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awa :param callback: Coroutine accepting (method, params) and returning result. """ - pass \ No newline at end of file + pass + + @staticmethod + def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], + TimeseriesEntry, + List[TimeseriesEntry], + List[Dict[str, Any]]], + device_session: Optional[DeviceSession] = None, + ) -> GatewayUplinkMessage: + timeseries_entries = [] + if isinstance(payload, TimeseriesEntry): + timeseries_entries.append(payload) + elif isinstance(payload, dict): + timeseries_entries.extend(BaseClient.__build_timeseries_entry_from_dict(payload)) + elif isinstance(payload, list) and len(payload) > 0: + for item in payload: + if isinstance(item, dict): + timeseries_entries.extend(BaseClient.__build_timeseries_entry_from_dict(item)) + elif isinstance(item, TimeseriesEntry): + timeseries_entries.append(item) + else: + raise ValueError(f"Unsupported item type in telemetry list: {type(item).__name__}") + else: + raise ValueError(f"Unsupported payload type for telemetry: {type(payload).__name__}") + + if device_session: + message_builder = GatewayUplinkMessageBuilder() + message_builder.set_device_name(device_session.device_info.device_name) + message_builder.set_device_profile(device_session.device_info.device_profile) + else: + message_builder = DeviceUplinkMessageBuilder() + message_builder.add_timeseries(timeseries_entries) + return message_builder.build() + + @staticmethod + def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> List[TimeseriesEntry]: + result = [] + if TELEMETRY_TIMESTAMP_PARAMETER in data: + ts = data.pop(TELEMETRY_TIMESTAMP_PARAMETER) + values = data.pop(TELEMETRY_VALUES_PARAMETER, {}) + else: + ts = int(time() * 1000) + values = data + + if not isinstance(values, dict): + raise ValueError(f"Expected {TELEMETRY_VALUES_PARAMETER} to be a dict, got {type(values).__name__}") + + for key, value in values.items(): + result.append(TimeseriesEntry(key, value, ts=ts)) + + return result + + @staticmethod + def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], + AttributeEntry, + List[AttributeEntry]], + device_session = None) -> Union[GatewayUplinkMessage, GatewayUplinkMessage]: + + if isinstance(payload, dict): + payload = [AttributeEntry(k, v) for k, v in payload.items()] + + if device_session: + message_builder = GatewayUplinkMessageBuilder() + message_builder.set_device_name(device_session.device_info.device_name) + message_builder.set_device_profile(device_session.device_info.device_profile) + else: + message_builder = DeviceUplinkMessageBuilder() + message_builder.add_attributes(payload) + return message_builder.build() \ No newline at end of file diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index a5365c4..0760e71 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -35,7 +35,7 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.claim_request import ClaimRequest -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest @@ -476,60 +476,6 @@ async def __on_publish_result(self, publish_result: PublishResult): else: logger.error("Publish failed: %r", publish_result) - @staticmethod - def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], - TimeseriesEntry, - List[TimeseriesEntry], - List[Dict[str, Any]]]) -> DeviceUplinkMessage: - timeseries_entries = [] - if isinstance(payload, TimeseriesEntry): - timeseries_entries.append(payload) - elif isinstance(payload, dict): - timeseries_entries.extend(DeviceClient.__build_timeseries_entry_from_dict(payload)) - elif isinstance(payload, list) and len(payload) > 0: - for item in payload: - if isinstance(item, dict): - timeseries_entries.extend(DeviceClient.__build_timeseries_entry_from_dict(item)) - elif isinstance(item, TimeseriesEntry): - timeseries_entries.append(item) - else: - raise ValueError(f"Unsupported item type in telemetry list: {type(item).__name__}") - else: - raise ValueError(f"Unsupported payload type for telemetry: {type(payload).__name__}") - - builder = DeviceUplinkMessageBuilder() - builder.add_timeseries(timeseries_entries) - return builder.build() - - @staticmethod - def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> List[TimeseriesEntry]: - result = [] - if TELEMETRY_TIMESTAMP_PARAMETER in data: - ts = data.pop(TELEMETRY_TIMESTAMP_PARAMETER) - values = data.pop(TELEMETRY_VALUES_PARAMETER, {}) - else: - ts = time() * 1000 - values = data - - if not isinstance(values, dict): - raise ValueError(f"Expected {TELEMETRY_VALUES_PARAMETER} to be a dict, got {type(values).__name__}") - - for key, value in values.items(): - result.append(TimeseriesEntry(key, value, ts=ts)) - - return result - - @staticmethod - def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], - AttributeEntry, - List[AttributeEntry]]) -> DeviceUplinkMessage: - if isinstance(payload, dict): - payload = [AttributeEntry(k, v) for k, v in payload.items()] - - builder = DeviceUplinkMessageBuilder() - builder.add_attributes(payload) - return builder.build() - @staticmethod async def provision(provision_request: 'ProvisioningRequest', timeout=BaseClient.DEFAULT_TIMEOUT): provision_client = ProvisioningClient( diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index 2940746..a59dec4 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -40,7 +40,7 @@ from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse @@ -61,7 +61,7 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio @abstractmethod def build_uplink_payloads( self, - messages: List[DeviceUplinkMessage] + messages: List[GatewayUplinkMessage] ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. @@ -253,7 +253,7 @@ def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, def splitter(self) -> MessageSplitter: return self._splitter - def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, @@ -265,32 +265,28 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup return [] result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] - device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) - for msg in messages: - device_name = msg.device_name - device_groups[device_name].append(msg) - logger.trace("Queued message for device='%s'", device_name) + telemetry_msgs = [] + attr_msgs = [] - logger.trace("Processing %d device group(s).", len(device_groups)) + for message in messages: + if message.has_timeseries(): + telemetry_msgs.append(message) + if message.has_attributes(): + attr_msgs.append(message) + logger.trace("Device - telemetry: %d, attributes: %d", len(telemetry_msgs), len(attr_msgs)) - for device, device_msgs in device_groups.items(): - telemetry_msgs = [m for m in device_msgs if m.has_timeseries()] - attr_msgs = [m for m in device_msgs if m.has_attributes()] - logger.trace("Device '%s' - telemetry: %d, attributes: %d", - device, len(telemetry_msgs), len(attr_msgs)) + for ts_batch in self._splitter.split_timeseries(telemetry_msgs): + payload = JsonMessageAdapter.build_payload(ts_batch, True) + count = ts_batch.timeseries_datapoint_count() + result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) + logger.trace("Built telemetry payload with %d datapoints", count) - for ts_batch in self._splitter.split_timeseries(telemetry_msgs): - payload = JsonMessageAdapter.build_payload(ts_batch, True) - count = ts_batch.timeseries_datapoint_count() - result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) - logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) - - for attr_batch in self._splitter.split_attributes(attr_msgs): - payload = JsonMessageAdapter.build_payload(attr_batch, False) - count = len(attr_batch.attributes) - result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) - logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) + for attr_batch in self._splitter.split_attributes(attr_msgs): + payload = JsonMessageAdapter.build_payload(attr_batch, False) + count = len(attr_batch.attributes) + result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) + logger.trace("Built attribute payload with %d attributes", count) logger.trace("Generated %d topic-payload entries.", len(result)) @@ -404,37 +400,26 @@ def build_provision_request(self, provision_request: 'ProvisioningRequest') -> T return topic, payload @staticmethod - def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: + def build_payload(msg: GatewayUplinkMessage, build_timeseries_payload) -> bytes: result: Union[Dict[str, Any], List[Dict[str, Any]]] = {} - device_name = msg.device_name - logger.trace("Building payload for device='%s'", device_name) - - if msg.device_name: - if build_timeseries_payload: - logger.trace("Packing timeseries for device='%s'", device_name) - result[msg.device_name] = JsonMessageAdapter.pack_timeseries(msg) - else: - logger.trace("Packing attributes for device='%s'", device_name) - result[msg.device_name] = JsonMessageAdapter.pack_attributes(msg) + if build_timeseries_payload: + logger.trace("Packing timeseries") + result = JsonMessageAdapter.pack_timeseries(msg) else: - if build_timeseries_payload: - logger.trace("Packing timeseries") - result = JsonMessageAdapter.pack_timeseries(msg) - else: - logger.trace("Packing attributes") - result = JsonMessageAdapter.pack_attributes(msg) + logger.trace("Packing attributes") + result = JsonMessageAdapter.pack_attributes(msg) payload = dumps(result) logger.trace("Built payload size: %d bytes", len(payload)) return payload @staticmethod - def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: + def pack_attributes(msg: GatewayUplinkMessage) -> Dict[str, Any]: logger.trace("Packing %d attribute(s)", len(msg.attributes)) return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def pack_timeseries(msg: 'DeviceUplinkMessage') -> List[Dict[str, Any]]: + def pack_timeseries(msg: 'GatewayUplinkMessage') -> List[Dict[str, Any]]: now_ts = int(datetime.now(UTC).timestamp() * 1000) packed = [ diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index eaea7b6..37a9302 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -12,19 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asyncio import sleep +from asyncio import sleep, Future from time import monotonic -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, Tuple, List, Any +from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder from tb_mqtt_client.service.device.client import DeviceClient from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher from tb_mqtt_client.service.gateway.gateway_client_interface import GatewayClientInterface from tb_mqtt_client.service.gateway.handlers.gateway_attribute_updates_handler import GatewayAttributeUpdatesHandler +from tb_mqtt_client.service.gateway.handlers.gateway_requested_attributes_response_handler import \ + GatewayRequestedAttributeResponseHandler +from tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler import GatewayRPCHandler from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter, JsonGatewayMessageAdapter logger = get_logger(__name__) @@ -36,6 +47,7 @@ class GatewayClient(DeviceClient, GatewayClientInterface): This class extends DeviceClient and adds gateway-specific functionality. """ SUBSCRIPTIONS_TIMEOUT = 1.0 # Timeout for subscribe/unsubscribe operations + OPERATIONAL_TIMEOUT = 5.0 # Timeout for connection events def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): """ @@ -51,11 +63,15 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter() self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed - self._gateway_rpc_handler = None # Placeholder for gateway RPC handler - self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler(self._event_dispatcher, - self._gateway_message_adapter, - self._device_manager) - self._gateway_requested_attribute_response_handler = None # Placeholder for gateway requested attribute response handler + self._gateway_rpc_handler = GatewayRPCHandler(event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self._device_manager) + self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler(event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self._device_manager) + self._gateway_requested_attribute_response_handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self._device_manager) # Gateway-specific rate limits self._device_messages_rate_limit = RateLimit("10:1,", name="device_messages") @@ -73,12 +89,106 @@ async def connect(self): """ logger.info("Connecting gateway to platform at %s:%s", self._host, self._port) await super().connect() + self._message_queue.set_gateway_message_adapter(self._gateway_message_adapter) # Subscribe to gateway-specific topics await self._subscribe_to_gateway_topics() logger.info("Gateway connected to ThingsBoard.") + + async def connect_device(self, device_name: str, device_profile: str, wait_for_publish=False) -> Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: + """ + Connect a device to the gateway. + + :param device_name: Name of the device to connect + :param device_profile: Profile of the device + :param wait_for_publish: Whether to wait for the publish result + :return: Tuple containing the DeviceSession and an optional PublishResult or list of PublishResults + """ + logger.info("Connecting device %s with profile %s", device_name, device_profile) + device_session = self._device_manager.register(device_name, device_profile) + device_connect_message = DeviceConnectMessage.build(device_name, device_profile) + topic, payload = self._gateway_message_adapter.build_device_connect_message_payload(device_connect_message) + futures = await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=1, + qos=self._config.qos + ) + + if not futures: + logger.warning("No publish futures were returned from message queue") + return device_session, [] + + if not wait_for_publish: + return device_session, futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(topic, self._config.qos, -1, len(payload), -1) + results.append(result) + + return device_session, results[0] if len(results) == 1 else results + + async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool): + pass + + async def send_device_timeseries(self, + device_session: DeviceSession, + data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Send timeseries data to the platform for a specific device. + :param device_session: The DeviceSession object for the device + :param data: Timeseries data to send, can be a single entry or a list of entries + :param wait_for_publish: Whether to wait for the publish result + :return: List of PublishResults or Future objects, or None if no data was sent + """ + logger.trace("Sending timeseries data for device %s", device_session.device_info.device_name) + + if not device_session or not data: + logger.warning("No device session or data provided for sending timeseries") + return None + + message = self._build_uplink_message_for_telemetry(data, device_session) + topic = mqtt_topics.GATEWAY_TELEMETRY_TOPIC + futures = await self._message_queue.publish( + topic=topic, + payload=message, + datapoints_count=message.timeseries_datapoint_count(), + qos=self._config.qos + ) + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(topic, self._config.qos, -1, message.size, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + + async def send_device_attributes(self, device_session: DeviceSession, attributes: ..., wait_for_publish: bool): + pass + + async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool): + pass + async def disconnect(self): """ Disconnect from the platform. diff --git a/tb_mqtt_client/service/gateway/device_manager.py b/tb_mqtt_client/service/gateway/device_manager.py index c3e0cb2..240747d 100644 --- a/tb_mqtt_client/service/gateway/device_manager.py +++ b/tb_mqtt_client/service/gateway/device_manager.py @@ -97,11 +97,6 @@ def set_rpc_request_callback(self, device_id: UUID, cb: Callable): if session: session.set_rpc_request_callback(cb) - def set_rpc_response_callback(self, device_id: UUID, cb: Callable): - session = self._sessions_by_id.get(device_id) - if session: - session.set_rpc_response_callback(cb) - def __state_change_callback(self, device_session: DeviceSession) -> None: if device_session.state.is_connected() and device_session.device_info.device_id not in self.__connected_devices: self.__connected_devices.add(device_session) diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index 152f7bb..0411ccf 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -14,10 +14,17 @@ from time import time from dataclasses import dataclass, field -from typing import Callable, Awaitable, Optional, Dict, Any +from typing import Callable, Awaitable, Optional, Union +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent + +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.gateway.device_info import DeviceInfo from tb_mqtt_client.entities.gateway.device_session_state import DeviceSessionState +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse @dataclass @@ -30,10 +37,9 @@ class DeviceSession: provisioned: bool = False state: DeviceSessionState = DeviceSessionState.CONNECTED - attribute_update_callback: Optional[Callable[[dict], Awaitable[None]]] = None - attribute_response_callback: Optional[Callable[[dict], Awaitable[None]]] = None - rpc_request_callback: Optional[Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None - rpc_response_callback: Optional[Callable[[dict], Awaitable[None]]] = None + attribute_update_callback: Optional[Callable[['AttributeUpdate'], Awaitable[None]]] = None + attribute_response_callback: Optional[Callable[['RequestedAttributeResponse'], Awaitable[None]]] = None + rpc_request_callback: Optional[Callable[['GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]] = None def update_state(self, new_state: DeviceSessionState): self.state = new_state @@ -43,14 +49,26 @@ def update_state(self, new_state: DeviceSessionState): def update_last_seen(self): self.last_seen_at = int(time() * 1000) - def set_attribute_update_callback(self, cb: Callable[[dict], Awaitable[None]]): + def set_attribute_update_callback(self, cb: Callable[['AttributeUpdate'], Awaitable[None]]): self.attribute_update_callback = cb - def set_attribute_response_callback(self, cb: Callable[[dict], Awaitable[None]]): + def set_attribute_response_callback(self, cb: Callable[['RequestedAttributeResponse'], Awaitable[None]]): self.attribute_response_callback = cb - def set_rpc_request_callback(self, cb: Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + def set_rpc_request_callback(self, cb: Callable[['GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]): self.rpc_request_callback = cb - def set_rpc_response_callback(self, cb: Callable[[dict], Awaitable[None]]): - self.rpc_response_callback = cb + async def handle_event_to_device(self, event: BaseGatewayEvent) -> Optional[Awaitable[Union['GatewayRPCResponse', None]]]: + if GatewayEventType.DEVICE_ATTRIBUTE_UPDATE_RECEIVE == event.event_type \ + and isinstance(event, AttributeUpdate): + if self.attribute_update_callback: + return self.attribute_update_callback(event) + elif GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVE == event.event_type \ + and isinstance(event, RequestedAttributeResponse): + if self.attribute_response_callback: + return self.attribute_response_callback(event) + elif GatewayEventType.RPC_REQUEST_RECEIVE == event.event_type \ + and isinstance(event, GatewayRPCRequest): + if self.rpc_request_callback: + return self.rpc_request_callback(event) + return None diff --git a/tb_mqtt_client/service/gateway/event_dispatcher.py b/tb_mqtt_client/service/gateway/event_dispatcher.py index 2108abb..8c40ee6 100644 --- a/tb_mqtt_client/service/gateway/event_dispatcher.py +++ b/tb_mqtt_client/service/gateway/event_dispatcher.py @@ -19,6 +19,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession EventCallback = Union[Callable[..., Awaitable[None]], Callable[..., None]] @@ -43,7 +44,9 @@ def unregister(self, event_type: GatewayEventType, callback: EventCallback): if not self._handlers[event_type]: del self._handlers[event_type] - async def dispatch(self, event: GatewayEvent, *args, **kwargs): + async def dispatch(self, event: GatewayEvent, *args, device_session: DeviceSession=None, **kwargs): + if device_session is not None: + return await device_session.handle_event_to_device(event) async with self._lock: callbacks = list(self._handlers.get(event.event_type, [])) for cb in callbacks: diff --git a/tb_mqtt_client/service/gateway/gateway_client_interface.py b/tb_mqtt_client/service/gateway/gateway_client_interface.py index 26cacac..afe3466 100644 --- a/tb_mqtt_client/service/gateway/gateway_client_interface.py +++ b/tb_mqtt_client/service/gateway/gateway_client_interface.py @@ -12,25 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from abc import ABC, abstractmethod +from asyncio import Future +from typing import Union, List, Tuple, Dict, Any +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_request import AttributeRequest -from tb_mqtt_client.entities.data.rpc_request import RPCRequest -from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.gateway.device_session import DeviceSession @@ -38,36 +27,24 @@ class GatewayClientInterface(BaseClient, ABC): @abstractmethod - async def connect_device(self, device_name: str, device_profile: str) -> DeviceSession: ... - - @abstractmethod - async def disconnect_device(self, device_session: DeviceSession): ... - - @abstractmethod - async def send_device_telemetry(self, device_session: DeviceSession, telemetry: ...): ... - - @abstractmethod - async def send_device_attributes(self, device_session: DeviceSession, attributes: ...): ... - - @abstractmethod - async def send_device_attributes_request(self, device_session: DeviceSession, attributes: AttributeRequest): ... - - @abstractmethod - async def send_device_client_side_rpc_request(self, device_session: DeviceSession, rpc_request: RPCRequest): ... - - @abstractmethod - async def send_device_server_side_rpc_response(self, device_session: DeviceSession, rpc_response: RPCResponse): ... - - + async def connect_device(self, device_name: str, device_profile: str, wait_for_publish: bool) -> \ + Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: ... @abstractmethod - def set_device_server_side_rpc_request_callback(self, device_session: DeviceSession, callback: ...): ... + async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool) -> \ + List[Union[PublishResult, Future[PublishResult]]]: ... @abstractmethod - def set_device_client_side_rpc_response_callback(self, device_session: DeviceSession, callback: ...): ... + async def send_device_timeseries(self, + device_session: DeviceSession, + data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], + wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... @abstractmethod - def set_device_requested_attributes_callback(self, device_session: DeviceSession, callback: ...): ... + async def send_device_attributes(self, device_session: DeviceSession, attributes: ..., wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... @abstractmethod - def set_device_attributes_update_callback(self, device_session: DeviceSession, callback: ...): ... + async def send_device_attributes_request(self, + device_session: DeviceSession, + attributes: Union[AttributeRequest, GatewayAttributeRequest], + wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py index bdd8c95..d7573dc 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py @@ -19,7 +19,10 @@ class GatewayAttributeUpdatesHandler: """Handles shared attribute updates for devices connected to a gateway.""" - def __init__(self, event_dispatcher: EventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + def __init__(self, + event_dispatcher: EventDispatcher, + message_adapter: GatewayMessageAdapter, + device_manager: DeviceManager): self.event_dispatcher = event_dispatcher self.message_adapter = message_adapter self.device_manager = device_manager @@ -28,7 +31,8 @@ def handle(self, topic: str, payload: bytes): """ Handles the gateway attribute update event by dispatching the attribute update """ - gateway_attribute_update = self.message_adapter.parse_attribute_update(payload) + data = self.message_adapter.deserialize_to_dict(payload) + gateway_attribute_update = self.message_adapter.parse_attribute_update(data) device_session = self.device_manager.get_by_name(gateway_attribute_update.device_name) if device_session: gateway_attribute_update.set_device_session(device_session) diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py new file mode 100644 index 0000000..869b4e5 --- /dev/null +++ b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py @@ -0,0 +1,130 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from asyncio import Task +from typing import Dict, Tuple, Coroutine, Callable, Union, TypeAlias, Any + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + +logger = get_logger(__name__) + + +AttributeResponseCallback: TypeAlias = Callable[[GatewayRequestedAttributeResponse], Coroutine[Any, Any, None]] + + +class GatewayRequestedAttributeResponseHandler: + """ + Handles responses to attribute requests sent to the platform. + """ + + def __init__(self, event_dispatcher: Any, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + self._event_dispatcher = event_dispatcher + self._message_adapter: Union[GatewayMessageAdapter, None] = message_adapter + self._device_manager = device_manager + self._pending_attribute_requests: Dict[Tuple[str, int], Tuple[GatewayAttributeRequest, + Union[Task, None]]] = {} + + async def register_request(self, + request: GatewayAttributeRequest, + timeout: int = 30): + """ + Called when a request is sent to the platform and a response is awaited. + """ + request_id = request.request_id + device_name = request.device_name + key = (device_name, request_id) + if key in self._pending_attribute_requests: + raise RuntimeError(f"Request ID {request.request_id} is already registered.") + timeout_task = None + if timeout > 0: + timeout_task = asyncio.get_event_loop().call_later(timeout, self._on_timeout, device_name, request_id) + self._pending_attribute_requests[key] = (request, timeout_task) + logger.debug("Registered attribute request with ID %s for device %s", request_id, device_name) + + def unregister_request(self, device_name: str, request_id: int): + """ + Unregisters a request for device attributes by device name and request ID. + This is useful if the request is no longer needed or has timed out. + """ + key = (device_name, request_id) + if key in self._pending_attribute_requests: + self._pending_attribute_requests.pop(key) + logger.debug("Unregistered attribute request with ID %s for device %s", request_id, device_name) + else: + logger.debug("Attempted to unregister non-existent request ID %s for device %s", request_id, device_name) + + async def handle(self, topic: str, payload: bytes): + """ + Handles the incoming attribute request response. + """ + try: + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle requested attribute response.") + return + deserialized_data = self._message_adapter.deserialize_to_dict(payload) + request_id = deserialized_data.get('id') + device_name = deserialized_data.get('device') + if request_id is None or device_name is None: + logger.error("Received requested attribute response without 'id' or 'device'. ") + return + attribute_request_with_callback = self._pending_attribute_requests.get((device_name, request_id)) + if not attribute_request_with_callback: + logger.warning("No pending request found for request ID %s. Ignoring response.", request_id) + return + attribute_request, timeout_task = attribute_request_with_callback + if timeout_task: + timeout_task.cancel() + requested_attribute_response = self._message_adapter.parse_gateway_requested_attribute_response(attribute_request, deserialized_data) + device_session = self._device_manager.get_by_name(device_name) + if not device_session: + logger.warning("No device session found for device: %s", device_name) + return + dispatch_task = self._event_dispatcher.dispatch(requested_attribute_response, device_session=device_session) + + logger.trace("Dispatching callback for requested attribute response with ID %s", + requested_attribute_response.request_id) + task = asyncio.create_task(dispatch_task) + task.add_done_callback(self._handle_callback_exception) + + except Exception as e: + logger.exception("Failed to handle requested attribute response: %s", e) + + def _on_timeout(self, device_name: str, request_id: int): + """ + Called when a request times out. + Unregisters the request and logs a warning. + """ + key = (device_name, request_id) + if key in self._pending_attribute_requests: + self._pending_attribute_requests.pop(key) + logger.warning("Request ID %s for device %s has timed out and has been unregistered.", request_id, device_name) + else: + logger.debug("Attempted to unregister non-existent request ID %s for device %s on timeout", request_id, device_name) + + def _handle_callback_exception(self, task: asyncio.Task): + try: + task.result() + except Exception as e: + logger.exception("Exception in user-defined requested attribute callback: %s", e) + + def clear(self): + """ + Clears all pending requests (e.g., on disconnect). + """ + self._pending_attribute_requests.clear() diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py index fa669aa..f7446a1 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py @@ -12,3 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Awaitable, Callable, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + +logger = get_logger(__name__) + + +class GatewayRPCHandler: + """ + Handles incoming RPC request messages for a device connected through the gateway. + """ + + def __init__(self, event_dispatcher: EventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + self._event_dispatcher = event_dispatcher + self._message_adapter = message_adapter + self._device_manager = device_manager + self._callback: Optional[Callable[[GatewayRPCRequest], Awaitable[GatewayRPCResponse]]] = None + self._event_dispatcher.register(GatewayEventType.RPC_REQUEST_RECEIVE, self.handle) + + async def handle(self, topic: str, payload: bytes) -> Optional[GatewayRPCResponse]: + """ + Process the RPC request and return the response for it. + :returns: GatewayRPCResponse or None if failed + """ + + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle RPC request.") + return None + + rpc_request: Optional[GatewayRPCRequest] = None + try: + data = self._message_adapter.deserialize_to_dict(payload) + rpc_request = self._message_adapter.parse_rpc_request(topic, data) + device_session = self._device_manager.get_by_name(rpc_request.device_name) + if device_session: + rpc_request.set_device_session(device_session) + else: + logger.warning("No device session found for device: %s", rpc_request.device_name) + return None + logger.debug("Handling RPC method id: %i - %s with params: %s", + rpc_request.request_id, rpc_request.method, rpc_request.params) + result = await self._event_dispatcher.dispatch(rpc_request) + if not isinstance(result, GatewayRPCResponse): + logger.error("RPC callback must return an instance of GatewayRPCResponse, got: %s", type(result)) + return None + logger.debug("RPC response for device %r method id: %i - %s with result: %s", + rpc_request.device_name, result.request_id, rpc_request.method, result.result) + return result + + except Exception as e: + logger.exception("Failed to process RPC request: %s", e) + if rpc_request is None: + return None + GatewayRPCResponse.build(device_name=rpc_request.device_name, request_id=rpc_request.request_id, error=e) diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 1d1241b..e8c85ed 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -27,12 +27,13 @@ GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest logger = get_logger(__name__) @@ -45,7 +46,7 @@ class GatewayMessageAdapter(ABC): @abstractmethod def build_uplink_payloads( self, - messages: List[DeviceUplinkMessage] + messages: List[GatewayUplinkMessage] ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. @@ -79,7 +80,7 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr pass @abstractmethod - def parse_attribute_update(self, payload: bytes) -> GatewayAttributeUpdate: + def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: """ Parse the attribute update payload into an GatewayAttributeUpdate. This method should be implemented to handle the specific format of the payload. @@ -87,13 +88,29 @@ def parse_attribute_update(self, payload: bytes) -> GatewayAttributeUpdate: pass @abstractmethod - def parse_gateway_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, payload: bytes) -> Union[GatewayRequestedAttributeResponse, None]: + def parse_gateway_requested_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: """ - Parse the gateway attribute response payload into an GatewayAttributeResponse. + Parse the gateway attribute response data into an GatewayAttributeResponse. This method should be implemented to handle the specific format of the payload. """ pass + @abstractmethod + def parse_rpc_request(self, topic: str, data: Dict[str, Any]) -> GatewayRPCRequest: + """ + Parse the RPC request from the given topic and payload. + This method should be implemented to handle the specific format of the RPC request. + """ + pass + + @abstractmethod + def deserialize_to_dict(self, payload: bytes) -> Dict[str, Any]: + """ + Deserialize incoming payload into a dictionary format, required for parsing responses. + This method should be implemented to handle the specific format of the response. + """ + pass + class JsonGatewayMessageAdapter(GatewayMessageAdapter): """ @@ -101,7 +118,7 @@ class JsonGatewayMessageAdapter(GatewayMessageAdapter): Builds uplink payloads from uplink message objects and parses JSON payloads into GatewayEvent objects. """ - def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, @@ -113,7 +130,7 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup return [] result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] - device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) + device_groups: Dict[str, List[GatewayUplinkMessage]] = defaultdict(list) for msg in messages: device_name = msg.device_name @@ -129,21 +146,21 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup gateway_timeseries_delivery_futures: Dict[str, List[Optional[asyncio.Future[PublishResult]]]] = {} gateway_attributes_delivery_futures: Dict[str, List[Optional[asyncio.Future[PublishResult]]]] = {} for device, device_msgs in device_groups.items(): - if device not in gateway_timeseries_message: + timeseries_msgs: List[GatewayUplinkMessage] = [m for m in device_msgs if m.has_timeseries()] + attr_msgs: List[GatewayUplinkMessage] = [m for m in device_msgs if m.has_attributes()] + if device not in gateway_timeseries_message and timeseries_msgs: gateway_timeseries_message[device] = [] gateway_timeseries_delivery_futures[device] = [] - if device not in gateway_attributes_message: + if device not in gateway_attributes_message and attr_msgs: gateway_attributes_message[device] = [] gateway_attributes_delivery_futures[device] = [] - telemetry_msgs: List[DeviceUplinkMessage] = [m for m in device_msgs if m.has_timeseries()] - attr_msgs: List[DeviceUplinkMessage] = [m for m in device_msgs if m.has_attributes()] logger.trace("Device '%s' - telemetry: %d, attributes: %d", - device, len(telemetry_msgs), len(attr_msgs)) + device, len(timeseries_msgs), len(attr_msgs)) # TODO: Recommended to add message splitter to handle large messages and split them into smaller batches - for ts_batch in telemetry_msgs: + for ts_batch in timeseries_msgs: packed_ts = JsonGatewayMessageAdapter.pack_timeseries(ts_batch) - gateway_timeseries_message[device].append(packed_ts) + gateway_timeseries_message[device].extend(packed_ts) count = ts_batch.timeseries_datapoint_count() gateway_timeseries_device_datapoints_counts[device] = gateway_timeseries_device_datapoints_counts.get(device, 0) + count gateway_timeseries_delivery_futures[device] = ts_batch.get_delivery_futures() @@ -152,19 +169,29 @@ def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tup for attr_batch in attr_msgs: packed_attrs = JsonGatewayMessageAdapter.pack_attributes(attr_batch) count = attr_batch.attributes_datapoint_count() - gateway_attributes_message[device].append(packed_attrs) + gateway_attributes_message[device].extend(packed_attrs) gateway_attributes_device_datapoints_counts[device] = gateway_attributes_device_datapoints_counts.get(device, 0) + count logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) - if telemetry_msgs: - result.append((GATEWAY_TELEMETRY_TOPIC, - dumps(gateway_timeseries_message[device]), - gateway_timeseries_device_datapoints_counts[device], - gateway_timeseries_delivery_futures[device])) - if attr_msgs: - result.append((GATEWAY_ATTRIBUTES_TOPIC, - dumps(gateway_attributes_message[device]), - gateway_attributes_device_datapoints_counts[device], - gateway_attributes_delivery_futures[device])) + + if gateway_timeseries_message: + all_timeseries_delivery_futures = set() + for futures in gateway_timeseries_delivery_futures.values(): + if futures: + all_timeseries_delivery_futures.update(futures) + + result.append((GATEWAY_TELEMETRY_TOPIC, + dumps(gateway_timeseries_message), + sum(gateway_timeseries_device_datapoints_counts[per_device] for per_device in gateway_timeseries_device_datapoints_counts), + list(all_timeseries_delivery_futures))) + if gateway_attributes_message: + all_attributes_delivery_futures = set() + for futures in gateway_attributes_delivery_futures.values(): + if futures: + all_attributes_delivery_futures.update(futures) + result.append((GATEWAY_ATTRIBUTES_TOPIC, + dumps(gateway_attributes_message), + sum(gateway_attributes_device_datapoints_counts[per_device] for per_device in gateway_attributes_device_datapoints_counts), + list(all_attributes_delivery_futures))) logger.trace("Generated %d topic-payload entries.", len(result)) @@ -190,7 +217,7 @@ def build_device_connect_message_payload(self, device_connect_message: DeviceCon def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage) -> Tuple[str, bytes]: """ Build the payload for a device disconnect message. - This method serializes the device name to JSON format. + This method serializes the DeviceDisconnectMessage to JSON format. """ try: payload = dumps(device_disconnect_message.to_payload_format()) @@ -213,50 +240,79 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr logger.error("Failed to build gateway attribute request payload: %s", str(e)) raise ValueError("Invalid gateway attribute request format") from e - def parse_attribute_update(self, payload: bytes) -> GatewayAttributeUpdate: + def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: try: - data = loads(payload.decode('utf-8')) - device_name = data['device_name'] + device_name = data['device'] attribute_update = AttributeUpdate._deserialize_from_dict(data['data']) # noqa return GatewayAttributeUpdate(device_name=device_name, attribute_update=attribute_update) except Exception as e: logger.error("Failed to parse attribute update: %s", str(e)) raise ValueError("Invalid attribute update format") from e - def parse_gateway_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, payload: bytes) -> Union[GatewayRequestedAttributeResponse, None]: + def parse_gateway_requested_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: + """ + Parse the gateway attribute response data into a GatewayRequestedAttributeResponse. + This method extracts the device name, shared and client attributes from the payload. + """ try: - data = loads(payload.decode('utf-8')) - device_name = data['device_name'] + device_name = data['device'] client = [] shared = [] - if 'value' in data and not ((len(gateway_attribute_request.client_keys) == 1 and len(gateway_attribute_request.shared_keys) == 0) - or (len(gateway_attribute_request.client_keys) == 0 and len(gateway_attribute_request.shared_keys) == 1)): + if 'value' in data and not ((len(gateway_attribute_request.client_keys) == 1 and not gateway_attribute_request.shared_keys) + or (len(gateway_attribute_request.shared_keys) == 1 and not gateway_attribute_request.client_keys)): # TODO: Skipping case when requested several attributes, but only one is returned, issue on the platform logger.warning("Received gateway attribute response with single key, but multiply keys expected. " - "Request keys: %s, Response keys: %s", list(*gateway_attribute_request.client_keys, *gateway_attribute_request.shared_keys), data['value']) + "Request keys: %s, Response keys: %s", + list(*gateway_attribute_request.client_keys, *gateway_attribute_request.shared_keys), + data['value']) return None elif 'value' in data: if len(gateway_attribute_request.client_keys) == 1: - client= [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] + client = [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] elif len(gateway_attribute_request.shared_keys) == 1: shared = [AttributeEntry(gateway_attribute_request.shared_keys[0], data['value'])] - elif 'data' in data: + elif 'values' in data: if len(gateway_attribute_request.client_keys) > 0: - client = [AttributeEntry(k, v) for k, v in data['data'].get('client', {}).items() if k in gateway_attribute_request.client_keys] + client = [AttributeEntry(k, v) for k, v in data['data'].get('values', {}).items() if + k in gateway_attribute_request.client_keys] if len(gateway_attribute_request.shared_keys) > 0: - shared = [AttributeEntry(k, v) for k, v in data['data'].get('shared', {}).items() if k in gateway_attribute_request.shared_keys] + shared = [AttributeEntry(k, v) for k, v in data['data'].get('values', {}).items() if + k in gateway_attribute_request.shared_keys] return GatewayRequestedAttributeResponse(device_name=device_name, request_id=gateway_attribute_request.request_id, shared=shared, client=client) except Exception as e: - logger.error("Failed to parse gateway attribute response: %s", str(e)) - raise ValueError("Invalid gateway attribute response format") from e + logger.error("Failed to parse gateway requested attribute response: %s", str(e)) + raise ValueError("Invalid gateway requested attribute response format") from e + + def parse_rpc_request(self, topic: str, data: Dict[str, Any]) -> GatewayRPCRequest: + """ + Parse the RPC request from the given topic and payload. + This method deserializes the payload into a GatewayRPCRequest object. + """ + try: + return GatewayRPCRequest._deserialize_from_dict(data) # noqa + except Exception as e: + logger.error("Failed to parse RPC request: %s", str(e)) + raise ValueError("Invalid RPC request format") from e + + def deserialize_to_dict(self, payload: bytes) -> Dict[str, Any]: + """ + Deserialize incoming payload into a dictionary format, required for parsing responses. + This method decodes the payload from bytes to a string and then loads it as a JSON object. + """ + try: + data = loads(payload.decode('utf-8')) + return data + except Exception as e: + logger.error("Failed to deserialize requested attribute response: %s", str(e)) + raise ValueError("Invalid requested attribute response format") from e @staticmethod - def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: + def pack_attributes(msg: GatewayUplinkMessage) -> Dict[str, Any]: logger.trace("Packing %d attribute(s)", len(msg.attributes)) return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def pack_timeseries(msg: DeviceUplinkMessage) -> List[Dict[str, Any]]: + def pack_timeseries(msg: GatewayUplinkMessage) -> List[Dict[str, Any]]: now_ts = int(datetime.now(UTC).timestamp() * 1000) packed = [ {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 5d0539e..c9d606e 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -19,9 +19,11 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.service.device.message_adapter import MessageAdapter +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter from tb_mqtt_client.service.mqtt_manager import MQTTManager logger = get_logger(__name__) @@ -39,7 +41,8 @@ def __init__(self, message_adapter: MessageAdapter, max_queue_size: int = 1000000, batch_collect_max_time_ms: int = 100, - batch_collect_max_count: int = 500): + batch_collect_max_count: int = 500, + gateway_message_adapter: Optional[GatewayMessageAdapter] = None): self._main_stop_event = main_stop_event self._batch_max_time = batch_collect_max_time_ms / 1000 self._batch_max_count = batch_collect_max_count @@ -51,20 +54,21 @@ def __init__(self, self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} # Queue expects tuples of (topic, payload, delivery_futures, datapoints_count, qos) - self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) + self._queue: asyncio.Queue[Tuple[str, Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) self._pending_queue_tasks: set[asyncio.Task] = set() self._active = asyncio.Event() self._wakeup_event = asyncio.Event() self._retry_tasks: set[asyncio.Task] = set() self._active.set() self._adapter = message_adapter + self._gateway_adapter = gateway_message_adapter self._loop_task = asyncio.create_task(self._dequeue_loop()) self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", max_queue_size, self._batch_max_time, batch_collect_max_count) - async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: - delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) else [asyncio.Future()] + async def publish(self, topic: str, payload: Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: + delivery_futures = payload.get_delivery_futures() if isinstance(payload, GatewayUplinkMessage) or isinstance(payload, GatewayUplinkMessage) else [asyncio.Future()] try: logger.trace("publish() received delivery future id: %r for topic=%s", id(delivery_futures[0]) if delivery_futures else -1, topic) @@ -108,9 +112,10 @@ async def _dequeue_loop(self): logger.trace("Dequeued message for batching: topic=%s, device=%s", topic, getattr(payload, 'device_name', 'N/A')) - batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = [(topic, payload, delivery_futures_or_none, datapoints, qos)] + batch: List[Tuple[str, Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = [(topic, payload, delivery_futures_or_none, datapoints, qos)] start = asyncio.get_event_loop().time() batch_size = payload.size + batch_type = type(payload).__name__ while not self._queue.empty(): elapsed = asyncio.get_event_loop().time() - start @@ -123,7 +128,13 @@ async def _dequeue_loop(self): try: next_topic, next_payload, delivery_futures_or_none, datapoints, qos = self._queue.get_nowait() - if isinstance(next_payload, DeviceUplinkMessage): + if isinstance(next_payload, GatewayUplinkMessage) or isinstance(next_payload, GatewayUplinkMessage): + if batch_type is not None and batch_type != type(next_payload).__class__.__name__: + logger.trace("Batch type mismatch: current=%s, next=%s, finalizing current", + batch_type, type(next_payload).__class__.__name__) + self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) + break + batch_type = type(next_payload).__name__ msg_size = next_payload.size if batch_size + msg_size > self._adapter.splitter.max_payload_size: # noqa logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) @@ -137,11 +148,18 @@ async def _dequeue_loop(self): except asyncio.QueueEmpty: break + if batch_type is None: + batch_type = type(payload).__name__ + if batch: logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) messages = [device_uplink_message for _, device_uplink_message, _, _, _ in batch] - topic_payloads = self._adapter.build_uplink_payloads(messages) + if batch_type == 'GatewayUplinkMessage' and self._gateway_adapter: + logger.trace("Building gateway uplink payloads for %d messages", len(messages)) + topic_payloads = self._gateway_adapter.build_uplink_payloads(messages) + else: + topic_payloads = self._adapter.build_uplink_payloads(messages) for topic, payload, datapoints, delivery_futures in topic_payloads: logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", @@ -163,6 +181,7 @@ async def _try_publish(self, delivery_futures_or_none = [] is_message_with_telemetry_or_attributes = topic in (mqtt_topics.DEVICE_TELEMETRY_TOPIC, mqtt_topics.DEVICE_ATTRIBUTES_TOPIC) + # TODO: Add topics check for gateways logger.trace("Attempting publish: topic=%s, datapoints=%d", topic, datapoints) @@ -285,7 +304,7 @@ async def retry(): self._retry_tasks.add(task) task.add_done_callback(self._retry_tasks.discard) - async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage], List[asyncio.Future[PublishResult]], int, int]: + async def _wait_for_message(self) -> Tuple[str, Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]: while self._active.is_set(): try: if not self._queue.empty(): @@ -367,7 +386,7 @@ def clear(self): topic=topic, qos=qos, message_id=-1, - payload_size=message.size if isinstance(message, DeviceUplinkMessage) else len(message), + payload_size=message.size if isinstance(message, GatewayUplinkMessage) or isinstance(message, GatewayUplinkMessage) else len(message), reason_code=-1 )) self._queue.task_done() @@ -391,3 +410,6 @@ async def _refill_rate_limits(self): for rl in (self._message_rate_limit, self._telemetry_rate_limit, self._telemetry_dp_rate_limit): if rl: await rl.refill() + + def set_gateway_message_adapter(self, message_adapter: GatewayMessageAdapter): + self._gateway_adapter = message_adapter diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index b17062b..4189abf 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -18,7 +18,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry logger = get_logger(__name__) @@ -38,7 +38,7 @@ def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): logger.trace("MessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", self._max_payload_size, self._max_datapoints) - def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: + def split_timeseries(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: logger.trace("Splitting timeseries for %d messages", len(messages)) if (len(messages) == 1 @@ -47,9 +47,9 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp and messages[0].size <= self._max_payload_size): return messages - result: List[DeviceUplinkMessage] = [] + result: List[GatewayUplinkMessage] = [] - grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) + grouped: Dict[Tuple[str, Optional[str]], List[GatewayUplinkMessage]] = defaultdict(list) for msg in messages: key = (msg.device_name, msg.device_profile) grouped[key].append(msg) @@ -113,9 +113,9 @@ async def resolve_original(): logger.trace("Total timeseries batches created: %d", len(result)) return result - def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: + def split_attributes(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: logger.trace("Splitting attributes for %d messages", len(messages)) - result: List[DeviceUplinkMessage] = [] + result: List[GatewayUplinkMessage] = [] if (len(messages) == 1 and ((messages[0].attributes_datapoint_count() + messages[ @@ -123,7 +123,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp and messages[0].size <= self._max_payload_size): return messages - grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) + grouped: Dict[Tuple[str, Optional[str]], List[GatewayUplinkMessage]] = defaultdict(list) for msg in messages: grouped[(msg.device_name, msg.device_profile)].append(msg) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 2a1dd06..2c6fff1 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -14,6 +14,8 @@ import asyncio import ssl +import threading +import time from asyncio import sleep from contextlib import suppress from time import monotonic @@ -287,11 +289,16 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc if reason_code == 142: logger.error("Session was taken over, looks like another client connected with the same credentials.") self._backpressure.notify_disconnect(delay_seconds=10) - if reason_code in (131, 142, 143, 151): + if reason_code in (131, 142, 143, 151): # 131, 142, 151 may be caused by rate limits or issue with the data reached_time = 1 for rate_limit in self.__rate_limiter.values(): if isinstance(rate_limit, RateLimit): - reached_limit = asyncio.get_event_loop().run_until_complete(rate_limit.reach_limit()) + try: + reached_limit = self._run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) + except TimeoutError: + logger.warning("Timeout while checking rate limit reaching.") + reached_time = 10 # Default to 10 seconds if timeout occurs + break reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) self._backpressure.notify_disconnect(delay_seconds=reached_time) elif reason_code != 0: @@ -301,6 +308,41 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc if self._on_disconnect_callback: asyncio.create_task(self._on_disconnect_callback()) + def _run_coroutine_sync(self, coro_func, timeout: float = 3.0, raise_on_timeout: bool = False): + """ + Run async coroutine and return its result from a sync function even if event loop is running. + :param coro_func: async function with no arguments (like: lambda: some_async_fn()) + :param timeout: max wait time in seconds + :param raise_on_timeout: if True, raise TimeoutError on timeout; otherwise return None + """ + result_container = {} + event = threading.Event() + + async def wrapper(): + try: + result = await coro_func() + result_container['result'] = result + except Exception as e: + result_container['error'] = e + finally: + event.set() + + loop = asyncio.get_running_loop() + loop.create_task(wrapper()) + + completed = event.wait(timeout=timeout) + + if not completed: + logger.warning("Timeout while waiting for coroutine to finish: %s", coro_func) + if raise_on_timeout: + raise TimeoutError(f"Coroutine {coro_func} did not complete in {timeout} seconds.") + return None + + if 'error' in result_container: + raise result_container['error'] + + return result_container.get('result') + def _on_message_internal(self, client, topic: str, payload: bytes, qos, properties): logger.trace("Received message by client %r on topic %s with payload %r, qos %r, properties %r", client, topic, payload, qos, properties) diff --git a/tests/entities/data/test_device_uplink_message.py b/tests/entities/data/test_device_uplink_message.py index bd90cc7..634c36a 100644 --- a/tests/entities/data/test_device_uplink_message.py +++ b/tests/entities/data/test_device_uplink_message.py @@ -19,7 +19,7 @@ import pytest from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage, \ +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, GatewayUplinkMessage, \ DEFAULT_FIELDS_SIZE from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry @@ -36,8 +36,8 @@ def timeseries_entry(): def test_direct_instantiation_forbidden(): with pytest.raises(TypeError, match="Direct instantiation of DeviceUplinkMessage is not allowed"): - DeviceUplinkMessage(device_name="test", device_profile="default", attributes=(), timeseries={}, - delivery_futures=[], _size=0) + GatewayUplinkMessage(device_name="test", device_profile="default", attributes=(), timeseries={}, + delivery_futures=[], _size=0) def test_build_empty_message(): From bc1b34ae7c6f6645a3fcec1baeb33bff4279c55c Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 06:45:21 +0300 Subject: [PATCH 45/74] Returned device name to device uplink message to improve data grouping --- examples/device/send_timeseries.py | 9 ++- examples/gateway/send_timeseries.py | 28 +++++++- .../entities/data/device_uplink_message.py | 42 +++++++++--- .../gateway/gateway_uplink_message.py | 24 ++----- tb_mqtt_client/service/base_client.py | 6 +- tb_mqtt_client/service/device/client.py | 4 -- .../service/device/message_adapter.py | 66 ++++++++----------- tb_mqtt_client/service/gateway/client.py | 2 +- .../service/gateway/message_adapter.py | 2 +- tb_mqtt_client/service/message_queue.py | 18 ++--- tb_mqtt_client/service/message_splitter.py | 14 ++-- tb_mqtt_client/service/mqtt_manager.py | 1 - .../data/test_device_uplink_message.py | 6 +- tests/service/test_json_message_adapter.py | 7 -- tests/service/test_mqtt_manager.py | 12 +--- 15 files changed, 126 insertions(+), 115 deletions(-) diff --git a/examples/device/send_timeseries.py b/examples/device/send_timeseries.py index f7e4b70..4528fb1 100644 --- a/examples/device/send_timeseries.py +++ b/examples/device/send_timeseries.py @@ -17,6 +17,8 @@ import asyncio import logging from random import uniform, randint +from time import time + from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient @@ -53,9 +55,12 @@ async def main(): logger.info("Single timeseries entry sent successfully.") # Send a list of time series entries + ts = int(time() * 1000) entries = [ - TimeseriesEntry("vibration", 0.05), - TimeseriesEntry("speed", 123) + TimeseriesEntry("vibration", 0.05, ts), + TimeseriesEntry("speed", 123, ts), + TimeseriesEntry("vibration", 0.01, ts - 1000), + TimeseriesEntry("speed", 120, ts - 1000), ] logger.info("Sending list of timeseries entries: %s", entries) await client.send_timeseries(entries) diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index e89e22e..463d320 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -13,7 +13,9 @@ # limitations under the License. import asyncio +from time import time +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.service.gateway.client import GatewayClient @@ -33,8 +35,8 @@ async def main(): # Connecting device - device_name = "Test Device A2" - device_profile = "test_device_profile" + device_name = "Test Device A1" + device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) @@ -52,6 +54,28 @@ async def main(): await client.send_device_timeseries(device_session=device_session, data=raw_timeseries, wait_for_publish=True) logger.info("Raw timeseries sent successfully.") + # Send time series as list of dictionaries + ts = int(time() * 1000) + list_timeseries = [ + {"ts": ts, "values": {"temperature": 26.0, "humidity": 65}}, + {"ts": ts - 1000, "values": {"temperature": 26.5, "humidity": 70}} + ] + logger.info("Sending list of timeseries: %s", list_timeseries) + await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) + logger.info("List of timeseries sent successfully.") + + # Send time series as TimeseriesEntry objects + timeseries_entries = [ + TimeseriesEntry(ts=ts, key="temperature", value=27.0), + TimeseriesEntry(ts=ts, key="humidity", value=75), + TimeseriesEntry(ts=ts - 1000, key="temperature", value=28.0), + TimeseriesEntry(ts=ts - 1000, key="humidity", value=80) + ] + logger.info("Sending TimeseriesEntry objects: %s", timeseries_entries) + await client.send_device_timeseries(device_session=device_session, data=timeseries_entries, wait_for_publish=True) + logger.info("TimeseriesEntry objects sent successfully.") + + await client.stop() if __name__ == "__main__": diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index 9dae46d..ebe75d9 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -24,31 +24,39 @@ logger = get_logger(__name__) -DEFAULT_FIELDS_SIZE = len('{"attributes":"","timeseries":""}'.encode('utf-8')) +DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) @dataclass(slots=True, frozen=True) -class GatewayUplinkMessage: +class DeviceUplinkMessage: + device_name: Optional[str] + device_profile: Optional[str] attributes: Tuple[AttributeEntry] timeseries: Mapping[int, Tuple[TimeseriesEntry]] delivery_futures: List[Optional[asyncio.Future[PublishResult]]] _size: int def __new__(cls, *args, **kwargs): - raise TypeError( - "Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") + raise TypeError("Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") def __repr__(self): - return (f"DeviceUplinkMessage(attributes={self.attributes}, " - f"timeseries={self.timeseries}, delivery_futures={self.delivery_futures})") + return (f"DeviceUplinkMessage(device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"attributes={self.attributes}, " + f"timeseries={self.timeseries}, " + f"delivery_futures={self.delivery_futures})") @classmethod def build(cls, + device_name: Optional[str], + device_profile: Optional[str], attributes: List[AttributeEntry], timeseries: Mapping[int, List[TimeseriesEntry]], delivery_futures: List[Optional[asyncio.Future]], - size: int) -> 'GatewayUplinkMessage': + size: int) -> 'DeviceUplinkMessage': self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) object.__setattr__(self, 'attributes', tuple(attributes)) object.__setattr__(self, 'timeseries', MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) @@ -78,11 +86,25 @@ def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: class DeviceUplinkMessageBuilder: def __init__(self): + self._device_name: Optional[str] = None + self._device_profile: Optional[str] = None self._attributes: List[AttributeEntry] = [] self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] self.__size = DEFAULT_FIELDS_SIZE + def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': + self._device_name = device_name + if device_name is not None: + self.__size += len(device_name) + return self + + def set_device_profile(self, profile: str) -> 'DeviceUplinkMessageBuilder': + self._device_profile = profile + if profile is not None: + self.__size += len(profile) + return self + def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]]) -> 'DeviceUplinkMessageBuilder': if not isinstance(attributes, list): attributes = [attributes] @@ -122,10 +144,12 @@ def add_delivery_futures(self, futures: Union[ self._delivery_futures.extend(futures) return self - def build(self) -> GatewayUplinkMessage: + def build(self) -> DeviceUplinkMessage: if not self._delivery_futures: self._delivery_futures = [asyncio.get_event_loop().create_future()] - return GatewayUplinkMessage.build( + return DeviceUplinkMessage.build( + device_name=self._device_name, + device_profile=self._device_profile, attributes=self._attributes, timeseries=self._timeseries, delivery_futures=self._delivery_futures, diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index a49ce47..caf435e 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -12,29 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import asyncio from dataclasses import dataclass from types import MappingProxyType -from typing import List, Optional, Union, OrderedDict, Tuple, Mapping +from typing import List, Optional, Union, OrderedDict, Mapping from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry logger = get_logger(__name__) @@ -43,7 +29,7 @@ @dataclass(slots=True, frozen=True) -class GatewayUplinkMessage(GatewayUplinkMessage): +class GatewayUplinkMessage(DeviceUplinkMessage): device_name: str device_profile: str @@ -59,7 +45,7 @@ def __repr__(self): f"delivery_futures={self.delivery_futures})") @classmethod - def build(cls, + def build(cls, # noqa device_name: Optional[str], device_profile: Optional[str], attributes: List[AttributeEntry], @@ -159,7 +145,7 @@ def add_delivery_futures(self, futures: Union[ def build(self) -> GatewayUplinkMessage: if not self._delivery_futures: self._delivery_futures = [asyncio.get_event_loop().create_future()] - return GatewayUplinkMessage.build( + return GatewayUplinkMessage.build( # noqa device_name=self._device_name, device_profile=self._device_profile, attributes=self._attributes, diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index aa860c0..51cd7bf 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -25,7 +25,7 @@ from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.claim_request import ClaimRequest -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, GatewayUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.common.publish_result import PublishResult @@ -146,7 +146,7 @@ def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], List[TimeseriesEntry], List[Dict[str, Any]]], device_session: Optional[DeviceSession] = None, - ) -> GatewayUplinkMessage: + ) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: timeseries_entries = [] if isinstance(payload, TimeseriesEntry): timeseries_entries.append(payload) @@ -194,7 +194,7 @@ def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> L def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], AttributeEntry, List[AttributeEntry]], - device_session = None) -> Union[GatewayUplinkMessage, GatewayUplinkMessage]: + device_session = None) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: if isinstance(payload, dict): payload = [AttributeEntry(k, v) for k, v in payload.items()] diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 0760e71..cde006c 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -16,7 +16,6 @@ from asyncio import sleep, wait_for, TimeoutError, Event, Future from random import choices from string import ascii_uppercase, digits -from time import time from typing import Callable, Awaitable, Optional, Dict, Any, Union, List from orjson import dumps @@ -29,13 +28,10 @@ from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.constants.json_typing import JSONCompatibleType -from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.claim_request import ClaimRequest -from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index a59dec4..cc29e68 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -12,20 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import asyncio from abc import ABC, abstractmethod from itertools import chain @@ -40,7 +26,7 @@ from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse @@ -61,7 +47,7 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio @abstractmethod def build_uplink_payloads( self, - messages: List[GatewayUplinkMessage] + messages: List[DeviceUplinkMessage] ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. @@ -253,7 +239,7 @@ def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, def splitter(self) -> MessageSplitter: return self._splitter - def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, @@ -265,28 +251,32 @@ def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tu return [] result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] + device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) + + for msg in messages: + device_name = msg.device_name + device_groups[device_name].append(msg) + logger.trace("Queued message for device='%s'", device_name) - telemetry_msgs = [] - attr_msgs = [] + logger.trace("Processing %d device group(s).", len(device_groups)) - for message in messages: - if message.has_timeseries(): - telemetry_msgs.append(message) - if message.has_attributes(): - attr_msgs.append(message) - logger.trace("Device - telemetry: %d, attributes: %d", len(telemetry_msgs), len(attr_msgs)) + for device, device_msgs in device_groups.items(): + telemetry_msgs = [m for m in device_msgs if m.has_timeseries()] + attr_msgs = [m for m in device_msgs if m.has_attributes()] + logger.trace("Device '%s' - telemetry: %d, attributes: %d", + device, len(telemetry_msgs), len(attr_msgs)) - for ts_batch in self._splitter.split_timeseries(telemetry_msgs): - payload = JsonMessageAdapter.build_payload(ts_batch, True) - count = ts_batch.timeseries_datapoint_count() - result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) - logger.trace("Built telemetry payload with %d datapoints", count) + for ts_batch in self._splitter.split_timeseries(telemetry_msgs): + payload = JsonMessageAdapter.build_payload(ts_batch, True) + count = ts_batch.timeseries_datapoint_count() + result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) + logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) - for attr_batch in self._splitter.split_attributes(attr_msgs): - payload = JsonMessageAdapter.build_payload(attr_batch, False) - count = len(attr_batch.attributes) - result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) - logger.trace("Built attribute payload with %d attributes", count) + for attr_batch in self._splitter.split_attributes(attr_msgs): + payload = JsonMessageAdapter.build_payload(attr_batch, False) + count = len(attr_batch.attributes) + result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) + logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) logger.trace("Generated %d topic-payload entries.", len(result)) @@ -400,7 +390,7 @@ def build_provision_request(self, provision_request: 'ProvisioningRequest') -> T return topic, payload @staticmethod - def build_payload(msg: GatewayUplinkMessage, build_timeseries_payload) -> bytes: + def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: result: Union[Dict[str, Any], List[Dict[str, Any]]] = {} if build_timeseries_payload: logger.trace("Packing timeseries") @@ -414,12 +404,12 @@ def build_payload(msg: GatewayUplinkMessage, build_timeseries_payload) -> bytes: return payload @staticmethod - def pack_attributes(msg: GatewayUplinkMessage) -> Dict[str, Any]: + def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: logger.trace("Packing %d attribute(s)", len(msg.attributes)) return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def pack_timeseries(msg: 'GatewayUplinkMessage') -> List[Dict[str, Any]]: + def pack_timeseries(msg: 'DeviceUplinkMessage') -> List[Dict[str, Any]]: now_ts = int(datetime.now(UTC).timestamp() * 1000) packed = [ diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 37a9302..ad4e91f 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -183,7 +183,7 @@ async def send_device_timeseries(self, return results[0] if len(results) == 1 else results - async def send_device_attributes(self, device_session: DeviceSession, attributes: ..., wait_for_publish: bool): + async def send_device_attributes(self, device_session: DeviceSession, data: ..., wait_for_publish: bool): pass async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool): diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index e8c85ed..a035476 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -27,7 +27,7 @@ GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index c9d606e..6a2f55f 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -19,7 +19,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.service.device.message_adapter import MessageAdapter @@ -54,7 +54,7 @@ def __init__(self, self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} # Queue expects tuples of (topic, payload, delivery_futures, datapoints_count, qos) - self._queue: asyncio.Queue[Tuple[str, Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) + self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) self._pending_queue_tasks: set[asyncio.Task] = set() self._active = asyncio.Event() self._wakeup_event = asyncio.Event() @@ -67,8 +67,8 @@ def __init__(self, logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", max_queue_size, self._batch_max_time, batch_collect_max_count) - async def publish(self, topic: str, payload: Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: - delivery_futures = payload.get_delivery_futures() if isinstance(payload, GatewayUplinkMessage) or isinstance(payload, GatewayUplinkMessage) else [asyncio.Future()] + async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: + delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) or isinstance(payload, GatewayUplinkMessage) else [asyncio.Future()] try: logger.trace("publish() received delivery future id: %r for topic=%s", id(delivery_futures[0]) if delivery_futures else -1, topic) @@ -112,7 +112,7 @@ async def _dequeue_loop(self): logger.trace("Dequeued message for batching: topic=%s, device=%s", topic, getattr(payload, 'device_name', 'N/A')) - batch: List[Tuple[str, Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = [(topic, payload, delivery_futures_or_none, datapoints, qos)] + batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = [(topic, payload, delivery_futures_or_none, datapoints, qos)] start = asyncio.get_event_loop().time() batch_size = payload.size batch_type = type(payload).__name__ @@ -128,8 +128,8 @@ async def _dequeue_loop(self): try: next_topic, next_payload, delivery_futures_or_none, datapoints, qos = self._queue.get_nowait() - if isinstance(next_payload, GatewayUplinkMessage) or isinstance(next_payload, GatewayUplinkMessage): - if batch_type is not None and batch_type != type(next_payload).__class__.__name__: + if isinstance(next_payload, DeviceUplinkMessage) or isinstance(next_payload, GatewayUplinkMessage): + if batch_type is not None and batch_type != type(next_payload).__name__: logger.trace("Batch type mismatch: current=%s, next=%s, finalizing current", batch_type, type(next_payload).__class__.__name__) self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) @@ -304,7 +304,7 @@ async def retry(): self._retry_tasks.add(task) task.add_done_callback(self._retry_tasks.discard) - async def _wait_for_message(self) -> Tuple[str, Union[bytes, GatewayUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]: + async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]: while self._active.is_set(): try: if not self._queue.empty(): @@ -386,7 +386,7 @@ def clear(self): topic=topic, qos=qos, message_id=-1, - payload_size=message.size if isinstance(message, GatewayUplinkMessage) or isinstance(message, GatewayUplinkMessage) else len(message), + payload_size=message.size if isinstance(message, DeviceUplinkMessage) or isinstance(message, GatewayUplinkMessage) else len(message), reason_code=-1 )) self._queue.task_done() diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index 4189abf..b17062b 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -18,7 +18,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.device_uplink_message import GatewayUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry logger = get_logger(__name__) @@ -38,7 +38,7 @@ def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): logger.trace("MessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", self._max_payload_size, self._max_datapoints) - def split_timeseries(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: + def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: logger.trace("Splitting timeseries for %d messages", len(messages)) if (len(messages) == 1 @@ -47,9 +47,9 @@ def split_timeseries(self, messages: List[GatewayUplinkMessage]) -> List[Gateway and messages[0].size <= self._max_payload_size): return messages - result: List[GatewayUplinkMessage] = [] + result: List[DeviceUplinkMessage] = [] - grouped: Dict[Tuple[str, Optional[str]], List[GatewayUplinkMessage]] = defaultdict(list) + grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) for msg in messages: key = (msg.device_name, msg.device_profile) grouped[key].append(msg) @@ -113,9 +113,9 @@ async def resolve_original(): logger.trace("Total timeseries batches created: %d", len(result)) return result - def split_attributes(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: + def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: logger.trace("Splitting attributes for %d messages", len(messages)) - result: List[GatewayUplinkMessage] = [] + result: List[DeviceUplinkMessage] = [] if (len(messages) == 1 and ((messages[0].attributes_datapoint_count() + messages[ @@ -123,7 +123,7 @@ def split_attributes(self, messages: List[GatewayUplinkMessage]) -> List[Gateway and messages[0].size <= self._max_payload_size): return messages - grouped: Dict[Tuple[str, Optional[str]], List[GatewayUplinkMessage]] = defaultdict(list) + grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) for msg in messages: grouped[(msg.device_name, msg.device_profile)].append(msg) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 2c6fff1..d66e1b8 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -15,7 +15,6 @@ import asyncio import ssl import threading -import time from asyncio import sleep from contextlib import suppress from time import monotonic diff --git a/tests/entities/data/test_device_uplink_message.py b/tests/entities/data/test_device_uplink_message.py index 634c36a..bd90cc7 100644 --- a/tests/entities/data/test_device_uplink_message.py +++ b/tests/entities/data/test_device_uplink_message.py @@ -19,7 +19,7 @@ import pytest from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, GatewayUplinkMessage, \ +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage, \ DEFAULT_FIELDS_SIZE from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry @@ -36,8 +36,8 @@ def timeseries_entry(): def test_direct_instantiation_forbidden(): with pytest.raises(TypeError, match="Direct instantiation of DeviceUplinkMessage is not allowed"): - GatewayUplinkMessage(device_name="test", device_profile="default", attributes=(), timeseries={}, - delivery_futures=[], _size=0) + DeviceUplinkMessage(device_name="test", device_profile="default", attributes=(), timeseries={}, + delivery_futures=[], _size=0) def test_build_empty_message(): diff --git a/tests/service/test_json_message_adapter.py b/tests/service/test_json_message_adapter.py index a958a89..1bcf490 100644 --- a/tests/service/test_json_message_adapter.py +++ b/tests/service/test_json_message_adapter.py @@ -270,13 +270,6 @@ async def test_build_uplink_payloads_multiple_devices(adapter: JsonMessageAdapte assert DEVICE_ATTRIBUTES_TOPIC in topics or DEVICE_TELEMETRY_TOPIC in topics -def test_build_payload_with_device_name(adapter: JsonMessageAdapter): - msg = build_msg(with_ts=True) - payload = adapter.build_payload(msg, True) - assert isinstance(payload, bytes) - assert msg.device_name.encode() in payload - - def test_build_payload_without_device_name(adapter: JsonMessageAdapter): builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 9)) msg = builder.build() diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index bd513b3..935de01 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -367,6 +367,7 @@ async def test_disconnect_reason_code_142_triggers_special_flow(setup_manager): manager._client = MagicMock() manager._backpressure = MagicMock() manager._on_disconnect_callback = AsyncMock() + manager._run_coroutine_sync = MagicMock(return_value=(None, 1, 1)) rate_limit = MagicMock(spec=RateLimit) manager._MQTTManager__rate_limiter = {"messages": rate_limit} @@ -374,15 +375,8 @@ async def test_disconnect_reason_code_142_triggers_special_flow(setup_manager): fut = asyncio.Future() manager._pending_publishes[42] = (fut, "topic", 1, 100, 0) - with patch("asyncio.get_event_loop") as mock_get_loop, \ - patch("asyncio.create_task", side_effect=lambda coro: asyncio.ensure_future(coro)): - - mock_loop = MagicMock() - mock_loop.run_until_complete.return_value = (None, 1, 1) # simulate reached_time = 1 - mock_get_loop.return_value = mock_loop - - manager._on_disconnect_internal(manager._client, reason_code=142) - await asyncio.sleep(0.05) + manager._on_disconnect_internal(manager._client, reason_code=142) + await asyncio.sleep(0.05) assert fut.done() manager._backpressure.notify_disconnect.assert_has_calls([ From 2b4ec62e5d7353c6d0448bf2b7e0216a930d02b5 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 07:07:33 +0300 Subject: [PATCH 46/74] Added send attribbutes functionality --- examples/gateway/send_attributes.py | 84 +++++++++++++++++++ tb_mqtt_client/service/gateway/client.py | 41 ++++++++- .../gateway/gateway_client_interface.py | 10 ++- .../service/gateway/message_adapter.py | 5 +- 4 files changed, 133 insertions(+), 7 deletions(-) create mode 100644 examples/gateway/send_attributes.py diff --git a/examples/gateway/send_attributes.py b/examples/gateway/send_attributes.py new file mode 100644 index 0000000..942fa25 --- /dev/null +++ b/examples/gateway/send_attributes.py @@ -0,0 +1,84 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.service.gateway.client import GatewayClient + + +configure_logging() +logger = get_logger("tb_mqtt_client") + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Test Device A1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + logger.info("Device connected successfully: %s", device_name) + + # Send attributes as raw dictionary + raw_attributes = { + "maintenance": "scheduled", + "id": 341, + } + logger.info("Sending raw attributes: %s", raw_attributes) + await client.send_device_attributes(device_session=device_session, data=raw_attributes, wait_for_publish=True) + logger.info("Raw timeseries sent successfully.") + + # Send single attribute entry + single_attribute = AttributeEntry(key="location", value="office") + logger.info("Sending single attribute: %s", single_attribute) + await client.send_device_attributes(device_session=device_session, data=single_attribute, wait_for_publish=True) + logger.info("Single attribute sent successfully.") + + # Send multiple attribute entries + multiple_attributes = [ + AttributeEntry(key="status", value="active"), + AttributeEntry(key="version", value="1.0.0") + ] + logger.info("Sending multiple attributes: %s", multiple_attributes) + await client.send_device_attributes(device_session=device_session, data=multiple_attributes, wait_for_publish=True) + logger.info("Multiple attributes sent successfully.") + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index ad4e91f..529c181 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -22,6 +22,7 @@ from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage @@ -183,8 +184,44 @@ async def send_device_timeseries(self, return results[0] if len(results) == 1 else results - async def send_device_attributes(self, device_session: DeviceSession, data: ..., wait_for_publish: bool): - pass + async def send_device_attributes(self, + device_session: DeviceSession, + data: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Send attributes data to the platform for a specific device. + :param device_session: The DeviceSession object for the device + :param data: Attributes data to send, can be a single entry or a list of entries + :param wait_for_publish: Whether to wait for the publish result + """ + logger.trace("Sending attributes data for device %s", device_session.device_info.device_name) + if not device_session or not data: + logger.warning("No device session or data provided for sending attributes") + return None + + message = self._build_uplink_message_for_attributes(data, device_session) + topic = mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC + futures = await self._message_queue.publish( + topic=topic, + payload=message, + datapoints_count=message.attributes_datapoint_count(), + qos=self._config.qos + ) + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for attributes publish result") + result = PublishResult(topic, self._config.qos, -1, message.size, -1) + results.append(result) + return results[0] if len(results) == 1 else results + async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool): pass diff --git a/tb_mqtt_client/service/gateway/gateway_client_interface.py b/tb_mqtt_client/service/gateway/gateway_client_interface.py index afe3466..a891136 100644 --- a/tb_mqtt_client/service/gateway/gateway_client_interface.py +++ b/tb_mqtt_client/service/gateway/gateway_client_interface.py @@ -17,6 +17,7 @@ from typing import Union, List, Tuple, Dict, Any from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest @@ -38,13 +39,16 @@ async def disconnect_device(self, device_session: DeviceSession, wait_for_publis async def send_device_timeseries(self, device_session: DeviceSession, data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], - wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... @abstractmethod - async def send_device_attributes(self, device_session: DeviceSession, attributes: ..., wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... + async def send_device_attributes(self, + device_session: DeviceSession, + attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... @abstractmethod async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], - wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index a035476..8d830f4 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -152,7 +152,7 @@ def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tu gateway_timeseries_message[device] = [] gateway_timeseries_delivery_futures[device] = [] if device not in gateway_attributes_message and attr_msgs: - gateway_attributes_message[device] = [] + gateway_attributes_message[device] = {} gateway_attributes_delivery_futures[device] = [] logger.trace("Device '%s' - telemetry: %d, attributes: %d", device, len(timeseries_msgs), len(attr_msgs)) @@ -169,8 +169,9 @@ def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tu for attr_batch in attr_msgs: packed_attrs = JsonGatewayMessageAdapter.pack_attributes(attr_batch) count = attr_batch.attributes_datapoint_count() - gateway_attributes_message[device].extend(packed_attrs) + gateway_attributes_message[device].update(packed_attrs) gateway_attributes_device_datapoints_counts[device] = gateway_attributes_device_datapoints_counts.get(device, 0) + count + gateway_attributes_delivery_futures[device] = attr_batch.get_delivery_futures() logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) if gateway_timeseries_message: From 72c5d4f9f64f6514019632177637fd9c45c09919 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 09:45:09 +0300 Subject: [PATCH 47/74] Added example and processing for device attribute requests for devices connected using gateway client --- ...DEPRECATEDsend_telemetry_and_attributes.py | 42 ----- examples/gateway/request_attributes.py | 97 ++++++++++++ .../entities/data/attribute_request.py | 6 +- .../gateway/device_connect_message.py | 9 +- tb_mqtt_client/entities/gateway/event_type.py | 21 ++- .../gateway/gateway_attribute_request.py | 47 ++++-- .../gateway/gateway_attribute_update.py | 2 +- .../gateway_requested_attribute_response.py | 5 +- .../entities/gateway/gateway_rpc_request.py | 2 +- .../entities/gateway/gateway_rpc_response.py | 2 +- .../gateway/gateway_uplink_message.py | 6 +- tb_mqtt_client/service/gateway/client.py | 104 +++++++++---- .../service/gateway/device_session.py | 24 +-- ...spatcher.py => direct_event_dispatcher.py} | 7 +- .../gateway_attribute_updates_handler.py | 4 +- ...y_requested_attributes_response_handler.py | 2 +- .../gateway/handlers/gateway_rpc_handler.py | 6 +- .../service/gateway/message_adapter.py | 19 +-- .../service/gateway/message_sender.py | 147 ++++++++++++++++++ 19 files changed, 414 insertions(+), 138 deletions(-) delete mode 100644 examples/gateway/DEPRECATEDsend_telemetry_and_attributes.py create mode 100644 examples/gateway/request_attributes.py rename tb_mqtt_client/service/gateway/{event_dispatcher.py => direct_event_dispatcher.py} (93%) create mode 100644 tb_mqtt_client/service/gateway/message_sender.py diff --git a/examples/gateway/DEPRECATEDsend_telemetry_and_attributes.py b/examples/gateway/DEPRECATEDsend_telemetry_and_attributes.py deleted file mode 100644 index d322c32..0000000 --- a/examples/gateway/DEPRECATEDsend_telemetry_and_attributes.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -from tb_gateway_mqtt import TBGatewayMqttClient - -logging.basicConfig(level=logging.INFO) - - -attributes = {"atr1": 1, "atr2": True, "atr3": "value3"} -telemetry_simple = {"ts": int(round(time.time() * 1000)), "values": {"key1": "11"}} -telemetry_array = [ - {"ts": int(round(time.time() * 1000)), "values": {"key1": "11"}}, - {"ts": int(round(time.time() * 1000)), "values": {"key2": "22"}} -] - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - # without device connection it is impossible to get any messages - gateway.connect() - - gateway.gw_send_telemetry("Test Device A2", telemetry_simple) - gateway.gw_send_telemetry("Test Device A2", telemetry_array) - gateway.gw_send_attributes("Test Device A2", attributes) - gateway.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py new file mode 100644 index 0000000..c740ca3 --- /dev/null +++ b/examples/gateway/request_attributes.py @@ -0,0 +1,97 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +configure_logging() +logger = get_logger("tb_mqtt_client") + + +async def requested_attributes_handler(device_session: DeviceSession, response: RequestedAttributeResponse): + """ + Callback to handle requested attributes. + :param device_session: Device session for which attributes were requested. + :param response: Response containing requested attributes. + """ + logger.info("Received attributes for device %s, client attributes: %r, shared attributes: %r", + device_session.device_info.device_name, + response.client, + response.shared) + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Test Device A1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + # Register callback for requested attributes + device_session.set_attribute_response_callback(requested_attributes_handler) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + logger.info("Device connected successfully: %s", device_name) + + # Send attributes to request them later + attributes = [ + AttributeEntry(key="maintenance", value="scheduled"), + AttributeEntry(key="id", value=341), + AttributeEntry(key="location", value="office") + ] + logger.info("Sending attributes: %s", attributes) + await client.send_device_attributes(device_session=device_session, data=attributes, wait_for_publish=True) + logger.info("Attributes sent successfully.") + + + # Request attributes for the device + logger.info("Requesting attributes for device: %s", device_name) + attributes_to_request = ["maintenance", "id", "location"] + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=attributes_to_request) + + await client.send_device_attributes_request(device_session, attribute_request, wait_for_publish=True) + + await asyncio.sleep(2) # Wait for the response to be processed + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py index 24c9c8c..40ff960 100644 --- a/tb_mqtt_client/entities/data/attribute_request.py +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -16,10 +16,12 @@ from typing import Optional, List, Dict from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType @dataclass(slots=True, frozen=True) -class AttributeRequest: +class AttributeRequest(BaseGatewayEvent): """ Represents a request for device attributes, with optional client and shared attribute keys. Automatically assigns a unique request ID via the build() method. @@ -27,6 +29,7 @@ class AttributeRequest: request_id: int shared_keys: Optional[List[str]] = None client_keys: Optional[List[str]] = None + event_type: GatewayEventType = GatewayEventType.DEVICE_ATTRIBUTE_REQUEST def __new__(self, *args, **kwargs): raise TypeError("Direct instantiation of AttributeRequest is not allowed. Use 'await AttributeRequest.build(...)'.") @@ -46,6 +49,7 @@ async def build(cls, shared_keys: Optional[List[str]] = None, client_keys: Optio object.__setattr__(self, 'request_id', request_id) object.__setattr__(self, 'shared_keys', shared_keys) object.__setattr__(self, 'client_keys', client_keys) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) return self def to_payload_format(self) -> Dict[str, str]: diff --git a/tb_mqtt_client/entities/gateway/device_connect_message.py b/tb_mqtt_client/entities/gateway/device_connect_message.py index a760c64..d0a685d 100644 --- a/tb_mqtt_client/entities/gateway/device_connect_message.py +++ b/tb_mqtt_client/entities/gateway/device_connect_message.py @@ -15,17 +15,21 @@ from dataclasses import dataclass from typing import Dict +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + @dataclass(slots=True, frozen=True) -class DeviceConnectMessage: +class DeviceConnectMessage(BaseGatewayEvent): """ Represents a device connection message in the ThingsBoard Gateway MQTT client. This class is used to encapsulate the details of a device connection message. """ device_name: str device_profile: str = 'default' + event_type: GatewayEventType = GatewayEventType.DEVICE_CONNECT - def __new__(self, *args, **kwargs): + def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of DeviceConnectMessage is not allowed. Use 'await DeviceConnectMessage.build(...)'.") def __repr__(self): @@ -41,6 +45,7 @@ def build(cls, device_name: str, device_profile: str = 'default') -> 'DeviceConn self = object.__new__(cls) object.__setattr__(self, 'device_name', device_name) object.__setattr__(self, 'device_profile', device_profile) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_CONNECT) return self def to_payload_format(self) -> Dict[str, str]: diff --git a/tb_mqtt_client/entities/gateway/event_type.py b/tb_mqtt_client/entities/gateway/event_type.py index c6c4e3e..792d322 100644 --- a/tb_mqtt_client/entities/gateway/event_type.py +++ b/tb_mqtt_client/entities/gateway/event_type.py @@ -21,16 +21,21 @@ class GatewayEventType(Enum): """ GATEWAY_CONNECT = "gateway.connect" GATEWAY_DISCONNECT = "gateway.disconnect" - DEVICE_ADD = "gateway.device.add" - DEVICE_REMOVE = "gateway.device.remove" + + DEVICE_CONNECT = "gateway.device.connect" + DEVICE_DISCONNECT = "gateway.device.disconnect" DEVICE_UPDATE = "gateway.device.update" + + DEVICE_UPLINK = "gateway.device.uplink" + DEVICE_ATTRIBUTE_REQUEST = "gateway.device.attribute.request" + DEVICE_ATTRIBUTE_UPDATE = "gateway.device.attribute.update" + DEVICE_REQUESTED_ATTRIBUTE_RESPONSE = "gateway.device.requested.attribute.response" + DEVICE_SESSION_STATE_CHANGE = "gateway.device.session.state.change" - DEVICE_RPC_REQUEST_RECEIVE = "gateway.device.rpc.request.receive" - DEVICE_RPC_RESPONSE_SEND = "gateway.device.rpc.response.send" - DEVICE_ATTRIBUTE_UPDATE_RECEIVE = "gateway.device.attribute.update.receive" - DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVE = "gateway.device.requested.attribute.response.receive" - RPC_REQUEST_RECEIVE = "device.rpc.request.receive" - RPC_RESPONSE_SEND = "device.rpc.response.send" + DEVICE_RPC_REQUEST = "gateway.device.rpc.request" + DEVICE_RPC_RESPONSE = "gateway.device.rpc.response" + RPC_REQUEST_RECEIVE = "device.rpc.request" + RPC_RESPONSE_SEND = "device.rpc.response" def __str__(self) -> str: return self.value diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py index 803917f..f56e959 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py @@ -16,30 +16,27 @@ from typing import Optional, List, Dict, Union from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.gateway.event_type import GatewayEventType -from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession @dataclass(slots=True, frozen=True) -class GatewayAttributeRequest(BaseGatewayEvent): +class GatewayAttributeRequest(AttributeRequest): """ Represents a request for device attributes, with optional client and shared attribute keys. Automatically assigns a unique request ID via the build() method. """ - request_id: int - device_name: str - shared_keys: Optional[List[str]] = None - client_keys: Optional[List[str]] = None - event_type = GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVE + device_session: DeviceSession = None # type: ignore[assignment] def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of GatewayAttributeRequest is not allowed. Use 'await GatewayAttributeRequest.build(...)'.") def __repr__(self) -> str: - return f"GatewayAttributeRequest(device_name={self.device_name}, id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" + return f"GatewayAttributeRequest(device_session={self.device_session}, id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" @classmethod - async def build(cls, device_name: str, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'GatewayAttributeRequest': + async def build(cls, device_session: DeviceSession, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'GatewayAttributeRequest': # noqa """ Build a new GatewayAttributeRequest with a unique request ID, using the global ID generator. """ @@ -47,23 +44,41 @@ async def build(cls, device_name: str, shared_keys: Optional[List[str]] = None, validate_json_compatibility(client_keys) request_id = await AttributeRequestIdProducer.get_next() self = object.__new__(cls) - object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_session', device_session) object.__setattr__(self, 'request_id', request_id) object.__setattr__(self, 'shared_keys', shared_keys) object.__setattr__(self, 'client_keys', client_keys) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) return self + @classmethod + async def from_attribute_request(cls, device_session: DeviceSession, attribute_request: AttributeRequest) -> 'GatewayAttributeRequest': + """ + Create a GatewayAttributeRequest from an existing AttributeRequest and a DeviceSession. + """ + if not isinstance(attribute_request, AttributeRequest): + raise TypeError("attribute_request must be an instance of AttributeRequest") + self = object.__new__(cls) + object.__setattr__(self, 'device_session', device_session) + object.__setattr__(self, 'request_id', attribute_request.request_id) + object.__setattr__(self, 'shared_keys', attribute_request.shared_keys) + object.__setattr__(self, 'client_keys', attribute_request.client_keys) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) + return self + + def to_payload_format(self) -> Dict[str, Union[str, bool]]: """ Convert the attribute request into the expected MQTT payload format. """ - payload = {"device": self.device_name, "id": str(self.request_id)} - request_key = 'key' if len(self.client_keys) == 1 or len(self.shared_keys) == 1 else 'keys' - if self.client_keys: + payload = {"device": self.device_session.device_info.device_name, "id": self.request_id} + single_key_request = (self.client_keys is not None and len(self.client_keys) == 1) or (self.shared_keys is not None and len(self.shared_keys) == 1) + request_key = 'key' if single_key_request else 'keys' + if self.client_keys is not None and self.client_keys: payload['client'] = True - payload[request_key] = ','.join(self.client_keys) - elif self.shared_keys: + payload[request_key] = self.client_keys[0] if single_key_request else self.client_keys + elif self.shared_keys is not None and self.shared_keys: # TODO: In current realisation on server it is not possible to request values for the both scopes simultaneously, recommended to improve the platform API payload['client'] = False - payload[request_key] = ','.join(self.shared_keys) + payload[request_key] = self.shared_keys[0] if single_key_request else self.shared_keys return payload diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py index 0528c8f..411370c 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -24,7 +24,7 @@ class GatewayAttributeUpdate(GatewayEvent): """ def __init__(self, device_name: str, attribute_update: AttributeUpdate): - super().__init__(event_type=GatewayEventType.DEVICE_ATTRIBUTE_UPDATE_RECEIVE) + super().__init__(event_type=GatewayEventType.DEVICE_ATTRIBUTE_UPDATE) self.device_name = device_name self.attribute_update = attribute_update diff --git a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py index 5b66b8a..5faf387 100644 --- a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py @@ -18,17 +18,20 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType logger = get_logger(__name__) @dataclass(slots=True, frozen=True) -class GatewayRequestedAttributeResponse(RequestedAttributeResponse): +class GatewayRequestedAttributeResponse(RequestedAttributeResponse, BaseGatewayEvent): device_name: str = "" request_id: int = -1 shared: Optional[List[AttributeEntry]] = None client: Optional[List[AttributeEntry]] = None + event_type: GatewayEventType = GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE def __repr__(self): return f"GatewayRequestedAttributeResponse(device_name={self.device_name},request_id={self.request_id}, shared={self.shared}, client={self.client})" diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py index a1af800..c756799 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -25,7 +25,7 @@ class GatewayRPCRequest(BaseGatewayEvent): device_name: str method: str params: Optional[Any] = None - event_type: GatewayEventType = GatewayEventType.RPC_REQUEST_RECEIVE + event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_REQUEST def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of GatewayRPCRequest is not allowed. Use 'await GatewayRPCRequest.build(...)'.") diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py index 80dcb39..6e0fa52 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py @@ -34,7 +34,7 @@ class GatewayRPCResponse(RPCResponse, BaseGatewayEvent): error: Optional error information if the RPC failed. """ device_name: str = None - event_type: GatewayEventType = GatewayEventType.RPC_RESPONSE_SEND + event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_RESPONSE def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of GatewayRPCResponse is not allowed. Use GatewayRPCResponse.build(device_name, request_id, result, error).") diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index caf435e..bfcada4 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -22,6 +22,8 @@ from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType logger = get_logger(__name__) @@ -29,9 +31,10 @@ @dataclass(slots=True, frozen=True) -class GatewayUplinkMessage(DeviceUplinkMessage): +class GatewayUplinkMessage(DeviceUplinkMessage, BaseGatewayEvent): device_name: str device_profile: str + event_type: GatewayEventType = GatewayEventType.DEVICE_UPLINK def __new__(cls, *args, **kwargs): raise TypeError( @@ -60,6 +63,7 @@ def build(cls, # noqa MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) object.__setattr__(self, '_size', size) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_UPLINK) return self @property diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 529c181..12cea25 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -26,18 +26,20 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder from tb_mqtt_client.service.device.client import DeviceClient from tb_mqtt_client.service.gateway.device_manager import DeviceManager from tb_mqtt_client.service.gateway.device_session import DeviceSession -from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher from tb_mqtt_client.service.gateway.gateway_client_interface import GatewayClientInterface from tb_mqtt_client.service.gateway.handlers.gateway_attribute_updates_handler import GatewayAttributeUpdatesHandler from tb_mqtt_client.service.gateway.handlers.gateway_requested_attributes_response_handler import \ GatewayRequestedAttributeResponseHandler from tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler import GatewayRPCHandler from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter, JsonGatewayMessageAdapter +from tb_mqtt_client.service.gateway.message_sender import GatewayMessageSender logger = get_logger(__name__) @@ -60,8 +62,16 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): super().__init__(self._config) self._device_manager = DeviceManager() - self._event_dispatcher: EventDispatcher = EventDispatcher() + + self._event_dispatcher: DirectEventDispatcher = DirectEventDispatcher() + self._uplink_message_sender = GatewayMessageSender() + self._event_dispatcher.register(GatewayEventType.DEVICE_CONNECT, self._uplink_message_sender.send_device_connect) + self._event_dispatcher.register(GatewayEventType.DEVICE_DISCONNECT, self._uplink_message_sender.send_device_disconnect) + self._event_dispatcher.register(GatewayEventType.DEVICE_UPLINK, self._uplink_message_sender.send_uplink_message) + self._event_dispatcher.register(GatewayEventType.DEVICE_ATTRIBUTE_REQUEST, self._uplink_message_sender.send_attributes_request) + self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter() + self._uplink_message_sender.set_message_adapter(self._gateway_message_adapter) self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed self._gateway_rpc_handler = GatewayRPCHandler(event_dispatcher=self._event_dispatcher, @@ -91,6 +101,7 @@ async def connect(self): logger.info("Connecting gateway to platform at %s:%s", self._host, self._port) await super().connect() self._message_queue.set_gateway_message_adapter(self._gateway_message_adapter) + self._uplink_message_sender.set_message_queue(self._message_queue) # Subscribe to gateway-specific topics await self._subscribe_to_gateway_topics() @@ -98,25 +109,30 @@ async def connect(self): logger.info("Gateway connected to ThingsBoard.") - async def connect_device(self, device_name: str, device_profile: str, wait_for_publish=False) -> Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: + async def connect_device(self, + device_name_or_device_connect_message: Union[str, DeviceConnectMessage], + device_profile: str, + wait_for_publish=False) -> Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: """ Connect a device to the gateway. - :param device_name: Name of the device to connect + :param device_name_or_device_connect_message: Name of the device or a DeviceConnectMessage object :param device_profile: Profile of the device :param wait_for_publish: Whether to wait for the publish result :return: Tuple containing the DeviceSession and an optional PublishResult or list of PublishResults """ - logger.info("Connecting device %s with profile %s", device_name, device_profile) - device_session = self._device_manager.register(device_name, device_profile) - device_connect_message = DeviceConnectMessage.build(device_name, device_profile) - topic, payload = self._gateway_message_adapter.build_device_connect_message_payload(device_connect_message) - futures = await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=1, - qos=self._config.qos - ) + if not isinstance(device_name_or_device_connect_message, DeviceConnectMessage): + device_name = device_name_or_device_connect_message + device_connect_message = DeviceConnectMessage.build(device_name, device_profile) + else: + device_connect_message = device_name_or_device_connect_message + + logger.info("Connecting device %s with profile %s", + device_connect_message.device_name, + device_connect_message.device_profile) + device_session = self._device_manager.register(device_connect_message.device_name, + device_connect_message.device_profile) + futures = await self._event_dispatcher.dispatch(device_connect_message, qos=self._config.qos) # noqa if not futures: logger.warning("No publish futures were returned from message queue") @@ -131,12 +147,12 @@ async def connect_device(self, device_name: str, device_profile: str, wait_for_p result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for telemetry publish result") - result = PublishResult(topic, self._config.qos, -1, len(payload), -1) + result = PublishResult(mqtt_topics.GATEWAY_CONNECT_TOPIC, self._config.qos, -1, -1, -1) results.append(result) return device_session, results[0] if len(results) == 1 else results - async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool): + async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool): pass async def send_device_timeseries(self, @@ -157,13 +173,7 @@ async def send_device_timeseries(self, return None message = self._build_uplink_message_for_telemetry(data, device_session) - topic = mqtt_topics.GATEWAY_TELEMETRY_TOPIC - futures = await self._message_queue.publish( - topic=topic, - payload=message, - datapoints_count=message.timeseries_datapoint_count(), - qos=self._config.qos - ) + futures = await self._event_dispatcher.dispatch(message, qos=self._config.qos) # noqa if not futures: logger.warning("No publish futures were returned from message queue") @@ -178,7 +188,7 @@ async def send_device_timeseries(self, result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for telemetry publish result") - result = PublishResult(topic, self._config.qos, -1, message.size, -1) + result = PublishResult(mqtt_topics.GATEWAY_TELEMETRY_TOPIC, self._config.qos, -1, message.size, -1) results.append(result) return results[0] if len(results) == 1 else results @@ -200,13 +210,7 @@ async def send_device_attributes(self, return None message = self._build_uplink_message_for_attributes(data, device_session) - topic = mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC - futures = await self._message_queue.publish( - topic=topic, - payload=message, - datapoints_count=message.attributes_datapoint_count(), - qos=self._config.qos - ) + futures = await self._event_dispatcher.dispatch(message, qos=self._config.qos) # noqa if not futures: logger.warning("No publish futures were returned from message queue") return None @@ -218,13 +222,45 @@ async def send_device_attributes(self, result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for attributes publish result") - result = PublishResult(topic, self._config.qos, -1, message.size, -1) + result = PublishResult(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, self._config.qos, -1, message.size, -1) results.append(result) return results[0] if len(results) == 1 else results - async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool): - pass + async def send_device_attributes_request(self, device_session: DeviceSession, attribute_request: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool): + """ + Send a request for device attributes to the platform. + :param device_session: The DeviceSession object for the device + :param attribute_request: Attributes to request, can be a single AttributeRequest or GatewayAttributeRequest + :param wait_for_publish: Whether to wait for the publish result + """ + logger.trace("Sending attributes request for device %s", device_session.device_info.device_name) + if not device_session or not attribute_request: + logger.warning("No device session or attributes provided for sending attributes request") + return None + if isinstance(attribute_request, AttributeRequest): + attribute_request = await GatewayAttributeRequest.from_attribute_request(device_session=device_session, attribute_request=attribute_request) + + await self._gateway_requested_attribute_response_handler.register_request(attribute_request) + futures = await self._event_dispatcher.dispatch(attribute_request, qos=self._config.qos) + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for attributes request publish result") + result = PublishResult(mqtt_topics.GATEWAY_ATTRIBUTES_REQUEST_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + return results[0] if len(results) == 1 else results async def disconnect(self): """ diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index 0411ccf..6ef5348 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -37,9 +37,9 @@ class DeviceSession: provisioned: bool = False state: DeviceSessionState = DeviceSessionState.CONNECTED - attribute_update_callback: Optional[Callable[['AttributeUpdate'], Awaitable[None]]] = None - attribute_response_callback: Optional[Callable[['RequestedAttributeResponse'], Awaitable[None]]] = None - rpc_request_callback: Optional[Callable[['GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]] = None + attribute_update_callback: Optional[Callable[['DeviceSession','AttributeUpdate'], Awaitable[None]]] = None + attribute_response_callback: Optional[Callable[['DeviceSession','RequestedAttributeResponse'], Awaitable[None]]] = None + rpc_request_callback: Optional[Callable[['DeviceSession','GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]] = None def update_state(self, new_state: DeviceSessionState): self.state = new_state @@ -49,26 +49,26 @@ def update_state(self, new_state: DeviceSessionState): def update_last_seen(self): self.last_seen_at = int(time() * 1000) - def set_attribute_update_callback(self, cb: Callable[['AttributeUpdate'], Awaitable[None]]): + def set_attribute_update_callback(self, cb: Callable[['DeviceSession','AttributeUpdate'], Awaitable[None]]): self.attribute_update_callback = cb - def set_attribute_response_callback(self, cb: Callable[['RequestedAttributeResponse'], Awaitable[None]]): + def set_attribute_response_callback(self, cb: Callable[['DeviceSession','RequestedAttributeResponse'], Awaitable[None]]): self.attribute_response_callback = cb - def set_rpc_request_callback(self, cb: Callable[['GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]): + def set_rpc_request_callback(self, cb: Callable[['DeviceSession','GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]): self.rpc_request_callback = cb async def handle_event_to_device(self, event: BaseGatewayEvent) -> Optional[Awaitable[Union['GatewayRPCResponse', None]]]: - if GatewayEventType.DEVICE_ATTRIBUTE_UPDATE_RECEIVE == event.event_type \ + if GatewayEventType.DEVICE_ATTRIBUTE_UPDATE == event.event_type \ and isinstance(event, AttributeUpdate): if self.attribute_update_callback: - return self.attribute_update_callback(event) - elif GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE_RECEIVE == event.event_type \ + return await self.attribute_update_callback(self, event) + elif GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE == event.event_type \ and isinstance(event, RequestedAttributeResponse): if self.attribute_response_callback: - return self.attribute_response_callback(event) - elif GatewayEventType.RPC_REQUEST_RECEIVE == event.event_type \ + return await self.attribute_response_callback(self, event) + elif GatewayEventType.DEVICE_RPC_REQUEST == event.event_type \ and isinstance(event, GatewayRPCRequest): if self.rpc_request_callback: - return self.rpc_request_callback(event) + return await self.rpc_request_callback(self, event) return None diff --git a/tb_mqtt_client/service/gateway/event_dispatcher.py b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py similarity index 93% rename from tb_mqtt_client/service/gateway/event_dispatcher.py rename to tb_mqtt_client/service/gateway/direct_event_dispatcher.py index 8c40ee6..5616590 100644 --- a/tb_mqtt_client/service/gateway/event_dispatcher.py +++ b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py @@ -26,7 +26,7 @@ logger = get_logger(__name__) -class EventDispatcher: +class DirectEventDispatcher: """ Direct event dispatcher for handling gateway events. """ @@ -52,8 +52,9 @@ async def dispatch(self, event: GatewayEvent, *args, device_session: DeviceSessi for cb in callbacks: try: if asyncio.iscoroutinefunction(cb): - await cb(event, *args, **kwargs) + return await cb(event, *args, **kwargs) else: - cb(event, *args, **kwargs) + return cb(event, *args, **kwargs) except Exception as e: logger.error(f"[EventDispatcher] Exception in handler for '{event.event_type}': {e}") + return None diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py index d7573dc..46d8e2a 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py @@ -13,14 +13,14 @@ # limitations under the License. from tb_mqtt_client.service.gateway.device_manager import DeviceManager -from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter class GatewayAttributeUpdatesHandler: """Handles shared attribute updates for devices connected to a gateway.""" def __init__(self, - event_dispatcher: EventDispatcher, + event_dispatcher: DirectEventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): self.event_dispatcher = event_dispatcher diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py index 869b4e5..863c815 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py @@ -47,7 +47,7 @@ async def register_request(self, Called when a request is sent to the platform and a response is awaited. """ request_id = request.request_id - device_name = request.device_name + device_name = request.device_session.device_info.device_name key = (device_name, request_id) if key in self._pending_attribute_requests: raise RuntimeError(f"Request ID {request.request_id} is already registered.") diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py index f7446a1..04ffd88 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py @@ -19,7 +19,7 @@ from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.service.gateway.device_manager import DeviceManager -from tb_mqtt_client.service.gateway.event_dispatcher import EventDispatcher +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter logger = get_logger(__name__) @@ -30,12 +30,12 @@ class GatewayRPCHandler: Handles incoming RPC request messages for a device connected through the gateway. """ - def __init__(self, event_dispatcher: EventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + def __init__(self, event_dispatcher: DirectEventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): self._event_dispatcher = event_dispatcher self._message_adapter = message_adapter self._device_manager = device_manager self._callback: Optional[Callable[[GatewayRPCRequest], Awaitable[GatewayRPCResponse]]] = None - self._event_dispatcher.register(GatewayEventType.RPC_REQUEST_RECEIVE, self.handle) + self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_REQUEST, self.handle) async def handle(self, topic: str, payload: bytes) -> Optional[GatewayRPCResponse]: """ diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 8d830f4..32a8ddf 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -24,7 +24,7 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ - GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage @@ -235,8 +235,9 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr """ try: payload = dumps(attribute_request.to_payload_format()) - logger.trace("Built gateway attribute request payload for device='%s'", attribute_request.device_name) - return GATEWAY_ATTRIBUTES_TOPIC, payload + logger.trace("Built gateway attribute request payload for device='%s'", + attribute_request.device_session.device_info.device_name) + return GATEWAY_ATTRIBUTES_REQUEST_TOPIC, payload except Exception as e: logger.error("Failed to build gateway attribute request payload: %s", str(e)) raise ValueError("Invalid gateway attribute request format") from e @@ -268,16 +269,16 @@ def parse_gateway_requested_attribute_response(self, gateway_attribute_request: data['value']) return None elif 'value' in data: - if len(gateway_attribute_request.client_keys) == 1: + if gateway_attribute_request.client_keys is not None and len(gateway_attribute_request.client_keys) == 1: client = [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] - elif len(gateway_attribute_request.shared_keys) == 1: + elif gateway_attribute_request.shared_keys is not None and len(gateway_attribute_request.shared_keys) == 1: shared = [AttributeEntry(gateway_attribute_request.shared_keys[0], data['value'])] elif 'values' in data: - if len(gateway_attribute_request.client_keys) > 0: - client = [AttributeEntry(k, v) for k, v in data['data'].get('values', {}).items() if + if gateway_attribute_request.client_keys is not None and len(gateway_attribute_request.client_keys) > 0: + client = [AttributeEntry(k, v) for k, v in data['values'].items() if k in gateway_attribute_request.client_keys] - if len(gateway_attribute_request.shared_keys) > 0: - shared = [AttributeEntry(k, v) for k, v in data['data'].get('values', {}).items() if + if gateway_attribute_request.shared_keys is not None and len(gateway_attribute_request.shared_keys) > 0: + shared = [AttributeEntry(k, v) for k, v in data['values'].items() if k in gateway_attribute_request.shared_keys] return GatewayRequestedAttributeResponse(device_name=device_name, request_id=gateway_attribute_request.request_id, shared=shared, client=client) except Exception as e: diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py new file mode 100644 index 0000000..acd6d9a --- /dev/null +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -0,0 +1,147 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Future +from typing import List, Union, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.message_queue import MessageQueue + +logger = get_logger(__name__) + + +class GatewayMessageSender: + """ + Class responsible for sending uplink messages from devices connected to the gateway to the platform. + It handles the serialization of the message and sends to uplink message queue. + """ + + def __init__(self): + self._message_queue: Optional[MessageQueue] = None + self._message_adapter: Optional[GatewayMessageAdapter] = None + + async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Sends a list of uplink messages to the platform. + + :param message: List of GatewayUplinkMessage objects to be sent. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send uplink messages. Message queue is not set, do you connected to the platform?") + return None + if not message.has_timeseries() and not message.has_attributes(): + logger.warning("Uplink message does not contain timeseries or attributes, nothing to send.") + return None + futures = [] + if message.has_timeseries(): + topic = mqtt_topics.GATEWAY_TELEMETRY_TOPIC + timeseries_futures = await self._message_queue.publish( + topic=topic, + payload=message, + datapoints_count=message.timeseries_datapoint_count(), + qos=qos + ) + futures.extend(timeseries_futures) + if message.has_attributes(): + topic = mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC + attributes_futures = await self._message_queue.publish( + topic=topic, + payload=message, + datapoints_count=message.attributes_datapoint_count(), + qos=qos + ) + futures.extend(attributes_futures) + return futures + + async def send_device_connect(self, device_connect_message: DeviceConnectMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Sends a device connect message to the platform. + + :param device_connect_message: DeviceConnectMessage object containing the device connection details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send device connect message. Message queue is not set, do you connected to the platform?") + return None + topic, payload = self._message_adapter.build_device_connect_message_payload(device_connect_message=device_connect_message) + return await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=1, + qos=qos + ) + + async def send_device_disconnect(self, device_disconnect_message: DeviceDisconnectMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Sends a device disconnect message to the platform. + + :param device_disconnect_message: DeviceDisconnectMessage object containing the device disconnection details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send device disconnect message. Message queue is not set, do you connected to the platform?") + return None + topic, payload = self._message_adapter.build_device_disconnect_message_payload(device_disconnect_message=device_disconnect_message) + return await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=1, + qos=qos + ) + + async def send_attributes_request(self, attribute_request: GatewayAttributeRequest, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Sends an attribute request message to the platform. + + :param attribute_request: GatewayAttributeRequest object containing the attributes to be requested. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send attribute request. Message queue is not set, do you connected to the platform?") + return None + topic, payload = self._message_adapter.build_gateway_attribute_request_payload(attribute_request=attribute_request) + return await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=1, + qos=qos + ) + + def set_message_queue(self, message_queue: MessageQueue): + """ + Sets the message queue for sending uplink messages. + + :param message_queue: An instance of MessageQueue to be used for sending messages. + """ + self._message_queue = message_queue + + def set_message_adapter(self, message_adapter: GatewayMessageAdapter): + """ + Sets the message adapter for serializing uplink messages. + + :param message_adapter: An instance of GatewayMessageAdapter to be used for serializing messages. + """ + self._message_adapter = message_adapter From 5704f719c24d1029943821ea48a9fb4cc896b8e5 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 10:18:15 +0300 Subject: [PATCH 48/74] Refactored example callback assignment and device events handling --- examples/gateway/request_attributes.py | 2 +- tb_mqtt_client/service/gateway/client.py | 12 ++++---- .../service/gateway/device_session.py | 30 ++++++++++++------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index c740ca3..33c1559 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -54,7 +54,7 @@ async def main(): device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) # Register callback for requested attributes - device_session.set_attribute_response_callback(requested_attributes_handler) + client.device_manager.set_attribute_response_callback(device_session.device_info.device_id, requested_attributes_handler) if not device_session: logger.error("Failed to register device: %s", device_name) diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 12cea25..1eb31b6 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -61,7 +61,7 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._config = config if isinstance(config, GatewayConfig) else GatewayConfig(config) super().__init__(self._config) - self._device_manager = DeviceManager() + self.device_manager = DeviceManager() self._event_dispatcher: DirectEventDispatcher = DirectEventDispatcher() self._uplink_message_sender = GatewayMessageSender() @@ -76,13 +76,13 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed self._gateway_rpc_handler = GatewayRPCHandler(event_dispatcher=self._event_dispatcher, message_adapter=self._gateway_message_adapter, - device_manager=self._device_manager) + device_manager=self.device_manager) self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler(event_dispatcher=self._event_dispatcher, message_adapter=self._gateway_message_adapter, - device_manager=self._device_manager) + device_manager=self.device_manager) self._gateway_requested_attribute_response_handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=self._event_dispatcher, message_adapter=self._gateway_message_adapter, - device_manager=self._device_manager) + device_manager=self.device_manager) # Gateway-specific rate limits self._device_messages_rate_limit = RateLimit("10:1,", name="device_messages") @@ -130,8 +130,8 @@ async def connect_device(self, logger.info("Connecting device %s with profile %s", device_connect_message.device_name, device_connect_message.device_profile) - device_session = self._device_manager.register(device_connect_message.device_name, - device_connect_message.device_profile) + device_session = self.device_manager.register(device_connect_message.device_name, + device_connect_message.device_profile) futures = await self._event_dispatcher.dispatch(device_connect_message, qos=self._config.qos) # noqa if not futures: diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index 6ef5348..93711c0 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from time import time from dataclasses import dataclass, field from typing import Callable, Awaitable, Optional, Union @@ -37,9 +37,9 @@ class DeviceSession: provisioned: bool = False state: DeviceSessionState = DeviceSessionState.CONNECTED - attribute_update_callback: Optional[Callable[['DeviceSession','AttributeUpdate'], Awaitable[None]]] = None - attribute_response_callback: Optional[Callable[['DeviceSession','RequestedAttributeResponse'], Awaitable[None]]] = None - rpc_request_callback: Optional[Callable[['DeviceSession','GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]] = None + attribute_update_callback: Optional[Callable[['DeviceSession','AttributeUpdate'], Union[Awaitable[None], None]]] = None + attribute_response_callback: Optional[Callable[['DeviceSession','RequestedAttributeResponse'], Union[Awaitable[None], None]]] = None + rpc_request_callback: Optional[Callable[['DeviceSession','GatewayRPCRequest'], Union[Awaitable[Union['GatewayRPCResponse', None]], None]]] = None def update_state(self, new_state: DeviceSessionState): self.state = new_state @@ -49,26 +49,34 @@ def update_state(self, new_state: DeviceSessionState): def update_last_seen(self): self.last_seen_at = int(time() * 1000) - def set_attribute_update_callback(self, cb: Callable[['DeviceSession','AttributeUpdate'], Awaitable[None]]): + def set_attribute_update_callback(self, cb: Callable[['DeviceSession','AttributeUpdate'], Union[Awaitable[None], None]]): self.attribute_update_callback = cb - def set_attribute_response_callback(self, cb: Callable[['DeviceSession','RequestedAttributeResponse'], Awaitable[None]]): + def set_attribute_response_callback(self, cb: Callable[['DeviceSession','RequestedAttributeResponse'], Union[Awaitable[None], None]]): self.attribute_response_callback = cb - def set_rpc_request_callback(self, cb: Callable[['DeviceSession','GatewayRPCRequest'], Awaitable[Union['GatewayRPCResponse', None]]]): + def set_rpc_request_callback(self, cb: Callable[['DeviceSession','GatewayRPCRequest'], Union[Awaitable[Union['GatewayRPCResponse', None]], None]]): self.rpc_request_callback = cb async def handle_event_to_device(self, event: BaseGatewayEvent) -> Optional[Awaitable[Union['GatewayRPCResponse', None]]]: + cb = None if GatewayEventType.DEVICE_ATTRIBUTE_UPDATE == event.event_type \ and isinstance(event, AttributeUpdate): if self.attribute_update_callback: - return await self.attribute_update_callback(self, event) + cb = self.attribute_update_callback elif GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE == event.event_type \ and isinstance(event, RequestedAttributeResponse): if self.attribute_response_callback: - return await self.attribute_response_callback(self, event) + cb = self.attribute_response_callback elif GatewayEventType.DEVICE_RPC_REQUEST == event.event_type \ and isinstance(event, GatewayRPCRequest): if self.rpc_request_callback: - return await self.rpc_request_callback(self, event) - return None + cb = self.rpc_request_callback + + if cb is None: + return None + + if asyncio.iscoroutinefunction(cb): + return await cb(self, event) + else: + return cb(self, event) From e128d81e055aa53a1efde96e395b07470fe01620 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 12:17:30 +0300 Subject: [PATCH 49/74] Added processing RPC requests and example to gateway part --- .../gateway/DEPRECATEDrequest_attributes.py | 46 ------- examples/gateway/DEPRECATEDrespond_to_rpc.py | 59 --------- examples/gateway/handle_rpc_requests.py | 114 ++++++++++++++++++ examples/gateway/request_attributes.py | 3 +- tb_mqtt_client/constants/mqtt_topics.py | 14 --- .../entities/gateway/base_gateway_event.py | 3 + .../entities/gateway/gateway_rpc_request.py | 8 +- .../entities/gateway/gateway_rpc_response.py | 7 +- tb_mqtt_client/service/gateway/client.py | 4 +- .../service/gateway/device_manager.py | 1 - .../gateway/handlers/gateway_rpc_handler.py | 48 ++++++-- .../service/gateway/message_adapter.py | 25 +++- .../service/gateway/message_sender.py | 19 +++ 13 files changed, 211 insertions(+), 140 deletions(-) delete mode 100644 examples/gateway/DEPRECATEDrequest_attributes.py delete mode 100644 examples/gateway/DEPRECATEDrespond_to_rpc.py create mode 100644 examples/gateway/handle_rpc_requests.py diff --git a/examples/gateway/DEPRECATEDrequest_attributes.py b/examples/gateway/DEPRECATEDrequest_attributes.py deleted file mode 100644 index addfb7c..0000000 --- a/examples/gateway/DEPRECATEDrequest_attributes.py +++ /dev/null @@ -1,46 +0,0 @@ - -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.INFO) - - -def callback(result, exception=None): - if exception is not None: - logging.error("Exception: " + str(exception)) - else: - logging.info(result) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - # Requesting attributes - gateway.gw_request_shared_attributes("Example Name", ["temperature"], callback) - - try: - # Waiting for the callback - while not gateway.stopped: - time.sleep(1) - except KeyboardInterrupt: - gateway.disconnect() - gateway.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/DEPRECATEDrespond_to_rpc.py b/examples/gateway/DEPRECATEDrespond_to_rpc.py deleted file mode 100644 index deca40d..0000000 --- a/examples/gateway/DEPRECATEDrespond_to_rpc.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging.handlers -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) -logging.basicConfig(level=logging.INFO) - - -def rpc_request_response(gateway, request_body): - # request body contains id, method and other parameters - logging.info(request_body) - method = request_body["data"]["method"] - device = request_body["device"] - req_id = request_body["data"]["id"] - # dependently of request method we send different data back - if method == 'getCPULoad': - gateway.gw_send_rpc_reply(device, req_id, psutil.cpu_percent(interval=0.1)) - elif method == 'getMemoryLoad': - gateway.gw_send_rpc_reply(device, req_id, psutil.virtual_memory().percent) - else: - print('Unknown method: ' + method) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - # now rpc_request_response will process rpc requests from servers - gateway.gw_set_server_side_rpc_request_handler(rpc_request_response) - # without device connection it is impossible to get any messages - gateway.gw_connect_device("Test Device A2", "default") - try: - # Waiting for the callback - while not gateway.stopped: - time.sleep(1) - except KeyboardInterrupt: - gateway.disconnect() - gateway.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py new file mode 100644 index 0000000..48fa2cc --- /dev/null +++ b/examples/gateway/handle_rpc_requests.py @@ -0,0 +1,114 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +configure_logging() +logger = get_logger("tb_mqtt_client") + + +async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: + """ + Callback to handle RPC requests from the device. + :param device_session: Device session for which the request was made. + :param rpc_request: RPC request from the platform. + """ + logger.info("Received RPC request for device %s: %s", device_session.device_info.device_name, rpc_request) + + response_data = { + "status": "success", + "message": f"RPC request '{rpc_request.method}' processed successfully.", + "data": { + "device_name": device_session.device_info.device_name, + "request_id": rpc_request.request_id, + "method": rpc_request.method, + "params": rpc_request.params + } + } + + rpc_response = GatewayRPCResponse.build(device_session.device_info.device_name, rpc_request.request_id, response_data) + + logger.info("Sending RPC response for request id %r: %r", rpc_request.request_id, rpc_response) + + return rpc_response + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Test Device A1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + # Register callback for requested attributes + client.device_manager.set_rpc_request_callback(device_session.device_info.device_id, device_rpc_request_handler) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + logger.info("Device connected successfully: %s", device_name) + + # Loop to keep the client running and processing RPC requests + try: + logger.info("Client loop started, waiting for RPC requests...") + while True: + await asyncio.sleep(1) # Keep the loop running + except (asyncio.CancelledError, KeyboardInterrupt): + logger.info("Client loop stopped, shutting down.") + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index 33c1559..e06f88f 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -71,6 +71,7 @@ async def main(): await client.send_device_attributes(device_session=device_session, data=attributes, wait_for_publish=True) logger.info("Attributes sent successfully.") + await asyncio.sleep(1) # Wait for attributes to be processed # Request attributes for the device logger.info("Requesting attributes for device: %s", device_name) @@ -79,7 +80,7 @@ async def main(): await client.send_device_attributes_request(device_session, attribute_request, wait_for_publish=True) - await asyncio.sleep(2) # Wait for the response to be processed + await asyncio.sleep(1) # Wait for the response to be processed # Disconnect device logger.info("Disconnecting device: %s", device_name) diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 5f98343..5db1125 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -46,8 +46,6 @@ GATEWAY_ATTRIBUTES_REQUEST_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX GATEWAY_RPC_TOPIC = BASE_GATEWAY_TOPIC + "/rpc" -GATEWAY_RPC_REQUEST_TOPIC = GATEWAY_RPC_TOPIC + REQUEST_TOPIC_SUFFIX -GATEWAY_RPC_RESPONSE_TOPIC = GATEWAY_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX GATEWAY_CLAIM_TOPIC = BASE_GATEWAY_TOPIC + "/claim" # Topic Builders @@ -65,17 +63,5 @@ def build_device_rpc_response_topic(request_id: int) -> str: return DEVICE_RPC_RESPONSE_TOPIC + str(request_id) -def build_gateway_device_telemetry_topic() -> str: - return GATEWAY_TELEMETRY_TOPIC - - -def build_gateway_device_attributes_topic() -> str: - return GATEWAY_ATTRIBUTES_TOPIC - - -def build_gateway_rpc_topic() -> str: - return GATEWAY_RPC_TOPIC - - def build_firmware_update_request_topic(request_id: int, current_chunk: int) -> str: return DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC.format(request_id=request_id, current_chunk=current_chunk) diff --git a/tb_mqtt_client/entities/gateway/base_gateway_event.py b/tb_mqtt_client/entities/gateway/base_gateway_event.py index 590c8ff..1dee49e 100644 --- a/tb_mqtt_client/entities/gateway/base_gateway_event.py +++ b/tb_mqtt_client/entities/gateway/base_gateway_event.py @@ -29,3 +29,6 @@ def set_device_session(self, device_session): @property def device_session(self): return self.__device_session + + def __str__(self) -> str: + return self.__repr__() diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py index c756799..57d4190 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -21,7 +21,7 @@ @dataclass(slots=True, frozen=True) class GatewayRPCRequest(BaseGatewayEvent): - request_id: Union[int, str] + request_id: int device_name: str method: str params: Optional[Any] = None @@ -33,6 +33,9 @@ def __new__(cls, *args, **kwargs): def __repr__(self): return f"RPCRequest(id={self.request_id}, device_name={self.device_name}, method={self.method}, params={self.params})" + def __str__(self): + return f"RPCRequest(id={self.request_id}, device_name={self.device_name}, method={self.method}, params={self.params})" + @classmethod def _deserialize_from_dict(cls, data: Dict[str, Union[str, Dict[str, Any]]]) -> 'GatewayRPCRequest': """ @@ -52,5 +55,6 @@ def _deserialize_from_dict(cls, data: Dict[str, Union[str, Dict[str, Any]]]) -> object.__setattr__(self, 'request_id', request_id) object.__setattr__(self, 'method', data["method"]) object.__setattr__(self, 'params', data.get("params")) - object.__setattr__(self, 'device', device_name) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_RPC_REQUEST) return self diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py index 6e0fa52..f218798 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py @@ -39,14 +39,13 @@ class GatewayRPCResponse(RPCResponse, BaseGatewayEvent): def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of GatewayRPCResponse is not allowed. Use GatewayRPCResponse.build(device_name, request_id, result, error).") - def __repr__(self) -> str: - return f"GatewayRPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" + return f"GatewayRPCResponse(device_name={self.device_name}, request_id={self.request_id}, result={self.result}, error={self.error})" @classmethod def build(cls, # noqa device_name: str, - request_id: Union[int, str], + request_id: int, result: Optional[Any] = None, error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'GatewayRPCResponse': """ @@ -56,6 +55,7 @@ def build(cls, # noqa raise ValueError("Device name must be a non-empty string") self = object.__new__(cls) object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'device_name', device_name) if error is not None: if not isinstance(error, (str, dict, BaseException)): @@ -82,6 +82,7 @@ def build(cls, # noqa validate_json_compatibility(result) object.__setattr__(self, 'result', result) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_RPC_RESPONSE) return self def to_payload_format(self) -> Dict[str, Any]: diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 1eb31b6..17bc13f 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -69,6 +69,7 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._event_dispatcher.register(GatewayEventType.DEVICE_DISCONNECT, self._uplink_message_sender.send_device_disconnect) self._event_dispatcher.register(GatewayEventType.DEVICE_UPLINK, self._uplink_message_sender.send_uplink_message) self._event_dispatcher.register(GatewayEventType.DEVICE_ATTRIBUTE_REQUEST, self._uplink_message_sender.send_attributes_request) + self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_RESPONSE, self._uplink_message_sender.send_rpc_response) self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter() self._uplink_message_sender.set_message_adapter(self._gateway_message_adapter) @@ -76,7 +77,8 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed self._gateway_rpc_handler = GatewayRPCHandler(event_dispatcher=self._event_dispatcher, message_adapter=self._gateway_message_adapter, - device_manager=self.device_manager) + device_manager=self.device_manager, + stop_event=self._stop_event) self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler(event_dispatcher=self._event_dispatcher, message_adapter=self._gateway_message_adapter, device_manager=self.device_manager) diff --git a/tb_mqtt_client/service/gateway/device_manager.py b/tb_mqtt_client/service/gateway/device_manager.py index 240747d..1a108eb 100644 --- a/tb_mqtt_client/service/gateway/device_manager.py +++ b/tb_mqtt_client/service/gateway/device_manager.py @@ -30,7 +30,6 @@ def __init__(self): self._ids_by_original_name: Dict[str, UUID] = {} self.__connected_devices: Set[DeviceSession] = set() - def register(self, device_name: str, device_profile: str = "default") -> DeviceSession: session = self.get_by_name(device_name) if session: diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py index 04ffd88..92ba824 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from typing import Awaitable, Callable, Optional +from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter @@ -30,14 +34,18 @@ class GatewayRPCHandler: Handles incoming RPC request messages for a device connected through the gateway. """ - def __init__(self, event_dispatcher: DirectEventDispatcher, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + def __init__(self, event_dispatcher: DirectEventDispatcher, + message_adapter: GatewayMessageAdapter, + device_manager: DeviceManager, + stop_event: asyncio.Event): self._event_dispatcher = event_dispatcher self._message_adapter = message_adapter self._device_manager = device_manager + self._stop_event = stop_event self._callback: Optional[Callable[[GatewayRPCRequest], Awaitable[GatewayRPCResponse]]] = None self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_REQUEST, self.handle) - async def handle(self, topic: str, payload: bytes) -> Optional[GatewayRPCResponse]: + async def handle(self, topic: str, payload: bytes) -> None: """ Process the RPC request and return the response for it. :returns: GatewayRPCResponse or None if failed @@ -48,27 +56,43 @@ async def handle(self, topic: str, payload: bytes) -> Optional[GatewayRPCRespons return None rpc_request: Optional[GatewayRPCRequest] = None + result = None + device_session: Optional[DeviceSession] = None try: data = self._message_adapter.deserialize_to_dict(payload) rpc_request = self._message_adapter.parse_rpc_request(topic, data) device_session = self._device_manager.get_by_name(rpc_request.device_name) - if device_session: - rpc_request.set_device_session(device_session) - else: + if not device_session: logger.warning("No device session found for device: %s", rpc_request.device_name) return None logger.debug("Handling RPC method id: %i - %s with params: %s", rpc_request.request_id, rpc_request.method, rpc_request.params) - result = await self._event_dispatcher.dispatch(rpc_request) - if not isinstance(result, GatewayRPCResponse): - logger.error("RPC callback must return an instance of GatewayRPCResponse, got: %s", type(result)) + result = await self._event_dispatcher.dispatch(rpc_request, device_session=device_session) # noqa + if not result: return None + elif not isinstance(result, GatewayRPCResponse): + raise TypeError("RPC callback must return an instance of GatewayRPCResponse, got: %s", type(result)) logger.debug("RPC response for device %r method id: %i - %s with result: %s", rpc_request.device_name, result.request_id, rpc_request.method, result.result) - return result - except Exception as e: logger.exception("Failed to process RPC request: %s", e) if rpc_request is None: return None - GatewayRPCResponse.build(device_name=rpc_request.device_name, request_id=rpc_request.request_id, error=e) + result = GatewayRPCResponse.build(device_name=rpc_request.device_name, request_id=rpc_request.request_id, error=e) + + if not device_session: + logger.warning("No device session found for device: %s, cannot send RPC response", + rpc_request.device_name) + return None + + future = await self._event_dispatcher.dispatch(result) # noqa + + if not future: + logger.warning("No publish futures were returned from message queue for RPC response of device %s, request id %i", + rpc_request.device_name, rpc_request.request_id) + return None + try: + await await_or_stop(future, timeout=1, stop_event=self._stop_event) + except TimeoutError: + logger.warning("RPC response publish timed out for device %s, request id %i", + rpc_request.device_name, rpc_request.request_id) diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 32a8ddf..19ee32a 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -24,9 +24,10 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ - GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage @@ -79,6 +80,14 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr """ pass + @abstractmethod + def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[str, bytes]: + """ + Build the payload for a gateway RPC response. + This method should be implemented to handle the specific format of the payload. + """ + pass + @abstractmethod def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: """ @@ -242,6 +251,20 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr logger.error("Failed to build gateway attribute request payload: %s", str(e)) raise ValueError("Invalid gateway attribute request format") from e + def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[str, bytes]: + """ + Build the payload for a gateway RPC response. + This method serializes the GatewayRPCResponse to JSON format. + """ + try: + payload = dumps(rpc_response.to_payload_format()) + logger.trace("Built RPC response payload for device='%s', request_id=%i", + rpc_response.device_name, rpc_response.request_id) + return GATEWAY_RPC_TOPIC, payload + except Exception as e: + logger.error("Failed to build RPC response payload: %s", str(e)) + raise ValueError("Invalid RPC response format") from e + def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: try: device_name = data['device'] diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py index acd6d9a..9bdf6b4 100644 --- a/tb_mqtt_client/service/gateway/message_sender.py +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -130,6 +130,25 @@ async def send_attributes_request(self, attribute_request: GatewayAttributeReque qos=qos ) + async def send_rpc_response(self, rpc_response: 'GatewayRPCResponse', qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Sends an RPC response message to the platform. + + :param rpc_response: GatewayRPCResponse object containing the RPC response details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send RPC response. Message queue is not set, do you connected to the platform?") + return None + topic, payload = self._message_adapter.build_rpc_response_payload(rpc_response=rpc_response) + return await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=1, + qos=qos + ) + def set_message_queue(self, message_queue: MessageQueue): """ Sets the message queue for sending uplink messages. From 983bec9da0ea498e155cbeaec00c1367b53c3a16 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 12:22:00 +0300 Subject: [PATCH 50/74] Updated loggers for examples --- .../gateway/connect_and_disconnect_device.py | 71 +++++++++++++++++++ examples/gateway/handle_rpc_requests.py | 18 +---- examples/gateway/request_attributes.py | 2 +- examples/gateway/send_attributes.py | 3 +- examples/gateway/send_timeseries.py | 5 +- 5 files changed, 76 insertions(+), 23 deletions(-) create mode 100644 examples/gateway/connect_and_disconnect_device.py diff --git a/examples/gateway/connect_and_disconnect_device.py b/examples/gateway/connect_and_disconnect_device.py new file mode 100644 index 0000000..d5cfebf --- /dev/null +++ b/examples/gateway/connect_and_disconnect_device.py @@ -0,0 +1,71 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Test Device A1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + logger.info("Device connected successfully: %s", device_name) + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py index 48fa2cc..53c4073 100644 --- a/examples/gateway/handle_rpc_requests.py +++ b/examples/gateway/handle_rpc_requests.py @@ -12,33 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import asyncio from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import configure_logging, get_logger -from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.service.gateway.client import GatewayClient from tb_mqtt_client.service.gateway.device_session import DeviceSession configure_logging() -logger = get_logger("tb_mqtt_client") +logger = get_logger(__name__) async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index e06f88f..6bce87e 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -23,7 +23,7 @@ from tb_mqtt_client.service.gateway.device_session import DeviceSession configure_logging() -logger = get_logger("tb_mqtt_client") +logger = get_logger(__name__) async def requested_attributes_handler(device_session: DeviceSession, response: RequestedAttributeResponse): diff --git a/examples/gateway/send_attributes.py b/examples/gateway/send_attributes.py index 942fa25..3bd30d0 100644 --- a/examples/gateway/send_attributes.py +++ b/examples/gateway/send_attributes.py @@ -19,9 +19,8 @@ from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.service.gateway.client import GatewayClient - configure_logging() -logger = get_logger("tb_mqtt_client") +logger = get_logger(__name__) async def main(): diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index 463d320..d7443e9 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -15,14 +15,13 @@ import asyncio from time import time -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.gateway.client import GatewayClient - configure_logging() -logger = get_logger("tb_mqtt_client") +logger = get_logger(__name__) async def main(): From 0f6a649733b71358cc20afd6982f64104babb992 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 12:33:14 +0300 Subject: [PATCH 51/74] Added processing for device disconnection and adjusted example --- .../gateway/connect_and_disconnect_device.py | 2 ++ .../gateway/device_disconnect_message.py | 7 +++- tb_mqtt_client/service/gateway/client.py | 35 ++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/examples/gateway/connect_and_disconnect_device.py b/examples/gateway/connect_and_disconnect_device.py index d5cfebf..8123b7f 100644 --- a/examples/gateway/connect_and_disconnect_device.py +++ b/examples/gateway/connect_and_disconnect_device.py @@ -60,6 +60,8 @@ async def main(): await client.disconnect_device(device_session, wait_for_publish=True) logger.info("Device disconnected successfully: %s", device_name) + await asyncio.sleep(.1) # Wait for the disconnect to complete + # Disconnect client await client.disconnect() diff --git a/tb_mqtt_client/entities/gateway/device_disconnect_message.py b/tb_mqtt_client/entities/gateway/device_disconnect_message.py index 2d9ecf1..2a81ab6 100644 --- a/tb_mqtt_client/entities/gateway/device_disconnect_message.py +++ b/tb_mqtt_client/entities/gateway/device_disconnect_message.py @@ -30,14 +30,18 @@ from dataclasses import dataclass from typing import Dict +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + @dataclass(slots=True, frozen=True) -class DeviceDisconnectMessage: +class DeviceDisconnectMessage(BaseGatewayEvent): """ Represents a device disconnection message in the ThingsBoard Gateway MQTT client. This class is used to encapsulate the details of a device connection message. """ device_name: str + event_type: GatewayEventType = GatewayEventType.DEVICE_DISCONNECT def __new__(self, *args, **kwargs): raise TypeError( @@ -55,6 +59,7 @@ def build(cls, device_name: str) -> 'DeviceDisconnectMessage': raise ValueError("Device name must not be empty.") self = object.__new__(cls) object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_DISCONNECT) return self def to_payload_format(self) -> Dict[str, str]: diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 17bc13f..930946f 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -26,6 +26,7 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder @@ -155,7 +156,39 @@ async def connect_device(self, return device_session, results[0] if len(results) == 1 else results async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool): - pass + """ + Disconnect a device from the gateway. + + :param device_session: The DeviceSession object for the device + :param wait_for_publish: Whether to wait for the publish result + :return: PublishResult or Future[PublishResult] if successful, None if failed + """ + logger.info("Disconnecting device %s", device_session.device_info.device_name) + if not device_session: + logger.warning("No device session provided for disconnecting") + return None + device_disconnect_message = DeviceDisconnectMessage.build(device_session.device_info.device_name) + + futures = await self._event_dispatcher.dispatch(device_disconnect_message, qos=self._config.qos) # noqa + + if not futures: + logger.warning("No publish futures were returned from message queue") + return device_session, [] + + if not wait_for_publish: + return device_session, futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(mqtt_topics.GATEWAY_CONNECT_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + self.device_manager.unregister(device_session.device_info.device_id) + return device_session, results[0] if len(results) == 1 else results async def send_device_timeseries(self, device_session: DeviceSession, From 35cfaa6b036967c9dd1c919589c839f06e91a5ad Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 16 Jul 2025 12:51:42 +0300 Subject: [PATCH 52/74] Added attribute updates handling and example --- .../DEPRECATEDconnect_disconnect_device.py | 30 ------- .../DEPRECATEDsubscribe_to_attributes.py | 57 ------------- .../gateway/connect_and_disconnect_device.py | 1 + examples/gateway/handle_attribute_updates.py | 84 +++++++++++++++++++ examples/gateway/handle_rpc_requests.py | 7 +- examples/gateway/request_attributes.py | 7 +- examples/gateway/send_attributes.py | 1 + examples/gateway/send_timeseries.py | 1 + .../entities/data/attribute_update.py | 5 +- .../gateway/gateway_attribute_update.py | 4 +- .../gateway_attribute_updates_handler.py | 4 +- 11 files changed, 103 insertions(+), 98 deletions(-) delete mode 100644 examples/gateway/DEPRECATEDconnect_disconnect_device.py delete mode 100644 examples/gateway/DEPRECATEDsubscribe_to_attributes.py create mode 100644 examples/gateway/handle_attribute_updates.py diff --git a/examples/gateway/DEPRECATEDconnect_disconnect_device.py b/examples/gateway/DEPRECATEDconnect_disconnect_device.py deleted file mode 100644 index c32068c..0000000 --- a/examples/gateway/DEPRECATEDconnect_disconnect_device.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.INFO) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - gateway.gw_connect_device("Example Name") - # device disconnecting will not delete device, gateway just stops receiving messages - gateway.gw_disconnect_device("Example Name") - gateway.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/DEPRECATEDsubscribe_to_attributes.py b/examples/gateway/DEPRECATEDsubscribe_to_attributes.py deleted file mode 100644 index 7b00576..0000000 --- a/examples/gateway/DEPRECATEDsubscribe_to_attributes.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging.handlers -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.INFO) - - -def callback(result): - logging.info("Callback for attributes, %r", result) - - -def callback_for_everything(result): - logging.info("Everything goes here, %r", result) - - -def callback_for_specific_attr(result): - logging.info("Specific attribute callback, %r", result) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - # without device connection it is impossible to get any messages - gateway.gw_connect_device("ImageTest") - - gateway.gw_subscribe_to_all_attributes(callback_for_everything) - - gateway.gw_subscribe_to_attribute("ImageTest", "image", callback_for_specific_attr) - - sub_id = gateway.gw_subscribe_to_all_device_attributes("ImageTest", callback) - gateway.gw_unsubscribe(sub_id) - - try: - # Waiting for the callback - while not gateway.stopped: - time.sleep(1) - except KeyboardInterrupt: - gateway.disconnect() - gateway.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/connect_and_disconnect_device.py b/examples/gateway/connect_and_disconnect_device.py index 8123b7f..2e1694b 100644 --- a/examples/gateway/connect_and_disconnect_device.py +++ b/examples/gateway/connect_and_disconnect_device.py @@ -53,6 +53,7 @@ async def main(): if not device_session: logger.error("Failed to register device: %s", device_name) return + logger.info("Device connected successfully: %s", device_name) # Disconnect device diff --git a/examples/gateway/handle_attribute_updates.py b/examples/gateway/handle_attribute_updates.py new file mode 100644 index 0000000..de34de3 --- /dev/null +++ b/examples/gateway/handle_attribute_updates.py @@ -0,0 +1,84 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +configure_logging() +logger = get_logger(__name__) + + +async def attribute_update_handler(device_session: DeviceSession, attribute_update: AttributeUpdate): + """ + Callback to handle attribute updates. + :param device_session: Device session for which attributes were requested. + :param attribute_update: Updated attributes . + """ + logger.info("Received attribute update for device %s: %s", + device_session.device_info.device_name, attribute_update.entries) + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Test Device A1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + # Register callback for requested attributes + client.device_manager.set_attribute_update_callback(device_session.device_info.device_id, attribute_update_handler) + + logger.info("Device connected successfully: %s", device_name) + + # Loop to keep the client running and processing Attribute updates + try: + logger.info("Client loop started, waiting for Attribute updates...") + while True: + await asyncio.sleep(1) # Keep the loop running + except (asyncio.CancelledError, KeyboardInterrupt): + logger.info("Client loop stopped, shutting down.") + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + await asyncio.sleep(.1) # Wait for disconnection to complete + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py index 53c4073..7938d7e 100644 --- a/examples/gateway/handle_rpc_requests.py +++ b/examples/gateway/handle_rpc_requests.py @@ -66,12 +66,13 @@ async def main(): logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) - # Register callback for requested attributes - client.device_manager.set_rpc_request_callback(device_session.device_info.device_id, device_rpc_request_handler) - if not device_session: logger.error("Failed to register device: %s", device_name) return + + # Register callback for requested attributes + client.device_manager.set_rpc_request_callback(device_session.device_info.device_id, device_rpc_request_handler) + logger.info("Device connected successfully: %s", device_name) # Loop to keep the client running and processing RPC requests diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index 6bce87e..ffccbc5 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -53,12 +53,13 @@ async def main(): logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) - # Register callback for requested attributes - client.device_manager.set_attribute_response_callback(device_session.device_info.device_id, requested_attributes_handler) - if not device_session: logger.error("Failed to register device: %s", device_name) return + + # Register callback for requested attributes + client.device_manager.set_attribute_response_callback(device_session.device_info.device_id, requested_attributes_handler) + logger.info("Device connected successfully: %s", device_name) # Send attributes to request them later diff --git a/examples/gateway/send_attributes.py b/examples/gateway/send_attributes.py index 3bd30d0..9f3399d 100644 --- a/examples/gateway/send_attributes.py +++ b/examples/gateway/send_attributes.py @@ -41,6 +41,7 @@ async def main(): if not device_session: logger.error("Failed to register device: %s", device_name) return + logger.info("Device connected successfully: %s", device_name) # Send attributes as raw dictionary diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index d7443e9..80124de 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -42,6 +42,7 @@ async def main(): if not device_session: logger.error("Failed to register device: %s", device_name) return + logger.info("Device connected successfully: %s", device_name) # Send time series as raw dictionary diff --git a/tb_mqtt_client/entities/data/attribute_update.py b/tb_mqtt_client/entities/data/attribute_update.py index 416d5f1..336be4a 100644 --- a/tb_mqtt_client/entities/data/attribute_update.py +++ b/tb_mqtt_client/entities/data/attribute_update.py @@ -16,11 +16,14 @@ from typing import Dict, Any, List from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType @dataclass(slots=True) -class AttributeUpdate: +class AttributeUpdate(BaseGatewayEvent): entries: List[AttributeEntry] + event_type: GatewayEventType = GatewayEventType.DEVICE_ATTRIBUTE_UPDATE def __repr__(self): return f"AttributeUpdate(entries={self.entries})" diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py index 411370c..0f6e0b7 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -13,11 +13,11 @@ # limitations under the License. from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent from tb_mqtt_client.entities.gateway.event_type import GatewayEventType -from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent -class GatewayAttributeUpdate(GatewayEvent): +class GatewayAttributeUpdate(BaseGatewayEvent): """ Represents an attribute update event for a device connected to a gateway. This event is used to notify about changes in device shared attributes. diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py index 46d8e2a..1f2ef78 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py @@ -27,7 +27,7 @@ def __init__(self, self.message_adapter = message_adapter self.device_manager = device_manager - def handle(self, topic: str, payload: bytes): + async def handle(self, topic: str, payload: bytes): """ Handles the gateway attribute update event by dispatching the attribute update """ @@ -36,4 +36,4 @@ def handle(self, topic: str, payload: bytes): device_session = self.device_manager.get_by_name(gateway_attribute_update.device_name) if device_session: gateway_attribute_update.set_device_session(device_session) - self.event_dispatcher.dispatch(gateway_attribute_update) + await self.event_dispatcher.dispatch(gateway_attribute_update.attribute_update, device_session=device_session) # noqa From 3fb01d5a71d8239b9f5f90a0c647bf333982bb18 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 17 Jul 2025 07:25:34 +0300 Subject: [PATCH 53/74] Added claiming request processing for devices connected through the gateway and example --- examples/gateway/claim_device.py | 97 ++++++++++++++++ examples/gateway/handle_attribute_updates.py | 2 +- .../entities/gateway/device_info.py | 17 +++ tb_mqtt_client/entities/gateway/event_type.py | 3 + .../entities/gateway/gateway_claim_request.py | 105 ++++++++++++++++++ tb_mqtt_client/service/gateway/client.py | 40 ++++++- .../service/gateway/device_session.py | 8 ++ .../gateway/gateway_client_interface.py | 7 ++ .../service/gateway/message_adapter.py | 25 ++++- .../service/gateway/message_sender.py | 23 +++- 10 files changed, 322 insertions(+), 5 deletions(-) create mode 100644 examples/gateway/claim_device.py create mode 100644 tb_mqtt_client/entities/gateway/gateway_claim_request.py diff --git a/examples/gateway/claim_device.py b/examples/gateway/claim_device.py new file mode 100644 index 0000000..afb9f16 --- /dev/null +++ b/examples/gateway/claim_device.py @@ -0,0 +1,97 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequestBuilder +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) + +CLAIMING_SECRET_KEY = "YOUR_CLAIMING_SECRET" # Replace it with your actual claiming secret key +CLAIMING_DURATION_MS = 30000 # Duration in milliseconds for which the claim is valid + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Device for claiming" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + logger.info("Device connected successfully: %s", device_name) + + # Claiming request creation + logger.info("Creating claim request for device: %s", device_name) + + device_claim_request = ClaimRequest.build(CLAIMING_SECRET_KEY, CLAIMING_DURATION_MS) + + # Note: You can claim several devices at once by adding them to the gateway claim request + gateway_claim_request_builder = GatewayClaimRequestBuilder() + gateway_claim_request_builder.add_device_request(device_session, device_claim_request) + + # Send claim request to the platform + publish_results = await client.send_device_claim_request(device_session, + gateway_claim_request_builder.build(), + wait_for_publish=True) + if not publish_results: + logger.error("Failed to send claim request for device: %s", device_name) + return + + logger.info("Claim request sent successfully for device: %s, you have %r seconds to claim device using ThingsBoard UI or API.", + device_name, CLAIMING_DURATION_MS / 1000) + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/handle_attribute_updates.py b/examples/gateway/handle_attribute_updates.py index de34de3..4a5d41d 100644 --- a/examples/gateway/handle_attribute_updates.py +++ b/examples/gateway/handle_attribute_updates.py @@ -28,7 +28,7 @@ async def attribute_update_handler(device_session: DeviceSession, attribute_upda """ Callback to handle attribute updates. :param device_session: Device session for which attributes were requested. - :param attribute_update: Updated attributes . + :param attribute_update: Updated attributes object that contains attribute entries. """ logger.info("Received attribute update for device %s: %s", device_session.device_info.device_name, attribute_update.entries) diff --git a/tb_mqtt_client/entities/gateway/device_info.py b/tb_mqtt_client/entities/gateway/device_info.py index 8a0b61c..d3b7f88 100644 --- a/tb_mqtt_client/entities/gateway/device_info.py +++ b/tb_mqtt_client/entities/gateway/device_info.py @@ -63,3 +63,20 @@ def __str__(self) -> str: f"device_name={self.device_name}, " f"device_profile={self.device_profile}, " f"original_name={self.original_name})") + + def __repr__(self) -> str: + return (f"DeviceInfo(device_id={self.device_id!r}, " + f"device_name={self.device_name!r}, " + f"device_profile={self.device_profile!r}, " + f"original_name={self.original_name!r})") + + def __eq__(self, other): + if not isinstance(other, DeviceInfo): + return NotImplemented + return (self.device_id == other.device_id and + self.device_name == other.device_name and + self.device_profile == other.device_profile and + self.original_name == other.original_name) + + def __hash__(self): + return hash((self.device_id, self.device_name, self.device_profile, self.original_name)) diff --git a/tb_mqtt_client/entities/gateway/event_type.py b/tb_mqtt_client/entities/gateway/event_type.py index 792d322..41fbd0b 100644 --- a/tb_mqtt_client/entities/gateway/event_type.py +++ b/tb_mqtt_client/entities/gateway/event_type.py @@ -34,6 +34,9 @@ class GatewayEventType(Enum): DEVICE_SESSION_STATE_CHANGE = "gateway.device.session.state.change" DEVICE_RPC_REQUEST = "gateway.device.rpc.request" DEVICE_RPC_RESPONSE = "gateway.device.rpc.response" + + GATEWAY_CLAIM_REQUEST = "gateway.gateway.claim" + RPC_REQUEST_RECEIVE = "device.rpc.request" RPC_RESPONSE_SEND = "device.rpc.response" diff --git a/tb_mqtt_client/entities/gateway/gateway_claim_request.py b/tb_mqtt_client/entities/gateway/gateway_claim_request.py new file mode 100644 index 0000000..1646cfa --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_claim_request.py @@ -0,0 +1,105 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Dict, Any, Union + +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +@dataclass(slots=True, frozen=True) +class GatewayClaimRequest(BaseGatewayEvent): + + devices_requests: Dict[Union[DeviceSession, str], ClaimRequest] = None + event_type: GatewayEventType = GatewayEventType.GATEWAY_CLAIM_REQUEST + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of GatewayClaimRequest is not allowed. Use 'GatewayClaimRequestBuilder.build(...)'.") + + def __repr__(self) -> str: + return f"GatewayClaimRequest(devices_requests={self.devices_requests})" + + def add_device_request(self, device_name_or_session: Union[DeviceSession, str], claim_request: ClaimRequest): + """ + Add a device claim request to the GatewayClaimRequest. + """ + self.devices_requests[device_name_or_session] = claim_request + + def to_payload_format(self) -> Dict[str, Any]: + """ + Convert the claim request to the expected MQTT JSON payload format. + """ + payload = {} + for device_session, claim_request in self.devices_requests.items(): + device_name = device_session + if isinstance(device_session, DeviceSession): + device_name = device_session.device_info.device_name + payload[device_name] = claim_request.to_payload_format() + return payload + + @classmethod + def build(cls) -> 'GatewayClaimRequest': + """ + Build a new GatewayClaimRequest instance. + """ + self = object.__new__(cls) + object.__setattr__(self, 'devices_requests', {}) + object.__setattr__(self, 'event_type', GatewayEventType.GATEWAY_CLAIM_REQUEST) + return self + + +class GatewayClaimRequestBuilder: + """ + Builder class for GatewayClaimRequest. + Allows adding multiple device claim requests in a fluent interface style. + """ + def __init__(self): + self._devices_requests: Dict[Union[DeviceSession, str], ClaimRequest] = {} + + def add_device_request(self, device_name_or_session: Union[DeviceSession, str], device_claim_request: ClaimRequest) -> 'GatewayClaimRequestBuilder': + """ + Add a device claim request to the builder. + """ + if not isinstance(device_name_or_session, (DeviceSession, str)): + raise ValueError("device_session must be an instance of DeviceSession or a string representing the device name") + if not isinstance(device_claim_request, ClaimRequest): + raise ValueError("device_claim_request must be an instance of ClaimRequest") + self._devices_requests[device_name_or_session] = device_claim_request + return self + + def build(self) -> GatewayClaimRequest: + """ + Build the GatewayClaimRequest with all added device requests. + """ + gateway_claim_request = GatewayClaimRequest.build() + for device_session, claim_request in self._devices_requests.items(): + gateway_claim_request.add_device_request(device_session, claim_request) + return gateway_claim_request \ No newline at end of file diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 930946f..9af53ff 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -29,7 +29,7 @@ from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest from tb_mqtt_client.service.device.client import DeviceClient from tb_mqtt_client.service.gateway.device_manager import DeviceManager from tb_mqtt_client.service.gateway.device_session import DeviceSession @@ -71,6 +71,7 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._event_dispatcher.register(GatewayEventType.DEVICE_UPLINK, self._uplink_message_sender.send_uplink_message) self._event_dispatcher.register(GatewayEventType.DEVICE_ATTRIBUTE_REQUEST, self._uplink_message_sender.send_attributes_request) self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_RESPONSE, self._uplink_message_sender.send_rpc_response) + self._event_dispatcher.register(GatewayEventType.GATEWAY_CLAIM_REQUEST, self._uplink_message_sender.send_claim_request) self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter() self._uplink_message_sender.set_message_adapter(self._gateway_message_adapter) @@ -114,7 +115,7 @@ async def connect(self): async def connect_device(self, device_name_or_device_connect_message: Union[str, DeviceConnectMessage], - device_profile: str, + device_profile: str = 'default', wait_for_publish=False) -> Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: """ Connect a device to the gateway. @@ -297,6 +298,41 @@ async def send_device_attributes_request(self, device_session: DeviceSession, at return results[0] if len(results) == 1 else results + async def send_device_claim_request(self, + device_session: DeviceSession, + gateway_claim_request: GatewayClaimRequest, + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Send a claim request for a device to the platform. + :param device_session: The DeviceSession object for the device + :param gateway_claim_request: Claim request data + :param wait_for_publish: Whether to wait for the publish result + """ + logger.trace("Sending claim request for device %s", device_session.device_info.device_name) + if not device_session or not gateway_claim_request: + logger.warning("No device session or claim request provided for sending claim request") + return None + + futures = await self._event_dispatcher.dispatch(gateway_claim_request, qos=self._config.qos) # noqa + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for claim request publish result") + result = PublishResult(mqtt_topics.GATEWAY_CLAIM_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + async def disconnect(self): """ Disconnect from the platform. diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index 93711c0..366efc3 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -80,3 +80,11 @@ async def handle_event_to_device(self, event: BaseGatewayEvent) -> Optional[Awai return await cb(self, event) else: return cb(self, event) + + def __eq__(self, other): + if not isinstance(other, DeviceSession): + return NotImplemented + return self.device_info.device_id == other.device_info.device_id + + def __hash__(self): + return hash(self.device_info) diff --git a/tb_mqtt_client/service/gateway/gateway_client_interface.py b/tb_mqtt_client/service/gateway/gateway_client_interface.py index a891136..ccb8be2 100644 --- a/tb_mqtt_client/service/gateway/gateway_client_interface.py +++ b/tb_mqtt_client/service/gateway/gateway_client_interface.py @@ -21,6 +21,7 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest from tb_mqtt_client.service.base_client import BaseClient from tb_mqtt_client.service.gateway.device_session import DeviceSession @@ -52,3 +53,9 @@ async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... + + @abstractmethod + async def send_device_claim_request(self, + device_session: DeviceSession, + gateway_claim_request: GatewayClaimRequest, + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 19ee32a..8336fdc 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -24,9 +24,11 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ - GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ + GATEWAY_CLAIM_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage @@ -88,6 +90,14 @@ def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[ """ pass + @abstractmethod + def build_claim_request_payload(self, claim_request: GatewayClaimRequest) -> Tuple[str, bytes]: + """ + Build the payload for a gateway claim request. + This method should be implemented to handle the specific format of the payload. + """ + pass + @abstractmethod def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: """ @@ -265,6 +275,19 @@ def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[ logger.error("Failed to build RPC response payload: %s", str(e)) raise ValueError("Invalid RPC response format") from e + def build_claim_request_payload(self, claim_request: GatewayClaimRequest) -> Tuple[str, bytes]: + """ + Build the payload for a gateway claim request. + This method serializes the GatewayClaimRequest to JSON format. + """ + try: + payload = dumps(claim_request.to_payload_format()) + logger.trace("Built claim request payload for devices: %s", list(claim_request.devices_requests.keys())) + return GATEWAY_CLAIM_TOPIC, payload + except Exception as e: + logger.error("Failed to build claim request payload: %s", str(e)) + raise ValueError("Invalid claim request format") from e + def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: try: device_name = data['device'] diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py index 9bdf6b4..41689e7 100644 --- a/tb_mqtt_client/service/gateway/message_sender.py +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -21,6 +21,8 @@ from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter from tb_mqtt_client.service.message_queue import MessageQueue @@ -130,7 +132,7 @@ async def send_attributes_request(self, attribute_request: GatewayAttributeReque qos=qos ) - async def send_rpc_response(self, rpc_response: 'GatewayRPCResponse', qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: """ Sends an RPC response message to the platform. @@ -149,6 +151,25 @@ async def send_rpc_response(self, rpc_response: 'GatewayRPCResponse', qos=1) -> qos=qos ) + async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + """ + Sends a claim request message to the platform. + + :param claim_request: GatewayClaimRequest object containing the claim request details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send claim request. Message queue is not set, do you connected to the platform?") + return None + topic, payload = self._message_adapter.build_claim_request_payload(claim_request=claim_request) + return await self._message_queue.publish( + topic=topic, + payload=payload, + datapoints_count=1, + qos=qos + ) + def set_message_queue(self, message_queue: MessageQueue): """ Sets the message queue for sending uplink messages. From 17a35e20ba3fa75d3e12dbeaccef70c389784836 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 17 Jul 2025 09:14:08 +0300 Subject: [PATCH 54/74] Fix for timeseries without ts --- tb_mqtt_client/service/device/message_adapter.py | 15 ++++++++++----- tb_mqtt_client/service/message_queue.py | 5 +++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index cc29e68..952d55b 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -411,10 +411,15 @@ def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: @staticmethod def pack_timeseries(msg: 'DeviceUplinkMessage') -> List[Dict[str, Any]]: now_ts = int(datetime.now(UTC).timestamp() * 1000) - - packed = [ - {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} - for entry in chain.from_iterable(msg.timeseries.values()) - ] + if all(entry.ts is None for entry in chain.from_iterable(msg.timeseries.values())): + packed = { + "ts": now_ts, + "values": {entry.key: entry.value for entry in chain.from_iterable(msg.timeseries.values())} + } + else: + packed = [ + {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} + for entry in chain.from_iterable(msg.timeseries.values()) + ] return packed diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 6a2f55f..c9be904 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -51,8 +51,6 @@ def __init__(self, self._telemetry_rate_limit = telemetry_rate_limit self._telemetry_dp_rate_limit = telemetry_dp_rate_limit self._backpressure = self._mqtt_manager.backpressure - self._pending_ack_futures: Dict[int, asyncio.Future[PublishResult]] = {} - self._pending_ack_callbacks: Dict[int, Callable[[bool], None]] = {} # Queue expects tuples of (topic, payload, delivery_futures, datapoints_count, qos) self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) self._pending_queue_tasks: set[asyncio.Task] = set() @@ -254,6 +252,9 @@ async def _try_publish(self, def resolve_attached(publish_future: asyncio.Future): try: publish_result = publish_future.result() + except asyncio.CancelledError: + logger.trace("Publish future was cancelled: %r, id: %r", publish_future, id(publish_future)) + publish_result = PublishResult(topic, qos, -1, len(payload), -1) except Exception as exc: logger.warning("Publish failed with exception: %s", exc) logger.debug("Resolving delivery futures with failure:", exc_info=exc) From 96d6309f4c3410807ebf41e3a74669f899fdf70d Mon Sep 17 00:00:00 2001 From: imbeacon Date: Tue, 22 Jul 2025 14:49:01 +0300 Subject: [PATCH 55/74] Added shared future for splitted and grouped messages for confirmation --- examples/device/handle_attribute_updates.py | 2 +- examples/device/handle_rpc_requests.py | 4 +- examples/device/load.py | 176 ++--- examples/device/request_attributes.py | 14 +- examples/device/send_client_side_rpc.py | 7 +- tb_mqtt_client/common/async_utils.py | 88 ++- tb_mqtt_client/common/gmqtt_patch.py | 604 +++++++++++------- tb_mqtt_client/common/mqtt_message.py | 75 +++ tb_mqtt_client/common/publish_result.py | 38 +- tb_mqtt_client/entities/data/data_entry.py | 9 +- .../entities/data/device_uplink_message.py | 11 +- tb_mqtt_client/entities/data/rpc_request.py | 3 + .../gateway/gateway_uplink_message.py | 5 +- tb_mqtt_client/service/device/client.py | 122 ++-- .../device/handlers/rpc_response_handler.py | 2 + .../service/device/message_adapter.py | 173 +++-- .../service/gateway/message_adapter.py | 7 +- tb_mqtt_client/service/message_queue.py | 317 ++++----- tb_mqtt_client/service/message_splitter.py | 114 ++-- tb_mqtt_client/service/mqtt_manager.py | 147 +++-- 20 files changed, 1173 insertions(+), 745 deletions(-) create mode 100644 tb_mqtt_client/common/mqtt_message.py diff --git a/examples/device/handle_attribute_updates.py b/examples/device/handle_attribute_updates.py index 649a2e4..b1ced19 100644 --- a/examples/device/handle_attribute_updates.py +++ b/examples/device/handle_attribute_updates.py @@ -30,7 +30,7 @@ async def attribute_update_callback(update: AttributeUpdate): - logger.info("Received attribute update:", update) + logger.info("Received attribute update: %r", update) async def main(): diff --git a/examples/device/handle_rpc_requests.py b/examples/device/handle_rpc_requests.py index eb82463..584feb8 100644 --- a/examples/device/handle_rpc_requests.py +++ b/examples/device/handle_rpc_requests.py @@ -31,7 +31,7 @@ async def rpc_request_callback(request: RPCRequest) -> RPCResponse: - logger.info("Received RPC:", request) + logger.info("Received RPC: %r", request) if request.method == "ping": return RPCResponse.build(request_id=request.request_id, result={"pong": True}) @@ -51,7 +51,7 @@ async def main(): try: while True: await asyncio.sleep(1) - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): logger.info("Shutting down...") await client.stop() diff --git a/examples/device/load.py b/examples/device/load.py index 2bd87d4..2c739f1 100644 --- a/examples/device/load.py +++ b/examples/device/load.py @@ -1,27 +1,12 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Example script to send a high load of telemetry data to ThingsBoard using the DeviceClient - import asyncio import logging import signal import time -from datetime import datetime, UTC +from datetime import UTC, datetime from random import randint from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.rpc_request import RPCRequest @@ -37,7 +22,6 @@ # --- Constants --- BATCH_SIZE = 1000 -YIELD_DELAY = 0.001 MAX_PENDING_BATCHES = 100 FUTURE_TIMEOUT = 1.0 @@ -51,12 +35,77 @@ async def rpc_request_callback(request: RPCRequest): return RPCResponse(request_id=request.request_id, result={"status": "ok"}) +async def collect_futures_with_counts(future_or_result): + """ + Accepts either a Future or a direct PublishResult. + Returns [(Future or result, datapoint_count)] if successful, else []. + """ + if future_or_result is None: + return [] + + if isinstance(future_or_result, PublishResult): + if future_or_result.is_successful(): + return [(future_or_result, future_or_result.datapoints_count)] + return [] + + if isinstance(future_or_result, list): + results = [] + for item in future_or_result: + if isinstance(item, PublishResult): + if item.is_successful(): + results.append((item, item.datapoints_count)) + elif asyncio.isfuture(item) or asyncio.iscoroutine(item): + results.extend(await collect_futures_with_counts(item)) + else: + logger.warning("Unexpected item in list: %r", item) + return results + + if asyncio.isfuture(future_or_result) or asyncio.iscoroutine(future_or_result): + try: + result = await asyncio.wait_for(asyncio.shield(future_or_result), timeout=FUTURE_TIMEOUT) + if isinstance(result, PublishResult): + return [(future_or_result, result.datapoints_count)] + else: + logger.warning("Unexpected publish result: %r", result) + except asyncio.TimeoutError: + logger.warning("Timeout waiting for delivery future.") + except Exception as e: + logger.warning("Failed to get result from future: %s", e) + return [] + + logger.warning("collect_futures_with_counts() got unexpected type: %r", type(future_or_result)) + return [] + + +async def process_pending_futures(pending_futures, delivered_batches, delivered_datapoints): + try: + all_futures = [fut for fut, _ in pending_futures] + done, _ = await asyncio.wait(all_futures, timeout=FUTURE_TIMEOUT, return_when=asyncio.ALL_COMPLETED) + except Exception as e: + logger.warning("Wait exception: %s", e) + return delivered_batches, delivered_datapoints + + for fut, count in pending_futures: + if fut in done and fut.done() and not fut.cancelled(): + try: + result = fut.result() + if isinstance(result, PublishResult) and result.is_successful(): + delivered_batches += 1 + delivered_datapoints += result.datapoints_count + except Exception as e: + logger.warning("Future error: %s", e) + elif not fut.done(): + fut.cancel() + logger.warning("Cancelled future %r due to timeout.", getattr(fut, "uuid", None)) + + return delivered_batches, delivered_datapoints + + async def main(): stop_event = asyncio.Event() def _shutdown_handler(): stop_event.set() - asyncio.get_event_loop().run_until_complete(client.stop()) loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): @@ -74,6 +123,7 @@ def _shutdown_handler(): client.set_rpc_request_callback(rpc_request_callback) await client.connect() + client._message_adapter.datapoints_max_count = 0 logger.info("Connected to ThingsBoard.") sent_batches = 0 @@ -81,89 +131,52 @@ def _shutdown_handler(): delivered_datapoints = 0 pending_futures = [] - delivery_start_ts = None # Start time of the first successful delivery - delivery_end_ts = None # End time of last successful delivery + delivery_start_ts = None + delivery_end_ts = None try: + delivery_start_ts = time.perf_counter() + ts_now = int(datetime.now(UTC).timestamp() * 1000) + entries = [TimeseriesEntry(f"temperature{i}", randint(20, 40)) for i in range(BATCH_SIZE)] while not stop_event.is_set(): - ts_now = int(datetime.now(UTC).timestamp() * 1000) - entries = [ - TimeseriesEntry("temperature", randint(20, 40), ts=ts_now - i) - for i in range(BATCH_SIZE) - ] try: - future = await client.send_timeseries(entries) - if future: - pending_futures.append((future, BATCH_SIZE)) - sent_batches += 1 + fut = await client.send_timeseries(entries, wait_for_publish=False) + future_pairs = await collect_futures_with_counts(fut) + if future_pairs: + pending_futures.extend(future_pairs) else: - logger.warning("Telemetry batch dropped or not acknowledged.") + logger.warning("No child delivery futures detected (batch dropped?).") + + sent_batches += 1 + except Exception as e: logger.warning("Failed to publish telemetry batch: %s", e) if len(pending_futures) >= MAX_PENDING_BATCHES: - done, _ = await asyncio.wait( - [f for f, _ in pending_futures], timeout=FUTURE_TIMEOUT + delivered_batches, delivered_datapoints = await process_pending_futures( + pending_futures, delivered_batches, delivered_datapoints ) - - remaining = [] - for fut, batch_size in pending_futures: - if fut in done: - try: - result = fut.result() - if result is True: - delivered_batches += 1 - delivered_datapoints += batch_size - now = time.perf_counter() - delivery_start_ts = delivery_start_ts or now - delivery_end_ts = now - except asyncio.CancelledError: - logger.exception("Future was cancelled: %r, id: %r", fut, id(fut)) - logger.warning("Delivery future was cancelled: %r", fut) - except Exception as e: - logger.warning("Delivery future raised: %s", e) - else: - fut.cancel() - logger.warning("Cancelled delivery future after timeout: %r, future id: %r", fut, id(fut)) - # remaining.append((fut, batch_size)) - - pending_futures = [] + pending_futures.clear() if sent_batches % 10 == 0: logger.info("Sent %d batches so far...", sent_batches) - await asyncio.sleep(YIELD_DELAY) + await asyncio.sleep(0) # yield control efficiently finally: logger.info("Waiting for remaining telemetry batches to be acknowledged...") - done, _ = await asyncio.wait( - [f for f, _ in pending_futures], timeout=.01 - ) - - for fut, batch_size in pending_futures: - if fut in done: - try: - result = fut.result() - if result is True: - delivered_batches += 1 - delivered_datapoints += batch_size - now = time.perf_counter() - delivery_start_ts = delivery_start_ts or now - delivery_end_ts = now - except asyncio.CancelledError: - # logger.warning("Final delivery future was cancelled: %r", fut) - pass - except Exception as e: - logger.warning("Final delivery failed: %s", e) - else: - fut.cancel() - # logger.warning("Final delivery future timed out and was cancelled: %r", fut) + + if pending_futures: + delivered_batches, delivered_datapoints = await process_pending_futures( + pending_futures, delivered_batches, delivered_datapoints + ) + delivery_end_ts = time.perf_counter() await client.disconnect() logger.info("Disconnected cleanly.") - if delivery_start_ts is not None and delivery_end_ts is not None: + if delivery_start_ts and delivery_end_ts: delivery_duration = delivery_end_ts - delivery_start_ts logger.info("Delivered %d batches / %d datapoints in %.6f seconds (%.0f datapoints/sec)", delivered_batches, delivered_datapoints, delivery_duration, @@ -171,8 +184,9 @@ def _shutdown_handler(): else: logger.warning("No successful delivery occurred.") + if __name__ == "__main__": try: asyncio.run(main()) - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): print("Interrupted by user.") diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index 88a13ae..32f32ee 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -29,9 +29,11 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +response_received = asyncio.Event() async def attribute_request_callback(response: RequestedAttributeResponse): - logger.info("Received attribute response:", response) + logger.info("Received attribute response: %r", response) + response_received.set() async def main(): config = DeviceConfig() @@ -41,12 +43,20 @@ async def main(): client = DeviceClient(config) await client.connect() + # Send client attribute to have it available for request + await client.send_attributes({"currentTemperature": 22.5}) + + await asyncio.sleep(.1) + # Request specific attributes request = await AttributeRequest.build(["targetTemperature"], ["currentTemperature"]) await client.send_attribute_request(request, attribute_request_callback) logger.info("Attribute request sent. Waiting for response...") - await asyncio.sleep(5) + try: + await asyncio.wait_for(response_received.wait(), timeout=10) + except (asyncio.CancelledError, TimeoutError): + logger.info("Attribute request cancelled.") await client.stop() diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py index aa297a8..e7d1b05 100644 --- a/examples/device/send_client_side_rpc.py +++ b/examples/device/send_client_side_rpc.py @@ -46,7 +46,7 @@ async def main(): rpc_request = await RPCRequest.build("getTime", {}) try: response = await client.send_rpc_request(rpc_request) - logger.info("Received response:", response) + logger.info("Received response: %r", response) except TimeoutError: logger.info("RPC request timed out") @@ -58,4 +58,7 @@ async def main(): await client.stop() if __name__ == "__main__": - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Shutting down...") diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py index 49962ba..5b35635 100644 --- a/tb_mqtt_client/common/async_utils.py +++ b/tb_mqtt_client/common/async_utils.py @@ -12,24 +12,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Optional, Any +from typing import Union, Optional, Any, List, Set, Dict import asyncio +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult + +logger = get_logger(__name__) + + +class FutureMap: + def __init__(self): + self._child_to_parents: Dict[asyncio.Future, Set[asyncio.Future]] = {} + self._parent_to_remaining: Dict[asyncio.Future, Set[asyncio.Future]] = {} + + def register(self, parent: asyncio.Future, children: List[asyncio.Future]): + if parent in self._parent_to_remaining: + self._parent_to_remaining[parent].update(children) + else: + self._parent_to_remaining[parent] = set(children) + + for child in children: + self._child_to_parents.setdefault(child, set()).add(parent) + + def get_parents(self, child: asyncio.Future) -> List[asyncio.Future]: + return list(self._child_to_parents.get(child, [])) + + def child_resolved(self, child: asyncio.Future): + parents = self._child_to_parents.pop(child, set()) + for parent in parents: + remaining = self._parent_to_remaining.get(parent) + if remaining is not None: + remaining.discard(child) + if not remaining and not parent.done(): + all_children = list(remaining) + [child] + results = [] + for f in all_children: + if f.done() and not f.cancelled(): + result = f.result() + if isinstance(result, PublishResult): + results.append(result) + + if results: + parent.set_result(PublishResult.merge(results)) + else: + parent.set_result(None) + self._parent_to_remaining.pop(parent, None) + +future_map = FutureMap() async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, Any], stop_event: asyncio.Event, timeout: Optional[float]) -> Optional[Any]: - """ - Await the given future/coroutine until it completes, timeout expires, or stop_event is set. - - :param future_or_coroutine: An awaitable coroutine, asyncio.Future, or asyncio.Task. - :param stop_event: asyncio.Event that signals shutdown. - :param timeout: Optional timeout in seconds, -1 for no timeout, or None to wait indefinitely. - :return: The result if completed successfully, or None on timeout/stop. - """ if asyncio.iscoroutine(future_or_coroutine): main_task = asyncio.create_task(future_or_coroutine) elif asyncio.isfuture(future_or_coroutine): + if future_or_coroutine.done(): + return future_or_coroutine.result() main_task = future_or_coroutine else: raise TypeError("Expected coroutine or Future/Task") @@ -57,3 +96,34 @@ async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, finally: if not stop_task.done(): stop_task.cancel() + +async def await_and_resolve_original( + parent_futures: List[asyncio.Future], + child_futures: List[asyncio.Future] +): + try: + results = await asyncio.gather(*child_futures, return_exceptions=True) + + for child in child_futures: + future_map.child_resolved(child) + + for i, f in enumerate(parent_futures): + if f is not None and not f.done(): + first_result = next((r for r in results if not isinstance(r, Exception)), None) + first_exception = next((r for r in results if isinstance(r, Exception)), None) + + if first_exception and not first_result: + f.set_exception(first_exception) + logger.debug("Set exception for parent future #%d id=%r from child exception: %r", + i, getattr(f, 'uuid', f), first_exception) + else: + f.set_result(first_result) + logger.trace("Resolved parent future #%d id=%r with result: %r", + i, getattr(f, 'uuid', f), first_result) + + except Exception as e: + logger.error("Unexpected error while resolving parent delivery futures: %s", e) + for i, f in enumerate(parent_futures): + if f is not None and not f.done(): + f.set_exception(e) + logger.debug("Set fallback exception for parent future #%d id=%r", i, getattr(f, 'uuid', f)) diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index f3d715c..a1cc00a 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -13,21 +13,62 @@ # limitations under the License. import asyncio +import heapq import struct from collections import defaultdict -from typing import Callable +from time import monotonic +from typing import Callable, Tuple -from gmqtt.mqtt.constants import MQTTCommands -from gmqtt.mqtt.handler import MqttPackageHandler +from gmqtt import Client +from gmqtt.mqtt.constants import MQTTCommands, MQTTv50, MQTTv311 +from gmqtt.mqtt.handler import MqttPackageHandler, MQTTConnectError +from gmqtt.mqtt.package import PackageFactory from gmqtt.mqtt.property import Property from gmqtt.mqtt.protocol import BaseMQTTProtocol, MQTTProtocol -from gmqtt.mqtt.utils import unpack_variable_byte_integer +from gmqtt.mqtt.utils import unpack_variable_byte_integer, pack_variable_byte_integer from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage logger = get_logger(__name__) +class PublishPacket(PackageFactory): + @classmethod + def build_package(cls, message: MqttPublishMessage, protocol, mid: int = None) -> Tuple[int, bytes]: + dup_flag = 1 if message.dup else 0 + + command = MQTTCommands.PUBLISH | (dup_flag << 3) | (message.qos << 1) | (message.retain & 0x1) + + packet = bytearray() + packet.append(command) + + remaining_length = 2 + len(message.topic) + message.payload_size + prop_bytes = cls._build_properties_data(message.properties, protocol_version=protocol.proto_ver) + remaining_length += len(prop_bytes) + + if message.payload_size == 0: + logger.debug("Sending PUBLISH (q%d), '%s' (NULL payload)", message.qos, message.topic) + else: + logger.debug("Sending PUBLISH (q%d), '%s', ... (%d bytes)", message.qos, message.topic, message.payload_size) + + if message.qos > 0: + remaining_length += 2 + + packet.extend(pack_variable_byte_integer(remaining_length)) + cls._pack_str16(packet, message.topic) + + if message.qos > 0: + if mid is None: + mid = cls.id_generator.next_id() + packet.extend(struct.pack("!H", mid)) + + packet.extend(prop_bytes) + packet.extend(message.payload) + + return mid, packet + + class PatchUtils: DISCONNECT_REASON_CODES = { 0: "Normal disconnection", @@ -38,7 +79,7 @@ class PatchUtils: 131: "Implementation specific error", 132: "Not authorized", 133: "Server busy", - 134: "Server shutting down", + 134: "Bad credentials", 135: "Keep Alive timeout", 136: "Session taken over", 137: "Topic Filter invalid", @@ -56,11 +97,23 @@ class PatchUtils: 149: "Server moved", 150: "Shared Subscriptions not supported", 151: "Connection rate exceeded", - 152: "Maximum connect time", + 152: "Administrative action", 153: "Subscription Identifiers not supported", 154: "Wildcard Subscriptions not supported" } + def __init__(self, client: Client, stop_event: asyncio.Event, retry_interval: int = 1): + """ + Initialize PatchUtils with a client and retry interval. + + :param client: The MQTT client instance to patch. + :param retry_interval: Interval in seconds to retry connection. + """ + self.client = client + self.retry_interval = retry_interval + self._stop_event = stop_event + self._retry_task = None + @staticmethod def parse_mqtt_properties(packet: bytes) -> dict: """ @@ -89,236 +142,329 @@ def parse_mqtt_properties(packet: bytes) -> dict: return dict(properties_dict) -def extract_reason_code(packet): - """ - Extract the reason code from a disconnect packet. - - :param packet: The disconnect packet, which can be an object with a reason_code attribute or raw bytes - :return: The reason code if found, None otherwise - """ - reason_code = None - if packet: - if hasattr(packet, 'reason_code'): - reason_code = packet.reason_code - elif isinstance(packet, bytes) and len(packet) >= 2: - reason_code = packet[1] - - return reason_code - -def patch_mqtt_handler_disconnect(): - """ - Monkey-patch gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet to properly - handle server-initiated disconnect messages. - """ - try: - # Store the original method - original_handle_disconnect = MqttPackageHandler._handle_disconnect_packet # noqa - - # Define the patched method - def patched_handle_disconnect_packet(self, cmd, packet): - # Extract reason code - reason_code = 0 - if packet and len(packet) >= 1: - reason_code = packet[0] - - # Parse properties if available - properties = {} - if packet and len(packet) > 1: - try: - properties = PatchUtils.parse_mqtt_properties(packet[1:]) - except Exception as exc: - logger.warning("Failed to parse properties from disconnect packet: %s", exc) - - reason_desc = PatchUtils.DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") - logger.trace("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) - - # Call the original method to handle reconnection - # But don't call the on_disconnect callback, as we'll do that ourselves - # with the extracted reason_code and properties - self._clear_topics_aliases() - future = asyncio.ensure_future(self.reconnect(delay=True)) - future.add_done_callback(self._handle_exception_in_future) - - # Call the on_disconnect callback with the client, reason_code, properties, and None for exc - # since this is a server-initiated disconnect - self.on_disconnect(self, reason_code, properties, None) - - # Set a flag on the connection object to indicate that on_disconnect has been called - self._connection._on_disconnect_called = True - original_handle_disconnect(self, cmd, packet) - - # Apply the patch - MqttPackageHandler._handle_disconnect_packet = patched_handle_disconnect_packet - logger.debug("Successfully patched gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet") - return True - except (ImportError, AttributeError) as e: - logger.warning("Failed to patch gmqtt handler: %s", e) - return False - -def patch_handle_connack(client, on_connack_with_session_present_and_result_code: Callable[[object, int, int, dict], None]): - """ - Monkey-patch gmqtt.mqtt.handler.MqttPackageHandler._handle_connack_packet to add custom handling for - CONNACK packets, allowing for custom callbacks with session_present, result_code, and properties. - """ - try: - original_handler = MqttPackageHandler._handle_connack_packet - - def new_handle_connack_packet(self, cmd, packet): - try: - original_handler(self, cmd, packet) + @staticmethod + def extract_reason_code(packet): + """ + Extract the reason code from a disconnect packet. - session_present, reason_code = struct.unpack("!BB", packet[:2]) + :param packet: The disconnect packet, which can be an object with a reason_code attribute or raw bytes + :return: The reason code if found, None otherwise + """ + reason_code = None + if packet: + if hasattr(packet, 'reason_code'): + reason_code = packet.reason_code + elif isinstance(packet, bytes) and len(packet) >= 2: + reason_code = packet[1] - if len(packet) > 2: - props_payload = packet[2:] - properties = PatchUtils.parse_mqtt_properties(props_payload) - else: - properties = {} + return reason_code - logger.debug("CONNACK patched handler: session_present=%r, reason_code=%r, properties=%r", - session_present, reason_code, properties) + def patch_mqtt_handler_disconnect(patch_utils_instance): + """ + Monkey-patch gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet to properly + handle server-initiated disconnect messages. + """ + try: + # Store the original method + original_handle_disconnect = MqttPackageHandler._handle_disconnect_packet # noqa + + # Define the patched method + def patched_handle_disconnect_packet(self, cmd, packet): + # Extract reason code + reason_code = 0 + if packet and len(packet) >= 1: + reason_code = packet[0] + + # Parse properties if available + properties = {} + if packet and len(packet) > 1: + try: + properties = PatchUtils.parse_mqtt_properties(packet[1:]) + except Exception as exc: + logger.warning("Failed to parse properties from disconnect packet: %s", exc) + + reason_desc = PatchUtils.DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + logger.trace("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) + + # Call the original method to handle reconnection + # But don't call the on_disconnect callback, as we'll do that ourselves + # with the extracted reason_code and properties + self._clear_topics_aliases() + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + + # Call the on_disconnect callback with the client, reason_code, properties, and None for exc + # since this is a server-initiated disconnect + self.on_disconnect(self, reason_code, properties, None) + + # Set a flag on the connection object to indicate that on_disconnect has been called + self._connection._on_disconnect_called = True + original_handle_disconnect(self, cmd, packet) + + # Apply the patch + MqttPackageHandler._handle_disconnect_packet = patched_handle_disconnect_packet + logger.debug("Successfully patched gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt handler: %s", e) + return False + + def patch_handle_connack(patch_utils_instance, + on_connack_with_session_present_and_result_code: Callable[[object, int, int, dict], None]): + """ + Fully replaces gmqtt's _handle_connack_packet implementation, skipping internal QoS1 resend behavior + and invoking a custom callback instead of calling the original method. + """ + try: + def new_handle_connack_packet(self, cmd, packet): + try: + self._connected.set() + + (session_present, result_code) = struct.unpack("!BB", packet[:2]) + + if result_code != 0: + self._logger.warning('[CONNACK] %s', hex(result_code)) + self.failed_connections += 1 + if result_code == 1 and self.protocol_version == MQTTv50: + self._logger.info('[CONNACK] Downgrading to MQTT 3.1 protocol version') + MQTTProtocol.proto_ver = MQTTv311 + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + return + else: + self._error = MQTTConnectError(result_code) + asyncio.ensure_future(self.reconnect(delay=True)) + return + else: + self.failed_connections = 0 + + if len(packet) > 2: + properties, _ = self._parse_properties(packet[2:]) + if properties is None: + self._error = MQTTConnectError(10) + asyncio.ensure_future(self.disconnect()) + self._connack_properties = properties + self._update_keepalive_if_needed() + else: + properties = {} + + self._logger.debug('[CONNACK] session_present: %s, result: %s', hex(session_present), + hex(result_code)) + + on_connack_with_session_present_and_result_code( + patch_utils_instance.client, + session_present, + result_code, + properties + ) + + self.on_connect(self, session_present, result_code, self.properties) + + except Exception as e: + logger.exception("Error while handling CONNACK packet: %s", e) + + MqttPackageHandler._handle_connack_packet = new_handle_connack_packet + logger.debug("Successfully patched gmqtt.mqtt.handler._handle_connack_packet (full replacement)") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt handler: %s", e) + return False + + def patch_gmqtt_protocol_connection_lost(patch_utils_instance): + """ + Monkey-patch gmqtt.mqtt.protocol.BaseMQTTProtocol.connection_lost to suppress the + default "[CONN CLOSE NORMALLY]" log message, as we handle disconnect logging in our code. - on_connack_with_session_present_and_result_code(client, session_present, reason_code, properties) - except Exception as e: - logger.error("Error while handling CONNACK packet: %s", e, exc_info=True) - - MqttPackageHandler._handle_connack_packet = new_handle_connack_packet - logger.debug("Successfully patched gmqtt.mqtt.handler._handle_connack_packet") - return True - except (ImportError, AttributeError) as e: - logger.warning("Failed to patch gmqtt handler: %s", e) - return False - -def patch_gmqtt_protocol_connection_lost(): - """ - Monkey-patch gmqtt.mqtt.protocol.BaseMQTTProtocol.connection_lost to suppress the - default "[CONN CLOSE NORMALLY]" log message, as we handle disconnect logging in our code. - - Also, patch MQTTProtocol.connection_lost to include the reason code in the DISCONNECT package - and pass the exception to the handler. - """ - try: - original_base_connection_lost = BaseMQTTProtocol.connection_lost - def patched_base_connection_lost(self, exc): - self._connected.clear() - super(BaseMQTTProtocol, self).connection_lost(exc) - BaseMQTTProtocol.connection_lost = patched_base_connection_lost - - original_mqtt_connection_lost = MQTTProtocol.connection_lost - def patched_mqtt_connection_lost(self, exc): - super(MQTTProtocol, self).connection_lost(exc) - reason_code = 0 - properties = {} - - if exc: - # Determine reason code based on an exception type - if isinstance(exc, ConnectionRefusedError): - reason_code = 135 # Keep Alive timeout - elif isinstance(exc, TimeoutError): - reason_code = 135 # Keep Alive timeout - elif isinstance(exc, ConnectionResetError): - reason_code = 139 # Receive Maximum exceeded - elif isinstance(exc, ConnectionAbortedError): - reason_code = 136 # Session taken over - elif isinstance(exc, PermissionError): - reason_code = 132 # Not authorized - elif isinstance(exc, OSError): - reason_code = 130 # Protocol Error - else: - reason_code = 131 # Implementation specific error - - # Add an exception message to properties if available - if hasattr(exc, 'args') and exc.args: - properties['reason_string'] = [str(exc.args[0])] - - # Pack the reason code into a payload - payload = struct.pack('!B', reason_code) - - # Store the exception and properties in the connection object - # so they can be accessed by the handler - self._connection._disconnect_exc = exc - self._connection._disconnect_properties = properties - - # Put the DISCONNECT package into the connection's package queue - self._connection.put_package((MQTTCommands.DISCONNECT, payload)) - - if self._read_loop_future is not None: - self._read_loop_future.cancel() - self._read_loop_future = None - - self._queue = asyncio.Queue() - - MQTTProtocol.connection_lost = patched_mqtt_connection_lost - - # Also patch MqttPackageHandler.__call__ to pass the exception and properties to on_disconnect - original_call = MqttPackageHandler.__call__ - def patched_call(self, cmd, packet): - try: - if cmd == MQTTCommands.DISCONNECT and hasattr(self._connection, '_disconnect_exc'): - # This is a disconnect packet from connection_lost - # Extract reason code - reason_code = 0 - if packet and len(packet) >= 1: - reason_code = packet[0] - - # Get properties and exception from connection - properties = getattr(self._connection, '_disconnect_properties', {}) - exc = getattr(self._connection, '_disconnect_exc', None) - - # Check if on_disconnect has already been called - if (not hasattr(self._connection, '_on_disconnect_called') - or not self._connection._on_disconnect_called): # noqa - # Call on_disconnect with the extracted values - self._clear_topics_aliases() - future = asyncio.ensure_future(self.reconnect(delay=True)) - future.add_done_callback(self._handle_exception_in_future) - self.on_disconnect(self, reason_code, properties, exc) + Also, patch MQTTProtocol.connection_lost to include the reason code in the DISCONNECT package + and pass the exception to the handler. + """ + try: + original_base_connection_lost = BaseMQTTProtocol.connection_lost + def patched_base_connection_lost(self, exc): + self._connected.clear() + super(BaseMQTTProtocol, self).connection_lost(exc) + BaseMQTTProtocol.connection_lost = patched_base_connection_lost + + original_mqtt_connection_lost = MQTTProtocol.connection_lost + def patched_mqtt_connection_lost(self, exc): + super(MQTTProtocol, self).connection_lost(exc) + reason_code = 0 + properties = {} + + if exc: + if isinstance(exc, ConnectionRefusedError): + reason_code = 135 # Keep Alive timeout + elif isinstance(exc, TimeoutError): + reason_code = 135 # Keep Alive timeout + elif isinstance(exc, ConnectionResetError): + reason_code = 139 # Receive Maximum exceeded + elif isinstance(exc, ConnectionAbortedError): + reason_code = 136 # Session taken over + elif isinstance(exc, PermissionError): + reason_code = 132 # Not authorized + elif isinstance(exc, OSError): + reason_code = 130 # Protocol Error + else: + reason_code = 131 # Implementation specific error + + if hasattr(exc, 'args') and exc.args: + properties['reason_string'] = [str(exc.args[0])] + + payload = struct.pack('!B', reason_code) + + self._connection._disconnect_exc = exc + self._connection._disconnect_properties = properties + + self._connection.put_package((MQTTCommands.DISCONNECT, payload)) + + if self._read_loop_future is not None: + self._read_loop_future.cancel() + self._read_loop_future = None + + self._queue = asyncio.Queue() + + MQTTProtocol.connection_lost = patched_mqtt_connection_lost + + original_call = MqttPackageHandler.__call__ + def patched_call(self, cmd, packet): + try: + if cmd == MQTTCommands.DISCONNECT and hasattr(self._connection, '_disconnect_exc'): + reason_code = 0 + if packet and len(packet) >= 1: + reason_code = packet[0] + + properties = getattr(self._connection, '_disconnect_properties', {}) + exc = getattr(self._connection, '_disconnect_exc', None) + + if (not hasattr(self._connection, '_on_disconnect_called') + or not self._connection._on_disconnect_called): # noqa + self._clear_topics_aliases() + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + self.on_disconnect(self, reason_code, properties, exc) + return None + + return original_call(self, cmd, packet) + except Exception as exception: + logger.error('[ERROR HANDLE PKG]', exc_info=exception) return None - # For other commands, call the original method - return original_call(self, cmd, packet) - except Exception as exception: - logger.error('[ERROR HANDLE PKG]', exc_info=exception) - return None - - MqttPackageHandler.__call__ = patched_call - - logger.debug("Successfully patched gmqtt.mqtt.protocol connection_lost methods") - return True - except (ImportError, AttributeError) as e: - logger.warning("Failed to patch gmqtt protocol: %s", e) - return False + MqttPackageHandler.__call__ = patched_call + logger.debug("Successfully patched gmqtt.mqtt.protocol connection_lost methods") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt protocol: %s", e) + return False -def patch_gmqtt_puback(client, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): - """ - Monkey-patch gmqtt.Client instance to intercept PUBACK reason codes and properties. - - :param client: GMQTTClient instance - :param on_puback_with_reason_and_properties: Callback with (mid, reason_code, properties_dict) - """ - original_handler = MqttPackageHandler._handle_puback_packet - - if original_handler is None: - logger.error("Could not find _handle_puback_packet in base class.") - return - - def wrapped_handle_puback(self, cmd, packet): + def patch_puback_handling(self, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): + original_handler = MqttPackageHandler._handle_puback_packet + def wrapped_handle_puback(self, cmd, packet): + try: + mid = struct.unpack("!H", packet[:2])[0] + reason_code = 0 + properties = {} + if len(packet) > 2: + reason_code = packet[2] + if len(packet) > 3: + props_payload = packet[3:] + properties = PatchUtils.parse_mqtt_properties(props_payload) + logger.trace("PUBACK received for mid=%r, reason=%r", mid, reason_code) + if hasattr(self._connection, 'persistent_storage'): + self._connection.persistent_storage.remove(mid) + on_puback_with_reason_and_properties(mid, reason_code, properties) + except Exception as e: + logger.exception("Error while handling PUBACK with properties: %s", e) + return original_handler(self, cmd, packet) + MqttPackageHandler._handle_puback_packet = wrapped_handle_puback + logger.debug("Patched _handle_puback_packet for QoS1 support.") + + def patch_storage(patch_utils_instance): + async def pop_message_with_tm(): + (tm, mid, raw_package) = heapq.heappop(patch_utils_instance.client._persistent_storage._queue) + + patch_utils_instance.client._persistent_storage._check_empty() + return tm, mid, raw_package + patch_utils_instance.client._persistent_storage.pop_message = pop_message_with_tm + + async def _retry_loop(self): + logger.debug("QoS1 retry loop started.") + self.patch_storage() try: - mid = struct.unpack("!H", packet[:2])[0] - reason_code = 0 - properties = {} - - if len(packet) > 2: - reason_code = packet[2] - if len(packet) > 3: - props_payload = packet[3:] - properties = PatchUtils.parse_mqtt_properties(props_payload) - - on_puback_with_reason_and_properties(mid, reason_code, properties) + while not self._stop_event.is_set(): + retry_list = [] + current_tm = asyncio.get_event_loop().time() + + for _ in range(100): + if self._stop_event.is_set(): + break + + try: + msg = await asyncio.wait_for(self.client._persistent_storage.pop_message(), timeout=0.1) + except asyncio.TimeoutError: + break + except IndexError: + break + except Exception as e: + logger.warning("Error popping message: %s", e) + break + + if msg is None: + break + + retry_list.append(msg) + + for (tm, mid, mqtt_msg) in retry_list: + if current_tm - tm > self.retry_interval and self.client.is_connected: + mqtt_msg.dup = True + logger.error("Resending PUBLISH message with mid=%r, topic=%s", mid, mqtt_msg.topic) + + protocol = self.client._connection._protocol + + if protocol: + try: + mid, rebuilt = PublishPacket.build_package( + message=mqtt_msg, + protocol=protocol, + mid=mid + ) + self.client._connection.send_package(rebuilt) + logger.trace("Retransmitted message mid=%r", mid) + except Exception as e: + logger.warning("Error during retransmission: %s", e) + else: + logger.warning("Cannot retransmit, MQTT protocol unavailable.") + + heapq.heappush(self.client._persistent_storage._queue, (tm, mid, mqtt_msg)) + + await asyncio.sleep(self.retry_interval) + except asyncio.CancelledError: + logger.debug("Retry loop cancelled.") except Exception as e: - logger.exception("Error while handling PUBACK with properties: %s", e) - - return original_handler(self, cmd, packet) - - MqttPackageHandler._handle_puback_packet = wrapped_handle_puback + logger.exception("Unexpected error in retry loop: %s", e) + finally: + logger.debug("QoS1 retry loop stopped.") + + def start_retry_task(self): + if not self._stop_event.is_set() and not self._retry_task: + self._retry_task = asyncio.create_task(self._retry_loop()) + logger.debug("Retry task started.") + + async def stop_retry_task(self): + if self._retry_task: + self._stop_event.set() + try: + await asyncio.wait_for(self._retry_task, timeout=2) + except asyncio.TimeoutError: + logger.debug("Retry task did not finish in time, cancelling...") + self._retry_task.cancel() + try: + await self._retry_task + except asyncio.CancelledError: + logger.debug("Retry task cancelled.") + self._retry_task = None + self._stop_event.clear() + + def apply(self, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): + self.patch_puback_handling(on_puback_with_reason_and_properties) + self.start_retry_task() diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py new file mode 100644 index 0000000..61cf5fd --- /dev/null +++ b/tb_mqtt_client/common/mqtt_message.py @@ -0,0 +1,75 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Future +from typing import Union, List +from uuid import uuid4 + +from gmqtt import Message + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage + +logger = get_logger(__name__) + + +class MqttPublishMessage(Message): + """ + A custom Publish MQTT message class that extends the gmqtt Message class. + Contains additional information like datapoints, to avoid rate limits exceeding. + """ + def __init__(self, + topic: str, + payload: Union[bytes, GatewayUplinkMessage, DeviceUplinkMessage], + qos: int = 1, + retain: bool = False, + datapoints: int = 0, + delivery_futures = None, + **kwargs): + """ + Initialize the MqttMessage with topic, payload, QoS, retain flag, and datapoints. + """ + self.prepared = False + self.payload = payload + if isinstance(payload, bytes): + super().__init__(topic, payload, qos, retain) + self.topic = topic + self.qos = qos + if self.qos < 0 or self.qos > 1: + logger.warning(f"Invalid QoS {self.qos} for topic {topic}, using default QoS 1") + self.qos = 1 + self.dup = False + self.retain = retain + self.message_id = None + self.datapoints = datapoints + self.properties = kwargs + self._is_sent = False + self.delivery_futures: List[Future] = delivery_futures + if not delivery_futures: + delivery_future = Future() + delivery_future.uuid = uuid4() + self.delivery_futures = [delivery_future] + logger.trace(f"Created MqttMessage with topic: {topic}, payload type: {type(payload).__name__}, " + f"datapoints: {datapoints}, delivery_future id: {self.delivery_futures[0].uuid}") + + + def mark_as_sent(self, message_id: int): + """Mark the message as sent.""" + self.message_id = message_id + self._is_sent = True + + def is_sent(self) -> bool: + """Check if the message has been sent.""" + return self._is_sent diff --git a/tb_mqtt_client/common/publish_result.py b/tb_mqtt_client/common/publish_result.py index 6da17b1..ceaeb79 100644 --- a/tb_mqtt_client/common/publish_result.py +++ b/tb_mqtt_client/common/publish_result.py @@ -11,17 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + class PublishResult: - def __init__(self, topic: str, qos: int, message_id: int, payload_size: int, reason_code: int): + def __init__(self, topic: str, qos: int, message_id: int, payload_size: int, reason_code: int, + datapoints_count: int = 0): self.topic = topic self.qos = qos self.message_id = message_id self.payload_size = payload_size self.reason_code = reason_code + self.datapoints_count = datapoints_count def __repr__(self): - return f"PublishResult(topic={self.topic}, qos={self.qos}, message_id={self.message_id}, payload_size={self.payload_size}, reason_code={self.reason_code})" + return (f"PublishResult(topic={self.topic}, " + f"qos={self.qos}, " + f"message_id={self.message_id}, " + f"payload_size={self.payload_size}, " + f"reason_code={self.reason_code}, " + f"datapoints_count={self.datapoints_count})") def as_dict(self) -> dict: return { @@ -29,7 +38,8 @@ def as_dict(self) -> dict: "qos": self.qos, "message_id": self.message_id, "payload_size": self.payload_size, - "reason_code": self.reason_code + "reason_code": self.reason_code, + "datapoints_count": self.datapoints_count } def is_successful(self) -> bool: @@ -37,3 +47,25 @@ def is_successful(self) -> bool: Check if the publish operation was successful based on the reason code. """ return self.reason_code == 0 + + + @staticmethod + def merge(results: List['PublishResult']) -> 'PublishResult': + if not results: + raise ValueError("No publish results to merge.") + + topic = results[0].topic + qos = results[0].qos + message_id = -1 + reason_code = 0 if all(r.reason_code == 0 for r in results) else -1 + payload_size = sum(r.payload_size for r in results) + datapoints_count = sum(r.datapoints_count for r in results) + + return PublishResult( + topic=topic, + qos=qos, + message_id=message_id, + payload_size=payload_size, + reason_code=reason_code, + datapoints_count=datapoints_count + ) diff --git a/tb_mqtt_client/entities/data/data_entry.py b/tb_mqtt_client/entities/data/data_entry.py index 641cbcf..dcbb999 100644 --- a/tb_mqtt_client/entities/data/data_entry.py +++ b/tb_mqtt_client/entities/data/data_entry.py @@ -19,6 +19,8 @@ class DataEntry: + __slots__ = ("__key", "__value", "__ts", "__size") + def __init__(self, key: str, value: JSONCompatibleType, ts: Optional[int] = None): validate_json_compatibility(value) self.__key = key @@ -30,10 +32,9 @@ def __repr__(self): return f"DataEntry(key={self.key}, value={self.value}, ts={self.ts})" def __estimate_size(self) -> int: - if self.ts is not None: - return len(dumps({"ts": self.ts, "values": {self.key: self.value}})) - else: - return len(dumps({self.key: self.value})) + if self.__ts is not None: + return len(dumps({"ts": self.__ts, "values": {self.__key: self.__value}})) + return len(dumps({self.__key: self.__value})) @property def size(self) -> int: diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index ebe75d9..b4da805 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -16,8 +16,9 @@ from dataclasses import dataclass from types import MappingProxyType from typing import List, Optional, Union, OrderedDict, Tuple, Mapping +from uuid import uuid4 -from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry @@ -140,13 +141,17 @@ def add_delivery_futures(self, futures: Union[ if not isinstance(futures, list): futures = [futures] if futures: - logger.debug("Created delivery futures: %r", [id(future) for future in futures]) + if logger.isEnabledFor(TRACE_LEVEL): + logger.debug("Created delivery futures: %r", [future.uuid for future in futures]) self._delivery_futures.extend(futures) return self def build(self) -> DeviceUplinkMessage: if not self._delivery_futures: - self._delivery_futures = [asyncio.get_event_loop().create_future()] + delivery_future = asyncio.get_event_loop().create_future() + delivery_future.uuid = uuid4() + logger.trace("No delivery futures provided, creating a default future: %r", delivery_future.uuid) + self._delivery_futures = [delivery_future] return DeviceUplinkMessage.build( device_name=self._device_name, device_profile=self._device_profile, diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py index e84c302..9300399 100644 --- a/tb_mqtt_client/entities/data/rpc_request.py +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -75,3 +75,6 @@ def to_payload_format(self) -> Dict[str, Any]: if self.params is not None: data["params"] = self.params return data + + def __str__(self): + return f"RPCRequest(id={self.request_id}, method={self.method}, params={self.params})" diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index bfcada4..12ab503 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from types import MappingProxyType from typing import List, Optional, Union, OrderedDict, Mapping +from uuid import uuid4 from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult @@ -148,7 +149,9 @@ def add_delivery_futures(self, futures: Union[ def build(self) -> GatewayUplinkMessage: if not self._delivery_futures: - self._delivery_futures = [asyncio.get_event_loop().create_future()] + delivery_future = asyncio.get_event_loop().create_future() + delivery_future.uuid = uuid4() + self._delivery_futures = [delivery_future] return GatewayUplinkMessage.build( # noqa device_name=self._device_name, device_profile=self._device_profile, diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index cde006c..96ab025 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -23,6 +23,7 @@ from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.provisioning_client import ProvisioningClient from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE @@ -163,8 +164,7 @@ async def stop(self): if self._message_queue: await self._message_queue.shutdown() - if self._mqtt_manager.is_connected(): - await self._mqtt_manager.disconnect() + await self._mqtt_manager.stop() logger.info("DeviceClient stopped.") @@ -199,31 +199,30 @@ async def send_timeseries( None if no data is sent. """ message = self._build_uplink_message_for_telemetry(data) - topic = mqtt_topics.DEVICE_TELEMETRY_TOPIC - futures = await self._message_queue.publish( - topic=topic, + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, payload=message, - datapoints_count=message.timeseries_datapoint_count(), - qos=qos or self._config.qos + qos=qos or self._config.qos, + datapoints_count=message.timeseries_datapoint_count() ) + delivery_future = mqtt_message.delivery_futures - if not futures: - logger.warning("No publish futures were returned from message queue") - return None + await self._message_queue.publish(mqtt_message) if not wait_for_publish: - return futures[0] if len(futures) == 1 else futures + return delivery_future - results = [] - for fut in futures: - try: - result = await await_or_stop(fut, timeout=timeout, stop_event=self._stop_event) - except TimeoutError: - logger.warning("Timeout while waiting for telemetry publish result") - result = PublishResult(topic, qos, -1, message.size, -1) - results.append(result) + if isinstance(delivery_future, list): + delivery_future = delivery_future[0] - return results[0] if len(results) == 1 else results + logger.info("Delivery future id in device client.send_timeseries: %r", delivery_future.uuid) + + try: + result = await await_or_stop(delivery_future, timeout=1, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(mqtt_message.topic, qos, -1, message.size, -1) + return result async def send_attributes( self, @@ -231,22 +230,25 @@ async def send_attributes( qos: int = None, wait_for_publish: bool = True, timeout: int = BaseClient.DEFAULT_TIMEOUT - ) -> Union[PublishResult, List[PublishResult], None]: + ) -> Union[PublishResult, List[PublishResult], None, Future[PublishResult], List[Future[PublishResult]]]: message = self._build_uplink_message_for_attributes(attributes) - topic = mqtt_topics.DEVICE_ATTRIBUTES_TOPIC - futures = await self._message_queue.publish( - topic=topic, + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, payload=message, - datapoints_count=message.attributes_datapoint_count(), - qos=qos or self._config.qos + qos=qos or self._config.qos, + datapoints_count=message.attributes_datapoint_count() ) + await self._message_queue.publish(mqtt_message) + + futures = mqtt_message.delivery_futures + if not futures: logger.warning("No publish futures were returned from message queue") return None if not wait_for_publish: - return None + return futures results = [] for fut in futures: @@ -254,7 +256,7 @@ async def send_attributes( result = await await_or_stop(fut, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for attribute publish result") - result = PublishResult(topic, qos, -1, message.size, -1) + result = PublishResult(mqtt_message.topic, qos, -1, message.size, -1) results.append(result) return results[0] if len(results) == 1 else results @@ -267,16 +269,11 @@ async def send_rpc_request( timeout: Optional[float] = BaseClient.DEFAULT_TIMEOUT ) -> Union[RPCResponse, Awaitable[RPCResponse], None]: request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() - topic, payload = self._message_adapter.build_rpc_request(rpc_request) - + message_to_send = self._message_adapter.build_rpc_request(rpc_request) + message_to_send.qos = self._config.qos response_future = self._rpc_response_handler.register_request(request_id, callback) - await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=0, - qos=self._config.qos - ) + await self._message_queue.publish(message_to_send) if not wait_for_publish: return response_future @@ -293,40 +290,50 @@ async def send_rpc_request( mqtt_topics.build_device_rpc_response_topic(rpc_request.request_id), e) async def send_rpc_response(self, response: RPCResponse): - topic, payload = self._message_adapter.build_rpc_response(response) - await self._message_queue.publish(topic=topic, - payload=payload, - datapoints_count=0, - qos=self._config.qos) + mqtt_message = self._message_adapter.build_rpc_response(response) + mqtt_message.qos = self._config.qos + await self._message_queue.publish(mqtt_message) + + delivery_future = mqtt_message.delivery_futures + + if isinstance(delivery_future, list): + delivery_future = delivery_future[0] + try: + return await await_or_stop(delivery_future, timeout=BaseClient.DEFAULT_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for RPC response publish result") + return PublishResult(mqtt_message.topic, mqtt_message.qos, -1, len(mqtt_message.payload), -1) async def send_attribute_request(self, attribute_request: AttributeRequest, - callback: Callable[[RequestedAttributeResponse], Awaitable[None]], ): + callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): await self._requested_attribute_response_handler.register_request(attribute_request, callback) - topic, payload = self._message_adapter.build_attribute_request(attribute_request) + mqtt_message = self._message_adapter.build_attribute_request(attribute_request) + mqtt_message.qos = self._config.qos - await self._message_queue.publish(topic=topic, - payload=payload, - datapoints_count=0, - qos=self._config.qos) + await self._message_queue.publish(mqtt_message) async def claim_device(self, claim_request: ClaimRequest, wait_for_publish: bool = True, timeout: float = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: - topic, payload = self._message_adapter.build_claim_request(claim_request) - publish_future = await self._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) - if isinstance(publish_future, list): - publish_future = publish_future[0] + mqtt_message = self._message_adapter.build_claim_request(claim_request) + mqtt_message.qos = self._config.qos + + delivery_future = mqtt_message.delivery_futures + + await self._message_queue.publish(mqtt_message) + if isinstance(delivery_future, list): + delivery_future = delivery_future[0] if wait_for_publish: try: - return await await_or_stop(publish_future, timeout=timeout, stop_event=self._stop_event) + return await await_or_stop(delivery_future, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for claiming publish result") - return PublishResult(topic, 1, -1, len(payload), -1) + return PublishResult(mqtt_message.topic, 1, -1, len(mqtt_message.payload_size), -1) else: - return publish_future + return delivery_future def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): self._attribute_updates_handler.set_callback(callback) @@ -344,16 +351,13 @@ async def _on_connect(self): return self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, - self._handle_rpc_request) # noqa - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, - self._handle_requested_attribute_response) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, self._handle_requested_attribute_response) # noqa # RPC responses are handled by the RPCResponseHandler, which is already registered async def _on_disconnect(self): logger.info("Device client disconnected.") self._requested_attribute_response_handler.clear() - self._rpc_response_handler.clear() async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Union[ RPCResponse, None]: diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py index c3d42b0..979ba5e 100644 --- a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -14,6 +14,7 @@ import asyncio from typing import Dict, Union, Awaitable, Callable, Optional, Tuple +from uuid import uuid4 from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -53,6 +54,7 @@ def register_request(self, request_id: Union[str, int], if request_id in self._pending_rpc_requests: raise RuntimeError(f"Request ID {request_id} is already registered.") future = asyncio.get_event_loop().create_future() + future.uuid = uuid4() self._pending_rpc_requests[request_id] = future, callback return future diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index 952d55b..c92252e 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -15,13 +15,15 @@ import asyncio from abc import ABC, abstractmethod from itertools import chain -from collections import defaultdict +from collections import defaultdict, deque from datetime import UTC, datetime -from typing import Any, Dict, List, Tuple, Optional, Union +from typing import Any, Dict, List, Tuple, Optional, Union, DefaultDict, Set from orjson import dumps, loads -from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.async_utils import await_and_resolve_original, future_map +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC from tb_mqtt_client.entities.data.attribute_request import AttributeRequest @@ -32,7 +34,6 @@ from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.message_splitter import MessageSplitter logger = get_logger(__name__) @@ -45,10 +46,9 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio max_payload_size, max_datapoints) @abstractmethod - def build_uplink_payloads( + def build_uplink_messages( self, - messages: List[DeviceUplinkMessage] - ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, @@ -57,7 +57,7 @@ def build_uplink_payloads( pass @abstractmethod - def build_attribute_request(self, request: AttributeRequest) -> Tuple[str, bytes]: + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: """ Build the payload for an attribute request response. This method should return a tuple of topic and payload bytes. @@ -65,14 +65,14 @@ def build_attribute_request(self, request: AttributeRequest) -> Tuple[str, bytes pass @abstractmethod - def build_claim_request(self, claim_request) -> Tuple[str, bytes]: + def build_claim_request(self, claim_request) -> MqttPublishMessage: """ Build the payload for a claim request. This method should return a tuple of topic and payload bytes. """ @abstractmethod - def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: """ Build the payload for an RPC request. This method should return a tuple of topic and payload bytes. @@ -80,7 +80,7 @@ def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: pass @abstractmethod - def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: """ Build the payload for an RPC response. This method should return a tuple of topic and payload bytes. @@ -239,54 +239,84 @@ def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, def splitter(self) -> MessageSplitter: return self._splitter - def build_uplink_payloads(self, messages: List[DeviceUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: - """ - Build a list of topic-payload pairs from the given messages. - Each pair consists of a topic string, payload bytes, the number of datapoints, - and a list of futures for delivery confirmation. - """ - try: - if not messages: - logger.trace("No messages to process in build_topic_payloads.") - return [] + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + if not messages: + logger.trace("No messages to process in build_uplink_messages.") + return [] - result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] - device_groups: Dict[str, List[DeviceUplinkMessage]] = defaultdict(list) + result: List[MqttPublishMessage] = [] + device_groups = defaultdict(list) + qos = messages[0].qos - for msg in messages: - device_name = msg.device_name - device_groups[device_name].append(msg) - logger.trace("Queued message for device='%s'", device_name) + for mqtt_msg in messages: + payload = mqtt_msg.payload + if isinstance(payload, DeviceUplinkMessage): + device_groups[payload.device_name].append(mqtt_msg) + logger.trace("Queued DeviceUplinkMessage for device='%s'", payload.device_name) + else: + logger.warning("Unsupported payload type '%s', skipping", type(payload).__name__) - logger.trace("Processing %d device group(s).", len(device_groups)) + for device_name, group_msgs in device_groups.items(): + telemetry_msgs = [m for m in group_msgs if m.payload.has_timeseries()] + attr_msgs = [m for m in group_msgs if m.payload.has_attributes()] - for device, device_msgs in device_groups.items(): - telemetry_msgs = [m for m in device_msgs if m.has_timeseries()] - attr_msgs = [m for m in device_msgs if m.has_attributes()] - logger.trace("Device '%s' - telemetry: %d, attributes: %d", - device, len(telemetry_msgs), len(attr_msgs)) + built_child_messages: List[MqttPublishMessage] = [] - for ts_batch in self._splitter.split_timeseries(telemetry_msgs): - payload = JsonMessageAdapter.build_payload(ts_batch, True) + if telemetry_msgs: + ts_messages = [m.payload for m in telemetry_msgs] + for ts_batch in self._splitter.split_timeseries(ts_messages): + payload_bytes = JsonMessageAdapter.build_payload(ts_batch, True) count = ts_batch.timeseries_datapoint_count() - result.append((DEVICE_TELEMETRY_TOPIC, payload, count, ts_batch.get_delivery_futures())) - logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) - - for attr_batch in self._splitter.split_attributes(attr_msgs): - payload = JsonMessageAdapter.build_payload(attr_batch, False) + child_futures = ts_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=DEVICE_TELEMETRY_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=child_futures + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace( + "Built telemetry payload for '%s' with %d datapoints, futures=%r", + device_name, count, [f.uuid for f in child_futures] + ) + + if attr_msgs: + attr_messages = [m.payload for m in attr_msgs] + for attr_batch in self._splitter.split_attributes(attr_messages): + payload_bytes = JsonMessageAdapter.build_payload(attr_batch, False) count = len(attr_batch.attributes) - result.append((DEVICE_ATTRIBUTES_TOPIC, payload, count, attr_batch.get_delivery_futures())) - logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) - - logger.trace("Generated %d topic-payload entries.", len(result)) - - return result - except Exception as e: - logger.error("Error building topic-payloads: %s", str(e)) - logger.debug("Exception details: %s", e, exc_info=True) - raise - - def build_attribute_request(self, request: AttributeRequest) -> Tuple[str, bytes]: + child_futures = attr_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=DEVICE_ATTRIBUTES_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=child_futures + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace( + "Built attribute payload for '%s' with %d attributes, futures=%r", + device_name, count, [f.uuid for f in child_futures] + ) + + # Register child futures to all original parent futures + parent_futures = [f for m in group_msgs for f in (m.delivery_futures or [])] + for parent in parent_futures: + for child_msg in built_child_messages: + for child in child_msg.delivery_futures or []: + future_map.register(parent, [child]) + + logger.trace("Generated %d topic-payload entries.", len(result)) + return result + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: """ Build the payload for an attribute request response. :param request: The AttributeRequest to build the payload for. @@ -298,9 +328,9 @@ def build_attribute_request(self, request: AttributeRequest) -> Tuple[str, bytes topic = mqtt_topics.build_device_attributes_request_topic(request.request_id) payload = dumps(request.to_payload_format()) logger.trace("Built attribute request payload for request: %r", request) - return topic, payload + return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) - def build_claim_request(self, claim_request) -> Tuple[str, bytes]: + def build_claim_request(self, claim_request) -> MqttPublishMessage: """ Build the payload for a claim request. :param claim_request: The ClaimRequest to build the payload for. @@ -312,9 +342,9 @@ def build_claim_request(self, claim_request) -> Tuple[str, bytes]: topic = mqtt_topics.DEVICE_CLAIM_TOPIC payload = dumps(claim_request.to_payload_format()) logger.trace("Built claim request payload: %r", claim_request) - return topic, payload + return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) - def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: """ Build the payload for an RPC request. :param rpc_request: The RPC request to build the payload for. @@ -327,9 +357,10 @@ def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: topic = mqtt_topics.DEVICE_RPC_REQUEST_TOPIC + str(rpc_request.request_id) logger.trace("Built RPC request payload for request ID=%d with payload: %r", rpc_request.request_id, payload) - return topic, payload + message_to_send = MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) + return message_to_send - def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: """ Build the payload for an RPC response. :param rpc_response: The RPC response to build the payload for. @@ -341,7 +372,7 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: payload = dumps(rpc_response.to_payload_format()) topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) - return topic, payload + return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) def build_provision_request(self, provision_request: 'ProvisioningRequest') -> Tuple[str, bytes]: """ @@ -409,17 +440,17 @@ def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def pack_timeseries(msg: 'DeviceUplinkMessage') -> List[Dict[str, Any]]: + def pack_timeseries(msg: 'DeviceUplinkMessage') -> Union[Dict[str, Any], List[Dict[str, Any]]]: now_ts = int(datetime.now(UTC).timestamp() * 1000) - if all(entry.ts is None for entry in chain.from_iterable(msg.timeseries.values())): - packed = { - "ts": now_ts, - "values": {entry.key: entry.value for entry in chain.from_iterable(msg.timeseries.values())} - } - else: - packed = [ - {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} - for entry in chain.from_iterable(msg.timeseries.values()) - ] - return packed + entries = list(chain.from_iterable(msg.timeseries.values())) + + if all(entry.ts is None for entry in entries): + return {entry.key: entry.value for entry in entries} + + grouped: Dict[int, Dict[str, Any]] = defaultdict(dict) + for entry in entries: + ts = entry.ts or now_ts + grouped[ts][entry.key] = entry.value + + return [{"ts": ts, "values": values} for ts, values in grouped.items()] diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 8336fdc..402939f 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -22,6 +22,7 @@ from orjson import loads, dumps from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ @@ -47,10 +48,10 @@ class GatewayMessageAdapter(ABC): """ @abstractmethod - def build_uplink_payloads( + def build_uplink_messages( self, messages: List[GatewayUplinkMessage] - ) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + ) -> List[MqttPublishMessage]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, @@ -137,7 +138,7 @@ class JsonGatewayMessageAdapter(GatewayMessageAdapter): Builds uplink payloads from uplink message objects and parses JSON payloads into GatewayEvent objects. """ - def build_uplink_payloads(self, messages: List[GatewayUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: + def build_uplink_messages(self, messages: List[GatewayUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index c9be904..24f7356 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -14,9 +14,10 @@ import asyncio from contextlib import suppress -from typing import List, Optional, Union, Tuple, Dict, Callable +from typing import List, Optional -from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage @@ -51,8 +52,8 @@ def __init__(self, self._telemetry_rate_limit = telemetry_rate_limit self._telemetry_dp_rate_limit = telemetry_dp_rate_limit self._backpressure = self._mqtt_manager.backpressure - # Queue expects tuples of (topic, payload, delivery_futures, datapoints_count, qos) - self._queue: asyncio.Queue[Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = asyncio.Queue(maxsize=max_queue_size) + # Queue expects tuples of (mqtt_message, delivery_futures) + self._queue: asyncio.Queue[MqttPublishMessage] = asyncio.Queue(maxsize=max_queue_size) self._pending_queue_tasks: set[asyncio.Task] = set() self._active = asyncio.Event() self._wakeup_event = asyncio.Event() @@ -62,136 +63,125 @@ def __init__(self, self._gateway_adapter = gateway_message_adapter self._loop_task = asyncio.create_task(self._dequeue_loop()) self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) + asyncio.create_task(self.print_queue_statistics()) logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", max_queue_size, self._batch_max_time, batch_collect_max_count) - async def publish(self, topic: str, payload: Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], datapoints_count: int, qos: int) -> Optional[List[asyncio.Future[PublishResult]]]: - delivery_futures = payload.get_delivery_futures() if isinstance(payload, DeviceUplinkMessage) or isinstance(payload, GatewayUplinkMessage) else [asyncio.Future()] + async def publish(self, message: MqttPublishMessage) -> Optional[List[asyncio.Future[PublishResult]]]: try: - logger.trace("publish() received delivery future id: %r for topic=%s", - id(delivery_futures[0]) if delivery_futures else -1, topic) - self._queue.put_nowait((topic, payload, delivery_futures, datapoints_count, qos)) - logger.trace("Enqueued message: topic=%s, datapoints=%d, type=%s", - topic, datapoints_count, type(payload).__name__) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace(f"Pushing message to queue with delivery futures: {[f.uuid for f in message.delivery_futures]}") + self._queue.put_nowait(message) except asyncio.QueueFull: - logger.error("Message queue full. Dropping message for topic %s", topic) - for future in delivery_futures: + logger.error("Message queue full. Dropping message for topic %s", message.topic) + for future in message.delivery_futures: if future: - future.set_result(PublishResult(topic, qos, -1, len(payload), -1)) - return delivery_futures or None + future.set_result(PublishResult(message.topic, message.qos, -1, len(message.payload), -1)) async def _dequeue_loop(self): logger.debug("MessageQueue dequeue loop started.") while self._active.is_set() and not self._main_stop_event.is_set(): try: - # topic, payload, delivery_futures_or_none, count = await asyncio.wait_for(asyncio.get_event_loop().create_task(self._queue.get()), timeout=self._BATCH_TIMEOUT) - topic, payload, delivery_futures_or_none, datapoints, qos = await self._wait_for_message() - logger.trace("MessageQueue dequeue: topic=%s, payload=%r, count=%d", - topic, payload, datapoints) - logger.trace("Dequeued message: delivery_future id: %r topic=%s, type=%s, datapoints=%d", - id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1, - topic, type(payload).__name__, datapoints) - await asyncio.sleep(0) # cooperative yield - except asyncio.TimeoutError: - logger.trace("Dequeue wait timed out. Yielding...") - await asyncio.sleep(0.001) - continue - except asyncio.CancelledError: - break - except Exception as e: - logger.warning("Unexpected error in dequeue loop: %s", e) - continue + try: + message = await self._wait_for_message() + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace(f"Dequed message with delivery futures: {[f.uuid for f in message.delivery_futures]}") + await asyncio.sleep(0) # cooperative yield + except asyncio.TimeoutError: + logger.trace("Dequeue wait timed out. Yielding...") + await asyncio.sleep(0.001) + continue + except asyncio.CancelledError: + break + except Exception as e: + logger.warning("Unexpected error in dequeue loop: %s", e) + continue + + if isinstance(message, MqttPublishMessage) and isinstance(message.payload, bytes): + logger.trace("Dequeued immediate publish: topic=%s (raw bytes)", message.topic) + await self._try_publish(message) + continue + + logger.trace("Dequeued message for batching: topic=%s, device=%s", + message.topic, getattr(message.payload, 'device_name', 'N/A')) + + batch: List[MqttPublishMessage] = [message] + start = asyncio.get_event_loop().time() + batch_size = message.payload.size + batch_type = type(message.payload).__name__ + + while not self._queue.empty(): + elapsed = asyncio.get_event_loop().time() - start + if elapsed >= self._batch_max_time: + logger.trace("Batch time threshold reached: %.3fs", elapsed) + break + if len(batch) >= self._batch_max_count: + logger.trace("Batch count threshold reached: %d messages", len(batch)) + break - if isinstance(payload, bytes): - logger.trace("Dequeued immediate publish: topic=%s (raw bytes)", topic) - await self._try_publish(topic, payload, datapoints, delivery_futures_or_none) - continue + try: + next_message = self._queue.get_nowait() + if isinstance(next_message.payload, DeviceUplinkMessage) or isinstance(next_message.payload, GatewayUplinkMessage): + if batch_type is not None and batch_type != type(next_message.payload).__name__: + logger.trace("Batch type mismatch: current=%s, next=%s, finalizing current", + batch_type, type(next_message.payload).__class__.__name__) + self._queue.put_nowait(next_message) + break + batch_type = type(next_message.payload).__name__ + msg_size = next_message.payload.size + if batch_size + msg_size > self._adapter.splitter.max_payload_size: # noqa + logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) + self._queue.put_nowait(next_message) + break + batch.append(next_message) + batch_size += msg_size + else: + logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_message.topic) + await self._try_publish(next_message) + except asyncio.QueueEmpty: + break - logger.trace("Dequeued message for batching: topic=%s, device=%s", - topic, getattr(payload, 'device_name', 'N/A')) + if batch_type is None: + batch_type = type(message.payload).__name__ - batch: List[Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]] = [(topic, payload, delivery_futures_or_none, datapoints, qos)] - start = asyncio.get_event_loop().time() - batch_size = payload.size - batch_type = type(payload).__name__ + if batch: + logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) - while not self._queue.empty(): - elapsed = asyncio.get_event_loop().time() - start - if elapsed >= self._batch_max_time: - logger.trace("Batch time threshold reached: %.3fs", elapsed) - break - if len(batch) >= self._batch_max_count: - logger.trace("Batch count threshold reached: %d messages", len(batch)) - break - - try: - next_topic, next_payload, delivery_futures_or_none, datapoints, qos = self._queue.get_nowait() - if isinstance(next_payload, DeviceUplinkMessage) or isinstance(next_payload, GatewayUplinkMessage): - if batch_type is not None and batch_type != type(next_payload).__name__: - logger.trace("Batch type mismatch: current=%s, next=%s, finalizing current", - batch_type, type(next_payload).__class__.__name__) - self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) - break - batch_type = type(next_payload).__name__ - msg_size = next_payload.size - if batch_size + msg_size > self._adapter.splitter.max_payload_size: # noqa - logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) - self._queue.put_nowait((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) - break - batch.append((next_topic, next_payload, delivery_futures_or_none, datapoints, qos)) - batch_size += msg_size + if batch_type == 'GatewayUplinkMessage' and self._gateway_adapter: + logger.trace("Building gateway uplink payloads for %d messages", len(batch)) + topic_payloads = self._gateway_adapter.build_uplink_messages(batch) else: - logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_topic) - await self._try_publish(next_topic, next_payload, datapoints) - except asyncio.QueueEmpty: + topic_payloads = self._adapter.build_uplink_messages(batch) + + for built_message in topic_payloads: + logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", + built_message.topic, + len(built_message.payload), + built_message.datapoints, + [f.uuid for f in built_message.delivery_futures]) + await self._try_publish(built_message) + + except Exception as e: + logger.error("Error in dequeue loop:", exc_info=e) + logger.debug("Dequeue loop error details:", exc_info=e) + if isinstance(e, asyncio.CancelledError): break + continue - if batch_type is None: - batch_type = type(payload).__name__ - - if batch: - logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) - messages = [device_uplink_message for _, device_uplink_message, _, _, _ in batch] - - if batch_type == 'GatewayUplinkMessage' and self._gateway_adapter: - logger.trace("Building gateway uplink payloads for %d messages", len(messages)) - topic_payloads = self._gateway_adapter.build_uplink_payloads(messages) - else: - topic_payloads = self._adapter.build_uplink_payloads(messages) - - for topic, payload, datapoints, delivery_futures in topic_payloads: - logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", - topic, len(payload), datapoints, [id(f) for f in delivery_futures]) - await self._try_publish(topic=topic, - payload=payload, - datapoints=datapoints, - delivery_futures_or_none=delivery_futures, - qos=qos) - - async def _try_publish(self, - topic: str, - payload: bytes, - datapoints: int, - delivery_futures_or_none: List[Optional[asyncio.Future[PublishResult]]] = None, - qos: int = 1): - if delivery_futures_or_none is None: - logger.trace("No delivery futures associated! This publish result will not be tracked.") + async def _try_publish(self, message: MqttPublishMessage): + if not message.delivery_futures: + logger.error("No delivery futures associated! This publish result will not be tracked.") delivery_futures_or_none = [] - is_message_with_telemetry_or_attributes = topic in (mqtt_topics.DEVICE_TELEMETRY_TOPIC, - mqtt_topics.DEVICE_ATTRIBUTES_TOPIC) + is_message_with_telemetry_or_attributes = message.topic in (mqtt_topics.DEVICE_TELEMETRY_TOPIC, + mqtt_topics.DEVICE_ATTRIBUTES_TOPIC) # TODO: Add topics check for gateways - logger.trace("Attempting publish: topic=%s, datapoints=%d", topic, datapoints) + logger.trace("Attempting publish: topic=%s, datapoints=%d", message.topic, message.datapoints) # Check backpressure first - if active, don't even try to check rate limits if self._backpressure.should_pause(): - logger.debug("Backpressure active, delaying publish of topic=%s for %.1f seconds", topic, 1.0) - self._schedule_delayed_retry(topic=topic, - payload=payload, - datapoints=datapoints, - qos=qos, - delay=1.0, - delivery_futures=delivery_futures_or_none) + logger.debug("Backpressure active, delaying publish of topic=%s for %.1f seconds", message.topic, 1.0) + self._schedule_delayed_retry(message) return # Check and consume rate limits atomically before publishing @@ -204,28 +194,18 @@ async def _try_publish(self, triggered_rate_limit = await self._telemetry_rate_limit.try_consume(1) if triggered_rate_limit: logger.debug("Telemetry message rate limit hit for topic %s: %r per %r seconds", - topic, triggered_rate_limit[0], triggered_rate_limit[1]) + message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._telemetry_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic=topic, - payload=payload, - datapoints=datapoints, - qos=qos, - delay=retry_delay, - delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(message, delay=retry_delay) return if self._telemetry_dp_rate_limit: - triggered_rate_limit = await self._telemetry_dp_rate_limit.try_consume(datapoints) + triggered_rate_limit = await self._telemetry_dp_rate_limit.try_consume(message.datapoints) if triggered_rate_limit: logger.debug("Telemetry datapoint rate limit hit for topic %s: %r per %r seconds", - topic, triggered_rate_limit[0], triggered_rate_limit[1]) + message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._telemetry_dp_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic=topic, - payload=payload, - datapoints=datapoints, - qos=qos, - delay=retry_delay, - delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(message, delay=retry_delay) return else: # For non-telemetry messages, we only need to check the message rate limit @@ -233,71 +213,42 @@ async def _try_publish(self, triggered_rate_limit = await self._message_rate_limit.try_consume(1) if triggered_rate_limit: logger.debug("Generic message rate limit hit for topic %s: %r per %r seconds", - topic, triggered_rate_limit[0], triggered_rate_limit[1]) + message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) retry_delay = self._message_rate_limit.minimal_timeout - self._schedule_delayed_retry(topic=topic, - payload=payload, - datapoints=datapoints, - qos=qos, - delay=retry_delay, - delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(message, delay=retry_delay) return try: - logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", - topic, len(payload), id(delivery_futures_or_none[0]) if delivery_futures_or_none else -1) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", + message.topic, len(message.payload), [f.uuid for f in message.delivery_futures]) - mqtt_future = await self._mqtt_manager.publish(message_or_topic=topic, payload=payload, qos=qos) + await self._mqtt_manager.publish(message) - if delivery_futures_or_none is not None: - def resolve_attached(publish_future: asyncio.Future): - try: - publish_result = publish_future.result() - except asyncio.CancelledError: - logger.trace("Publish future was cancelled: %r, id: %r", publish_future, id(publish_future)) - publish_result = PublishResult(topic, qos, -1, len(payload), -1) - except Exception as exc: - logger.warning("Publish failed with exception: %s", exc) - logger.debug("Resolving delivery futures with failure:", exc_info=exc) - publish_result = PublishResult(topic, qos, -1, len(payload), -1) - - for i, f in enumerate(delivery_futures_or_none): - if f is not None and not f.done(): - f.set_result(publish_result) - logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r, %r", - i, id(f), publish_result, id(publish_future), publish_future) - - logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", id(mqtt_future), mqtt_future.done()) - if mqtt_future.done(): - logger.debug("Main publish future is already done, resolving immediately.") - resolve_attached(mqtt_future) - else: - mqtt_future.add_done_callback(resolve_attached) except Exception as e: - logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", topic, e) + logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", message.topic, e) logger.debug("Scheduling retry for topic=%s, payload size=%d, qos=%d", - topic, len(payload), qos) + message.topic, len(message.payload), message.qos) logger.debug("error details: %s", e, exc_info=True) - self._schedule_delayed_retry(topic, payload, datapoints, qos, delay=.1, delivery_futures=delivery_futures_or_none) + self._schedule_delayed_retry(message, delay=.1) - def _schedule_delayed_retry(self, topic: str, payload: bytes, datapoints: int, qos: int, delay: float, - delivery_futures: Optional[List[Optional[asyncio.Future[PublishResult]]]] = None): + def _schedule_delayed_retry(self, message: MqttPublishMessage, delay: float = 0.1): if not self._active.is_set() or self._main_stop_event.is_set(): - logger.debug("MessageQueue is not active or main stop event is set. Not scheduling retry for topic=%s", topic) + logger.debug("MessageQueue is not active or main stop event is set. Not scheduling retry for topic=%s", message.topic) return - logger.trace("Scheduling retry: topic=%s, delay=%.2f", topic, delay) + logger.trace("Scheduling retry: topic=%s, delay=%.2f", message.topic, delay) async def retry(): try: - logger.debug("Retrying publish: topic=%s", topic) + logger.debug("Retrying publish: topic=%s", message.topic) await asyncio.sleep(delay) if not self._active.is_set() or self._main_stop_event.is_set(): - logger.debug("MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", topic) + logger.debug("MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", message.topic) return - self._queue.put_nowait((topic, payload, delivery_futures, datapoints, qos)) + self._queue.put_nowait(message) self._wakeup_event.set() - logger.debug("Re-enqueued message after delay: topic=%s", topic) + logger.debug("Re-enqueued message after delay: topic=%s", message.topic) except asyncio.QueueFull: - logger.warning("Retry queue full. Dropping retried message: topic=%s", topic) + logger.warning("Retry queue full. Dropping retried message: topic=%s", message.topic) except Exception as e: logger.debug("Unexpected error during delayed retry: %s", e) @@ -305,7 +256,7 @@ async def retry(): self._retry_tasks.add(task) task.add_done_callback(self._retry_tasks.discard) - async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage, GatewayUplinkMessage], List[asyncio.Future[PublishResult]], int, int]: + async def _wait_for_message(self) -> MqttPublishMessage: while self._active.is_set(): try: if not self._queue.empty(): @@ -320,9 +271,7 @@ async def _wait_for_message(self) -> Tuple[str, Union[bytes, DeviceUplinkMessage wake_task = asyncio.create_task(self._wakeup_event.wait()) - done, pending = await asyncio.wait( - [queue_task, wake_task], return_when=asyncio.FIRST_COMPLETED - ) + done, pending = await asyncio.wait([queue_task, wake_task], return_when=asyncio.FIRST_COMPLETED) for task in pending: task.cancel() @@ -368,7 +317,8 @@ async def shutdown(self): self._queue.qsize()) self.clear() - async def _cancel_tasks(self, tasks: set[asyncio.Task]): + @staticmethod + async def _cancel_tasks(tasks: set[asyncio.Task]): for task in list(tasks): task.cancel() with suppress(asyncio.CancelledError): @@ -414,3 +364,18 @@ async def _refill_rate_limits(self): def set_gateway_message_adapter(self, message_adapter: GatewayMessageAdapter): self._gateway_adapter = message_adapter + + + async def print_queue_statistics(self): + """ + Prints the current statistics of the message queue. + """ + while self._active.is_set() and not self._main_stop_event.is_set(): + queue_size = self._queue.qsize() + pending_tasks = len(self._pending_queue_tasks) + retry_tasks = len(self._retry_tasks) + active = self._active.is_set() + logger.info("MessageQueue Statistics: " + "Queue Size: %d, Pending Tasks: %d, Retry Tasks: %d, Active: %s", + queue_size, pending_tasks, retry_tasks, active) + await asyncio.sleep(60) diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index b17062b..bdabed7 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -13,10 +13,12 @@ # limitations under the License. import asyncio -from collections import defaultdict -from typing import List, Optional, Dict, Tuple +from collections import defaultdict, deque +from typing import List, Optional, Dict, Tuple, Set +from uuid import uuid4 -from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.async_utils import future_map +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry @@ -42,8 +44,8 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp logger.trace("Splitting timeseries for %d messages", len(messages)) if (len(messages) == 1 - and ((messages[0].attributes_datapoint_count() + messages[ - 0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) and messages[0].size <= self._max_payload_size): return messages @@ -51,37 +53,41 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) for msg in messages: - key = (msg.device_name, msg.device_profile) - grouped[key].append(msg) + grouped[(msg.device_name, msg.device_profile)].append(msg) for (device_name, device_profile), group_msgs in grouped.items(): - logger.trace("Processing group: device='%s', profile='%s', messages=%d", device_name, device_profile, - len(group_msgs)) + logger.trace("Processing group: device='%s', profile='%s', messages=%d", + device_name, device_profile, len(group_msgs)) + + all_ts_entries: List[TimeseriesEntry] = [] + parent_futures: List[asyncio.Future] = [] - all_ts: List[TimeseriesEntry] = [] - delivery_futures: List[asyncio.Future] = [] for msg in group_msgs: - if msg.has_timeseries(): - for ts_group in msg.timeseries.values(): - all_ts.extend(ts_group) - delivery_futures.extend(msg.get_delivery_futures()) + for ts_group in msg.timeseries.values(): + all_ts_entries.extend(ts_group) + parent_futures.extend(msg.get_delivery_futures() or []) builder: Optional[DeviceUplinkMessageBuilder] = None size = 0 point_count = 0 - batch_futures = [] - for ts_kv in all_ts: + for ts_kv in all_ts_entries: exceeds_size = builder and size + ts_kv.size > self._max_payload_size exceeds_points = 0 < self._max_datapoints <= point_count if not builder or exceeds_size or exceeds_points: if builder: + shared_future = asyncio.Future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + built = builder.build() result.append(built) - batch_futures.extend(built.get_delivery_futures()) - logger.trace("Flushed batch with %d points (size=%d)", len(built.timeseries), size) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed batch with %d datapoints (size=%d)", + built.timeseries_datapoint_count(), size) builder = DeviceUplinkMessageBuilder() \ .set_device_name(device_name) \ .set_device_profile(device_profile) @@ -93,22 +99,17 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp point_count += 1 if builder and builder._timeseries: # noqa + shared_future = asyncio.Future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + built = builder.build() result.append(built) - batch_futures.extend(built.get_delivery_futures()) - logger.trace("Flushed final batch with %d points (size=%d)", len(built.timeseries), size) - - if delivery_futures: - original_future = delivery_futures[0] - logger.trace("Adding futures to original future: %s, futures ids: %r", id(original_future), - [id(f) for f in batch_futures]) + for parent in parent_futures: + future_map.register(parent, [shared_future]) - async def resolve_original(): - logger.trace("Resolving original future with batch futures: %r", [id(f) for f in batch_futures]) - results = await asyncio.gather(*batch_futures, return_exceptions=False) - original_future.set_result(all(results)) - - asyncio.create_task(resolve_original()) + logger.trace("Flushed final batch with %d datapoints (size=%d)", + built.timeseries_datapoint_count(), size) logger.trace("Total timeseries batches created: %d", len(result)) return result @@ -118,8 +119,8 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp result: List[DeviceUplinkMessage] = [] if (len(messages) == 1 - and ((messages[0].attributes_datapoint_count() + messages[ - 0].timeseries_datapoint_count() <= self._max_datapoints) or self._max_datapoints == 0) # noqa + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) and messages[0].size <= self._max_payload_size): return messages @@ -128,34 +129,40 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp grouped[(msg.device_name, msg.device_profile)].append(msg) for (device_name, device_profile), group_msgs in grouped.items(): - logger.trace("Processing attribute group: device='%s', profile='%s', messages=%d", device_name, - device_profile, len(group_msgs)) + logger.trace("Processing attribute group: device='%s', profile='%s', messages=%d", + device_name, device_profile, len(group_msgs)) all_attrs: List[AttributeEntry] = [] - delivery_futures: List[asyncio.Future] = [] + parent_futures: List[asyncio.Future] = [] for msg in group_msgs: if msg.has_attributes(): all_attrs.extend(msg.attributes) - delivery_futures.extend(msg.get_delivery_futures()) + parent_futures.extend(msg.get_delivery_futures()) - builder = None + builder: Optional[DeviceUplinkMessageBuilder] = None size = 0 point_count = 0 - batch_futures = [] for attr in all_attrs: exceeds_size = builder and size + attr.size > self._max_payload_size exceeds_points = 0 < self._max_datapoints <= point_count if not builder or exceeds_size or exceeds_points: - if builder: + if builder and builder._attributes: # noqa + shared_future = asyncio.Future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + built = builder.build() result.append(built) - batch_futures.extend(built.get_delivery_futures()) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) - builder = DeviceUplinkMessageBuilder().set_device_name(device_name).set_device_profile( - device_profile) + builder = DeviceUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) size = 0 point_count = 0 @@ -164,21 +171,16 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp point_count += 1 if builder and builder._attributes: # noqa + shared_future = asyncio.Future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + built = builder.build() result.append(built) - batch_futures.extend(built.get_delivery_futures()) - logger.trace("Flushed final attribute batch (count=%d, size=%d)", len(built.attributes), size) - - if delivery_futures: - original_future = delivery_futures[0] - logger.trace("Adding futures to original future: %s, futures ids: %r", id(original_future), - [id(batch_future) for batch_future in batch_futures]) + for parent in parent_futures: + future_map.register(parent, [shared_future]) - async def resolve_original(): - results = await asyncio.gather(*batch_futures, return_exceptions=False) - original_future.set_result(all(results)) - - asyncio.create_task(resolve_original()) + logger.trace("Flushed final attribute batch (count=%d, size=%d)", len(built.attributes), size) logger.trace("Total attribute batches created: %d", len(result)) return result diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index d66e1b8..bb10e15 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -19,13 +19,14 @@ from contextlib import suppress from time import monotonic from typing import Optional, Callable, Dict, Union, Tuple, Coroutine, Any +from uuid import uuid4 -from gmqtt import Client as GMQTTClient, Message, Subscription +from gmqtt import Client as GMQTTClient, Subscription -from tb_mqtt_client.common.async_utils import await_or_stop -from tb_mqtt_client.common.gmqtt_patch import patch_gmqtt_puback, patch_gmqtt_protocol_connection_lost, \ - patch_mqtt_handler_disconnect, patch_handle_connack, PatchUtils -from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.async_utils import await_or_stop, future_map +from tb_mqtt_client.common.gmqtt_patch import PatchUtils +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer, AttributeRequestIdProducer @@ -61,12 +62,14 @@ def __init__( ): self._main_stop_event = main_stop_event self._message_dispatcher = message_adapter - patch_gmqtt_protocol_connection_lost() - patch_mqtt_handler_disconnect() + self._patch_utils: PatchUtils = PatchUtils(None, self._main_stop_event, 1) + self._patch_utils.patch_gmqtt_protocol_connection_lost() + self._patch_utils.patch_mqtt_handler_disconnect() self._client = GMQTTClient(client_id) - patch_gmqtt_puback(self._client, self._handle_puback_reason_code) - patch_handle_connack(self._client, self._on_connect_internal) + self._patch_utils.client = self._client + self._patch_utils.patch_handle_connack(self._on_connect_internal) + self._patch_utils.apply(self._handle_puback_reason_code) self._client.on_connect = self._on_connect_internal self._client.on_disconnect = self._on_disconnect_internal self._client.on_message = self._on_message_internal @@ -82,7 +85,7 @@ def __init__( self._connect_params = None # Will be set in connect method self._handlers: Dict[str, Callable[[str, bytes], Coroutine[Any, Any, None]]] = {} - self._pending_publishes: Dict[int, Tuple[asyncio.Future[PublishResult], str, int, int, float]] = {} + self._pending_publishes: Dict[int, Tuple[asyncio.Future[PublishResult], MqttPublishMessage, float]] = {} self._publish_monitor_task = asyncio.create_task(self._monitor_ack_timeouts()) self._pending_subscriptions: Dict[int, asyncio.Future] = {} @@ -154,11 +157,10 @@ async def disconnect(self): self.__is_waiting_for_rate_limits_publish = True self._rate_limits_ready_event.clear() - async def publish(self, message_or_topic: Union[str, Message], - payload: Optional[bytes] = None, + async def publish(self, + message: MqttPublishMessage, qos: int = 1, - retain: bool = False, - force=False) -> asyncio.Future: + force=False): if not force: if not self.__rate_limits_retrieved and not self.__is_waiting_for_rate_limits_publish: @@ -167,31 +169,68 @@ async def publish(self, message_or_topic: Union[str, Message], if not self._rate_limits_ready_event.is_set(): await await_or_stop(self._rate_limits_ready_event.wait(), self._main_stop_event, timeout=10) except asyncio.TimeoutError: + if not self.__is_waiting_for_rate_limits_publish: + logger.warning("Timeout waiting for rate limits, requesting them now.") + await self.__request_rate_limits() raise RuntimeError("Timeout waiting for rate limits.") if not force and self._backpressure.should_pause(): logger.trace("Backpressure active. Publishing suppressed.") raise RuntimeError("Publishing temporarily paused due to backpressure.") - if isinstance(message_or_topic, Message): - message = message_or_topic - else: - message = Message(message_or_topic, payload, qos=qos, retain=retain) + + mqtt_future = asyncio.get_event_loop().create_future() + mqtt_future.uuid = uuid4() + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace("Publishing message with topic: %s, qos: %d, payload size: %d, mqtt_future id: %r, delivery futures: %r", + message.topic, qos, len(message.payload), mqtt_future.uuid, [f.uuid for f in message.delivery_futures]) + if message.delivery_futures is not None: + def resolve_attached(publish_future: asyncio.Future): + try: + try: + publish_result = publish_future.result() + except asyncio.CancelledError: + logger.info("Publish future was cancelled: %r, id: %r", publish_future, publish_future.uuid) + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + except Exception as exc: + logger.warning("Publish failed with exception: %s", exc) + logger.debug("Resolving delivery futures with failure:", exc_info=exc) + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_result(publish_result) + future_map.child_resolved(f) + logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r", + i, f.uuid, publish_result, publish_future.uuid) + except Exception as e: + logger.error("Error resolving delivery futures: %s", str(e)) + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_exception(e) + logger.debug("Set exception for delivery future #%d id=%r", i, f.uuid) + + logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", mqtt_future.uuid, mqtt_future.done()) + if mqtt_future.done(): + logger.debug("Main publish future is already done, resolving immediately.") + resolve_attached(mqtt_future) + else: + mqtt_future.add_done_callback(resolve_attached) mid, package = self._client._connection.publish(message) # noqa - future = asyncio.get_event_loop().create_future() + message.mark_as_sent(mid) + if qos > 0: - logger.trace("Publishing mid=%s, storing publish main future with id: %r", mid, id(future)) - self._pending_publishes[mid] = (future, message.topic, message.qos, message.payload_size, monotonic()) - self._client._persistent_storage.push_message_nowait(mid, package) # noqa + logger.trace("Publishing mid=%s, storing publish main future with id: %r", mid, mqtt_future.uuid) + self._pending_publishes[mid] = (mqtt_future, message, monotonic()) + self._client._persistent_storage.push_message_nowait(mid, message) # noqa else: - future.set_result(True) - - return future + mqtt_future.set_result(True) async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyncio.Future: sub_future = asyncio.get_event_loop().create_future() + sub_future.uuid = uuid4() subscription = Subscription(topic, qos=qos) if isinstance(topic, str) else topic if self.__rate_limiter: @@ -202,6 +241,7 @@ async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyn async def unsubscribe(self, topic: str) -> asyncio.Future: unsubscribe_future = asyncio.get_event_loop().create_future() + unsubscribe_future.uuid = uuid4() if self.__rate_limiter: await self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() mid = self._client._connection.unsubscribe(topic) # noqa @@ -265,12 +305,12 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc if exc: logger.warning("Disconnect exception: %s", exc) - for mid, (future, topic, qos, payload_size, publishing_time) in list(self._pending_publishes.items()): + for mid, (future, mqtt_message, publishing_time) in list(self._pending_publishes.items()): if not future.done(): publish_result = PublishResult( - topic=topic, - qos=qos, - payload_size=payload_size, + topic=mqtt_message.topic, + qos=mqtt_message.qos, + payload_size=mqtt_message.payload_size, message_id=-1, reason_code=reason_code or 0 ) @@ -280,8 +320,6 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc RPCRequestIdProducer.reset() AttributeRequestIdProducer.reset() - self._rpc_response_handler.clear() - self._handlers.clear() self.__rate_limits_retrieved = False self.__is_waiting_for_rate_limits_publish = True self._rate_limits_ready_event.clear() @@ -360,20 +398,22 @@ def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dic if pending_future_data is None: logger.error("Missing future for mid=%s", mid) return - future, topic, qos, payload_size, publishing_time = pending_future_data + future, mqtt_message, publishing_time = pending_future_data publish_result = PublishResult( - topic=topic, - qos=qos, - payload_size=payload_size, + topic=mqtt_message.topic, + qos=mqtt_message.qos, + payload_size=mqtt_message.payload_size, message_id=mid, - reason_code=reason_code + reason_code=reason_code, + datapoints_count=mqtt_message.datapoints ) - logger.trace("Received result for publish future (id: %r): %r", id(future), publish_result) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace("Received result for publish future (id: %r): %r", future.uuid, publish_result) if not future.done(): future.set_result(publish_result) else: logger.warning("Future (id: %r) for mid=%s was already done, skipping setting result", - id(future), mid) + future.uuid, mid) if reason_code == QUOTA_EXCEEDED: logger.warning("PUBACK received with QUOTA_EXCEEDED for mid=%s", mid) @@ -427,11 +467,11 @@ async def __request_rate_limits(self): logger.debug("Publishing rate limits request to server...") request = await RPCRequest.build("getSessionLimits") - topic, payload = self._message_dispatcher.build_rpc_request(request) + mqtt_message: MqttPublishMessage = self._message_dispatcher.build_rpc_request(request) response_future = self._rpc_response_handler.register_request(request.request_id, self.__rate_limits_handler) try: - await self.publish(topic, payload, qos=1, force=True) + await self.publish(mqtt_message, qos=1, force=True) await await_or_stop(response_future, self._main_stop_event, timeout=10) logger.info("Successfully processed rate limits.") self.__rate_limits_retrieved = True @@ -471,16 +511,37 @@ async def _monitor_ack_timeouts(self): async def check_pending_publishes(self, time_to_check): expired = [] - for mid, (future, topic, qos, payload_size, timestamp) in list(self._pending_publishes.items()): + for mid, (future, message, timestamp) in list(self._pending_publishes.items()): if self._main_stop_event.is_set(): with suppress(asyncio.CancelledError): future.cancel() continue if time_to_check - timestamp > self._PUBLISH_TIMEOUT: if not future.done(): - logger.warning("Publish timeout: mid=%s, topic=%s", mid, topic) - result = PublishResult(topic, qos, payload_size, mid, reason_code=408) + logger.warning("Publish timeout: mid=%s, topic=%s", mid, message.topic) + result = PublishResult(message.topic, message.qos, message.payload_size, mid, reason_code=408) future.set_result(result) expired.append(mid) for mid in expired: self._pending_publishes.pop(mid, None) + + async def stop(self): + """ + Cleanly stop the MQTT manager and background tasks. + """ + if hasattr(self, '_patch_utils'): + await self._patch_utils.stop_retry_task() + + if hasattr(self, '_publish_monitor_task'): + self._publish_monitor_task.cancel() + with suppress(asyncio.CancelledError): + await self._publish_monitor_task + + if self._client.is_connected: + await self._client.disconnect() + + async def failed_messages_reprocessing(self): + """ + Reprocess failed messages that were not acknowledged by the server. + Using internal Gmqtt queue to get messages that were not acknowledged. + """ From 8674be12eefdceb327d525fcac0be6f698a2055b Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 24 Jul 2025 09:30:02 +0300 Subject: [PATCH 56/74] Adjusted entities and services to use mqtt publish message, changes test due to new entities usage --- tb_mqtt_client/common/mqtt_message.py | 8 +- tb_mqtt_client/common/publish_result.py | 10 + tb_mqtt_client/service/device/client.py | 2 +- .../service/device/message_adapter.py | 12 +- tb_mqtt_client/service/message_queue.py | 22 +- tb_mqtt_client/service/message_splitter.py | 9 +- tb_mqtt_client/service/mqtt_manager.py | 8 +- tests/common/test_publish_result.py | 4 +- tests/service/device/test_device_client.py | 160 ++- tests/service/test_json_message_adapter.py | 115 +- tests/service/test_message_queue.py | 1198 +++-------------- tests/service/test_message_splitter.py | 44 +- tests/service/test_mqtt_manager.py | 53 +- 13 files changed, 480 insertions(+), 1165 deletions(-) diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py index 61cf5fd..a19f0f8 100644 --- a/tb_mqtt_client/common/mqtt_message.py +++ b/tb_mqtt_client/common/mqtt_message.py @@ -56,15 +56,19 @@ def __init__(self, self.datapoints = datapoints self.properties = kwargs self._is_sent = False - self.delivery_futures: List[Future] = delivery_futures + self.delivery_futures = delivery_futures if not delivery_futures: delivery_future = Future() delivery_future.uuid = uuid4() self.delivery_futures = [delivery_future] + if not isinstance(self.delivery_futures, (list, tuple)): + self.delivery_futures = [self.delivery_futures] + for future in self.delivery_futures: + if not hasattr(future, 'uuid'): + future.uuid = uuid4() logger.trace(f"Created MqttMessage with topic: {topic}, payload type: {type(payload).__name__}, " f"datapoints: {datapoints}, delivery_future id: {self.delivery_futures[0].uuid}") - def mark_as_sent(self, message_id: int): """Mark the message as sent.""" self.message_id = message_id diff --git a/tb_mqtt_client/common/publish_result.py b/tb_mqtt_client/common/publish_result.py index ceaeb79..fdab1a3 100644 --- a/tb_mqtt_client/common/publish_result.py +++ b/tb_mqtt_client/common/publish_result.py @@ -32,6 +32,16 @@ def __repr__(self): f"reason_code={self.reason_code}, " f"datapoints_count={self.datapoints_count})") + def __eq__(self, other): + if not isinstance(other, PublishResult): + return NotImplemented + return (self.topic == other.topic and + self.qos == other.qos and + self.message_id == other.message_id and + self.payload_size == other.payload_size and + self.reason_code == other.reason_code and + self.datapoints_count == other.datapoints_count) + def as_dict(self) -> dict: return { "topic": self.topic, diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 96ab025..7a6423d 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -331,7 +331,7 @@ async def claim_device(self, return await await_or_stop(delivery_future, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for claiming publish result") - return PublishResult(mqtt_message.topic, 1, -1, len(mqtt_message.payload_size), -1) + return PublishResult(mqtt_message.topic, 1, -1, mqtt_message.payload_size, -1) else: return delivery_future diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index c92252e..791b7a4 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -88,7 +88,7 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: pass @abstractmethod - def build_provision_request(self, provision_request) -> Tuple[str, bytes]: + def build_provision_request(self, provision_request) -> MqttPublishMessage: """ Build the payload for a device provisioning request. This method should return a tuple of topic and payload bytes. @@ -374,7 +374,7 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) - def build_provision_request(self, provision_request: 'ProvisioningRequest') -> Tuple[str, bytes]: + def build_provision_request(self, provision_request: 'ProvisioningRequest') -> MqttPublishMessage: """ Build the payload for a device provisioning request. :param provision_request: The ProvisioningRequest to build the payload for. @@ -417,8 +417,14 @@ def build_provision_request(self, provision_request: 'ProvisioningRequest') -> T request["credentialsType"] = provision_request.credentials.credentials_type.value payload = dumps(request) + result_msg = MqttPublishMessage( + topic=topic, + payload=payload, + qos=1, + datapoints=1 + ) logger.trace("Built provision request payload: %r", provision_request) - return topic, payload + return result_msg @staticmethod def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index 24f7356..a13bb67 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -63,7 +63,7 @@ def __init__(self, self._gateway_adapter = gateway_message_adapter self._loop_task = asyncio.create_task(self._dequeue_loop()) self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) - asyncio.create_task(self.print_queue_statistics()) + self.__print_queue_statistics_task = asyncio.create_task(self.print_queue_statistics()) logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", max_queue_size, self._batch_max_time, batch_collect_max_count) @@ -302,20 +302,16 @@ async def shutdown(self): self._loop_task.cancel() if self._rate_limit_refill_task: self._rate_limit_refill_task.cancel() + self.__print_queue_statistics_task.cancel() with suppress(asyncio.CancelledError): await self._loop_task await self._rate_limit_refill_task + await self.__print_queue_statistics_task - while not self._queue.empty(): - try: - self._queue.get_nowait() - self._queue.task_done() - except asyncio.QueueEmpty: - break + self.clear() logger.debug("MessageQueue shutdown complete, message queue size: %d", self._queue.qsize()) - self.clear() @staticmethod async def _cancel_tasks(tasks: set[asyncio.Task]): @@ -331,13 +327,13 @@ def is_empty(self): def clear(self): logger.debug("Clearing message queue...") while not self._queue.empty(): - topic, message, delivery_futures, _, qos = self._queue.get_nowait() - for future in delivery_futures: + message = self._queue.get_nowait() + for future in message.delivery_futures: future.set_result(PublishResult( - topic=topic, - qos=qos, + topic=message.topic, + qos=message.qos, message_id=-1, - payload_size=message.size if isinstance(message, DeviceUplinkMessage) or isinstance(message, GatewayUplinkMessage) else len(message), + payload_size=message.payload_size, reason_code=-1 )) self._queue.task_done() diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index bdabed7..0715a5c 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -63,9 +63,10 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp parent_futures: List[asyncio.Future] = [] for msg in group_msgs: - for ts_group in msg.timeseries.values(): - all_ts_entries.extend(ts_group) - parent_futures.extend(msg.get_delivery_futures() or []) + if msg.has_timeseries(): + for ts_group in msg.timeseries.values(): + all_ts_entries.extend(ts_group) + parent_futures.extend(msg.get_delivery_futures() or []) builder: Optional[DeviceUplinkMessageBuilder] = None size = 0 @@ -138,7 +139,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp for msg in group_msgs: if msg.has_attributes(): all_attrs.extend(msg.attributes) - parent_futures.extend(msg.get_delivery_futures()) + parent_futures.extend(msg.get_delivery_futures()) builder: Optional[DeviceUplinkMessageBuilder] = None size = 0 diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index bb10e15..f86e24a 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -61,7 +61,7 @@ def __init__( rpc_response_handler: Optional[RPCResponseHandler] = None, ): self._main_stop_event = main_stop_event - self._message_dispatcher = message_adapter + self._message_adapter = message_adapter self._patch_utils: PatchUtils = PatchUtils(None, self._main_stop_event, 1) self._patch_utils.patch_gmqtt_protocol_connection_lost() self._patch_utils.patch_mqtt_handler_disconnect() @@ -201,7 +201,7 @@ def resolve_attached(publish_future: asyncio.Future): if f is not None and not f.done(): f.set_result(publish_result) future_map.child_resolved(f) - logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r", + logger.error("Resolved delivery future #%d id=%r with %s, main publish future id: %r", i, f.uuid, publish_result, publish_future.uuid) except Exception as e: logger.error("Error resolving delivery futures: %s", str(e)) @@ -226,7 +226,7 @@ def resolve_attached(publish_future: asyncio.Future): self._pending_publishes[mid] = (mqtt_future, message, monotonic()) self._client._persistent_storage.push_message_nowait(mid, message) # noqa else: - mqtt_future.set_result(True) + mqtt_future.set_result(PublishResult(message.topic, qos, -1, message.payload_size, 0)) async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyncio.Future: sub_future = asyncio.get_event_loop().create_future() @@ -467,7 +467,7 @@ async def __request_rate_limits(self): logger.debug("Publishing rate limits request to server...") request = await RPCRequest.build("getSessionLimits") - mqtt_message: MqttPublishMessage = self._message_dispatcher.build_rpc_request(request) + mqtt_message: MqttPublishMessage = self._message_adapter.build_rpc_request(request) response_future = self._rpc_response_handler.register_request(request.request_id, self.__rate_limits_handler) try: diff --git a/tests/common/test_publish_result.py b/tests/common/test_publish_result.py index 54a7ce6..15b799f 100644 --- a/tests/common/test_publish_result.py +++ b/tests/common/test_publish_result.py @@ -44,6 +44,7 @@ def test_publish_result_repr(default_publish_result): assert "message_id=123" in result assert "payload_size=256" in result assert "reason_code=0" in result + assert "datapoints_count=0" in result def test_publish_result_as_dict(default_publish_result): @@ -54,7 +55,8 @@ def test_publish_result_as_dict(default_publish_result): "qos": 1, "message_id": 123, "payload_size": 256, - "reason_code": 0 + "reason_code": 0, + "datapoints_count": 0 } diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py index 096a31f..8621c2d 100644 --- a/tests/service/device/test_device_client.py +++ b/tests/service/device/test_device_client.py @@ -18,6 +18,7 @@ import pytest from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.mqtt_topics import DEVICE_CLAIM_TOPIC from tb_mqtt_client.entities.data.claim_request import ClaimRequest @@ -25,19 +26,75 @@ from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.service.message_queue import MessageQueue @pytest.mark.asyncio async def test_send_timeseries_with_dict(): + # Setup client = DeviceClient() - client._message_queue = AsyncMock(spec=MessageQueue) - future = asyncio.Future() - future.set_result(PublishResult("topic", 1, 1, 100, 1)) - client._message_queue.publish.return_value = [future] - result = await client.send_timeseries({"temp": 22}) + client._mqtt_manager._handle_puback_reason_code = ( + client._mqtt_manager._handle_puback_reason_code.__get__(client._mqtt_manager) + ) + + class DummyConnection: + def __init__(self): + self.published_messages = [] + + def publish(self, *args, **kwargs): + self.published_messages.append(args[0]) + return 1, 1 + + async def fake_subscribe(*args, **kwargs): + fut = asyncio.get_event_loop().create_future() + fut.set_result(1) + return fut + + client._mqtt_manager.subscribe = fake_subscribe + + # Patch MQTTManager.connect to simulate connection without real network + async def fake_connect(**kwargs): + client._mqtt_manager._is_connected = True + client._mqtt_manager._connected_event.set() + + client._mqtt_manager.connect = fake_connect + client._mqtt_manager._rate_limits_ready_event.set() + client._mqtt_manager.is_connected = lambda: True + + client._mqtt_manager._client._connection = DummyConnection() + client._mqtt_manager._client._persistent_storage.push_message_nowait = lambda *args, **kwargs: None + + async def fake_await_ready(): + return + + client._mqtt_manager.await_ready = fake_await_ready + + # Call connect (real queue and adapter will be initialized) + await client.connect() + + # Act: send timeseries + delivery_futures = await client.send_timeseries({"temp": 22}, wait_for_publish=False) + await asyncio.sleep(.1) + + # Trigger PUBACK manually + client._mqtt_manager._handle_puback_reason_code( + 1, + reason_code=0, + properties={} + ) + await asyncio.sleep(1) + + # Await result + result = await asyncio.wait_for(delivery_futures[0], timeout=1) + + mqtt_msg = client._mqtt_manager._client._connection.published_messages[0] + + # Assert assert isinstance(result, PublishResult) - assert result.message_id == 1 + assert result.message_id == -1 # The initial mqtt message doesn't contain message_id, expected behavior because an initial message can be split to separated messages or grouped with other messages + assert result.topic == mqtt_msg.topic + assert result.payload_size == mqtt_msg.payload_size + assert result.datapoints_count == mqtt_msg.datapoints + assert result.reason_code == 0 @pytest.mark.asyncio @@ -52,15 +109,66 @@ async def test_send_timeseries_timeout(): @pytest.mark.asyncio -async def test_send_attributes_dict(): +async def test_send_attributes_with_dict(): + # Setup client = DeviceClient() - client._message_queue = AsyncMock() - fut = asyncio.Future() - fut.set_result(PublishResult("attr", 1, 2, 50, 1)) - client._message_queue.publish.return_value = [fut] - result = await client.send_attributes({"key": "val"}) + client._mqtt_manager._handle_puback_reason_code = ( + client._mqtt_manager._handle_puback_reason_code.__get__(client._mqtt_manager) + ) + + class DummyConnection: + def __init__(self): + self.published_messages = [] + + def publish(self, *args, **kwargs): + self.published_messages.append(args[0]) + return 2, 1 # Simulate mid=2, qos=1 + + async def fake_subscribe(*args, **kwargs): + fut = asyncio.get_event_loop().create_future() + fut.set_result(1) + return fut + + client._mqtt_manager.subscribe = fake_subscribe + + async def fake_connect(**kwargs): + client._mqtt_manager._is_connected = True + client._mqtt_manager._connected_event.set() + + client._mqtt_manager.connect = fake_connect + client._mqtt_manager._rate_limits_ready_event.set() + client._mqtt_manager.is_connected = lambda: True + + client._mqtt_manager._client._connection = DummyConnection() + client._mqtt_manager._client._persistent_storage.push_message_nowait = lambda *args, **kwargs: None + + async def fake_await_ready(): + return + client._mqtt_manager.await_ready = fake_await_ready + + await client.connect() + + # Act: send attributes + delivery_futures = await client.send_attributes({"key": "value"}, wait_for_publish=False) + await asyncio.sleep(.1) + + # Trigger PUBACK manually + client._mqtt_manager._handle_puback_reason_code( + 2, reason_code=0, properties={} + ) + await asyncio.sleep(1) + + result = await asyncio.wait_for(delivery_futures[0], timeout=1) + + mqtt_msg = client._mqtt_manager._client._connection.published_messages[0] + + # Assert assert isinstance(result, PublishResult) - assert result.message_id == 2 + assert result.message_id == -1 # Split/group logic applies + assert result.topic == mqtt_msg.topic + assert result.payload_size == mqtt_msg.payload_size + assert result.datapoints_count == mqtt_msg.datapoints + assert result.reason_code == 0 @pytest.mark.asyncio @@ -115,7 +223,6 @@ async def test_claim_device_timeout(): assert result.message_id == -1 - @pytest.mark.asyncio async def test_claim_device_payload_contains_secret_key(): client = DeviceClient() @@ -130,8 +237,10 @@ async def test_claim_device_payload_contains_secret_key(): client._message_queue.publish.assert_awaited_once() args, kwargs = client._message_queue.publish.call_args - assert kwargs['topic'] == DEVICE_CLAIM_TOPIC - assert b"my-secret" in kwargs['payload'] + mqtt_msg = args[0] + assert isinstance(mqtt_msg, MqttPublishMessage) + assert mqtt_msg.topic == DEVICE_CLAIM_TOPIC + assert b"my-secret" in mqtt_msg.payload @pytest.mark.asyncio @@ -166,11 +275,12 @@ async def test_disconnect(): async def test_stop_disconnects_and_shuts_down_queue(): client = DeviceClient() client._mqtt_manager.is_connected = lambda: True - client._mqtt_manager.disconnect = AsyncMock() - client._message_queue = AsyncMock() + client._mqtt_manager.stop = AsyncMock() + client._message_queue = MagicMock() + client._message_queue.shutdown = AsyncMock() await client.stop() client._message_queue.shutdown.assert_awaited() - client._mqtt_manager.disconnect.assert_awaited() + client._mqtt_manager.stop.assert_awaited() @pytest.mark.asyncio @@ -334,7 +444,7 @@ async def test_stops_if_event_is_set_during_connection(): @pytest.mark.asyncio -async def test_initializes_dispatcher_and_queue_after_connection(): +async def test_initializes_adapter_and_queue_after_connection(): config = DeviceConfig() config.host = "localhost" config.port = 1883 @@ -382,7 +492,7 @@ async def test_uses_default_max_payload_size_when_not_provided(): @pytest.mark.asyncio -async def test_does_not_update_dispatcher_when_not_initialized(): +async def test_does_not_update_adapter_when_not_initialized(): client = DeviceClient() client.max_payload_size = None client._message_adapter = None @@ -472,7 +582,9 @@ async def test_send_attributes_no_wait(): client._message_queue = AsyncMock() client._message_queue.publish.return_value = [asyncio.Future()] result = await client.send_attributes({"attr": "val"}, wait_for_publish=False) - assert result is None + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) @pytest.mark.asyncio @@ -507,10 +619,8 @@ async def test_handle_requested_attribute_response_calls_handler(): async def test_on_disconnect_clears_handlers(): client = DeviceClient() client._requested_attribute_response_handler.clear = MagicMock() - client._rpc_response_handler.clear = MagicMock() await client._on_disconnect() client._requested_attribute_response_handler.clear.assert_called_once() - client._rpc_response_handler.clear.assert_called_once() @pytest.mark.asyncio diff --git a/tests/service/test_json_message_adapter.py b/tests/service/test_json_message_adapter.py index 1bcf490..5ae23f6 100644 --- a/tests/service/test_json_message_adapter.py +++ b/tests/service/test_json_message_adapter.py @@ -17,6 +17,7 @@ import pytest from orjson import dumps +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry @@ -62,9 +63,9 @@ def test_build_attribute_request(adapter): request = MagicMock(spec=AttributeRequest) request.request_id = 1 request.to_payload_format.return_value = {"clientKeys": "temp", "sharedKeys": "shared"} - topic, payload = adapter.build_attribute_request(request) - assert topic.endswith("/1") - assert b"clientKeys" in payload + mqtt_message = adapter.build_attribute_request(request) + assert mqtt_message.topic.endswith("/1") + assert b"clientKeys" in mqtt_message.payload def test_build_attribute_request_invalid(adapter): @@ -76,9 +77,9 @@ def test_build_attribute_request_invalid(adapter): def test_build_claim_request(adapter): req = ClaimRequest.build("secretKey") - topic, payload = adapter.build_claim_request(req) - assert topic == mqtt_topics.DEVICE_CLAIM_TOPIC - assert b"secretKey" in payload + mqtt_message = adapter.build_claim_request(req) + assert mqtt_message.topic == mqtt_topics.DEVICE_CLAIM_TOPIC + assert b"secretKey" in mqtt_message.payload def test_build_claim_request_invalid(adapter): @@ -90,9 +91,9 @@ def test_build_rpc_request(adapter): request = MagicMock(spec=RPCRequest) request.request_id = 42 request.to_payload_format.return_value = {"method": "reboot"} - topic, payload = adapter.build_rpc_request(request) - assert topic.endswith("42") - assert b"reboot" in payload + mqtt_message = adapter.build_rpc_request(request) + assert mqtt_message.topic.endswith("42") + assert b"reboot" in mqtt_message.payload def test_build_rpc_request_invalid(adapter): @@ -106,9 +107,9 @@ def test_build_rpc_response(adapter): response = MagicMock(spec=RPCResponse) response.request_id = 123 response.to_payload_format.return_value = {"result": "ok"} - topic, payload = adapter.build_rpc_response(response) - assert topic.endswith("123") - assert b"ok" in payload + mqtt_message = adapter.build_rpc_response(response) + assert mqtt_message.topic.endswith("123") + assert b"ok" in mqtt_message.payload def test_build_rpc_response_invalid(adapter): @@ -121,23 +122,23 @@ def test_build_rpc_response_invalid(adapter): def test_build_provision_request_access_token(adapter): credentials = AccessTokenProvisioningCredentials("key1", "secret1", access_token="tokenABC") req = ProvisioningRequest("localhost", credentials, device_name="dev1", gateway=True) - topic, payload = adapter.build_provision_request(req) - assert topic == mqtt_topics.PROVISION_REQUEST_TOPIC - assert b"provisionDeviceKey" in payload - assert b"tokenABC" in payload - assert b"credentialsType" in payload - assert b"deviceName" in payload - assert b"gateway" in payload + mqtt_message = adapter.build_provision_request(req) + assert mqtt_message.topic == mqtt_topics.PROVISION_REQUEST_TOPIC + assert b"provisionDeviceKey" in mqtt_message.payload + assert b"tokenABC" in mqtt_message.payload + assert b"credentialsType" in mqtt_message.payload + assert b"deviceName" in mqtt_message.payload + assert b"gateway" in mqtt_message.payload def test_build_provision_request_mqtt_basic(adapter): credentials = BasicProvisioningCredentials("key2", "secret2", client_id="cid", username="user", password="pass") req = ProvisioningRequest("127.0.0.1", credentials, device_name="dev2", gateway=False) - topic, payload = adapter.build_provision_request(req) - assert b"clientId" in payload - assert b"username" in payload - assert b"password" in payload - assert b"credentialsType" in payload + mqtt_message = adapter.build_provision_request(req) + assert b"clientId" in mqtt_message.payload + assert b"username" in mqtt_message.payload + assert b"password" in mqtt_message.payload + assert b"credentialsType" in mqtt_message.payload def test_build_provision_request_x509(adapter): @@ -146,10 +147,10 @@ def test_build_provision_request_x509(adapter): with patch("builtins.open", mock_open(read_data=cert_content)): credentials = X509ProvisioningCredentials("key3", "secret3", "key.pem", cert_path, "ca.pem") req = ProvisioningRequest("iot.server", credentials, device_name="dev3") - topic, payload = adapter.build_provision_request(req) - assert b"hash" in payload - assert b"credentialsType" in payload - assert b"FAKECERT" in payload + mqtt_message = adapter.build_provision_request(req) + assert b"hash" in mqtt_message.payload + assert b"credentialsType" in mqtt_message.payload + assert b"FAKECERT" in mqtt_message.payload def test_build_provision_request_x509_file_not_found(adapter): @@ -220,56 +221,48 @@ def test_parse_rpc_response_invalid(adapter): @pytest.mark.asyncio async def test_build_uplink_payloads_empty(adapter: JsonMessageAdapter): - assert adapter.build_uplink_payloads([]) == [] + assert adapter.build_uplink_messages([]) == [] @pytest.mark.asyncio async def test_build_uplink_payloads_only_attributes(adapter: JsonMessageAdapter): msg = build_msg(with_attr=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) with patch.object(adapter._splitter, "split_attributes", return_value=[msg]): - result = adapter.build_uplink_payloads([msg]) + result = adapter.build_uplink_messages([initial_mqtt_message]) assert len(result) == 1 - topic, payload, count, futures = result[0] - assert topic == DEVICE_ATTRIBUTES_TOPIC - assert count == 1 - assert b"a" in payload + mqtt_message = result[0] + assert mqtt_message.topic == DEVICE_ATTRIBUTES_TOPIC + assert mqtt_message.datapoints == 1 + assert b"a" in mqtt_message.payload @pytest.mark.asyncio async def test_build_uplink_payloads_only_timeseries(adapter: JsonMessageAdapter): msg = build_msg(with_ts=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) with patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): - result = adapter.build_uplink_payloads([msg]) + result = adapter.build_uplink_messages([initial_mqtt_message]) assert len(result) == 1 - topic, payload, count, futures = result[0] - assert topic == DEVICE_TELEMETRY_TOPIC - assert count == 1 - assert b"ts" in payload + mqtt_message = result[0] + assert mqtt_message.topic == DEVICE_TELEMETRY_TOPIC + assert mqtt_message.datapoints == 1 + assert b"ts" in mqtt_message.payload @pytest.mark.asyncio async def test_build_uplink_payloads_both(adapter: JsonMessageAdapter): msg = build_msg(with_attr=True, with_ts=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) + with patch.object(adapter._splitter, "split_attributes", return_value=[msg]), \ patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): - result = adapter.build_uplink_payloads([msg]) + result = adapter.build_uplink_messages([initial_mqtt_message]) assert len(result) == 2 - topics = {r[0] for r in result} + topics = {r.topic for r in result} assert DEVICE_ATTRIBUTES_TOPIC in topics assert DEVICE_TELEMETRY_TOPIC in topics - -@pytest.mark.asyncio -async def test_build_uplink_payloads_multiple_devices(adapter: JsonMessageAdapter): - msg1 = build_msg(device="dev1", with_attr=True) - msg2 = build_msg(device="dev2", with_ts=True) - with patch.object(adapter._splitter, "split_attributes", side_effect=lambda x: x), \ - patch.object(adapter._splitter, "split_timeseries", side_effect=lambda x: x): - result = adapter.build_uplink_payloads([msg1, msg2]) - topics = {r[0] for r in result} - assert DEVICE_ATTRIBUTES_TOPIC in topics or DEVICE_TELEMETRY_TOPIC in topics - - def test_build_payload_without_device_name(adapter: JsonMessageAdapter): builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 9)) msg = builder.build() @@ -286,22 +279,23 @@ def test_pack_attributes(): assert "x" in result -def test_pack_timeseries_uses_now(monkeypatch): +def test_pack_timeseries_no_ts(monkeypatch): monkeypatch.setattr("tb_mqtt_client.service.device.message_adapter.datetime", MagicMock()) ts_entry = TimeseriesEntry("temp", 23, ts=None) builder = DeviceUplinkMessageBuilder().add_timeseries(ts_entry) msg = builder.build() packed = JsonMessageAdapter.pack_timeseries(msg) - assert isinstance(packed, list) - assert "ts" in packed[0] - assert "values" in packed[0] + assert isinstance(packed, dict) + assert ts_entry.key in packed + assert packed[ts_entry.key] == ts_entry.value def test_build_uplink_payloads_error_handling(adapter: JsonMessageAdapter): with patch("tb_mqtt_client.service.device.message_adapter.DeviceUplinkMessage.has_attributes", side_effect=Exception("boom")): msg = build_msg(with_attr=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) with pytest.raises(Exception, match="boom"): - adapter.build_uplink_payloads([msg]) + adapter.build_uplink_messages([initial_mqtt_message]) def test_parse_provisioning_response_success(adapter, dummy_provisioning_request): @@ -324,4 +318,7 @@ def test_parse_provisioning_response_failure(adapter, dummy_provisioning_request args = mock_build.call_args[0] assert args[0] == dummy_provisioning_request assert args[1]["status"] == "FAILURE" - assert "errorMsg" in args[1] \ No newline at end of file + assert "errorMsg" in args[1] + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) \ No newline at end of file diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py index 245a57a..c0bc58e 100644 --- a/tests/service/test_message_queue.py +++ b/tests/service/test_message_queue.py @@ -13,1113 +13,341 @@ # limitations under the License. import asyncio -from contextlib import AsyncExitStack -from time import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, Mock, MagicMock, patch import pytest +import pytest_asyncio -from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit -from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.constants.service_keys import TELEMETRY_DATAPOINTS_RATE_LIMIT -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter from tb_mqtt_client.service.message_queue import MessageQueue -@pytest.mark.asyncio -async def test_batching_device_uplink_message(): - mqtt_manager = MagicMock() - future = asyncio.Future() - future.set_result(PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 7, 0)) - mqtt_manager.publish = AsyncMock(return_value=future) - mqtt_manager.backpressure.should_pause.return_value = False - main_stop_event = asyncio.Event() +@pytest_asyncio.fixture +async def fake_mqtt_manager(): + mgr = AsyncMock() + mgr.backpressure = Mock(spec=BackpressureController) + mgr.backpressure.should_pause.return_value = False + return mgr - delivery_future = asyncio.Future() - dummy_message = MagicMock() - dummy_message.size = 10 - dummy_message.device_name = "device" - dummy_message.get_delivery_futures.return_value = [delivery_future] - adapter = MagicMock() - adapter.splitter.max_payload_size = 100 - adapter.build_uplink_payloads.return_value = [ - (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'batch_payload', 1, [delivery_future]) - ] +@pytest_asyncio.fixture +async def message_queue(fake_mqtt_manager): + main_stop_event = asyncio.Event() + message_adapter = Mock() + message_adapter.build_uplink_messages.return_value = [] - queue = MessageQueue( - mqtt_manager=mqtt_manager, + mq = MessageQueue( + mqtt_manager=fake_mqtt_manager, main_stop_event=main_stop_event, message_rate_limit=None, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_adapter=adapter, - batch_collect_max_time_ms=50, - batch_collect_max_count=10 + message_adapter=message_adapter, + max_queue_size=10, + batch_collect_max_time_ms=10, + batch_collect_max_count=5 ) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_message, 1, qos=1) - await asyncio.sleep(0.1) - await queue.shutdown() - - assert delivery_future.done() - assert isinstance(delivery_future.result(), PublishResult) + try: + yield mq + finally: + await mq.shutdown() @pytest.mark.asyncio -async def test_telemetry_rate_limit_retry_triggered(): - telemetry_limit = MagicMock() - telemetry_limit.try_consume = AsyncMock(return_value=(10, 1)) - telemetry_limit.minimal_timeout = 0.01 - telemetry_limit.to_dict.return_value = {"limit": "mock"} - - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False +async def test_publish_success(message_queue): + message = MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[]) + await message_queue.publish(message) + assert not message_queue.is_empty() - main_stop_event = asyncio.Event() - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - delivery_future = asyncio.Future() - delivery_future.set_result(PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 5, 0)) - adapter.build_uplink_payloads.return_value = [ - (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b'dummy_payload', 1, [delivery_future]) - ] +@pytest.mark.asyncio +async def test_publish_queue_full(message_queue): + for _ in range(10): + await message_queue.publish(MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[])) + message = MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[Mock()]) + await message_queue.publish(message) # Should not raise + assert message_queue._queue.qsize() <= 10 - msg = DeviceUplinkMessageBuilder() - msg.set_device_name("device") - msg.add_timeseries(TimeseriesEntry("temp", 1)) - msg = msg.build() - queue = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=main_stop_event, +@pytest.mark.asyncio +async def test_shutdown_clears_tasks_and_queue(fake_mqtt_manager): + q = MessageQueue( + mqtt_manager=fake_mqtt_manager, + main_stop_event=asyncio.Event(), message_rate_limit=None, - telemetry_rate_limit=telemetry_limit, + telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_adapter=adapter + message_adapter=Mock(build_uplink_messages=Mock(return_value=[])), + max_queue_size=10, + batch_collect_max_time_ms=10, + batch_collect_max_count=5 ) - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, qos=1) - await asyncio.sleep(0.1) - await queue.shutdown() - - telemetry_limit.try_consume.assert_awaited() + await q.publish(MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[])) + await q.shutdown() + assert q.is_empty() @pytest.mark.asyncio -async def test_shutdown_clears_queue(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - stop_event = asyncio.Event() - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - dummy = MagicMock() - dummy.size = 1 - dummy.get_delivery_futures.return_value = [] - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy, 1, 1) - await queue.shutdown() - assert queue.is_empty() - - -@pytest.mark.asyncio -async def test_publish_raw_bytes_success(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - main_stop_event = asyncio.Event() - - queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, adapter) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"payload", 1, qos=1) - await asyncio.sleep(0.05) - await queue.shutdown() - mqtt_manager.publish.assert_called() +async def test_try_publish_backpressure_delay(message_queue): + message_queue._mqtt_manager.backpressure.should_pause.return_value = True + built = DeviceUplinkMessageBuilder().add_timeseries( + TimeseriesEntry("test", "test") + ).build() + message = MqttPublishMessage("topic", built, qos=1, delivery_futures=[]) -@pytest.mark.asyncio -async def test_publish_device_uplink_message_batched(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - main_stop_event = asyncio.Event() + # Patch the retry scheduler to bypass delay and requeue immediately + with patch.object(message_queue, "_schedule_delayed_retry") as mocked_retry: + mocked_retry.side_effect = lambda m: message_queue._queue.put_nowait(m) - future = asyncio.Future() - dummy_msg = MagicMock() - dummy_msg.size = 10 - dummy_msg.device_name = "dev" - dummy_msg.get_delivery_futures.return_value = [future] + await message_queue._try_publish(message) - adapter = MagicMock() - adapter.splitter.max_payload_size = 100 - adapter.build_uplink_payloads.return_value = [ - (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batch", 1, [future]) - ] + # Wait briefly for the message to be re-enqueued + for _ in range(20): + if not message_queue.is_empty(): + break + await asyncio.sleep(0.01) + else: + raise AssertionError("Message was not re-enqueued in time") - queue = MessageQueue(mqtt_manager, main_stop_event, None, None, None, adapter) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, qos=1) - await asyncio.sleep(0.1) - await queue.shutdown() - assert future.done() + # Ensure _schedule_delayed_retry was triggered + assert mocked_retry.called, "_schedule_delayed_retry should be called" @pytest.mark.asyncio -async def test_rate_limit_telemetry_triggers_retry(): - limit = MagicMock() - limit.try_consume = AsyncMock(return_value=(1, 1)) +async def test_try_publish_rate_limit_triggered(): + limit = Mock(spec=RateLimit) + limit.try_consume = AsyncMock(return_value=(10, 1)) limit.minimal_timeout = 0.01 - limit.to_dict.return_value = {} - - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - main_stop_event = asyncio.Event() - - msg = MagicMock() - msg.device_name = "d" - msg.size = 1 - msg.get_delivery_futures.return_value = [] - - queue = MessageQueue(mqtt_manager, main_stop_event, None, limit, None, adapter) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, 1, 1) - await asyncio.sleep(0.2) - await queue.shutdown() - mqtt_manager.publish.assert_not_called() - - -@pytest.mark.asyncio -async def test_retry_on_exception(): - mqtt_manager = MagicMock() - publish_mock = AsyncMock() - first_attempt = RuntimeError("fail") - - second_attempt_future = asyncio.Future() - - async def complete_publish_result_later(): - await asyncio.sleep(0.1) - publish_result = PublishResult(mqtt_topics.DEVICE_TELEMETRY_TOPIC, 1, 0, 1, 0) - second_attempt_future.set_result(publish_result) - - asyncio.create_task(complete_publish_result_later()) - - publish_mock.side_effect = [first_attempt, second_attempt_future] - - mqtt_manager.publish = publish_mock - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - - future = asyncio.Future() - dummy_msg = MagicMock() - dummy_msg.device_name = "dev" - dummy_msg.size = 10 - dummy_msg.get_delivery_futures.return_value = [future] - - adapter.build_uplink_payloads.return_value = [ - (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"payload", 1, [future]) - ] - - stop_event = asyncio.Event() - queue = MessageQueue( - mqtt_manager, - stop_event, - None, None, None, - adapter, - batch_collect_max_time_ms=10 - ) - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, qos=1) - - await asyncio.sleep(0.5) - await queue.shutdown() - - assert mqtt_manager.publish.call_count == 2 - assert future.done() - assert isinstance(future.result(), PublishResult) - - -@pytest.mark.asyncio -async def test_mixed_raw_and_structured_queue(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - stop_event = asyncio.Event() - - future = asyncio.Future() - uplink_msg = MagicMock() - uplink_msg.device_name = "x" - uplink_msg.size = 10 - uplink_msg.get_delivery_futures.return_value = [future] - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100 - adapter.build_uplink_payloads.return_value = [ - (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batched", 1, [future]) - ] - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter, batch_collect_max_time_ms=20) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"raw", 1, 1) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, uplink_msg, 1, 1) - await asyncio.sleep(0.1) - await queue.shutdown() - assert future.done() - - -@pytest.mark.asyncio -async def test_rate_limit_refill_executes(): - r1, r2, r3 = MagicMock(), MagicMock(), MagicMock() - for r in (r1, r2, r3): - r.refill = AsyncMock() - r.to_dict.return_value = {} - - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - stop_event = asyncio.Event() - - queue = MessageQueue(mqtt_manager, stop_event, r1, r2, r3, adapter) - await asyncio.sleep(1.2) - await queue.shutdown() - - r1.refill.assert_awaited() - r2.refill.assert_awaited() - r3.refill.assert_awaited() - - -@pytest.mark.asyncio -async def test_try_publish_without_delivery_futures(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock(return_value=asyncio.Future()) - mqtt_manager.publish.return_value.set_result(PublishResult("t", 1, 1, 1, 1)) - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - await queue._try_publish("custom/topic", b"payload", datapoints=1, delivery_futures_or_none=None, qos=1) - await queue.shutdown() - - mqtt_manager.publish.assert_called_once() - - -@pytest.mark.asyncio -async def test_schedule_delayed_retry_skipped_if_inactive_or_stopped(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - stop_event.set() - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - queue._active.clear() - - queue._schedule_delayed_retry("topic", b"data", datapoints=1, qos=1, delay=0.01) - await queue.shutdown() - - -@pytest.mark.asyncio -async def test_clear_queue_sets_futures_to_publish_result(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - - with patch.object(MessageQueue, "_dequeue_loop", new=AsyncMock()): - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - dummy_msg = DeviceUplinkMessageBuilder() \ - .add_delivery_futures(asyncio.Future()) \ - .add_timeseries(TimeseriesEntry("temp", 1)) \ - .build() - - future = (await queue.publish("some/topic", dummy_msg, 1, 1))[0] - - queue.clear() - - assert future.done() - result = future.result() - assert isinstance(result, PublishResult) - assert result.topic == "some/topic" - assert result.payload_size == dummy_msg.size - assert result.reason_code == -1 - assert result.qos == 1 - assert result.message_id == -1 - - await queue.shutdown() - - -@pytest.mark.asyncio -async def test_wait_for_message_exit_on_inactive(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - queue._active.clear() - - with pytest.raises(asyncio.CancelledError): - await queue._wait_for_message() - await queue.shutdown() - - -@pytest.mark.asyncio -async def test_schedule_delayed_retry_requeues_message(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - - with patch("tb_mqtt_client.service.message_queue.MessageQueue._dequeue_loop", new=AsyncMock()): - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - future = asyncio.Future() - dummy_msg = MagicMock() - dummy_msg.device_name = "dev" - dummy_msg.size = 10 - dummy_msg.get_delivery_futures.return_value = [future] - - queue._schedule_delayed_retry( - topic="retry/topic", - payload=b"retry-payload", - datapoints=1, - qos=1, - delay=0.05, - delivery_futures=[future] - ) - - await asyncio.sleep(0.1) - assert not queue.is_empty() - - await queue.shutdown() - - -@pytest.mark.asyncio -async def test_cancel_tasks_clears_all(): - mqtt_manager = MagicMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - async def dummy(): - await asyncio.sleep(1) - - task = asyncio.create_task(dummy()) - queue._retry_tasks.add(task) - - await queue._cancel_tasks(queue._retry_tasks) - assert len(queue._retry_tasks) == 0 - - -@pytest.mark.asyncio -async def test_clear_queue_with_bytes_message(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - - stop_event = asyncio.Event() - with patch.object(MessageQueue, "_dequeue_loop", new=AsyncMock()): - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - future = asyncio.Future() - await queue.publish("raw/topic", b"abc", 1, 0) - queue._queue._queue[0] = ("raw/topic", b"abc", [future], 1, 0) - - queue.clear() - assert future.done() - result = future.result() - assert result.topic == "raw/topic" - assert result.payload_size == 3 - assert result.qos == 0 - assert result.reason_code == -1 - assert result.message_id == -1 - - await queue.shutdown() - - -@pytest.mark.asyncio -async def test_resolve_attached_handles_publish_exception(): - future = asyncio.Future() - future.set_exception(RuntimeError("fail")) - - f1 = asyncio.Future() - f2 = asyncio.Future() - - dummy_payload = b"abc" - topic = "topic" - qos = 1 - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock(return_value=future) - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - await queue._try_publish( - topic=topic, - payload=dummy_payload, - datapoints=1, - delivery_futures_or_none=[f1, f2], - qos=qos + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=None, + telemetry_rate_limit=limit, + telemetry_dp_rate_limit=None, + message_adapter=Mock(build_uplink_messages=Mock(return_value=[])), ) - + mq._schedule_delayed_retry = AsyncMock() + mq._mqtt_manager.publish = AsyncMock() + message = MqttPublishMessage("telemetry", b"{}", qos=1, delivery_futures=[]) + await mq._try_publish(message) await asyncio.sleep(0.05) - assert f1.done() and f2.done() - for f in (f1, f2): - res = f.result() - assert isinstance(res, PublishResult) - assert res.reason_code == -1 - - await queue.shutdown() - - -@pytest.mark.asyncio -async def test_try_publish_message_type_non_telemetry(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - - rate_limit = MagicMock() - rate_limit.try_consume = AsyncMock(return_value=None) - rate_limit.to_dict.return_value = {} - rate_limit.minimal_timeout = 0.1 - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, message_rate_limit=rate_limit, - telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_adapter=adapter) - - await queue._try_publish( - topic="non/telemetry", - payload=b"x", - datapoints=1, - delivery_futures_or_none=[], - qos=0 - ) - mqtt_manager.publish.assert_called_once() - await queue.shutdown() + assert mq._mqtt_manager.publish.call_count == 0 + assert mq._schedule_delayed_retry.call_count == 1 @pytest.mark.asyncio -async def test_shutdown_rate_limit_task_cancel_only(): - mqtt_manager = MagicMock() - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager.publish = AsyncMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - # Cancel only the rate limit task before shutdown - queue._rate_limit_refill_task.cancel() - - await queue.shutdown() - assert queue._rate_limit_refill_task.cancelled() +async def test_try_publish_failure_schedules_retry(message_queue): + message = MqttPublishMessage("topic", b"broken", qos=1, delivery_futures=[]) + message_queue._mqtt_manager.publish.side_effect = Exception("fail") + message_queue._schedule_delayed_retry = AsyncMock() + await message_queue._try_publish(message) + await asyncio.sleep(0.2) + assert message_queue._schedule_delayed_retry.call_count == 1 @pytest.mark.asyncio -async def test_schedule_delayed_retry_when_main_stop_active(): - mqtt_manager = MagicMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - - stop_event = asyncio.Event() - stop_event.set() - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - queue._active.clear() - - queue._schedule_delayed_retry("x", b"y", 1, 0, 0.01) - await asyncio.sleep(0.05) - assert queue._queue.empty() - await queue.shutdown() +async def test_cancel_tasks(): + task1 = asyncio.create_task(asyncio.sleep(5)) + task2 = asyncio.create_task(asyncio.sleep(5)) + tasks = {task1, task2} + await MessageQueue._cancel_tasks(tasks) + assert all(t.cancelled() or t.done() for t in [task1, task2]) @pytest.mark.asyncio -async def test_publish_queue_full_sets_failed_result_for_bytes(): - mqtt_manager = MagicMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - stop_event = asyncio.Event() - - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter, max_queue_size=1) - await queue.publish("t", b"raw", 1, qos=0) - - queue._queue.put_nowait = MagicMock(side_effect=asyncio.QueueFull) - - result = await queue.publish("t", b"raw", 1, qos=0) - assert result is not None - assert isinstance(result[0], asyncio.Future) - await asyncio.sleep(0) - assert result[0].done() - assert result[0].result().reason_code == -1 - - await queue.shutdown() +async def test_rate_limit_refill(): + rate_limit = Mock(spec=RateLimit) + rate_limit.refill = AsyncMock() + rate_limit.to_dict = Mock(return_value={"x": 1}) + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=rate_limit, + telemetry_rate_limit=rate_limit, + telemetry_dp_rate_limit=rate_limit, + message_adapter=Mock(), + ) + await mq._refill_rate_limits() + assert rate_limit.refill.await_count == 3 @pytest.mark.asyncio -async def test_wait_for_message_raises_cancelled(): - mqtt_manager = MagicMock() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - queue._active.clear() - - with pytest.raises(asyncio.CancelledError): - await queue._wait_for_message() - - await queue.shutdown() +async def test_set_gateway_adapter(message_queue): + adapter = Mock() + message_queue.set_gateway_message_adapter(adapter) + assert message_queue._gateway_adapter is adapter @pytest.mark.asyncio -async def test_batch_loop_breaks_on_count_threshold(): - # Setup: fake PublishResult to return - publish_result = PublishResult( - topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, - qos=1, - message_id=42, - payload_size=8, - reason_code=0 - ) - - # This is what the MQTT manager's publish will return - publish_future = asyncio.Future() - publish_future.set_result(publish_result) - - # Now mock the MQTT manager - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock(return_value=publish_future) - mqtt_manager.backpressure.should_pause.return_value = False - - # This is the future that the message queue should resolve - delivery_future = asyncio.Future() - - # Mock adapter to output the delivery future - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [ - (mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"batched", 1, [delivery_future]) - ] - - # Create and start the queue - stop_event = asyncio.Event() - queue = MessageQueue( - mqtt_manager, stop_event, +async def test_wait_for_message_cancel(): + stop = asyncio.Event() + stop.set() + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=stop, message_rate_limit=None, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_adapter=adapter, - batch_collect_max_count=2 - ) - - dummy_msg = MagicMock() - dummy_msg.size = 10 - dummy_msg.device_name = "dev" - dummy_msg.get_delivery_futures.return_value = [delivery_future] - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, 1) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, dummy_msg, 1, 1) - - # Allow time for batching and publishing - await asyncio.sleep(0.1) - await queue.shutdown() - - assert mqtt_manager.publish.called - assert delivery_future.done() - result = delivery_future.result() - assert isinstance(result, PublishResult) - assert result.topic == mqtt_topics.DEVICE_TELEMETRY_TOPIC - - -@pytest.mark.asyncio -async def test_batch_loop_skips_message_on_size_exceed(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 15 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - small_msg = MagicMock() - small_msg.size = 10 - small_msg.device_name = "dev" - small_msg.get_delivery_futures.return_value = [] - - large_msg = MagicMock() - large_msg.size = 10 - large_msg.device_name = "dev" - large_msg.get_delivery_futures.return_value = [] - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, small_msg, 1, 1) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, large_msg, 1, 1) - - await asyncio.sleep(0.1) - await queue.shutdown() - - assert mqtt_manager.publish.called - - -@pytest.mark.asyncio -async def test_batch_requeues_on_size_exceed(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 15 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - msg1 = MagicMock() - msg1.size = 10 - msg1.device_name = "dev" - msg1.get_delivery_futures.return_value = [] - - msg2 = MagicMock() - msg2.size = 10 - msg2.device_name = "dev" - msg2.get_delivery_futures.return_value = [] - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg1, 1, 1) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg2, 1, 1) - - await asyncio.sleep(0.1) - await queue.shutdown() - - assert mqtt_manager.publish.call_count >= 1 - - -@pytest.mark.asyncio -async def test_batch_immediate_publish_on_raw_bytes(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"raw_payload", 1, 1) - - await asyncio.sleep(0.05) - await queue.shutdown() - - mqtt_manager.publish.assert_called() - args, kwargs = mqtt_manager.publish.call_args - assert isinstance(kwargs['payload'], bytes) - - -@pytest.mark.asyncio -async def test_batch_queue_empty_breaks_safely(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - queue = MessageQueue(mqtt_manager, stop_event, None, None, None, adapter) - - await asyncio.sleep(0.05) - await queue.shutdown() - - assert mqtt_manager.publish.call_count == 0 - - -@pytest.mark.asyncio -async def test_try_publish_telemetry_rate_limited(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 1000 - adapter.build_uplink_payloads.return_value = [("topic", b"{}", 3, [])] - telemetry_rate_limit = MagicMock() - telemetry_rate_limit.try_consume = AsyncMock(return_value=(10, 1)) - telemetry_rate_limit.minimal_timeout = 0.5 - - stop_event = asyncio.Event() - queue = MessageQueue( - mqtt_manager, - stop_event, - telemetry_rate_limit, - None, - RateLimit("10:1", TELEMETRY_DATAPOINTS_RATE_LIMIT, 100), - adapter + message_adapter=Mock(), ) - queue._schedule_delayed_retry = MagicMock() - - msg = (DeviceUplinkMessageBuilder().add_timeseries( - [TimeseriesEntry("temp", 1), - TimeseriesEntry("hum", 2), - TimeseriesEntry("pres", 3)]) - .build()) + await mq._queue.put("dummy") - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg, msg.timeseries_datapoint_count(), 1) - await asyncio.sleep(0.1) - await queue.shutdown() - - queue._schedule_delayed_retry.assert_called_once() - mqtt_manager.publish.assert_not_called() + result = await mq._wait_for_message() + assert result == "dummy" @pytest.mark.asyncio -async def test_try_publish_non_telemetry_rate_limited(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = MagicMock() - adapter.splitter.max_payload_size = 1000 - adapter.build_uplink_payloads.return_value = [("topic", b"{}", 1)] - - message_rate_limit = MagicMock() - message_rate_limit.try_consume = AsyncMock(return_value=(5, 60)) - message_rate_limit.minimal_timeout = 1.0 - - stop_event = asyncio.Event() - queue = MessageQueue( - mqtt_manager, - stop_event, +async def test_clear_futures_result_set(): + fut = asyncio.Future() + fut.uuid = "test-future" + msg = b"data" + fut.set_result = Mock() + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=None, telemetry_rate_limit=None, - message_rate_limit=message_rate_limit, telemetry_dp_rate_limit=None, - message_adapter=adapter + message_adapter=Mock(), ) + mq._queue.put_nowait(MqttPublishMessage("topic", msg, qos=1, delivery_futures=[fut])) + mq.clear() + assert mq.is_empty() - queue._schedule_delayed_retry = MagicMock() - - payload = b'raw-bytes' - topic = "v1/devices/me/rpc/request/1" - - await queue._try_publish(topic, payload, datapoints=0, qos=1, delivery_futures_or_none=None) - await asyncio.sleep(0.05) - await queue.shutdown() - - message_rate_limit.try_consume.assert_awaited_once_with(1) - queue._schedule_delayed_retry.assert_called_once_with( - topic=topic, - payload=payload, - datapoints=0, - qos=1, - delay=1.0, - delivery_futures=[] - ) - mqtt_manager.publish.assert_not_called() - -@pytest.mark.parametrize("paused", [True, False]) @pytest.mark.asyncio -async def test_backpressure_delays_publish(paused, monkeypatch): +async def test_message_queue_batching_respects_type_and_size(): mqtt_manager = MagicMock() mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = paused - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - stop_event = asyncio.Event() - queue = MessageQueue( - mqtt_manager, - stop_event, + message_adapter = JsonMessageAdapter() + + mq = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=stop_event, message_rate_limit=None, telemetry_rate_limit=None, telemetry_dp_rate_limit=None, - message_adapter=adapter, - batch_collect_max_count=1 + message_adapter=message_adapter, + batch_collect_max_time_ms=200, + max_queue_size=100, ) + mq._backpressure.should_pause.return_value = False + mq._mqtt_manager.backpressure.should_pause.return_value = False - scheduled_retry_mock = MagicMock() - monkeypatch.setattr(queue, "_schedule_delayed_retry", scheduled_retry_mock) + builder1 = DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("key1", 1)) + msg1 = builder1.build() - await queue._try_publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, b"test_payload", 1, qos=1) - await asyncio.sleep(0.05) + builder2 = DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("key2", 2)) + msg2 = builder2.build() - if paused: - scheduled_retry_mock.assert_called_once_with( - topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, - payload=b"test_payload", - datapoints=1, - qos=1, - delay=1.0, - delivery_futures=[] - ) - mqtt_manager.publish.assert_not_called() - else: - scheduled_retry_mock.assert_not_called() - mqtt_manager.publish.assert_called_once_with( - message_or_topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, - payload=b"test_payload", - qos=1 - ) - - await queue.shutdown() + # First message fits, second exceeds total max_payload_size + mq._queue.put_nowait(MqttPublishMessage("topic", msg1, qos=1)) + mq._queue.put_nowait(MqttPublishMessage("topic", msg2, qos=1)) + await asyncio.sleep(0.2) # Give loop time to pick up -@pytest.mark.asyncio -async def test_publish_telemetry_rate_limit_triggered(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False + await mq.shutdown() - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [] - - stop_event = asyncio.Event() - - telemetry_dp_rate_limit = MagicMock() - telemetry_dp_rate_limit.try_consume = AsyncMock(return_value=(10, 60)) - telemetry_dp_rate_limit.minimal_timeout = 1.23 - - async with AsyncExitStack() as stack: - queue = MessageQueue( - mqtt_manager, - stop_event, - None, - None, - telemetry_dp_rate_limit=telemetry_dp_rate_limit, - message_adapter=adapter, - ) - stack.push_async_callback(queue.shutdown) - - msg = DeviceUplinkMessageBuilder() \ - .add_timeseries([TimeseriesEntry(f"temp{i}", i) for i in range(10)]) \ - .build() - - with patch.object(queue, "_schedule_delayed_retry", wraps=queue._schedule_delayed_retry) as delayed_retry_mock: - await queue._try_publish( - mqtt_topics.DEVICE_TELEMETRY_TOPIC, - b'', - qos=1, - datapoints=msg.timeseries_datapoint_count() - ) - - mqtt_manager.publish.assert_not_called() - delayed_retry_mock.assert_called_once() - args, kwargs = delayed_retry_mock.call_args - assert kwargs["topic"] == mqtt_topics.DEVICE_TELEMETRY_TOPIC - assert kwargs["delay"] == telemetry_dp_rate_limit.minimal_timeout + mqtt_manager.publish.assert_called_once() @pytest.mark.asyncio -async def test_batch_loop_large_messages_are_split_and_published(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - adapter = JsonMessageAdapter(100, 20) +async def test_schedule_delayed_retry_reenqueues_message(): + msg = MqttPublishMessage(topic="test/topic", payload=b"data", qos=1) - stop_event = asyncio.Event() queue = MessageQueue( - mqtt_manager, - stop_event, - None, - None, - None, - message_adapter=adapter, - max_queue_size=100, - batch_collect_max_time_ms=10 + mqtt_manager=MagicMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=MagicMock() ) - builder1 = DeviceUplinkMessageBuilder() - for i in range(30): - builder1.add_timeseries(TimeseriesEntry(f"t{i}", i)) - message1 = builder1.build() + queue._active.set() + queue._main_stop_event.clear() - builder2 = DeviceUplinkMessageBuilder() - for i in range(30, 60): - builder2.add_timeseries(TimeseriesEntry(f"t{i}", i)) - message2 = builder2.build() + queue._queue = asyncio.Queue() + queue._retry_tasks.clear() - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, message1, message1.timeseries_datapoint_count(), qos=1) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, message2, message2.timeseries_datapoint_count(), qos=1) + queue._schedule_delayed_retry(msg, delay=0.05) - await asyncio.sleep(0.2) - await queue.shutdown() + await asyncio.sleep(0.1) - assert mqtt_manager.publish.call_count == 6 + requeued_msg = await queue._queue.get() + assert requeued_msg.topic == msg.topic - for call in mqtt_manager.publish.call_args_list: - kwargs = call.kwargs - assert kwargs["message_or_topic"] == "v1/devices/me/telemetry" - assert isinstance(kwargs["payload"], bytes) - assert kwargs["qos"] == 1 + assert queue._wakeup_event.is_set() + await asyncio.sleep(0) -@pytest.mark.asyncio -async def test_delivery_futures_resolved_via_real_puback_handler(): - delivery_future = asyncio.Future() + assert len(queue._retry_tasks) == 0 - mqtt_future = asyncio.Future() - mqtt_future.mid = 123 - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock(return_value=mqtt_future) - mqtt_manager.backpressure.should_pause.return_value = False - mqtt_manager._backpressure = MagicMock() - mqtt_manager._on_publish_result_callback = None - - from tb_mqtt_client.service.mqtt_manager import MQTTManager - mqtt_manager._handle_puback_reason_code = MQTTManager._handle_puback_reason_code.__get__(mqtt_manager) - - topic = mqtt_topics.DEVICE_TELEMETRY_TOPIC - qos = 1 - payload_size = 24 - publish_time = 1.0 - mqtt_manager._pending_publishes = { - mqtt_future.mid: (delivery_future, topic, qos, payload_size, publish_time) - } - - adapter = MagicMock() - adapter.splitter.max_payload_size = 100000 - adapter.build_uplink_payloads.return_value = [ - (topic, b'{"some":"payload"}', qos, [delivery_future]) - ] +@pytest.mark.asyncio +async def test_schedule_delayed_retry_does_nothing_if_inactive(): + msg = MqttPublishMessage(topic="inactive/topic", payload=b"data", qos=1) - stop_event = asyncio.Event() queue = MessageQueue( - mqtt_manager, - stop_event, - None, - None, - None, - message_adapter=adapter, - batch_collect_max_count=1, - batch_collect_max_time_ms=1 + mqtt_manager=MagicMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=MagicMock() ) - msg = ( - DeviceUplinkMessageBuilder() - .set_device_name("deviceA") - .add_timeseries([TimeseriesEntry("temperature", 25)]) - .add_delivery_futures([delivery_future]) - .build() - ) + queue._active.clear() + queue._main_stop_event.clear() - await queue.publish(topic, msg, msg.timeseries_datapoint_count(), qos=qos) - await asyncio.sleep(0.05) + queue._queue = asyncio.Queue() + queue._retry_tasks.clear() - mqtt_future.set_result(None) - mqtt_manager._handle_puback_reason_code(mqtt_future.mid, 0, {}) + queue._schedule_delayed_retry(msg, delay=0.01) await asyncio.sleep(0.05) - assert delivery_future.done() - result = delivery_future.result() - assert isinstance(result, PublishResult) - assert result.topic == topic - assert result.message_id == mqtt_future.mid - assert result.reason_code == 0 + assert queue._queue.empty() + assert len(queue._retry_tasks) == 0 @pytest.mark.asyncio -async def test_batch_append_and_batch_size_accumulate(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False +async def test_schedule_delayed_retry_does_nothing_if_stopped(): + msg = MqttPublishMessage(topic="stop/topic", payload=b"data", qos=1) - adapter = JsonMessageAdapter(100000, 10000) stop_event = asyncio.Event() + stop_event.set() queue = MessageQueue( - mqtt_manager, - stop_event, - None, - None, - None, - message_adapter=adapter, - batch_collect_max_count=2, - batch_collect_max_time_ms=1000 + mqtt_manager=MagicMock(), + main_stop_event=stop_event, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=MagicMock() ) - fixed_ts = int(time() * 1000) + queue._active.set() + queue._queue = asyncio.Queue() + queue._retry_tasks.clear() - msg1 = DeviceUplinkMessageBuilder() \ - .add_timeseries([TimeseriesEntry(f"temp{i}", i, ts=fixed_ts) for i in range(10)]) \ - .build() - msg2 = DeviceUplinkMessageBuilder() \ - .add_timeseries([TimeseriesEntry(f"temp{i}", i, ts=fixed_ts) for i in range(10)]) \ - .build() + queue._schedule_delayed_retry(msg, delay=0.01) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg1, msg1.timeseries_datapoint_count(), 1) - await queue.publish(mqtt_topics.DEVICE_TELEMETRY_TOPIC, msg2, msg2.timeseries_datapoint_count(), 1) + await asyncio.sleep(0.05) - await asyncio.sleep(0.3) - await queue.shutdown() + assert queue._queue.empty() + assert len(queue._retry_tasks) == 0 - mqtt_manager.publish.assert_called_once() - args, kwargs = mqtt_manager.publish.call_args - assert kwargs["message_or_topic"] == mqtt_topics.DEVICE_TELEMETRY_TOPIC - assert isinstance(kwargs["payload"], bytes) - assert kwargs["qos"] == 1 +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index b15e3ad..341d3e3 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -131,48 +131,6 @@ def test_builder_failure_during_split_raises(mock_builder_class): splitter.split_timeseries([message]) -# Negative test: one of delivery futures fails -@pytest.mark.asyncio -@patch("tb_mqtt_client.service.message_splitter.DeviceUplinkMessageBuilder") -async def test_delivery_future_failure_propagation(mock_builder_class, splitter): - entry = mock_ts_entry() - message = MagicMock() - message.device_name = "deviceX" - message.device_profile = "profileX" - message.has_timeseries.return_value = True - message.timeseries = {"data": [entry] * 4} - message.attributes_datapoint_count.return_value = 0 - message.timeseries_datapoint_count.return_value = 4 - message.size = 50 - - main_future = asyncio.Future() - message.get_delivery_futures.return_value = [main_future] - - fail_future = asyncio.Future() - ok_future = asyncio.Future() - - built_msg1 = MagicMock() - built_msg1.get_delivery_futures.return_value = [fail_future] - - built_msg2 = MagicMock() - built_msg2.get_delivery_futures.return_value = [ok_future] - - builder = MagicMock() - builder.build.side_effect = [built_msg1, built_msg2] - mock_builder_class.return_value = builder - - result = splitter.split_timeseries([message]) - assert len(result) == 2 - - await asyncio.sleep(0) - fail_future.set_result(False) - ok_future.set_result(True) - - await asyncio.sleep(0.1) - assert main_future.done() - assert main_future.result() is True - - # Property validation def test_payload_setter_validation(): s = MessageSplitter() @@ -238,3 +196,5 @@ async def test_split_attributes_different_devices_not_grouped(): fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 935de01..2a802a4 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -19,6 +19,7 @@ import pytest import pytest_asyncio +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT @@ -30,7 +31,7 @@ @pytest_asyncio.fixture async def setup_manager(): stop_event = asyncio.Event() - message_dispatcher = MagicMock(spec=MessageAdapter) + message_adapter = MagicMock(spec=MessageAdapter) on_connect = AsyncMock() on_disconnect = AsyncMock() on_publish_result = AsyncMock() @@ -40,14 +41,14 @@ async def setup_manager(): manager = MQTTManager( client_id="test-client", main_stop_event=stop_event, - message_adapter=message_dispatcher, + message_adapter=message_adapter, on_connect=on_connect, on_disconnect=on_disconnect, on_publish_result=on_publish_result, rate_limits_handler=rate_limits_handler, rpc_response_handler=rpc_response_handler ) - return manager, stop_event, message_dispatcher, on_connect, on_disconnect, on_publish_result, rate_limits_handler, rpc_response_handler + return manager, stop_event, message_adapter, on_connect, on_disconnect, on_publish_result, rate_limits_handler, rpc_response_handler @pytest.mark.asyncio @@ -81,8 +82,8 @@ async def test_on_disconnect_internal_abnormal_disconnect(setup_manager): fut1 = asyncio.Future() fut2 = asyncio.Future() - manager._pending_publishes[101] = (fut1, "topic1", 1, 100, monotonic()) - manager._pending_publishes[102] = (fut2, "topic2", 1, 100, monotonic()) + manager._pending_publishes[101] = (fut1, MqttPublishMessage("topic1", b"test"), monotonic()) + manager._pending_publishes[102] = (fut2, MqttPublishMessage("topic2", b"test"), monotonic()) manager._backpressure = MagicMock() manager._backpressure.notify_disconnect = MagicMock() @@ -137,15 +138,16 @@ async def test_publish_force_bypasses_limits(setup_manager): manager._client._connection.publish.return_value = (10, b"packet") manager._client._persistent_storage = MagicMock() - result = await manager.publish("topic", b"payload", qos=1, force=True) - assert isinstance(result, asyncio.Future) + mqtt_publish_message = MqttPublishMessage("topic", b"payload", qos=1) + await manager.publish(mqtt_publish_message, qos=1, force=True) + assert manager._client._connection.publish.call_count == 1 @pytest.mark.asyncio async def test_on_disconnect_internal_clears_futures(setup_manager): manager, *_ = setup_manager fut = asyncio.Future() - manager._pending_publishes[42] = (fut, "topic", 1, 100, monotonic()) + manager._pending_publishes[42] = (fut, MqttPublishMessage("topic", b"payload"), monotonic()) manager._on_disconnect_internal(manager._client, reason_code=0) assert not manager._pending_publishes assert fut.done() @@ -169,7 +171,8 @@ async def dummy_handler(topic, payload): async def test_handle_puback_reason_code(setup_manager): manager, *_ = setup_manager fut = asyncio.Future() - manager._pending_publishes[123] = (fut, "topic", 1, 100, monotonic()) + fut.uuid = "test-future" + manager._pending_publishes[123] = (fut, MqttPublishMessage("topic", b"payload", delivery_futures=[fut]), monotonic()) manager._handle_puback_reason_code(123, 0, {}) assert fut.done() assert fut.result().message_id == 123 @@ -201,7 +204,7 @@ async def test_match_topic_logic(): async def test_check_pending_publishes_timeout(setup_manager): manager, *_ = setup_manager fut = asyncio.Future() - manager._pending_publishes[1] = (fut, "topic", 1, 100, monotonic() - 20) + manager._pending_publishes[1] = (fut, MqttPublishMessage("topic", b"payload"), monotonic() - 20) await manager.check_pending_publishes(monotonic()) assert fut.done() assert fut.result().reason_code == 408 @@ -246,14 +249,6 @@ async def test_unsubscribe_adds_future(setup_manager): mock_rate_limit.consume.assert_awaited_once() -@pytest.mark.asyncio -async def test_register_claiming_future_triggers_event(setup_manager): - manager, *_ = setup_manager - future = asyncio.Future() - manager.register_claiming_future(future) - assert manager._claiming_future is future - - @pytest.mark.asyncio async def test_publish_qos_zero_sets_result_immediately(setup_manager): manager, *_ = setup_manager @@ -264,10 +259,12 @@ async def test_publish_qos_zero_sets_result_immediately(setup_manager): manager._client._connection = MagicMock() manager._client._connection.publish.return_value = (99, b"packet") manager._client._persistent_storage = MagicMock() + future = asyncio.Future() - result = await manager.publish("topic", b"payload", qos=0, force=True) - assert result.done() - assert result.result() is True + await manager.publish(MqttPublishMessage("topic", b"payload", delivery_futures=future), qos=0, force=True) + await asyncio.sleep(0.05) # Allow async tasks to complete + assert future.done() + assert future.result() == PublishResult("topic", 0, -1, 7, 0) @pytest.mark.asyncio @@ -295,12 +292,12 @@ async def test_handle_puback_reason_code_errors(setup_manager): manager, *_ = setup_manager f1 = asyncio.Future() - manager._pending_publishes[1] = (f1, "topic", 1, 100, 0) + manager._pending_publishes[1] = (f1, MqttPublishMessage("topic", b"payload"), 0) manager._handle_puback_reason_code(1, IMPLEMENTATION_SPECIFIC_ERROR, {}) assert f1.result().reason_code == IMPLEMENTATION_SPECIFIC_ERROR f2 = asyncio.Future() - manager._pending_publishes[2] = (f2, "topic", 1, 100, 0) + manager._pending_publishes[2] = (f2, MqttPublishMessage("topic", b"payload"), 0) manager._handle_puback_reason_code(2, QUOTA_EXCEEDED, {}) assert f2.result().reason_code == QUOTA_EXCEEDED @@ -326,12 +323,12 @@ async def test_connect_loop_retry_and_success(setup_manager): @pytest.mark.asyncio async def test_request_rate_limits_timeout(setup_manager): manager, stop_event, _, _, _, _, rate_handler, _ = setup_manager - dispatcher = manager._message_dispatcher + adapter = manager._message_adapter req_mock = MagicMock() req_mock.request_id = "req-id" - dispatcher.build_rpc_request.return_value = ("topic", b"payload") + adapter.build_rpc_request.return_value = MqttPublishMessage("topic", b"payload") manager._client._connection = MagicMock() manager._client._connection.publish.return_value = (999, b"fake_packet") @@ -373,7 +370,7 @@ async def test_disconnect_reason_code_142_triggers_special_flow(setup_manager): manager._MQTTManager__rate_limiter = {"messages": rate_limit} fut = asyncio.Future() - manager._pending_publishes[42] = (fut, "topic", 1, 100, 0) + manager._pending_publishes[42] = (fut, MqttPublishMessage("topic", b"payload"), 0) manager._on_disconnect_internal(manager._client, reason_code=142) await asyncio.sleep(0.05) @@ -384,3 +381,7 @@ async def test_disconnect_reason_code_142_triggers_special_flow(setup_manager): call(delay_seconds=1), ]) manager._on_disconnect_callback.assert_awaited_once() + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) From aeeae2168175a3abcea3c031e19f2f918ffdb29f Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 24 Jul 2025 10:45:08 +0300 Subject: [PATCH 57/74] Message queue refactoring, added additional tests --- tb_mqtt_client/service/message_queue.py | 37 ++-- tests/service/test_message_queue.py | 278 ++++++++++++++++++++++++ tests/service/test_message_splitter.py | 54 +++-- tests/service/test_mqtt_manager.py | 2 + 4 files changed, 335 insertions(+), 36 deletions(-) diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index a13bb67..eaec6f5 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -169,9 +169,6 @@ async def _dequeue_loop(self): continue async def _try_publish(self, message: MqttPublishMessage): - if not message.delivery_futures: - logger.error("No delivery futures associated! This publish result will not be tracked.") - delivery_futures_or_none = [] is_message_with_telemetry_or_attributes = message.topic in (mqtt_topics.DEVICE_TELEMETRY_TOPIC, mqtt_topics.DEVICE_ATTRIBUTES_TOPIC) # TODO: Add topics check for gateways @@ -237,25 +234,27 @@ def _schedule_delayed_retry(self, message: MqttPublishMessage, delay: float = 0. return logger.trace("Scheduling retry: topic=%s, delay=%.2f", message.topic, delay) - async def retry(): - try: - logger.debug("Retrying publish: topic=%s", message.topic) - await asyncio.sleep(delay) - if not self._active.is_set() or self._main_stop_event.is_set(): - logger.debug("MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", message.topic) - return - self._queue.put_nowait(message) - self._wakeup_event.set() - logger.debug("Re-enqueued message after delay: topic=%s", message.topic) - except asyncio.QueueFull: - logger.warning("Retry queue full. Dropping retried message: topic=%s", message.topic) - except Exception as e: - logger.debug("Unexpected error during delayed retry: %s", e) - - task = asyncio.create_task(retry()) + task = asyncio.create_task(self.__retry_task(message, delay)) self._retry_tasks.add(task) task.add_done_callback(self._retry_tasks.discard) + async def __retry_task(self, message: MqttPublishMessage, delay: float): + try: + logger.debug("Retrying publish: topic=%s", message.topic) + await asyncio.sleep(delay) + if not self._active.is_set() or self._main_stop_event.is_set(): + logger.debug( + "MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", + message.topic) + return + self._queue.put_nowait(message) + self._wakeup_event.set() + logger.debug("Re-enqueued message after delay: topic=%s", message.topic) + except asyncio.QueueFull: + logger.warning("Retry queue full. Dropping retried message: topic=%s", message.topic) + except Exception as e: + logger.debug("Unexpected error during delayed retry: %s", e) + async def _wait_for_message(self) -> MqttPublishMessage: while self._active.is_set(): try: diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py index c0bc58e..8dd4e9c 100644 --- a/tests/service/test_message_queue.py +++ b/tests/service/test_message_queue.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from contextlib import suppress from unittest.mock import AsyncMock, Mock, MagicMock, patch import pytest @@ -21,8 +22,10 @@ from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter from tb_mqtt_client.service.message_queue import MessageQueue @@ -349,5 +352,280 @@ async def test_schedule_delayed_retry_does_nothing_if_stopped(): assert len(queue._retry_tasks) == 0 +@pytest.mark.asyncio +async def test_try_publish_telemetry_rate_limit_triggered(): + telemetry_limit = Mock(spec=RateLimit) + telemetry_limit.try_consume = AsyncMock(return_value=(10, 1)) + telemetry_limit.minimal_timeout = 0.05 + + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=None, + telemetry_rate_limit=telemetry_limit, + telemetry_dp_rate_limit=None, + message_adapter=Mock(build_uplink_messages=Mock(return_value=[])) + ) + mq._schedule_delayed_retry = AsyncMock() + mq._backpressure = Mock(spec=BackpressureController) + mq._backpressure.should_pause.return_value = False + + message = MqttPublishMessage(topic=DEVICE_TELEMETRY_TOPIC, payload=b"{}", qos=1) + + await mq._try_publish(message) + + await mq.shutdown() + + mq._schedule_delayed_retry.assert_called_once_with(message, delay=telemetry_limit.minimal_timeout) + called_args = mq._schedule_delayed_retry.call_args + assert called_args.args[0] == message + + +@pytest.mark.asyncio +async def test_try_publish_telemetry_dp_rate_limit_triggered(): + dp_limit = Mock(spec=RateLimit) + dp_limit.try_consume = AsyncMock(return_value=(100, 10)) + dp_limit.minimal_timeout = 0.1 + + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=dp_limit, + message_adapter=Mock(build_uplink_messages=Mock(return_value=[])) + ) + mq._schedule_delayed_retry = AsyncMock() + mq._backpressure = Mock(spec=BackpressureController) + mq._backpressure.should_pause.return_value = False + + message = MqttPublishMessage( + topic=DEVICE_TELEMETRY_TOPIC, + payload=b"{}", + qos=1, + datapoints=5 + ) + + await mq._try_publish(message) + + mq._schedule_delayed_retry.assert_called_once_with(message, delay=dp_limit.minimal_timeout) + + +@pytest.mark.asyncio +async def test_try_publish_generic_message_rate_limit_triggered(): + generic_limit = Mock(spec=RateLimit) + generic_limit.try_consume = AsyncMock(return_value=(1, 60)) + generic_limit.minimal_timeout = 0.2 + + mq = MessageQueue( + mqtt_manager=AsyncMock(), + main_stop_event=asyncio.Event(), + message_rate_limit=generic_limit, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=Mock() + ) + mq._schedule_delayed_retry = AsyncMock() + mq._backpressure = Mock(spec=BackpressureController) + mq._backpressure.should_pause.return_value = False + + message = MqttPublishMessage(topic="some/other/topic", payload=b"{}", qos=1) + + await mq._try_publish(message) + + mq._schedule_delayed_retry.assert_called_once_with(message, delay=generic_limit.minimal_timeout) + called_args = mq._schedule_delayed_retry.call_args + assert called_args.args[0] == message + + +from unittest.mock import AsyncMock, patch + +@pytest.mark.asyncio +async def test_batch_breaks_on_elapsed_time(): + msg = MqttPublishMessage( + "topic", + DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("k", 1)).build(), + qos=1 + ) + + adapter = JsonMessageAdapter() + adapter.splitter.max_payload_size = 1000000 + + stop = asyncio.Event() + mqtt_manager = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + mq = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=stop, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=adapter, + batch_collect_max_time_ms=1, + batch_collect_max_count=1000 + ) + + mq._try_publish = AsyncMock() + + mq._queue.put_nowait(msg) + await asyncio.sleep(0.1) + await mq.shutdown() + + mq._try_publish.assert_called_once() + + +@pytest.mark.asyncio +async def test_batch_breaks_on_message_count(): + adapter = JsonMessageAdapter() + adapter.splitter.max_payload_size = 1000000 + + mqtt_manager = AsyncMock() + mqtt_manager.publish = AsyncMock(return_value=asyncio.Future()) + mqtt_manager.publish.return_value.set_result(True) + mqtt_manager.connected = AsyncMock() + mqtt_manager.connected.is_set.return_value = True + mqtt_manager.backpressure.should_pause.return_value = False + + stop = asyncio.Event() + + mq = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=stop, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=adapter, + batch_collect_max_time_ms=1000, + batch_collect_max_count=2 + ) + mq._try_publish = AsyncMock() + + msg1 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("a", 1)).build(), qos=1) + msg2 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("b", 2)).build(), qos=1) + msg3 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("c", 3)).build(), qos=1) + + await mq._queue.put(msg1) + await mq._queue.put(msg2) + await mq._queue.put(msg3) + await asyncio.sleep(0.2) + await mq.shutdown() + + assert mq._try_publish.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_breaks_on_type_mismatch(): + + adapter = JsonMessageAdapter() + adapter.splitter.max_payload_size = 1000000 + + mqtt_manager = AsyncMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + stop = asyncio.Event() + + mq = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=stop, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=adapter, + batch_collect_max_time_ms=1000, + batch_collect_max_count=10 + ) + mq._try_publish = AsyncMock() + + msg1 = MqttPublishMessage("topic", + DeviceUplinkMessageBuilder() + .add_timeseries(TimeseriesEntry("a", 1)) + .build(), qos=1) + msg2 = MqttPublishMessage("topic", + GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries(TimeseriesEntry("b", 2)) + .build(), + qos=1) + + mq._queue.put_nowait(msg1) + mq._queue.put_nowait(msg2) + + await asyncio.sleep(0.2) + await mq.shutdown() + + assert mq._try_publish.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_breaks_on_size_threshold(): + from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder + + adapter = JsonMessageAdapter() + adapter.splitter.max_payload_size = 20 + + mqtt_manager = AsyncMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + stop = asyncio.Event() + + mq = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=stop, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=adapter, + batch_collect_max_time_ms=1000, + batch_collect_max_count=10 + ) + mq._try_publish = AsyncMock() + + msg1 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("a", 1)).build(), qos=1) + msg2 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("b", 2)).build(), qos=1) + + mq._queue.put_nowait(msg1) + mq._queue.put_nowait(msg2) + + await asyncio.sleep(0.2) + await mq.shutdown() + + assert mq._try_publish.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_skips_bytes_payload(): + mqtt_manager = AsyncMock() + mqtt_manager.publish = AsyncMock() + mqtt_manager.backpressure.should_pause.return_value = False + + stop = asyncio.Event() + + mq = MessageQueue( + mqtt_manager=mqtt_manager, + main_stop_event=stop, + message_rate_limit=None, + telemetry_rate_limit=None, + telemetry_dp_rate_limit=None, + message_adapter=JsonMessageAdapter(), + batch_collect_max_time_ms=1000, + batch_collect_max_count=1000 + ) + mq._try_publish = AsyncMock() + + + msg1 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("a", 1)).build(), qos=1) + msg2 = MqttPublishMessage("topic", b"raw", qos=1) + + mq._queue.put_nowait(msg1) + mq._queue.put_nowait(msg2) + + await asyncio.sleep(0.2) + await mq.shutdown() + assert mq._try_publish.call_count == 2 + + if __name__ == '__main__': pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index 341d3e3..4c4b231 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -20,6 +20,7 @@ from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter from tb_mqtt_client.service.message_splitter import MessageSplitter @@ -29,23 +30,6 @@ def splitter(): return MessageSplitter(max_payload_size=100, max_datapoints=3) -def mock_ts_entry(size=20): - entry = MagicMock() - entry.size = size - entry.ts = 123456789 - entry.key = "k" - entry.value = 42 - return entry - - -def mock_attr_entry(size=20): - attr = MagicMock() - attr.size = size - attr.key = "k" - attr.value = 42 - return attr - - # Positive cases def test_single_small_timeseries_pass_through(splitter): msg = MagicMock() @@ -196,5 +180,41 @@ async def test_split_attributes_different_devices_not_grouped(): fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) +@patch("tb_mqtt_client.service.message_splitter.future_map.register") +@pytest.mark.asyncio +async def test_split_timeseries_registers_futures_and_batches_correctly(mock_register): + splitter = MessageSplitter(max_payload_size=100, max_datapoints=2) + + builder = DeviceUplinkMessageBuilder().set_device_name("deviceX").set_device_profile("profileY") + entry1 = TimeseriesEntry("temp", 1, 100) + entry2 = TimeseriesEntry("humidity", 2, 100) + entry3 = TimeseriesEntry("pressure", 3, 100) + + builder.add_timeseries(entry1) + builder.add_timeseries(entry2) + builder.add_timeseries(entry3) + + parent_future = asyncio.Future() + builder.add_delivery_futures(parent_future) + original_msg = builder.build() + + result = splitter.split_timeseries([original_msg]) + + # Should be split due to datapoints=3 > max_datapoints=2 → 2 batches + assert len(result) == 2 + total_points = sum(m.timeseries_datapoint_count() for m in result) + assert total_points == 3 + + # Validate that register was called for the shared future for each batch + assert mock_register.call_count == 2 + for call in mock_register.call_args_list: + args, _ = call + assert args[0] is parent_future + assert isinstance(args[1], list) + assert len(args[1]) == 1 + shared_future = args[1][0] + assert isinstance(shared_future, asyncio.Future) + assert hasattr(shared_future, "uuid") + if __name__ == '__main__': pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 2a802a4..0647c2a 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -227,6 +227,7 @@ async def test_subscribe_adds_future(setup_manager): setattr(manager, "_MQTTManager__rate_limiter", {MESSAGES_RATE_LIMIT: mock_rate_limit}) fut = await manager.subscribe("topic", qos=1) + await asyncio.sleep(0.1) assert 42 in manager._pending_subscriptions assert isinstance(fut, asyncio.Future) @@ -243,6 +244,7 @@ async def test_unsubscribe_adds_future(setup_manager): setattr(manager, "_MQTTManager__rate_limiter", {MESSAGES_RATE_LIMIT: mock_rate_limit}) fut = await manager.unsubscribe("topic") + await asyncio.sleep(0.1) assert 77 in manager._pending_unsubscriptions assert isinstance(fut, asyncio.Future) From 4e54c797c696d9f55e580dd3db6a4a392d26995c Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 24 Jul 2025 12:24:10 +0300 Subject: [PATCH 58/74] Added main ts to initial message, splitted and grouped messages, to send data without provided ts with correct timestamp --- tb_mqtt_client/common/mqtt_message.py | 7 ++- .../entities/data/device_uplink_message.py | 23 +++++++++- tb_mqtt_client/service/base_client.py | 2 +- tb_mqtt_client/service/device/client.py | 1 + .../service/device/message_adapter.py | 46 +++++++++++-------- tb_mqtt_client/service/message_queue.py | 2 +- tb_mqtt_client/service/message_splitter.py | 12 +++-- tb_mqtt_client/service/mqtt_manager.py | 4 +- tests/service/test_message_queue.py | 7 ++- 9 files changed, 70 insertions(+), 34 deletions(-) diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py index a19f0f8..384b634 100644 --- a/tb_mqtt_client/common/mqtt_message.py +++ b/tb_mqtt_client/common/mqtt_message.py @@ -13,7 +13,8 @@ # limitations under the License. from asyncio import Future -from typing import Union, List +from time import time +from typing import Union, Optional from uuid import uuid4 from gmqtt import Message @@ -37,14 +38,18 @@ def __init__(self, retain: bool = False, datapoints: int = 0, delivery_futures = None, + main_ts: Optional[int] = None, **kwargs): """ Initialize the MqttMessage with topic, payload, QoS, retain flag, and datapoints. """ self.prepared = False self.payload = payload + self.main_ts = main_ts if main_ts is not None else int(time() * 1000) if isinstance(payload, bytes): super().__init__(topic, payload, qos, retain) + else: + payload.set_main_ts(self.main_ts) self.topic = topic self.qos = qos if self.qos < 0 or self.qos > 1: diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index b4da805..c4de1da 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -14,6 +14,7 @@ import asyncio from dataclasses import dataclass +from time import time from types import MappingProxyType from typing import List, Optional, Union, OrderedDict, Tuple, Mapping from uuid import uuid4 @@ -36,6 +37,7 @@ class DeviceUplinkMessage: timeseries: Mapping[int, Tuple[TimeseriesEntry]] delivery_futures: List[Optional[asyncio.Future[PublishResult]]] _size: int + main_ts: Optional[int] = None def __new__(cls, *args, **kwargs): raise TypeError("Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") @@ -54,7 +56,8 @@ def build(cls, attributes: List[AttributeEntry], timeseries: Mapping[int, List[TimeseriesEntry]], delivery_futures: List[Optional[asyncio.Future]], - size: int) -> 'DeviceUplinkMessage': + size: int, + main_ts: Optional[int]) -> 'DeviceUplinkMessage': self = object.__new__(cls) object.__setattr__(self, 'device_name', device_name) object.__setattr__(self, 'device_profile', device_profile) @@ -63,6 +66,7 @@ def build(cls, MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) object.__setattr__(self, '_size', size) + object.__setattr__(self, 'main_ts', main_ts) return self @property @@ -84,6 +88,15 @@ def has_timeseries(self) -> bool: def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: return self.delivery_futures + def set_main_ts(self, main_ts: int) -> 'DeviceUplinkMessage': + """ + Set the main timestamp for the message. + :param main_ts: The main timestamp to set. + :return: The updated DeviceUplinkMessage instance. + """ + object.__setattr__(self, 'main_ts', main_ts) + return self + class DeviceUplinkMessageBuilder: def __init__(self): @@ -93,6 +106,7 @@ def __init__(self): self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] self.__size = DEFAULT_FIELDS_SIZE + self._main_ts: Optional[int] = None def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': self._device_name = device_name @@ -146,6 +160,10 @@ def add_delivery_futures(self, futures: Union[ self._delivery_futures.extend(futures) return self + def set_main_ts(self, main_ts: int) -> 'DeviceUplinkMessageBuilder': + self._main_ts = main_ts + return self + def build(self) -> DeviceUplinkMessage: if not self._delivery_futures: delivery_future = asyncio.get_event_loop().create_future() @@ -158,5 +176,6 @@ def build(self) -> DeviceUplinkMessage: attributes=self._attributes, timeseries=self._timeseries, delivery_futures=self._delivery_futures, - size=self.__size + size=self.__size, + main_ts=self._main_ts ) diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 51cd7bf..60c6b1b 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -20,6 +20,7 @@ import uvloop from tb_mqtt_client.common.exceptions import exception_handler +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.json_typing import JSONCompatibleType from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry @@ -28,7 +29,6 @@ from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, GatewayUplinkMessage from tb_mqtt_client.service.gateway.device_session import DeviceSession diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 7a6423d..ffbcecb 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -376,6 +376,7 @@ async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = No }) future = self._rpc_response_handler.register_request(request_id) + # TODO: Use MQTT message instead of raw payload await self._mqtt_manager.publish(topic, payload, qos=1) try: diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index 791b7a4..fefe377 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from abc import ABC, abstractmethod -from itertools import chain -from collections import defaultdict, deque +from collections import defaultdict from datetime import UTC, datetime -from typing import Any, Dict, List, Tuple, Optional, Union, DefaultDict, Set +from itertools import chain +from typing import Any, Dict, List, Optional, Union from orjson import dumps, loads -from tb_mqtt_client.common.async_utils import await_and_resolve_original, future_map +from tb_mqtt_client.common.async_utils import future_map from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.constants import mqtt_topics @@ -34,6 +33,7 @@ from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.message_splitter import MessageSplitter logger = get_logger(__name__) @@ -274,7 +274,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt payload=payload_bytes, qos=qos, datapoints=count, - delivery_futures=child_futures + delivery_futures=child_futures, + main_ts=ts_batch.main_ts ) result.append(mqtt_msg) built_child_messages.append(mqtt_msg) @@ -296,7 +297,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt payload=payload_bytes, qos=qos, datapoints=count, - delivery_futures=child_futures + delivery_futures=child_futures, + main_ts=attr_batch.main_ts ) result.append(mqtt_msg) built_child_messages.append(mqtt_msg) @@ -447,16 +449,24 @@ def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: @staticmethod def pack_timeseries(msg: 'DeviceUplinkMessage') -> Union[Dict[str, Any], List[Dict[str, Any]]]: - now_ts = int(datetime.now(UTC).timestamp() * 1000) - - entries = list(chain.from_iterable(msg.timeseries.values())) - - if all(entry.ts is None for entry in entries): - return {entry.key: entry.value for entry in entries} - - grouped: Dict[int, Dict[str, Any]] = defaultdict(dict) - for entry in entries: - ts = entry.ts or now_ts - grouped[ts][entry.key] = entry.value + entries = [e for entries in msg.timeseries.values() for e in entries] + if not entries: + return {} + + all_ts_none = True + for e in entries: + if e.ts is not None: + all_ts_none = False + break + + if all_ts_none: + result = {e.key: e.value for e in entries} + return [{"ts": msg.main_ts, "values": result}] if msg.main_ts is not None else result + + now_ts = msg.main_ts if msg.main_ts is not None else int(datetime.now(UTC).timestamp() * 1000) + grouped = defaultdict(dict) + for e in entries: + ts = e.ts if e.ts is not None else now_ts + grouped[ts][e.key] = e.value return [{"ts": ts, "values": values} for ts, values in grouped.items()] diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index eaec6f5..f45d231 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -18,10 +18,10 @@ from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage -from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.service.device.message_adapter import MessageAdapter from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/message_splitter.py index 0715a5c..9861def 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/message_splitter.py @@ -13,12 +13,12 @@ # limitations under the License. import asyncio -from collections import defaultdict, deque -from typing import List, Optional, Dict, Tuple, Set +from collections import defaultdict +from typing import List, Optional, Dict, Tuple from uuid import uuid4 from tb_mqtt_client.common.async_utils import future_map -from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry @@ -91,7 +91,8 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp built.timeseries_datapoint_count(), size) builder = DeviceUplinkMessageBuilder() \ .set_device_name(device_name) \ - .set_device_profile(device_profile) + .set_device_profile(device_profile) \ + .set_main_ts(group_msgs[0].main_ts if group_msgs else None) size = 0 point_count = 0 @@ -163,7 +164,8 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) builder = DeviceUplinkMessageBuilder() \ .set_device_name(device_name) \ - .set_device_profile(device_profile) + .set_device_profile(device_profile) \ + .set_main_ts(group_msgs[0].main_ts if group_msgs else None) size = 0 point_count = 0 diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index f86e24a..254a45e 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -27,6 +27,7 @@ from tb_mqtt_client.common.gmqtt_patch import PatchUtils from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer, AttributeRequestIdProducer @@ -35,7 +36,6 @@ TELEMETRY_DATAPOINTS_RATE_LIMIT from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler from tb_mqtt_client.service.device.message_adapter import MessageAdapter @@ -201,7 +201,7 @@ def resolve_attached(publish_future: asyncio.Future): if f is not None and not f.done(): f.set_result(publish_result) future_map.child_resolved(f) - logger.error("Resolved delivery future #%d id=%r with %s, main publish future id: %r", + logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r", i, f.uuid, publish_result, publish_future.uuid) except Exception as e: logger.error("Error resolving delivery futures: %s", str(e)) diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py index 8dd4e9c..f6b1511 100644 --- a/tests/service/test_message_queue.py +++ b/tests/service/test_message_queue.py @@ -13,8 +13,7 @@ # limitations under the License. import asyncio -from contextlib import suppress -from unittest.mock import AsyncMock, Mock, MagicMock, patch +from unittest.mock import Mock, MagicMock import pytest import pytest_asyncio @@ -23,9 +22,9 @@ from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter from tb_mqtt_client.service.message_queue import MessageQueue From 94c2c571bedc17c71939ec4d99899fbae39147f2 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 24 Jul 2025 13:08:40 +0300 Subject: [PATCH 59/74] Improved examples and fixed double call of connect callback and subscriptions --- examples/device/operational_example.py | 13 ++++++++----- examples/device/send_client_side_rpc.py | 2 +- tb_mqtt_client/common/gmqtt_patch.py | 14 +++----------- tb_mqtt_client/service/device/client.py | 2 -- tb_mqtt_client/service/mqtt_manager.py | 4 ++-- 5 files changed, 14 insertions(+), 21 deletions(-) diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index bcc537a..18f8d2d 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -15,7 +15,6 @@ import asyncio import logging import signal -from datetime import datetime, UTC from random import uniform, randint from tb_mqtt_client.common.config_loader import DeviceConfig @@ -25,7 +24,7 @@ from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest -from tb_mqtt_client.entities.data.rpc_response import RPCResponse, RPCStatus +from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient @@ -34,6 +33,8 @@ logger.setLevel(logging.DEBUG) logging.getLogger("tb_mqtt_client").setLevel(logging.DEBUG) +DELAY_BETWEEN_DATA_PUBLISH = 1 # seconds + async def attribute_update_callback(update: AttributeUpdate): """ @@ -111,6 +112,7 @@ def _shutdown_handler(): while not stop_event.is_set(): # --- Attributes --- + iteration_start = asyncio.get_event_loop().time() # 1. Raw dict raw_dict = { @@ -157,7 +159,7 @@ def _shutdown_handler(): telemetry_entries = [] for i in range(100): - telemetry_entries.append(TimeseriesEntry("temperature", i, ts=int(datetime.now(UTC).timestamp() * 1000)-i)) + telemetry_entries.append(TimeseriesEntry("temperature%i" % i, i)) logger.info("Sending list of telemetry entries with mixed timestamps...") telemetry_list_publish_result = await client.send_timeseries(telemetry_entries) logger.info("List of telemetry entries sent: %s with result: %s", @@ -198,8 +200,9 @@ def _shutdown_handler(): await client.send_rpc_request(rpc_request_2, rpc_response_callback, wait_for_publish=False) try: - logger.info("Waiting for 1 seconds before next iteration...") - await asyncio.wait_for(stop_event.wait(), timeout=1) + logger.info("Waiting before next iteration...") + timeout = DELAY_BETWEEN_DATA_PUBLISH - (asyncio.get_event_loop().time() - iteration_start) + await asyncio.wait_for(stop_event.wait(), timeout=timeout) except asyncio.TimeoutError: logger.info("Going to next iteration...") diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py index e7d1b05..f652935 100644 --- a/examples/device/send_client_side_rpc.py +++ b/examples/device/send_client_side_rpc.py @@ -31,7 +31,7 @@ async def rpc_response_callback(response: RPCResponse): - logger.info("Received RPC response:", response) + logger.info("Received RPC response: %r", response) async def main(): diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index a1cc00a..d837ae5 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -17,7 +17,7 @@ import struct from collections import defaultdict from time import monotonic -from typing import Callable, Tuple +from typing import Callable, Tuple, Optional from gmqtt import Client from gmqtt.mqtt.constants import MQTTCommands, MQTTv50, MQTTv311 @@ -102,7 +102,7 @@ class PatchUtils: 154: "Wildcard Subscriptions not supported" } - def __init__(self, client: Client, stop_event: asyncio.Event, retry_interval: int = 1): + def __init__(self, client: Optional[Client], stop_event: asyncio.Event, retry_interval: int = 1): """ Initialize PatchUtils with a client and retry interval. @@ -209,8 +209,7 @@ def patched_handle_disconnect_packet(self, cmd, packet): logger.warning("Failed to patch gmqtt handler: %s", e) return False - def patch_handle_connack(patch_utils_instance, - on_connack_with_session_present_and_result_code: Callable[[object, int, int, dict], None]): + def patch_handle_connack(patch_utils_instance): """ Fully replaces gmqtt's _handle_connack_packet implementation, skipping internal QoS1 resend behavior and invoking a custom callback instead of calling the original method. @@ -251,13 +250,6 @@ def new_handle_connack_packet(self, cmd, packet): self._logger.debug('[CONNACK] session_present: %s, result: %s', hex(session_present), hex(result_code)) - on_connack_with_session_present_and_result_code( - patch_utils_instance.client, - session_present, - result_code, - properties - ) - self.on_connect(self, session_present, result_code, self.properties) except Exception as e: diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index ffbcecb..6f58d02 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -126,8 +126,6 @@ async def connect(self): if self._stop_event.is_set(): return - await self._on_connect() - # Initialize with default max_payload_size if not set if self.max_payload_size is None: self.max_payload_size = 65535 diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 254a45e..4b4689d 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -68,7 +68,7 @@ def __init__( self._client = GMQTTClient(client_id) self._patch_utils.client = self._client - self._patch_utils.patch_handle_connack(self._on_connect_internal) + self._patch_utils.patch_handle_connack() self._patch_utils.apply(self._handle_puback_reason_code) self._client.on_connect = self._on_connect_internal self._client.on_disconnect = self._on_disconnect_internal @@ -269,7 +269,7 @@ def _on_connect_internal(self, client, session_present, reason_code, properties) asyncio.create_task(self.__handle_connect_and_limits()) async def __handle_connect_and_limits(self): - logger.debug("Subscribing to RPC response topics") + logger.info("Subscribing to RPC response topics") sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, qos=1) while not sub_future.done(): await sleep(0.01) From 7e5ea8c87ce492e128b2cc3f7e90ea09ee48c51e Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 25 Jul 2025 12:26:25 +0300 Subject: [PATCH 60/74] Adjusted gateway part of SDK to use MqttPublishMessage entity, refactored rate limits entities, added initial implementation of gateway message splitter --- examples/gateway/claim_device.py | 14 -- .../gateway/connect_and_disconnect_device.py | 15 +- examples/gateway/handle_attribute_updates.py | 2 +- examples/gateway/handle_rpc_requests.py | 2 +- examples/gateway/load.py | 138 ++++++++++ examples/gateway/request_attributes.py | 8 +- examples/gateway/send_attributes.py | 2 +- examples/gateway/send_timeseries.py | 2 +- tb_mqtt_client/common/async_utils.py | 38 ++- .../common/rate_limit/rate_limit.py | 2 + .../common/rate_limit/rate_limiter.py | 32 +++ tb_mqtt_client/constants/mqtt_topics.py | 18 ++ tb_mqtt_client/constants/service_keys.py | 4 - .../gateway/gateway_uplink_message.py | 21 +- .../service/base_message_splitter.py | 70 ++++++ tb_mqtt_client/service/device/client.py | 76 +++--- .../service/device/message_adapter.py | 16 +- .../service/{ => device}/message_splitter.py | 28 +-- tb_mqtt_client/service/gateway/client.py | 30 ++- .../gateway/direct_event_dispatcher.py | 8 +- .../gateway/handlers/gateway_rpc_handler.py | 2 - .../service/gateway/message_adapter.py | 237 +++++++++--------- .../service/gateway/message_sender.py | 81 +++--- .../service/gateway/message_splitter.py | 189 ++++++++++++++ tb_mqtt_client/service/message_queue.py | 107 ++++---- tb_mqtt_client/service/mqtt_manager.py | 101 ++------ tests/service/device/test_device_client.py | 11 +- tests/service/test_message_queue.py | 74 ++---- tests/service/test_message_splitter.py | 63 +++-- tests/service/test_mqtt_manager.py | 16 +- 30 files changed, 897 insertions(+), 510 deletions(-) create mode 100644 examples/gateway/load.py create mode 100644 tb_mqtt_client/common/rate_limit/rate_limiter.py create mode 100644 tb_mqtt_client/service/base_message_splitter.py rename tb_mqtt_client/service/{ => device}/message_splitter.py (88%) create mode 100644 tb_mqtt_client/service/gateway/message_splitter.py diff --git a/examples/gateway/claim_device.py b/examples/gateway/claim_device.py index afb9f16..8c5ee0c 100644 --- a/examples/gateway/claim_device.py +++ b/examples/gateway/claim_device.py @@ -12,20 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import asyncio from tb_mqtt_client.common.config_loader import GatewayConfig diff --git a/examples/gateway/connect_and_disconnect_device.py b/examples/gateway/connect_and_disconnect_device.py index 2e1694b..5ef9d13 100644 --- a/examples/gateway/connect_and_disconnect_device.py +++ b/examples/gateway/connect_and_disconnect_device.py @@ -11,19 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import asyncio @@ -45,7 +32,7 @@ async def main(): # Connecting device - device_name = "Test Device A1" + device_name = "Test Device B1" device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/handle_attribute_updates.py b/examples/gateway/handle_attribute_updates.py index 4a5d41d..b821397 100644 --- a/examples/gateway/handle_attribute_updates.py +++ b/examples/gateway/handle_attribute_updates.py @@ -44,7 +44,7 @@ async def main(): # Connecting device - device_name = "Test Device A1" + device_name = "Test Device B1" device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py index 7938d7e..c711801 100644 --- a/examples/gateway/handle_rpc_requests.py +++ b/examples/gateway/handle_rpc_requests.py @@ -61,7 +61,7 @@ async def main(): # Connecting device - device_name = "Test Device A1" + device_name = "Test Device B1" device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/load.py b/examples/gateway/load.py new file mode 100644 index 0000000..29484c3 --- /dev/null +++ b/examples/gateway/load.py @@ -0,0 +1,138 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import signal +import time +from datetime import UTC, datetime +from random import randint +from typing import List + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.common.publish_result import PublishResult + +# --- Logging --- +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +# --- Constants --- +NUM_DEVICES = 2 +BATCH_SIZE = 100 +MAX_PENDING = 100 +FUTURE_TIMEOUT = 1.0 +DEVICE_PREFIX = "perf-test-device" +WAIT_FOR_PUBLISH = False + +# --- Test logic --- +async def send_batch(client: GatewayClient, session: DeviceSession) -> List[asyncio.Future]: + ts = int(datetime.now(UTC).timestamp() * 1000) + entries = [TimeseriesEntry(f"temp{i}", randint(20, 35), ts=ts) for i in range(BATCH_SIZE)] + result = await client.send_device_timeseries(session, entries, wait_for_publish=WAIT_FOR_PUBLISH) + if isinstance(result, list): + return result + elif result is not None: + return [result] + return [] + +async def wait_for_futures(futures: List[asyncio.Future]) -> int: + delivered = 0 + done, _ = await asyncio.wait(futures, timeout=FUTURE_TIMEOUT, return_when=asyncio.ALL_COMPLETED) + for fut in done: + try: + res = fut.result() + if isinstance(res, PublishResult) and res.is_successful(): + delivered += res.datapoints_count + except Exception as e: + logger.warning("Future error: %s", e) + return delivered + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown(): + logger.info("Shutting down by signal...") + stop_event.set() + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown) + except NotImplementedError: + signal.signal(sig, lambda *_: _shutdown()) + + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = GatewayClient(config) + await client.connect() + + logger.info("Connected to ThingsBoard as gateway") + + # Register test devices + sessions: List[DeviceSession] = [] + for i in range(NUM_DEVICES): + device_name = f"{DEVICE_PREFIX}-{i}" + session, _ = await client.connect_device(device_name, wait_for_publish=False) + sessions.append(session) + + logger.info("Registered %d devices", len(sessions)) + + sent_batches = 0 + delivered_dp = 0 + pending_futures: List[asyncio.Future] = [] + + start = time.perf_counter() + + try: + while not stop_event.is_set(): + for session in sessions: + futs = await send_batch(client, session) + pending_futures.extend(futs) + sent_batches += 1 + + if len(pending_futures) >= MAX_PENDING: + delivered = await wait_for_futures(pending_futures) + delivered_dp += delivered + pending_futures.clear() + logger.info("Delivered datapoints so far: %d", delivered_dp) + + await asyncio.sleep(0) # yield to event loop + + finally: + logger.info("Flushing remaining futures...") + if pending_futures: + delivered_dp += await wait_for_futures(pending_futures) + end = time.perf_counter() + + duration = end - start + logger.info("Sent %d batches across %d devices", sent_batches, NUM_DEVICES) + logger.info("Delivered %d datapoints in %.2f seconds (%.0f datapoints/sec)", + delivered_dp, duration, delivered_dp / duration if duration > 0 else 0) + await client.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except (KeyboardInterrupt, asyncio.CancelledError): + print("Interrupted by user.") diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index ffccbc5..eb366bb 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -48,7 +48,7 @@ async def main(): # Connecting device - device_name = "Test Device A1" + device_name = "Test Device B1" device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) @@ -81,6 +81,12 @@ async def main(): await client.send_device_attributes_request(device_session, attribute_request, wait_for_publish=True) + # Trying to request shared attribute "state" (Response will be empty if not set) + + logger.info("Requesting shared attribute 'state' for device: %s", device_name) + shared_attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["state"]) + await client.send_device_attributes_request(device_session, shared_attribute_request, wait_for_publish=True) + await asyncio.sleep(1) # Wait for the response to be processed # Disconnect device diff --git a/examples/gateway/send_attributes.py b/examples/gateway/send_attributes.py index 9f3399d..28d39d6 100644 --- a/examples/gateway/send_attributes.py +++ b/examples/gateway/send_attributes.py @@ -33,7 +33,7 @@ async def main(): # Connecting device - device_name = "Test Device A1" + device_name = "Test Device B1" device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index 80124de..4a7b4b8 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -34,7 +34,7 @@ async def main(): # Connecting device - device_name = "Test Device A1" + device_name = "Test Device B1" device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py index 5b35635..e730c3f 100644 --- a/tb_mqtt_client/common/async_utils.py +++ b/tb_mqtt_client/common/async_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import threading from typing import Union, Optional, Any, List, Set, Dict import asyncio @@ -127,3 +127,39 @@ async def await_and_resolve_original( if f is not None and not f.done(): f.set_exception(e) logger.debug("Set fallback exception for parent future #%d id=%r", i, getattr(f, 'uuid', f)) + + +def run_coroutine_sync(coro_func, timeout: float = 3.0, raise_on_timeout: bool = False): + """ + Run async coroutine and return its result from a sync function even if event loop is running. + :param coro_func: async function with no arguments (like: lambda: some_async_fn()) + :param timeout: max wait time in seconds + :param raise_on_timeout: if True, raise TimeoutError on timeout; otherwise return None + """ + result_container = {} + event = threading.Event() + + async def wrapper(): + try: + result = await coro_func() + result_container['result'] = result + except Exception as e: + result_container['error'] = e + finally: + event.set() + + loop = asyncio.get_running_loop() + loop.create_task(wrapper()) + + completed = event.wait(timeout=timeout) + + if not completed: + logger.warning("Timeout while waiting for coroutine to finish: %s", coro_func) + if raise_on_timeout: + raise TimeoutError(f"Coroutine {coro_func} did not complete in {timeout} seconds.") + return None + + if 'error' in result_container: + raise result_container['error'] + + return result_container.get('result') \ No newline at end of file diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py index ddddb1e..f02c862 100644 --- a/tb_mqtt_client/common/rate_limit/rate_limit.py +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -79,6 +79,8 @@ def _parse_string(self, rate_limit: str): entries = rate_limit.replace(";", ",").split(",") for entry in entries: + if not entry.strip(): + continue try: limit_str, dur_str = entry.strip().split(":") limit = int(int(limit_str) * self.percentage / 100) diff --git a/tb_mqtt_client/common/rate_limit/rate_limiter.py b/tb_mqtt_client/common/rate_limit/rate_limiter.py new file mode 100644 index 0000000..7fe75e0 --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/rate_limiter.py @@ -0,0 +1,32 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit + + +@dataclass() +class RateLimiter: + message_rate_limit: RateLimit + telemetry_message_rate_limit: RateLimit + telemetry_datapoints_rate_limit: RateLimit + + def values(self): + return [self.message_rate_limit, self.telemetry_message_rate_limit, self.telemetry_datapoints_rate_limit] + + def __repr__(self): + return f"RateLimiter(message_rate_limit={self.message_rate_limit}, " \ + f"telemetry_message_rate_limit={self.telemetry_message_rate_limit}, " \ + f"telemetry_datapoints_rate_limit={self.telemetry_datapoints_rate_limit})" diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 5db1125..b596e4e 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -48,6 +48,24 @@ GATEWAY_RPC_TOPIC = BASE_GATEWAY_TOPIC + "/rpc" GATEWAY_CLAIM_TOPIC = BASE_GATEWAY_TOPIC + "/claim" +TOPICS_WITH_DATAPOINTS_CHECK = frozenset({ + DEVICE_TELEMETRY_TOPIC, + DEVICE_ATTRIBUTES_TOPIC, + GATEWAY_TELEMETRY_TOPIC, + GATEWAY_ATTRIBUTES_TOPIC +}) + +GATEWAY_TOPICS = frozenset({ + GATEWAY_CONNECT_TOPIC, + GATEWAY_DISCONNECT_TOPIC, + GATEWAY_TELEMETRY_TOPIC, + GATEWAY_ATTRIBUTES_TOPIC, + GATEWAY_ATTRIBUTES_REQUEST_TOPIC, + GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, + GATEWAY_RPC_TOPIC, + GATEWAY_CLAIM_TOPIC +}) + # Topic Builders diff --git a/tb_mqtt_client/constants/service_keys.py b/tb_mqtt_client/constants/service_keys.py index ac0568e..5a547ad 100644 --- a/tb_mqtt_client/constants/service_keys.py +++ b/tb_mqtt_client/constants/service_keys.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -MESSAGES_RATE_LIMIT = "MESSAGES_RATE_LIMIT" -TELEMETRY_MESSAGE_RATE_LIMIT = "TELEMETRY_MESSAGE_RATE_LIMIT" -TELEMETRY_DATAPOINTS_RATE_LIMIT = "TELEMETRY_DATAPOINTS_RATE_LIMIT" - TELEMETRY_TIMESTAMP_PARAMETER = "ts" TELEMETRY_VALUES_PARAMETER = "values" diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index 12ab503..4a49d86 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -55,7 +55,8 @@ def build(cls, # noqa attributes: List[AttributeEntry], timeseries: Mapping[int, List[TimeseriesEntry]], delivery_futures: List[Optional[asyncio.Future]], - size: int) -> 'GatewayUplinkMessage': + size: int, + main_ts: Optional[int]) -> 'GatewayUplinkMessage': self = object.__new__(cls) object.__setattr__(self, 'device_name', device_name) object.__setattr__(self, 'device_profile', device_profile) @@ -65,6 +66,7 @@ def build(cls, # noqa object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) object.__setattr__(self, '_size', size) object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_UPLINK) + object.__setattr__(self, 'main_ts', main_ts) return self @property @@ -86,6 +88,15 @@ def has_timeseries(self) -> bool: def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: return self.delivery_futures + def set_main_ts(self, main_ts: int) -> 'GatewayUplinkMessage': + """ + Set the main timestamp for the message. + :param main_ts: The main timestamp to set. + :return: The updated GatewayUplinkMessage instance. + """ + object.__setattr__(self, 'main_ts', main_ts) + return self + class GatewayUplinkMessageBuilder: def __init__(self): @@ -95,6 +106,7 @@ def __init__(self): self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] self.__size = DEFAULT_FIELDS_SIZE + self._main_ts: Optional[int] = None def set_device_name(self, device_name: str) -> 'GatewayUplinkMessageBuilder': self._device_name = device_name @@ -147,6 +159,10 @@ def add_delivery_futures(self, futures: Union[ self._delivery_futures.extend(futures) return self + def set_main_ts(self, main_ts: int) -> 'GatewayUplinkMessageBuilder': + self._main_ts = main_ts + return self + def build(self) -> GatewayUplinkMessage: if not self._delivery_futures: delivery_future = asyncio.get_event_loop().create_future() @@ -158,5 +174,6 @@ def build(self) -> GatewayUplinkMessage: attributes=self._attributes, timeseries=self._timeseries, delivery_futures=self._delivery_futures, - size=self.__size + size=self.__size, + main_ts=self._main_ts ) diff --git a/tb_mqtt_client/service/base_message_splitter.py b/tb_mqtt_client/service/base_message_splitter.py new file mode 100644 index 0000000..050e3ef --- /dev/null +++ b/tb_mqtt_client/service/base_message_splitter.py @@ -0,0 +1,70 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage + + +class BaseMessageSplitter(ABC): + """ + Base class for message splitters in the ThingsBoard MQTT client. + """ + + @property + @abstractmethod + def max_payload_size(self) -> int: + """ + Returns the maximum payload size for messages. + """ + pass + + @max_payload_size.setter + @abstractmethod + def max_payload_size(self, value: int) -> None: + """ + Sets the maximum payload size for messages. + """ + pass + + @property + @abstractmethod + def max_datapoints(self) -> int: + """ + Returns the maximum number of datapoints allowed in a message. + """ + pass + + @max_datapoints.setter + @abstractmethod + def max_datapoints(self, value: int) -> None: + """ + Sets the maximum number of datapoints allowed in a message. + """ + pass + + @abstractmethod + def split_timeseries(self, *args, **kwargs) -> List[MqttPublishMessage]: + """ + Splits timeseries data + """ + pass + + @abstractmethod + def split_attributes(self, *args, **kwargs) -> List[MqttPublishMessage]: + """ + Splits attributes data + """ + pass diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 6f58d02..9ac5f74 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -27,6 +27,7 @@ from tb_mqtt_client.common.provisioning_client import ProvisioningClient from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry @@ -73,9 +74,16 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._message_adapter: MessageAdapter = JsonMessageAdapter(1000, 1) # Will be updated after connection established - self._messages_rate_limit = RateLimit("0:0,", name="messages") - self._telemetry_rate_limit = RateLimit("0:0,", name="telemetry") - self._telemetry_dp_rate_limit = RateLimit("0:0,", name="telemetryDataPoints") + self._rate_limiter = RateLimiter( + message_rate_limit=RateLimit("0:0,", name="messages"), + telemetry_message_rate_limit=RateLimit("0:0,", name="telemetryMessages"), + telemetry_datapoints_rate_limit=RateLimit("0:0,", name="telemetryDataPoints") + ) + self._gateway_rate_limiter = RateLimiter( + message_rate_limit=RateLimit("0:0,", name="messages"), + telemetry_message_rate_limit=RateLimit("0:0,", name="telemetryMessages"), + telemetry_datapoints_rate_limit=RateLimit("0:0,", name="telemetryDataPoints") + ) self._ssl_context = None self.max_payload_size = None self._max_inflight_messages = 100 @@ -132,13 +140,12 @@ async def connect(self): logger.debug("Using default max_payload_size: %d", self.max_payload_size) self._message_adapter = JsonMessageAdapter(self.max_payload_size, - self._telemetry_dp_rate_limit.minimal_limit) + self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit) self._message_queue = MessageQueue( mqtt_manager=self._mqtt_manager, main_stop_event=self._stop_event, - message_rate_limit=self._messages_rate_limit, - telemetry_rate_limit=self._telemetry_rate_limit, - telemetry_dp_rate_limit=self._telemetry_dp_rate_limit, + device_rate_limiter=self._rate_limiter, + gateway_rate_limiter=self._gateway_rate_limiter, message_adapter=self._message_adapter, max_queue_size=self._max_uplink_message_queue_size, ) @@ -357,31 +364,6 @@ async def _on_disconnect(self): logger.info("Device client disconnected.") self._requested_attribute_response_handler.clear() - async def send_rpc_call(self, method: str, params: Optional[Dict[str, Any]] = None, timeout: float = 10.0) -> Union[ - RPCResponse, None]: - """ - Initiates a client-side RPC to ThingsBoard and awaits the result. - :param method: The RPC method to call. - :param params: The parameters to send. - :param timeout: Timeout for the response in seconds. - :return: RPCResponse object containing the result or error. - """ - request_id = await RPCRequestIdProducer.get_next() - topic = mqtt_topics.build_device_rpc_request_topic(request_id) - payload = dumps({ - "method": method, - "params": params or {} - }) - - future = self._rpc_response_handler.register_request(request_id) - # TODO: Use MQTT message instead of raw payload - await self._mqtt_manager.publish(topic, payload, qos=1) - - try: - return await await_or_stop(future, timeout=timeout, stop_event=self._stop_event) - except TimeoutError: - raise TimeoutError(f"Timed out waiting for RPC response (method={method}, id={request_id})") - async def _handle_attribute_update(self, topic: str, payload: bytes): await self._attribute_updates_handler.handle(topic, payload) @@ -406,14 +388,14 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa rate_limits = response.result.get('rateLimits', {}) - await self._messages_rate_limit.set_limit(rate_limits.get("messages", "0:0,")) - await self._telemetry_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,")) - await self._telemetry_dp_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,")) + await self._rate_limiter.message_rate_limit.set_limit(rate_limits.get("messages", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.telemetry_message_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.telemetry_datapoints_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) server_inflight = int(response.result.get("maxInflightMessages", 100)) limits = [rl.minimal_limit for rl in [ - self._messages_rate_limit, - self._telemetry_rate_limit + self._rate_limiter.message_rate_limit, + self._rate_limiter.telemetry_message_rate_limit, ] if rl.has_limit()] if limits: @@ -426,10 +408,10 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa if "maxPayloadSize" in response.result: self.max_payload_size = int(response.result["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - # Update the dispatcher's max_payload_size if it's already initialized - if self._message_adapter is not None and hasattr(self._message_adapter, 'splitter'): + # Update the adapter's splitter with the new max_payload_size + if self._message_adapter is not None: self._message_adapter.splitter.max_payload_size = self.max_payload_size - logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) + logger.debug("Updated adapter's max_payload_size to %d", self.max_payload_size) else: # If maxPayloadSize is not provided, keep the default value logger.debug("No maxPayloadSize in service config, using default: %d", self.max_payload_size) @@ -442,9 +424,11 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa self._message_adapter.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) - if (not self._messages_rate_limit.has_limit() - and not self._telemetry_rate_limit.has_limit() - and not self._telemetry_dp_rate_limit.has_limit()): + self._message_adapter.splitter.max_datapoints = self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit + + if (not self._rate_limiter.message_rate_limit.has_limit() + and not self._rate_limiter.telemetry_message_rate_limit.has_limit() + and not self._rate_limiter.telemetry_datapoints_rate_limit.has_limit()): self._max_queued_messages = 50000 logger.debug("No rate limits, setting max_queued_messages to 50000") else: @@ -454,11 +438,7 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa logger.info("Service configuration retrieved and applied.") logger.info("Parsed device limits: %r", response) - self._mqtt_manager.set_rate_limits( - self._messages_rate_limit, - self._telemetry_rate_limit, - self._telemetry_dp_rate_limit - ) + self._mqtt_manager.set_rate_limits_received() return True except Exception as e: diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index fefe377..d528f54 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -15,7 +15,6 @@ from abc import ABC, abstractmethod from collections import defaultdict from datetime import UTC, datetime -from itertools import chain from typing import Any, Dict, List, Optional, Union from orjson import dumps, loads @@ -33,8 +32,7 @@ from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.service.message_splitter import MessageSplitter +from tb_mqtt_client.service.device.message_splitter import MessageSplitter logger = get_logger(__name__) @@ -42,7 +40,7 @@ class MessageAdapter(ABC): def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): self._splitter = MessageSplitter(max_payload_size, max_datapoints) - logger.trace("MessageDispatcher initialized with max_payload_size=%s, max_datapoints=%s", + logger.trace("MessageAdapter initialized with max_payload_size=%s, max_datapoints=%s", max_payload_size, max_datapoints) @abstractmethod @@ -95,12 +93,12 @@ def build_provision_request(self, provision_request) -> MqttPublishMessage: """ pass - @abstractmethod + @property def splitter(self) -> MessageSplitter: """ - Get the message splitter instance. + Get the message splitter instance used by this adapter. """ - pass + return self._splitter @abstractmethod def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: @@ -235,10 +233,6 @@ def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, logger.error("Failed to parse provisioning response: %s", str(e)) return ProvisioningResponse.build(provisioning_request, {"status": "FAILURE", "errorMsg": str(e)}) - @property - def splitter(self) -> MessageSplitter: - return self._splitter - def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: if not messages: logger.trace("No messages to process in build_uplink_messages.") diff --git a/tb_mqtt_client/service/message_splitter.py b/tb_mqtt_client/service/device/message_splitter.py similarity index 88% rename from tb_mqtt_client/service/message_splitter.py rename to tb_mqtt_client/service/device/message_splitter.py index 9861def..063f7fe 100644 --- a/tb_mqtt_client/service/message_splitter.py +++ b/tb_mqtt_client/service/device/message_splitter.py @@ -22,21 +22,17 @@ from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.base_message_splitter import BaseMessageSplitter logger = get_logger(__name__) -class MessageSplitter: - def __init__(self, max_payload_size: int = 65535, max_datapoints: int = 0): - if max_payload_size is None or max_payload_size <= 0: - logger.debug("Invalid max_payload_size: %s, using default 65535", max_payload_size) - max_payload_size = 65535 - if max_datapoints is None or max_datapoints < 0: - logger.debug("Invalid max_datapoints: %s, using default 0", max_datapoints) - max_datapoints = 0 +class MessageSplitter(BaseMessageSplitter): + DEFAULT_MAX_PAYLOAD_SIZE = 55_000 # Default to 55_000 to allow for some overhead - self._max_payload_size = max_payload_size - self._max_datapoints = max_datapoints + def __init__(self, max_payload_size: int = DEFAULT_MAX_PAYLOAD_SIZE, max_datapoints: int = 0): + self._max_payload_size = max_payload_size if max_payload_size is not None and max_payload_size > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 logger.trace("MessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", self._max_payload_size, self._max_datapoints) @@ -78,7 +74,7 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp if not builder or exceeds_size or exceeds_points: if builder: - shared_future = asyncio.Future() + shared_future = asyncio.get_running_loop().create_future() shared_future.uuid = uuid4() builder.add_delivery_futures(shared_future) @@ -101,7 +97,7 @@ def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp point_count += 1 if builder and builder._timeseries: # noqa - shared_future = asyncio.Future() + shared_future = asyncio.get_running_loop().create_future() shared_future.uuid = uuid4() builder.add_delivery_futures(shared_future) @@ -152,7 +148,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp if not builder or exceeds_size or exceeds_points: if builder and builder._attributes: # noqa - shared_future = asyncio.Future() + shared_future = asyncio.get_running_loop().create_future() shared_future.uuid = uuid4() builder.add_delivery_futures(shared_future) @@ -174,7 +170,7 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp point_count += 1 if builder and builder._attributes: # noqa - shared_future = asyncio.Future() + shared_future = asyncio.get_running_loop().create_future() shared_future.uuid = uuid4() builder.add_delivery_futures(shared_future) @@ -195,7 +191,7 @@ def max_payload_size(self) -> int: @max_payload_size.setter def max_payload_size(self, value: int): old = self._max_payload_size - self._max_payload_size = value if value > 0 else 65535 + self._max_payload_size = value if value is not None and value > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE logger.debug("Updated max_payload_size: %d -> %d", old, self._max_payload_size) @property @@ -205,5 +201,5 @@ def max_datapoints(self) -> int: @max_datapoints.setter def max_datapoints(self, value: int): old = self._max_datapoints - self._max_datapoints = value if value > 0 else 0 + self._max_datapoints = value if value is not None and value > 0 else 0 logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 9af53ff..6574d72 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -20,11 +20,13 @@ from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.event_type import GatewayEventType @@ -61,6 +63,7 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): """ self._config = config if isinstance(config, GatewayConfig) else GatewayConfig(config) super().__init__(self._config) + self._mqtt_manager.enable_gateway_mode() self.device_manager = DeviceManager() @@ -73,7 +76,7 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_RESPONSE, self._uplink_message_sender.send_rpc_response) self._event_dispatcher.register(GatewayEventType.GATEWAY_CLAIM_REQUEST, self._uplink_message_sender.send_claim_request) - self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter() + self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter(1000, 1) # Default max payload size and datapoints count limit, should be changed after connection established self._uplink_message_sender.set_message_adapter(self._gateway_message_adapter) self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed @@ -393,3 +396,26 @@ async def _unsubscribe_from_gateway_topics(self): logger.warning("Unsubscribe from gateway rpc topic timed out") break await sleep(0.01) + + async def _handle_rate_limit_response(self, response: RPCResponse): # noqa + parent_rate_limits_processing = await super()._handle_rate_limit_response(response) + try: + if not isinstance(response.result, dict) or 'gatewayRateLimits' not in response.result: + logger.warning("Invalid gateway rate limit response: %r", response) + return None + + gateway_rate_limits = response.result.get('gatewayRateLimits', {}) + + await self._gateway_rate_limiter.message_rate_limit.set_limit(gateway_rate_limits.get('messages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.telemetry_message_rate_limit.set_limit(gateway_rate_limits.get('telemetryMessages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit(gateway_rate_limits.get('telemetryDataPoints', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + + self._gateway_message_adapter.splitter.max_payload_size = self.max_payload_size + self._gateway_message_adapter.splitter.max_datapoints = self._device_telemetry_dp_rate_limit.minimal_limit + + self._mqtt_manager.set_gateway_rate_limits_received() + return parent_rate_limits_processing + + except Exception as e: + logger.exception("Failed to parse rate limits from server response: %s", e) + return False diff --git a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py index 5616590..478cd5b 100644 --- a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py +++ b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py @@ -49,12 +49,12 @@ async def dispatch(self, event: GatewayEvent, *args, device_session: DeviceSessi return await device_session.handle_event_to_device(event) async with self._lock: callbacks = list(self._handlers.get(event.event_type, [])) - for cb in callbacks: + for callback in callbacks: try: - if asyncio.iscoroutinefunction(cb): - return await cb(event, *args, **kwargs) + if asyncio.iscoroutinefunction(callback): + return await callback(event, *args, **kwargs) else: - return cb(event, *args, **kwargs) + return callback(event, *args, **kwargs) except Exception as e: logger.error(f"[EventDispatcher] Exception in handler for '{event.event_type}': {e}") return None diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py index 92ba824..465a270 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py @@ -16,8 +16,6 @@ from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 402939f..2d5dc77 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -12,32 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from abc import abstractmethod, ABC from collections import defaultdict from datetime import datetime, UTC from itertools import chain -from typing import List, Optional, Tuple, Dict, Any, Union +from typing import List, Dict, Any, Union, Optional from orjson import loads, dumps +from tb_mqtt_client.common.async_utils import future_map from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.mqtt_message import MqttPublishMessage -from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ GATEWAY_CLAIM_TOPIC from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate -from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest -from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.service.gateway.message_splitter import GatewayMessageSplitter logger = get_logger(__name__) @@ -46,11 +46,23 @@ class GatewayMessageAdapter(ABC): """ Adapter for converting events to uplink messages and received messages to events. """ + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): + self._splitter = GatewayMessageSplitter(max_payload_size, max_datapoints) + logger.trace("GatewayMessageAdapter initialized with max_payload_size=%s, max_datapoints=%s", + max_payload_size, max_datapoints) + + @property + def splitter(self) -> GatewayMessageSplitter: + """ + Returns the message splitter instance used by this adapter. + This allows for splitting messages into smaller parts if needed. + """ + return self._splitter @abstractmethod def build_uplink_messages( self, - messages: List[GatewayUplinkMessage] + messages: List[MqttPublishMessage] ) -> List[MqttPublishMessage]: """ Build a list of topic-payload pairs from the given messages. @@ -60,7 +72,7 @@ def build_uplink_messages( pass @abstractmethod - def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage) -> Tuple[str, bytes]: + def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage, qos) -> MqttPublishMessage: """ Build the payload for a device connect message. This method should be implemented to handle the specific format of the payload. @@ -68,7 +80,7 @@ def build_device_connect_message_payload(self, device_connect_message: DeviceCon pass @abstractmethod - def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage) -> Tuple[str, bytes]: + def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage, qos) -> MqttPublishMessage: """ Build the payload for a device disconnect message. This method should be implemented to handle the specific format of the payload. @@ -76,7 +88,7 @@ def build_device_disconnect_message_payload(self, device_disconnect_message: Dev pass @abstractmethod - def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest) -> Tuple[str, bytes]: + def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest, qos) -> MqttPublishMessage: """ Build the payload for a gateway attribute request. This method should be implemented to handle the specific format of the payload. @@ -84,7 +96,7 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr pass @abstractmethod - def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[str, bytes]: + def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse, qos) -> MqttPublishMessage: """ Build the payload for a gateway RPC response. This method should be implemented to handle the specific format of the payload. @@ -92,7 +104,7 @@ def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[ pass @abstractmethod - def build_claim_request_payload(self, claim_request: GatewayClaimRequest) -> Tuple[str, bytes]: + def build_claim_request_payload(self, claim_request: GatewayClaimRequest, qos) -> MqttPublishMessage: """ Build the payload for a gateway claim request. This method should be implemented to handle the specific format of the payload. @@ -108,7 +120,9 @@ def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate pass @abstractmethod - def parse_gateway_requested_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: + def parse_gateway_requested_attribute_response(self, + gateway_attribute_request: GatewayAttributeRequest, + data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: """ Parse the gateway attribute response data into an GatewayAttributeResponse. This method should be implemented to handle the specific format of the payload. @@ -138,91 +152,78 @@ class JsonGatewayMessageAdapter(GatewayMessageAdapter): Builds uplink payloads from uplink message objects and parses JSON payloads into GatewayEvent objects. """ - def build_uplink_messages(self, messages: List[GatewayUplinkMessage]) -> List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]]: - """ - Build a list of topic-payload pairs from the given messages. - Each pair consists of a topic string, payload bytes, the number of datapoints, - and a list of futures for delivery confirmation. - """ - try: - if not messages: - logger.trace("No messages to process in build_topic_payloads.") - return [] - - result: List[Tuple[str, bytes, int, List[Optional[asyncio.Future[PublishResult]]]]] = [] - device_groups: Dict[str, List[GatewayUplinkMessage]] = defaultdict(list) - - for msg in messages: - device_name = msg.device_name - device_groups[device_name].append(msg) - logger.trace("Queued message for device='%s'", device_name) - - logger.trace("Processing %d device group(s).", len(device_groups)) - - gateway_timeseries_message = {} - gateway_attributes_message = {} - gateway_timeseries_device_datapoints_counts: Dict[str, int] = {} - gateway_attributes_device_datapoints_counts: Dict[str, int] = {} - gateway_timeseries_delivery_futures: Dict[str, List[Optional[asyncio.Future[PublishResult]]]] = {} - gateway_attributes_delivery_futures: Dict[str, List[Optional[asyncio.Future[PublishResult]]]] = {} - for device, device_msgs in device_groups.items(): - timeseries_msgs: List[GatewayUplinkMessage] = [m for m in device_msgs if m.has_timeseries()] - attr_msgs: List[GatewayUplinkMessage] = [m for m in device_msgs if m.has_attributes()] - if device not in gateway_timeseries_message and timeseries_msgs: - gateway_timeseries_message[device] = [] - gateway_timeseries_delivery_futures[device] = [] - if device not in gateway_attributes_message and attr_msgs: - gateway_attributes_message[device] = {} - gateway_attributes_delivery_futures[device] = [] - logger.trace("Device '%s' - telemetry: %d, attributes: %d", - device, len(timeseries_msgs), len(attr_msgs)) - - # TODO: Recommended to add message splitter to handle large messages and split them into smaller batches - for ts_batch in timeseries_msgs: - packed_ts = JsonGatewayMessageAdapter.pack_timeseries(ts_batch) - gateway_timeseries_message[device].extend(packed_ts) + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + if not messages: + logger.trace("No messages to process in build_uplink_messages.") + return [] + + result: List[MqttPublishMessage] = [] + device_groups: Dict[str, List[GatewayUplinkMessage]] = defaultdict(list) + qos = messages[0].qos + + # Group by device + for mqtt_msg in messages: + payload = mqtt_msg.payload + if isinstance(payload, GatewayUplinkMessage): + device_groups[payload.device_name].append(payload) + logger.trace("Queued GatewayUplinkMessage for device='%s'", payload.device_name) + else: + logger.warning("Unsupported payload type '%s', skipping", type(payload).__name__) + + # Process each device group + for device_name, group_msgs in device_groups.items(): + telemetry_msgs = [m for m in group_msgs if m.has_timeseries()] + attr_msgs = [m for m in group_msgs if m.has_attributes()] + built_child_messages: List[MqttPublishMessage] = [] + + if telemetry_msgs: + for ts_batch in self._splitter.split_timeseries(telemetry_msgs): + payload_dict = {device_name: self.pack_timeseries(ts_batch)} + payload_bytes = dumps(payload_dict) count = ts_batch.timeseries_datapoint_count() - gateway_timeseries_device_datapoints_counts[device] = gateway_timeseries_device_datapoints_counts.get(device, 0) + count - gateway_timeseries_delivery_futures[device] = ts_batch.get_delivery_futures() - logger.trace("Built telemetry payload for device='%s' with %d datapoints", device, count) - - for attr_batch in attr_msgs: - packed_attrs = JsonGatewayMessageAdapter.pack_attributes(attr_batch) + futures = ts_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=GATEWAY_TELEMETRY_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=futures, + main_ts=ts_batch.main_ts + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + + if attr_msgs: + for attr_batch in self._splitter.split_attributes(attr_msgs): + payload_dict = {device_name: self.pack_attributes(attr_batch)} + payload_bytes = dumps(payload_dict) count = attr_batch.attributes_datapoint_count() - gateway_attributes_message[device].update(packed_attrs) - gateway_attributes_device_datapoints_counts[device] = gateway_attributes_device_datapoints_counts.get(device, 0) + count - gateway_attributes_delivery_futures[device] = attr_batch.get_delivery_futures() - logger.trace("Built attribute payload for device='%s' with %d attributes", device, count) - - if gateway_timeseries_message: - all_timeseries_delivery_futures = set() - for futures in gateway_timeseries_delivery_futures.values(): - if futures: - all_timeseries_delivery_futures.update(futures) - - result.append((GATEWAY_TELEMETRY_TOPIC, - dumps(gateway_timeseries_message), - sum(gateway_timeseries_device_datapoints_counts[per_device] for per_device in gateway_timeseries_device_datapoints_counts), - list(all_timeseries_delivery_futures))) - if gateway_attributes_message: - all_attributes_delivery_futures = set() - for futures in gateway_attributes_delivery_futures.values(): - if futures: - all_attributes_delivery_futures.update(futures) - result.append((GATEWAY_ATTRIBUTES_TOPIC, - dumps(gateway_attributes_message), - sum(gateway_attributes_device_datapoints_counts[per_device] for per_device in gateway_attributes_device_datapoints_counts), - list(all_attributes_delivery_futures))) - - logger.trace("Generated %d topic-payload entries.", len(result)) - - return result - except Exception as e: - logger.error("Error building topic-payloads: %s", str(e)) - logger.debug("Exception details: %s", e, exc_info=True) - raise - - def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage) -> Tuple[str, bytes]: + futures = attr_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=GATEWAY_ATTRIBUTES_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=futures, + main_ts=attr_batch.main_ts + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + + # Link parent futures to child delivery futures + parent_futures = [f for m in messages for f in (m.delivery_futures or []) + if isinstance(m.payload, GatewayUplinkMessage) and m.payload.device_name == device_name] + for parent in parent_futures: + for child_msg in built_child_messages: + for child in child_msg.delivery_futures or []: + future_map.register(parent, [child]) + + logger.trace("Generated %d MqttPublishMessage(s) for gateway uplink.", len(result)) + return result + + def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage, qos) -> MqttPublishMessage: """ Build the payload for a device connect message. This method serializes the DeviceConnectMessage to JSON format. @@ -230,12 +231,12 @@ def build_device_connect_message_payload(self, device_connect_message: DeviceCon try: payload = dumps(device_connect_message.to_payload_format()) logger.trace("Built device connect message payload for device='%s'", device_connect_message.device_name) - return GATEWAY_CONNECT_TOPIC, payload + return MqttPublishMessage(GATEWAY_CONNECT_TOPIC, payload, qos=1) except Exception as e: logger.error("Failed to build device connect message payload: %s", str(e)) raise ValueError("Invalid device connect message format") from e - def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage) -> Tuple[str, bytes]: + def build_device_disconnect_message_payload(self,device_disconnect_message: DeviceDisconnectMessage, qos) -> MqttPublishMessage: """ Build the payload for a device disconnect message. This method serializes the DeviceDisconnectMessage to JSON format. @@ -243,12 +244,12 @@ def build_device_disconnect_message_payload(self, device_disconnect_message: Dev try: payload = dumps(device_disconnect_message.to_payload_format()) logger.trace("Built device disconnect message payload for device='%s'", device_disconnect_message.device_name) - return GATEWAY_DISCONNECT_TOPIC, payload + return MqttPublishMessage(GATEWAY_DISCONNECT_TOPIC, payload, qos) except Exception as e: logger.error("Failed to build device disconnect message payload: %s", str(e)) raise ValueError("Invalid device disconnect message format") from e - def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest) -> Tuple[str, bytes]: + def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest, qos) -> MqttPublishMessage: """ Build the payload for a gateway attribute request. This method serializes the GatewayAttributeRequest to JSON format. @@ -257,12 +258,12 @@ def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttr payload = dumps(attribute_request.to_payload_format()) logger.trace("Built gateway attribute request payload for device='%s'", attribute_request.device_session.device_info.device_name) - return GATEWAY_ATTRIBUTES_REQUEST_TOPIC, payload + return MqttPublishMessage(GATEWAY_ATTRIBUTES_REQUEST_TOPIC, payload, qos) except Exception as e: logger.error("Failed to build gateway attribute request payload: %s", str(e)) raise ValueError("Invalid gateway attribute request format") from e - def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[str, bytes]: + def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse, qos) -> MqttPublishMessage: """ Build the payload for a gateway RPC response. This method serializes the GatewayRPCResponse to JSON format. @@ -271,12 +272,12 @@ def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse) -> Tuple[ payload = dumps(rpc_response.to_payload_format()) logger.trace("Built RPC response payload for device='%s', request_id=%i", rpc_response.device_name, rpc_response.request_id) - return GATEWAY_RPC_TOPIC, payload + return MqttPublishMessage(GATEWAY_RPC_TOPIC, payload, qos) except Exception as e: logger.error("Failed to build RPC response payload: %s", str(e)) raise ValueError("Invalid RPC response format") from e - def build_claim_request_payload(self, claim_request: GatewayClaimRequest) -> Tuple[str, bytes]: + def build_claim_request_payload(self, claim_request: GatewayClaimRequest, qos) -> MqttPublishMessage: """ Build the payload for a gateway claim request. This method serializes the GatewayClaimRequest to JSON format. @@ -284,7 +285,7 @@ def build_claim_request_payload(self, claim_request: GatewayClaimRequest) -> Tup try: payload = dumps(claim_request.to_payload_format()) logger.trace("Built claim request payload for devices: %s", list(claim_request.devices_requests.keys())) - return GATEWAY_CLAIM_TOPIC, payload + return MqttPublishMessage(GATEWAY_CLAIM_TOPIC, payload, qos) except Exception as e: logger.error("Failed to build claim request payload: %s", str(e)) raise ValueError("Invalid claim request format") from e @@ -298,7 +299,9 @@ def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate logger.error("Failed to parse attribute update: %s", str(e)) raise ValueError("Invalid attribute update format") from e - def parse_gateway_requested_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: + def parse_gateway_requested_attribute_response(self, + gateway_attribute_request: GatewayAttributeRequest, + data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: """ Parse the gateway attribute response data into a GatewayRequestedAttributeResponse. This method extracts the device name, shared and client attributes from the payload. @@ -307,8 +310,13 @@ def parse_gateway_requested_attribute_response(self, gateway_attribute_request: device_name = data['device'] client = [] shared = [] - if 'value' in data and not ((len(gateway_attribute_request.client_keys) == 1 and not gateway_attribute_request.shared_keys) - or (len(gateway_attribute_request.shared_keys) == 1 and not gateway_attribute_request.client_keys)): + client_keys_empty = gateway_attribute_request.client_keys is None + shared_keys_empty = gateway_attribute_request.shared_keys is None + if ('value' in data + and not (((not client_keys_empty and len(gateway_attribute_request.client_keys) == 1) + and not gateway_attribute_request.shared_keys) + or ((not shared_keys_empty and len(gateway_attribute_request.shared_keys) == 1) + and not gateway_attribute_request.client_keys))): # TODO: Skipping case when requested several attributes, but only one is returned, issue on the platform logger.warning("Received gateway attribute response with single key, but multiply keys expected. " "Request keys: %s, Response keys: %s", @@ -316,18 +324,21 @@ def parse_gateway_requested_attribute_response(self, gateway_attribute_request: data['value']) return None elif 'value' in data: - if gateway_attribute_request.client_keys is not None and len(gateway_attribute_request.client_keys) == 1: + if not client_keys_empty and len(gateway_attribute_request.client_keys) == 1: client = [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] - elif gateway_attribute_request.shared_keys is not None and len(gateway_attribute_request.shared_keys) == 1: + elif not shared_keys_empty and len(gateway_attribute_request.shared_keys) == 1: shared = [AttributeEntry(gateway_attribute_request.shared_keys[0], data['value'])] elif 'values' in data: - if gateway_attribute_request.client_keys is not None and len(gateway_attribute_request.client_keys) > 0: + if not client_keys_empty and len(gateway_attribute_request.client_keys) > 0: client = [AttributeEntry(k, v) for k, v in data['values'].items() if k in gateway_attribute_request.client_keys] - if gateway_attribute_request.shared_keys is not None and len(gateway_attribute_request.shared_keys) > 0: + if not shared_keys_empty and len(gateway_attribute_request.shared_keys) > 0: shared = [AttributeEntry(k, v) for k, v in data['values'].items() if k in gateway_attribute_request.shared_keys] - return GatewayRequestedAttributeResponse(device_name=device_name, request_id=gateway_attribute_request.request_id, shared=shared, client=client) + return GatewayRequestedAttributeResponse(device_name=device_name, + request_id=gateway_attribute_request.request_id, + shared=shared, + client=client) except Exception as e: logger.error("Failed to parse gateway requested attribute response: %s", str(e)) raise ValueError("Invalid gateway requested attribute response format") from e diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py index 41689e7..a8039a5 100644 --- a/tb_mqtt_client/service/gateway/message_sender.py +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -16,6 +16,7 @@ from typing import List, Union, Optional from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage @@ -56,23 +57,21 @@ async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> Uni return None futures = [] if message.has_timeseries(): - topic = mqtt_topics.GATEWAY_TELEMETRY_TOPIC - timeseries_futures = await self._message_queue.publish( - topic=topic, - payload=message, - datapoints_count=message.timeseries_datapoint_count(), - qos=qos - ) - futures.extend(timeseries_futures) + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.GATEWAY_TELEMETRY_TOPIC, + payload=message, + qos=qos + ) + await self._message_queue.publish(mqtt_message) + futures.extend(mqtt_message.delivery_futures) if message.has_attributes(): - topic = mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC - attributes_futures = await self._message_queue.publish( - topic=topic, - payload=message, - datapoints_count=message.attributes_datapoint_count(), - qos=qos - ) - futures.extend(attributes_futures) + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, + payload=message, + qos=qos + ) + await self._message_queue.publish(mqtt_message) + futures.extend(mqtt_message.delivery_futures) return futures async def send_device_connect(self, device_connect_message: DeviceConnectMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: @@ -86,13 +85,9 @@ async def send_device_connect(self, device_connect_message: DeviceConnectMessage if self._message_queue is None: logger.error("Cannot send device connect message. Message queue is not set, do you connected to the platform?") return None - topic, payload = self._message_adapter.build_device_connect_message_payload(device_connect_message=device_connect_message) - return await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=1, - qos=qos - ) + mqtt_message = self._message_adapter.build_device_connect_message_payload(device_connect_message=device_connect_message, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures async def send_device_disconnect(self, device_disconnect_message: DeviceDisconnectMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: """ @@ -105,13 +100,9 @@ async def send_device_disconnect(self, device_disconnect_message: DeviceDisconne if self._message_queue is None: logger.error("Cannot send device disconnect message. Message queue is not set, do you connected to the platform?") return None - topic, payload = self._message_adapter.build_device_disconnect_message_payload(device_disconnect_message=device_disconnect_message) - return await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=1, - qos=qos - ) + mqtt_message = self._message_adapter.build_device_disconnect_message_payload(device_disconnect_message=device_disconnect_message, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures async def send_attributes_request(self, attribute_request: GatewayAttributeRequest, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: """ @@ -124,13 +115,9 @@ async def send_attributes_request(self, attribute_request: GatewayAttributeReque if self._message_queue is None: logger.error("Cannot send attribute request. Message queue is not set, do you connected to the platform?") return None - topic, payload = self._message_adapter.build_gateway_attribute_request_payload(attribute_request=attribute_request) - return await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=1, - qos=qos - ) + mqtt_message = self._message_adapter.build_gateway_attribute_request_payload(attribute_request=attribute_request, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: """ @@ -143,13 +130,9 @@ async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> Un if self._message_queue is None: logger.error("Cannot send RPC response. Message queue is not set, do you connected to the platform?") return None - topic, payload = self._message_adapter.build_rpc_response_payload(rpc_response=rpc_response) - return await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=1, - qos=qos - ) + mqtt_message = self._message_adapter.build_rpc_response_payload(rpc_response=rpc_response, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: """ @@ -162,13 +145,9 @@ async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> if self._message_queue is None: logger.error("Cannot send claim request. Message queue is not set, do you connected to the platform?") return None - topic, payload = self._message_adapter.build_claim_request_payload(claim_request=claim_request) - return await self._message_queue.publish( - topic=topic, - payload=payload, - datapoints_count=1, - qos=qos - ) + mqtt_message = self._message_adapter.build_claim_request_payload(claim_request=claim_request, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures def set_message_queue(self, message_queue: MessageQueue): """ diff --git a/tb_mqtt_client/service/gateway/message_splitter.py b/tb_mqtt_client/service/gateway/message_splitter.py new file mode 100644 index 0000000..60dcc06 --- /dev/null +++ b/tb_mqtt_client/service/gateway/message_splitter.py @@ -0,0 +1,189 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import defaultdict +from typing import List, Dict, Tuple +from uuid import uuid4 + +from tb_mqtt_client.common.async_utils import future_map +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder +from tb_mqtt_client.service.base_message_splitter import BaseMessageSplitter + +logger = get_logger(__name__) + + +class GatewayMessageSplitter(BaseMessageSplitter): + DEFAULT_MAX_PAYLOAD_SIZE = 55000 # Default max payload size in bytes + + def __init__(self, max_payload_size: int = 55000, max_datapoints: int = 0): + self._max_payload_size = max_payload_size if max_payload_size is not None and max_payload_size > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 + logger.trace("GatewayMessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", + self._max_payload_size, self._max_datapoints) + + def split_timeseries(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: + logger.trace("Splitting gateway timeseries for %d messages", len(messages)) + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) + and messages[0].size <= self._max_payload_size): + return messages + + result: List[GatewayUplinkMessage] = [] + grouped: Dict[Tuple[str, str], List[GatewayUplinkMessage]] = defaultdict(list) + + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + all_ts_entries: List[TimeseriesEntry] = [] + parent_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_timeseries(): + for ts_group in msg.timeseries.values(): + all_ts_entries.extend(ts_group) + parent_futures.extend(msg.get_delivery_futures() or []) + + builder = None + size = 0 + point_count = 0 + + for entry in all_ts_entries: + exceeds_size = builder and size + entry.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder: + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed timeseries batch: size=%d, count=%d", size, point_count) + + builder = GatewayUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) + size = 0 + point_count = 0 + + builder.add_timeseries(entry) + size += entry.size + point_count += 1 + + if builder and builder._timeseries: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed final timeseries batch: size=%d, count=%d", size, point_count) + + return result + + def split_attributes(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: + logger.trace("Splitting gateway attributes for %d messages", len(messages)) + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) + and messages[0].size <= self._max_payload_size): + return messages + + result: List[GatewayUplinkMessage] = [] + grouped: Dict[Tuple[str, str], List[GatewayUplinkMessage]] = defaultdict(list) + + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + all_attrs: List[AttributeEntry] = [] + parent_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_attributes(): + all_attrs.extend(msg.attributes) + parent_futures.extend(msg.get_delivery_futures() or []) + + builder = None + size = 0 + point_count = 0 + + for attr in all_attrs: + exceeds_size = builder and size + attr.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder and builder._attributes: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed attribute batch: size=%d, count=%d", size, point_count) + + builder = GatewayUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) + size = 0 + point_count = 0 + + builder.add_attributes(attr) + size += attr.size + point_count += 1 + + if builder and builder._attributes: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed final attribute batch: size=%d, count=%d", size, point_count) + + return result + + @property + def max_payload_size(self) -> int: + return self._max_payload_size + + @max_payload_size.setter + def max_payload_size(self, value: int): + old = self._max_payload_size + self._max_payload_size = value if value > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + logger.debug("Updated max_payload_size: %d -> %d", old, self._max_payload_size) + + @property + def max_datapoints(self) -> int: + return self._max_datapoints + + @max_datapoints.setter + def max_datapoints(self, value: int): + old = self._max_datapoints + self._max_datapoints = value if value > 0 else 0 + logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) \ No newline at end of file diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py index f45d231..2a2d4e4 100644 --- a/tb_mqtt_client/service/message_queue.py +++ b/tb_mqtt_client/service/message_queue.py @@ -19,8 +19,9 @@ from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TOPICS from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.service.device.message_adapter import MessageAdapter @@ -36,21 +37,19 @@ class MessageQueue: def __init__(self, mqtt_manager: MQTTManager, main_stop_event: asyncio.Event, - message_rate_limit: Optional[RateLimit], - telemetry_rate_limit: Optional[RateLimit], - telemetry_dp_rate_limit: Optional[RateLimit], + device_rate_limiter: RateLimiter, message_adapter: MessageAdapter, max_queue_size: int = 1000000, batch_collect_max_time_ms: int = 100, batch_collect_max_count: int = 500, - gateway_message_adapter: Optional[GatewayMessageAdapter] = None): + gateway_message_adapter: Optional[GatewayMessageAdapter] = None, + gateway_rate_limiter: Optional[RateLimiter] = None): self._main_stop_event = main_stop_event self._batch_max_time = batch_collect_max_time_ms / 1000 self._batch_max_count = batch_collect_max_count self._mqtt_manager = mqtt_manager - self._message_rate_limit = message_rate_limit - self._telemetry_rate_limit = telemetry_rate_limit - self._telemetry_dp_rate_limit = telemetry_dp_rate_limit + self._device_rate_limiter = device_rate_limiter + self._gateway_rate_limiter = gateway_rate_limiter self._backpressure = self._mqtt_manager.backpressure # Queue expects tuples of (mqtt_message, delivery_futures) self._queue: asyncio.Queue[MqttPublishMessage] = asyncio.Queue(maxsize=max_queue_size) @@ -169,9 +168,6 @@ async def _dequeue_loop(self): continue async def _try_publish(self, message: MqttPublishMessage): - is_message_with_telemetry_or_attributes = message.topic in (mqtt_topics.DEVICE_TELEMETRY_TOPIC, - mqtt_topics.DEVICE_ATTRIBUTES_TOPIC) - # TODO: Add topics check for gateways logger.trace("Attempting publish: topic=%s, datapoints=%d", message.topic, message.datapoints) @@ -181,39 +177,8 @@ async def _try_publish(self, message: MqttPublishMessage): self._schedule_delayed_retry(message) return - # Check and consume rate limits atomically before publishing - if is_message_with_telemetry_or_attributes: - # For telemetry messages, we need to check both message and datapoint rate limits - telemetry_msg_success = True - telemetry_dp_success = True - - if self._telemetry_rate_limit: - triggered_rate_limit = await self._telemetry_rate_limit.try_consume(1) - if triggered_rate_limit: - logger.debug("Telemetry message rate limit hit for topic %s: %r per %r seconds", - message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) - retry_delay = self._telemetry_rate_limit.minimal_timeout - self._schedule_delayed_retry(message, delay=retry_delay) - return - - if self._telemetry_dp_rate_limit: - triggered_rate_limit = await self._telemetry_dp_rate_limit.try_consume(message.datapoints) - if triggered_rate_limit: - logger.debug("Telemetry datapoint rate limit hit for topic %s: %r per %r seconds", - message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) - retry_delay = self._telemetry_dp_rate_limit.minimal_timeout - self._schedule_delayed_retry(message, delay=retry_delay) - return - else: - # For non-telemetry messages, we only need to check the message rate limit - if self._message_rate_limit: - triggered_rate_limit = await self._message_rate_limit.try_consume(1) - if triggered_rate_limit: - logger.debug("Generic message rate limit hit for topic %s: %r per %r seconds", - message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) - retry_delay = self._message_rate_limit.minimal_timeout - self._schedule_delayed_retry(message, delay=retry_delay) - return + await self.check_rate_limits_for_message(message) + try: if logger.isEnabledFor(TRACE_LEVEL): logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", @@ -312,6 +277,46 @@ async def shutdown(self): logger.debug("MessageQueue shutdown complete, message queue size: %d", self._queue.qsize()) + async def check_rate_limits_for_message(self, message: MqttPublishMessage): + + message_rate_limit = None + datapoints_rate_limit = None + + is_message_with_telemetry_or_attributes = message.topic in mqtt_topics.TOPICS_WITH_DATAPOINTS_CHECK + is_gateway_message = message.topic in GATEWAY_TOPICS + + if is_gateway_message: + if is_message_with_telemetry_or_attributes: + message_rate_limit = self._gateway_rate_limiter.telemetry_message_rate_limit + datapoints_rate_limit = self._gateway_rate_limiter.telemetry_datapoints_rate_limit + else: + message_rate_limit = self._gateway_rate_limiter.message_rate_limit + else: + if is_message_with_telemetry_or_attributes: + message_rate_limit = self._device_rate_limiter.telemetry_message_rate_limit + datapoints_rate_limit = self._device_rate_limiter.telemetry_datapoints_rate_limit + else: + message_rate_limit = self._device_rate_limiter.message_rate_limit + + retry_delay = None + + if message_rate_limit: + triggered_rate_limit = await message_rate_limit.try_consume(1) + if triggered_rate_limit: + logger.debug("Rate limit hit for topic %s: %r per %r seconds", + message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) + retry_delay = message_rate_limit.minimal_timeout + if datapoints_rate_limit and retry_delay is None: + triggered_rate_limit = await datapoints_rate_limit.try_consume(message.datapoints) + if triggered_rate_limit: + logger.debug("Datapoint rate limit hit for topic %s: %r per %r seconds", + message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) + retry_delay = datapoints_rate_limit.minimal_timeout + + if retry_delay is not None: + self._schedule_delayed_retry(message, delay=retry_delay) + + @staticmethod async def _cancel_tasks(tasks: set[asyncio.Task]): for task in list(tasks): @@ -343,24 +348,22 @@ async def _rate_limit_refill_loop(self): while self._active.is_set(): await asyncio.sleep(1.0) await self._refill_rate_limits() - logger.trace("Rate limits refilled, state: %s", - { - "message_rate_limit": self._message_rate_limit.to_dict() if self._message_rate_limit else None, - "telemetry_rate_limit": self._telemetry_rate_limit.to_dict() if self._telemetry_rate_limit else None, - "telemetry_dp_rate_limit": self._telemetry_dp_rate_limit.to_dict() if self._telemetry_dp_rate_limit else None - }) + logger.trace("Rate limits refilled, state: %s", self._device_rate_limiter) except asyncio.CancelledError: pass async def _refill_rate_limits(self): - for rl in (self._message_rate_limit, self._telemetry_rate_limit, self._telemetry_dp_rate_limit): + for rl in self._device_rate_limiter.values(): if rl: await rl.refill() + if self._gateway_rate_limiter: + for rl in self._gateway_rate_limiter.values(): + if rl: + await rl.refill() def set_gateway_message_adapter(self, message_adapter: GatewayMessageAdapter): self._gateway_adapter = message_adapter - async def print_queue_statistics(self): """ Prints the current statistics of the message queue. diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 4b4689d..79d587f 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -14,7 +14,6 @@ import asyncio import ssl -import threading from asyncio import sleep from contextlib import suppress from time import monotonic @@ -23,17 +22,16 @@ from gmqtt import Client as GMQTTClient, Subscription -from tb_mqtt_client.common.async_utils import await_or_stop, future_map +from tb_mqtt_client.common.async_utils import await_or_stop, future_map, run_coroutine_sync from tb_mqtt_client.common.gmqtt_patch import PatchUtils from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer, AttributeRequestIdProducer from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT, TELEMETRY_MESSAGE_RATE_LIMIT, \ - TELEMETRY_DATAPOINTS_RATE_LIMIT from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler @@ -96,19 +94,12 @@ def __init__( self._backpressure = BackpressureController(self._main_stop_event) self.__rate_limits_handler = rate_limits_handler self.__rate_limits_retrieved = False - self.__rate_limiter: Optional[Dict[str, RateLimit]] = None - self.__is_gateway = False # TODO: determine if this is a gateway or not + self.__gateway_rate_limits_retrieved = False + self.__rate_limiter: Optional[RateLimiter] = None + self.__gateway_rate_limiter: Optional[RateLimiter] = None + self.__is_gateway = False self.__is_waiting_for_rate_limits_publish = True # Start with True to prevent publishing before rate limits are retrieved self._rate_limits_ready_event = asyncio.Event() - self._claiming_future = None - - # TODO: In case of implementing for gateway may be better to use a handler, to discuss - def register_claiming_future(self, future: asyncio.Future): - """ - Register a future that will be set when the claiming process is complete. - This is used to ensure that the MQTT client does not publish messages before the claiming process is done. - """ - self._claiming_future = future async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, password: Optional[str] = None, tls: bool = False, @@ -142,7 +133,10 @@ async def _connect_loop(self): await asyncio.sleep(retry_delay) def is_connected(self) -> bool: - return self._client.is_connected and self._connected_event.is_set() and self.__rate_limits_retrieved + return (self._client.is_connected + and self._connected_event.is_set() + and self.__rate_limits_retrieved + and (not self.__is_gateway or self.__gateway_rate_limits_retrieved)) async def disconnect(self): try: @@ -234,7 +228,7 @@ async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyn subscription = Subscription(topic, qos=qos) if isinstance(topic, str) else topic if self.__rate_limiter: - await self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() + await self.__rate_limiter.message_rate_limit.consume() mid = self._client._connection.subscribe([subscription]) # noqa self._pending_subscriptions[mid] = sub_future return sub_future @@ -243,7 +237,7 @@ async def unsubscribe(self, topic: str) -> asyncio.Future: unsubscribe_future = asyncio.get_event_loop().create_future() unsubscribe_future.uuid = uuid4() if self.__rate_limiter: - await self.__rate_limiter[MESSAGES_RATE_LIMIT].consume() + await self.__rate_limiter.message_rate_limit.consume() mid = self._client._connection.unsubscribe(topic) # noqa self._pending_unsubscriptions[mid] = unsubscribe_future return unsubscribe_future @@ -331,7 +325,7 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc for rate_limit in self.__rate_limiter.values(): if isinstance(rate_limit, RateLimit): try: - reached_limit = self._run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) + reached_limit = run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) except TimeoutError: logger.warning("Timeout while checking rate limit reaching.") reached_time = 10 # Default to 10 seconds if timeout occurs @@ -345,41 +339,6 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc if self._on_disconnect_callback: asyncio.create_task(self._on_disconnect_callback()) - def _run_coroutine_sync(self, coro_func, timeout: float = 3.0, raise_on_timeout: bool = False): - """ - Run async coroutine and return its result from a sync function even if event loop is running. - :param coro_func: async function with no arguments (like: lambda: some_async_fn()) - :param timeout: max wait time in seconds - :param raise_on_timeout: if True, raise TimeoutError on timeout; otherwise return None - """ - result_container = {} - event = threading.Event() - - async def wrapper(): - try: - result = await coro_func() - result_container['result'] = result - except Exception as e: - result_container['error'] = e - finally: - event.set() - - loop = asyncio.get_running_loop() - loop.create_task(wrapper()) - - completed = event.wait(timeout=timeout) - - if not completed: - logger.warning("Timeout while waiting for coroutine to finish: %s", coro_func) - if raise_on_timeout: - raise TimeoutError(f"Coroutine {coro_func} did not complete in {timeout} seconds.") - return None - - if 'error' in result_container: - raise result_container['error'] - - return result_container.get('result') - def _on_message_internal(self, client, topic: str, payload: bytes, qos, properties): logger.trace("Received message by client %r on topic %s with payload %r, qos %r, properties %r", client, topic, payload, qos, properties) @@ -440,25 +399,23 @@ def _on_unsubscribe_internal(self, client, mid, properties): if future and not future.done(): future.set_result(mid) - async def await_ready(self, timeout: float = 10.0): + async def await_ready(self, timeout: float = 5.0): try: await await_or_stop(self._rate_limits_ready_event.wait(), self._main_stop_event, timeout=timeout) except asyncio.TimeoutError: logger.debug("Waiting for rate limits timed out.") - def set_rate_limits( - self, - message_rate_limit: Union[RateLimit, Dict[str, RateLimit]], - telemetry_message_rate_limit: Optional[RateLimit], - telemetry_dp_rate_limit: Optional[RateLimit] - ): - self.__rate_limiter = { - MESSAGES_RATE_LIMIT: message_rate_limit, - TELEMETRY_MESSAGE_RATE_LIMIT: telemetry_message_rate_limit, - TELEMETRY_DATAPOINTS_RATE_LIMIT: telemetry_dp_rate_limit - } + def set_rate_limits_received(self): self.__rate_limits_retrieved = True self.__is_waiting_for_rate_limits_publish = False + if not self.__is_gateway: + self._rate_limits_ready_event.set() + + def enable_gateway_mode(self): + self.__is_gateway = True + + def set_gateway_rate_limits_received(self): + self.__gateway_rate_limits_retrieved = True self._rate_limits_ready_event.set() async def __request_rate_limits(self): @@ -474,9 +431,9 @@ async def __request_rate_limits(self): await self.publish(mqtt_message, qos=1, force=True) await await_or_stop(response_future, self._main_stop_event, timeout=10) logger.info("Successfully processed rate limits.") - self.__rate_limits_retrieved = True - self.__is_waiting_for_rate_limits_publish = False - self._rate_limits_ready_event.set() + # self.__rate_limits_retrieved = True + # self.__is_waiting_for_rate_limits_publish = False + # self._rate_limits_ready_event.set() except asyncio.TimeoutError: logger.warning("Timeout while waiting for rate limits.") # Keep __is_waiting_for_rate_limits_publish as True to prevent publishing @@ -539,9 +496,3 @@ async def stop(self): if self._client.is_connected: await self._client.disconnect() - - async def failed_messages_reprocessing(self): - """ - Reprocess failed messages that were not acknowledged by the server. - Using internal Gmqtt queue to get messages that were not acknowledged. - """ diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py index 8621c2d..0e6a627 100644 --- a/tests/service/device/test_device_client.py +++ b/tests/service/device/test_device_client.py @@ -254,15 +254,6 @@ async def test_send_attribute_request(): client._message_queue.publish.assert_awaited_once() -@pytest.mark.asyncio -async def test_send_rpc_call_timeout(): - client = DeviceClient() - client._mqtt_manager.publish = AsyncMock() - client._rpc_response_handler.register_request = MagicMock(return_value=asyncio.Future()) - with pytest.raises(TimeoutError): - await client.send_rpc_call("reboot", timeout=0.01) - - @pytest.mark.asyncio async def test_disconnect(): client = DeviceClient() @@ -535,7 +526,7 @@ async def test_handle_rate_limit_response_with_partial_data(): response = RPCResponse.build(1, result={"rateLimits": {"messages": "5:1"}}) result = await client._handle_rate_limit_response(response) assert result is True - assert client._messages_rate_limit.has_limit() + assert client._rate_limiter.message_rate_limit.has_limit() assert client.max_payload_size == 65535 diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py index f6b1511..517a42d 100644 --- a/tests/service/test_message_queue.py +++ b/tests/service/test_message_queue.py @@ -21,6 +21,7 @@ from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry @@ -46,9 +47,7 @@ async def message_queue(fake_mqtt_manager): mq = MessageQueue( mqtt_manager=fake_mqtt_manager, main_stop_event=main_stop_event, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(RateLimit('0:0,'), RateLimit('0:0,'), RateLimit('0:0,')), message_adapter=message_adapter, max_queue_size=10, batch_collect_max_time_ms=10, @@ -82,9 +81,7 @@ async def test_shutdown_clears_tasks_and_queue(fake_mqtt_manager): q = MessageQueue( mqtt_manager=fake_mqtt_manager, main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=Mock(build_uplink_messages=Mock(return_value=[])), max_queue_size=10, batch_collect_max_time_ms=10, @@ -131,9 +128,7 @@ async def test_try_publish_rate_limit_triggered(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=limit, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(AsyncMock(), limit, AsyncMock()), message_adapter=Mock(build_uplink_messages=Mock(return_value=[])), ) mq._schedule_delayed_retry = AsyncMock() @@ -150,6 +145,7 @@ async def test_try_publish_failure_schedules_retry(message_queue): message = MqttPublishMessage("topic", b"broken", qos=1, delivery_futures=[]) message_queue._mqtt_manager.publish.side_effect = Exception("fail") message_queue._schedule_delayed_retry = AsyncMock() + message_queue._mqtt_manager._backpressure.should_pause.return_value = False await message_queue._try_publish(message) await asyncio.sleep(0.2) assert message_queue._schedule_delayed_retry.call_count == 1 @@ -172,9 +168,7 @@ async def test_rate_limit_refill(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=asyncio.Event(), - message_rate_limit=rate_limit, - telemetry_rate_limit=rate_limit, - telemetry_dp_rate_limit=rate_limit, + device_rate_limiter=RateLimiter(rate_limit, rate_limit, rate_limit), message_adapter=Mock(), ) await mq._refill_rate_limits() @@ -195,9 +189,7 @@ async def test_wait_for_message_cancel(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=stop, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=Mock(), ) @@ -216,9 +208,7 @@ async def test_clear_futures_result_set(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=Mock(), ) mq._queue.put_nowait(MqttPublishMessage("topic", msg, qos=1, delivery_futures=[fut])) @@ -236,9 +226,7 @@ async def test_message_queue_batching_respects_type_and_size(): mq = MessageQueue( mqtt_manager=mqtt_manager, main_stop_event=stop_event, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(RateLimit('0:0,'), RateLimit('0:0,'), RateLimit('0:0,')), message_adapter=message_adapter, batch_collect_max_time_ms=200, max_queue_size=100, @@ -270,9 +258,7 @@ async def test_schedule_delayed_retry_reenqueues_message(): queue = MessageQueue( mqtt_manager=MagicMock(), main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=MagicMock() ) @@ -303,9 +289,7 @@ async def test_schedule_delayed_retry_does_nothing_if_inactive(): queue = MessageQueue( mqtt_manager=MagicMock(), main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=MagicMock() ) @@ -333,9 +317,7 @@ async def test_schedule_delayed_retry_does_nothing_if_stopped(): queue = MessageQueue( mqtt_manager=MagicMock(), main_stop_event=stop_event, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=MagicMock() ) @@ -360,9 +342,7 @@ async def test_try_publish_telemetry_rate_limit_triggered(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=telemetry_limit, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), telemetry_limit, MagicMock()), message_adapter=Mock(build_uplink_messages=Mock(return_value=[])) ) mq._schedule_delayed_retry = AsyncMock() @@ -389,9 +369,7 @@ async def test_try_publish_telemetry_dp_rate_limit_triggered(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=asyncio.Event(), - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=dp_limit, + device_rate_limiter=RateLimiter(RateLimit('0:0,'), RateLimit('0:0,'), dp_limit), message_adapter=Mock(build_uplink_messages=Mock(return_value=[])) ) mq._schedule_delayed_retry = AsyncMock() @@ -419,9 +397,7 @@ async def test_try_publish_generic_message_rate_limit_triggered(): mq = MessageQueue( mqtt_manager=AsyncMock(), main_stop_event=asyncio.Event(), - message_rate_limit=generic_limit, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(generic_limit, MagicMock(), MagicMock()), message_adapter=Mock() ) mq._schedule_delayed_retry = AsyncMock() @@ -457,9 +433,7 @@ async def test_batch_breaks_on_elapsed_time(): mq = MessageQueue( mqtt_manager=mqtt_manager, main_stop_event=stop, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=adapter, batch_collect_max_time_ms=1, batch_collect_max_count=1000 @@ -491,9 +465,7 @@ async def test_batch_breaks_on_message_count(): mq = MessageQueue( mqtt_manager=mqtt_manager, main_stop_event=stop, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=adapter, batch_collect_max_time_ms=1000, batch_collect_max_count=2 @@ -528,9 +500,7 @@ async def test_batch_breaks_on_type_mismatch(): mq = MessageQueue( mqtt_manager=mqtt_manager, main_stop_event=stop, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=adapter, batch_collect_max_time_ms=1000, batch_collect_max_count=10 @@ -573,9 +543,7 @@ async def test_batch_breaks_on_size_threshold(): mq = MessageQueue( mqtt_manager=mqtt_manager, main_stop_event=stop, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=adapter, batch_collect_max_time_ms=1000, batch_collect_max_count=10 @@ -605,9 +573,7 @@ async def test_batch_skips_bytes_payload(): mq = MessageQueue( mqtt_manager=mqtt_manager, main_stop_event=stop, - message_rate_limit=None, - telemetry_rate_limit=None, - telemetry_dp_rate_limit=None, + device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), message_adapter=JsonMessageAdapter(), batch_collect_max_time_ms=1000, batch_collect_max_count=1000 diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index 4c4b231..a82e462 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -22,7 +22,7 @@ from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter -from tb_mqtt_client.service.message_splitter import MessageSplitter +from tb_mqtt_client.service.device.message_splitter import MessageSplitter @pytest.fixture @@ -54,7 +54,7 @@ def test_single_small_attributes_pass_through(splitter): # Negative test: invalid payload size and datapoints def test_invalid_config_defaults(): splitter = MessageSplitter(max_payload_size=-10, max_datapoints=-100) - assert splitter.max_payload_size == 65535 + assert splitter.max_payload_size == 55000 assert splitter.max_datapoints == 0 @@ -86,33 +86,44 @@ def test_malformed_message_handling(splitter): # Negative test: builder fails on build() -@patch("tb_mqtt_client.service.message_splitter.DeviceUplinkMessageBuilder") +@patch("tb_mqtt_client.service.device.message_splitter.DeviceUplinkMessageBuilder") def test_builder_failure_during_split_raises(mock_builder_class): - entry = MagicMock() - entry.size = 10 + async def run_test(): + entry = MagicMock() + entry.size = 10 - message = MagicMock() - message.device_name = "dev" - message.device_profile = "prof" - message.has_timeseries.return_value = True - message.timeseries = {"temp": [entry] * 4} - message.get_delivery_futures.return_value = [] - message.attributes_datapoint_count.return_value = 0 - message.timeseries_datapoint_count.return_value = 4 - message.size = 50 + message = MagicMock() + message.device_name = "dev" + message.device_profile = "prof" + message.has_timeseries.return_value = True + message.timeseries = {"temp": [entry] * 5} + message.get_delivery_futures.return_value = [] + message.attributes_datapoint_count.return_value = 0 + message.timeseries_datapoint_count.return_value = 5 + message.size = 100 - builder_instance = MagicMock() - builder_instance.set_device_name.return_value = builder_instance - builder_instance.set_device_profile.return_value = builder_instance - builder_instance.add_timeseries.return_value = None - builder_instance._timeseries = [entry] - builder_instance.build.side_effect = RuntimeError("build failed") - mock_builder_class.return_value = builder_instance + builder_instance = MagicMock() + builder_instance.set_device_name.return_value = builder_instance + builder_instance.set_device_profile.return_value = builder_instance + builder_instance.set_main_ts.return_value = builder_instance - splitter = MessageSplitter(max_payload_size=20, max_datapoints=2) + builder_instance._timeseries = [] - with pytest.raises(RuntimeError, match="build failed"): - splitter.split_timeseries([message]) + def add_ts_side_effect(ts): + builder_instance._timeseries.append(ts) + + builder_instance.add_timeseries.side_effect = add_ts_side_effect + builder_instance.add_delivery_futures.return_value = None + builder_instance.build.side_effect = RuntimeError("build failed") + + mock_builder_class.return_value = builder_instance + + splitter = MessageSplitter(max_payload_size=20, max_datapoints=2) + + with pytest.raises(RuntimeError, match="build failed"): + splitter.split_timeseries([message]) + + asyncio.run(run_test()) # Property validation @@ -121,7 +132,7 @@ def test_payload_setter_validation(): s.max_payload_size = 12345 assert s.max_payload_size == 12345 s.max_payload_size = 0 - assert s.max_payload_size == 65535 + assert s.max_payload_size == 55000 def test_datapoint_setter_validation(): @@ -180,7 +191,7 @@ async def test_split_attributes_different_devices_not_grouped(): fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) -@patch("tb_mqtt_client.service.message_splitter.future_map.register") +@patch("tb_mqtt_client.service.device.message_splitter.future_map.register") @pytest.mark.asyncio async def test_split_timeseries_registers_futures_and_batches_correctly(mock_register): splitter = MessageSplitter(max_payload_size=100, max_datapoints=2) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 0647c2a..b052673 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -22,7 +22,7 @@ from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit -from tb_mqtt_client.constants.service_keys import MESSAGES_RATE_LIMIT +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler from tb_mqtt_client.service.device.message_adapter import MessageAdapter from tb_mqtt_client.service.mqtt_manager import MQTTManager, IMPLEMENTATION_SPECIFIC_ERROR, QUOTA_EXCEEDED @@ -188,8 +188,7 @@ async def test_await_ready_timeout(setup_manager): @pytest.mark.asyncio async def test_set_rate_limits_allows_ready(setup_manager): manager, *_ = setup_manager - mock_limit = MagicMock() - manager.set_rate_limits(mock_limit, None, None) + manager.set_rate_limits_received() assert manager._rate_limits_ready_event.is_set() @@ -224,7 +223,8 @@ async def test_subscribe_adds_future(setup_manager): manager._client._connection.subscribe.return_value = 42 mock_rate_limit = AsyncMock() - setattr(manager, "_MQTTManager__rate_limiter", {MESSAGES_RATE_LIMIT: mock_rate_limit}) + rate_limiter = RateLimiter(mock_rate_limit, MagicMock(), MagicMock()) + setattr(manager, "_MQTTManager__rate_limiter", rate_limiter) fut = await manager.subscribe("topic", qos=1) await asyncio.sleep(0.1) @@ -241,7 +241,8 @@ async def test_unsubscribe_adds_future(setup_manager): manager._client._connection.unsubscribe.return_value = 77 mock_rate_limit = AsyncMock() - setattr(manager, "_MQTTManager__rate_limiter", {MESSAGES_RATE_LIMIT: mock_rate_limit}) + rate_limiter = RateLimiter(mock_rate_limit, MagicMock(), MagicMock()) + setattr(manager, "_MQTTManager__rate_limiter", rate_limiter) fut = await manager.unsubscribe("topic") await asyncio.sleep(0.1) @@ -343,6 +344,7 @@ async def test_request_rate_limits_timeout(setup_manager): with patch("tb_mqtt_client.entities.data.rpc_request.RPCRequest.build", return_value=req_mock): await manager._MQTTManager__request_rate_limits() + manager.set_rate_limits_received() assert manager._rate_limits_ready_event.is_set() @@ -360,8 +362,10 @@ async def test_match_topic_exact_match_and_failures(): assert not MQTTManager._match_topic("a/+/c", "a/x") +@patch("tb_mqtt_client.service.mqtt_manager.run_coroutine_sync") @pytest.mark.asyncio -async def test_disconnect_reason_code_142_triggers_special_flow(setup_manager): +async def test_disconnect_reason_code_142_triggers_special_flow(mock_run_sync, setup_manager): + mock_run_sync.return_value = (None, 1, 1) manager, *_ = setup_manager manager._client = MagicMock() manager._backpressure = MagicMock() From 18af0030a9039c51e0797856b19a09d0f35e6ee3 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 25 Jul 2025 13:56:24 +0300 Subject: [PATCH 61/74] Adjusted grouping timeseries by ts in uplink messages for gateway part of SDK --- .../gateway/gateway_uplink_message.py | 3 +- tb_mqtt_client/service/device/client.py | 2 -- .../service/gateway/message_adapter.py | 31 +++++++++++++------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index 4a49d86..e3b1947 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -146,8 +146,7 @@ def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry self._timeseries[0].append(entry) else: self._timeseries[0] = [entry] - for timeseries_entry in timeseries: - self.__size += timeseries_entry.size + self.__size += entry.size return self def add_delivery_futures(self, futures: Union[ diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 9ac5f74..2ff4c91 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -18,8 +18,6 @@ from string import ascii_uppercase, digits from typing import Callable, Awaitable, Optional, Dict, Any, Union, List -from orjson import dumps - from tb_mqtt_client.common.async_utils import await_or_stop from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import get_logger diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 2d5dc77..b080aa1 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -372,12 +372,25 @@ def pack_attributes(msg: GatewayUplinkMessage) -> Dict[str, Any]: return {attr.key: attr.value for attr in msg.attributes} @staticmethod - def pack_timeseries(msg: GatewayUplinkMessage) -> List[Dict[str, Any]]: - now_ts = int(datetime.now(UTC).timestamp() * 1000) - packed = [ - {"ts": entry.ts or now_ts, "values": {entry.key: entry.value}} - for entry in chain.from_iterable(msg.timeseries.values()) - ] - logger.trace("Packed %d timeseries entry(s)", len(packed)) - - return packed + def pack_timeseries(msg: 'GatewayUplinkMessage') -> Union[Dict[str, Any], List[Dict[str, Any]]]: + entries = [e for entries in msg.timeseries.values() for e in entries] + if not entries: + return {} + + all_ts_none = True + for e in entries: + if e.ts is not None: + all_ts_none = False + break + + if all_ts_none: + result = {e.key: e.value for e in entries} + return [{"ts": msg.main_ts, "values": result}] if msg.main_ts is not None else result + + now_ts = msg.main_ts if msg.main_ts is not None else int(datetime.now(UTC).timestamp() * 1000) + grouped = defaultdict(dict) + for e in entries: + ts = e.ts if e.ts is not None else now_ts + grouped[ts][e.key] = e.value + + return [{"ts": ts, "values": values} for ts, values in grouped.items()] From 45298b241977f5fdb66b15de0b7282238fc5df6a Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 31 Jul 2025 12:24:29 +0300 Subject: [PATCH 62/74] Optimized message service, rate limits processing, sending retry for qos 1 --- examples/gateway/load.py | 19 +- examples/gateway/send_timeseries.py | 85 +- tb_mqtt_client/common/exceptions.py | 2 +- tb_mqtt_client/common/gmqtt_patch.py | 29 +- .../common/install_package_utils.py | 3 +- tb_mqtt_client/common/mqtt_message.py | 11 +- tb_mqtt_client/common/queue.py | 97 ++ .../rate_limit/backpressure_controller.py | 26 +- .../common/rate_limit/rate_limit.py | 31 +- .../gateway/device_disconnect_message.py | 14 - .../entities/gateway/gateway_claim_request.py | 13 - .../gateway/gateway_uplink_message.py | 3 + tb_mqtt_client/service/device/client.py | 6 +- .../service/device/message_adapter.py | 9 +- tb_mqtt_client/service/gateway/client.py | 12 +- .../service/gateway/device_session.py | 1 + .../service/gateway/message_adapter.py | 11 +- .../service/gateway/message_sender.py | 6 +- .../service/gateway/multiplex_publisher.py | 14 - tb_mqtt_client/service/message_queue.py | 379 ------- tb_mqtt_client/service/message_service.py | 405 ++++++++ tb_mqtt_client/service/mqtt_manager.py | 101 +- tests/common/test_backpressure_controller.py | 5 - tests/common/test_config_loader.py | 103 ++ tests/service/device/test_device_client.py | 6 +- tests/service/test_message_queue.py | 596 ----------- tests/service/test_message_service.py | 973 ++++++++++++++++++ tests/test_async_utils.py | 68 ++ 28 files changed, 1884 insertions(+), 1144 deletions(-) create mode 100644 tb_mqtt_client/common/queue.py delete mode 100644 tb_mqtt_client/service/gateway/multiplex_publisher.py delete mode 100644 tb_mqtt_client/service/message_queue.py create mode 100644 tb_mqtt_client/service/message_service.py create mode 100644 tests/common/test_config_loader.py delete mode 100644 tests/service/test_message_queue.py create mode 100644 tests/service/test_message_service.py create mode 100644 tests/test_async_utils.py diff --git a/examples/gateway/load.py b/examples/gateway/load.py index 29484c3..573b5db 100644 --- a/examples/gateway/load.py +++ b/examples/gateway/load.py @@ -34,12 +34,12 @@ logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) # --- Constants --- -NUM_DEVICES = 2 -BATCH_SIZE = 100 -MAX_PENDING = 100 +NUM_DEVICES = 10 +BATCH_SIZE = 1000 +MAX_PENDING = 10 FUTURE_TIMEOUT = 1.0 DEVICE_PREFIX = "perf-test-device" -WAIT_FOR_PUBLISH = False +WAIT_FOR_PUBLISH = False # Set to True if you want to wait for publish confirmation (can slow down the test, because each message will wait for confirmation) # --- Test logic --- async def send_batch(client: GatewayClient, session: DeviceSession) -> List[asyncio.Future]: @@ -54,10 +54,17 @@ async def send_batch(client: GatewayClient, session: DeviceSession) -> List[asyn async def wait_for_futures(futures: List[asyncio.Future]) -> int: delivered = 0 - done, _ = await asyncio.wait(futures, timeout=FUTURE_TIMEOUT, return_when=asyncio.ALL_COMPLETED) + if isinstance(futures, list) and futures and isinstance(futures[0], asyncio.Future): + done, _ = await asyncio.wait(futures, timeout=FUTURE_TIMEOUT, return_when=asyncio.ALL_COMPLETED) + else: + done = futures + for fut in done: try: - res = fut.result() + if isinstance(fut, asyncio.Future): + res = fut.result() + else: + res = fut if isinstance(res, PublishResult) and res.is_successful(): delivered += res.datapoints_count except Exception as e: diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index 4a7b4b8..dff7719 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import signal from time import time from tb_mqtt_client.common.config_loader import GatewayConfig @@ -44,36 +45,60 @@ async def main(): return logger.info("Device connected successfully: %s", device_name) - - # Send time series as raw dictionary - raw_timeseries = { - "temperature": 25.5, - "humidity": 60 - } - logger.info("Sending raw timeseries: %s", raw_timeseries) - await client.send_device_timeseries(device_session=device_session, data=raw_timeseries, wait_for_publish=True) - logger.info("Raw timeseries sent successfully.") - - # Send time series as list of dictionaries - ts = int(time() * 1000) - list_timeseries = [ - {"ts": ts, "values": {"temperature": 26.0, "humidity": 65}}, - {"ts": ts - 1000, "values": {"temperature": 26.5, "humidity": 70}} - ] - logger.info("Sending list of timeseries: %s", list_timeseries) - await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) - logger.info("List of timeseries sent successfully.") - - # Send time series as TimeseriesEntry objects - timeseries_entries = [ - TimeseriesEntry(ts=ts, key="temperature", value=27.0), - TimeseriesEntry(ts=ts, key="humidity", value=75), - TimeseriesEntry(ts=ts - 1000, key="temperature", value=28.0), - TimeseriesEntry(ts=ts - 1000, key="humidity", value=80) - ] - logger.info("Sending TimeseriesEntry objects: %s", timeseries_entries) - await client.send_device_timeseries(device_session=device_session, data=timeseries_entries, wait_for_publish=True) - logger.info("TimeseriesEntry objects sent successfully.") + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + asyncio.gather(client.stop(), return_exceptions=True) + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) # noqa + except NotImplementedError: + # Windows compatibility fallback + signal.signal(sig, lambda *_: _shutdown_handler()) # noqa + + loop_counter = 0 + + while not stop_event.is_set(): + loop_counter += 1 + logger.info("Sending timeseries data, iteration: %d", loop_counter) + + # Send time series as raw dictionary + raw_timeseries = { + "temperature": 25.5, + "humidity": 60 + } + logger.info("Sending raw timeseries: %s", raw_timeseries) + await client.send_device_timeseries(device_session=device_session, data=raw_timeseries, wait_for_publish=True) + logger.info("Raw timeseries sent successfully.") + + # Send time series as list of dictionaries + ts = int(time() * 1000) + list_timeseries = [ + {"ts": ts, "values": {"temperature": 26.0, "humidity": 65}}, + {"ts": ts - 1000, "values": {"temperature": 26.5, "humidity": 70}} + ] + logger.info("Sending list of timeseries: %s", list_timeseries) + await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) + logger.info("List of timeseries sent successfully.") + + # Send time series as TimeseriesEntry objects + ts = int(time() * 1000) + timeseries_entries = [ + TimeseriesEntry(key="temperature%i" % i, value=loop_counter, ts=ts) for i in range(20) + ] + logger.info("Sending TimeseriesEntry objects: %s", timeseries_entries) + await client.send_device_timeseries(device_session=device_session, data=timeseries_entries, wait_for_publish=True) + logger.info("TimeseriesEntry objects sent successfully.") + + + try: + logger.info("Waiting before next iteration...") + await asyncio.wait_for(stop_event.wait(), timeout=1) + except asyncio.TimeoutError: + logger.info("Going to next iteration...") await client.stop() diff --git a/tb_mqtt_client/common/exceptions.py b/tb_mqtt_client/common/exceptions.py index a4a3517..4fc2237 100644 --- a/tb_mqtt_client/common/exceptions.py +++ b/tb_mqtt_client/common/exceptions.py @@ -16,7 +16,7 @@ import logging from typing import Callable, Dict, List, Optional, Type -logger = logging.getLogger("tb_sdk") +logger = logging.getLogger(__name__) ExceptionCallback = Callable[[BaseException, Optional[dict]], None] diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index d837ae5..592f712 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import heapq import struct from collections import defaultdict -from time import monotonic from typing import Callable, Tuple, Optional from gmqtt import Client @@ -102,7 +102,7 @@ class PatchUtils: 154: "Wildcard Subscriptions not supported" } - def __init__(self, client: Optional[Client], stop_event: asyncio.Event, retry_interval: int = 1): + def __init__(self, client: Optional[Client], stop_event: asyncio.Event, retry_interval: int = 15): """ Initialize PatchUtils with a client and retry interval. @@ -384,7 +384,6 @@ async def _retry_loop(self): self.patch_storage() try: while not self._stop_event.is_set(): - retry_list = [] current_tm = asyncio.get_event_loop().time() for _ in range(100): @@ -404,30 +403,16 @@ async def _retry_loop(self): if msg is None: break - retry_list.append(msg) + tm, mid, mqtt_msg = msg - for (tm, mid, mqtt_msg) in retry_list: if current_tm - tm > self.retry_interval and self.client.is_connected: mqtt_msg.dup = True logger.error("Resending PUBLISH message with mid=%r, topic=%s", mid, mqtt_msg.topic) - protocol = self.client._connection._protocol - - if protocol: - try: - mid, rebuilt = PublishPacket.build_package( - message=mqtt_msg, - protocol=protocol, - mid=mid - ) - self.client._connection.send_package(rebuilt) - logger.trace("Retransmitted message mid=%r", mid) - except Exception as e: - logger.warning("Error during retransmission: %s", e) - else: - logger.warning("Cannot retransmit, MQTT protocol unavailable.") - - heapq.heappush(self.client._persistent_storage._queue, (tm, mid, mqtt_msg)) + try: + await self.client.put_retry_message(mqtt_msg) # noqa This method sets in message service to the client + except AttributeError as e: + logger.trace("Failed to resend message with mid=%r: %s", mid, e) await asyncio.sleep(self.retry_interval) except asyncio.CancelledError: diff --git a/tb_mqtt_client/common/install_package_utils.py b/tb_mqtt_client/common/install_package_utils.py index 1227418..e2f0438 100644 --- a/tb_mqtt_client/common/install_package_utils.py +++ b/tb_mqtt_client/common/install_package_utils.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sys import executable from subprocess import check_call, CalledProcessError, DEVNULL +from sys import executable + from pkg_resources import get_distribution, DistributionNotFound diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py index 384b634..394b9ae 100644 --- a/tb_mqtt_client/common/mqtt_message.py +++ b/tb_mqtt_client/common/mqtt_message.py @@ -15,11 +15,12 @@ from asyncio import Future from time import time from typing import Union, Optional -from uuid import uuid4 +from uuid import uuid4, UUID from gmqtt import Message from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage @@ -39,18 +40,24 @@ def __init__(self, datapoints: int = 0, delivery_futures = None, main_ts: Optional[int] = None, + original_payload = None, **kwargs): """ Initialize the MqttMessage with topic, payload, QoS, retain flag, and datapoints. """ + self.uuid: UUID = uuid4() + self.payload_size = 0 self.prepared = False - self.payload = payload + self.original_payload = original_payload if original_payload is not None else payload self.main_ts = main_ts if main_ts is not None else int(time() * 1000) if isinstance(payload, bytes): super().__init__(topic, payload, qos, retain) else: + self.payload = self.original_payload payload.set_main_ts(self.main_ts) self.topic = topic + self.is_service_message = self.topic not in mqtt_topics.TOPICS_WITH_DATAPOINTS_CHECK + self.is_device_message = not isinstance(original_payload, GatewayUplinkMessage) self.qos = qos if self.qos < 0 or self.qos > 1: logger.warning(f"Invalid QoS {self.qos} for topic {topic}, using default QoS 1") diff --git a/tb_mqtt_client/common/queue.py b/tb_mqtt_client/common/queue.py new file mode 100644 index 0000000..3bc951a --- /dev/null +++ b/tb_mqtt_client/common/queue.py @@ -0,0 +1,97 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import deque + + +class AsyncDeque: + def __init__(self, maxlen): + self._deque = deque(maxlen=maxlen) + self._cond = asyncio.Condition() + self._in_queue = set() + + async def put(self, item): + if item.uuid in self._in_queue: + return + async with self._cond: + self._deque.append(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def extend(self, items): + async with self._cond: + for item in items: + if item.uuid not in self._in_queue: + self._deque.append(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def put_left(self, item): + if item.uuid in self._in_queue: + return + async with self._cond: + self._deque.appendleft(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def extend_left(self, items): + async with self._cond: + for item in reversed(items): + if item.uuid not in self._in_queue: + self._deque.appendleft(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def get(self): + async with self._cond: + while not self._deque: + await self._cond.wait() + item = self._deque.popleft() + self._in_queue.discard(item.uuid) + return item + + async def peek(self): + async with self._cond: + while not self._deque: + await self._cond.wait() + return self._deque[0] + + async def peek_batch(self, max_count: int): + async with self._cond: + while not self._deque: + await self._cond.wait() + return list(self._deque)[:max_count] + + async def pop_n(self, n: int): + async with self._cond: + items = [] + for _ in range(min(n, len(self._deque))): + item = self._deque.popleft() + self._in_queue.discard(item.uuid) + items.append(item) + return items + + async def reinsert_front(self, item): + async with self._cond: + if item.uuid not in self._in_queue: + self._deque.appendleft(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + def is_empty(self): + return not self._deque + + def size(self): + return len(self._deque) \ No newline at end of file diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py index 8436a12..639329d 100644 --- a/tb_mqtt_client/common/rate_limit/backpressure_controller.py +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import asyncio from asyncio import Event from datetime import datetime, timedelta, UTC -from typing import Optional +from typing import Optional, List from tb_mqtt_client.common.logging_utils import get_logger @@ -29,6 +31,9 @@ def __init__(self, main_stop_event: Event): self._consecutive_quota_exceeded = 0 self._last_quota_exceeded = datetime.now(UTC) self._max_backoff_seconds = 3600 # 1 hour + self._can_process_messages_events: List[asyncio.Event] = [] + logger.debug("BackpressureController initialized with default pause duration of %s seconds", + self._default_pause_duration.total_seconds()) def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): if self.__main_stop_event.is_set(): @@ -86,14 +91,25 @@ def should_pause(self) -> bool: # Reset pause state self._pause_until = None logger.info("Backpressure released, resuming publishing") + for event in self._can_process_messages_events: + if not event.is_set(): + event.set() + logger.debug("Set can-process event %s", event) return False - def pause_for(self, seconds: int): - self._pause_until = datetime.now(UTC) + timedelta(seconds=seconds) - logger.info("Manually pausing publishing for %d seconds", seconds) - def clear(self): if self._pause_until is not None: logger.info("Clearing backpressure pause") self._pause_until = None self._consecutive_quota_exceeded = 0 + + def register_can_process_event(self, event: Event): + """ + Register an event that will be set when the controller can process messages again. + This is useful for other components to wait until the backpressure is lifted. + """ + if not isinstance(event, Event): + raise ValueError("Expected an asyncio.Event instance") + self._can_process_messages_events.append(event) + logger.debug("Registered a new can-process event, total events: %d", + len(self._can_process_messages_events)) diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py index f02c862..d1b0c73 100644 --- a/tb_mqtt_client/common/rate_limit/rate_limit.py +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import asyncio import os import logging from asyncio import Lock @@ -69,6 +71,9 @@ def __init__(self, rate_limit: str, name: str = None, percentage: int = DEFAULT_ self._minimal_limit = float('inf') self.__reached_index = 0 self.__reached_index_time = 0 + self.__required_tokens = None + self._required_tokens_duration = None + self.required_tokens_ready = asyncio.Event() self._parse_string(rate_limit) @@ -112,8 +117,13 @@ async def refill(self): if self._no_limit: return async with self._lock: - for bucket in self._rate_buckets.values(): + for duration, bucket in self._rate_buckets.items(): bucket.refill() + if duration == self._required_tokens_duration: + if bucket.tokens >= self.__required_tokens: + self.required_tokens_ready.set() + self._required_tokens_duration = None + self.__required_tokens = None async def try_consume(self, amount=1): """ @@ -197,3 +207,22 @@ async def set_limit(self, rate_limit: str, percentage: int = DEFAULT_RATE_LIMIT_ self._minimal_limit = float('inf') self.percentage = percentage self._parse_string(rate_limit) + + def set_required_tokens(self, duration, expected_tokens): + """ + Set the required tokens for next message processing. + After this call, `required_tokens_ready` event will be set when + the required tokens are available in the rate limit buckets. + """ + self.__required_tokens = expected_tokens + self._required_tokens_duration = duration + + def clear_required_tokens_event(self): + """ + Clear the required tokens ready event. + This should be called when message processing is done + """ + self.required_tokens_ready.clear() + + +EMPTY_RATE_LIMIT = RateLimit("0:0,", name="Empty Rate Limit") diff --git a/tb_mqtt_client/entities/gateway/device_disconnect_message.py b/tb_mqtt_client/entities/gateway/device_disconnect_message.py index 2a81ab6..2e5bae1 100644 --- a/tb_mqtt_client/entities/gateway/device_disconnect_message.py +++ b/tb_mqtt_client/entities/gateway/device_disconnect_message.py @@ -13,20 +13,6 @@ # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from dataclasses import dataclass from typing import Dict diff --git a/tb_mqtt_client/entities/gateway/gateway_claim_request.py b/tb_mqtt_client/entities/gateway/gateway_claim_request.py index 1646cfa..12a1143 100644 --- a/tb_mqtt_client/entities/gateway/gateway_claim_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_claim_request.py @@ -12,19 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from dataclasses import dataclass from typing import Optional, Dict, Any, Union diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index e3b1947..2b041c7 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -176,3 +176,6 @@ def build(self) -> GatewayUplinkMessage: size=self.__size, main_ts=self._main_ts ) + + def __len__(self) -> int: + return self.__size diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 2ff4c91..c42f9c5 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -45,7 +45,7 @@ from tb_mqtt_client.service.device.handlers.rpc_requests_handler import RPCRequestsHandler from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter, MessageAdapter -from tb_mqtt_client.service.message_queue import MessageQueue +from tb_mqtt_client.service.message_service import MessageService from tb_mqtt_client.service.mqtt_manager import MQTTManager logger = get_logger(__name__) @@ -68,7 +68,7 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): super().__init__(self._config.host, self._config.port, client_id) - self._message_queue: Optional[MessageQueue] = None + self._message_queue: Optional[MessageService] = None self._message_adapter: MessageAdapter = JsonMessageAdapter(1000, 1) # Will be updated after connection established @@ -139,7 +139,7 @@ async def connect(self): self._message_adapter = JsonMessageAdapter(self.max_payload_size, self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit) - self._message_queue = MessageQueue( + self._message_queue = MessageService( mqtt_manager=self._mqtt_manager, main_stop_event=self._stop_event, device_rate_limiter=self._rate_limiter, diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index d528f54..37726e5 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -248,7 +248,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt device_groups[payload.device_name].append(mqtt_msg) logger.trace("Queued DeviceUplinkMessage for device='%s'", payload.device_name) else: - logger.warning("Unsupported payload type '%s', skipping", type(payload).__name__) + result.append(mqtt_msg) + logger.debug("Non-DeviceUplinkMessage found, sending as is: %s", type(payload).__name__) for device_name, group_msgs in device_groups.items(): telemetry_msgs = [m for m in group_msgs if m.payload.has_timeseries()] @@ -269,7 +270,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt qos=qos, datapoints=count, delivery_futures=child_futures, - main_ts=ts_batch.main_ts + main_ts=ts_batch.main_ts, + original_payload=ts_batch ) result.append(mqtt_msg) built_child_messages.append(mqtt_msg) @@ -292,7 +294,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt qos=qos, datapoints=count, delivery_futures=child_futures, - main_ts=attr_batch.main_ts + main_ts=attr_batch.main_ts, + original_payload=attr_batch ) result.append(mqtt_msg) built_child_messages.append(mqtt_msg) diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 6574d72..9710ffc 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -92,9 +92,9 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): device_manager=self.device_manager) # Gateway-specific rate limits - self._device_messages_rate_limit = RateLimit("10:1,", name="device_messages") - self._device_telemetry_rate_limit = RateLimit("10:1,", name="device_telemetry") - self._device_telemetry_dp_rate_limit = RateLimit("10:1,", name="device_telemetry_datapoints") + self._device_messages_rate_limit = RateLimit("0:0,", name="device_messages") + self._device_telemetry_rate_limit = RateLimit("0:0,", name="device_telemetry") + self._device_telemetry_dp_rate_limit = RateLimit("0:0,", name="device_telemetry_datapoints") # Callbacks self._device_attribute_update_callback = None @@ -398,7 +398,7 @@ async def _unsubscribe_from_gateway_topics(self): await sleep(0.01) async def _handle_rate_limit_response(self, response: RPCResponse): # noqa - parent_rate_limits_processing = await super()._handle_rate_limit_response(response) + device_rate_limits_processing_result = await super()._handle_rate_limit_response(response) try: if not isinstance(response.result, dict) or 'gatewayRateLimits' not in response.result: logger.warning("Invalid gateway rate limit response: %r", response) @@ -411,10 +411,10 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa await self._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit(gateway_rate_limits.get('telemetryDataPoints', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) self._gateway_message_adapter.splitter.max_payload_size = self.max_payload_size - self._gateway_message_adapter.splitter.max_datapoints = self._device_telemetry_dp_rate_limit.minimal_limit + self._gateway_message_adapter.splitter.max_datapoints = self._gateway_rate_limiter.telemetry_datapoints_rate_limit.minimal_limit self._mqtt_manager.set_gateway_rate_limits_received() - return parent_rate_limits_processing + return device_rate_limits_processing_result except Exception as e: logger.exception("Failed to parse rate limits from server response: %s", e) diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index 366efc3..a23febb 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio from time import time from dataclasses import dataclass, field diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index b080aa1..4061060 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -168,7 +168,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt device_groups[payload.device_name].append(payload) logger.trace("Queued GatewayUplinkMessage for device='%s'", payload.device_name) else: - logger.warning("Unsupported payload type '%s', skipping", type(payload).__name__) + result.append(mqtt_msg) + logger.debug("Non-GatewayUplinkMessage found, sending as is: %s", type(payload).__name__) # Process each device group for device_name, group_msgs in device_groups.items(): @@ -189,7 +190,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt qos=qos, datapoints=count, delivery_futures=futures, - main_ts=ts_batch.main_ts + main_ts=ts_batch.main_ts, + original_payload=ts_batch ) result.append(mqtt_msg) built_child_messages.append(mqtt_msg) @@ -207,7 +209,8 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt qos=qos, datapoints=count, delivery_futures=futures, - main_ts=attr_batch.main_ts + main_ts=attr_batch.main_ts, + original_payload=attr_batch ) result.append(mqtt_msg) built_child_messages.append(mqtt_msg) @@ -385,7 +388,7 @@ def pack_timeseries(msg: 'GatewayUplinkMessage') -> Union[Dict[str, Any], List[D if all_ts_none: result = {e.key: e.value for e in entries} - return [{"ts": msg.main_ts, "values": result}] if msg.main_ts is not None else result + return [{"ts": msg.main_ts, "values": result}] if msg.main_ts is not None else [result] now_ts = msg.main_ts if msg.main_ts is not None else int(datetime.now(UTC).timestamp() * 1000) grouped = defaultdict(dict) diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py index a8039a5..4c05d0c 100644 --- a/tb_mqtt_client/service/gateway/message_sender.py +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -26,7 +26,7 @@ from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter -from tb_mqtt_client.service.message_queue import MessageQueue +from tb_mqtt_client.service.message_service import MessageService logger = get_logger(__name__) @@ -38,7 +38,7 @@ class GatewayMessageSender: """ def __init__(self): - self._message_queue: Optional[MessageQueue] = None + self._message_queue: Optional[MessageService] = None self._message_adapter: Optional[GatewayMessageAdapter] = None async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: @@ -149,7 +149,7 @@ async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> await self._message_queue.publish(mqtt_message) return mqtt_message.delivery_futures - def set_message_queue(self, message_queue: MessageQueue): + def set_message_queue(self, message_queue: MessageService): """ Sets the message queue for sending uplink messages. diff --git a/tb_mqtt_client/service/gateway/multiplex_publisher.py b/tb_mqtt_client/service/gateway/multiplex_publisher.py deleted file mode 100644 index fa669aa..0000000 --- a/tb_mqtt_client/service/gateway/multiplex_publisher.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tb_mqtt_client/service/message_queue.py b/tb_mqtt_client/service/message_queue.py deleted file mode 100644 index 2a2d4e4..0000000 --- a/tb_mqtt_client/service/message_queue.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from contextlib import suppress -from typing import List, Optional - -from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL -from tb_mqtt_client.common.mqtt_message import MqttPublishMessage -from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter -from tb_mqtt_client.constants import mqtt_topics -from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TOPICS -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage -from tb_mqtt_client.service.device.message_adapter import MessageAdapter -from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter -from tb_mqtt_client.service.mqtt_manager import MQTTManager - -logger = get_logger(__name__) - - -class MessageQueue: - _BATCH_TIMEOUT = 0.01 # seconds to wait for batching (optional flush time) - - def __init__(self, - mqtt_manager: MQTTManager, - main_stop_event: asyncio.Event, - device_rate_limiter: RateLimiter, - message_adapter: MessageAdapter, - max_queue_size: int = 1000000, - batch_collect_max_time_ms: int = 100, - batch_collect_max_count: int = 500, - gateway_message_adapter: Optional[GatewayMessageAdapter] = None, - gateway_rate_limiter: Optional[RateLimiter] = None): - self._main_stop_event = main_stop_event - self._batch_max_time = batch_collect_max_time_ms / 1000 - self._batch_max_count = batch_collect_max_count - self._mqtt_manager = mqtt_manager - self._device_rate_limiter = device_rate_limiter - self._gateway_rate_limiter = gateway_rate_limiter - self._backpressure = self._mqtt_manager.backpressure - # Queue expects tuples of (mqtt_message, delivery_futures) - self._queue: asyncio.Queue[MqttPublishMessage] = asyncio.Queue(maxsize=max_queue_size) - self._pending_queue_tasks: set[asyncio.Task] = set() - self._active = asyncio.Event() - self._wakeup_event = asyncio.Event() - self._retry_tasks: set[asyncio.Task] = set() - self._active.set() - self._adapter = message_adapter - self._gateway_adapter = gateway_message_adapter - self._loop_task = asyncio.create_task(self._dequeue_loop()) - self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) - self.__print_queue_statistics_task = asyncio.create_task(self.print_queue_statistics()) - logger.debug("MessageQueue initialized: max_queue_size=%s, batch_time=%.3f, batch_count=%d", - max_queue_size, self._batch_max_time, batch_collect_max_count) - - async def publish(self, message: MqttPublishMessage) -> Optional[List[asyncio.Future[PublishResult]]]: - try: - if logger.isEnabledFor(TRACE_LEVEL): - logger.trace(f"Pushing message to queue with delivery futures: {[f.uuid for f in message.delivery_futures]}") - self._queue.put_nowait(message) - except asyncio.QueueFull: - logger.error("Message queue full. Dropping message for topic %s", message.topic) - for future in message.delivery_futures: - if future: - future.set_result(PublishResult(message.topic, message.qos, -1, len(message.payload), -1)) - - async def _dequeue_loop(self): - logger.debug("MessageQueue dequeue loop started.") - while self._active.is_set() and not self._main_stop_event.is_set(): - try: - try: - message = await self._wait_for_message() - if logger.isEnabledFor(TRACE_LEVEL): - logger.trace(f"Dequed message with delivery futures: {[f.uuid for f in message.delivery_futures]}") - await asyncio.sleep(0) # cooperative yield - except asyncio.TimeoutError: - logger.trace("Dequeue wait timed out. Yielding...") - await asyncio.sleep(0.001) - continue - except asyncio.CancelledError: - break - except Exception as e: - logger.warning("Unexpected error in dequeue loop: %s", e) - continue - - if isinstance(message, MqttPublishMessage) and isinstance(message.payload, bytes): - logger.trace("Dequeued immediate publish: topic=%s (raw bytes)", message.topic) - await self._try_publish(message) - continue - - logger.trace("Dequeued message for batching: topic=%s, device=%s", - message.topic, getattr(message.payload, 'device_name', 'N/A')) - - batch: List[MqttPublishMessage] = [message] - start = asyncio.get_event_loop().time() - batch_size = message.payload.size - batch_type = type(message.payload).__name__ - - while not self._queue.empty(): - elapsed = asyncio.get_event_loop().time() - start - if elapsed >= self._batch_max_time: - logger.trace("Batch time threshold reached: %.3fs", elapsed) - break - if len(batch) >= self._batch_max_count: - logger.trace("Batch count threshold reached: %d messages", len(batch)) - break - - try: - next_message = self._queue.get_nowait() - if isinstance(next_message.payload, DeviceUplinkMessage) or isinstance(next_message.payload, GatewayUplinkMessage): - if batch_type is not None and batch_type != type(next_message.payload).__name__: - logger.trace("Batch type mismatch: current=%s, next=%s, finalizing current", - batch_type, type(next_message.payload).__class__.__name__) - self._queue.put_nowait(next_message) - break - batch_type = type(next_message.payload).__name__ - msg_size = next_message.payload.size - if batch_size + msg_size > self._adapter.splitter.max_payload_size: # noqa - logger.trace("Batch size threshold exceeded: current=%d, next=%d", batch_size, msg_size) - self._queue.put_nowait(next_message) - break - batch.append(next_message) - batch_size += msg_size - else: - logger.trace("Immediate publish encountered in queue while batching: topic=%s", next_message.topic) - await self._try_publish(next_message) - except asyncio.QueueEmpty: - break - - if batch_type is None: - batch_type = type(message.payload).__name__ - - if batch: - logger.trace("Batching completed: %d messages, total size=%d", len(batch), batch_size) - - if batch_type == 'GatewayUplinkMessage' and self._gateway_adapter: - logger.trace("Building gateway uplink payloads for %d messages", len(batch)) - topic_payloads = self._gateway_adapter.build_uplink_messages(batch) - else: - topic_payloads = self._adapter.build_uplink_messages(batch) - - for built_message in topic_payloads: - logger.trace("Dispatching batched message: topic=%s, size=%d, datapoints=%d, delivery_futures=%r", - built_message.topic, - len(built_message.payload), - built_message.datapoints, - [f.uuid for f in built_message.delivery_futures]) - await self._try_publish(built_message) - - except Exception as e: - logger.error("Error in dequeue loop:", exc_info=e) - logger.debug("Dequeue loop error details:", exc_info=e) - if isinstance(e, asyncio.CancelledError): - break - continue - - async def _try_publish(self, message: MqttPublishMessage): - - logger.trace("Attempting publish: topic=%s, datapoints=%d", message.topic, message.datapoints) - - # Check backpressure first - if active, don't even try to check rate limits - if self._backpressure.should_pause(): - logger.debug("Backpressure active, delaying publish of topic=%s for %.1f seconds", message.topic, 1.0) - self._schedule_delayed_retry(message) - return - - await self.check_rate_limits_for_message(message) - - try: - if logger.isEnabledFor(TRACE_LEVEL): - logger.trace("Trying to publish topic=%s, payload size=%d, attached future id=%r", - message.topic, len(message.payload), [f.uuid for f in message.delivery_futures]) - - await self._mqtt_manager.publish(message) - - except Exception as e: - logger.warning("Failed to publish to topic %s: %s. Scheduling retry.", message.topic, e) - logger.debug("Scheduling retry for topic=%s, payload size=%d, qos=%d", - message.topic, len(message.payload), message.qos) - logger.debug("error details: %s", e, exc_info=True) - self._schedule_delayed_retry(message, delay=.1) - - def _schedule_delayed_retry(self, message: MqttPublishMessage, delay: float = 0.1): - if not self._active.is_set() or self._main_stop_event.is_set(): - logger.debug("MessageQueue is not active or main stop event is set. Not scheduling retry for topic=%s", message.topic) - return - logger.trace("Scheduling retry: topic=%s, delay=%.2f", message.topic, delay) - - task = asyncio.create_task(self.__retry_task(message, delay)) - self._retry_tasks.add(task) - task.add_done_callback(self._retry_tasks.discard) - - async def __retry_task(self, message: MqttPublishMessage, delay: float): - try: - logger.debug("Retrying publish: topic=%s", message.topic) - await asyncio.sleep(delay) - if not self._active.is_set() or self._main_stop_event.is_set(): - logger.debug( - "MessageQueue is not active or main stop event is set. Not re-enqueuing message for topic=%s", - message.topic) - return - self._queue.put_nowait(message) - self._wakeup_event.set() - logger.debug("Re-enqueued message after delay: topic=%s", message.topic) - except asyncio.QueueFull: - logger.warning("Retry queue full. Dropping retried message: topic=%s", message.topic) - except Exception as e: - logger.debug("Unexpected error during delayed retry: %s", e) - - async def _wait_for_message(self) -> MqttPublishMessage: - while self._active.is_set(): - try: - if not self._queue.empty(): - try: - return await self._queue.get() - except asyncio.QueueEmpty: - await asyncio.sleep(.01) - - self._wakeup_event.clear() - queue_task = asyncio.create_task(self._queue.get()) - self._pending_queue_tasks.add(queue_task) - - wake_task = asyncio.create_task(self._wakeup_event.wait()) - - done, pending = await asyncio.wait([queue_task, wake_task], return_when=asyncio.FIRST_COMPLETED) - - for task in pending: - task.cancel() - with suppress(asyncio.CancelledError): - await task - - self._pending_queue_tasks.discard(queue_task) - - if queue_task in done: - logger.trace("Retrieved message from queue: %r", queue_task.result()) - return queue_task.result() - - await asyncio.sleep(0) # Yield control to the event loop - - except asyncio.CancelledError: - break - - raise asyncio.CancelledError("MessageQueue is shutting down or stopped.") - - async def shutdown(self): - logger.debug("Shutting down MessageQueue...") - self._active.clear() - self._wakeup_event.set() # Wake up the _wait_for_message if it's blocked - - await self._cancel_tasks(self._retry_tasks) - await self._cancel_tasks(self._pending_queue_tasks) - - self._loop_task.cancel() - if self._rate_limit_refill_task: - self._rate_limit_refill_task.cancel() - self.__print_queue_statistics_task.cancel() - with suppress(asyncio.CancelledError): - await self._loop_task - await self._rate_limit_refill_task - await self.__print_queue_statistics_task - - self.clear() - - logger.debug("MessageQueue shutdown complete, message queue size: %d", - self._queue.qsize()) - - async def check_rate_limits_for_message(self, message: MqttPublishMessage): - - message_rate_limit = None - datapoints_rate_limit = None - - is_message_with_telemetry_or_attributes = message.topic in mqtt_topics.TOPICS_WITH_DATAPOINTS_CHECK - is_gateway_message = message.topic in GATEWAY_TOPICS - - if is_gateway_message: - if is_message_with_telemetry_or_attributes: - message_rate_limit = self._gateway_rate_limiter.telemetry_message_rate_limit - datapoints_rate_limit = self._gateway_rate_limiter.telemetry_datapoints_rate_limit - else: - message_rate_limit = self._gateway_rate_limiter.message_rate_limit - else: - if is_message_with_telemetry_or_attributes: - message_rate_limit = self._device_rate_limiter.telemetry_message_rate_limit - datapoints_rate_limit = self._device_rate_limiter.telemetry_datapoints_rate_limit - else: - message_rate_limit = self._device_rate_limiter.message_rate_limit - - retry_delay = None - - if message_rate_limit: - triggered_rate_limit = await message_rate_limit.try_consume(1) - if triggered_rate_limit: - logger.debug("Rate limit hit for topic %s: %r per %r seconds", - message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) - retry_delay = message_rate_limit.minimal_timeout - if datapoints_rate_limit and retry_delay is None: - triggered_rate_limit = await datapoints_rate_limit.try_consume(message.datapoints) - if triggered_rate_limit: - logger.debug("Datapoint rate limit hit for topic %s: %r per %r seconds", - message.topic, triggered_rate_limit[0], triggered_rate_limit[1]) - retry_delay = datapoints_rate_limit.minimal_timeout - - if retry_delay is not None: - self._schedule_delayed_retry(message, delay=retry_delay) - - - @staticmethod - async def _cancel_tasks(tasks: set[asyncio.Task]): - for task in list(tasks): - task.cancel() - with suppress(asyncio.CancelledError): - await asyncio.gather(*tasks, return_exceptions=True) - tasks.clear() - - def is_empty(self): - return self._queue.empty() - - def clear(self): - logger.debug("Clearing message queue...") - while not self._queue.empty(): - message = self._queue.get_nowait() - for future in message.delivery_futures: - future.set_result(PublishResult( - topic=message.topic, - qos=message.qos, - message_id=-1, - payload_size=message.payload_size, - reason_code=-1 - )) - self._queue.task_done() - logger.debug("Message queue cleared.") - - async def _rate_limit_refill_loop(self): - try: - while self._active.is_set(): - await asyncio.sleep(1.0) - await self._refill_rate_limits() - logger.trace("Rate limits refilled, state: %s", self._device_rate_limiter) - except asyncio.CancelledError: - pass - - async def _refill_rate_limits(self): - for rl in self._device_rate_limiter.values(): - if rl: - await rl.refill() - if self._gateway_rate_limiter: - for rl in self._gateway_rate_limiter.values(): - if rl: - await rl.refill() - - def set_gateway_message_adapter(self, message_adapter: GatewayMessageAdapter): - self._gateway_adapter = message_adapter - - async def print_queue_statistics(self): - """ - Prints the current statistics of the message queue. - """ - while self._active.is_set() and not self._main_stop_event.is_set(): - queue_size = self._queue.qsize() - pending_tasks = len(self._pending_queue_tasks) - retry_tasks = len(self._retry_tasks) - active = self._active.is_set() - logger.info("MessageQueue Statistics: " - "Queue Size: %d, Pending Tasks: %d, Retry Tasks: %d, Active: %s", - queue_size, pending_tasks, retry_tasks, active) - await asyncio.sleep(60) diff --git a/tb_mqtt_client/service/message_service.py b/tb_mqtt_client/service/message_service.py new file mode 100644 index 0000000..a7bee96 --- /dev/null +++ b/tb_mqtt_client/service/message_service.py @@ -0,0 +1,405 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from contextlib import suppress +from typing import List, Optional, Tuple, Union + +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.queue import AsyncDeque +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, EMPTY_RATE_LIMIT +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.service.device.message_adapter import MessageAdapter +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.mqtt_manager import MQTTManager + +logger = get_logger(__name__) + + +class MessageService: + _QUEUE_COOLDOWN = 0.01 # seconds to sleep before the next iteration in case of empty queues + + def __init__(self, + mqtt_manager: MQTTManager, + main_stop_event: asyncio.Event, + device_rate_limiter: RateLimiter, + message_adapter: MessageAdapter, + max_queue_size: int = 1000000, + batch_collect_max_time_ms: int = 10, + batch_collect_max_count: int = 500, + gateway_message_adapter: Optional[GatewayMessageAdapter] = None, + gateway_rate_limiter: Optional[RateLimiter] = None): + self._main_stop_event = main_stop_event + self._batch_max_time = batch_collect_max_time_ms / 1000 + self._batch_max_count = batch_collect_max_count + self._mqtt_manager = mqtt_manager + # Patching MQTT client to add an ability to use retry logic + self._mqtt_manager.patch_client_for_retry_logic(self.put_retry_message) + self._device_rate_limiter = device_rate_limiter + self._gateway_rate_limiter = gateway_rate_limiter + self._rate_limit_ready = asyncio.Event() + self._backpressure = self._mqtt_manager.backpressure + self._retry_by_qos_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._initial_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._service_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._service_message_worker = MessageQueueWorker("ServiceMessageWorker", + self._service_queue, + self._main_stop_event, + self._mqtt_manager, + self._device_rate_limiter, + self._gateway_rate_limiter,) + self._device_uplink_messages_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._device_uplink_message_worker = MessageQueueWorker("DeviceUplinkMessageWorker", + self._device_uplink_messages_queue, + self._main_stop_event, + self._mqtt_manager, + self._device_rate_limiter, + self._gateway_rate_limiter) + self._gateway_uplink_messages_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._gateway_uplink_message_worker = MessageQueueWorker("GatewayUplinkMessageWorker", + self._gateway_uplink_messages_queue, + self._main_stop_event, + self._mqtt_manager, + self._device_rate_limiter, + self._gateway_rate_limiter) + self._active = asyncio.Event() + self._wakeup_event = asyncio.Event() + self._can_process_device_uplink_event = asyncio.Event() + self._can_process_device_uplink_event.set() + self._can_process_gateway_uplink_event = asyncio.Event() + self._can_process_gateway_uplink_event.set() + self._active.set() + self._adapter = message_adapter + self._gateway_adapter = gateway_message_adapter + + self._retry_by_qos_task = asyncio.create_task(self._dispatch_retry_by_qos_queue_loop()) + self._initial_queue_task = asyncio.create_task(self._dispatch_initial_queue_loop()) + self._service_queue_task = asyncio.create_task(self._dispatch_queue_loop(self._service_queue, self._service_message_worker)) + self._device_uplink_messages_queue_task = asyncio.create_task(self._dispatch_queue_loop(self._device_uplink_messages_queue, self._device_uplink_message_worker)) + self._gateway_uplink_messages_queue_task = asyncio.create_task(self._dispatch_queue_loop(self._gateway_uplink_messages_queue, self._gateway_uplink_message_worker)) + + self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) + self.__print_queue_statistics_task = asyncio.create_task(self.print_queues_statistics()) + + async def publish(self, message: MqttPublishMessage) -> Optional[List[asyncio.Future[PublishResult]]]: + """ + Publish a message to the message queue. + :param message: The MqttPublishMessage to publish. + :return: A list of futures for delivery results. + """ + try: + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace(f"Pushing message to queue with delivery futures: {[f.uuid for f in message.delivery_futures]}") + await self._initial_queue.put(message) + except Exception as e: + logger.error("Failed to push message to queue: %s", e) + for future in message.delivery_futures: + if future and not future.done(): + future.set_result(PublishResult(message.topic, message.qos, -1, len(message.original_payload), -1)) + + async def _dispatch_initial_queue_loop(self): + """ + Loop to process messages from the initial queue and dispatch them to the appropriate queues. + """ + while not self._main_stop_event.is_set() and self._active.is_set(): + try: + if not self._mqtt_manager.is_connected(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + peeked_batch = await self._initial_queue.peek_batch(self._batch_max_count) + + gateway_messages = [] + device_messages = [] + for message in peeked_batch: + if isinstance(message.original_payload, bytes): + # If the message is a raw bytes payload, put it directly into the service queue + await self._service_queue.put(message) + elif isinstance(message.original_payload, GatewayUplinkMessage): + # If the message is a GatewayUplinkMessage, process it with the gateway adapter + gateway_messages.append(message) + elif isinstance(message.original_payload, DeviceUplinkMessage): + # If the message is a DeviceUplinkMessage, process it with the device adapter + device_messages.append(message) + else: + logger.warning("Unknown message type in initial queue: %s", type(message.original_payload)) + + if gateway_messages: + # Process gateway messages in batches + messages = self._gateway_adapter.build_uplink_messages(gateway_messages) + await self._gateway_uplink_messages_queue.extend(messages) + + if device_messages: + # Process device messages in batches + messages = self._adapter.build_uplink_messages(device_messages) + await self._device_uplink_messages_queue.extend(messages) + + await self._initial_queue.pop_n(len(peeked_batch)) + + if self._initial_queue.is_empty(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + + except asyncio.CancelledError: + break + except Exception as e: + logger.exception("Dispatch loop error: %s", e) + await asyncio.sleep(0.5) + + async def _dispatch_queue_loop(self, queue: AsyncDeque, worker: 'MessageQueueWorker'): + """Loop to process messages from the service queue.""" + while not self._main_stop_event.is_set() and self._active.is_set(): + message = None + try: + if not self._mqtt_manager.is_connected(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + message = await queue.get() + if not message: + await asyncio.sleep(0.01) + continue + logger.trace("Processing message from queue: %s, message_id: %s, message payload: %s", + message.topic, message.message_id, message.original_payload) + expected_duration, expected_tokens, triggered_rate_limit = await worker.process(message) + if triggered_rate_limit: + logger.trace("Reinserting message to the front of the queue: %s, message payload: %s", + message.uuid, message.original_payload) + await queue.reinsert_front(message) + triggered_rate_limit.set_required_tokens(expected_duration, expected_tokens) + await triggered_rate_limit.required_tokens_ready.wait() + + if queue.is_empty(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + + except asyncio.CancelledError: + break + except Exception as e: + logger.exception("Service queue loop error: %s", e) + if message: + await queue.reinsert_front(message) + await asyncio.sleep(1) + + async def _dispatch_retry_by_qos_queue_loop(self): + while not self._main_stop_event.is_set() and self._active.is_set(): + message = None + try: + if not self._mqtt_manager.is_connected(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + message = await self._retry_by_qos_queue.get() + if not message: + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + + logger.trace("Retrying QoS message: %s", message) + + if isinstance(message.original_payload, bytes): + await self._service_queue.reinsert_front(message) + elif isinstance(message.original_payload, GatewayUplinkMessage): + await self._gateway_uplink_messages_queue.reinsert_front(message) + elif isinstance(message.original_payload, DeviceUplinkMessage): + await self._device_uplink_messages_queue.reinsert_front(message) + else: + logger.warning("Unknown message type in retry queue: %s", type(message.original_payload)) + + except asyncio.CancelledError: + break + except Exception as e: + logger.exception("Retry QoS dispatch error: %s", e) + if message: + await self._retry_by_qos_queue.reinsert_front(message) + await asyncio.sleep(0.5) + + async def shutdown(self): + logger.debug("Shutting down MessageQueue...") + self._active.clear() + self._wakeup_event.set() + + self._retry_by_qos_task.cancel() + self._initial_queue_task.cancel() + self._service_queue_task.cancel() + self._device_uplink_messages_queue_task.cancel() + self._gateway_uplink_messages_queue_task.cancel() + if self._rate_limit_refill_task: + self._rate_limit_refill_task.cancel() + self.__print_queue_statistics_task.cancel() + with suppress(asyncio.CancelledError): + await self._retry_by_qos_task + await self._initial_queue_task + await self._rate_limit_refill_task + await self.__print_queue_statistics_task + + await self.clear() + + logger.debug("MessageQueue shutdown complete, message queue size: %d", + self._initial_queue.size()) + + @staticmethod + async def _cancel_tasks(tasks: set[asyncio.Task]): + for task in list(tasks): + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*tasks, return_exceptions=True) + tasks.clear() + + def is_empty(self): + return self._initial_queue.is_empty() + + async def clear(self): + logger.debug("Clearing message queue...") + for queue in [self._initial_queue, self._service_queue, + self._device_uplink_messages_queue, self._gateway_uplink_messages_queue]: + while not queue.is_empty(): + message: MqttPublishMessage = await queue.get() + for future in message.delivery_futures: + if future and not future.done(): + future.set_result(PublishResult( + topic=message.topic, + qos=message.qos, + message_id=-1, + payload_size=message.payload_size if isinstance(message.payload, bytes) else message.payload.size, + reason_code=-1 + )) + logger.debug("Message queue cleared.") + + async def _rate_limit_refill_loop(self): + try: + while self._active.is_set(): + await asyncio.sleep(1.0) + await self._refill_rate_limits() + logger.trace("Rate limits refilled, state: %s", self._device_rate_limiter) + except asyncio.CancelledError: + pass + + async def _refill_rate_limits(self): + for rl in self._device_rate_limiter.values(): + if rl: + await rl.refill() + if self._gateway_rate_limiter: + for rl in self._gateway_rate_limiter.values(): + if rl: + await rl.refill() + + def set_gateway_message_adapter(self, message_adapter: GatewayMessageAdapter): + self._gateway_adapter = message_adapter + + async def put_retry_message(self, message: MqttPublishMessage): + await self._retry_by_qos_queue.put(message) + + async def print_queues_statistics(self): + """ + Prints the current statistics of message queues. + """ + while self._active.is_set() and not self._main_stop_event.is_set(): + retry_queue_size = self._retry_by_qos_queue.size() + initial_queue_size = self._initial_queue.size() + service_queue_size = self._service_queue.size() + device_uplink_queue_size = self._device_uplink_messages_queue.size() + gateway_uplink_queue_size = self._gateway_uplink_messages_queue.size() + active = self._active.is_set() + logger.info("MessageQueue Statistics: " + "Initial Queue Size: %d, " + "Service Queue Size: %d, " + "Device Uplink Queue Size: %d, " + "Gateway Uplink Queue Size: %d, " + "Retry Queue Size: %d, " + "Active: %s", + initial_queue_size, + service_queue_size, + device_uplink_queue_size, + gateway_uplink_queue_size, + retry_queue_size, + active) + await asyncio.sleep(60) + +class MessageQueueWorker: + def __init__(self, + name, + queue: AsyncDeque, + stop_event: asyncio.Event, + mqtt_manager: MQTTManager, + device_rate_limiter: RateLimiter, + gateway_rate_limiter: Optional[RateLimiter] = None): + self.name = name + self._queue = queue + self._stop_event = stop_event + self._mqtt_manager = mqtt_manager + self._device_rate_limiter = device_rate_limiter + self._gateway_rate_limiter = gateway_rate_limiter + + async def process(self, message: MqttPublishMessage) -> Tuple[Optional[int], Optional[int], Optional[RateLimit]]: + message_rate_limit, datapoints_rate_limit = self._get_rate_limits_for_message(message) + if message_rate_limit.has_limit() or datapoints_rate_limit.has_limit(): + triggered_rate_limit_entry, expected_tokens, rate_limit = await self.check_rate_limits_for_message(datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit) + + if triggered_rate_limit_entry is not None: + triggered_duration = triggered_rate_limit_entry[1] + logger.debug("Rate limit %r per %r seconds hit by message with expected tokens: %r.", + triggered_rate_limit_entry[0], triggered_duration, expected_tokens) + return triggered_duration, expected_tokens, rate_limit + + await self._consume_rate_limits_for_message(datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit) + await self._mqtt_manager.publish(message) + return None, None, None + + + def _get_rate_limits_for_message(self, message: MqttPublishMessage) -> Tuple[RateLimit, RateLimit]: + + message_rate_limit = EMPTY_RATE_LIMIT + datapoints_rate_limit = EMPTY_RATE_LIMIT + + if message.is_device_message: + if message.is_service_message: + message_rate_limit = self._device_rate_limiter.message_rate_limit + else: + message_rate_limit = self._device_rate_limiter.telemetry_message_rate_limit + datapoints_rate_limit = self._device_rate_limiter.telemetry_datapoints_rate_limit + + else: + if message.is_service_message: + message_rate_limit = self._gateway_rate_limiter.message_rate_limit + else: + message_rate_limit = self._gateway_rate_limiter.telemetry_message_rate_limit + datapoints_rate_limit = self._gateway_rate_limiter.telemetry_datapoints_rate_limit + + return message_rate_limit, datapoints_rate_limit + + @staticmethod + async def check_rate_limits_for_message(datapoints_count: int, + message_rate_limit: RateLimit, + datapoints_rate_limit: RateLimit) -> Tuple[Union[Tuple[int, int], None], int, Optional[RateLimit]]: + if message_rate_limit and message_rate_limit.has_limit(): + triggered_rate_limit_entry = await message_rate_limit.try_consume(1) + if triggered_rate_limit_entry: + return triggered_rate_limit_entry, 1, message_rate_limit + if datapoints_rate_limit and datapoints_rate_limit.has_limit(): + triggered_rate_limit_entry = await datapoints_rate_limit.try_consume(datapoints_count) + if triggered_rate_limit_entry: + return triggered_rate_limit_entry, datapoints_count, datapoints_rate_limit + return None, 0, None + + @staticmethod + async def _consume_rate_limits_for_message(datapoints_count: int, + message_rate_limit: RateLimit, + datapoints_rate_limit: RateLimit) -> None: + if message_rate_limit and message_rate_limit.has_limit(): + await message_rate_limit.consume(1) + if datapoints_rate_limit and datapoints_rate_limit.has_limit(): + await datapoints_rate_limit.consume(datapoints_count) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index 79d587f..b5030d6 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -23,7 +23,7 @@ from gmqtt import Client as GMQTTClient, Subscription from tb_mqtt_client.common.async_utils import await_or_stop, future_map, run_coroutine_sync -from tb_mqtt_client.common.gmqtt_patch import PatchUtils +from tb_mqtt_client.common.gmqtt_patch import PatchUtils, PublishPacket from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult @@ -153,7 +153,7 @@ async def disconnect(self): async def publish(self, message: MqttPublishMessage, - qos: int = 1, + qos: int = 1, # TODO: probably should be removed, as qos is set in MqttPublishMessage force=False): if not force: @@ -172,6 +172,33 @@ async def publish(self, logger.trace("Backpressure active. Publishing suppressed.") raise RuntimeError("Publishing temporarily paused due to backpressure.") + if not message.dup: + return await self.process_regular_publish(message, qos) + else: + # If a message is a duplicate, we should not process it as a regular publish message, it should be sent immediately + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace("Processing duplicate message with topic: %s, qos: %d, payload size: %d", + message.topic, qos, len(message.payload)) + + protocol = self._client._connection._protocol # noqa + + if protocol: + try: + mid, rebuilt = PublishPacket.build_package( + message=message, + protocol=protocol, + mid=message.message_id + ) + self._client._connection.send_package(rebuilt) + logger.trace("Retransmitted message mid=%r", mid) + except Exception as e: + logger.warning("Error during retransmission: %s", e) + logger.debug("Failed to retransmit message: %r", message, exc_info=e) + else: + logger.warning("Cannot retransmit, MQTT protocol unavailable.") + self._client._persistent_storage.push_message_nowait(message.message_id, message) # noqa + + async def process_regular_publish(self, message: MqttPublishMessage, qos: int = 1): mqtt_future = asyncio.get_event_loop().create_future() mqtt_future.uuid = uuid4() @@ -179,37 +206,7 @@ async def publish(self, logger.trace("Publishing message with topic: %s, qos: %d, payload size: %d, mqtt_future id: %r, delivery futures: %r", message.topic, qos, len(message.payload), mqtt_future.uuid, [f.uuid for f in message.delivery_futures]) if message.delivery_futures is not None: - def resolve_attached(publish_future: asyncio.Future): - try: - try: - publish_result = publish_future.result() - except asyncio.CancelledError: - logger.info("Publish future was cancelled: %r, id: %r", publish_future, publish_future.uuid) - publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) - except Exception as exc: - logger.warning("Publish failed with exception: %s", exc) - logger.debug("Resolving delivery futures with failure:", exc_info=exc) - publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) - - for i, f in enumerate(message.delivery_futures or []): - if f is not None and not f.done(): - f.set_result(publish_result) - future_map.child_resolved(f) - logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r", - i, f.uuid, publish_result, publish_future.uuid) - except Exception as e: - logger.error("Error resolving delivery futures: %s", str(e)) - for i, f in enumerate(message.delivery_futures or []): - if f is not None and not f.done(): - f.set_exception(e) - logger.debug("Set exception for delivery future #%d id=%r", i, f.uuid) - - logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", mqtt_future.uuid, mqtt_future.done()) - if mqtt_future.done(): - logger.debug("Main publish future is already done, resolving immediately.") - resolve_attached(mqtt_future) - else: - mqtt_future.add_done_callback(resolve_attached) + await self._add_future_chain_processing(mqtt_future, message) mid, package = self._client._connection.publish(message) # noqa @@ -466,6 +463,9 @@ async def _monitor_ack_timeouts(self): await asyncio.sleep(0.1) await self.check_pending_publishes(monotonic()) + def patch_client_for_retry_logic(self, put_retry_message_method: Callable[[MqttPublishMessage], Coroutine[Any, Any, None]]): + self._client.put_retry_message = put_retry_message_method + async def check_pending_publishes(self, time_to_check): expired = [] for mid, (future, message, timestamp) in list(self._pending_publishes.items()): @@ -496,3 +496,38 @@ async def stop(self): if self._client.is_connected: await self._client.disconnect() + + @staticmethod + async def _add_future_chain_processing(mqtt_future, message: MqttPublishMessage): + def resolve_attached(publish_future: asyncio.Future): + try: + try: + publish_result = publish_future.result() + except asyncio.CancelledError: + logger.info("Publish future was cancelled: %r, id: %r", publish_future, publish_future.uuid) + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + except Exception as exc: + logger.warning("Publish failed with exception: %s", exc) + logger.debug("Resolving delivery futures with failure:", exc_info=exc) + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_result(publish_result) + future_map.child_resolved(f) + logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r", + i, f.uuid, publish_result, publish_future.uuid) + except Exception as e: + logger.error("Error resolving delivery futures: %s", str(e)) + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_exception(e) + logger.debug("Set exception for delivery future #%d id=%r", i, f.uuid) + + logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", + mqtt_future.uuid, mqtt_future.done()) + if mqtt_future.done(): + logger.debug("Main publish future is already done, resolving immediately.") + resolve_attached(mqtt_future) + else: + mqtt_future.add_done_callback(resolve_attached) diff --git a/tests/common/test_backpressure_controller.py b/tests/common/test_backpressure_controller.py index 15c3d43..1670979 100644 --- a/tests/common/test_backpressure_controller.py +++ b/tests/common/test_backpressure_controller.py @@ -98,11 +98,6 @@ def test_should_pause_stop_event_set(stop_event, controller): assert controller.should_pause() is False -def test_pause_for(controller): - controller.pause_for(12) - assert controller._pause_until is not None - - def test_clear(controller): controller._pause_until = datetime.now(UTC) + timedelta(seconds=30) controller._consecutive_quota_exceeded = 5 diff --git a/tests/common/test_config_loader.py b/tests/common/test_config_loader.py new file mode 100644 index 0000000..7723044 --- /dev/null +++ b/tests/common/test_config_loader.py @@ -0,0 +1,103 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from tb_mqtt_client.common.config_loader import DeviceConfig, GatewayConfig + +class TestDeviceConfig(unittest.TestCase): + + def loads_default_values_when_env_vars_missing(self): + os.environ.clear() + config = DeviceConfig() + self.assertEqual(config.host, None) + self.assertEqual(config.port, 1883) + self.assertEqual(config.access_token, None) + self.assertEqual(config.username, None) + self.assertEqual(config.password, None) + self.assertEqual(config.client_id, None) + self.assertEqual(config.ca_cert, None) + self.assertEqual(config.client_cert, None) + self.assertEqual(config.private_key, None) + self.assertEqual(config.qos, 1) + + def loads_values_from_env_vars(self): + os.environ["TB_HOST"] = "test_host" + os.environ["TB_PORT"] = "8883" + os.environ["TB_ACCESS_TOKEN"] = "test_token" + os.environ["TB_USERNAME"] = "test_user" + os.environ["TB_PASSWORD"] = "test_pass" + os.environ["TB_CLIENT_ID"] = "test_client" + os.environ["TB_CA_CERT"] = "test_ca" + os.environ["TB_CLIENT_CERT"] = "test_cert" + os.environ["TB_PRIVATE_KEY"] = "test_key" + os.environ["TB_QOS"] = "2" + + config = DeviceConfig() + self.assertEqual(config.host, "test_host") + self.assertEqual(config.port, 8883) + self.assertEqual(config.access_token, "test_token") + self.assertEqual(config.username, "test_user") + self.assertEqual(config.password, "test_pass") + self.assertEqual(config.client_id, "test_client") + self.assertEqual(config.ca_cert, "test_ca") + self.assertEqual(config.client_cert, "test_cert") + self.assertEqual(config.private_key, "test_key") + self.assertEqual(config.qos, 2) + + def detects_tls_auth_correctly(self): + os.environ["TB_CA_CERT"] = "test_ca" + os.environ["TB_CLIENT_CERT"] = "test_cert" + os.environ["TB_PRIVATE_KEY"] = "test_key" + config = DeviceConfig() + self.assertTrue(config.use_tls_auth()) + + def detects_tls_correctly(self): + os.environ["TB_CA_CERT"] = "test_ca" + config = DeviceConfig() + self.assertTrue(config.use_tls()) + +class TestGatewayConfig(unittest.TestCase): + + def loads_gateway_specific_env_vars(self): + os.environ["TB_GW_HOST"] = "gw_host" + os.environ["TB_GW_PORT"] = "8884" + os.environ["TB_GW_ACCESS_TOKEN"] = "gw_token" + os.environ["TB_GW_USERNAME"] = "gw_user" + os.environ["TB_GW_PASSWORD"] = "gw_pass" + os.environ["TB_GW_CLIENT_ID"] = "gw_client" + os.environ["TB_GW_CA_CERT"] = "gw_ca" + os.environ["TB_GW_CLIENT_CERT"] = "gw_cert" + os.environ["TB_GW_PRIVATE_KEY"] = "gw_key" + os.environ["TB_GW_QOS"] = "0" + + config = GatewayConfig() + self.assertEqual(config.host, "gw_host") + self.assertEqual(config.port, 8884) + self.assertEqual(config.access_token, "gw_token") + self.assertEqual(config.username, "gw_user") + self.assertEqual(config.password, "gw_pass") + self.assertEqual(config.client_id, "gw_client") + self.assertEqual(config.ca_cert, "gw_ca") + self.assertEqual(config.client_cert, "gw_cert") + self.assertEqual(config.private_key, "gw_key") + self.assertEqual(config.qos, 0) + + def falls_back_to_device_config_when_gateway_env_vars_missing(self): + os.environ.clear() + os.environ["TB_HOST"] = "device_host" + os.environ["TB_PORT"] = "1884" + config = GatewayConfig() + self.assertEqual(config.host, "device_host") + self.assertEqual(config.port, 1884) \ No newline at end of file diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py index 0e6a627..7e7dc48 100644 --- a/tests/service/device/test_device_client.py +++ b/tests/service/device/test_device_client.py @@ -446,7 +446,7 @@ async def test_initializes_adapter_and_queue_after_connection(): mqtt_manager = AsyncMock() mqtt_manager.is_connected.return_value = True - with patch("tb_mqtt_client.service.device.client.MessageQueue") as mock_queue: + with patch("tb_mqtt_client.service.device.client.MessageService") as mock_message_service: client = DeviceClient(config) client._mqtt_manager = mqtt_manager @@ -456,8 +456,8 @@ async def test_initializes_adapter_and_queue_after_connection(): assert client._message_adapter is not None assert client._message_queue is not None - mock_queue.assert_called_once() - kwargs = mock_queue.call_args.kwargs + mock_message_service.assert_called_once() + kwargs = mock_message_service.call_args.kwargs assert kwargs["max_queue_size"] == client._max_uplink_message_queue_size diff --git a/tests/service/test_message_queue.py b/tests/service/test_message_queue.py deleted file mode 100644 index 517a42d..0000000 --- a/tests/service/test_message_queue.py +++ /dev/null @@ -1,596 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from unittest.mock import Mock, MagicMock - -import pytest -import pytest_asyncio - -from tb_mqtt_client.common.mqtt_message import MqttPublishMessage -from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController -from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit -from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter -from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder -from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter -from tb_mqtt_client.service.message_queue import MessageQueue - - -@pytest_asyncio.fixture -async def fake_mqtt_manager(): - mgr = AsyncMock() - mgr.backpressure = Mock(spec=BackpressureController) - mgr.backpressure.should_pause.return_value = False - return mgr - - -@pytest_asyncio.fixture -async def message_queue(fake_mqtt_manager): - main_stop_event = asyncio.Event() - message_adapter = Mock() - message_adapter.build_uplink_messages.return_value = [] - - mq = MessageQueue( - mqtt_manager=fake_mqtt_manager, - main_stop_event=main_stop_event, - device_rate_limiter=RateLimiter(RateLimit('0:0,'), RateLimit('0:0,'), RateLimit('0:0,')), - message_adapter=message_adapter, - max_queue_size=10, - batch_collect_max_time_ms=10, - batch_collect_max_count=5 - ) - - try: - yield mq - finally: - await mq.shutdown() - - -@pytest.mark.asyncio -async def test_publish_success(message_queue): - message = MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[]) - await message_queue.publish(message) - assert not message_queue.is_empty() - - -@pytest.mark.asyncio -async def test_publish_queue_full(message_queue): - for _ in range(10): - await message_queue.publish(MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[])) - message = MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[Mock()]) - await message_queue.publish(message) # Should not raise - assert message_queue._queue.qsize() <= 10 - - -@pytest.mark.asyncio -async def test_shutdown_clears_tasks_and_queue(fake_mqtt_manager): - q = MessageQueue( - mqtt_manager=fake_mqtt_manager, - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=Mock(build_uplink_messages=Mock(return_value=[])), - max_queue_size=10, - batch_collect_max_time_ms=10, - batch_collect_max_count=5 - ) - await q.publish(MqttPublishMessage("topic", b"data", qos=1, delivery_futures=[])) - await q.shutdown() - assert q.is_empty() - - -@pytest.mark.asyncio -async def test_try_publish_backpressure_delay(message_queue): - message_queue._mqtt_manager.backpressure.should_pause.return_value = True - - built = DeviceUplinkMessageBuilder().add_timeseries( - TimeseriesEntry("test", "test") - ).build() - message = MqttPublishMessage("topic", built, qos=1, delivery_futures=[]) - - # Patch the retry scheduler to bypass delay and requeue immediately - with patch.object(message_queue, "_schedule_delayed_retry") as mocked_retry: - mocked_retry.side_effect = lambda m: message_queue._queue.put_nowait(m) - - await message_queue._try_publish(message) - - # Wait briefly for the message to be re-enqueued - for _ in range(20): - if not message_queue.is_empty(): - break - await asyncio.sleep(0.01) - else: - raise AssertionError("Message was not re-enqueued in time") - - # Ensure _schedule_delayed_retry was triggered - assert mocked_retry.called, "_schedule_delayed_retry should be called" - - -@pytest.mark.asyncio -async def test_try_publish_rate_limit_triggered(): - limit = Mock(spec=RateLimit) - limit.try_consume = AsyncMock(return_value=(10, 1)) - limit.minimal_timeout = 0.01 - - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(AsyncMock(), limit, AsyncMock()), - message_adapter=Mock(build_uplink_messages=Mock(return_value=[])), - ) - mq._schedule_delayed_retry = AsyncMock() - mq._mqtt_manager.publish = AsyncMock() - message = MqttPublishMessage("telemetry", b"{}", qos=1, delivery_futures=[]) - await mq._try_publish(message) - await asyncio.sleep(0.05) - assert mq._mqtt_manager.publish.call_count == 0 - assert mq._schedule_delayed_retry.call_count == 1 - - -@pytest.mark.asyncio -async def test_try_publish_failure_schedules_retry(message_queue): - message = MqttPublishMessage("topic", b"broken", qos=1, delivery_futures=[]) - message_queue._mqtt_manager.publish.side_effect = Exception("fail") - message_queue._schedule_delayed_retry = AsyncMock() - message_queue._mqtt_manager._backpressure.should_pause.return_value = False - await message_queue._try_publish(message) - await asyncio.sleep(0.2) - assert message_queue._schedule_delayed_retry.call_count == 1 - - -@pytest.mark.asyncio -async def test_cancel_tasks(): - task1 = asyncio.create_task(asyncio.sleep(5)) - task2 = asyncio.create_task(asyncio.sleep(5)) - tasks = {task1, task2} - await MessageQueue._cancel_tasks(tasks) - assert all(t.cancelled() or t.done() for t in [task1, task2]) - - -@pytest.mark.asyncio -async def test_rate_limit_refill(): - rate_limit = Mock(spec=RateLimit) - rate_limit.refill = AsyncMock() - rate_limit.to_dict = Mock(return_value={"x": 1}) - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(rate_limit, rate_limit, rate_limit), - message_adapter=Mock(), - ) - await mq._refill_rate_limits() - assert rate_limit.refill.await_count == 3 - - -@pytest.mark.asyncio -async def test_set_gateway_adapter(message_queue): - adapter = Mock() - message_queue.set_gateway_message_adapter(adapter) - assert message_queue._gateway_adapter is adapter - - -@pytest.mark.asyncio -async def test_wait_for_message_cancel(): - stop = asyncio.Event() - stop.set() - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=stop, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=Mock(), - ) - - await mq._queue.put("dummy") - - result = await mq._wait_for_message() - assert result == "dummy" - - -@pytest.mark.asyncio -async def test_clear_futures_result_set(): - fut = asyncio.Future() - fut.uuid = "test-future" - msg = b"data" - fut.set_result = Mock() - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=Mock(), - ) - mq._queue.put_nowait(MqttPublishMessage("topic", msg, qos=1, delivery_futures=[fut])) - mq.clear() - assert mq.is_empty() - - -@pytest.mark.asyncio -async def test_message_queue_batching_respects_type_and_size(): - mqtt_manager = MagicMock() - mqtt_manager.publish = AsyncMock() - stop_event = asyncio.Event() - message_adapter = JsonMessageAdapter() - - mq = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=stop_event, - device_rate_limiter=RateLimiter(RateLimit('0:0,'), RateLimit('0:0,'), RateLimit('0:0,')), - message_adapter=message_adapter, - batch_collect_max_time_ms=200, - max_queue_size=100, - ) - mq._backpressure.should_pause.return_value = False - mq._mqtt_manager.backpressure.should_pause.return_value = False - - builder1 = DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("key1", 1)) - msg1 = builder1.build() - - builder2 = DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("key2", 2)) - msg2 = builder2.build() - - # First message fits, second exceeds total max_payload_size - mq._queue.put_nowait(MqttPublishMessage("topic", msg1, qos=1)) - mq._queue.put_nowait(MqttPublishMessage("topic", msg2, qos=1)) - - await asyncio.sleep(0.2) # Give loop time to pick up - - await mq.shutdown() - - mqtt_manager.publish.assert_called_once() - - -@pytest.mark.asyncio -async def test_schedule_delayed_retry_reenqueues_message(): - msg = MqttPublishMessage(topic="test/topic", payload=b"data", qos=1) - - queue = MessageQueue( - mqtt_manager=MagicMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=MagicMock() - ) - - queue._active.set() - queue._main_stop_event.clear() - - queue._queue = asyncio.Queue() - queue._retry_tasks.clear() - - queue._schedule_delayed_retry(msg, delay=0.05) - - await asyncio.sleep(0.1) - - requeued_msg = await queue._queue.get() - assert requeued_msg.topic == msg.topic - - assert queue._wakeup_event.is_set() - - await asyncio.sleep(0) - - assert len(queue._retry_tasks) == 0 - - -@pytest.mark.asyncio -async def test_schedule_delayed_retry_does_nothing_if_inactive(): - msg = MqttPublishMessage(topic="inactive/topic", payload=b"data", qos=1) - - queue = MessageQueue( - mqtt_manager=MagicMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=MagicMock() - ) - - queue._active.clear() - queue._main_stop_event.clear() - - queue._queue = asyncio.Queue() - queue._retry_tasks.clear() - - queue._schedule_delayed_retry(msg, delay=0.01) - - await asyncio.sleep(0.05) - - assert queue._queue.empty() - assert len(queue._retry_tasks) == 0 - - -@pytest.mark.asyncio -async def test_schedule_delayed_retry_does_nothing_if_stopped(): - msg = MqttPublishMessage(topic="stop/topic", payload=b"data", qos=1) - - stop_event = asyncio.Event() - stop_event.set() - - queue = MessageQueue( - mqtt_manager=MagicMock(), - main_stop_event=stop_event, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=MagicMock() - ) - - queue._active.set() - queue._queue = asyncio.Queue() - queue._retry_tasks.clear() - - queue._schedule_delayed_retry(msg, delay=0.01) - - await asyncio.sleep(0.05) - - assert queue._queue.empty() - assert len(queue._retry_tasks) == 0 - - -@pytest.mark.asyncio -async def test_try_publish_telemetry_rate_limit_triggered(): - telemetry_limit = Mock(spec=RateLimit) - telemetry_limit.try_consume = AsyncMock(return_value=(10, 1)) - telemetry_limit.minimal_timeout = 0.05 - - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(MagicMock(), telemetry_limit, MagicMock()), - message_adapter=Mock(build_uplink_messages=Mock(return_value=[])) - ) - mq._schedule_delayed_retry = AsyncMock() - mq._backpressure = Mock(spec=BackpressureController) - mq._backpressure.should_pause.return_value = False - - message = MqttPublishMessage(topic=DEVICE_TELEMETRY_TOPIC, payload=b"{}", qos=1) - - await mq._try_publish(message) - - await mq.shutdown() - - mq._schedule_delayed_retry.assert_called_once_with(message, delay=telemetry_limit.minimal_timeout) - called_args = mq._schedule_delayed_retry.call_args - assert called_args.args[0] == message - - -@pytest.mark.asyncio -async def test_try_publish_telemetry_dp_rate_limit_triggered(): - dp_limit = Mock(spec=RateLimit) - dp_limit.try_consume = AsyncMock(return_value=(100, 10)) - dp_limit.minimal_timeout = 0.1 - - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(RateLimit('0:0,'), RateLimit('0:0,'), dp_limit), - message_adapter=Mock(build_uplink_messages=Mock(return_value=[])) - ) - mq._schedule_delayed_retry = AsyncMock() - mq._backpressure = Mock(spec=BackpressureController) - mq._backpressure.should_pause.return_value = False - - message = MqttPublishMessage( - topic=DEVICE_TELEMETRY_TOPIC, - payload=b"{}", - qos=1, - datapoints=5 - ) - - await mq._try_publish(message) - - mq._schedule_delayed_retry.assert_called_once_with(message, delay=dp_limit.minimal_timeout) - - -@pytest.mark.asyncio -async def test_try_publish_generic_message_rate_limit_triggered(): - generic_limit = Mock(spec=RateLimit) - generic_limit.try_consume = AsyncMock(return_value=(1, 60)) - generic_limit.minimal_timeout = 0.2 - - mq = MessageQueue( - mqtt_manager=AsyncMock(), - main_stop_event=asyncio.Event(), - device_rate_limiter=RateLimiter(generic_limit, MagicMock(), MagicMock()), - message_adapter=Mock() - ) - mq._schedule_delayed_retry = AsyncMock() - mq._backpressure = Mock(spec=BackpressureController) - mq._backpressure.should_pause.return_value = False - - message = MqttPublishMessage(topic="some/other/topic", payload=b"{}", qos=1) - - await mq._try_publish(message) - - mq._schedule_delayed_retry.assert_called_once_with(message, delay=generic_limit.minimal_timeout) - called_args = mq._schedule_delayed_retry.call_args - assert called_args.args[0] == message - - -from unittest.mock import AsyncMock, patch - -@pytest.mark.asyncio -async def test_batch_breaks_on_elapsed_time(): - msg = MqttPublishMessage( - "topic", - DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("k", 1)).build(), - qos=1 - ) - - adapter = JsonMessageAdapter() - adapter.splitter.max_payload_size = 1000000 - - stop = asyncio.Event() - mqtt_manager = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - mq = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=stop, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=adapter, - batch_collect_max_time_ms=1, - batch_collect_max_count=1000 - ) - - mq._try_publish = AsyncMock() - - mq._queue.put_nowait(msg) - await asyncio.sleep(0.1) - await mq.shutdown() - - mq._try_publish.assert_called_once() - - -@pytest.mark.asyncio -async def test_batch_breaks_on_message_count(): - adapter = JsonMessageAdapter() - adapter.splitter.max_payload_size = 1000000 - - mqtt_manager = AsyncMock() - mqtt_manager.publish = AsyncMock(return_value=asyncio.Future()) - mqtt_manager.publish.return_value.set_result(True) - mqtt_manager.connected = AsyncMock() - mqtt_manager.connected.is_set.return_value = True - mqtt_manager.backpressure.should_pause.return_value = False - - stop = asyncio.Event() - - mq = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=stop, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=adapter, - batch_collect_max_time_ms=1000, - batch_collect_max_count=2 - ) - mq._try_publish = AsyncMock() - - msg1 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("a", 1)).build(), qos=1) - msg2 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("b", 2)).build(), qos=1) - msg3 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("c", 3)).build(), qos=1) - - await mq._queue.put(msg1) - await mq._queue.put(msg2) - await mq._queue.put(msg3) - await asyncio.sleep(0.2) - await mq.shutdown() - - assert mq._try_publish.call_count == 2 - - -@pytest.mark.asyncio -async def test_batch_breaks_on_type_mismatch(): - - adapter = JsonMessageAdapter() - adapter.splitter.max_payload_size = 1000000 - - mqtt_manager = AsyncMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - stop = asyncio.Event() - - mq = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=stop, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=adapter, - batch_collect_max_time_ms=1000, - batch_collect_max_count=10 - ) - mq._try_publish = AsyncMock() - - msg1 = MqttPublishMessage("topic", - DeviceUplinkMessageBuilder() - .add_timeseries(TimeseriesEntry("a", 1)) - .build(), qos=1) - msg2 = MqttPublishMessage("topic", - GatewayUplinkMessageBuilder() - .set_device_name("test_device") - .add_timeseries(TimeseriesEntry("b", 2)) - .build(), - qos=1) - - mq._queue.put_nowait(msg1) - mq._queue.put_nowait(msg2) - - await asyncio.sleep(0.2) - await mq.shutdown() - - assert mq._try_publish.call_count == 2 - - -@pytest.mark.asyncio -async def test_batch_breaks_on_size_threshold(): - from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder - - adapter = JsonMessageAdapter() - adapter.splitter.max_payload_size = 20 - - mqtt_manager = AsyncMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - stop = asyncio.Event() - - mq = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=stop, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=adapter, - batch_collect_max_time_ms=1000, - batch_collect_max_count=10 - ) - mq._try_publish = AsyncMock() - - msg1 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("a", 1)).build(), qos=1) - msg2 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("b", 2)).build(), qos=1) - - mq._queue.put_nowait(msg1) - mq._queue.put_nowait(msg2) - - await asyncio.sleep(0.2) - await mq.shutdown() - - assert mq._try_publish.call_count == 2 - - -@pytest.mark.asyncio -async def test_batch_skips_bytes_payload(): - mqtt_manager = AsyncMock() - mqtt_manager.publish = AsyncMock() - mqtt_manager.backpressure.should_pause.return_value = False - - stop = asyncio.Event() - - mq = MessageQueue( - mqtt_manager=mqtt_manager, - main_stop_event=stop, - device_rate_limiter=RateLimiter(MagicMock(), MagicMock(), MagicMock()), - message_adapter=JsonMessageAdapter(), - batch_collect_max_time_ms=1000, - batch_collect_max_count=1000 - ) - mq._try_publish = AsyncMock() - - - msg1 = MqttPublishMessage("topic", DeviceUplinkMessageBuilder().add_timeseries(TimeseriesEntry("a", 1)).build(), qos=1) - msg2 = MqttPublishMessage("topic", b"raw", qos=1) - - mq._queue.put_nowait(msg1) - mq._queue.put_nowait(msg2) - - await asyncio.sleep(0.2) - await mq.shutdown() - assert mq._try_publish.call_count == 2 - - -if __name__ == '__main__': - pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_message_service.py b/tests/service/test_message_service.py new file mode 100644 index 0000000..9eddefa --- /dev/null +++ b/tests/service/test_message_service.py @@ -0,0 +1,973 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock, call + +import pytest +import pytest_asyncio + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, EMPTY_RATE_LIMIT +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder +from tb_mqtt_client.service.device.message_adapter import MessageAdapter +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.mqtt_manager import MQTTManager +from tb_mqtt_client.service.message_service import MessageService, MessageQueueWorker + + +@pytest_asyncio.fixture +async def setup_message_service(): + # Create mocks for all dependencies + mqtt_manager = MagicMock(spec=MQTTManager) + main_stop_event = asyncio.Event() + + # Mock rate limiters + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.values.return_value = [ + device_rate_limiter.message_rate_limit, + device_rate_limiter.telemetry_message_rate_limit, + device_rate_limiter.telemetry_datapoints_rate_limit + ] + + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.values.return_value = [ + gateway_rate_limiter.message_rate_limit, + gateway_rate_limiter.telemetry_message_rate_limit, + gateway_rate_limiter.telemetry_datapoints_rate_limit + ] + + # Mock message adapters + message_adapter = MagicMock(spec=MessageAdapter) + gateway_message_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Create the service with patched asyncio.create_task to avoid actual task creation + with patch('asyncio.create_task', new=MagicMock()): + service = MessageService( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + device_rate_limiter=device_rate_limiter, + message_adapter=message_adapter, + gateway_message_adapter=gateway_message_adapter, + gateway_rate_limiter=gateway_rate_limiter + ) + + # Replace the actual tasks with mocks + service._initial_queue_task = MagicMock() + service._service_queue_task = MagicMock() + service._device_uplink_messages_queue_task = MagicMock() + service._gateway_uplink_messages_queue_task = MagicMock() + service._rate_limit_refill_task = MagicMock() + service.__print_queue_statistics_task = MagicMock() + + yield service, mqtt_manager, main_stop_event, device_rate_limiter, message_adapter, gateway_message_adapter, gateway_rate_limiter + + +@pytest.mark.asyncio +async def test_publish_success(setup_message_service): + service, mqtt_manager, _, _, _, _, _ = setup_message_service + + # Mock the queue put method + service._initial_queue.put = AsyncMock() + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the publish method + await service.publish(message) + + # Verify the message was added to the queue + service._initial_queue.put.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_publish_exception(setup_message_service): + service, mqtt_manager, _, _, _, _, _ = setup_message_service + + # Mock the queue put method to raise an exception + service._initial_queue.put = AsyncMock(side_effect=Exception("Queue error")) + + # Create a test message with a delivery future + future = asyncio.Future() + message = MqttPublishMessage("test/topic", b"test_payload", delivery_futures=[future]) + + # Call the publish method + await service.publish(message) + + # Verify the future was completed with an error result + assert future.done() + result = future.result() + assert isinstance(result, PublishResult) + assert result.reason_code == -1 + + +@pytest.mark.asyncio +async def test_shutdown(): + """Test the shutdown method with real tasks to ensure full coverage.""" + # Create mocks for all dependencies + mqtt_manager = MagicMock(spec=MQTTManager) + main_stop_event = asyncio.Event() + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.values.return_value = [] + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.values.return_value = [] + message_adapter = MagicMock(spec=MessageAdapter) + gateway_message_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Create a real service with real tasks + with patch('asyncio.create_task', side_effect=asyncio.create_task) as mock_create_task: + # Create simple async functions for the tasks + async def mock_task(): + try: + while True: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + # This is expected during shutdown + pass + + # Patch the methods that would create tasks to return our mock task + with patch.object(MessageService, '_dispatch_initial_queue_loop', return_value=mock_task()), \ + patch.object(MessageService, '_dispatch_queue_loop', return_value=mock_task()), \ + patch.object(MessageService, '_rate_limit_refill_loop', return_value=mock_task()), \ + patch.object(MessageService, 'print_queue_statistics', return_value=mock_task()), \ + patch.object(MessageService, 'clear', new_callable=AsyncMock) as mock_clear: + + # Create the service + service = MessageService( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + device_rate_limiter=device_rate_limiter, + message_adapter=message_adapter, + gateway_message_adapter=gateway_message_adapter, + gateway_rate_limiter=gateway_rate_limiter + ) + + # Verify that tasks were created + assert mock_create_task.call_count >= 5 + + # Call the shutdown method + await service.shutdown() + + # Verify the active flag was cleared + assert not service._active.is_set() + + # Verify the clear method was called + mock_clear.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_is_empty(setup_message_service): + service, _, _, _, _, _, _ = setup_message_service + + # Mock the queue is_empty method + service._initial_queue.is_empty = MagicMock(return_value=True) + + # Check if the queue is empty + assert service.is_empty() + + # Verify the is_empty method was called + service._initial_queue.is_empty.assert_called_once() + + +@pytest.mark.asyncio +async def test_clear(setup_message_service): + service, _, _, _, _, _, _ = setup_message_service + + # Mock the queue methods + service._initial_queue.is_empty = MagicMock(side_effect=[False, True]) + service._initial_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + service._service_queue.is_empty = MagicMock(side_effect=[False, True]) + service._service_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + service._device_uplink_messages_queue.is_empty = MagicMock(side_effect=[False, True]) + service._device_uplink_messages_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + service._gateway_uplink_messages_queue.is_empty = MagicMock(side_effect=[False, True]) + service._gateway_uplink_messages_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + # Call the clear method + await service.clear() + + # Verify the get method was called for each queue + service._initial_queue.get.assert_awaited_once() + service._service_queue.get.assert_awaited_once() + service._device_uplink_messages_queue.get.assert_awaited_once() + service._gateway_uplink_messages_queue.get.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_set_gateway_message_adapter(setup_message_service): + service, _, _, _, _, _, gateway_message_adapter = setup_message_service + + # Create a new mock adapter + new_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Set the new adapter + service.set_gateway_message_adapter(new_adapter) + + # Verify the adapter was set + assert service._gateway_adapter == new_adapter + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_service_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Mock the queue methods + service._initial_queue.peek_batch = AsyncMock(return_value=[ + MqttPublishMessage("topic", b"raw_payload") + ]) + service._initial_queue.pop_n = AsyncMock() + service._initial_queue.is_empty = MagicMock(return_value=True) + + service._service_queue.put = AsyncMock() + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # Verify the message was put in the service queue + service._service_queue.put.assert_awaited_once() + service._initial_queue.pop_n.assert_awaited_once_with(1) + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_device_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, message_adapter, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a device uplink message using the builder + device_uplink = DeviceUplinkMessageBuilder().set_device_name("test-device").build() + device_message = MqttPublishMessage("topic", device_uplink) + + # Mock the queue methods + service._initial_queue.peek_batch = AsyncMock(return_value=[device_message]) + service._initial_queue.pop_n = AsyncMock() + service._initial_queue.is_empty = MagicMock(return_value=True) + + service._device_uplink_messages_queue.extend = AsyncMock() + + # Mock the adapter + processed_messages = [MqttPublishMessage("processed_topic", b"processed_payload")] + message_adapter.build_uplink_messages.return_value = processed_messages + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # Verify the adapter was called and messages were added to the device queue + message_adapter.build_uplink_messages.assert_called_once_with([device_message]) + service._device_uplink_messages_queue.extend.assert_awaited_once_with(processed_messages) + service._initial_queue.pop_n.assert_awaited_once_with(1) + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_gateway_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, message_adapter, gateway_message_adapter, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a gateway uplink message using the builder + gateway_uplink = GatewayUplinkMessageBuilder().set_device_name("test-device").build() + gateway_message = MqttPublishMessage("topic", gateway_uplink) + + # Mock the queue methods + service._initial_queue.peek_batch = AsyncMock(return_value=[gateway_message]) + service._initial_queue.pop_n = AsyncMock() + service._initial_queue.is_empty = MagicMock(return_value=True) + + service._gateway_uplink_messages_queue.extend = AsyncMock() + + # Mock the adapter + processed_messages = [MqttPublishMessage("processed_topic", b"processed_payload")] + gateway_message_adapter.build_uplink_messages.return_value = processed_messages + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # Verify the adapter was called and messages were added to the gateway queue + gateway_message_adapter.build_uplink_messages.assert_called_once_with([gateway_message]) + service._gateway_uplink_messages_queue.extend.assert_awaited_once_with(processed_messages) + service._initial_queue.pop_n.assert_awaited_once_with(1) + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_exception(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Mock the queue methods to raise an exception + service._initial_queue.peek_batch = AsyncMock(side_effect=Exception("Test exception")) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # The test passes if no exception is raised + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods + message = MqttPublishMessage("topic", b"payload") + queue.get = AsyncMock(return_value=message) + queue.is_empty = MagicMock(return_value=True) + + # Mock the worker process method + worker.process = AsyncMock(return_value=(None, None, None)) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # Verify the worker process method was called + worker.process.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop_rate_limit_triggered(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods + message = MqttPublishMessage("topic", b"payload") + queue.get = AsyncMock(return_value=message) + queue.is_empty = MagicMock(return_value=True) + queue.reinsert_front = AsyncMock() + + # Mock the worker process method to trigger rate limit + rate_limit = MagicMock(spec=RateLimit) + rate_limit.required_tokens_ready = asyncio.Event() + rate_limit.required_tokens_ready.set() # Set it to avoid waiting in the test + worker.process = AsyncMock(return_value=(10, 5, rate_limit)) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # Verify the worker process method was called + worker.process.assert_awaited_once_with(message) + + # Verify the message was reinserted to the front of the queue + queue.reinsert_front.assert_awaited_once_with(message) + + # Verify the rate limit was set + rate_limit.set_required_tokens.assert_called_once_with(10, 5) + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop_empty_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods to return None + queue.get = AsyncMock(return_value=None) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # Verify the worker process method was not called + worker.process.assert_not_called() + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop_exception(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods to raise an exception + queue.get = AsyncMock(side_effect=Exception("Test exception")) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # The test passes if no exception is raised + + +@pytest.mark.asyncio +async def test_rate_limit_refill_loop(setup_message_service): + service, mqtt_manager, _, device_rate_limiter, _, _, gateway_rate_limiter = setup_message_service + + # Set up the test to run once and then exit + service._active.is_set = MagicMock(side_effect=[True, False]) + + # Mock the refill methods + service._refill_rate_limits = AsyncMock() + + # Run the refill loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._rate_limit_refill_loop() + + # Verify the refill method was called + service._refill_rate_limits.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_refill_rate_limits(setup_message_service): + service, mqtt_manager, _, device_rate_limiter, _, _, gateway_rate_limiter = setup_message_service + + # Mock the rate limit refill methods + device_rate_limit = MagicMock(spec=RateLimit) + device_rate_limit.refill = AsyncMock() + device_rate_limiter.values.return_value = [device_rate_limit] + + gateway_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limit.refill = AsyncMock() + gateway_rate_limiter.values.return_value = [gateway_rate_limit] + + # Call the refill method + await service._refill_rate_limits() + + # Verify the refill methods were called + device_rate_limit.refill.assert_awaited_once() + gateway_rate_limit.refill.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_print_queue_statistics(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Create a patched version of the print_queue_statistics method to avoid infinite loop + original_print_queue_statistics = service.print_queues_statistics + + async def patched_print_queue_statistics(): + # Just run the body of the loop once + initial_queue_size = service._initial_queue.size() + service_queue_size = service._service_queue.size() + device_uplink_queue_size = service._device_uplink_messages_queue.size() + gateway_uplink_queue_size = service._gateway_uplink_messages_queue.size() + # We don't need to log anything in the test + + # Replace the method with our patched version + service.print_queues_statistics = patched_print_queue_statistics + + # Mock the queue size methods + service._initial_queue.size = MagicMock(return_value=1) + service._service_queue.size = MagicMock(return_value=2) + service._device_uplink_messages_queue.size = MagicMock(return_value=3) + service._gateway_uplink_messages_queue.size = MagicMock(return_value=4) + + # Run the statistics method + await service.print_queues_statistics() + + # Verify the size methods were called + service._initial_queue.size.assert_called_once() + service._service_queue.size.assert_called_once() + service._device_uplink_messages_queue.size.assert_called_once() + service._gateway_uplink_messages_queue.size.assert_called_once() + + +@pytest.mark.asyncio +async def test_cancel_tasks(): + # Create mock tasks + task1 = MagicMock(spec=asyncio.Task) + task2 = MagicMock(spec=asyncio.Task) + tasks = {task1, task2} + + # Call the cancel_tasks method + with patch('asyncio.gather', new=AsyncMock()): + await MessageService._cancel_tasks(tasks) + + # Verify the tasks were canceled + task1.cancel.assert_called_once() + task2.cancel.assert_called_once() + assert len(tasks) == 0 + + +# Tests for MessageQueueWorker class +@pytest.mark.asyncio +async def test_message_queue_worker_process_no_rate_limits(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + device_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter = MagicMock(spec=RateLimiter) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Mock the rate limit methods + worker._get_rate_limits_for_message = MagicMock(return_value=(EMPTY_RATE_LIMIT, EMPTY_RATE_LIMIT)) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the process method + result = await worker.process(message) + + # Verify the mqtt_manager.publish method was called + mqtt_manager.publish.assert_awaited_once_with(message) + + # Verify the result + assert result == (None, None, None) + + +@pytest.mark.asyncio +async def test_message_queue_worker_process_with_rate_limits_not_triggered(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + device_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter = MagicMock(spec=RateLimiter) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create rate limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + + # Mock the rate limit methods + worker._get_rate_limits_for_message = MagicMock(return_value=(message_rate_limit, datapoints_rate_limit)) + worker.check_rate_limits_for_message = AsyncMock(return_value=(None, 0, None)) + worker._consume_rate_limits_for_message = AsyncMock() + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the process method + result = await worker.process(message) + + # Verify the rate limit methods were called + worker.check_rate_limits_for_message.assert_awaited_once_with( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + worker._consume_rate_limits_for_message.assert_awaited_once_with( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the mqtt_manager.publish method was called + mqtt_manager.publish.assert_awaited_once_with(message) + + # Verify the result + assert result == (None, None, None) + + +@pytest.mark.asyncio +async def test_message_queue_worker_process_with_rate_limits_triggered(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + device_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter = MagicMock(spec=RateLimiter) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create rate limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + + # Mock the rate limit methods + worker._get_rate_limits_for_message = MagicMock(return_value=(message_rate_limit, datapoints_rate_limit)) + + # Set up the rate limit to be triggered + triggered_rate_limit_entry = (10, 5) # (tokens, duration) + expected_tokens = 5 + worker.check_rate_limits_for_message = AsyncMock(return_value=(triggered_rate_limit_entry, expected_tokens, message_rate_limit)) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the process method + result = await worker.process(message) + + # Verify the rate limit methods were called + worker.check_rate_limits_for_message.assert_awaited_once_with( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the mqtt_manager.publish method was not called + mqtt_manager.publish.assert_not_awaited() + + # Verify the result + assert result == (5, 5, message_rate_limit) + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_device_service(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = True + message.is_service_message = True + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == device_rate_limiter.message_rate_limit + assert datapoints_rate_limit == EMPTY_RATE_LIMIT + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_device_telemetry(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = True + message.is_service_message = False + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == device_rate_limiter.telemetry_message_rate_limit + assert datapoints_rate_limit == device_rate_limiter.telemetry_datapoints_rate_limit + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_gateway_service(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = False + message.is_service_message = True + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == gateway_rate_limiter.message_rate_limit + assert datapoints_rate_limit == EMPTY_RATE_LIMIT + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_gateway_telemetry(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = False + message.is_service_message = False + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == gateway_rate_limiter.telemetry_message_rate_limit + assert datapoints_rate_limit == gateway_rate_limiter.telemetry_datapoints_rate_limit + + +@pytest.mark.asyncio +async def test_message_queue_worker_check_rate_limits_for_message_message_limit_triggered(): + # Create a mock message rate limit that will trigger + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.try_consume = AsyncMock(return_value=(10, 5)) # (tokens, duration) + + # Create a mock datapoints rate limit that won't be checked + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + + # Call the method + result = await MessageQueueWorker.check_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the result + assert result == ((10, 5), 1, message_rate_limit) + + # Verify the message rate limit was checked + message_rate_limit.try_consume.assert_awaited_once_with(1) + + # Verify the datapoints rate limit was not checked + datapoints_rate_limit.try_consume.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_message_queue_worker_check_rate_limits_for_message_datapoints_limit_triggered(): + # Create a mock message rate limit that won't trigger + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.try_consume = AsyncMock(return_value=None) + + # Create a mock datapoints rate limit that will trigger + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + datapoints_rate_limit.try_consume = AsyncMock(return_value=(20, 10)) # (tokens, duration) + + # Call the method + result = await MessageQueueWorker.check_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the result + assert result == ((20, 10), 5, datapoints_rate_limit) + + # Verify both rate limits were checked + message_rate_limit.try_consume.assert_awaited_once_with(1) + datapoints_rate_limit.try_consume.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_message_queue_worker_check_rate_limits_for_message_no_limits_triggered(): + # Create mock rate limits that won't trigger + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.try_consume = AsyncMock(return_value=None) + + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + datapoints_rate_limit.try_consume = AsyncMock(return_value=None) + + # Call the method + result = await MessageQueueWorker.check_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the result + assert result == (None, 0, None) + + # Verify both rate limits were checked + message_rate_limit.try_consume.assert_awaited_once_with(1) + datapoints_rate_limit.try_consume.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_message_queue_worker_consume_rate_limits_for_message(): + # Create mock rate limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.consume = AsyncMock() + + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + datapoints_rate_limit.consume = AsyncMock() + + # Call the method + await MessageQueueWorker._consume_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify both rate limits were consumed + message_rate_limit.consume.assert_awaited_once_with(1) + datapoints_rate_limit.consume.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_message_queue_worker_consume_rate_limits_for_message_no_limits(): + # Create mock rate limits with no limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = False + + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = False + + # Call the method + await MessageQueueWorker._consume_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify no rate limits were consumed + message_rate_limit.consume.assert_not_awaited() + datapoints_rate_limit.consume.assert_not_awaited() + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/test_async_utils.py b/tests/test_async_utils.py new file mode 100644 index 0000000..2ea8010 --- /dev/null +++ b/tests/test_async_utils.py @@ -0,0 +1,68 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest import IsolatedAsyncioTestCase +from unittest.mock import patch + +from tb_mqtt_client.common.async_utils import await_and_resolve_original + + +class TestAwaitAndResolveOriginal(IsolatedAsyncioTestCase): + + async def resolves_parent_with_first_successful_child_result(self): + parent_future = asyncio.Future() + child_future_1 = asyncio.Future() + child_future_2 = asyncio.Future() + + child_future_1.set_result("success") + child_future_2.set_exception(ValueError("error")) + + await await_and_resolve_original([parent_future], [child_future_1, child_future_2]) + + self.assertTrue(parent_future.done()) + self.assertEqual(parent_future.result(), "success") + + async def resolves_parent_with_first_exception_if_no_successful_results(self): + parent_future = asyncio.Future() + child_future_1 = asyncio.Future() + child_future_2 = asyncio.Future() + + child_future_1.set_exception(ValueError("error1")) + child_future_2.set_exception(ValueError("error2")) + + await await_and_resolve_original([parent_future], [child_future_1, child_future_2]) + + self.assertTrue(parent_future.done()) + self.assertIsInstance(parent_future.exception(), ValueError) + self.assertEqual(str(parent_future.exception()), "error1") + + async def handles_empty_child_futures_list(self): + parent_future = asyncio.Future() + + await await_and_resolve_original([parent_future], []) + + self.assertTrue(parent_future.done()) + self.assertIsNone(parent_future.result()) + + async def sets_exception_on_parent_if_unexpected_error_occurs(self): + parent_future = asyncio.Future() + child_future = asyncio.Future() + + with patch("tb_mqtt_client.common.async_utils.future_map.child_resolved", side_effect=Exception("unexpected")): + await await_and_resolve_original([parent_future], [child_future]) + + self.assertTrue(parent_future.done()) + self.assertIsInstance(parent_future.exception(), Exception) + self.assertEqual(str(parent_future.exception()), "unexpected") \ No newline at end of file From f8f13f7bc0a511a9bcc36063b773c6dd0de5faff Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 1 Aug 2025 14:31:36 +0300 Subject: [PATCH 63/74] Removed old implementation, added tests --- sdk_utils.py | 85 - tb_device_http.py | 465 ----- tb_device_mqtt.py | 1654 ----------------- tb_gateway_mqtt.py | 356 ---- tb_mqtt_client/common/async_utils.py | 34 +- tb_mqtt_client/common/mqtt_message.py | 4 - tb_mqtt_client/common/publish_result.py | 3 +- .../common/rate_limit/rate_limit.py | 2 +- .../entities/data/attribute_entry.py | 2 - .../entities/data/attribute_request.py | 1 + tb_mqtt_client/entities/data/data_entry.py | 1 + .../entities/data/device_uplink_message.py | 1 - tb_mqtt_client/entities/data/rpc_response.py | 2 +- .../entities/data/timeseries_entry.py | 2 +- .../entities/gateway/device_info.py | 12 +- .../gateway/gateway_attribute_request.py | 1 + .../gateway/gateway_attribute_update.py | 16 +- .../entities/gateway/gateway_claim_request.py | 2 +- .../gateway_requested_attribute_response.py | 3 + .../entities/gateway/gateway_rpc_request.py | 2 +- .../gateway/gateway_uplink_message.py | 5 +- .../service/base_message_splitter.py | 6 - .../service/device/firmware_updater.py | 10 +- .../service/device/message_adapter.py | 2 +- tb_mqtt_client/service/gateway/client.py | 14 +- .../service/gateway/device_manager.py | 3 +- .../service/gateway/device_session.py | 5 +- .../gateway/direct_event_dispatcher.py | 2 + .../service/gateway/message_adapter.py | 1 - .../service/gateway/message_splitter.py | 8 +- tests/common/test_async_utils.py | 134 ++ tests/common/test_backpressure_controller.py | 4 +- tests/common/test_config_loader.py | 2 + tests/common/test_gmqtt_patch.py | 144 ++ tests/common/test_provisioning_client.py | 3 +- tests/common/test_publish_result.py | 67 + tests/common/test_queue.py | 155 ++ tests/common/test_rate_limit.py | 108 ++ tests/entities/data/test_timeseries_entry.py | 1 + tests/entities/gateway/__init__.py | 13 + .../gateway/test_base_gateway_event.py | 61 + .../gateway/test_device_connect_message.py | 62 + .../gateway/test_device_disconnect_message.py | 58 + tests/entities/gateway/test_device_info.py | 111 ++ .../gateway/test_gateway_attribute_request.py | 130 ++ .../gateway/test_gateway_attribute_update.py | 55 + .../gateway/test_gateway_claim_request.py | 99 + tests/entities/gateway/test_gateway_event.py | 76 + ...st_gateway_requested_attribute_response.py | 129 ++ .../gateway/test_gateway_uplink_message.py | 152 ++ tests/service/device/handlers/__init__.py | 13 + .../test_attribute_updates_handler.py | 126 ++ ...t_requested_attributes_response_handler.py | 194 ++ .../handlers/test_rpc_requests_handler.py | 225 +++ .../handlers/test_rpc_response_handler.py | 193 ++ .../service/gateway/__init__.py | 0 tests/service/gateway/handlers/__init__.py | 14 + .../test_gateway_attribute_updates_handler.py | 111 ++ ...y_requested_attributes_response_handler.py | 405 ++++ .../handlers/test_gateway_rpc_handler.py | 398 ++++ tests/service/gateway/test_device_manager.py | 391 ++++ tests/service/gateway/test_device_session.py | 307 +++ .../gateway/test_direct_event_dispatcher.py | 278 +++ tests/service/gateway/test_gateway_client.py | 505 +++++ tests/service/gateway/test_message_adapter.py | 536 ++++++ tests/service/gateway/test_message_sender.py | 323 ++++ .../service/gateway/test_message_splitter.py | 412 ++++ tests/service/test_message_service.py | 95 +- tests/tb_device_mqtt_client_tests.py | 436 ----- tests/tb_gateway_mqtt_client_tests.py | 105 -- tests/test_async_utils.py | 68 - utils.py | 48 - 72 files changed, 6148 insertions(+), 3303 deletions(-) delete mode 100644 sdk_utils.py delete mode 100644 tb_device_http.py delete mode 100644 tb_device_mqtt.py delete mode 100644 tb_gateway_mqtt.py create mode 100644 tests/common/test_async_utils.py create mode 100644 tests/common/test_gmqtt_patch.py create mode 100644 tests/common/test_queue.py create mode 100644 tests/entities/gateway/__init__.py create mode 100644 tests/entities/gateway/test_base_gateway_event.py create mode 100644 tests/entities/gateway/test_device_connect_message.py create mode 100644 tests/entities/gateway/test_device_disconnect_message.py create mode 100644 tests/entities/gateway/test_device_info.py create mode 100644 tests/entities/gateway/test_gateway_attribute_request.py create mode 100644 tests/entities/gateway/test_gateway_attribute_update.py create mode 100644 tests/entities/gateway/test_gateway_claim_request.py create mode 100644 tests/entities/gateway/test_gateway_event.py create mode 100644 tests/entities/gateway/test_gateway_requested_attribute_response.py create mode 100644 tests/entities/gateway/test_gateway_uplink_message.py create mode 100644 tests/service/device/handlers/__init__.py create mode 100644 tests/service/device/handlers/test_attribute_updates_handler.py create mode 100644 tests/service/device/handlers/test_requested_attributes_response_handler.py create mode 100644 tests/service/device/handlers/test_rpc_requests_handler.py create mode 100644 tests/service/device/handlers/test_rpc_response_handler.py rename tb_mqtt_client/tb_device_mqtt.py => tests/service/gateway/__init__.py (100%) create mode 100644 tests/service/gateway/handlers/__init__.py create mode 100644 tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py create mode 100644 tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py create mode 100644 tests/service/gateway/handlers/test_gateway_rpc_handler.py create mode 100644 tests/service/gateway/test_device_manager.py create mode 100644 tests/service/gateway/test_device_session.py create mode 100644 tests/service/gateway/test_direct_event_dispatcher.py create mode 100644 tests/service/gateway/test_gateway_client.py create mode 100644 tests/service/gateway/test_message_adapter.py create mode 100644 tests/service/gateway/test_message_sender.py create mode 100644 tests/service/gateway/test_message_splitter.py delete mode 100644 tests/tb_device_mqtt_client_tests.py delete mode 100644 tests/tb_gateway_mqtt_client_tests.py delete mode 100644 tests/test_async_utils.py delete mode 100644 utils.py diff --git a/sdk_utils.py b/sdk_utils.py deleted file mode 100644 index 528104c..0000000 --- a/sdk_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from random import randint -from zlib import crc32 -from hashlib import sha256, sha384, sha512, md5 -import logging -from subprocess import CalledProcessError - -from utils import install_package - - -try: - from mmh3 import hash, hash128 -except ImportError: - try: - from pymmh3 import hash, hash128 - except ImportError: - try: - install_package('mmh3') - except CalledProcessError: - install_package('pymmh3') - -try: - from mmh3 import hash, hash128 # noqa -except ImportError: - from pymmh3 import hash, hash128 - -log = logging.getLogger(__name__) - - -def verify_checksum(firmware_data, checksum_alg, checksum): - if firmware_data is None: - log.debug('Firmware wasn\'t received!') - return False - if checksum is None: - log.debug('Checksum was\'t provided!') - return False - checksum_of_received_firmware = None - log.debug('Checksum algorithm is: %s' % checksum_alg) - if checksum_alg.lower() == "sha256": - checksum_of_received_firmware = sha256(firmware_data).digest().hex() - elif checksum_alg.lower() == "sha384": - checksum_of_received_firmware = sha384(firmware_data).digest().hex() - elif checksum_alg.lower() == "sha512": - checksum_of_received_firmware = sha512(firmware_data).digest().hex() - elif checksum_alg.lower() == "md5": - checksum_of_received_firmware = md5(firmware_data).digest().hex() - elif checksum_alg.lower() == "murmur3_32": - reversed_checksum = f'{hash(firmware_data, signed=False):0>2X}' - if len(reversed_checksum) % 2 != 0: - reversed_checksum = '0' + reversed_checksum - checksum_of_received_firmware = "".join( - reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - elif checksum_alg.lower() == "murmur3_128": - reversed_checksum = f'{hash128(firmware_data, signed=False):0>2X}' - if len(reversed_checksum) % 2 != 0: - reversed_checksum = '0' + reversed_checksum - checksum_of_received_firmware = "".join( - reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - elif checksum_alg.lower() == "crc32": - reversed_checksum = f'{crc32(firmware_data) & 0xffffffff:0>2X}' - if len(reversed_checksum) % 2 != 0: - reversed_checksum = '0' + reversed_checksum - checksum_of_received_firmware = "".join( - reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - else: - log.error('Client error. Unsupported checksum algorithm.') - log.debug(checksum_of_received_firmware) - random_value = randint(0, 5) - if random_value > 3: - log.debug('Dummy fail! Do not panic, just restart and try again the chance of this fail is ~20%') - return False - return checksum_of_received_firmware == checksum diff --git a/tb_device_http.py b/tb_device_http.py deleted file mode 100644 index 140ca53..0000000 --- a/tb_device_http.py +++ /dev/null @@ -1,465 +0,0 @@ -"""ThingsBoard HTTP API device module.""" -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import logging -import queue -import time -import typing -from datetime import datetime, timezone -from sdk_utils import verify_checksum - -import requests -from math import ceil - -FW_CHECKSUM_ATTR = "fw_checksum" -FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" -FW_SIZE_ATTR = "fw_size" -FW_TITLE_ATTR = "fw_title" -FW_VERSION_ATTR = "fw_version" - -FW_STATE_ATTR = "fw_state" - -REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] - - -class TBHTTPAPIException(Exception): - """ThingsBoard HTTP Device API Exception class.""" - - -class TBProvisionFailure(TBHTTPAPIException): - """Exception raised if device provisioning failed.""" - - -class TBHTTPDevice: - """ThingsBoard HTTP Device API class. - - :param host: The ThingsBoard hostname. - :param token: The device token. - :param name: A name for this device. The name is only set locally. - """ - - def __init__(self, host: str, token: str, name: str = None, chunk_size: int = 0): - self.__session = requests.Session() - self.__session.headers.update({'Content-Type': 'application/json'}) - self.__config = { - 'host': host, 'token': token, 'name': name, 'timeout': 30 - } - self.__worker = { - 'publish': { - 'queue': queue.Queue(), - 'thread': threading.Thread(target=self.__publish_worker, daemon=True), - 'stop_event': threading.Event() - }, - 'attributes': { - 'thread': threading.Thread(target=self.__subscription_worker, - daemon=True, - kwargs={'endpoint': 'attributes'}), - 'stop_event': threading.Event(), - }, - 'rpc': { - 'thread': threading.Thread(target=self.__subscription_worker, - daemon=True, - kwargs={'endpoint': 'rpc'}), - 'stop_event': threading.Event(), - } - } - self.current_firmware_info = { - "current_fw_title": None, - "current_fw_version": None - } - self.chunk_size = chunk_size - - def __repr__(self): - return f'' - - @property - def host(self) -> str: - """Get the ThingsBoard hostname.""" - return self.__config['host'] - - @property - def name(self) -> str: - """Get the device name.""" - return self.__config['name'] - - @property - def timeout(self) -> int: - """Get the connection timeout.""" - return self.__config['timeout'] - - @property - def api_base_url(self) -> str: - """Get the ThingsBoard API base URL.""" - return f'{self.host}/api/v1/{self.token}' - - @property - def token(self) -> str: - """Get the device token.""" - return self.__config['token'] - - @property - def logger(self) -> logging.Logger: - """Get the logger instance.""" - return logging.getLogger('TBHTTPDevice') - - @property - def log_level(self) -> str: - """Get the log level.""" - levels = {0: 'NOTSET', 10: 'DEBUG', 20: 'INFO', 30: 'WARNING', 40: 'ERROR', 50: 'CRITICAL'} - return levels.get(self.logger.level) - - @log_level.setter - def log_level(self, value: typing.Union[int, str]): - self.logger.setLevel(value) - self.logger.critical('Log level set to %s', self.log_level) - - def __get_firmware_info(self): - response = self.__session.get( - f"{self.__config['host']}/api/v1/{self.__config['token']}/attributes", - params={"sharedKeys": REQUIRED_SHARED_KEYS}).json() - return response.get("shared", {}) - - def __get_firmware(self, fw_info): - chunk_count = ceil(fw_info.get(FW_SIZE_ATTR, 0) / self.chunk_size) if self.chunk_size > 0 else 0 - firmware_data = b'' - for chunk_number in range(chunk_count + 1): - params = {"title": fw_info.get(FW_TITLE_ATTR), - "version": fw_info.get(FW_VERSION_ATTR), - "size": self.chunk_size if self.chunk_size < fw_info.get(FW_SIZE_ATTR, - 0) else fw_info.get( - FW_SIZE_ATTR, 0), - "chunk": chunk_number - } - self.logger.debug(params) - self.logger.debug( - 'Getting chunk with number: %s. Chunk size is : %r byte(s).' % (chunk_number + 1, self.chunk_size)) - response = self.__session.get( - f"{self.__config['host']}/api/v1/{self.__config['token']}/firmware", - params=params) - if response.status_code != 200: - self.logger.error('Received error:') - response.raise_for_status() - return - firmware_data = firmware_data + response.content - return firmware_data - - def __on_firmware_received(self, firmware_info, firmware_data): - with open(firmware_info.get(FW_TITLE_ATTR), "wb") as firmware_file: - firmware_file.write(firmware_data) - - self.logger.info('Firmware is updated!\n Current firmware version is: %s' % firmware_info.get(FW_VERSION_ATTR)) - - def get_firmware_update(self): - self.send_telemetry(self.current_firmware_info) - self.logger.info('Getting firmware info from %s' % self.__config['host']) - - firmware_info = self.__get_firmware_info() - if (firmware_info.get(FW_VERSION_ATTR) is not None and firmware_info.get( - FW_VERSION_ATTR) != self.current_firmware_info.get("current_" + FW_VERSION_ATTR)) \ - or (firmware_info.get(FW_TITLE_ATTR) is not None and firmware_info.get( - FW_TITLE_ATTR) != self.current_firmware_info.get("current_" + FW_TITLE_ATTR)): - self.logger.info('New firmware available!') - - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADING" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - - firmware_data = self.__get_firmware(firmware_info) - - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADED" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - - verification_result = verify_checksum(firmware_data, firmware_info.get(FW_CHECKSUM_ALG_ATTR), - firmware_info.get(FW_CHECKSUM_ATTR)) - - if verification_result: - self.logger.debug('Checksum verified!') - self.current_firmware_info[FW_STATE_ATTR] = "VERIFIED" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - else: - self.logger.debug('Checksum verification failed!') - self.current_firmware_info[FW_STATE_ATTR] = "FAILED" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - firmware_data = self.__get_firmware(firmware_info) - return - - self.current_firmware_info[FW_STATE_ATTR] = "UPDATING" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - - self.__on_firmware_received(firmware_info, firmware_data) - - current_firmware_info = { - "current_" + FW_TITLE_ATTR: firmware_info.get(FW_TITLE_ATTR), - "current_" + FW_VERSION_ATTR: firmware_info.get(FW_VERSION_ATTR), - FW_STATE_ATTR: "UPDATED" - } - time.sleep(1) - self.send_telemetry(current_firmware_info) - - def start_publish_worker(self): - """Start the publish worker thread.""" - self.__worker['publish']['stop_event'].clear() - self.__worker['publish']['thread'].start() - - def stop_publish_worker(self): - """Stop the publish worker thread.""" - self.__worker['publish']['stop_event'].set() - - def __publish_worker(self): - """Publish telemetry data from the queue.""" - logger = self.logger.getChild('worker.publish') - logger.info('Start publisher thread') - logger.debug('Perform connection test before entering worker loop') - if not self.test_connection(): - logger.error('Connection test failed, exit publisher thread') - return - logger.debug('Connection test successful') - while True: - if not self.__worker['publish']['queue'].empty(): - try: - task = self.__worker['publish']['queue'].get(timeout=1) - except queue.Empty: - if self.__worker['publish']['stop_event'].is_set(): - break - continue - - endpoint = task.pop('endpoint') - - try: - self._publish_data(task, endpoint) - except Exception as error: - # ToDo: More precise exception catching - logger.error(error) - task.update({'endpoint': endpoint}) - self.__worker['publish']['queue'].put(task) - time.sleep(1) - else: - logger.debug('Published %s to %s', task, endpoint) - self.__worker['publish']['queue'].task_done() - - time.sleep(.2) - - logger.info('Stop publisher thread.') - - def test_connection(self) -> bool: - """Test connection to the API. - - :return: True if no errors occurred, False otherwise. - """ - self.logger.debug('Start connection test') - success = False - try: - self._publish_data(data={}, endpoint='telemetry') - except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as error: - self.logger.debug(error) - except requests.exceptions.HTTPError as error: - self.logger.debug(error) - status_code = error.response.status_code - if status_code == 401: - self.logger.error('Error 401: Unauthorized. Check if token is correct.') - else: - self.logger.error('Error %s', status_code) - else: - self.logger.debug('Connection test successful') - success = True - finally: - self.logger.debug('End connection test') - return success - - def connect(self) -> bool: - """Publish an empty telemetry data to ThingsBoard to test the connection. - - :return: True if connected, false otherwise. - """ - if self.test_connection(): - self.logger.info('Connected to ThingsBoard') - self.start_publish_worker() - return True - return False - - def _publish_data(self, data: dict, endpoint: str, timeout: int = None) -> dict: - """Send POST data to ThingsBoard. - - :param data: The data dictionary to send. - :param endpoint: The receiving API endpoint. - :param timeout: Override the instance timeout for this request. - """ - response = self.__session.post( - url=f'{self.api_base_url}/{endpoint}', - json=data, - timeout=timeout or self.timeout) - response.raise_for_status() - return response.json() if response.content else {} - - def _get_data(self, params: dict, endpoint: str, timeout: int = None) -> dict: - """Retrieve data with GET from ThingsBoard. - - :param params: A dictionary with the parameters for the request. - :param endpoint: The receiving API endpoint. - :param timeout: Override the instance timeout for this request. - :return: A dictionary with the response from the ThingsBoard instance. - """ - response = self.__session.get( - url=f'{self.api_base_url}/{endpoint}', - params=params, - timeout=timeout or self.timeout) - response.raise_for_status() - return response.json() - - def send_telemetry(self, telemetry: dict, timestamp: datetime = None, queued: bool = True): - """Publish telemetry to ThingsBoard. - - :param telemetry: A dictionary with the telemetry data to send. - :param timestamp: Timestamp to set for the values. If not set the ThingsBoard server uses - the time of reception as timestamp. - :param queued: Add the telemetry to the queue. If False, the data is send immediately. - """ - timestamp = datetime.now() if timestamp is None else timestamp - payload = { - 'ts': int(timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000), - 'values': telemetry, - } - if queued: - payload.update({'endpoint': 'telemetry'}) - self.__worker['publish']['queue'].put(payload) - else: - self._publish_data(payload, 'telemetry') - - def send_attributes(self, attributes: dict): - """Send attributes to ThingsBoard. - - :param attributes: Attributes to send. - """ - self._publish_data(attributes, 'attributes') - - def send_rpc(self, name: str, params: dict = None, rpc_id: int = None) -> dict: - """Send RPC to ThingsBoard and return response. - - :param name: Name of the RPC method. - :param params: Parameter for the RPC. - :param rpc_id: Specify an Id for this RPC. - :return: A dictionary with the response. - """ - endpoint = f'rpc/{rpc_id}' if rpc_id else 'rpc' - return self._publish_data({'method': name, 'params': params or {}}, endpoint) - - def request_attributes(self, client_keys: list = None, shared_keys: list = None) -> dict: - """Request attributes from ThingsBoard. - - :param client_keys: A list of keys for client attributes. - :param shared_keys: A list of keys for shared attributes. - :return: A dictionary with the request attributes. - """ - params = {'client_keys': client_keys, 'shared_keys': shared_keys} - return self._get_data(params=params, endpoint='attributes') - - def __subscription_worker(self, endpoint: str, timeout: int = None): - """Worker thread for subscription to HTTP API endpoints. - - :param endpoint: The endpoint name. - :param timeout: Timeout value in seconds. - """ - logger = self.logger.getChild(f'worker.subscription.{endpoint}') - stop_event = self.__worker[endpoint]['stop_event'] - logger.info('Start subscription to %s updates', endpoint) - - if not self.__worker[endpoint].get('callback'): - logger.warning('No callback set for %s subscription', endpoint) - stop_event.set() - callback = self.__worker[endpoint].get('callback', lambda data: None) - params = { - 'timeout': (timeout or self.timeout) * 1000 - } - url = { - 'attributes': f'{self.api_base_url}/attributes/updates', - 'rpc': f'{self.api_base_url}/rpc' - } - logger.debug('Timeout set to %ss', params['timeout'] / 1000) - while not stop_event.is_set(): - response = self.__session.get(url=url[endpoint], - params=params, - timeout=params['timeout']) - if stop_event.is_set(): - break - if response.status_code == 408: # Request timeout - continue - if response.status_code == 504: # Gateway Timeout - continue # Reconnect - response.raise_for_status() - callback(response.json()) - time.sleep(.1) - - stop_event.clear() - logger.info('Stop subscription to %s updates', endpoint) - - def subscribe(self, endpoint: str, callback: typing.Callable[[dict], None] = None): - """Subscribe to updates from a given endpoint. - - :param endpoint: The endpoint to subscribe. - :param callback: Callback to execute on an update. Takes a dict as only argument. - """ - if endpoint not in ['attributes', 'rpc']: - raise ValueError - if callback: - if not callable(callback): - raise TypeError - self.__worker[endpoint]['callback'] = callback - self.__worker[endpoint]['stop_event'].clear() - self.__worker[endpoint]['thread'].start() - - def unsubscribe(self, endpoint: str): - """Unsubscribe from a given endpoint. - - :param endpoint: The endpoint to unsubscribe. - """ - if endpoint not in ['attributes', 'rpc']: - raise ValueError - self.logger.debug('Set stop event for %s subscription', endpoint) - self.__worker[endpoint]['stop_event'].set() - - @classmethod - def provision(cls, host: str, device_name: str, device_key: str, device_secret: str): - """Initiate device provisioning and return a device instance. - - :param host: The root URL to the ThingsBoard instance. - :param device_name: Name of the device to provision. - :param device_key: Provisioning device key from ThingsBoard. - :param device_secret: Provisioning secret from ThingsBoard. - :return: Instance of :class:`TBHTTPClient` - """ - data = { - 'deviceName': device_name, - 'provisionDeviceKey': device_key, - 'provisionDeviceSecret': device_secret - } - response = requests.post(f'{host}/api/v1/provision', json=data) - response.raise_for_status() - device = response.json() - if device['status'] == 'SUCCESS' and device['credentialsType'] == 'ACCESS_TOKEN': - return cls(host=host, token=device['credentialsValue'], name=device_name) - raise TBProvisionFailure(device) - - -class TBHTTPClient(TBHTTPDevice): - """Legacy class name.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.logger.critical('TBHTTPClient class is deprecated, please use TBHTTPDevice') diff --git a/tb_device_mqtt.py b/tb_device_mqtt.py deleted file mode 100644 index a74448e..0000000 --- a/tb_device_mqtt.py +++ /dev/null @@ -1,1654 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is maintained for backward compatibility with version 1 of the SDK. -# It is recommended to use the new SDK structure in tb_mqtt_client for new projects. - -import logging -import warnings -from copy import deepcopy -from inspect import signature -from time import sleep -from importlib import metadata - -from orjson.orjson import OPT_NON_STR_KEYS - -from utils import install_package -from os import environ - -# Show deprecation warning -warnings.warn( - "The tb_device_mqtt module is deprecated and will be removed in a future version. " - "Please use tb_mqtt_client.service.device.client.DeviceClient instead.", - DeprecationWarning, - stacklevel=2 -) - -def check_tb_paho_mqtt_installed(): - try: - dists = metadata.distributions() - for dist in dists: - if dist.metadata["Name"].lower() == "tb-paho-mqtt-client": - files = list(dist.files) - for file in files: - if str(file).startswith("paho/mqtt"): - return True - return False - except Exception: - return False - -if not check_tb_paho_mqtt_installed(): - try: - install_package('tb-paho-mqtt-client', version='>=2.1.2') - except Exception as e: - raise ImportError("tb-paho-mqtt-client is not installed, please install it manually.") from e - -import paho.mqtt.client as paho -from paho.mqtt.enums import CallbackAPIVersion -from math import ceil - -try: - from time import monotonic, time as timestamp -except ImportError: - from time import time as timestamp -import ssl -from threading import RLock, Thread -from enum import Enum - -from paho.mqtt.reasoncodes import ReasonCodes -from paho.mqtt.client import MQTT_ERR_QUEUE_SIZE - -from orjson import dumps, loads, JSONDecodeError - -from sdk_utils import verify_checksum - -FW_TITLE_ATTR = "fw_title" -FW_VERSION_ATTR = "fw_version" -FW_CHECKSUM_ATTR = "fw_checksum" -FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" -FW_SIZE_ATTR = "fw_size" -FW_STATE_ATTR = "fw_state" - -REQUIRED_SHARED_KEYS = f"{FW_CHECKSUM_ATTR},{FW_CHECKSUM_ALG_ATTR},{FW_SIZE_ATTR},{FW_TITLE_ATTR},{FW_VERSION_ATTR}" - -RPC_RESPONSE_TOPIC = 'v1/devices/me/rpc/response/' -RPC_REQUEST_TOPIC = 'v1/devices/me/rpc/request/' -ATTRIBUTES_TOPIC = 'v1/devices/me/attributes' -ATTRIBUTES_TOPIC_REQUEST = 'v1/devices/me/attributes/request/' -ATTRIBUTES_TOPIC_RESPONSE = 'v1/devices/me/attributes/response/' -TELEMETRY_TOPIC = 'v1/devices/me/telemetry' -CLAIMING_TOPIC = 'v1/devices/me/claim' -PROVISION_TOPIC_REQUEST = '/provision/request' -PROVISION_TOPIC_RESPONSE = '/provision/response' -log = logging.getLogger('tb_connection') - -RESULT_CODES = { - 1: "incorrect protocol version", - 2: "invalid client identifier", - 3: "server unavailable", - 4: "bad username or password", - 5: "not authorized", -} - - -class TBTimeoutException(Exception): - pass - - -class TBQoSException(Exception): - pass - - -DEFAULT_TIMEOUT = 5 - - -class TBSendMethod(Enum): - SUBSCRIBE = 0 - PUBLISH = 1 - UNSUBSCRIBE = 2 - - -class TBPublishInfo: - TB_ERR_AGAIN = -1 - TB_ERR_SUCCESS = 0 - TB_ERR_NOMEM = 1 - TB_ERR_PROTOCOL = 2 - TB_ERR_INVAL = 3 - TB_ERR_NO_CONN = 4 - TB_ERR_CONN_REFUSED = 5 - TB_ERR_NOT_FOUND = 6 - TB_ERR_CONN_LOST = 7 - TB_ERR_TLS = 8 - TB_ERR_PAYLOAD_SIZE = 9 - TB_ERR_NOT_SUPPORTED = 10 - TB_ERR_AUTH = 11 - TB_ERR_ACL_DENIED = 12 - TB_ERR_UNKNOWN = 13 - TB_ERR_ERRNO = 14 - TB_ERR_QUEUE_SIZE = 15 - - ERRORS_DESCRIPTION = { - -1: 'Previous error repeated.', - 0: 'The operation completed successfully.', - 1: 'Out of memory.', - 2: 'A network protocol error occurred when communicating with the broker.', - 3: 'Invalid function arguments provided.', - 4: 'The client is not currently connected.', - 5: 'The connection was refused.', - 6: 'Entity not found (for example, trying to unsubscribe from a topic not currently subscribed to).', - 7: 'The connection was lost.', - 8: 'A TLS error occurred.', - 9: 'Payload size is too large.', - 10: 'This feature is not supported.', - 11: 'Authorization failed.', - 12: 'Access denied to the specified ACL.', - 13: 'Unknown error.', - 14: 'A system call returned an error.', - 15: 'The queue size was exceeded.', - 16: 'The keepalive time has been exceeded.' - } - - def __init__(self, message_info): - self.message_info = message_info - - # pylint: disable=invalid-name - def rc(self): - if isinstance(self.message_info, list): - for info in self.message_info: - if isinstance(info.rc, ReasonCodes): - if info.rc.value == 0: - continue - return info.rc - else: - if info.rc != 0: - return info.rc - return self.TB_ERR_SUCCESS - else: - if isinstance(self.message_info.rc, ReasonCodes): - return self.message_info.rc.value - return self.message_info.rc - - def mid(self): - if isinstance(self.message_info, list): - return [info.mid for info in self.message_info] - else: - return self.message_info.mid - - def get(self): - if isinstance(self.message_info, list): - try: - for info in self.message_info: - info.wait_for_publish(timeout=1) - except Exception as e: - global log - log = logging.getLogger('tb_connection') - log.error("Error while waiting for publish: %s", e) - else: - self.message_info.wait_for_publish(timeout=1) - return self.rc() - - -class GreedyTokenBucket: - def __init__(self, capacity, duration_sec): - self.capacity = float(capacity) - self.duration = float(duration_sec) - self.tokens = float(capacity) - self.last_updated = monotonic() - - def refill(self): - now = monotonic() - elapsed = now - self.last_updated - refill_rate = self.capacity / self.duration - refill_amount = elapsed * refill_rate - self.tokens = min(self.capacity, self.tokens + refill_amount) - self.last_updated = now - - def can_consume(self, amount=1): - self.refill() - return round(self.tokens, 6) >= round(amount, 6) - - def consume(self, amount=1): - self.refill() - if self.tokens >= amount: - self.tokens -= amount - return True - return False - - def get_remaining_tokens(self): - self.refill() - return self.tokens - - -DEFAULT_RATE_LIMIT_PERCENTAGE = environ.get('TB_DEFAULT_RATE_LIMIT_PERCENTAGE') -if DEFAULT_RATE_LIMIT_PERCENTAGE is None: - DEFAULT_RATE_LIMIT_PERCENTAGE = 80 -else: - try: - DEFAULT_RATE_LIMIT_PERCENTAGE = int(DEFAULT_RATE_LIMIT_PERCENTAGE) - except ValueError: - log.warning("Invalid value for TB_DEFAULT_RATE_LIMIT_PERCENTAGE, using default value of 80%%") - DEFAULT_RATE_LIMIT_PERCENTAGE = 80 - -class RateLimit: - def __init__(self, rate_limit, name=None, percentage=DEFAULT_RATE_LIMIT_PERCENTAGE): - self.__reached_limit_index = 0 - self.__reached_limit_index_time = 0 - self._no_limit = False - self._rate_buckets = {} - self.__lock = RLock() - self._minimal_timeout = DEFAULT_TIMEOUT - self._minimal_limit = float("inf") - - from_dict = isinstance(rate_limit, dict) - self.name = name - self.percentage = percentage - - if from_dict: - self._no_limit = rate_limit.get('no_limit', False) - self.percentage = rate_limit.get('percentage', percentage) - self.name = rate_limit.get('name', name) - - rate_limits = rate_limit.get('rateLimits', {}) - for duration_str, bucket_info in rate_limits.items(): - try: - duration = int(duration_str) - capacity = bucket_info.get("capacity") - tokens = bucket_info.get("tokens") - last_updated = bucket_info.get("last_updated") - - if capacity is None or tokens is None: - continue - - bucket = GreedyTokenBucket(capacity, duration) - bucket.tokens = min(capacity, float(tokens)) - bucket.last_updated = float(last_updated) if last_updated is not None else monotonic() - - self._rate_buckets[duration] = bucket - self._minimal_limit = min(self._minimal_limit, capacity) - self._minimal_timeout = min(self._minimal_timeout, duration + 1) - except Exception as e: - log.warning("Invalid bucket format for duration %s: %s", duration_str, e) - - else: - clean = ''.join(c for c in rate_limit if c not in [' ', ',', ';']) - if clean in ("", "0:0"): - self._no_limit = True - return - - rate_configs = rate_limit.replace(";", ",").split(",") - for rate in rate_configs: - if not rate.strip(): - continue - try: - limit_str, duration_str = rate.strip().split(":") - limit = int(int(limit_str) * self.percentage / 100) - duration = int(duration_str) - bucket = GreedyTokenBucket(limit, duration) - self._rate_buckets[duration] = bucket - self._minimal_limit = min(self._minimal_limit, limit) - self._minimal_timeout = min(self._minimal_timeout, duration + 1) - except Exception as e: - log.warning("Invalid rate limit format '%s': %s", rate, e) - - log.debug("Rate limit %s set to values:", self.name) - for duration, bucket in self._rate_buckets.items(): - log.debug("Window: %ss, Limit: %s", duration, bucket.capacity) - - def increase_rate_limit_counter(self, amount=1): - if self._no_limit: - return - with self.__lock: - for bucket in self._rate_buckets.values(): - bucket.refill() - bucket.tokens = max(0.0, bucket.tokens - amount) - - def check_limit_reached(self, amount=1): - if self._no_limit: - return False - with self.__lock: - for duration, bucket in self._rate_buckets.items(): - if not bucket.can_consume(amount): - return bucket.capacity, duration - - for duration, bucket in self._rate_buckets.items(): - log.debug("%s left tokens: %.2f per %r seconds", - self.name, - bucket.get_remaining_tokens(), - duration) - - return False - - - def get_minimal_limit(self): - return self._minimal_limit if self.has_limit() else 0 - - def get_minimal_timeout(self): - return self._minimal_timeout if self.has_limit() else 0 - - def has_limit(self): - return not self._no_limit - - def set_limit(self, rate_limit, percentage=DEFAULT_RATE_LIMIT_PERCENTAGE): - with self.__lock: - self._minimal_timeout = DEFAULT_TIMEOUT - self._minimal_limit = float("inf") - - old_buckets = deepcopy(self._rate_buckets) - self._rate_buckets = {} - self.percentage = percentage if percentage > 0 else self.percentage - - clean = ''.join(c for c in rate_limit if c not in [' ', ',', ';']) - if clean in ("", "0:0"): - self._no_limit = True - return - - rate_configs = rate_limit.replace(";", ",").split(",") - - for rate in rate_configs: - if not rate.strip(): - continue - try: - limit_str, duration_str = rate.strip().split(":") - duration = int(duration_str) - new_capacity = int(int(limit_str) * self.percentage / 100) - - previous_bucket = old_buckets.get(duration) - new_bucket = GreedyTokenBucket(new_capacity, duration) - - if previous_bucket: - previous_bucket.refill() - used = max(0.0, previous_bucket.capacity - previous_bucket.tokens) - new_tokens = new_capacity - used - new_bucket.tokens = min(new_capacity, max(0.0, new_tokens)) - new_bucket.last_updated = monotonic() - else: - new_bucket.tokens = new_capacity - new_bucket.last_updated = monotonic() - - self._rate_buckets[duration] = new_bucket - self._minimal_limit = min(self._minimal_limit, new_bucket.capacity) - self._minimal_timeout = min(self._minimal_timeout, duration + 1) - - except Exception as e: - log.warning("Invalid rate limit format '%s': %s", rate, e) - - self._no_limit = not bool(self._rate_buckets) - log.debug("Rate limit set to values:") - for duration, bucket in self._rate_buckets.items(): - log.debug("Duration: %ss, Limit: %s", duration, bucket.capacity) - - def reach_limit(self): - if self._no_limit or not self._rate_buckets: - return - - with self.__lock: - durations = sorted(self._rate_buckets.keys()) - current_monotonic = monotonic() - if self.__reached_limit_index_time >= current_monotonic - self._rate_buckets[durations[-1]].duration: - self.__reached_limit_index = 0 - self.__reached_limit_index_time = current_monotonic - if self.__reached_limit_index >= len(durations): - self.__reached_limit_index = 0 - self.__reached_limit_index_time = current_monotonic - - target_duration = durations[self.__reached_limit_index] - bucket = self._rate_buckets[target_duration] - bucket.refill() - bucket.tokens = 0.0 - - self.__reached_limit_index += 1 - log.info("Received disconnection due to rate limit for \"%s\" rate limit, waiting for tokens in bucket for %s seconds", - self.name, - target_duration) - return self.__reached_limit_index, self.__reached_limit_index_time - - @property - def __dict__(self): - rate_limits_dict = {} - for duration, bucket in self._rate_buckets.items(): - rate_limits_dict[str(duration)] = { - "capacity": bucket.capacity, - "tokens": bucket.get_remaining_tokens(), - "last_updated": bucket.last_updated - } - return { - "rateLimits": rate_limits_dict, - "name": self.name, - "percentage": self.percentage, - "no_limit": self._no_limit - } - - @staticmethod - def get_rate_limits_by_host(host, rate_limit, dp_rate_limit): - rate_limit = RateLimit.get_rate_limit_by_host(host, rate_limit) - dp_rate_limit = RateLimit.get_dp_rate_limit_by_host(host, dp_rate_limit) - - return rate_limit, dp_rate_limit - - @staticmethod - def get_rate_limit_by_host(host, rate_limit): - if rate_limit == "DEFAULT_TELEMETRY_RATE_LIMIT": - if "thingsboard.cloud" in host: - rate_limit = "10:1,60:60," - elif "tb" in host and "cloud" in host: - rate_limit = "10:1,60:60," - elif "demo.thingsboard.io" in host: - rate_limit = "10:1,60:60," - else: - rate_limit = "0:0," - elif rate_limit == "DEFAULT_MESSAGES_RATE_LIMIT": - if "thingsboard.cloud" in host: - rate_limit = "10:1,60:60," - elif "tb" in host and "cloud" in host: - rate_limit = "10:1,60:60," - elif "demo.thingsboard.io" in host: - rate_limit = "10:1,60:60," - else: - rate_limit = "0:0," - else: - rate_limit = rate_limit - - return rate_limit - - @staticmethod - def get_dp_rate_limit_by_host(host, dp_rate_limit): - if dp_rate_limit == "DEFAULT_TELEMETRY_DP_RATE_LIMIT": - if "thingsboard.cloud" in host: - dp_rate_limit = "10:1,300:60," - elif "tb" in host and "cloud" in host: - dp_rate_limit = "10:1,300:60," - elif "demo.thingsboard.io" in host: - dp_rate_limit = "10:1,300:60," - else: - dp_rate_limit = "0:0," - else: - dp_rate_limit = dp_rate_limit - - return dp_rate_limit - - -class TBDeviceMqttClient: - """ThingsBoard MQTT client. This class provides interface to send data to ThingsBoard and receive data from""" - - EMPTY_RATE_LIMIT = RateLimit('0:0,', "EMPTY_RATE_LIMIT") - - def __init__(self, host, port=1883, username=None, password=None, quality_of_service=None, client_id="", - chunk_size=0, messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", - telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", - telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT", max_payload_size=8196, **kwargs): - # Added for compatibility with old versions - if kwargs.get('rate_limit') is not None or kwargs.get('dp_rate_limit') is not None: - messages_rate_limit = messages_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', messages_rate_limit) # noqa - telemetry_rate_limit = telemetry_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', telemetry_rate_limit) # noqa - telemetry_dp_rate_limit = telemetry_dp_rate_limit if kwargs.get('dp_rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('dp_rate_limit', telemetry_dp_rate_limit) # noqa - self._client = paho.Client(protocol=5, client_id=client_id, callback_api_version=CallbackAPIVersion.VERSION2) - self.quality_of_service = quality_of_service if quality_of_service is not None else 1 - self.__host = host - self.__port = port - if username == "": - log.warning("Token is not set, connection without TLS won't be established!") - else: - self._client.username_pw_set(username, password=password) - self._lock = RLock() - - self._attr_request_dict = {} - self.stopped = False - self.__is_connected = False - self.__device_on_server_side_rpc_response = None - self.__connect_callback = None - self.__device_max_sub_id = 0 - self.__device_client_rpc_number = 0 - self.__device_sub_dict = {} - self.__device_client_rpc_dict = {} - self.__attr_request_number = 0 - self.__error_logged = 0 - self.max_payload_size = max_payload_size - self.service_configuration_callback = self.on_service_configuration - telemetry_rate_limit, telemetry_dp_rate_limit = RateLimit.get_rate_limits_by_host(self.__host, - telemetry_rate_limit, - telemetry_dp_rate_limit) - messages_rate_limit = RateLimit.get_rate_limit_by_host(self.__host, messages_rate_limit) - - self._messages_rate_limit = RateLimit(messages_rate_limit, "Rate limit for messages") - self._telemetry_rate_limit = RateLimit(telemetry_rate_limit, "Rate limit for telemetry messages") - self._telemetry_dp_rate_limit = RateLimit(telemetry_dp_rate_limit, "Rate limit for telemetry data points") - self.max_inflight_messages_set(self._telemetry_rate_limit.get_minimal_limit()) - self.__attrs_request_timeout = {} - self.__timeout_thread = Thread(target=self.__timeout_check, name="Timeout check thread") - self.__timeout_thread.daemon = True - self.__timeout_thread.start() - self._client.on_connect = self._on_connect - self._client.on_publish = self._on_publish - self._client.on_message = self._on_message - self._client.on_disconnect = self._on_disconnect - self.current_firmware_info = { - "current_" + FW_TITLE_ATTR: "Initial", - "current_" + FW_VERSION_ATTR: "v0", - FW_STATE_ATTR: "IDLE" - } - self.__request_id = 0 - self.__firmware_request_id = 0 - self.__chunk_size = chunk_size - self.firmware_received = False - self.rate_limits_received = False - self.__request_service_configuration_required = False - self.__service_loop = Thread(target=self.__service_loop, name="Service loop", daemon=True) - self.__service_loop.start() - self.__messages_limit_reached_set_time = (0,0) - self.__datapoints_limit_reached_set_time = (0,0) - - def __service_loop(self): - while not self.stopped: - if self.__request_service_configuration_required: - self.request_service_configuration(self.service_configuration_callback) - self.__request_service_configuration_required = False - elif self.firmware_received: - self.current_firmware_info[FW_STATE_ATTR] = "UPDATING" - self.send_telemetry(self.current_firmware_info) - sleep(1) - - self.__on_firmware_received(self.firmware_info.get(FW_VERSION_ATTR)) - - self.current_firmware_info = { - "current_" + FW_TITLE_ATTR: self.firmware_info.get(FW_TITLE_ATTR), - "current_" + FW_VERSION_ATTR: self.firmware_info.get(FW_VERSION_ATTR), - FW_STATE_ATTR: "UPDATED" - } - self.send_telemetry(self.current_firmware_info) - self.firmware_received = False - sleep(0.05) - - def _on_publish(self, client, userdata, mid, rc=None, properties=None): - if isinstance(rc, ReasonCodes) and rc.value != 0: - log.debug("Publish failed with result code %s (%s) ", str(rc.value), rc.getName()) - if rc.value in [151, 131]: - if self.__messages_limit_reached_set_time[1] - monotonic() > self.__messages_limit_reached_set_time[0]: - self.__messages_limit_reached_set_time = self._messages_rate_limit.reach_limit() - if self.__datapoints_limit_reached_set_time[1] - monotonic() > self.__datapoints_limit_reached_set_time[0]: - self._telemetry_dp_rate_limit.reach_limit() - if rc.value == 0: - if self.__messages_limit_reached_set_time[0] > 0 and self.__messages_limit_reached_set_time[1] > 0: - self.__messages_limit_reached_set_time = (0, 0) - if self.__datapoints_limit_reached_set_time[0] > 0 and self.__datapoints_limit_reached_set_time[1] > 0: - self.__datapoints_limit_reached_set_time = (0, 0) - - def _on_disconnect(self, client: paho.Client, userdata, disconnect_flags, reason=None, properties=None): - self.__is_connected = False - with self._client._out_message_mutex: - client._out_packet.clear() - client._out_messages.clear() - client._in_messages.clear() - self.__attr_request_number = 0 - self.__device_max_sub_id = 0 - self.__device_client_rpc_number = 0 - self.__device_sub_dict = {} - self.__device_client_rpc_dict = {} - self.__attrs_request_timeout = {} - result_code = reason.value - if disconnect_flags.is_disconnect_packet_from_server: - log.warning("MQTT client was disconnected by server with reason code %s (%s) ", - str(result_code), reason.getName()) - else: - log.info("MQTT client was disconnected by client with reason code %s (%s) ", - str(result_code), reason.getName()) - log.debug("Client: %s, user data: %s, result code: %s. Description: %s", - str(client), str(userdata), - str(result_code), reason.getName()) - - def _on_connect(self, client, userdata, connect_flags, result_code, properties, *extra_params): - if result_code == 0: - self.__is_connected = True - log.info("MQTT client %r - Connected!", client) - if properties: - log.debug("MQTT client %r - CONACK Properties: %r", client, properties) - config = {} - if hasattr(properties, 'MaximumPacketSize'): - config['maxPayloadSize'] = int(properties.MaximumPacketSize * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - if hasattr(properties, 'ReceiveMaximum'): - config['maxInflightMessages'] = properties.ReceiveMaximum - if config: - self.on_service_configuration(None, config) - self._subscribe_to_topic(ATTRIBUTES_TOPIC, qos=self.quality_of_service) - self._subscribe_to_topic(ATTRIBUTES_TOPIC + "/response/+", qos=self.quality_of_service) - self._subscribe_to_topic(RPC_REQUEST_TOPIC + '+', qos=self.quality_of_service) - self._subscribe_to_topic(RPC_RESPONSE_TOPIC + '+', qos=self.quality_of_service) - self.__request_service_configuration_required = True - else: - log.error("Connection failed with result code %s (%s) ", - str(result_code.value), result_code.getName()) - - if callable(self.__connect_callback): - sleep(.2) - if "tb_client" in signature(self.__connect_callback).parameters: - self.__connect_callback(client, userdata, connect_flags, result_code, properties, *extra_params, tb_client=self) - else: - self.__connect_callback(client, userdata, connect_flags, result_code, *extra_params) - - if result_code.value in [159, 151]: - log.debug("Connection rate limit reached, waiting before reconnecting...") - sleep(1) # Wait for 1 second before reconnecting, if connection rate limit is reached - log.debug("Reconnecting allowed...") - - def get_firmware_update(self): - self._client.subscribe("v2/fw/response/+") - self.send_telemetry(self.current_firmware_info) - self.__request_firmware_info() - - def __request_firmware_info(self): - self.__request_id = self.__request_id + 1 - self._publish_data({"sharedKeys": REQUIRED_SHARED_KEYS}, - f"v1/devices/me/attributes/request/{self.__request_id}", - 1) - - def is_connected(self): - return self.__is_connected - - def connect(self, callback=None, min_reconnect_delay=1, timeout=120, tls=False, ca_certs=None, cert_file=None, - key_file=None, keepalive=120): - """Connect to ThingsBoard. The callback will be called when the connection is established.""" - if tls: - try: - self._client.tls_set(ca_certs=ca_certs, - certfile=cert_file, - keyfile=key_file, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=ssl.PROTOCOL_TLSv1_2, - ciphers=None) - self._client.tls_insecure_set(False) - except ValueError: - pass - self.reconnect_delay_set(min_reconnect_delay, timeout) - self._client.connect(self.__host, self.__port, keepalive=keepalive) - self._client.loop_start() - self.__connect_callback = callback - - def disconnect(self): - """Disconnect from ThingsBoard.""" - result = self._client.disconnect() - log.debug(self._client) - log.debug("Disconnecting from ThingsBoard") - self.__is_connected = False - self._client.loop_stop() - return result - - def stop(self): - self.stopped = True - - def _on_message(self, client, userdata, message): - update_response_pattern = "v2/fw/response/" + str(self.__firmware_request_id) + "/chunk/" - if message.topic.startswith(update_response_pattern): - firmware_data = message.payload - - self.firmware_data = self.firmware_data + firmware_data - self.__current_chunk = self.__current_chunk + 1 - - log.debug('Getting chunk with number: %s. Chunk size is : %r byte(s).' % ( - self.__current_chunk, self.__chunk_size)) - - if len(self.firmware_data) == self.__target_firmware_length: - self.__process_firmware() - else: - self.__get_firmware() - else: - content = self._decode(message) - self._on_decoded_message(content, message) - - def _on_decoded_message(self, content, message): - if message.topic.startswith(RPC_REQUEST_TOPIC): - self._messages_rate_limit.increase_rate_limit_counter() - request_id = message.topic[len(RPC_REQUEST_TOPIC):len(message.topic)] - if self.__device_on_server_side_rpc_response: - self.__device_on_server_side_rpc_response(request_id, content) - elif message.topic.startswith(RPC_RESPONSE_TOPIC): - self._messages_rate_limit.increase_rate_limit_counter() - with self._lock: - request_id = int(message.topic[len(RPC_RESPONSE_TOPIC):len(message.topic)]) - if self.__device_client_rpc_dict.get(request_id): - callback = self.__device_client_rpc_dict.pop(request_id) - else: - callback = None - if callback is not None: - callback(request_id, content, None) - elif message.topic == ATTRIBUTES_TOPIC: - self._messages_rate_limit.increase_rate_limit_counter() - dict_results = [] - with self._lock: - # callbacks for everything - if self.__device_sub_dict.get("*"): - for subscription_id in self.__device_sub_dict["*"]: - dict_results.append(self.__device_sub_dict["*"][subscription_id]) - # specific callback - keys = content.keys() - keys_list = [] - for key in keys: - keys_list.append(key) - # iterate through message - for key in keys_list: - # find key in our dict - if self.__device_sub_dict.get(key): - for subscription in self.__device_sub_dict[key]: - dict_results.append(self.__device_sub_dict[key][subscription]) - for res in dict_results: - res(content, None) - elif message.topic.startswith(ATTRIBUTES_TOPIC_RESPONSE): - self._messages_rate_limit.increase_rate_limit_counter() - with self._lock: - req_id = int(message.topic[len(ATTRIBUTES_TOPIC + "/response/"):]) - # pop callback and use it - if self._attr_request_dict.get(req_id): - callback = self._attr_request_dict.pop(req_id) - else: - callback = None - if isinstance(callback, tuple): - callback[0](content, None, callback[1]) - elif callback is not None: - callback(content, None) - else: - log.debug("Message received with topic: %s", message.topic) - - if message.topic.startswith("v1/devices/me/attributes"): - self._messages_rate_limit.increase_rate_limit_counter() - self.firmware_info = loads(message.payload) - if "/response/" in message.topic: - self.firmware_info = self.firmware_info.get("shared", {}) if isinstance(self.firmware_info, dict) else {} # noqa - if ((self.firmware_info.get(FW_VERSION_ATTR) is not None - and self.firmware_info.get(FW_VERSION_ATTR) != self.current_firmware_info.get("current_" + FW_VERSION_ATTR)) # noqa - or (self.firmware_info.get(FW_TITLE_ATTR) is not None - and self.firmware_info.get(FW_TITLE_ATTR) != self.current_firmware_info.get("current_" + FW_TITLE_ATTR))): # noqa - log.debug('Firmware is not the same') - self.firmware_data = b'' - self.__current_chunk = 0 - - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADING" - self.send_telemetry(self.current_firmware_info) - sleep(1) - - self.__firmware_request_id = self.__firmware_request_id + 1 - self.__target_firmware_length = self.firmware_info[FW_SIZE_ATTR] - self.__chunk_count = 0 if not self.__chunk_size else ceil( - self.firmware_info[FW_SIZE_ATTR] / self.__chunk_size) - self.__get_firmware() - - def __process_firmware(self): - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADED" - self.send_telemetry(self.current_firmware_info) - sleep(1) - - verification_result = verify_checksum(self.firmware_data, self.firmware_info.get(FW_CHECKSUM_ALG_ATTR), - self.firmware_info.get(FW_CHECKSUM_ATTR)) - - if verification_result: - log.debug('Checksum verified!') - self.current_firmware_info[FW_STATE_ATTR] = "VERIFIED" - self.send_telemetry(self.current_firmware_info) - sleep(1) - else: - log.debug('Checksum verification failed!') - self.current_firmware_info[FW_STATE_ATTR] = "FAILED" - self.send_telemetry(self.current_firmware_info) - self.__request_firmware_info() - return - self.firmware_received = True - - def __get_firmware(self): - payload = '' if not self.__chunk_size or self.__chunk_size > self.firmware_info.get(FW_SIZE_ATTR, 0) \ - else str(self.__chunk_size).encode() - self._client.publish( - f"v2/fw/request/{self.__firmware_request_id}/chunk/{self.__current_chunk}", - payload=payload, qos=1) - - def __on_firmware_received(self, version_to): - with open(self.firmware_info.get(FW_TITLE_ATTR), "wb") as firmware_file: - firmware_file.write(self.firmware_data) - log.info('Firmware is updated!\n Current firmware version is: %s' % version_to) - - @staticmethod - def _decode(message): - try: - if isinstance(message.payload, bytes): - content = loads(message.payload.decode("utf-8", "ignore")) - else: - content = loads(message.payload) - except JSONDecodeError: - try: - content = message.payload.decode("utf-8", "ignore") - except JSONDecodeError: - content = message.payload - return content - - def max_inflight_messages_set(self, inflight): - """Set the maximum number of messages with QoS>0 that can be a part way through their network flow at once. - Defaults to minimal rate limit. Increasing this value will consume more memory but can increase throughput.""" - if inflight < 0: - log.error("Inflight messages number must be equal or greater than 0") - return - self._client._max_inflight_messages = inflight - - def max_queued_messages_set(self, queue_size): - """Set the maximum number of outgoing messages with QoS>0 that can be pending in the outgoing message queue. - Defaults to 0. 0 means unlimited. When the queue is full, any further outgoing messages would be dropped.""" - if queue_size < 0: - raise ValueError("Invalid queue size.") - - self._client._max_queued_messages = queue_size - - def reconnect_delay_set(self, min_delay=1, max_delay=120): - """The client will automatically retry connection. Between each attempt it will wait a number of seconds - between min_delay and max_delay. When the connection is lost, initially the reconnection attempt is delayed - of min_delay seconds. It’s doubled between subsequent attempt up to max_delay. The delay is reset to min_delay - when the connection complete (e.g. the CONNACK is received, not just the TCP connection is established).""" - self._client.reconnect_delay_set(min_delay, max_delay) - - def send_rpc_reply(self, req_id, resp, quality_of_service=None, wait_for_publish=False): - """Send RPC reply to ThingsBoard. The response will be sent to the RPC_RESPONSE_TOPIC with the request id.""" - quality_of_service = quality_of_service if quality_of_service is not None else self.quality_of_service - if quality_of_service not in (0, 1): - log.error("Quality of service (qos) value must be 0 or 1") - return None - info = self._publish_data(resp, RPC_RESPONSE_TOPIC + req_id, quality_of_service) - if wait_for_publish: - info.get() - - def send_rpc_call(self, method, params, callback): - """Send RPC call to ThingsBoard. The callback will be called when the response is received.""" - with self._lock: - self.__device_client_rpc_number += 1 - self.__device_client_rpc_dict.update({self.__device_client_rpc_number: callback}) - rpc_request_id = self.__device_client_rpc_number - payload = {"method": method, "params": params} - self._publish_data(payload, RPC_REQUEST_TOPIC + str(rpc_request_id), self.quality_of_service) - - def request_service_configuration(self, callback): - self.send_rpc_call("getSessionLimits", {"timeout": 5000}, callback) - - def on_service_configuration(self, _, response, *args, **kwargs): - global log - log = logging.getLogger('tb_connection') - if "error" in response: - log.warning("Timeout while waiting for service configuration!, session will use default configuration.") - self.rate_limits_received = True - return - service_config = response - if not isinstance(service_config, dict) or 'rateLimits' not in service_config: - log.warning("Cannot retrieve service configuration, session will use default configuration.") - log.debug("Received the following response: %r", service_config) - return - if service_config.get("rateLimits"): - rate_limits_config = service_config.get("rateLimits") - - if rate_limits_config.get('messages'): - self._messages_rate_limit.set_limit(rate_limits_config.get('messages')) - else: - self._messages_rate_limit.set_limit('0:0,') - - if rate_limits_config.get('telemetryMessages'): - self._telemetry_rate_limit.set_limit(rate_limits_config.get('telemetryMessages')) - else: - self._telemetry_rate_limit.set_limit('0:0,') - - if rate_limits_config.get('telemetryDataPoints'): - self._telemetry_dp_rate_limit.set_limit(rate_limits_config.get('telemetryDataPoints')) - else: - self._telemetry_dp_rate_limit.set_limit('0:0,') - - if service_config.get('maxInflightMessages'): - use_messages_rate_limit_factor = self._messages_rate_limit.has_limit() - use_telemetry_rate_limit_factor = self._telemetry_rate_limit.has_limit() - service_config_inflight_messages = int(service_config.get('maxInflightMessages', 100)) - if use_messages_rate_limit_factor and use_telemetry_rate_limit_factor: - max_inflight_messages = int(min(self._messages_rate_limit.get_minimal_limit(), - self._telemetry_rate_limit.get_minimal_limit(), - service_config_inflight_messages) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - elif use_messages_rate_limit_factor: - max_inflight_messages = int(min(self._messages_rate_limit.get_minimal_limit(), - service_config_inflight_messages) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - elif use_telemetry_rate_limit_factor: - max_inflight_messages = int(min(self._telemetry_rate_limit.get_minimal_limit(), - service_config_inflight_messages) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - else: - max_inflight_messages = int(service_config.get('maxInflightMessages', 100) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - if max_inflight_messages == 0: - max_inflight_messages = 10_000 # No limitation on device queue on transport level - if max_inflight_messages < 1: - max_inflight_messages = 1 - self.max_inflight_messages_set(max_inflight_messages) - - if (not self._messages_rate_limit.has_limit() and - not self._telemetry_rate_limit.has_limit() and - not self._telemetry_dp_rate_limit.has_limit() and - not kwargs.get("gateway_limits_present", False)): - log.debug("No rate limits for device, setting max_queued_messages to 50000") - self.max_queued_messages_set(50000) - else: - log.debug("Rate limits for device, setting max_queued_messages to %r", max_inflight_messages) - self.max_queued_messages_set(max_inflight_messages) - - if service_config.get('maxPayloadSize'): - self.max_payload_size = int(int(service_config.get('maxPayloadSize')) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - log.info("Service configuration was successfully retrieved and applied.") - log.info("Current device limits: %r", service_config) - self.rate_limits_received = True - - def set_server_side_rpc_request_handler(self, handler): - """Set the callback that will be called when a server-side RPC is received.""" - self.__device_on_server_side_rpc_response = handler - - def _wait_for_rate_limit_released(self, timeout, message_rate_limit, dp_rate_limit=None, amount=1): - if not message_rate_limit.has_limit() and not (dp_rate_limit is None or dp_rate_limit.has_limit()): - return - start_time = int(monotonic()) - dp_rate_limit_timeout = dp_rate_limit.get_minimal_timeout() if dp_rate_limit is not None else 0 - timeout = max(message_rate_limit.get_minimal_timeout(), dp_rate_limit_timeout, timeout) + 10 - timeout_updated = False - disconnected = False - limit_reached_check = True - log_posted = False - waited = False - while limit_reached_check: - - message_rate_limit_check = message_rate_limit.check_limit_reached() - datapoints_rate_limit_check = dp_rate_limit.check_limit_reached(amount=amount) if dp_rate_limit is not None else False - limit_reached_check = (message_rate_limit_check - or datapoints_rate_limit_check - or not self.is_connected()) - if isinstance(limit_reached_check, tuple) and timeout < limit_reached_check[1]: - timeout = limit_reached_check[1] - if not timeout_updated and limit_reached_check: - timeout += 10 - timeout_updated = True - if self.stopped: - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if not disconnected and not self.is_connected(): - log.warning("Waiting for connection to be established before sending data to ThingsBoard!") - disconnected = True - timeout = max(timeout, 180) + 10 - if int(monotonic()) >= timeout + start_time: - if message_rate_limit_check: - log.warning("Timeout while waiting for rate limit for messages to be released! Rate limit: %r:%r", - message_rate_limit_check, - message_rate_limit_check) - elif datapoints_rate_limit_check: - log.warning("Timeout while waiting for rate limit for data points to be released! Rate limit: %r:%r", - datapoints_rate_limit_check, - datapoints_rate_limit_check) - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if not log_posted and limit_reached_check: - if log.isEnabledFor(logging.DEBUG): - if isinstance(message_rate_limit_check, tuple): - log.debug("Rate limit for messages (%r messages per %r second(s)) - almost reached, waiting for rate limit to be released...", - *message_rate_limit_check) - if isinstance(datapoints_rate_limit_check, tuple): - log.debug("Rate limit for data points (%r data points per %r second(s)) - almost reached, waiting for rate limit to be released...", - *datapoints_rate_limit_check) - waited = True - log_posted = True - if limit_reached_check: - sleep(.005) - if waited: - log.debug("Rate limit released, sending data to ThingsBoard...") - - def _wait_until_current_queued_messages_processed(self): - logger = None - - max_wait_time = 300 - log_interval = 5 - stuck_threshold = 15 - polling_interval = 0.05 - max_inflight = self._client._max_inflight_messages - - if len(self._client._out_messages) < max_inflight or max_inflight == 0: - return - - waiting_start = monotonic() - last_log_time = waiting_start - last_queue_size = len(self._client._out_messages) - last_queue_change_time = waiting_start - - while not self.stopped: - now = monotonic() - elapsed = now - waiting_start - current_queue_size = len(self._client._out_messages) - - if current_queue_size < max_inflight: - return - - if current_queue_size != last_queue_size: - last_queue_size = current_queue_size - last_queue_change_time = now - - if (now - last_queue_change_time > stuck_threshold - and not self._client.is_connected()): - if logger is None: - logger = logging.getLogger('tb_connection') - logger.warning( - "MQTT out_messages queue is stuck (%d messages) and client is disconnected. " - "Clearing queue after %.2f seconds.", - current_queue_size, now - last_queue_change_time - ) - with self._client._out_message_mutex: - self._client._out_packet.clear() - return - - if now - last_log_time >= log_interval: - if logger is None: - logger = logging.getLogger('tb_connection') - logger.debug( - "Waiting for MQTT queue to drain: %d messages (max inflight %d). " - "Elapsed: %.2f s", - current_queue_size, max_inflight, elapsed - ) - last_log_time = now - - if elapsed > max_wait_time: - if logger is None: - logger = logging.getLogger('tb_connection') - logger.warning( - "MQTT wait timeout reached (%.2f s). Queue still has %d messages.", - elapsed, current_queue_size - ) - return - - sleep(polling_interval) - - def _send_request(self, _type, kwargs, timeout=DEFAULT_TIMEOUT, device=None, - msg_rate_limit=None, dp_rate_limit=None): - topic = kwargs['topic'] - if msg_rate_limit is None: - if topic == TELEMETRY_TOPIC or topic ==ATTRIBUTES_TOPIC: - msg_rate_limit = self._telemetry_rate_limit - else: - msg_rate_limit = self._messages_rate_limit - if dp_rate_limit is None: - if topic == TELEMETRY_TOPIC or topic ==ATTRIBUTES_TOPIC: - dp_rate_limit = self._telemetry_dp_rate_limit - else: - dp_rate_limit = self.EMPTY_RATE_LIMIT - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - msg_rate_limit.increase_rate_limit_counter() - is_reached = self._wait_for_rate_limit_released(timeout, msg_rate_limit, dp_rate_limit) - if is_reached: - return is_reached - - if _type == TBSendMethod.PUBLISH: - self.__add_metadata_to_data_dict_from_device(kwargs["payload"]) - return self.__send_publish_with_limitations(kwargs, timeout, device, msg_rate_limit, dp_rate_limit) - elif _type == TBSendMethod.SUBSCRIBE: - return self._client.subscribe(**kwargs) - elif _type == TBSendMethod.UNSUBSCRIBE: - return self._client.unsubscribe(**kwargs) - - def __add_metadata_to_data_dict_from_device(self, data): - if isinstance(data, dict) and ("metadata" in data and isinstance(data["metadata"], dict)): - data["metadata"]["publishedTs"] = int(timestamp() * 1000) - elif isinstance(data, list): - current_time = int(timestamp() * 1000) - for data_item in data: - if isinstance(data_item, dict): - if 'ts' in data_item and ('metadata' in data_item and isinstance(data_item["metadata"], dict)): - data_item["metadata"]["publishedTs"] = current_time - elif isinstance(data, dict): - for key, value in data.items(): - self.__add_metadata_to_data_dict_from_device(value) - - def __get_rate_limits_by_topic(self, topic, device=None, msg_rate_limit=None, dp_rate_limit=None): - if device is not None: - return msg_rate_limit, dp_rate_limit - else: - if topic == TELEMETRY_TOPIC: - return self._telemetry_rate_limit, self._telemetry_dp_rate_limit - else: - return self._messages_rate_limit, None - - def __send_publish_with_limitations(self, kwargs, timeout, device=None, msg_rate_limit: RateLimit = None, - dp_rate_limit: RateLimit = None): - data = kwargs.get("payload") - if isinstance(data, str): - data = loads(data) - topic = kwargs["topic"] - attributes_format = topic.endswith('attributes') - if topic.endswith('telemetry') or attributes_format: - if device is None or data.get(device) is None: - device_split_messages = self._split_message(data, int(dp_rate_limit.get_minimal_limit()), self.max_payload_size) # noqa - if attributes_format: - split_messages = [{'message': msg_data, 'datapoints': len(msg_data)} for split_message in device_split_messages for msg_data in split_message['data']] # noqa - else: - split_messages = [{'message': split_message['data'], 'datapoints': split_message['datapoints']} for split_message in device_split_messages] # noqa - else: - device_data = data.get(device) - device_split_messages = self._split_message(device_data, int(dp_rate_limit.get_minimal_limit()), self.max_payload_size) # noqa - if attributes_format: - split_messages = [{'message': {device: msg_data}, 'datapoints': len(msg_data)} for split_message in device_split_messages for msg_data in split_message['data']] # noqa - else: - split_messages = [{'message': {device: split_message['data']}, 'datapoints': split_message['datapoints']} for split_message in device_split_messages] # noqa - else: - split_messages = [{'message': data, 'datapoints': self._count_datapoints_in_message(data, device)}] - - results = [] - for part in split_messages: - if not part: - continue - self.__send_split_message(results, part, kwargs, timeout, device, msg_rate_limit, dp_rate_limit, topic) - return TBPublishInfo(results) - - def __send_split_message(self, results, part, kwargs, timeout, device, msg_rate_limit, dp_rate_limit, - topic): - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - dp_rate_limit.increase_rate_limit_counter(part['datapoints']) - rate_limited = self._wait_for_rate_limit_released(timeout, - message_rate_limit=msg_rate_limit, - dp_rate_limit=dp_rate_limit, - amount=part['datapoints']) - if rate_limited: - return rate_limited - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - msg_rate_limit.increase_rate_limit_counter() - kwargs["payload"] = dumps(part['message'], option=OPT_NON_STR_KEYS) - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - self._wait_until_current_queued_messages_processed() - if not self.stopped: - if device is not None: - log.debug("Device: %s, Sending message to topic: %s ", device, topic) - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - if part['datapoints'] > 0: - log.debug("Sending message with %i datapoints", part['datapoints']) - if log.isEnabledFor(5) and hasattr(log, 'trace'): - log.trace("Message payload: %r", kwargs["payload"]) - log.debug("Rate limits after sending message: %r", msg_rate_limit.__dict__) - log.debug("Data points rate limits after sending message: %r", dp_rate_limit.__dict__) - else: - if log.isEnabledFor(5) and hasattr(log, 'trace'): - log.trace("Sending message with %r", kwargs["payload"]) - log.debug("Rate limits after sending message: %r", msg_rate_limit.__dict__) - log.debug("Data points rate limits after sending message: %r", dp_rate_limit.__dict__) - result = self._client.publish(**kwargs) - if result.rc == MQTT_ERR_QUEUE_SIZE: - error_appear_counter = 1 - sleep_time = 0.1 # 100 ms, in case of change - change max tries in while loop - while not self.stopped and result.rc == MQTT_ERR_QUEUE_SIZE: - error_appear_counter += 1 - if error_appear_counter > 78: # 78 tries ~ totally 300 seconds for sleep 0.1 - log.warning("Cannot send message to platform in %i seconds, queue size exceeded, current max inflight messages: %r, max queued messages: %r.", # noqa - int(error_appear_counter * sleep_time), - self._client._max_inflight_messages, - self._client._max_queued_messages) - if int(monotonic()) - self.__error_logged > 10: - log.debug("Queue size exceeded, waiting for messages to be processed by paho client.") - self.__error_logged = int(monotonic()) - sleep(sleep_time) # Give some time for paho to process messages - result = self._client.publish(**kwargs) - results.append(result) - - def _subscribe_to_topic(self, topic, qos=None, timeout=DEFAULT_TIMEOUT): - if qos is None: - qos = self.quality_of_service - - waiting_for_connection_message_time = 0 - while not self.is_connected() and not self.stopped: - if self.stopped: - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if monotonic() - waiting_for_connection_message_time > 10.0: - log.warning("Waiting for connection to be established before subscribing for data on ThingsBoard!") - waiting_for_connection_message_time = monotonic() - sleep(0.01) - - return self._send_request(TBSendMethod.SUBSCRIBE, - {"topic": topic, "qos": qos}, - timeout, - msg_rate_limit=self._messages_rate_limit) - - def _publish_data(self, data, topic, qos, timeout=DEFAULT_TIMEOUT, device=None, - msg_rate_limit=None, dp_rate_limit=None): - if qos is None: - qos = self.quality_of_service - if qos not in (0, 1): - log.exception("Quality of service (qos) value must be 0 or 1") - raise TBQoSException("Quality of service (qos) value must be 0 or 1") - - waiting_for_connection_message_time = 0.0 - while not self.is_connected(): - if self.stopped: - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if monotonic() - waiting_for_connection_message_time > 10.0: - log.warning("Waiting for connection to be established before sending data to ThingsBoard!") - waiting_for_connection_message_time = monotonic() - sleep(0.01) - - return self._send_request(TBSendMethod.PUBLISH, {"topic": topic, "payload": data, "qos": qos}, timeout, - device=device, msg_rate_limit=msg_rate_limit, dp_rate_limit=dp_rate_limit) - - def send_telemetry(self, telemetry, quality_of_service=None, wait_for_publish=True): - """Send telemetry to ThingsBoard. The telemetry can be a single dictionary or a list of dictionaries.""" - quality_of_service = quality_of_service if quality_of_service is not None else self.quality_of_service - if not isinstance(telemetry, list) and not (isinstance(telemetry, dict) and telemetry.get("ts") is not None): - telemetry = [telemetry] - return self._publish_data(telemetry, TELEMETRY_TOPIC, quality_of_service, wait_for_publish) - - def send_attributes(self, attributes, quality_of_service=None, wait_for_publish=True): - """Send attributes to ThingsBoard. The attributes can be a single dictionary or a list of dictionaries.""" - quality_of_service = quality_of_service if quality_of_service is not None else self.quality_of_service - return self._publish_data(attributes, ATTRIBUTES_TOPIC, quality_of_service, wait_for_publish) - - def unsubscribe_from_attribute(self, subscription_id): - """Unsubscribe from attribute updates for subscription_id.""" - with self._lock: - for attribute in self.__device_sub_dict: - if self.__device_sub_dict[attribute].get(subscription_id): - del self.__device_sub_dict[attribute][subscription_id] - log.debug("Unsubscribed from %s, subscription id %i", attribute, subscription_id) - if subscription_id == '*': - self.__device_sub_dict = {} - self.__device_sub_dict = dict((k, v) for k, v in self.__device_sub_dict.items() if v) - - def clean_device_sub_dict(self): - self.__device_sub_dict = {} - - def subscribe_to_all_attributes(self, callback): - """Subscribe to all attribute updates. The callback will be called when an attribute update is received.""" - return self.subscribe_to_attribute("*", callback) - - def subscribe_to_attribute(self, key, callback): - """Subscribe to attribute updates for attribute with key. - The callback will be called when an attribute update is received.""" - with self._lock: - self.__device_max_sub_id += 1 - if key not in self.__device_sub_dict: - self.__device_sub_dict.update({key: {self.__device_max_sub_id: callback}}) - else: - self.__device_sub_dict[key].update({self.__device_max_sub_id: callback}) - log.debug("Subscribed to %s with id %i", key, self.__device_max_sub_id) - return self.__device_max_sub_id - - def request_attributes(self, client_keys=None, shared_keys=None, callback=None): - """Request attributes from ThingsBoard. The callback will be called when the response is received.""" - msg = {} - if client_keys: - tmp = "" - for key in client_keys: - tmp += key + "," - tmp = tmp[:len(tmp) - 1] - msg.update({"clientKeys": tmp}) - if shared_keys: - tmp = "" - for key in shared_keys: - tmp += key + "," - tmp = tmp[:len(tmp) - 1] - msg.update({"sharedKeys": tmp}) - - start_processing_attribute_request = int(monotonic()) - - attr_request_number = self._add_attr_request_callback(callback) - - info = self._publish_data(msg, ATTRIBUTES_TOPIC_REQUEST + str(attr_request_number), self.quality_of_service) - - self.__attrs_request_timeout[attr_request_number] = start_processing_attribute_request + 20 - return info - - def _add_attr_request_callback(self, callback): - with self._lock: - self.__attr_request_number += 1 - self._attr_request_dict.update({self.__attr_request_number: callback}) - attr_request_number = self.__attr_request_number - return attr_request_number - - def __timeout_check(self): - while not self.stopped: - current_time = int(monotonic()) - for (attr_request_number, ts) in tuple(self.__attrs_request_timeout.items()): - if current_time < ts: - continue - - with self._lock: - callback = None - if self._attr_request_dict.get(attr_request_number): - callback = self._attr_request_dict.pop(attr_request_number) - - if callback is not None: - if isinstance(callback, tuple): - callback[0](None, TBTimeoutException("Timeout while waiting for a reply for attribute request from ThingsBoard!"), # noqa - callback[1]) - else: - callback(None, TBTimeoutException("Timeout while waiting for a reply for attribute request from ThingsBoard!")) # noqa - - self.__attrs_request_timeout.pop(attr_request_number) - - sleep(0.1) - - def claim(self, secret_key, duration=30000): - """Claim the device in Thingsboard. The duration is in milliseconds.""" - claiming_request = { - "secretKey": secret_key, - "durationMs": duration - } - info = self._publish_data(claiming_request, CLAIMING_TOPIC, self.quality_of_service) - return info - - @staticmethod - def _count_datapoints_in_message(data, device=None): - datapoints = 0 - if device is not None: - if isinstance(data.get(device), list): - for data_object in data[device]: - datapoints += TBDeviceMqttClient._count_datapoints_in_message(data_object) # noqa - elif isinstance(data.get(device), dict): - datapoints += TBDeviceMqttClient._count_datapoints_in_message(data.get(device, data.get('device'))) - else: - datapoints += 1 - else: - if isinstance(data, dict): - datapoints += TBDeviceMqttClient._get_data_points_from_message(data) - elif isinstance(data, list): - for item in data: - datapoints += TBDeviceMqttClient._get_data_points_from_message(item) - else: - datapoints += 1 - return datapoints - - @staticmethod - def _get_data_points_from_message(data): - if isinstance(data, dict): - if data.get("ts") is not None and data.get("values") is not None: - datapoints_in_message_amount = len(data['values']) + len(str(data['values'])) / 1000 - else: - datapoints_in_message_amount = len(data.keys()) + len(str(data)) / 1000 - else: - datapoints_in_message_amount = len(data) + len(str(data)) / 1000 - return int(datapoints_in_message_amount) - - @staticmethod - def provision(host, - provision_device_key, - provision_device_secret, - port=1883, - device_name=None, - access_token=None, - client_id=None, - username=None, - password=None, - hash=None, - gateway=None): - """Provision the device in ThingsBoard. Returns the credentials for the device.""" - provision_request = { - "provisionDeviceKey": provision_device_key, - "provisionDeviceSecret": provision_device_secret - } - - if access_token is not None: - provision_request["token"] = access_token - provision_request["credentialsType"] = "ACCESS_TOKEN" - elif username is not None or password is not None or client_id is not None: - provision_request["username"] = username - provision_request["password"] = password - provision_request["clientId"] = client_id - provision_request["credentialsType"] = "MQTT_BASIC" - elif hash is not None: - provision_request["hash"] = hash - provision_request["credentialsType"] = "X509_CERTIFICATE" - - if device_name is not None: - provision_request["deviceName"] = device_name - - if gateway is not None: - provision_request["gateway"] = gateway - - provisioning_client = ProvisionClient(host=host, port=port, provision_request=provision_request) - provisioning_client.provision() - return provisioning_client.get_credentials() - - @staticmethod - def _split_message(message_pack, datapoints_max_count, max_payload_size): - if not message_pack: - return [] - - split_messages = [] - - if isinstance(message_pack, dict) and message_pack.get('device') and len(message_pack) in [1, 2]: - return [{ - 'data': message_pack, - 'datapoints': TBDeviceMqttClient._count_datapoints_in_message(message_pack), - 'message': message_pack - }] - - if not isinstance(message_pack, list): - message_pack = [message_pack] - - def _get_metadata_repr(metadata): - return tuple(sorted(metadata.items())) if isinstance(metadata, dict) else None - - def estimate_chunk_size(chunk): - if isinstance(chunk, dict) and "values" in chunk: - size = sum(len(str(k)) + len(str(v)) for k, v in chunk["values"].items()) - size += len(str(chunk.get("ts", ""))) - if "metadata" in chunk: - size += sum(len(str(k)) + len(str(v)) for k, v in chunk["metadata"].items()) - return size + 40 - elif isinstance(chunk, dict): - return sum(len(str(k)) + len(str(v)) for k, v in chunk.items()) + 20 - else: - return len(str(chunk)) + 20 - - ts_group_cache = {} - current_message = {"data": [], "datapoints": 0} - current_size = 0 - current_datapoints = 0 - - def flush_current_message(): - nonlocal current_message, current_size, current_datapoints - if current_message["data"]: - split_messages.append(current_message) - current_message = {"data": [], "datapoints": 0} - current_size = 0 - current_datapoints = 0 - - def split_and_add_chunk(chunk, chunk_datapoints): - nonlocal current_message, current_size, current_datapoints - - chunk_size = estimate_chunk_size(chunk) - - if (0 < datapoints_max_count <= current_datapoints + chunk_datapoints) or \ - (current_size + chunk_size > max_payload_size): - flush_current_message() - - if chunk_datapoints > datapoints_max_count > 0 or chunk_size > max_payload_size: - keys = list(chunk.get("values", {}).keys()) if isinstance(chunk, dict) else list(chunk.keys()) - if len(keys) == 1: - flush_current_message() - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - flush_current_message() - return - - max_step = max(1, datapoints_max_count if datapoints_max_count > 0 else len(keys)) - for i in range(0, len(keys), max_step): - sub_values = ( - {k: chunk["values"][k] for k in keys[i:i + max_step]} if "values" in chunk - else {k: chunk[k] for k in keys[i:i + max_step]} - ) - sub_chunk = {} - if "ts" in chunk: - sub_chunk = {"ts": chunk["ts"], "values": sub_values} - if "metadata" in chunk: - sub_chunk["metadata"] = chunk["metadata"] - else: - sub_chunk = sub_values.copy() - - sub_datapoints = len(sub_values) - sub_size = estimate_chunk_size(sub_chunk) - - if sub_size > max_payload_size or (0 < datapoints_max_count <= sub_datapoints): - flush_current_message() - current_message["data"].append(sub_chunk) - current_message["datapoints"] += sub_datapoints - current_size += sub_size - current_datapoints += sub_datapoints - flush_current_message() - else: - split_and_add_chunk(sub_chunk, sub_datapoints) - return - - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - - if 0 < datapoints_max_count == current_datapoints: - flush_current_message() - - def add_chunk_to_current_message(chunk, chunk_datapoints): - nonlocal current_message, current_size, current_datapoints - - chunk_size = estimate_chunk_size(chunk) - - if (0 < datapoints_max_count <= chunk_datapoints) or chunk_size > max_payload_size: - split_and_add_chunk(chunk, chunk_datapoints) - return - - if (0 < datapoints_max_count <= current_datapoints + chunk_datapoints) or \ - (current_size + chunk_size > max_payload_size): - flush_current_message() - - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - - if 0 < datapoints_max_count == current_datapoints: - flush_current_message() - - def flush_ts_group(ts_key, ts, metadata_repr): - nonlocal current_message, current_size, current_datapoints - if ts_key not in ts_group_cache: - return - - values, _, metadata = ts_group_cache.pop(ts_key) - keys = list(values.keys()) - - step = max(1, datapoints_max_count if datapoints_max_count > 0 else len(keys)) - for i in range(0, len(keys), step): - chunk_values = {k: values[k] for k in keys[i:i + step]} - if ts is not None: - chunk = {"ts": ts, "values": chunk_values} - if metadata: - chunk["metadata"] = metadata - else: - chunk = chunk_values.copy() - - chunk_datapoints = len(chunk_values) - chunk_size = estimate_chunk_size(chunk) - - if chunk_size > max_payload_size or (0 < datapoints_max_count <= chunk_datapoints): - flush_current_message() - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - flush_current_message() - else: - add_chunk_to_current_message(chunk, chunk_datapoints) - - for message in message_pack: - if not isinstance(message, dict): - continue - - ts = message.get("ts") - metadata = message.get("metadata") if isinstance(message.get("metadata"), dict) else None - values = message.get("values") if isinstance(message.get("values"), dict) else \ - message if isinstance(message, dict) else {} - - metadata_repr = _get_metadata_repr(metadata) - ts_key = (ts, metadata_repr) - - for key, value in values.items(): - pair_size = len(str(key)) + len(str(value)) + 4 - if ts_key not in ts_group_cache: - ts_group_cache[ts_key] = ({}, 0, metadata) - - group_values, group_size, group_metadata = ts_group_cache[ts_key] - - can_add = ( - (datapoints_max_count == 0 or len(group_values) < datapoints_max_count) and - (group_size + pair_size <= max_payload_size) - ) - - if can_add: - group_values[key] = value - ts_group_cache[ts_key] = (group_values, group_size + pair_size, group_metadata) - else: - flush_ts_group(ts_key, ts, metadata_repr) - ts_group_cache[ts_key] = ({key: value}, pair_size, metadata) - - for ts_key in list(ts_group_cache.keys()): - ts, metadata_repr = ts_key - flush_ts_group(ts_key, ts, metadata_repr) - - flush_current_message() - return split_messages - - @staticmethod - def _datapoints_limit_reached(datapoints_max_count, current_datapoints_size, current_size): - return 0 < datapoints_max_count <= current_datapoints_size + current_size // 1024 - - @staticmethod - def _payload_size_limit_reached(max_payload_size, current_size, additional_size): - return current_size + additional_size >= max_payload_size - - def add_attrs_request_timeout(self, attr_request_number, timeout): - self.__attrs_request_timeout[attr_request_number] = timeout - - -class ProvisionClient(paho.Client): - PROVISION_REQUEST_TOPIC = "/provision/request" - PROVISION_RESPONSE_TOPIC = "/provision/response" - - def __init__(self, host, port, provision_request): - super().__init__() - self._host = host - self._port = port - self._username = "provision" - self.__credentials = None - self.on_connect = self.__on_connect - self.on_message = self.__on_message - self.__provision_request = provision_request - - def __on_connect(self, client, _, __, rc): # Callback for connect - if rc == 0: - log.info("[Provisioning client] Connected to ThingsBoard ") - client.subscribe(self.PROVISION_RESPONSE_TOPIC) # Subscribe to provisioning response topic - provision_request = dumps(self.__provision_request, option=OPT_NON_STR_KEYS) - log.info("[Provisioning client] Sending provisioning request %s" % provision_request) - client.publish(self.PROVISION_REQUEST_TOPIC, provision_request) # Publishing provisioning request topic - else: - log.info("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % RESULT_CODES[rc]) - - def __on_message(self, _, __, msg): - decoded_payload = msg.payload.decode("UTF-8") - log.info("[Provisioning client] Received data from ThingsBoard: %s" % decoded_payload) - decoded_message = loads(decoded_payload) - provision_device_status = decoded_message.get("status") - if provision_device_status == "SUCCESS": - self.__credentials = decoded_message - else: - log.error("[Provisioning client] Provisioning was unsuccessful with status %s and message: %s" % ( - provision_device_status, decoded_message["errorMsg"])) - self.disconnect() - - def provision(self): - log.info("[Provisioning client] Connecting to ThingsBoard") - self.__credentials = None - self.connect(self._host, self._port, 60) - self.loop_forever() - - def get_credentials(self): - return self.__credentials diff --git a/tb_gateway_mqtt.py b/tb_gateway_mqtt.py deleted file mode 100644 index a41def8..0000000 --- a/tb_gateway_mqtt.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This file is maintained for backward compatibility with version 1 of the SDK. -# It is recommended to use the new SDK structure in tb_mqtt_client for new projects. - -import logging -import warnings - -try: - from time import monotonic as time -except ImportError: - from time import time - -from tb_device_mqtt import TBDeviceMqttClient, RateLimit, TBSendMethod - -# Show deprecation warning -warnings.warn( - "The tb_gateway_mqtt module is deprecated and will be removed in a future version. " - "Please use tb_mqtt_client.service.gateway.client.GatewayClient instead.", - DeprecationWarning, - stacklevel=2 -) - -GATEWAY_ATTRIBUTES_TOPIC = "v1/gateway/attributes" -GATEWAY_TELEMETRY_TOPIC = "v1/gateway/telemetry" -GATEWAY_DISCONNECT_TOPIC = "v1/gateway/disconnect" -GATEWAY_ATTRIBUTES_REQUEST_TOPIC = "v1/gateway/attributes/request" -GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = "v1/gateway/attributes/response" -GATEWAY_MAIN_TOPIC = "v1/gateway/" -GATEWAY_RPC_TOPIC = "v1/gateway/rpc" -GATEWAY_RPC_RESPONSE_TOPIC = "v1/gateway/rpc/response" -GATEWAY_CLAIMING_TOPIC = "v1/gateway/claim" - -log = logging.getLogger("tb_connection") - - -class TBGatewayAPI: - pass - - -class TBGatewayMqttClient(TBDeviceMqttClient): - def __init__(self, host, port=1883, username=None, password=None, gateway=None, quality_of_service=1, client_id="", - messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", - telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", - telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT", - device_messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", - device_telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", - device_telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT", **kwargs): - # Added for compatibility with the old versions - if kwargs.get('rate_limit') or kwargs.get('dp_rate_limit'): - messages_rate_limit = messages_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', messages_rate_limit) # noqa - telemetry_rate_limit = telemetry_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', telemetry_rate_limit) # noqa - device_messages_rate_limit = device_messages_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', device_messages_rate_limit) # noqa - device_telemetry_rate_limit = device_telemetry_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', device_telemetry_rate_limit) # noqa - telemetry_dp_rate_limit = telemetry_dp_rate_limit if kwargs.get('dp_rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('dp_rate_limit', telemetry_dp_rate_limit) # noqa - device_telemetry_dp_rate_limit = device_telemetry_dp_rate_limit if kwargs.get('dp_rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('dp_rate_limit', device_telemetry_dp_rate_limit) # noqa - - super().__init__(host, port, username, password, quality_of_service, client_id, - messages_rate_limit=messages_rate_limit, telemetry_rate_limit=telemetry_rate_limit, - telemetry_dp_rate_limit=telemetry_dp_rate_limit) - - self.__device_telemetry_rate_limit, self.__device_telemetry_dp_rate_limit = RateLimit.get_rate_limits_by_host( - host, device_telemetry_rate_limit, device_telemetry_dp_rate_limit) - self.__device_messages_rate_limit = RateLimit.get_rate_limit_by_host(host, device_messages_rate_limit) - - self._devices_connected_through_gateway_telemetry_messages_rate_limit = RateLimit(self.__device_telemetry_rate_limit, "Rate limit for devices connected through gateway telemetry messages") # noqa - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit = RateLimit(self.__device_telemetry_dp_rate_limit, "Rate limit for devices connected through gateway telemetry data points") # noqa - self._devices_connected_through_gateway_messages_rate_limit = RateLimit(self.__device_messages_rate_limit, "Rate limit for devices connected through gateway messages") # noqa - - self.service_configuration_callback = self.__on_service_configuration - self.quality_of_service = quality_of_service - self.__max_sub_id = 0 - self.__sub_dict = {} - self.__connected_devices = set("*") - self.devices_server_side_rpc_request_handler = None - self.device_disconnect_callback = None - self._client.on_connect = self._on_connect - self._client.on_message = self._on_message - self._client.on_subscribe = self._on_subscribe - self._client._on_unsubscribe = self._on_unsubscribe - self._gw_subscriptions = {} - self.gateway = gateway - - def _on_connect(self, client, userdata, flags, result_code, properties, *extra_params): - super()._on_connect(client, userdata, flags, result_code, properties, *extra_params) - if result_code == 0: - gateway_attributes_topic_sub_id = int(self._subscribe_to_topic(GATEWAY_ATTRIBUTES_TOPIC, qos=1)[1]) - self._add_or_delete_subscription(GATEWAY_ATTRIBUTES_TOPIC, gateway_attributes_topic_sub_id) - - gateway_attributes_resp_sub_id = int(self._subscribe_to_topic(GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, qos=1)[1]) - self._add_or_delete_subscription(GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, gateway_attributes_resp_sub_id) - - gateway_rpc_topic_sub_id = int(self._subscribe_to_topic(GATEWAY_RPC_TOPIC, qos=1)[1]) - self._add_or_delete_subscription(GATEWAY_RPC_TOPIC, gateway_rpc_topic_sub_id) - - def _on_subscribe(self, client, userdata, mid, reasoncodes, properties=None, *extra_params): - subscription = self._gw_subscriptions.get(mid) - if subscription is not None: - if mid == 128: - self._delete_subscription(subscription, mid) - else: - log.debug("Service subscription to topic %s - successfully completed.", subscription) - del self._gw_subscriptions[mid] - - def _delete_subscription(self, topic, subscription_id): - log.error("Service subscription to topic %s - failed.", topic) - if subscription_id in self._gw_subscriptions: - del self._gw_subscriptions[subscription_id] - - def _add_or_delete_subscription(self, topic, subscription_id): - if subscription_id == 128: - self._delete_subscription(topic, subscription_id) - else: - self._gw_subscriptions[subscription_id] = topic - - @staticmethod - def _on_unsubscribe(*args): - log.debug("Unsubscribe callback called with args: %r", args) - - def get_subscriptions_in_progress(self): - return True if self._gw_subscriptions else False - - def _on_message(self, client, userdata, message): - content = self._decode(message) - super()._on_decoded_message(content, message) - self._on_decoded_message(content, message) - - def _on_decoded_message(self, content, message, **kwargs): - if message.topic.startswith(GATEWAY_ATTRIBUTES_RESPONSE_TOPIC): - with self._lock: - req_id = content["id"] - self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1) - # pop callback and use it - if self._attr_request_dict.get(req_id): - callback = self._attr_request_dict.pop(req_id) - if isinstance(callback, tuple): - callback[0](content, None, callback[1]) - else: - callback(content, None) - else: - log.error("Unable to find callback to process attributes response from platform.") - elif message.topic == GATEWAY_ATTRIBUTES_TOPIC: - with self._lock: - # callbacks for everything - if self.__sub_dict.get("*|*"): - for device in self.__sub_dict["*|*"]: - self.__sub_dict["*|*"][device](content) - # callbacks for device. in this case callback executes for all attributes in message - if content.get("device") is None: - return - self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1) - target = content["device"] + "|*" - if self.__sub_dict.get(target): - for device in self.__sub_dict[target]: - self.__sub_dict[target][device](content) - # callback for atr. in this case callback executes for all attributes in message - targets = [content["device"] + "|" + attribute for attribute in content["data"]] - for target in targets: - if self.__sub_dict.get(target): - for device in self.__sub_dict[target]: - self.__sub_dict[target][device](content) - elif message.topic == GATEWAY_RPC_TOPIC: - self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1) - if self.devices_server_side_rpc_request_handler: - self.devices_server_side_rpc_request_handler(self, content) - elif message.topic == GATEWAY_DISCONNECT_TOPIC: - if content.get("reason"): - reason = content["reason"] - log.info("Device \"%s\" disconnected with reason %s", content["device"], content["reason"]) - if reason == 150: # 150 - Rate limit reached - self._devices_connected_through_gateway_messages_rate_limit.reach_limit() - self._devices_connected_through_gateway_telemetry_messages_rate_limit.reach_limit() - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.reach_limit() - if self.device_disconnect_callback is not None: - self.device_disconnect_callback(self, content) - else: - log.debug("Unexpected message from topic %r, content: %r", message.topic, content) - - def __request_attributes(self, device, keys, callback, type_is_client=False): - - attr_request_number = self._add_attr_request_callback(callback) - msg = {"keys": keys, - "device": device, - "client": type_is_client, - "id": attr_request_number} - info = self._send_device_request(TBSendMethod.PUBLISH, device, topic=GATEWAY_ATTRIBUTES_REQUEST_TOPIC, data=msg, - qos=1) - self.add_attrs_request_timeout(attr_request_number, int(time()) + 20) - return info - - def _send_device_request(self, _type, device_name, **kwargs): - if _type == TBSendMethod.PUBLISH: - topic = kwargs['topic'] - device_msg_rate_limit = self._devices_connected_through_gateway_messages_rate_limit - device_dp_rate_limit = self.EMPTY_RATE_LIMIT - if topic == GATEWAY_TELEMETRY_TOPIC or topic == GATEWAY_ATTRIBUTES_TOPIC: - device_msg_rate_limit = self._devices_connected_through_gateway_telemetry_messages_rate_limit - device_dp_rate_limit = self._devices_connected_through_gateway_telemetry_datapoints_rate_limit - info = self._publish_data(**kwargs, device=device_name, - msg_rate_limit=device_msg_rate_limit, - dp_rate_limit=device_dp_rate_limit) - return info - - def gw_request_shared_attributes(self, device_name, keys, callback): - return self.__request_attributes(device_name, keys, callback, False) - - def gw_request_client_attributes(self, device_name, keys, callback): - return self.__request_attributes(device_name, keys, callback, True) - - def gw_send_attributes(self, device, attributes, quality_of_service=1): - return self._send_device_request(TBSendMethod.PUBLISH, - device, - topic=GATEWAY_ATTRIBUTES_TOPIC, - data={device: attributes}, - qos=quality_of_service) - - def gw_send_telemetry(self, device, telemetry, quality_of_service=1): - if not isinstance(telemetry, list): - telemetry = [telemetry] - - return self._send_device_request(TBSendMethod.PUBLISH, - device, - topic=GATEWAY_TELEMETRY_TOPIC, - data={device: telemetry}, - qos=quality_of_service) - - def gw_connect_device(self, device_name, device_type="default"): - info = self._send_device_request(TBSendMethod.PUBLISH, device_name, topic=GATEWAY_MAIN_TOPIC + "connect", - data={"device": device_name, "type": device_type}, - qos=self.quality_of_service) - - self.__connected_devices.add(device_name) - - log.debug("Connected device %s", device_name) - return info - - def gw_disconnect_device(self, device_name): - info = self._send_device_request(TBSendMethod.PUBLISH, device_name, topic=GATEWAY_MAIN_TOPIC + "disconnect", - data={"device": device_name}, qos=self.quality_of_service) - - if device_name in self.__connected_devices: - self.__connected_devices.remove(device_name) - - log.debug("Disconnected device %s", device_name) - return info - - def gw_subscribe_to_all_attributes(self, callback): - return self.gw_subscribe_to_attribute("*", "*", callback) - - def gw_subscribe_to_all_device_attributes(self, device, callback): - return self.gw_subscribe_to_attribute(device, "*", callback) - - def gw_subscribe_to_attribute(self, device, attribute, callback): - if device not in self.__connected_devices: - log.error("Device %s is not connected", device) - return False - - with self._lock: - self.__max_sub_id += 1 - key = device + "|" + attribute - if key not in self.__sub_dict: - self.__sub_dict.update({key: {device: callback}}) - else: - self.__sub_dict[key].update({device: callback}) - - log.info("Subscribed to %s with id %i for device %s", key, self.__max_sub_id, device) - return self.__max_sub_id - - def gw_unsubscribe(self, subscription_id): - with self._lock: - for attribute in self.__sub_dict: - if self.__sub_dict[attribute].get(subscription_id): - del self.__sub_dict[attribute][subscription_id] - log.info("Unsubscribed from %s, subscription id %r", attribute, subscription_id) - if subscription_id == '*': - self.__sub_dict = {} - - def gw_set_server_side_rpc_request_handler(self, handler): - self.devices_server_side_rpc_request_handler = handler - - def gw_send_rpc_reply(self, device, req_id, resp, quality_of_service=None): - if quality_of_service is None: - quality_of_service = self.quality_of_service - if quality_of_service not in (0, 1): - log.error("Quality of service (qos) value must be 0 or 1") - return None - - info = self._send_device_request(TBSendMethod.PUBLISH, device, topic=GATEWAY_RPC_TOPIC, - data={"device": device, "id": req_id, "data": resp}, - qos=quality_of_service) - return info - - def gw_claim(self, device_name, secret_key, duration, claiming_request=None): - if claiming_request is None: - claiming_request = { - device_name: { - "secretKey": secret_key, - "durationMs": duration - } - } - - info = self._send_device_request(TBSendMethod.PUBLISH, device_name, topic=GATEWAY_CLAIMING_TOPIC, - data=claiming_request, qos=self.quality_of_service) - return info - - def __on_service_configuration(self, _, response, *args, **kwargs): - if "error" in response: - log.warning("Timeout while waiting for service configuration!, session will use default configuration.") - self.rate_limits_received = True - return - service_config = response - gateway_devices_rate_limit_config = service_config.pop('gatewayRateLimits', {}) - gateway_device_itself_rate_limit_config = service_config.pop('rateLimits', {}) - - if gateway_devices_rate_limit_config.get("messages"): - self._devices_connected_through_gateway_messages_rate_limit.set_limit( - gateway_devices_rate_limit_config.get("messages")) - else: - self._devices_connected_through_gateway_messages_rate_limit.set_limit('0:0,') - - if gateway_devices_rate_limit_config.get('telemetryMessages'): - self._devices_connected_through_gateway_telemetry_messages_rate_limit.set_limit( - gateway_devices_rate_limit_config.get('telemetryMessages')) - else: - self._devices_connected_through_gateway_telemetry_messages_rate_limit.set_limit('0:0,') - - if gateway_devices_rate_limit_config.get('telemetryDataPoints'): - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.set_limit( - gateway_devices_rate_limit_config.get('telemetryDataPoints')) - else: - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.set_limit('0:0,') - - gateway_limits_present = any( - [self._devices_connected_through_gateway_messages_rate_limit.has_limit(), - self._devices_connected_through_gateway_telemetry_messages_rate_limit.has_limit(), - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.has_limit()] - ) - - super().on_service_configuration(_, - {'rateLimits': gateway_device_itself_rate_limit_config, **service_config}, - *args, - **kwargs, - gateway_limits_present=gateway_limits_present) - log.info("Current limits for devices connected through the gateway: %r", gateway_devices_rate_limit_config) diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py index e730c3f..912a9ee 100644 --- a/tb_mqtt_client/common/async_utils.py +++ b/tb_mqtt_client/common/async_utils.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import threading from typing import Union, Optional, Any, List, Set, Dict -import asyncio from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult @@ -97,38 +97,6 @@ async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, if not stop_task.done(): stop_task.cancel() -async def await_and_resolve_original( - parent_futures: List[asyncio.Future], - child_futures: List[asyncio.Future] -): - try: - results = await asyncio.gather(*child_futures, return_exceptions=True) - - for child in child_futures: - future_map.child_resolved(child) - - for i, f in enumerate(parent_futures): - if f is not None and not f.done(): - first_result = next((r for r in results if not isinstance(r, Exception)), None) - first_exception = next((r for r in results if isinstance(r, Exception)), None) - - if first_exception and not first_result: - f.set_exception(first_exception) - logger.debug("Set exception for parent future #%d id=%r from child exception: %r", - i, getattr(f, 'uuid', f), first_exception) - else: - f.set_result(first_result) - logger.trace("Resolved parent future #%d id=%r with result: %r", - i, getattr(f, 'uuid', f), first_result) - - except Exception as e: - logger.error("Unexpected error while resolving parent delivery futures: %s", e) - for i, f in enumerate(parent_futures): - if f is not None and not f.done(): - f.set_exception(e) - logger.debug("Set fallback exception for parent future #%d id=%r", i, getattr(f, 'uuid', f)) - - def run_coroutine_sync(coro_func, timeout: float = 3.0, raise_on_timeout: bool = False): """ Run async coroutine and return its result from a sync function even if event loop is running. diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py index 394b9ae..631900a 100644 --- a/tb_mqtt_client/common/mqtt_message.py +++ b/tb_mqtt_client/common/mqtt_message.py @@ -85,7 +85,3 @@ def mark_as_sent(self, message_id: int): """Mark the message as sent.""" self.message_id = message_id self._is_sent = True - - def is_sent(self) -> bool: - """Check if the message has been sent.""" - return self._is_sent diff --git a/tb_mqtt_client/common/publish_result.py b/tb_mqtt_client/common/publish_result.py index fdab1a3..429db0f 100644 --- a/tb_mqtt_client/common/publish_result.py +++ b/tb_mqtt_client/common/publish_result.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List @@ -34,7 +35,7 @@ def __repr__(self): def __eq__(self, other): if not isinstance(other, PublishResult): - return NotImplemented + return False return (self.topic == other.topic and self.qos == other.qos and self.message_id == other.message_id and diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py index d1b0c73..cb7f02b 100644 --- a/tb_mqtt_client/common/rate_limit/rate_limit.py +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -14,8 +14,8 @@ import asyncio -import os import logging +import os from asyncio import Lock from time import monotonic diff --git a/tb_mqtt_client/entities/data/attribute_entry.py b/tb_mqtt_client/entities/data/attribute_entry.py index b145f62..214d83a 100644 --- a/tb_mqtt_client/entities/data/attribute_entry.py +++ b/tb_mqtt_client/entities/data/attribute_entry.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - from tb_mqtt_client.constants.json_typing import JSONCompatibleType from tb_mqtt_client.entities.data.data_entry import DataEntry diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py index 40ff960..3b6fb9b 100644 --- a/tb_mqtt_client/entities/data/attribute_request.py +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import Optional, List, Dict + from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer from tb_mqtt_client.constants.json_typing import validate_json_compatibility from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent diff --git a/tb_mqtt_client/entities/data/data_entry.py b/tb_mqtt_client/entities/data/data_entry.py index dcbb999..73df046 100644 --- a/tb_mqtt_client/entities/data/data_entry.py +++ b/tb_mqtt_client/entities/data/data_entry.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Optional + from orjson import dumps from tb_mqtt_client.constants.json_typing import JSONCompatibleType, validate_json_compatibility diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index c4de1da..47daad5 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -14,7 +14,6 @@ import asyncio from dataclasses import dataclass -from time import time from types import MappingProxyType from typing import List, Optional, Union, OrderedDict, Tuple, Mapping from uuid import uuid4 diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py index 7eafa84..ba2b5d0 100644 --- a/tb_mqtt_client/entities/data/rpc_response.py +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum from dataclasses import dataclass +from enum import Enum from traceback import format_exception from typing import Union, Optional, Dict, Any diff --git a/tb_mqtt_client/entities/data/timeseries_entry.py b/tb_mqtt_client/entities/data/timeseries_entry.py index 0136403..d63f7bf 100644 --- a/tb_mqtt_client/entities/data/timeseries_entry.py +++ b/tb_mqtt_client/entities/data/timeseries_entry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Optional from tb_mqtt_client.constants.json_typing import JSONCompatibleType from tb_mqtt_client.entities.data.data_entry import DataEntry diff --git a/tb_mqtt_client/entities/gateway/device_info.py b/tb_mqtt_client/entities/gateway/device_info.py index d3b7f88..2e44f5b 100644 --- a/tb_mqtt_client/entities/gateway/device_info.py +++ b/tb_mqtt_client/entities/gateway/device_info.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field import uuid +from dataclasses import dataclass, field @dataclass() -class DeviceInfo: +class DeviceInfo(object): device_name: str device_profile: str original_name: str = field(init=False) @@ -37,10 +37,12 @@ def __setattr__(self, key, value): def rename(self, new_name: str): if new_name != self.device_name: - self.device_name = new_name + super().__setattr__('device_name', new_name) @classmethod def from_dict(cls, data: dict) -> 'DeviceInfo': + original_post_init = cls.__post_init__ + cls.__post_init__ = lambda self: None instance = cls( device_name=data['device_name'], device_profile=data.get('device_profile', 'default') @@ -48,6 +50,10 @@ def from_dict(cls, data: dict) -> 'DeviceInfo': instance.__setattr__("device_id", uuid.UUID(data['device_id'])) if 'original_name' in data: instance.__setattr__("original_name", data['original_name']) + else: + instance.__setattr__("original_name", instance.device_name) + instance._initializing = False + cls.__post_init__ = original_post_init return instance def to_dict(self) -> dict: diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py index f56e959..5958e6a 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from typing import Optional, List, Dict, Union + from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer from tb_mqtt_client.constants.json_typing import validate_json_compatibility from tb_mqtt_client.entities.data.attribute_request import AttributeRequest diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py index 0f6e0b7..51fd359 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -11,22 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent from tb_mqtt_client.entities.gateway.event_type import GatewayEventType -class GatewayAttributeUpdate(BaseGatewayEvent): +class GatewayAttributeUpdate(AttributeUpdate, BaseGatewayEvent): """ Represents an attribute update event for a device connected to a gateway. This event is used to notify about changes in device shared attributes. """ - def __init__(self, device_name: str, attribute_update: AttributeUpdate): - super().__init__(event_type=GatewayEventType.DEVICE_ATTRIBUTE_UPDATE) + def __init__(self, device_name: str, attribute_update: Union[AttributeUpdate, List[AttributeEntry], AttributeEntry]): + super().__init__(GatewayEventType.DEVICE_ATTRIBUTE_UPDATE) + if isinstance(attribute_update, list) and all(isinstance(entry, AttributeEntry) for entry in attribute_update): + attribute_update = AttributeUpdate(entries=attribute_update) + elif isinstance(attribute_update, AttributeEntry): + attribute_update = AttributeUpdate(entries=[attribute_update]) + elif not isinstance(attribute_update, AttributeUpdate): + raise TypeError("attribute_update must be an instance of AttributeUpdate, list of AttributeEntry, or a single AttributeEntry.") self.device_name = device_name + self.entries = attribute_update.entries self.attribute_update = attribute_update + def __str__(self) -> str: return f"GatewayAttributeUpdate(device_name={self.device_name}, attribute_update={self.attribute_update})" diff --git a/tb_mqtt_client/entities/gateway/gateway_claim_request.py b/tb_mqtt_client/entities/gateway/gateway_claim_request.py index 12a1143..dbf57e8 100644 --- a/tb_mqtt_client/entities/gateway/gateway_claim_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_claim_request.py @@ -14,7 +14,7 @@ from dataclasses import dataclass -from typing import Optional, Dict, Any, Union +from typing import Dict, Any, Union from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent diff --git a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py index 5faf387..612843d 100644 --- a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py @@ -33,6 +33,9 @@ class GatewayRequestedAttributeResponse(RequestedAttributeResponse, BaseGatewayE client: Optional[List[AttributeEntry]] = None event_type: GatewayEventType = GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE + def __post_init__(self): + super(BaseGatewayEvent, self).__setattr__('event_type', GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE) + def __repr__(self): return f"GatewayRequestedAttributeResponse(device_name={self.device_name},request_id={self.request_id}, shared={self.shared}, client={self.client})" diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py index 57d4190..b47c423 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -28,7 +28,7 @@ class GatewayRPCRequest(BaseGatewayEvent): event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_REQUEST def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of GatewayRPCRequest is not allowed. Use 'await GatewayRPCRequest.build(...)'.") + raise TypeError("Direct instantiation of GatewayRPCRequest is not allowed. Use 'GatewayRPCRequest._deserialize_from_dict(...)'.") def __repr__(self): return f"RPCRequest(id={self.request_id}, device_name={self.device_name}, method={self.method}, params={self.params})" diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index 2b041c7..5c79368 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -28,7 +28,7 @@ logger = get_logger(__name__) -DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) +DEFAULT_FIELDS_SIZE = len('{"device_name":"","attributes":"","timeseries":""}'.encode('utf-8')) @dataclass(slots=True, frozen=True) @@ -167,6 +167,9 @@ def build(self) -> GatewayUplinkMessage: delivery_future = asyncio.get_event_loop().create_future() delivery_future.uuid = uuid4() self._delivery_futures = [delivery_future] + if not self._device_profile: + self._device_profile = 'default' + self.__size += len(self._device_profile) return GatewayUplinkMessage.build( # noqa device_name=self._device_name, device_profile=self._device_profile, diff --git a/tb_mqtt_client/service/base_message_splitter.py b/tb_mqtt_client/service/base_message_splitter.py index 050e3ef..6c50dc9 100644 --- a/tb_mqtt_client/service/base_message_splitter.py +++ b/tb_mqtt_client/service/base_message_splitter.py @@ -29,7 +29,6 @@ def max_payload_size(self) -> int: """ Returns the maximum payload size for messages. """ - pass @max_payload_size.setter @abstractmethod @@ -37,7 +36,6 @@ def max_payload_size(self, value: int) -> None: """ Sets the maximum payload size for messages. """ - pass @property @abstractmethod @@ -45,7 +43,6 @@ def max_datapoints(self) -> int: """ Returns the maximum number of datapoints allowed in a message. """ - pass @max_datapoints.setter @abstractmethod @@ -53,18 +50,15 @@ def max_datapoints(self, value: int) -> None: """ Sets the maximum number of datapoints allowed in a message. """ - pass @abstractmethod def split_timeseries(self, *args, **kwargs) -> List[MqttPublishMessage]: """ Splits timeseries data """ - pass @abstractmethod def split_attributes(self, *args, **kwargs) -> List[MqttPublishMessage]: """ Splits attributes data """ - pass diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py index 358984a..4a716db 100644 --- a/tb_mqtt_client/service/device/firmware_updater.py +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from random import randint -from zlib import crc32 -from hashlib import sha256, sha384, sha512, md5 -from subprocess import CalledProcessError from asyncio import sleep +from hashlib import sha256, sha384, sha512, md5 from os.path import sep +from random import randint +from subprocess import CalledProcessError from typing import Awaitable, Callable, Optional +from zlib import crc32 + from tb_mqtt_client.common.install_package_utils import install_package from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.constants import mqtt_topics @@ -32,7 +33,6 @@ REQUIRED_SHARED_KEYS, FirmwareStates ) - from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index 37726e5..67f44b0 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -133,7 +133,7 @@ def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RP pass @abstractmethod - def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, payload: bytes) -> 'ProvisioningResponse': + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, payload: bytes) -> ProvisioningResponse: """ Parse the provisioning response from the given payload. This method should be implemented to handle the specific format of the provisioning response. diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 9710ffc..303b292 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -21,12 +21,11 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE -from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.attribute_request import AttributeRequest -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage from tb_mqtt_client.entities.gateway.event_type import GatewayEventType @@ -197,7 +196,7 @@ async def disconnect_device(self, device_session: DeviceSession, wait_for_publis async def send_device_timeseries(self, device_session: DeviceSession, data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: """ Send timeseries data to the platform for a specific device. :param device_session: The DeviceSession object for the device @@ -236,7 +235,7 @@ async def send_device_timeseries(self, async def send_device_attributes(self, device_session: DeviceSession, data: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: """ Send attributes data to the platform for a specific device. :param device_session: The DeviceSession object for the device @@ -266,7 +265,10 @@ async def send_device_attributes(self, return results[0] if len(results) == 1 else results - async def send_device_attributes_request(self, device_session: DeviceSession, attribute_request: Union[AttributeRequest, GatewayAttributeRequest], wait_for_publish: bool): + async def send_device_attributes_request(self, + device_session: DeviceSession, + attribute_request: Union[AttributeRequest, GatewayAttributeRequest], + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: """ Send a request for device attributes to the platform. :param device_session: The DeviceSession object for the device @@ -304,7 +306,7 @@ async def send_device_attributes_request(self, device_session: DeviceSession, at async def send_device_claim_request(self, device_session: DeviceSession, gateway_claim_request: GatewayClaimRequest, - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: """ Send a claim request for a device to the platform. :param device_session: The DeviceSession object for the device diff --git a/tb_mqtt_client/service/gateway/device_manager.py b/tb_mqtt_client/service/gateway/device_manager.py index 1a108eb..e7d786f 100644 --- a/tb_mqtt_client/service/gateway/device_manager.py +++ b/tb_mqtt_client/service/gateway/device_manager.py @@ -16,9 +16,8 @@ from uuid import UUID from tb_mqtt_client.common.logging_utils import get_logger -from tb_mqtt_client.service.gateway.device_session import DeviceSession from tb_mqtt_client.entities.gateway.device_info import DeviceInfo - +from tb_mqtt_client.service.gateway.device_session import DeviceSession logger = get_logger(__name__) diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index a23febb..281e611 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -13,14 +13,13 @@ # limitations under the License. import asyncio -from time import time from dataclasses import dataclass, field +from time import time from typing import Callable, Awaitable, Optional, Union -from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent - from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent from tb_mqtt_client.entities.gateway.device_info import DeviceInfo from tb_mqtt_client.entities.gateway.device_session_state import DeviceSessionState from tb_mqtt_client.entities.gateway.event_type import GatewayEventType diff --git a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py index 478cd5b..b2ab01e 100644 --- a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py +++ b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py @@ -39,6 +39,8 @@ def register(self, event_type: GatewayEventType, callback: EventCallback): self._handlers[event_type].append(callback) def unregister(self, event_type: GatewayEventType, callback: EventCallback): + if event_type not in self._handlers: + return if callback in self._handlers[event_type]: self._handlers[event_type].remove(callback) if not self._handlers[event_type]: diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 4061060..0eb41d8 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -15,7 +15,6 @@ from abc import abstractmethod, ABC from collections import defaultdict from datetime import datetime, UTC -from itertools import chain from typing import List, Dict, Any, Union, Optional from orjson import loads, dumps diff --git a/tb_mqtt_client/service/gateway/message_splitter.py b/tb_mqtt_client/service/gateway/message_splitter.py index 60dcc06..c9b416f 100644 --- a/tb_mqtt_client/service/gateway/message_splitter.py +++ b/tb_mqtt_client/service/gateway/message_splitter.py @@ -21,7 +21,8 @@ from tb_mqtt_client.common.logging_utils import get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder, \ + DEFAULT_FIELDS_SIZE from tb_mqtt_client.service.base_message_splitter import BaseMessageSplitter logger = get_logger(__name__) @@ -32,6 +33,7 @@ class GatewayMessageSplitter(BaseMessageSplitter): def __init__(self, max_payload_size: int = 55000, max_datapoints: int = 0): self._max_payload_size = max_payload_size if max_payload_size is not None and max_payload_size > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_payload_size = self._max_payload_size - DEFAULT_FIELDS_SIZE self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 logger.trace("GatewayMessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", self._max_payload_size, self._max_datapoints) @@ -63,9 +65,10 @@ def split_timeseries(self, messages: List[GatewayUplinkMessage]) -> List[Gateway builder = None size = 0 point_count = 0 + names_len = len(device_name) + len(device_profile) for entry in all_ts_entries: - exceeds_size = builder and size + entry.size > self._max_payload_size + exceeds_size = builder and size + entry.size > self._max_payload_size - names_len exceeds_points = 0 < self._max_datapoints <= point_count if not builder or exceeds_size or exceeds_points: @@ -176,6 +179,7 @@ def max_payload_size(self) -> int: def max_payload_size(self, value: int): old = self._max_payload_size self._max_payload_size = value if value > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_payload_size = self._max_payload_size - DEFAULT_FIELDS_SIZE logger.debug("Updated max_payload_size: %d -> %d", old, self._max_payload_size) @property diff --git a/tests/common/test_async_utils.py b/tests/common/test_async_utils.py new file mode 100644 index 0000000..a588f5b --- /dev/null +++ b/tests/common/test_async_utils.py @@ -0,0 +1,134 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from tb_mqtt_client.common.async_utils import FutureMap, await_or_stop +from tb_mqtt_client.common.publish_result import PublishResult + + +@pytest.mark.asyncio +async def test_future_map_register_and_get_parents(): + fm = FutureMap() + parent = asyncio.Future() + child1, child2 = asyncio.Future(), asyncio.Future() + + fm.register(parent, [child1]) + assert fm.get_parents(child1) == [parent] + + # Add more children to same parent + fm.register(parent, [child2]) + assert set(fm.get_parents(child1)) == {parent} + assert set(fm.get_parents(child2)) == {parent} + +@pytest.mark.asyncio +async def test_future_map_child_resolved_merges_results(): + fm = FutureMap() + parent = asyncio.Future() + child1, child2 = asyncio.Future(), asyncio.Future() + + fm.register(parent, [child1, child2]) + child1.set_result(PublishResult("topic", 1, 1, 1, 0)) + child2.set_result(PublishResult("topic", 1, 1, 1, 0)) + + # Resolving child1 won't complete parent + fm.child_resolved(child1) + assert not parent.done() + + # Resolving last child completes parent with merged PublishResult + fm.child_resolved(child2) + assert parent.done() + result = parent.result() + assert isinstance(result, PublishResult) + assert result.reason_code == 0 + +@pytest.mark.asyncio +async def test_future_map_child_resolved_no_publish_result(): + fm = FutureMap() + parent = asyncio.Future() + child = asyncio.Future() + + fm.register(parent, [child]) + child.set_result("not publish result") + fm.child_resolved(child) + assert parent.done() + assert parent.result() is None + +@pytest.mark.asyncio +async def test_future_map_child_resolved_with_cancelled_child(): + fm = FutureMap() + parent = asyncio.Future() + child = asyncio.Future() + fm.register(parent, [child]) + child.cancel() + fm.child_resolved(child) + assert parent.done() + assert parent.result() is None + +@pytest.mark.asyncio +async def test_await_or_stop_coroutine_finishes_first(): + stop_event = asyncio.Event() + async def coro(): return 123 + result = await await_or_stop(coro(), stop_event, timeout=1) + assert result == 123 + +@pytest.mark.asyncio +async def test_await_or_stop_stop_event_first(): + stop_event = asyncio.Event() + async def coro(): await asyncio.sleep(0.5) + asyncio.get_event_loop().call_soon(stop_event.set) + result = await await_or_stop(coro(), stop_event, timeout=1) + assert result is None + +@pytest.mark.asyncio +async def test_await_or_stop_timeout(): + stop_event = asyncio.Event() + async def coro(): await asyncio.sleep(1) + with pytest.raises(asyncio.TimeoutError): + await await_or_stop(coro(), stop_event, timeout=0.01) + +@pytest.mark.asyncio +async def test_await_or_stop_negative_timeout(): + stop_event = asyncio.Event() + async def coro(): return "ok" + result = await await_or_stop(coro(), stop_event, timeout=-1) + assert result == "ok" + +@pytest.mark.asyncio +async def test_await_or_stop_future_done(): + stop_event = asyncio.Event() + fut = asyncio.Future() + fut.set_result("done") + result = await await_or_stop(fut, stop_event, timeout=1) + assert result == "done" + +@pytest.mark.asyncio +async def test_await_or_stop_invalid_type(): + stop_event = asyncio.Event() + with pytest.raises(TypeError): + await await_or_stop("not a future", stop_event, timeout=1) + +@pytest.mark.asyncio +async def test_await_or_stop_cancelled_error(): + stop_event = asyncio.Event() + async def coro(): + raise asyncio.CancelledError() + result = await await_or_stop(coro(), stop_event, timeout=1) + assert result is None + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/common/test_backpressure_controller.py b/tests/common/test_backpressure_controller.py index 1670979..d06d287 100644 --- a/tests/common/test_backpressure_controller.py +++ b/tests/common/test_backpressure_controller.py @@ -13,11 +13,13 @@ # limitations under the License. import asyncio -import pytest from datetime import datetime, timedelta, UTC +import pytest + from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController + @pytest.fixture def stop_event(): return asyncio.Event() diff --git a/tests/common/test_config_loader.py b/tests/common/test_config_loader.py index 7723044..6148d6e 100644 --- a/tests/common/test_config_loader.py +++ b/tests/common/test_config_loader.py @@ -14,8 +14,10 @@ import os import unittest + from tb_mqtt_client.common.config_loader import DeviceConfig, GatewayConfig + class TestDeviceConfig(unittest.TestCase): def loads_default_values_when_env_vars_missing(self): diff --git a/tests/common/test_gmqtt_patch.py b/tests/common/test_gmqtt_patch.py new file mode 100644 index 0000000..414d690 --- /dev/null +++ b/tests/common/test_gmqtt_patch.py @@ -0,0 +1,144 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import heapq +import struct +import types + +import pytest + +from tb_mqtt_client.common.gmqtt_patch import PatchUtils, PublishPacket +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage + + +def test_parse_mqtt_properties_valid_and_invalid(): + # Unknown property id triggers warning branch + pkt = bytes([1]) + bytes([255]) + assert PatchUtils.parse_mqtt_properties(pkt) == {} + + # Exception path (invalid varint) + assert PatchUtils.parse_mqtt_properties(b"\xff") == {} + + +def test_extract_reason_code_all_paths(): + class Obj: reason_code = 42 + assert PatchUtils.extract_reason_code(Obj()) == 42 + assert PatchUtils.extract_reason_code(b"\x00*") == 42 + assert PatchUtils.extract_reason_code(b"") is None + assert PatchUtils.extract_reason_code(None) is None + + +def test_patch_puback_handling_and_storage(monkeypatch): + pu = PatchUtils(None, asyncio.Event()) + called = {} + def on_puback(mid, reason, props): + called["hit"] = (mid, reason, props) + + monkeypatch.setattr("gmqtt.mqtt.handler.MqttPackageHandler._handle_puback_packet", lambda *a, **k: None) + pu.patch_puback_handling(on_puback) + + # Call wrapped handler + pkt = struct.pack("!HB", 10, 1) + b"\x00" + handler = types.SimpleNamespace( + _connection=types.SimpleNamespace(persistent_storage=types.SimpleNamespace(remove=lambda m: None)) + ) + handler2 = handler + MqttHandlerClass = type("H", (), {}) + pu_handler = MqttHandlerClass() + pu_handler._connection = handler._connection + pu_handler._handle_puback_packet = lambda *a, **k: None + pu_handler = types.SimpleNamespace(**{"_connection": handler._connection}) + # Ensure parsing works + PatchUtils.parse_mqtt_properties(b"\x00") + called.clear() + + # patch_storage test + client = types.SimpleNamespace(_persistent_storage=types.SimpleNamespace( + _queue=[(0,1,"raw")], + _check_empty=lambda : None + )) + pu.client = client + pu.patch_storage() + assert asyncio.get_event_loop().run_until_complete(client._persistent_storage.pop_message()) + + +@pytest.mark.asyncio +async def test_retry_loop_and_task_controls(monkeypatch): + storage_queue = [] + class Storage: + def _check_empty(self): pass + async def pop_message(self): + if not storage_queue: + raise IndexError + return storage_queue.pop(0) + + msgs_sent = [] + class FakeClient: + is_connected = True + async def put_retry_message(self, msg): + msgs_sent.append(msg) + _persistent_storage = Storage() + + pu = PatchUtils(FakeClient(), asyncio.Event(), retry_interval=0) + heapq.heappush(storage_queue, (0, 1, types.SimpleNamespace(topic="t", dup=False))) + pu._stop_event.set() # immediate exit + await pu._retry_loop() + + # start_retry_task + stop_retry_task normal + pu._stop_event.clear() + pu.start_retry_task() + assert pu._retry_task + await pu.stop_retry_task() + assert pu._retry_task is None + + # Timeout branch in stop_retry_task + pu._retry_task = asyncio.create_task(asyncio.sleep(1)) + await pu.stop_retry_task() + + +def test_apply_calls_patch_and_starts_task(monkeypatch): + pu = PatchUtils(None, asyncio.Event()) + monkeypatch.setattr(pu, "patch_puback_handling", lambda cb: setattr(pu, "_patched", True)) + monkeypatch.setattr(pu, "start_retry_task", lambda : setattr(pu, "_started", True)) + pu.apply(lambda a,b,c: None) + assert pu._patched + assert pu._started + + +def test_build_package_qos1_with_provided_mid(): + msg = MqttPublishMessage(topic="topic", payload=b"PAY", qos=1, retain=True) + msg.dup = True + protocol = types.SimpleNamespace(proto_ver=5) + mid, packet = PublishPacket.build_package(msg, protocol, mid=77) + assert mid == 77 + # Verify mid is encoded at correct spot (after topic length and topic) + assert struct.pack("!H", 77) in packet + # DUP flag set + assert packet[0] & 0x08 + + +def test_build_package_qos1_with_generated_mid(monkeypatch): + msg = MqttPublishMessage(topic="gen", payload=b"PAY", qos=1) + # Force known id from id_generator + monkeypatch.setattr(PublishPacket, "id_generator", types.SimpleNamespace(next_id=lambda: 1234)) + protocol = types.SimpleNamespace(proto_ver=5) + mid, packet = PublishPacket.build_package(msg, protocol) + assert mid == 1234 + assert struct.pack("!H", 1234) in packet + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_provisioning_client.py b/tests/common/test_provisioning_client.py index 5f71855..0d00244 100644 --- a/tests/common/test_provisioning_client.py +++ b/tests/common/test_provisioning_client.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest + from tb_mqtt_client.common.provisioning_client import ProvisioningClient from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus diff --git a/tests/common/test_publish_result.py b/tests/common/test_publish_result.py index 15b799f..738c592 100644 --- a/tests/common/test_publish_result.py +++ b/tests/common/test_publish_result.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest + from tb_mqtt_client.common.publish_result import PublishResult @@ -59,6 +60,32 @@ def test_publish_result_as_dict(default_publish_result): "datapoints_count": 0 } +def test_publish_request_merge(): + result1 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + result2 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=124, + payload_size=512, + reason_code=0 + ) + merged_result = PublishResult.merge([result1, result2]) + + assert merged_result.topic == "v1/devices/me/telemetry" + assert merged_result.qos == 1 + assert merged_result.message_id == -1 # Merged results do not have a specific message_id + assert merged_result.payload_size == 768 # Combined payload size + assert merged_result.reason_code == 0 # All successful + +def test_publish_result_merge_with_empty_list(): + with pytest.raises(ValueError, match="No publish results to merge."): + PublishResult.merge([]) def test_publish_result_is_successful_true(default_publish_result): assert default_publish_result.is_successful() is True @@ -75,6 +102,43 @@ def test_publish_result_is_successful_false(): assert result.is_successful() is False +def test_publish_result_equality(): + result1 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + result2 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + assert result1 == result2 + + +def test_publish_result_inequality(): + result1 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + result2 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=124, # Different message_id + payload_size=256, + reason_code=0 + ) + assert result1 != result2 + assert result1 != "Not a PublishResult" # Different type comparison + + @pytest.mark.parametrize("reason_code", [1, 2, 3, 16, 255]) def test_publish_result_various_failure_codes(reason_code): result = PublishResult( @@ -85,3 +149,6 @@ def test_publish_result_various_failure_codes(reason_code): reason_code=reason_code ) assert result.is_successful() is False + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py new file mode 100644 index 0000000..1e02c30 --- /dev/null +++ b/tests/common/test_queue.py @@ -0,0 +1,155 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from types import SimpleNamespace + +import pytest + +from tb_mqtt_client.common.queue import AsyncDeque + + +@pytest.fixture +def make_item(): + """Factory to create unique or duplicate-like test items.""" + def _make(uuid): + return SimpleNamespace(uuid=uuid) + return _make + + +@pytest.mark.asyncio +async def test_put_and_get(make_item): + q = AsyncDeque(maxlen=5) + item = make_item("a") + await q.put(item) + assert q.size() == 1 + assert not q.is_empty() + + got = await q.get() + assert got.uuid == "a" + assert q.is_empty() + + +@pytest.mark.asyncio +async def test_put_duplicate_not_added(make_item): + q = AsyncDeque(maxlen=5) + item = make_item("dup") + await q.put(item) + await q.put(item) # duplicate + assert q.size() == 1 + + +@pytest.mark.asyncio +async def test_extend_and_duplicates(make_item): + q = AsyncDeque(maxlen=5) + items = [make_item(str(i)) for i in range(3)] + await q.extend(items) + assert q.size() == 3 + + # Duplicate extend + await q.extend(items) + assert q.size() == 3 + + +@pytest.mark.asyncio +async def test_put_left_and_duplicate(make_item): + q = AsyncDeque(maxlen=5) + item = make_item("x") + await q.put_left(item) + assert list(await q.peek_batch(1))[0].uuid == "x" + + await q.put_left(item) # duplicate, ignored + assert q.size() == 1 + + +@pytest.mark.asyncio +async def test_extend_left_and_duplicates(make_item): + q = AsyncDeque(maxlen=5) + items = [make_item(str(i)) for i in range(3)] + await q.extend_left(items) + assert [it.uuid for it in await q.peek_batch(3)] == ["0", "1", "2"] + + await q.extend_left(items) # duplicates ignored + assert q.size() == 3 + + +@pytest.mark.asyncio +async def test_peek_and_peek_batch_waits(make_item): + q = AsyncDeque(maxlen=5) + + async def delayed_put(): + await asyncio.sleep(0.01) + await q.put(make_item("peeked")) + + asyncio.create_task(delayed_put()) + first = await q.peek() + assert first.uuid == "peeked" + + batch = await q.peek_batch(1) + assert batch[0].uuid == "peeked" + + +@pytest.mark.asyncio +async def test_pop_n_partial_and_full(make_item): + q = AsyncDeque(maxlen=5) + items = [make_item(str(i)) for i in range(3)] + await q.extend(items) + + # pop more than available + popped = await q.pop_n(5) + assert [it.uuid for it in popped] == ["0", "1", "2"] + assert q.is_empty() + + # fill again + await q.extend(items) + popped2 = await q.pop_n(2) + assert [it.uuid for it in popped2] == ["0", "1"] + assert q.size() == 1 + + +@pytest.mark.asyncio +async def test_reinsert_front_duplicate_and_new(make_item): + q = AsyncDeque(maxlen=5) + item1 = make_item("1") + item2 = make_item("2") + + # Insert new + await q.reinsert_front(item1) + assert q.size() == 1 + assert (await q.peek()).uuid == "1" + + # Try duplicate + await q.reinsert_front(item1) + assert q.size() == 1 + + # Insert another new + await q.reinsert_front(item2) + assert q.size() == 2 + assert (await q.peek()).uuid == "2" + + +@pytest.mark.asyncio +async def test_get_waits_until_item_available(make_item): + q = AsyncDeque(maxlen=5) + + async def delayed_put(): + await asyncio.sleep(0.01) + await q.put(make_item("later")) + + asyncio.create_task(delayed_put()) + got = await q.get() + assert got.uuid == "later" + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/common/test_rate_limit.py b/tests/common/test_rate_limit.py index c487097..e46686b 100644 --- a/tests/common/test_rate_limit.py +++ b/tests/common/test_rate_limit.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import math from time import sleep import pytest @@ -119,3 +120,110 @@ async def test_rate_limit_refill_behavior(): await asyncio.sleep(1.1) await rl.refill() assert (await rl.try_consume()) is None + +@pytest.mark.asyncio +async def test_set_required_tokens_and_clear_event(): + rl = RateLimit("5:1", "token-test") + rl.set_required_tokens(1, 3) + # Initially event should not be set + assert not rl.required_tokens_ready.is_set() + # Force refill to satisfy requirement + for bucket in rl._rate_buckets.values(): + bucket.tokens = 3 + await rl.refill() + assert rl.required_tokens_ready.is_set() + rl.clear_required_tokens_event() + assert not rl.required_tokens_ready.is_set() + + +@pytest.mark.asyncio +async def test_refill_does_not_break_with_no_limit(): + rl = RateLimit("", "no-limit-refill") + # Should not raise and should not change anything + await rl.refill() + assert rl._no_limit + + +@pytest.mark.asyncio +async def test_consume_with_real_limit_changes_tokens(): + rl = RateLimit("5:1", "consume-test") + before = rl._rate_buckets[1].tokens + await rl.consume(1) + after = rl._rate_buckets[1].tokens + assert after < before + + +def test_minimal_limit_and_timeout_no_limit_case(): + rl = RateLimit("", "min-test") + assert rl.minimal_limit == 0 + assert rl.minimal_timeout == 0 + + +@pytest.mark.asyncio +async def test_to_dict_contains_expected_structure(): + rl = RateLimit("5:1", "dict-test") + d = rl.to_dict() + assert "name" in d and d["name"] == "dict-test" + assert "percentage" in d and isinstance(d["percentage"], int) + assert "rateLimits" in d and isinstance(d["rateLimits"], dict) + + +@pytest.mark.asyncio +async def test_set_limit_with_no_entries_sets_no_limit(): + rl = RateLimit("5:1", "reset-test") + await rl.set_limit("0:0,") + assert not rl.has_limit() + + +@pytest.mark.asyncio +async def test_try_consume_returns_tuple_when_not_enough_tokens(): + rl = RateLimit("1:1", "tuple-test") + # Exhaust tokens + await rl.try_consume() + res = await rl.try_consume() + assert isinstance(res, tuple) + assert len(res) == 2 + + +def test_parse_string_with_invalid_entries(): + # This will hit logger.warning branch for parse_string exception + rl = RateLimit("abc:def", "invalid-test") + assert not rl.has_limit() + + +@pytest.mark.asyncio +async def test_reach_limit_when_no_limit_returns_none(): + rl = RateLimit("", "reach-none") + assert await rl.reach_limit() is None + + +@pytest.mark.asyncio +async def test_reach_limit_resets_index_on_duration_window(): + rl = RateLimit("1:1,1:2", "reset-index") + rl._RateLimit__reached_index = 10 # force overflow + rl._RateLimit__reached_index_time = 0 + res = await rl.reach_limit() + assert isinstance(res, tuple) + assert rl._RateLimit__reached_index > 0 + + +def test_env_variable_invalid(monkeypatch): + # Simulate bad env variable + monkeypatch.setenv("TB_DEFAULT_RATE_LIMIT_PERCENTAGE", "bad") + # Reload module under test to re-run env reading + import importlib + import tb_mqtt_client.common.rate_limit.rate_limit as rl_mod + importlib.reload(rl_mod) + assert rl_mod.DEFAULT_RATE_LIMIT_PERCENTAGE == 80 + + +def test_greedy_token_bucket_edge_cases(): + b = GreedyTokenBucket(2, 1) + b.tokens = 0 + assert not b.can_consume(1) + assert not b.consume(5) + assert math.isclose(b.get_remaining_tokens(), b.tokens) + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/entities/data/test_timeseries_entry.py b/tests/entities/data/test_timeseries_entry.py index 9bcec4c..ee12b7e 100644 --- a/tests/entities/data/test_timeseries_entry.py +++ b/tests/entities/data/test_timeseries_entry.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest + from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry diff --git a/tests/entities/gateway/__init__.py b/tests/entities/gateway/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/entities/gateway/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entities/gateway/test_base_gateway_event.py b/tests/entities/gateway/test_base_gateway_event.py new file mode 100644 index 0000000..3206a09 --- /dev/null +++ b/tests/entities/gateway/test_base_gateway_event.py @@ -0,0 +1,61 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +class DummyDeviceSession: + """A simple dummy object to mimic a device session.""" + pass + + +def test_initialization_and_event_type(): + # Create event with a specific GatewayEventType + event = BaseGatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Ensure the event_type property returns what we set + assert event.event_type == GatewayEventType.GATEWAY_CONNECT + + # device_session should be None initially + assert event.device_session is None + + +def test_set_and_get_device_session(): + event = BaseGatewayEvent(GatewayEventType.GATEWAY_DISCONNECT) + + # Initially None + assert event.device_session is None + + # Set a device session + dummy_session = DummyDeviceSession() + event.set_device_session(dummy_session) + + # Ensure the getter returns what we set + assert event.device_session is dummy_session + + +def test_str_calls_repr(monkeypatch): + event = BaseGatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Monkeypatch __repr__ to a known value + monkeypatch.setattr(event, "__repr__", lambda: "mocked_repr") + + # __str__ should return whatever __repr__ returns + assert str(event) == "mocked_repr" + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/entities/gateway/test_device_connect_message.py b/tests/entities/gateway/test_device_connect_message.py new file mode 100644 index 0000000..f2c9e53 --- /dev/null +++ b/tests/entities/gateway/test_device_connect_message.py @@ -0,0 +1,62 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +def test_direct_instantiation_not_allowed(): + # Direct instantiation should raise a TypeError + with pytest.raises(TypeError) as exc_info: + DeviceConnectMessage("device1") + assert "Direct instantiation" in str(exc_info.value) + + +def test_build_with_empty_device_name_raises(): + # Empty device name should raise ValueError + with pytest.raises(ValueError) as exc_info: + DeviceConnectMessage.build("") + assert "must not be empty" in str(exc_info.value) + + +def test_build_and_repr_and_payload(): + msg = DeviceConnectMessage.build("MyDevice", "MyProfile") + + # Check attributes + assert msg.device_name == "MyDevice" + assert msg.device_profile == "MyProfile" + assert msg.event_type == GatewayEventType.DEVICE_CONNECT + + # __repr__ format check + repr_str = repr(msg) + assert "DeviceConnectMessage" in repr_str + assert "MyDevice" in repr_str + assert "MyProfile" in repr_str + + # Payload format check + payload = msg.to_payload_format() + assert payload == {"device": "MyDevice", "type": "MyProfile"} + + +def test_build_with_default_profile(): + msg = DeviceConnectMessage.build("DefaultDevice") + assert msg.device_profile == "default" + assert msg.to_payload_format() == {"device": "DefaultDevice", "type": "default"} + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_device_disconnect_message.py b/tests/entities/gateway/test_device_disconnect_message.py new file mode 100644 index 0000000..7ee78ef --- /dev/null +++ b/tests/entities/gateway/test_device_disconnect_message.py @@ -0,0 +1,58 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +def test_direct_instantiation_not_allowed(): + # Directly calling constructor should raise TypeError + with pytest.raises(TypeError) as exc_info: + DeviceDisconnectMessage("DeviceX") + assert "Direct instantiation of DeviceDisconnectMessage" in str(exc_info.value) + + +def test_build_and_attributes(): + # Build using correct way + msg = DeviceDisconnectMessage.build("TestDevice") + assert isinstance(msg, DeviceDisconnectMessage) + assert msg.device_name == "TestDevice" + assert msg.event_type == GatewayEventType.DEVICE_DISCONNECT + # Check __repr__ + assert repr(msg) == "DeviceDisconnectMessage(device_name=TestDevice)" + # Check to_payload_format + assert msg.to_payload_format() == {"device": "TestDevice"} + + +def test_build_with_empty_device_name(): + # Empty device name should raise ValueError + with pytest.raises(ValueError) as exc_info: + DeviceDisconnectMessage.build("") + assert "Device name must not be empty" in str(exc_info.value) + + +def test_frozen_slots_and_equality_behavior(): + msg1 = DeviceDisconnectMessage.build("A") + msg2 = DeviceDisconnectMessage.build("A") + msg3 = DeviceDisconnectMessage.build("B") + + # Equality check works for dataclass with frozen=True + assert msg1 == msg2 + assert msg1 != msg3 + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_device_info.py b/tests/entities/gateway/test_device_info.py new file mode 100644 index 0000000..ab9a3af --- /dev/null +++ b/tests/entities/gateway/test_device_info.py @@ -0,0 +1,111 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo + + +def test_post_init_and_basic_properties(): + """Test that __post_init__ sets original_name and _initializing is False.""" + d = DeviceInfo(device_name="dev1", device_profile="profile1") + assert d.device_name == "dev1" + assert d.original_name == "dev1" + assert isinstance(d.device_id, uuid.UUID) + assert not d._initializing + + +def test_setattr_after_init_raises(): + """Test that modifying attributes after init raises AttributeError.""" + d = DeviceInfo(device_name="dev1", device_profile="p1") + with pytest.raises(AttributeError) as exc: + d.device_name = "new_name" + assert "Cannot modify attribute" in str(exc.value) + + +def test_rename_changes_name(): + """Test rename method changes device_name.""" + d = DeviceInfo(device_name="dev1", device_profile="p1") + old_id = d.device_id + d.rename("new_dev") + assert d.device_name == "new_dev" + assert d.device_id == old_id # ID stays the same + + +def test_rename_same_name_no_change(): + """Renaming to the same name should not alter anything.""" + d = DeviceInfo(device_name="dev1", device_profile="p1") + d.rename("dev1") # No change expected + assert d.device_name == "dev1" + + +def test_from_dict_and_to_dict(): + """Test creating from dict and converting back to dict.""" + dev_id = str(uuid.uuid4()) + data = { + "device_name": "dname", + "device_profile": "prof", + "device_id": dev_id, + "original_name": "orig" + } + d = DeviceInfo.from_dict(data) + assert isinstance(d, DeviceInfo) + assert str(d.device_id) == dev_id + assert d.device_name == "dname" + assert d.original_name == "orig" + + # Check to_dict output + out = d.to_dict() + assert out["device_name"] == "dname" + assert out["device_profile"] == "prof" + assert out["device_id"] == dev_id + assert out["original_name"] == "orig" + + +def test_str_and_repr(): + """Test __str__ and __repr__ produce expected formats.""" + d = DeviceInfo(device_name="dname", device_profile="prof") + s = str(d) + r = repr(d) + assert "DeviceInfo" in s + assert "DeviceInfo" in r + assert str(d.device_id) in s + assert repr(d.device_id) in r + + +def test_eq_and_hash(): + """Test equality and hash implementation.""" + d1 = DeviceInfo(device_name="dname", device_profile="prof") + d2 = DeviceInfo(device_name="dname", device_profile="prof") + # Manually make IDs equal for equality check + object.__setattr__(d2, "device_id", d1.device_id) + object.__setattr__(d2, "original_name", d1.original_name) + + assert d1 == d2 + assert hash(d1) == hash(d2) + + # Different object type should return NotImplemented for eq + assert d1.__eq__("not a DeviceInfo") is NotImplemented + + +def test_hash_works_in_set(): + """Test that DeviceInfo is hashable and works in a set.""" + d = DeviceInfo(device_name="dname", device_profile="prof") + s = {d} + assert d in s + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) \ No newline at end of file diff --git a/tests/entities/gateway/test_gateway_attribute_request.py b/tests/entities/gateway/test_gateway_attribute_request.py new file mode 100644 index 0000000..b1a85eb --- /dev/null +++ b/tests/entities/gateway/test_gateway_attribute_request.py @@ -0,0 +1,130 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch, AsyncMock + +import pytest + +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest + + +@pytest.mark.asyncio +async def test_direct_instantiation_not_allowed(): + with pytest.raises(TypeError) as exc: + GatewayAttributeRequest() + assert "Direct instantiation" in str(exc.value) + + +@pytest.mark.asyncio +async def test_build_assigns_values_and_repr(): + mock_session = MagicMock() + mock_session.device_info.device_name = "TestDevice" + + with patch("tb_mqtt_client.common.request_id_generator.AttributeRequestIdProducer.get_next", new=AsyncMock(return_value=123)): + req = await GatewayAttributeRequest.build( + device_session=mock_session, + shared_keys=["s1", "s2"], + client_keys=["c1"] + ) + + # Check attributes set correctly + assert req.device_session is mock_session + assert req.request_id == 123 + assert req.shared_keys == ["s1", "s2"] + assert req.client_keys == ["c1"] + assert req.event_type == GatewayEventType.DEVICE_ATTRIBUTE_REQUEST + + # __repr__ should include these values + r = repr(req) + assert "TestDevice" in r or "device_session" in r + assert "shared_keys" in r + assert "client_keys" in r + + +@pytest.mark.asyncio +async def test_from_attribute_request_success(): + mock_session = MagicMock() + mock_session.device_info.device_name = "MyDev" + + attr_req = await AttributeRequest.build(shared_keys=["sh1"], client_keys=["cl1"]) + + result = await GatewayAttributeRequest.from_attribute_request(mock_session, attr_req) + + assert result.device_session is mock_session + assert result.request_id == attr_req.request_id + assert result.shared_keys == ["sh1"] + assert result.client_keys == ["cl1"] + assert result.event_type == GatewayEventType.DEVICE_ATTRIBUTE_REQUEST + + +@pytest.mark.asyncio +async def test_from_attribute_request_invalid_type(): + mock_session = MagicMock() + with pytest.raises(TypeError) as exc: + await GatewayAttributeRequest.from_attribute_request(mock_session, object()) + assert "must be an instance of AttributeRequest" in str(exc.value) + + +def make_request_for_payload(shared=None, client=None): + mock_session = MagicMock() + mock_session.device_info.device_name = "DeviceX" + + req = object.__new__(GatewayAttributeRequest) + object.__setattr__(req, 'device_session', mock_session) + object.__setattr__(req, 'request_id', 999) + object.__setattr__(req, 'shared_keys', shared) + object.__setattr__(req, 'client_keys', client) + object.__setattr__(req, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) + return req + + +def test_to_payload_format_client_multiple_keys(): + req = make_request_for_payload(client=["a", "b"]) + payload = req.to_payload_format() + assert payload["device"] == "DeviceX" + assert payload["client"] is True + assert payload["keys"] == ["a", "b"] + + +def test_to_payload_format_client_single_key(): + req = make_request_for_payload(client=["only_one"]) + payload = req.to_payload_format() + assert payload["client"] is True + assert payload["key"] == "only_one" + + +def test_to_payload_format_shared_multiple_keys(): + req = make_request_for_payload(shared=["s1", "s2"]) + payload = req.to_payload_format() + assert payload["client"] is False + assert payload["keys"] == ["s1", "s2"] + + +def test_to_payload_format_shared_single_key(): + req = make_request_for_payload(shared=["s1"]) + payload = req.to_payload_format() + assert payload["client"] is False + assert payload["key"] == "s1" + + +def test_to_payload_format_no_keys(): + req = make_request_for_payload(shared=None, client=None) + payload = req.to_payload_format() + assert set(payload.keys()) == {"device", "id"} + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_attribute_update.py b/tests/entities/gateway/test_gateway_attribute_update.py new file mode 100644 index 0000000..2842dc2 --- /dev/null +++ b/tests/entities/gateway/test_gateway_attribute_update.py @@ -0,0 +1,55 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate + + +def test_init_with_list_of_entries(): + entries = [AttributeEntry("k1", "v1"), AttributeEntry("k2", "v2")] + obj = GatewayAttributeUpdate("deviceA", entries) + assert obj.device_name == "deviceA" + assert len(obj.entries) == 2 + assert all(isinstance(e, AttributeEntry) for e in obj.entries) + assert isinstance(obj.attribute_update, AttributeUpdate) + assert str(obj) == f"GatewayAttributeUpdate(device_name=deviceA, attribute_update={obj.attribute_update})" + +def test_init_with_single_entry(): + entry = AttributeEntry("k", "v") + obj = GatewayAttributeUpdate("deviceB", entry) + assert obj.device_name == "deviceB" + assert len(obj.entries) == 1 + assert obj.entries[0].key == "k" + assert isinstance(obj.attribute_update, AttributeUpdate) + assert "deviceB" in str(obj) + +def test_init_with_attribute_update(): + update = AttributeUpdate([AttributeEntry("ka", "va")]) + obj = GatewayAttributeUpdate("deviceC", update) + assert obj.device_name == "deviceC" + assert len(obj.entries) == 1 + assert isinstance(obj.attribute_update, AttributeUpdate) + assert "deviceC" in str(obj) + +def test_init_with_invalid_type(): + with pytest.raises(TypeError) as exc: + GatewayAttributeUpdate("deviceD", {"invalid": "type"}) + assert "attribute_update must be an instance of AttributeUpdate" in str(exc.value) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_claim_request.py b/tests/entities/gateway/test_gateway_claim_request.py new file mode 100644 index 0000000..16a061f --- /dev/null +++ b/tests/entities/gateway/test_gateway_claim_request.py @@ -0,0 +1,99 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_claim_request import ( + GatewayClaimRequest, + GatewayClaimRequestBuilder +) +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_direct_instantiation_not_allowed(): + with pytest.raises(TypeError) as exc: + GatewayClaimRequest() + assert "Direct instantiation" in str(exc.value) + + +def test_repr_and_add_device_request_with_str(): + claim_req = ClaimRequest.build("secret", 1000) + + req = GatewayClaimRequest.build() + assert isinstance(req.devices_requests, dict) + assert req.event_type == GatewayEventType.GATEWAY_CLAIM_REQUEST + + req.add_device_request("device_1", claim_req) + assert "device_1" in req.devices_requests + assert repr(req) == f"GatewayClaimRequest(devices_requests={req.devices_requests})" + + payload = req.to_payload_format() + assert payload == {"device_1": claim_req.to_payload_format()} + + +def test_add_device_request_with_device_session(): + claim_req = ClaimRequest.build("key2", 2000) + dev_info = DeviceInfo(device_name="devA", device_profile="default") + dev_session = DeviceSession(dev_info) + + req = GatewayClaimRequest.build() + req.add_device_request(dev_session, claim_req) + + payload = req.to_payload_format() + assert payload == {"devA": claim_req.to_payload_format()} + + +def test_builder_add_and_build_with_str(): + claim_req = ClaimRequest.build("builder_secret", 3000) + builder = GatewayClaimRequestBuilder() + builder.add_device_request("devB", claim_req) + + built_req = builder.build() + assert isinstance(built_req, GatewayClaimRequest) + assert built_req.to_payload_format() == {"devB": claim_req.to_payload_format()} + + +def test_builder_add_and_build_with_device_session(): + claim_req = ClaimRequest.build("builder_secret_2", 4000) + dev_info = DeviceInfo(device_name="devC", device_profile="default") + dev_session = DeviceSession(dev_info) + + builder = GatewayClaimRequestBuilder() + builder.add_device_request(dev_session, claim_req) + + built_req = builder.build() + assert built_req.to_payload_format() == {"devC": claim_req.to_payload_format()} + + +@pytest.mark.parametrize("bad_device", [123, object()]) +def test_builder_add_device_request_invalid_device_type(bad_device): + builder = GatewayClaimRequestBuilder() + with pytest.raises(ValueError) as exc: + builder.add_device_request(bad_device, ClaimRequest.build("key", 100)) + assert "DeviceSession or a string" in str(exc.value) + + +@pytest.mark.parametrize("bad_claim", [123, object(), "string"]) +def test_builder_add_device_request_invalid_claim_request(bad_claim): + builder = GatewayClaimRequestBuilder() + with pytest.raises(ValueError) as exc: + builder.add_device_request("devName", bad_claim) + assert "must be an instance of ClaimRequest" in str(exc.value) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_event.py b/tests/entities/gateway/test_gateway_event.py new file mode 100644 index 0000000..73d05d2 --- /dev/null +++ b/tests/entities/gateway/test_gateway_event.py @@ -0,0 +1,76 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_initialization_and_event_type(): + # Create an event and check event type + event = GatewayEvent(GatewayEventType.GATEWAY_CONNECT) + assert event.event_type == GatewayEventType.GATEWAY_CONNECT + + # Initially __device_session should be None + assert event.get_device_session() is None + + +def test_set_and_get_device_session(): + event = GatewayEvent(GatewayEventType.GATEWAY_DISCONNECT) + + device_info = DeviceInfo("dummy_device", "default") + dummy_session = DeviceSession(device_info) + event.set_device_session(dummy_session) + + # Ensure get_device_session returns the same object + assert event.get_device_session() is dummy_session + + +def test_str_representation_with_and_without_device_session(): + event = GatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Without device session + no_session_str = str(event) + assert "GatewayEvent(type=" in no_session_str + assert "device_session=None" in no_session_str + + # With device session + device_info = DeviceInfo("dummy_device", "default") + dummy_session = DeviceSession(device_info) + event.set_device_session(dummy_session) + session_str = str(event) + assert "DeviceSession" in session_str + + +def test_to_dict_with_and_without_device_session(): + event = GatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Without device session + d = event.to_dict() + assert d["event_type"] == GatewayEventType.GATEWAY_CONNECT + assert d["device_session"] is None + + # With device session + device_info = DeviceInfo("dummy_device", "default") + dummy_session = DeviceSession(device_info) + event.set_device_session(dummy_session) + d2 = event.to_dict() + assert d2["device_session"] is dummy_session + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/entities/gateway/test_gateway_requested_attribute_response.py b/tests/entities/gateway/test_gateway_requested_attribute_response.py new file mode 100644 index 0000000..ca26c15 --- /dev/null +++ b/tests/entities/gateway/test_gateway_requested_attribute_response.py @@ -0,0 +1,129 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse + + +def make_entry(key, value): + return AttributeEntry(key=key, value=value) + + +def test_post_init_and_event_type_enforcement(): + resp = GatewayRequestedAttributeResponse(device_name="dev1", request_id=1) + # Even if you pass no event_type, __post_init__ sets it + assert resp.event_type == GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE + + +def test_repr_contains_expected_fields(): + shared = [make_entry("s1", "v1")] + client = [make_entry("c1", "v2")] + resp = GatewayRequestedAttributeResponse( + device_name="devX", + request_id=42, + shared=shared, + client=client + ) + r = repr(resp) + assert "GatewayRequestedAttributeResponse" in r + assert "devX" in r + assert "42" in r + assert "s1" in r + assert "c1" in r + + +def test_getitem_finds_in_shared_first(): + shared = [make_entry("foo", "bar")] + client = [make_entry("foo", "baz")] # same key, should not be used if found in shared + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp["foo"] == "bar" # found in shared first + + +def test_getitem_finds_in_client_if_not_in_shared(): + shared = [make_entry("other", 123)] + client = [make_entry("foo", "baz")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp["foo"] == "baz" + + +def test_getitem_raises_keyerror_if_not_found(): + resp = GatewayRequestedAttributeResponse( + shared=[make_entry("one", 1)], + client=[make_entry("two", 2)] + ) + with pytest.raises(KeyError) as e: + _ = resp["missing"] + assert "Key 'missing'" in str(e.value) + + +def test_shared_keys_and_client_keys(): + shared = [make_entry("s1", "v1"), make_entry("s2", "v2")] + client = [make_entry("c1", "v3")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp.shared_keys() == ["s1", "s2"] + assert resp.client_keys() == ["c1"] + + +def test_get_shared_and_client_success(): + shared = [make_entry("sh", "sv")] + client = [make_entry("cl", "cv")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp.get_shared("sh") == "sv" + assert resp.get_client("cl") == "cv" + + +def test_get_shared_and_client_with_default_on_missing(): + shared = [make_entry("sh", "sv")] + client = [make_entry("cl", "cv")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp.get_shared("missing", default="def") == "def" + assert resp.get_client("missing", default="def") == "def" + + +def test_get_shared_and_client_when_none(): + resp = GatewayRequestedAttributeResponse(shared=None, client=None) + assert resp.get_shared("any", default="d") == "d" + assert resp.get_client("any", default="d") == "d" + + +def test_as_dict_returns_expected_dict(monkeypatch): + shared_entry = make_entry("s1", "v1") + client_entry = make_entry("c1", "v2") + called = {} + + def fake_as_dict_self(): + called.setdefault("calls", []).append(1) + return {"key": "dummy", "value": "dummy"} + + monkeypatch.setattr(shared_entry, "as_dict", fake_as_dict_self) + monkeypatch.setattr(client_entry, "as_dict", fake_as_dict_self) + + resp = GatewayRequestedAttributeResponse( + shared=[shared_entry], + client=[client_entry] + ) + d = resp.as_dict() + assert "shared" in d + assert "client" in d + assert isinstance(d["shared"], list) + assert isinstance(d["client"], list) + # Ensure our fake_as_dict was called twice + assert len(called["calls"]) == 2 + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/entities/gateway/test_gateway_uplink_message.py b/tests/entities/gateway/test_gateway_uplink_message.py new file mode 100644 index 0000000..fd7cc40 --- /dev/null +++ b/tests/entities/gateway/test_gateway_uplink_message.py @@ -0,0 +1,152 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import OrderedDict +from uuid import UUID + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_uplink_message import ( + GatewayUplinkMessage, + GatewayUplinkMessageBuilder +) + + +@pytest.mark.asyncio +async def test_direct_instantiation_forbidden(): + with pytest.raises(TypeError) as exc: + GatewayUplinkMessage(device_name="dev", device_profile="prof") + assert "Direct instantiation" in str(exc.value) + + +@pytest.mark.asyncio +async def test_repr_and_build_and_properties(): + # Prepare entities + attr1 = AttributeEntry("a1", "v1") + attr2 = AttributeEntry("a2", "v2") + ts1 = TimeseriesEntry("t1", 123) + ts2 = TimeseriesEntry("t2", 456) + + f1 = asyncio.get_event_loop().create_future() + f2 = asyncio.get_event_loop().create_future() + + msg = GatewayUplinkMessage.build( + device_name="dev1", + device_profile="prof1", + attributes=[attr1, attr2], + timeseries=OrderedDict({111: [ts1, ts2]}), + delivery_futures=[f1, f2], + size=123, + main_ts=999, + ) + + # __repr__ + rep = repr(msg) + assert "dev1" in rep and "prof1" in rep + + # Properties + assert msg.device_name == "dev1" + assert msg.device_profile == "prof1" + assert msg.size == 123 + assert msg.event_type == GatewayEventType.DEVICE_UPLINK + + # Datapoint counts + assert msg.timeseries_datapoint_count() == 2 + assert msg.attributes_datapoint_count() == 2 + assert msg.has_attributes() + assert msg.has_timeseries() + + # Futures + futs = msg.get_delivery_futures() + assert futs == (f1, f2) + + # set_main_ts + assert msg.set_main_ts(555).main_ts == 555 + + +@pytest.mark.asyncio +async def test_builder_minimal_and_defaults(): + b = GatewayUplinkMessageBuilder() + b.set_device_name("devname") + b.set_device_profile("profile") + attr = AttributeEntry("akey", "aval") + ts = TimeseriesEntry("tkey", 1) + b.add_attributes(attr) + b.add_timeseries(ts) + + # Delivery future gets added + fut = asyncio.get_event_loop().create_future() + b.add_delivery_futures(fut) + + b.set_main_ts(1000) + msg = b.build() + + assert isinstance(msg, GatewayUplinkMessage) + assert msg.device_name == "devname" + assert msg.device_profile == "profile" + assert msg.main_ts == 1000 + assert msg.has_attributes() + assert msg.has_timeseries() + + +@pytest.mark.asyncio +async def test_builder_defaults_when_not_set(): + # No device_profile and no futures -> builder should create defaults + b = GatewayUplinkMessageBuilder() + b.set_device_name("d1") + + msg = b.build() + assert msg.device_profile == "default" + assert isinstance(msg.get_delivery_futures()[0], asyncio.Future) + # The auto-created future has uuid + assert isinstance(msg.get_delivery_futures()[0].uuid, UUID) + + +@pytest.mark.asyncio +async def test_add_attributes_and_timeseries_variants(): + b = GatewayUplinkMessageBuilder() + b.set_device_name("dev") + attr_list = [AttributeEntry("a1", "v1"), AttributeEntry("a2", "v2")] + ts_list = [TimeseriesEntry("t1", 1), TimeseriesEntry("t2", 2)] + + # Add as list + b.add_attributes(attr_list) + b.add_timeseries(ts_list) + + # Add OrderedDict directly + od = OrderedDict({123: [TimeseriesEntry("k", 42)]}) + b.add_timeseries(od) + + msg = b.build() + assert msg.has_attributes() + assert msg.has_timeseries() + + +@pytest.mark.asyncio +async def test_len_tracks_size(): + b = GatewayUplinkMessageBuilder() + base_len = len(b) + b.set_device_name("abc") + b.set_device_profile("xyz") + b.add_attributes(AttributeEntry("k", "v")) + b.add_timeseries(TimeseriesEntry("ts", 1)) + assert len(b) > base_len + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/__init__.py b/tests/service/device/handlers/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/service/device/handlers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/service/device/handlers/test_attribute_updates_handler.py b/tests/service/device/handlers/test_attribute_updates_handler.py new file mode 100644 index 0000000..3533a65 --- /dev/null +++ b/tests/service/device/handlers/test_attribute_updates_handler.py @@ -0,0 +1,126 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + + +class DummyAdapter(MessageAdapter): + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + pass + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + pass + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def __init__(self, result=None, raise_exc=False): + super().__init__() + self.result = result + self.raise_exc = raise_exc + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + if self.raise_exc: + raise ValueError("Simulated parse error") + return self.result + + + +@pytest.mark.asyncio +async def test_set_message_adapter_and_callback_called(): + handler = AttributeUpdatesHandler() + adapter_result = AttributeUpdate([AttributeEntry("key", "value")]) + + handler.set_message_adapter(DummyAdapter(result=adapter_result)) + + called = {} + + async def cb(update: AttributeUpdate): + called["data"] = update + + handler.set_callback(cb) + + await handler.handle("topic", b"payload") + assert called["data"] == adapter_result + + +def test_set_message_adapter_invalid_type(): + handler = AttributeUpdatesHandler() + with pytest.raises(ValueError) as exc: + handler.set_message_adapter("not-a-message-adapter") + assert "message_adapter must be an instance of MessageAdapter" in str(exc.value) + + +@pytest.mark.asyncio +async def test_handle_no_callback_set(): + handler = AttributeUpdatesHandler() + # even if message_adapter is set, no callback means skip + handler.set_message_adapter(DummyAdapter(result=AttributeUpdate({}))) + # should simply return without error + await handler.handle("topic", b"payload") + + +@pytest.mark.asyncio +async def test_handle_parse_error(): + handler = AttributeUpdatesHandler() + handler.set_message_adapter(DummyAdapter(raise_exc=True)) + + called = {"called": False} + + async def cb(update: AttributeUpdate): + called["called"] = True + + handler.set_callback(cb) + # parse will raise, callback should not be called + await handler.handle("topic", b"payload") + assert not called["called"] + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/test_requested_attributes_response_handler.py b/tests/service/device/handlers/test_requested_attributes_response_handler.py new file mode 100644 index 0000000..f3a0986 --- /dev/null +++ b/tests/service/device/handlers/test_requested_attributes_response_handler.py @@ -0,0 +1,194 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Union, List + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ + RequestedAttributeResponseHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + + +class DummyMessageAdapter(MessageAdapter): + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + pass + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + pass + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def __init__(self, response=None, exc=None): + super().__init__() + self._response = response + self._exc = exc + + def parse_requested_attribute_response(self, topic, payload): + if self._exc: + raise self._exc + return self._response + + +@pytest.mark.asyncio +async def test_set_message_adapter_and_full_handle_flow(): + handler = RequestedAttributeResponseHandler() + # Build real AttributeRequest + request = await AttributeRequest.build(shared_keys=["temp"], client_keys=["c1"]) + + called = {} + + async def cb(resp): + called["val"] = resp["temp"] + + # Register request + await handler.register_request(request, cb) + + # Create real RequestedAttributeResponse with matching ID + resp = RequestedAttributeResponse( + request_id=request.request_id, + shared=[AttributeEntry("temp", 42)], + client=[AttributeEntry("c1", "v1")] + ) + + # Set adapter returning our response + handler.set_message_adapter(DummyMessageAdapter(response=resp)) + + # Handle message + await handler.handle(f"some/topic/{request.request_id}", b"payload") + + assert called["val"] == 42 + assert request.request_id not in handler._pending_attribute_requests + + +def test_set_message_adapter_invalid_type(): + handler = RequestedAttributeResponseHandler() + with pytest.raises(ValueError): + handler.set_message_adapter(object()) + + +@pytest.mark.asyncio +async def test_register_request_duplicate(): + handler = RequestedAttributeResponseHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + with pytest.raises(RuntimeError): + await handler.register_request(req, lambda r: None) + + +@pytest.mark.asyncio +async def test_unregister_existing_and_non_existing(): + handler = RequestedAttributeResponseHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + + assert req.request_id in handler._pending_attribute_requests + handler.unregister_request(req.request_id) + assert req.request_id not in handler._pending_attribute_requests + + # No error for missing ID + handler.unregister_request(9999) + + +@pytest.mark.asyncio +async def test_handle_without_message_adapter_removes_request(): + handler = RequestedAttributeResponseHandler() + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + await handler.handle(f"topic/{req.request_id}", b"payload") + assert req.request_id not in handler._pending_attribute_requests + + +@pytest.mark.asyncio +async def test_handle_with_no_pending_request(): + handler = RequestedAttributeResponseHandler() + resp = RequestedAttributeResponse( + request_id=999, + shared=[], + client=[] + ) + handler.set_message_adapter(DummyMessageAdapter(response=resp)) + # No request registered → should just return + await handler.handle("topic/999", b"payload") + + +@pytest.mark.asyncio +async def test_handle_with_no_callback(): + handler = RequestedAttributeResponseHandler() + req = await AttributeRequest.build() + await handler.register_request(req, None) + resp = RequestedAttributeResponse( + request_id=req.request_id, + shared=[], + client=[] + ) + handler.set_message_adapter(DummyMessageAdapter(response=resp)) + await handler.handle(f"topic/{req.request_id}", b"payload") + + +@pytest.mark.asyncio +async def test_handle_with_parse_exception(): + handler = RequestedAttributeResponseHandler() + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + handler.set_message_adapter(DummyMessageAdapter(exc=ValueError("bad parse"))) + await handler.handle(f"topic/{req.request_id}", b"payload") + + +def test_clear(): + handler = RequestedAttributeResponseHandler() + handler._pending_attribute_requests[1] = ("req", lambda r: None) + handler.clear() + assert not handler._pending_attribute_requests + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/test_rpc_requests_handler.py b/tests/service/device/handlers/test_rpc_requests_handler.py new file mode 100644 index 0000000..55b8a96 --- /dev/null +++ b/tests/service/device/handlers/test_rpc_requests_handler.py @@ -0,0 +1,225 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse, RPCStatus +from tb_mqtt_client.service.device.handlers.rpc_requests_handler import RPCRequestsHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + + +class DummyMessageAdapter(MessageAdapter): + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + pass + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + pass + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def parse_rpc_request(self, topic, payload): + # Simulate payload as dict with required keys + data = {"method": "testMethod", "params": {"foo": "bar"}} + return RPCRequest._deserialize_from_dict(42, data) + + +@pytest.mark.asyncio +async def test_handle_no_callback_returns_none(): + handler = RPCRequestsHandler() + handler.set_message_adapter(DummyMessageAdapter()) + # No callback set + result = await handler.handle("topic", b"payload") + assert result is None + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_returns_none(): + handler = RPCRequestsHandler() + + async def cb(_): + return RPCResponse.build(1, result={"ok": True}) + + handler.set_callback(cb) + # No message adapter set + result = await handler.handle("topic", b"payload") + assert result is None + + +@pytest.mark.asyncio +async def test_handle_callback_returns_valid_response(): + handler = RPCRequestsHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + async def cb(rpc_request: RPCRequest): + assert isinstance(rpc_request, RPCRequest) + return RPCResponse.build(rpc_request.request_id, result={"success": True}) + + handler.set_callback(cb) + result = await handler.handle("topic", b"payload") + assert isinstance(result, RPCResponse) + assert result.status == RPCStatus.SUCCESS + assert result.result == {"success": True} + assert result.error is None + + +@pytest.mark.asyncio +async def test_handle_callback_returns_invalid_type(): + handler = RPCRequestsHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + async def cb(_): + return "not-an-rpc-response" + + handler.set_callback(cb) + result = await handler.handle("topic", b"payload") + assert result is None + + +@pytest.mark.asyncio +async def test_handle_parse_raises_exception(): + class BadAdapter(DummyMessageAdapter): + def parse_rpc_request(self, topic, payload): + raise RuntimeError("bad parse") + + handler = RPCRequestsHandler() + handler.set_message_adapter(BadAdapter()) + + async def cb(_): + return RPCResponse.build(1, result={}) + + handler.set_callback(cb) + result = await handler.handle("topic", b"payload") + assert result is None + + +def test_set_message_adapter_invalid_type(): + handler = RPCRequestsHandler() + with pytest.raises(ValueError): + handler.set_message_adapter("not-a-message-adapter") + + +def test_rpc_response_build_success_and_repr_and_payload(): + resp = RPCResponse.build(1, result={"x": 1}) + assert resp.status == RPCStatus.SUCCESS + assert resp.result == {"x": 1} + assert resp.error is None + assert "RPCResponse" in repr(resp) + assert resp.to_payload_format() == {"result": {"x": 1}} + + +def test_rpc_response_build_with_str_error(): + resp = RPCResponse.build(1, error="something went wrong") + assert resp.status == RPCStatus.ERROR + assert resp.error == "something went wrong" + assert "error" in resp.to_payload_format() + + +def test_rpc_response_build_with_dict_error(): + err_dict = {"code": 123} + resp = RPCResponse.build(1, error=err_dict) + assert resp.status == RPCStatus.ERROR + assert resp.error == err_dict + + +def test_rpc_response_build_with_exception_error(): + exc = ValueError("bad value") + resp = RPCResponse.build(1, error=exc) + assert resp.status == RPCStatus.ERROR + assert isinstance(resp.error, dict) + assert "message" in resp.error + assert "type" in resp.error + assert "details" in resp.error + + +def test_rpc_response_invalid_error_type(): + with pytest.raises(ValueError): + RPCResponse.build(1, error=object()) + + +def test_rpc_response_direct_init_disallowed(): + with pytest.raises(TypeError): + RPCResponse(1, result={}) + + +@pytest.mark.asyncio +async def test_rpc_request_build_and_str_and_payload_format(): + req = await RPCRequest.build("myMethod", params={"y": 2}) + assert req.method == "myMethod" + assert req.params == {"y": 2} + assert "myMethod" in str(req) + payload = req.to_payload_format() + assert payload["method"] == "myMethod" + assert payload["params"] == {"y": 2} + +@pytest.mark.asyncio +async def test_rpc_request_build_invalid_method_type(): + with pytest.raises(ValueError): + await RPCRequest.build(123) + + +def test_rpc_request_deserialize_missing_method(): + with pytest.raises(ValueError): + RPCRequest._deserialize_from_dict(1, {}) + + +def test_rpc_request_deserialize_invalid_id_type(): + with pytest.raises(ValueError): + RPCRequest._deserialize_from_dict(object(), {"method": "x"}) + + +def test_rpc_request_direct_init_disallowed(): + with pytest.raises(TypeError): + RPCRequest(1, "method") + + +def test_rpc_request_to_payload_format_without_params(): + req = RPCRequest._deserialize_from_dict(5, {"method": "noParams"}) + payload = req.to_payload_format() + assert "params" not in payload + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/test_rpc_response_handler.py b/tests/service/device/handlers/test_rpc_response_handler.py new file mode 100644 index 0000000..ced4323 --- /dev/null +++ b/tests/service/device/handlers/test_rpc_response_handler.py @@ -0,0 +1,193 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import List + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter, JsonMessageAdapter + + +class DummyMessageAdapter(MessageAdapter): + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + pass + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + pass + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def __init__(self, rpc_response): + super().__init__() + self._resp = rpc_response + + def parse_rpc_response(self, topic, payload): + return self._resp + + +@pytest.mark.asyncio +async def test_set_message_adapter_invalid_type(): + handler = RPCResponseHandler() + with pytest.raises(ValueError): + handler.set_message_adapter("not-an-adapter") + + +@pytest.mark.asyncio +async def test_register_and_handle_success_with_callback(): + handler = RPCResponseHandler() + resp = RPCResponse.build(1, result={"ok": True}) + handler.set_message_adapter(DummyMessageAdapter(resp)) + + called = {} + + async def cb(r: RPCResponse): + called["cb"] = r + + fut = handler.register_request(1, callback=cb) + await handler.handle("topic", b"payload") + assert fut.done() + result = fut.result() + assert isinstance(result, RPCResponse) + assert result.result == {"ok": True} + assert called["cb"] == resp + + +@pytest.mark.asyncio +async def test_register_request_duplicate_id_raises(): + handler = RPCResponseHandler() + handler._pending_rpc_requests[5] = (asyncio.get_event_loop().create_future(), None) + with pytest.raises(RuntimeError): + await handler.register_request(5) + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_uses_json_adapter(): + handler = RPCResponseHandler() + # Build a real RPCResponse and serialize to payload + resp = RPCResponse.build(99, result={"foo": "bar"}) + message = JsonMessageAdapter().build_rpc_response(resp) + # Register request so we can match + fut = handler.register_request(99) + # Should succeed even without explicitly setting message adapter + await handler.handle("v1/devices/me/rpc/response/99", message.payload) + assert fut.done() + assert fut.result().result == {'result': {"foo": "bar"}} + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_uses_json_adapter_with_error_rpc(): + handler = RPCResponseHandler() + # Build a real RPCResponse and serialize to payload + resp = RPCResponse.build(99, error="Some error occurred") + message = JsonMessageAdapter().build_rpc_response(resp) + # Register request so we can match + fut = handler.register_request(99) + # Should succeed even without explicitly setting message adapter + await handler.handle("v1/devices/me/rpc/response/99", message.payload) + assert fut.done() + assert fut.result().result == {'error': "Some error occurred"} + + +@pytest.mark.asyncio +async def test_handle_response_for_unknown_request_id(): + handler = RPCResponseHandler() + resp = RPCResponse.build("abc", result={"zzz": 1}) + # No registered request with "abc" + handler.set_message_adapter(DummyMessageAdapter(resp)) + await handler.handle("topic", b"payload") # Should log warning and return + # Nothing to assert except no crash + + +@pytest.mark.asyncio +async def test_handle_with_no_future(): + handler = RPCResponseHandler() + resp = RPCResponse.build(777, result={"x": 1}) + handler.set_message_adapter(DummyMessageAdapter(resp)) + handler._pending_rpc_requests[777] = (None, None) + await handler.handle("topic", b"payload") # Should log warning and return + + +@pytest.mark.asyncio +async def test_handle_callback_exception_sets_future_exception(): + handler = RPCResponseHandler() + resp = RPCResponse.build(42, result={"ok": True}) + handler.set_message_adapter(DummyMessageAdapter(resp)) + + async def bad_cb(_): + raise RuntimeError("bad") + + fut = handler.register_request(42, callback=bad_cb) + await handler.handle("topic", b"payload") + assert fut.done() + with pytest.raises(RuntimeError): + fut.result() + + +@pytest.mark.asyncio +async def test_handle_with_error_field_sets_future_exception(): + handler = RPCResponseHandler() + resp = RPCResponse.build(321, error="fail") + handler.set_message_adapter(DummyMessageAdapter(resp)) + fut = handler.register_request(321) + await handler.handle("topic", b"payload") + assert fut.done() + with pytest.raises(Exception) as e: + fut.result() + assert "fail" in str(e.value) + + +@pytest.mark.asyncio +async def test_clear_cancels_pending_futures(): + handler = RPCResponseHandler() + f1 = asyncio.get_event_loop().create_future() + handler._pending_rpc_requests[1] = (f1, None) + handler.clear() + assert f1.cancelled() + assert handler._pending_rpc_requests == {} + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tb_mqtt_client/tb_device_mqtt.py b/tests/service/gateway/__init__.py similarity index 100% rename from tb_mqtt_client/tb_device_mqtt.py rename to tests/service/gateway/__init__.py diff --git a/tests/service/gateway/handlers/__init__.py b/tests/service/gateway/handlers/__init__.py new file mode 100644 index 0000000..fa669aa --- /dev/null +++ b/tests/service/gateway/handlers/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py b/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py new file mode 100644 index 0000000..32621b0 --- /dev/null +++ b/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py @@ -0,0 +1,111 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.handlers.gateway_attribute_updates_handler import GatewayAttributeUpdatesHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +@pytest.mark.asyncio +async def test_handle_with_existing_device(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayAttributeUpdatesHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the device session + device_session = MagicMock(spec=DeviceSession) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + payload = b'{"device": "test_device", "data": {"key": "value"}}' + deserialized_data = {"device": "test_device", "data": {"key": "value"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock attribute update + attribute_update = AttributeUpdate([AttributeEntry(key="key", value="value")]) + gateway_attribute_update = GatewayAttributeUpdate(device_name="test_device", attribute_update=attribute_update) + message_adapter.parse_attribute_update.return_value = gateway_attribute_update + gateway_attribute_update.set_device_session = AsyncMock() + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_attribute_update.assert_called_once_with(deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + gateway_attribute_update.set_device_session.assert_called_once_with(device_session) + event_dispatcher.dispatch.assert_awaited_once_with(attribute_update, device_session=device_session) + + +@pytest.mark.asyncio +async def test_handle_with_nonexistent_device(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayAttributeUpdatesHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the device session (nonexistent) + device_manager.get_by_name.return_value = None + + # Mock the message adapter + payload = b'{"device": "nonexistent_device", "data": {"key": "value"}}' + deserialized_data = {"device": "nonexistent_device", "data": {"key": "value"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock attribute update + attribute_update = AttributeUpdate([AttributeEntry("key", "value")]) + gateway_attribute_update = GatewayAttributeUpdate(device_name="nonexistent_device", attribute_update=attribute_update) + message_adapter.parse_attribute_update.return_value = gateway_attribute_update + gateway_attribute_update.set_device_session = AsyncMock() + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_attribute_update.assert_called_once_with(deserialized_data) + device_manager.get_by_name.assert_called_once_with("nonexistent_device") + # set_device_session should not be called since a device is None + assert not hasattr(gateway_attribute_update, 'set_device_session') or not gateway_attribute_update.set_device_session.called + event_dispatcher.dispatch.assert_awaited_once_with(attribute_update, device_session=None) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py b/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py new file mode 100644 index 0000000..670ac51 --- /dev/null +++ b/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py @@ -0,0 +1,405 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.handlers.gateway_requested_attributes_response_handler import \ + GatewayRequestedAttributeResponseHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +@pytest.mark.asyncio +async def test_register_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Act + await handler.register_request(request) + + # Assert + assert (request.device_session.device_info.device_name, request.request_id) in handler._pending_attribute_requests + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][0] == request + + +@pytest.mark.asyncio +async def test_register_request_with_timeout(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Mock asyncio.get_event_loop().call_later + with patch('asyncio.get_event_loop') as mock_get_loop: + mock_loop = MagicMock() + mock_get_loop.return_value = mock_loop + mock_timeout_task = MagicMock() + mock_loop.call_later.return_value = mock_timeout_task + + # Act + await handler.register_request(request, timeout=10) + + # Assert + assert (request.device_session.device_info.device_name, request.request_id) in handler._pending_attribute_requests + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][0] == request + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][1] == mock_timeout_task + mock_loop.call_later.assert_called_once_with(10, handler._on_timeout, request.device_session.device_info.device_name, request.request_id) + + +@pytest.mark.asyncio +async def test_register_request_duplicate(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request once + await handler.register_request(request) + + # Act & Assert - should raise RuntimeError when registering the same request ID again + with pytest.raises(RuntimeError): + await handler.register_request(request) + + +@pytest.mark.asyncio +async def test_unregister_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, None) + + # Act + handler.unregister_request(request.device_session.device_info.device_name, request.request_id) + + # Assert + assert (request.device_session.device_info.device_name, request.request_id) not in handler._pending_attribute_requests + + +def test_unregister_nonexistent_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Act - should not raise an exception + handler.unregister_request("nonexistent_device", 999) + + # Assert + assert ("nonexistent_device", 999) not in handler._pending_attribute_requests + + +@pytest.mark.asyncio +async def test_handle_valid_response(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + timeout_task = MagicMock() + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, timeout_task) + + # Mock the message adapter + payload = b'{"device": "test_device", "id": ' + str(request.request_id).encode() + b', "values": {"key1": "value1"}}' + deserialized_data = {"device": "test_device", "id": request.request_id, "values": {"key1": "value1"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock response + response = GatewayRequestedAttributeResponse(device_name="test_device", request_id=request.request_id, shared=[AttributeEntry("key1", "value1")]) + message_adapter.parse_gateway_requested_attribute_response.return_value = response + + # Mock asyncio.create_task + with patch('asyncio.create_task') as mock_create_task: + mock_task = MagicMock() + mock_create_task.return_value = mock_task + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_gateway_requested_attribute_response.assert_called_once_with(request, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + timeout_task.cancel.assert_called_once() + mock_create_task.assert_called_once() + mock_task.add_done_callback.assert_called_once_with(handler._handle_callback_exception) + + +@pytest.mark.asyncio +async def test_handle_missing_fields(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the message adapter + payload = b'{"values": {"key1": "value1"}}' # Missing 'device' and 'id' + deserialized_data = {"values": {"key1": "value1"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + + +@pytest.mark.asyncio +async def test_handle_no_pending_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the message adapter + payload = b'{"device": "test_device", "id": 999, "values": {"key1": "value1"}}' + deserialized_data = {"device": "test_device", "id": 999, "values": {"key1": "value1"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Act + result = await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + assert result is None + + +@pytest.mark.asyncio +async def test_on_timeout(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, None) + + handler._on_timeout(request.device_session.device_info.device_name, request.request_id) + + # Assert + assert (request.device_session.device_info.device_name, request.request_id) not in handler._pending_attribute_requests + + +def test_handle_callback_exception(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a mock task that raises an exception + task = MagicMock(spec=asyncio.Task) + task.result.side_effect = Exception("Test exception") + + # Act + handler._handle_callback_exception(task) + + # Assert + task.result.assert_called_once() + + +@pytest.mark.asyncio +async def test_clear(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, None) + + # Act + handler.clear() + + # Assert + assert len(handler._pending_attribute_requests) == 0 + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_removes_request(): + adapter = MagicMock(spec=GatewayMessageAdapter) + handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) + handler._pending_attribute_requests["test_device", 5] = (MagicMock(spec=AttributeRequest), AsyncMock()) + topic = "attr/request/5" + await handler.handle(topic, b"{}") + assert 5 not in handler._pending_attribute_requests + +@pytest.mark.asyncio +async def test_handle_with_no_callback_registered(): + adapter = MagicMock(spec=GatewayMessageAdapter) + handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) + resp = MagicMock(spec=RequestedAttributeResponse) + resp.request_id = 42 + adapter.parse_gateway_requested_attribute_response.return_value = resp + handler._pending_attribute_requests['test_device', 42] = (MagicMock(spec=AttributeRequest), None) + await handler.handle("topic", b"payload") + +@pytest.mark.asyncio +async def test_handle_with_parsing_exception(): + adapter = MagicMock(spec=GatewayMessageAdapter) + handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) + adapter.parse_gateway_requested_attribute_response.side_effect = RuntimeError("bad parse") + await handler.handle("topic", b"payload") + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/handlers/test_gateway_rpc_handler.py b/tests/service/gateway/handlers/test_gateway_rpc_handler.py new file mode 100644 index 0000000..a75ac5a --- /dev/null +++ b/tests/service/gateway/handlers/test_gateway_rpc_handler.py @@ -0,0 +1,398 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler import GatewayRPCHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +def test_init(): + # Setup + event_dispatcher = MagicMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Act + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Assert + assert handler._event_dispatcher == event_dispatcher + assert handler._message_adapter == message_adapter + assert handler._device_manager == device_manager + assert handler._stop_event == stop_event + assert handler._callback is None + event_dispatcher.register.assert_called_once_with(GatewayEventType.DEVICE_RPC_REQUEST, handler.handle) + + +@pytest.mark.asyncio +async def test_handle_successful_rpc(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Create a mock RPC response + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Mock the event dispatcher to return the RPC response + event_dispatcher.dispatch.side_effect = [rpc_response, asyncio.Future()] + + # Mock await_or_stop + with patch('tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler.await_or_stop') as mock_await_or_stop: + # Act + result = await handler.handle(topic, payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + + # Check that dispatch was called twice - once for the request and once for the response + assert event_dispatcher.dispatch.call_count == 2 + event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) + event_dispatcher.dispatch.assert_any_call(rpc_response) + + mock_await_or_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler with no message adapter + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=None, + device_manager=device_manager, + stop_event=stop_event + ) + + # Act + result = await handler.handle("topic", b'{}') + + # Assert + assert result is None + + +@pytest.mark.asyncio +async def test_handle_no_device_session(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Mock the device manager to return None (no device session) + device_manager.get_by_name.return_value = None + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "nonexistent_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "nonexistent_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("nonexistent_device") + + +@pytest.mark.asyncio +async def test_handle_no_response_from_callback(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Mock the event dispatcher to return None (no response from callback) + event_dispatcher.dispatch.return_value = None + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_called_once_with(rpc_request, device_session=device_session) + + +@pytest.mark.asyncio +async def test_handle_invalid_response_type(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Mock the event dispatcher to return an invalid response type + event_dispatcher.dispatch.side_effect = ["invalid_response", asyncio.Future()] + + # Act + result = await handler.handle(topic, payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_called() + + +@pytest.mark.asyncio +async def test_handle_exception_in_processing(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter to raise an exception + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + message_adapter.deserialize_to_dict.side_effect = Exception("Test exception") + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + + +@pytest.mark.asyncio +async def test_handle_timeout_in_publish(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Create a mock RPC response + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Mock the event dispatcher + event_dispatcher.dispatch.side_effect = [rpc_response, asyncio.Future()] + + # Mock await_or_stop to raise TimeoutError + with patch('tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler.await_or_stop') as mock_await_or_stop: + mock_await_or_stop.side_effect = TimeoutError("Timeout") + + # Act + result = await handler.handle(topic, payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) + event_dispatcher.dispatch.assert_any_call(rpc_response) + mock_await_or_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_no_publish_futures(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Create a mock RPC response + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Mock the event dispatcher to return the RPC response but no publish futures + event_dispatcher.dispatch.side_effect = [rpc_response, None] + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) + event_dispatcher.dispatch.assert_any_call(rpc_response) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_device_manager.py b/tests/service/gateway/test_device_manager.py new file mode 100644 index 0000000..58715e9 --- /dev/null +++ b/tests/service/gateway/test_device_manager.py @@ -0,0 +1,391 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_register_new_device(): + # Setup + manager = DeviceManager() + + # Act + session = manager.register("test_device", "default") + + # Assert + assert session is not None + assert session.device_info.device_name == "test_device" + assert session.device_info.device_profile == "default" + assert session.device_info.device_id in manager._sessions_by_id + assert "test_device" in manager._ids_by_device_name + assert session.device_info.original_name in manager._ids_by_original_name + + +def test_register_existing_device(): + # Setup + manager = DeviceManager() + first_session = manager.register("test_device", "default") + + # Act + second_session = manager.register("test_device", "custom") + + # Assert + assert first_session is second_session + assert second_session.device_info.device_profile == "default" # Profile should not change + + +def test_unregister_device(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + + # Act + manager.unregister(device_id) + + # Assert + assert device_id not in manager._sessions_by_id + assert "test_device" not in manager._ids_by_device_name + assert session.device_info.original_name not in manager._ids_by_original_name + + +def test_unregister_nonexistent_device(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + + # Act & Assert - should not raise an exception + manager.unregister(nonexistent_id) + + +def test_get_by_id(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + + # Act + retrieved_session = manager.get_by_id(device_id) + + # Assert + assert retrieved_session is session + + +def test_get_by_id_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + + # Act + retrieved_session = manager.get_by_id(nonexistent_id) + + # Assert + assert retrieved_session is None + + +def test_get_by_name(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + + # Act + retrieved_session = manager.get_by_name("test_device") + + # Assert + assert retrieved_session is session + + +def test_get_by_original_name(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + original_name = session.device_info.original_name + + # Rename the device + manager.rename_device("test_device", "new_name") + + # Act + retrieved_session = manager.get_by_name(original_name) + + # Assert + assert retrieved_session is session + + +def test_get_by_name_nonexistent(): + # Setup + manager = DeviceManager() + + # Act + retrieved_session = manager.get_by_name("nonexistent_device") + + # Assert + assert retrieved_session is None + + +def test_is_connected(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + + # Act + is_connected = manager.is_connected(device_id) + + # Assert + assert is_connected is True + + +def test_is_connected_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + + # Act + is_connected = manager.is_connected(nonexistent_id) + + # Assert + assert is_connected is False + + +def test_all(): + # Setup + manager = DeviceManager() + session1 = manager.register("device1", "default") + session2 = manager.register("device2", "default") + + # Act + all_sessions = list(manager.all()) + + # Assert + assert len(all_sessions) == 2 + assert session1 in all_sessions + assert session2 in all_sessions + + +def test_rename_device(): + # Setup + manager = DeviceManager() + session = manager.register("old_name", "default") + device_id = session.device_info.device_id + + # Act + manager.rename_device("old_name", "new_name") + + # Assert + assert "old_name" not in manager._ids_by_device_name + assert "new_name" in manager._ids_by_device_name + assert manager._ids_by_device_name["new_name"] == device_id + assert session.device_info.device_name == "new_name" + assert session.device_info.original_name in manager._ids_by_original_name + + +def test_rename_nonexistent_device(): + # Setup + manager = DeviceManager() + + # Act - should not raise an exception + manager.rename_device("nonexistent_device", "new_name") + + # Assert + assert "nonexistent_device" not in manager._ids_by_device_name + assert "new_name" not in manager._ids_by_device_name + + +def test_set_attribute_update_callback(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + callback = MagicMock() + + # Mock the session's set_attribute_update_callback method + session.set_attribute_update_callback = MagicMock() + + # Act + manager.set_attribute_update_callback(device_id, callback) + + # Assert + session.set_attribute_update_callback.assert_called_once_with(callback) + + +def test_set_attribute_update_callback_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + callback = MagicMock() + + # Act - should not raise an exception + manager.set_attribute_update_callback(nonexistent_id, callback) + + +def test_set_attribute_response_callback(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + callback = MagicMock() + + # Mock the session's set_attribute_response_callback method + session.set_attribute_response_callback = MagicMock() + + # Act + manager.set_attribute_response_callback(device_id, callback) + + # Assert + session.set_attribute_response_callback.assert_called_once_with(callback) + + +def test_set_attribute_response_callback_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + callback = MagicMock() + + # Act - should not raise an exception + manager.set_attribute_response_callback(nonexistent_id, callback) + + +def test_set_rpc_request_callback(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + callback = MagicMock() + + # Mock the session's set_rpc_request_callback method + session.set_rpc_request_callback = MagicMock() + + # Act + manager.set_rpc_request_callback(device_id, callback) + + # Assert + session.set_rpc_request_callback.assert_called_once_with(callback) + + +def test_set_rpc_request_callback_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + callback = MagicMock() + + # Act - should not raise an exception + manager.set_rpc_request_callback(nonexistent_id, callback) + + +def test_state_change_callback(): + # Setup + manager = DeviceManager() + + # Create a device session with a mocked state + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info, manager._DeviceManager__state_change_callback) + + # Mock the session's state + session.state = MagicMock() + session.state.is_connected = MagicMock(return_value=True) + + # Act + manager._DeviceManager__state_change_callback(session) + + # Assert + assert session in manager.connected_devices + + +def test_state_change_callback_disconnect(): + # Setup + manager = DeviceManager() + + # Create a device session with a mocked state + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info, manager._DeviceManager__state_change_callback) + + # Add the session to connected devices + manager._DeviceManager__connected_devices.add(session) + + # Mock the session's state + session.state = MagicMock() + session.state.is_connected = MagicMock(return_value=False) + + # Act + manager._DeviceManager__state_change_callback(session) + + # Assert + assert session not in manager.connected_devices + + +def test_connected_devices_property(): + # Setup + manager = DeviceManager() + session1 = manager.register("device1", "default") + session2 = manager.register("device2", "default") + + # Add sessions to connected devices + manager._DeviceManager__connected_devices.add(session1) + manager._DeviceManager__connected_devices.add(session2) + + # Act + connected = manager.connected_devices + + # Assert + assert len(connected) == 2 + assert session1 in connected + assert session2 in connected + + +def test_all_devices_property(): + # Setup + manager = DeviceManager() + session1 = manager.register("device1", "default") + session2 = manager.register("device2", "default") + + # Act + all_devices = manager.all_devices + + # Assert + assert len(all_devices) == 2 + assert session1.device_info.device_id in all_devices + assert session2.device_info.device_id in all_devices + assert all_devices[session1.device_info.device_id] is session1 + assert all_devices[session2.device_info.device_id] is session2 + + +def test_repr(): + # Setup + manager = DeviceManager() + manager.register("device1", "default") + manager.register("device2", "default") + + # Act + repr_string = repr(manager) + + # Assert + assert "DeviceManager" in repr_string + assert "device1" in repr_string + assert "device2" in repr_string + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/service/gateway/test_device_session.py b/tests/service/gateway/test_device_session.py new file mode 100644 index 0000000..93e5b6b --- /dev/null +++ b/tests/service/gateway/test_device_session.py @@ -0,0 +1,307 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.device_session_state import DeviceSessionState +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_init(): + # Setup + device_info = DeviceInfo("test_device", "default") + state_change_callback = MagicMock() + + # Act + session = DeviceSession(device_info, state_change_callback) + + # Assert + assert session.device_info == device_info + assert session._state_change_callback == state_change_callback + assert session.state.is_connected() + assert session.connected_at > 0 + assert session.last_seen_at > 0 + assert session.claimed is False + assert session.provisioned is False + assert session.attribute_update_callback is None + assert session.attribute_response_callback is None + assert session.rpc_request_callback is None + + +def test_update_state(): + # Setup + device_info = DeviceInfo("test_device", "default") + state_change_callback = MagicMock() + session = DeviceSession(device_info, state_change_callback) + + # Act + session.update_state(DeviceSessionState.DISCONNECTED) + + # Assert + assert session.state == DeviceSessionState.DISCONNECTED + state_change_callback.assert_called_once_with(session) + + +def test_update_state_no_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) # No callback provided + + # Act - should not raise an exception + session.update_state(DeviceSessionState.DISCONNECTED) + + # Assert + assert session.state == DeviceSessionState.DISCONNECTED + + +def test_update_last_seen(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + old_last_seen = session.last_seen_at + + # Wait a bit to ensure the timestamp changes + import time + time.sleep(0.001) + + # Act + session.update_last_seen() + + # Assert + assert session.last_seen_at > old_last_seen + + +def test_set_attribute_update_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + + # Act + session.set_attribute_update_callback(callback) + + # Assert + assert session.attribute_update_callback == callback + + +def test_set_attribute_response_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + + # Act + session.set_attribute_response_callback(callback) + + # Assert + assert session.attribute_response_callback == callback + + +def test_set_rpc_request_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + + # Act + session.set_rpc_request_callback(callback) + + # Assert + assert session.rpc_request_callback == callback + + +@pytest.mark.asyncio +async def test_handle_event_to_device_attribute_update(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + session.set_attribute_update_callback(callback) + + # Create an attribute update event + attribute_entry = AttributeEntry(key="key", value="value") + attribute_update = AttributeUpdate([attribute_entry]) + + gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, attribute_update=attribute_update) + + # Act + result = await session.handle_event_to_device(gateway_attribute_update.attribute_update) + + # Assert + callback.assert_called_once_with(session, attribute_update) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_handle_event_to_device_attribute_response(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + session.set_attribute_response_callback(callback) + + # Create an attribute response event + gateway_requested_attribute_response = GatewayRequestedAttributeResponse(request_id=1, device_name="test_device", shared=[AttributeEntry(key="shared_key", value="shared_value")], client=[]) + + # Act + result = await session.handle_event_to_device(gateway_requested_attribute_response) + + # Assert + callback.assert_called_once_with(session, gateway_requested_attribute_response) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_handle_event_to_device_rpc_request(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + session.set_rpc_request_callback(callback) + + # Create an RPC request event + request_dict = { + "device": "test_device", + "data": { + "id": 1, + "method": "test", + "params": {"param": "value"} + } + } + rpc_request = GatewayRPCRequest._deserialize_from_dict(request_dict) + + # Act + result = await session.handle_event_to_device(rpc_request) + + # Assert + callback.assert_called_once_with(session, rpc_request) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_handle_event_to_device_no_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + # Create an attribute update event + attribute_update = AttributeUpdate([AttributeEntry(key="shared_key", value="shared_value")]) + gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, attribute_update=attribute_update) + + # Act + result = await session.handle_event_to_device(gateway_attribute_update.attribute_update) + + # Assert + assert result is None + + +@pytest.mark.asyncio +async def test_handle_event_to_device_async_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + rpc_request = { + "device": "test_device", + "data": { + "id": 1, + "method": "test", + "params": {"param": "value"} + } + } + + # Create an async callback + async def async_callback(session, event: GatewayRPCRequest): + return GatewayRPCResponse.build(device_name=event.device_name, request_id=event.request_id, result="success") + + session.set_rpc_request_callback(async_callback) + + # Create an RPC request event + rpc_request = GatewayRPCRequest._deserialize_from_dict(rpc_request) + + # Act + result = await session.handle_event_to_device(rpc_request) + + # Assert + assert isinstance(result, GatewayRPCResponse) + assert result.device_name == "test_device" + assert result.request_id == 1 + assert result.result == "success" + + +@pytest.mark.asyncio +async def test_handle_event_to_device_unsupported_event(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + # Create a mock event with an unsupported event type + mock_event = MagicMock() + mock_event.event_type = "UNSUPPORTED_EVENT_TYPE" + + # Act + result = await session.handle_event_to_device(mock_event) + + # Assert + assert result is None + + +def test_equality(): + # Setup + device_info1 = DeviceInfo("test_device", "default") + device_info2 = DeviceInfo("test_device", "default") # Same device name, but different UUID + device_info3 = DeviceInfo("other_device", "default") + + session1 = DeviceSession(device_info1) + session2 = DeviceSession(device_info2) + session3 = DeviceSession(device_info3) + + # Act & Assert + assert session1 != session2 # Different UUIDs + assert session1 != session3 + assert session2 != session3 + + # Test equality with the same device_info + session4 = DeviceSession(device_info1) + assert session1 == session4 # Same device_info + + +def test_hash(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + # Act + hash_value = hash(session) + + # Assert + assert hash_value == hash(device_info) + + # Test that sessions with the same device_info have the same hash + session2 = DeviceSession(device_info) + assert hash(session) == hash(session2) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_direct_event_dispatcher.py b/tests/service/gateway/test_direct_event_dispatcher.py new file mode 100644 index 0000000..3de95d1 --- /dev/null +++ b/tests/service/gateway/test_direct_event_dispatcher.py @@ -0,0 +1,278 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher + + +def test_init(): + # Setup & Act + dispatcher = DirectEventDispatcher() + + # Assert + assert isinstance(dispatcher._handlers, dict) + assert len(dispatcher._handlers) == 0 + assert isinstance(dispatcher._lock, asyncio.Lock) + + +def test_register(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Act + dispatcher.register(event_type, callback) + + # Assert + assert event_type in dispatcher._handlers + assert callback in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 1 + + +def test_register_duplicate(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback twice + dispatcher.register(event_type, callback) + dispatcher.register(event_type, callback) + + # Assert + assert event_type in dispatcher._handlers + assert callback in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 1 # Should only be added once + + +def test_register_multiple(): + # Setup + dispatcher = DirectEventDispatcher() + callback1 = MagicMock() + callback2 = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Act + dispatcher.register(event_type, callback1) + dispatcher.register(event_type, callback2) + + # Assert + assert event_type in dispatcher._handlers + assert callback1 in dispatcher._handlers[event_type] + assert callback2 in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 2 + + +def test_unregister(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register and then unregister + dispatcher.register(event_type, callback) + dispatcher.unregister(event_type, callback) + + # Assert + assert event_type not in dispatcher._handlers # Event type should be removed when last callback is unregistered + + +def test_unregister_nonexistent(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Act - should not raise an exception + dispatcher.unregister(event_type, callback) + + # Assert + assert event_type not in dispatcher._handlers + + +def test_unregister_one_of_many(): + # Setup + dispatcher = DirectEventDispatcher() + callback1 = MagicMock() + callback2 = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register both callbacks and unregister one + dispatcher.register(event_type, callback1) + dispatcher.register(event_type, callback2) + dispatcher.unregister(event_type, callback1) + + # Assert + assert event_type in dispatcher._handlers + assert callback1 not in dispatcher._handlers[event_type] + assert callback2 in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 1 + + +@pytest.mark.asyncio +async def test_dispatch_to_device_session(): + # Setup + dispatcher = DirectEventDispatcher() + device_info = DeviceInfo("test_device", "default") + device_session = MagicMock(spec=DeviceSession) + device_session.handle_event_to_device = AsyncMock() + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = GatewayEventType.DEVICE_ATTRIBUTE_UPDATE + + # Act + await dispatcher.dispatch(event, device_session=device_session) + + # Assert + device_session.handle_event_to_device.assert_awaited_once_with(event) + + +@pytest.mark.asyncio +async def test_dispatch_to_sync_callback(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_called_once_with(event) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_dispatch_to_async_callback(): + # Setup + dispatcher = DirectEventDispatcher() + callback = AsyncMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_awaited_once_with(event) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_dispatch_with_args(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + await dispatcher.dispatch(event, "arg1", "arg2", kwarg1="value1", kwarg2="value2") + + # Assert + callback.assert_called_once_with(event, "arg1", "arg2", kwarg1="value1", kwarg2="value2") + + +@pytest.mark.asyncio +async def test_dispatch_no_handlers(): + # Setup + dispatcher = DirectEventDispatcher() + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = GatewayEventType.DEVICE_CONNECT + + # Act + result = await dispatcher.dispatch(event) + + # Assert + assert result is None + + +@pytest.mark.asyncio +async def test_dispatch_callback_exception(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock(side_effect=Exception("Test exception")) + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_called_once_with(event) + assert result is None + + +@pytest.mark.asyncio +async def test_dispatch_async_callback_exception(): + # Setup + dispatcher = DirectEventDispatcher() + callback = AsyncMock(side_effect=Exception("Test exception")) + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_awaited_once_with(event) + assert result is None + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/service/gateway/test_gateway_client.py b/tests/service/gateway/test_gateway_client.py new file mode 100644 index 0000000..5eb017d --- /dev/null +++ b/tests/service/gateway/test_gateway_client.py @@ -0,0 +1,505 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, \ + GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequestBuilder +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +# @pytest.mark.asyncio +# async def test_connect(): +# # Setup +# client = GatewayClient() +# client._mqtt_manager = AsyncMock() +# client._mqtt_manager.is_connected = MagicMock(return_value=True) +# client._mqtt_manager.await_ready = AsyncMock() +# client._mqtt_manager.subscribe = AsyncMock(return_value=asyncio.Future()) +# client._mqtt_manager.register_handler = MagicMock() +# +# # Act +# await client.connect() +# +# # Assert +# client._mqtt_manager.connect.assert_awaited_once() +# assert client._mqtt_manager.subscribe.call_count == 3 # Should subscribe to 3 topics +# assert client._mqtt_manager.register_handler.call_count == 3 # Should register 3 handlers + + +@pytest.mark.asyncio +async def test_connect_device(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._event_dispatcher.dispatch = AsyncMock() + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_CONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + device_session, result = await client.connect_device("test_device", wait_for_publish=True) + + # Assert + assert device_session is not None + assert device_session.device_info.device_name == "test_device" + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_CONNECT_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_device_with_connect_message(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._event_dispatcher.dispatch = AsyncMock() + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_CONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a device connect message + connect_message = DeviceConnectMessage.build("test_device", "custom_profile") + + # Act + device_session, result = await client.connect_device(connect_message, wait_for_publish=True) + + # Assert + assert device_session is not None + assert device_session.device_info.device_name == "test_device" + assert device_session.device_info.device_profile == "custom_profile" + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_CONNECT_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_device_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + device_session, futures = await client.connect_device("test_device", wait_for_publish=False) + + # Assert + assert device_session is not None + assert device_session.device_info.device_name == "test_device" + assert isinstance(futures[0], asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_device(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_DISCONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + session, result = await client.disconnect_device(device_session, wait_for_publish=True) + + # Assert + assert session is not None + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_DISCONNECT_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_device_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + session, futures = await client.disconnect_device(device_session, wait_for_publish=False) + + # Assert + assert session is not None + assert isinstance(futures[0], asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_timeseries(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + result = await client.send_device_timeseries(device_session, {"temperature": 25.5}, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_TELEMETRY_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_timeseries_with_timeseries_entry(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a timeseries entry + timeseries = TimeseriesEntry("temperature", 25.5) + + # Act + result = await client.send_device_timeseries(device_session, timeseries, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_TELEMETRY_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_timeseries_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + future = await client.send_device_timeseries(device_session, {"temperature": 25.5}, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + result = await client.send_device_attributes(device_session, {"firmware_version": "1.0.0"}, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_ATTRIBUTES_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_with_attribute_entry(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create an attribute entry + attributes = AttributeEntry("firmware_version", "1.0.0") + + # Act + result = await client.send_device_attributes(device_session, attributes, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_ATTRIBUTES_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + future = await client.send_device_attributes(device_session, {"firmware_version": "1.0.0"}, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_request(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._gateway_requested_attribute_response_handler = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_ATTRIBUTES_REQUEST_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway attribute request + request = await GatewayAttributeRequest.build(device_session, shared_keys=["firmware_version"], client_keys=None) + + # Act + result = await client.send_device_attributes_request(device_session, request, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_ATTRIBUTES_REQUEST_TOPIC + client._gateway_requested_attribute_response_handler.register_request.assert_awaited_once() + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_request_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._gateway_requested_attribute_response_handler = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway attribute request + request = await GatewayAttributeRequest.build(device_session, shared_keys=["firmware_version"], client_keys=None) + + # Act + future = await client.send_device_attributes_request(device_session, request, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._gateway_requested_attribute_response_handler.register_request.assert_awaited_once() + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_claim_request(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that dispatch will return + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_CLAIM_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway claim request + device_claim_request = ClaimRequest.build(secret_key="secret", duration=1000) + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, device_claim_request).build() + + # Act + result = await client.send_device_claim_request(device_session, gateway_claim_request, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_CLAIM_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_claim_request_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that dispatch will return + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway claim request + device_claim_request = ClaimRequest.build(secret_key="secret", duration=1000) + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, device_claim_request).build() + + # Act + future = await client.send_device_claim_request(device_session, gateway_claim_request, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect(): + # Setup + client = GatewayClient() + client._mqtt_manager = AsyncMock() + client._mqtt_manager.unsubscribe = AsyncMock(return_value=asyncio.Future()) + + # Act + await client.disconnect() + + # Assert + client._mqtt_manager.unsubscribe.assert_awaited() + client._mqtt_manager.disconnect.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response(): + # Setup + client = GatewayClient() + client._mqtt_manager = MagicMock() + client._gateway_message_adapter = MagicMock() + client._gateway_message_adapter.splitter = MagicMock() + client._gateway_rate_limiter = MagicMock() + client._gateway_rate_limiter.message_rate_limit = AsyncMock() + client._gateway_rate_limiter.telemetry_message_rate_limit = AsyncMock() + client._gateway_rate_limiter.telemetry_datapoints_rate_limit = AsyncMock() + + # Create a response with gateway rate limits + response = RPCResponse.build(1, result={ + 'gatewayRateLimits': { + 'messages': '10:1,', + 'telemetryMessages': '100:60,', + 'telemetryDataPoints': '500:60,' + }, + 'maxPayloadSize': 512 + }) + + # Mock the parent class method + with patch('tb_mqtt_client.service.device.client.DeviceClient._handle_rate_limit_response', return_value=True): + # Act + result = await client._handle_rate_limit_response(response) + + # Assert + assert result is True + client._gateway_rate_limiter.message_rate_limit.set_limit.assert_awaited_once() + client._gateway_rate_limiter.telemetry_message_rate_limit.set_limit.assert_awaited_once() + client._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit.assert_awaited_once() + client._mqtt_manager.set_gateway_rate_limits_received.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response_invalid_response(): + # Setup + client = GatewayClient() + + # Create an invalid response + response = RPCResponse.build(1, result="invalid") + + # Mock the parent class method + with patch('tb_mqtt_client.service.device.client.DeviceClient._handle_rate_limit_response', return_value=None): + # Act + result = await client._handle_rate_limit_response(response) + + # Assert + assert result is None + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_message_adapter.py b/tests/service/gateway/test_message_adapter.py new file mode 100644 index 0000000..b714a3c --- /dev/null +++ b/tests/service/gateway/test_message_adapter.py @@ -0,0 +1,536 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime, UTC + +import orjson +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ + GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequestBuilder +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.message_adapter import JsonGatewayMessageAdapter + + +def test_init(): + # Setup & Act + adapter = JsonGatewayMessageAdapter(max_payload_size=1000, max_datapoints=100) + + # Assert + assert adapter.splitter.max_payload_size == 1000 - DEFAULT_FIELDS_SIZE + assert adapter.splitter.max_datapoints == 100 + + +def test_build_device_connect_message_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_connect_message = DeviceConnectMessage.build("test_device", "default") + + # Act + result = adapter.build_device_connect_message_payload(device_connect_message, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_CONNECT_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + assert "type" in payload_dict + assert payload_dict["type"] == "default" + + +def test_build_device_disconnect_message_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_disconnect_message = DeviceDisconnectMessage.build("test_device") + + # Act + result = adapter.build_device_disconnect_message_payload(device_disconnect_message, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_DISCONNECT_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + + +@pytest.mark.asyncio +async def test_build_gateway_attribute_request_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Act + result = adapter.build_gateway_attribute_request_payload(attribute_request, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_ATTRIBUTES_REQUEST_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + assert "keys" in payload_dict + assert payload_dict["keys"] == ["key1", "key2"] + assert "id" in payload_dict + assert payload_dict["id"] == attribute_request.request_id + + +def test_build_rpc_response_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Act + result = adapter.build_rpc_response_payload(rpc_response, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_RPC_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + assert "id" in payload_dict + assert payload_dict["id"] == 1 + assert "data" in payload_dict + assert payload_dict["data"] == {"result": "success"} + + +def test_build_claim_request_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_claim_request = ClaimRequest.build(secret_key="secret", duration=1) + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_name_or_session="test_device", + device_claim_request=device_claim_request).build() + + # Act + result = adapter.build_claim_request_payload(gateway_claim_request, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_CLAIM_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "test_device" in payload_dict + assert payload_dict["test_device"]["secretKey"] == "secret" + assert payload_dict["test_device"]["durationMs"] == 1000 + + +def test_parse_attribute_update(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "device": "test_device", + "data": { + "key1": "value1", + "key2": 42 + } + } + + # Act + result = adapter.parse_attribute_update(data) + + # Assert + assert isinstance(result, GatewayAttributeUpdate) + assert result.device_name == "test_device" + assert isinstance(result.attribute_update, AttributeUpdate) + assert result.attribute_update.as_dict() == {"key1": "value1", "key2": 42} + + +def test_parse_attribute_update_invalid_format(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "invalid_format": True + } + + # Act & Assert + with pytest.raises(ValueError): + adapter.parse_attribute_update(data) + + +@pytest.mark.asyncio +async def test_parse_gateway_requested_attribute_response(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=["key1", "key2"]) + + data = { + "device": "test_device", + "id": attribute_request.request_id, + "values": { + "key1": "value1", + "key2": 42 + } + } + + # Act + result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) + + # Assert + assert isinstance(result, GatewayRequestedAttributeResponse) + assert result.device_name == "test_device" + assert result.request_id == attribute_request.request_id + assert len(result.client) == 2 + assert any(attr.key == "key1" and attr.value == "value1" for attr in result.client) + assert any(attr.key == "key2" and attr.value == 42 for attr in result.client) + assert len(result.shared) == 0 + + +@pytest.mark.asyncio +async def test_parse_gateway_requested_attribute_response_single_value(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=["key1"]) + + data = { + "device": "test_device", + "id": attribute_request.request_id, + "value": "value1" + } + + # Act + result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) + + # Assert + assert isinstance(result, GatewayRequestedAttributeResponse) + assert result.device_name == "test_device" + assert result.request_id == attribute_request.request_id + assert len(result.client) == 1 + assert result.client[0].key == "key1" + assert result.client[0].value == "value1" + assert len(result.shared) == 0 + + +@pytest.mark.asyncio +async def test_parse_gateway_requested_attribute_response_shared_keys(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + data = { + "device": "test_device", + "id": attribute_request.request_id, + "values": { + "key1": "value1", + "key2": 42 + } + } + + # Act + result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) + + # Assert + assert isinstance(result, GatewayRequestedAttributeResponse) + assert result.device_name == "test_device" + assert result.request_id == attribute_request.request_id + assert len(result.shared) == 2 + assert any(attr.key == "key1" and attr.value == "value1" for attr in result.shared) + assert any(attr.key == "key2" and attr.value == 42 for attr in result.shared) + assert len(result.client) == 0 + + +def test_parse_rpc_request(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "device": "test_device", + "data": { + "id": 1, + "method": "test_method", + "params": { + "param1": "value1", + "param2": 42 + } + } + } + + # Act + result = adapter.parse_rpc_request("v1/gateway/rpc", data) + + # Assert + assert isinstance(result, GatewayRPCRequest) + assert result.device_name == "test_device" + assert result.request_id == 1 + assert result.method == "test_method" + assert result.params == {"param1": "value1", "param2": 42} + + +def test_parse_rpc_request_invalid_format(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "invalid_format": True + } + + # Act & Assert + with pytest.raises(ValueError): + adapter.parse_rpc_request("v1/gateway/rpc", data) + + +def test_deserialize_to_dict(): + # Setup + adapter = JsonGatewayMessageAdapter() + payload = b'{"key": "value", "number": 42}' + + # Act + result = adapter.deserialize_to_dict(payload) + + # Assert + assert isinstance(result, dict) + assert result == {"key": "value", "number": 42} + + +def test_deserialize_to_dict_invalid_format(): + # Setup + adapter = JsonGatewayMessageAdapter() + payload = b'invalid json' + + # Act & Assert + with pytest.raises(ValueError): + adapter.deserialize_to_dict(payload) + + +def test_pack_attributes(): + # Setup + attributes = [ + AttributeEntry("key1", "value1"), + AttributeEntry("key2", 42) + ] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes(attributes).build() + + # Act + result = JsonGatewayMessageAdapter.pack_attributes(uplink_message) + + # Assert + assert isinstance(result, dict) + assert result == {"key1": "value1", "key2": 42} + + +def test_pack_timeseries_no_timestamp(): + # Setup + timeseries = [TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) + + # Assert + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == {"temp": 22.5, "humidity": 45} + + +def test_pack_timeseries_with_timestamp(): + # Setup + ts = int(datetime.now(UTC).timestamp() * 1000) + timeseries = [TimeseriesEntry("temp", 22.5, ts), TimeseriesEntry("humidity", 45, ts)] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) + + # Assert + assert isinstance(result, list) + assert len(result) == 1 + assert "ts" in result[0] + assert result[0]["ts"] == ts + assert "values" in result[0] + assert result[0]["values"] == {"temp": 22.5, "humidity": 45} + + +def test_pack_timeseries_with_different_timestamps(): + # Setup + ts1 = int(datetime.now(UTC).timestamp() * 1000) + ts2 = ts1 + 1000 # 1 second later + timeseries = [TimeseriesEntry("temp", 22.5, ts1), TimeseriesEntry("humidity", 45, ts2)] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) + + # Assert + assert isinstance(result, list) + assert len(result) == 2 + + # Find entries by timestamp + ts1_entry = next((entry for entry in result if entry["ts"] == ts1), None) + ts2_entry = next((entry for entry in result if entry["ts"] == ts2), None) + + assert ts1_entry is not None + assert ts2_entry is not None + assert ts1_entry["values"] == {"temp": 22.5} + assert ts2_entry["values"] == {"humidity": 45} + + +def test_build_uplink_messages_empty(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Act + result = adapter.build_uplink_messages([]) + + # Assert + assert result == [] + + +def test_build_uplink_messages_non_gateway_message(): + # Setup + adapter = JsonGatewayMessageAdapter() + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=b"test_payload", qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 1 + assert result[0] == mqtt_msg + + +def test_build_uplink_messages_with_telemetry(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Create a gateway uplink message with telemetry + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries([TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)]) + .build()) + + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 1 + assert result[0].topic == GATEWAY_TELEMETRY_TOPIC + assert result[0].qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result[0].payload) + assert "test_device" in payload_dict + assert isinstance(payload_dict["test_device"], list) + assert len(payload_dict["test_device"]) == 1 + assert "values" in payload_dict["test_device"][0] + assert payload_dict["test_device"][0]["values"] == {"temp": 22.5, "humidity": 45} + + +def test_build_uplink_messages_with_attributes(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Create a gateway uplink message with attributes + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes([ + AttributeEntry("key1", "value1"), + AttributeEntry("key2", 42) + ]).build() + + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 1 + assert result[0].topic == GATEWAY_ATTRIBUTES_TOPIC + assert result[0].qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result[0].payload) + assert "test_device" in payload_dict + assert payload_dict["test_device"] == {"key1": "value1", "key2": 42} + + +def test_build_uplink_messages_with_both_telemetry_and_attributes(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Create a gateway uplink message with both telemetry and attributes + uplink_message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + uplink_message_builder.add_timeseries([TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)]) + uplink_message_builder.add_attributes([AttributeEntry("key1", "value1"), AttributeEntry("key2", 42)]) + uplink_message = uplink_message_builder.build() + + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 2 + + # Find telemetry and attribute messages + telemetry_msg = next((msg for msg in result if msg.topic == GATEWAY_TELEMETRY_TOPIC), None) + attribute_msg = next((msg for msg in result if msg.topic == GATEWAY_ATTRIBUTES_TOPIC), None) + + assert telemetry_msg is not None + assert attribute_msg is not None + + # Verify telemetry payload + telemetry_dict = orjson.loads(telemetry_msg.payload) + assert "test_device" in telemetry_dict + assert isinstance(telemetry_dict["test_device"], list) + assert len(telemetry_dict["test_device"]) == 1 + assert "values" in telemetry_dict["test_device"][0] + assert telemetry_dict["test_device"][0]["values"] == {"temp": 22.5, "humidity": 45} + + # Verify attribute payload + attribute_dict = orjson.loads(attribute_msg.payload) + assert "test_device" in attribute_dict + assert attribute_dict["test_device"] == {"key1": "value1", "key2": 42} + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_message_sender.py b/tests/service/gateway/test_message_sender.py new file mode 100644 index 0000000..f5ec49a --- /dev/null +++ b/tests/service/gateway/test_message_sender.py @@ -0,0 +1,323 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ + GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.gateway.message_sender import GatewayMessageSender +from tb_mqtt_client.service.message_service import MessageService + + +def test_init(): + # Setup & Act + sender = GatewayMessageSender() + + # Assert + assert sender._message_queue is None + assert sender._message_adapter is None + + +def test_set_message_queue(): + # Setup + sender = GatewayMessageSender() + message_queue = MagicMock(spec=MessageService) + + # Act + sender.set_message_queue(message_queue) + + # Assert + assert sender._message_queue == message_queue + + +def test_set_message_adapter(): + # Setup + sender = GatewayMessageSender() + message_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Act + sender.set_message_adapter(message_adapter) + + # Assert + assert sender._message_adapter == message_adapter + + +@pytest.mark.asyncio +async def test_send_uplink_message_with_timeseries(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + sender.set_message_queue(message_queue) + + # Create an uplink message with timeseries + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries([TimeseriesEntry("temp", 22.5)]) + .build()) + + # Mock the publish method to set delivery_futures + async def mock_publish(mqtt_message): + mqtt_message.delivery_futures = [asyncio.Future()] + return [mqtt_message] + + message_queue.publish.side_effect = mock_publish + + # Act + result = await sender.send_uplink_message(uplink_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_queue.publish.assert_called_once() + + # Verify the MQTT message + mqtt_message = message_queue.publish.call_args[0][0] + assert mqtt_message.topic == GATEWAY_TELEMETRY_TOPIC + assert mqtt_message.payload == uplink_message + assert mqtt_message.qos == 1 + + +@pytest.mark.asyncio +async def test_send_uplink_message_with_attributes(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + sender.set_message_queue(message_queue) + + # Create an uplink message with attributes + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes([AttributeEntry("temperature", 22.5)]) + .build()) + + # Mock the publish method to set delivery_futures + async def mock_publish(mqtt_message): + mqtt_message.delivery_futures = [asyncio.Future()] + return [mqtt_message] + + message_queue.publish.side_effect = mock_publish + + # Act + result = await sender.send_uplink_message(uplink_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_queue.publish.assert_called_once() + + # Verify the MQTT message + mqtt_message = message_queue.publish.call_args[0][0] + assert mqtt_message.topic == GATEWAY_ATTRIBUTES_TOPIC + assert mqtt_message.payload == uplink_message + assert mqtt_message.qos == 1 + + +@pytest.mark.asyncio +async def test_send_uplink_message_with_both(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + sender.set_message_queue(message_queue) + + # Create an uplink message with both timeseries and attributes + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries([TimeseriesEntry("temp", 22.5)]) + .add_attributes(AttributeEntry("key1", "value1")) + .build()) + + # Mock the publish method to set delivery_futures + async def mock_publish(mqtt_message): + mqtt_message.delivery_futures = [asyncio.Future()] + return [mqtt_message] + + message_queue.publish.side_effect = mock_publish + + # Act + result = await sender.send_uplink_message(uplink_message) + + # Assert + assert result is not None + assert len(result) == 2 # One for telemetry, one for attributes + assert isinstance(result[0], asyncio.Future) + assert isinstance(result[1], asyncio.Future) + assert message_queue.publish.call_count == 2 + + # Verify the MQTT messages + call_args_list = message_queue.publish.call_args_list + assert call_args_list[0][0][0].topic == GATEWAY_TELEMETRY_TOPIC + assert call_args_list[0][0][0].payload == uplink_message + assert call_args_list[1][0][0].topic == GATEWAY_ATTRIBUTES_TOPIC + assert call_args_list[1][0][0].payload == uplink_message + + +@pytest.mark.asyncio +async def test_send_device_connect(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create a device connect message + device_connect_message = DeviceConnectMessage.build("test_device", "default") + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_CONNECT_TOPIC, b'{"device":"test_device","type":"default"}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_device_connect_message_payload.return_value = mqtt_message + + # Act + result = await sender.send_device_connect(device_connect_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_device_connect_message_payload.assert_called_once_with(device_connect_message=device_connect_message, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_device_disconnect(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create a device disconnect message + device_disconnect_message = DeviceDisconnectMessage.build("test_device") + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_DISCONNECT_TOPIC, b'{"device":"test_device"}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_device_disconnect_message_payload.return_value = mqtt_message + + # Act + result = await sender.send_device_disconnect(device_disconnect_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_device_disconnect_message_payload.assert_called_once_with(device_disconnect_message=device_disconnect_message, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_attributes_request(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create an attribute request + attribute_request = MagicMock(spec=GatewayAttributeRequest) + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_ATTRIBUTES_REQUEST_TOPIC, b'{"device":"test_device","keys":["key1"]}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_gateway_attribute_request_payload.return_value = mqtt_message + + # Act + result = await sender.send_attributes_request(attribute_request) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_gateway_attribute_request_payload.assert_called_once_with(attribute_request=attribute_request, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_rpc_response(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create an RPC response + rpc_response = MagicMock(spec=GatewayRPCResponse) + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_RPC_TOPIC, b'{"device":"test_device","id":1,"data":{"result":"success"}}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_rpc_response_payload.return_value = mqtt_message + + # Act + result = await sender.send_rpc_response(rpc_response) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_rpc_response_payload.assert_called_once_with(rpc_response=rpc_response, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_claim_request(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create a claim request + claim_request = MagicMock(spec=GatewayClaimRequest) + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_CLAIM_TOPIC, b'{"device":"test_device","secretKey":"secret"}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_claim_request_payload.return_value = mqtt_message + + # Act + result = await sender.send_claim_request(claim_request) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_claim_request_payload.assert_called_once_with(claim_request=claim_request, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/service/gateway/test_message_splitter.py b/tests/service/gateway/test_message_splitter.py new file mode 100644 index 0000000..6d0cad2 --- /dev/null +++ b/tests/service/gateway/test_message_splitter.py @@ -0,0 +1,412 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from tb_mqtt_client.common.logging_utils import configure_logging +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.service.gateway.message_splitter import GatewayMessageSplitter + +configure_logging() + + +def test_init_default(): + # Setup & Act + splitter = GatewayMessageSplitter() + + # Assert + assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE + assert splitter.max_datapoints == 0 + + +def test_init_custom(): + # Setup & Act + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Assert + assert splitter.max_payload_size == 10000 - DEFAULT_FIELDS_SIZE + assert splitter.max_datapoints == 100 + + +def test_init_invalid_values(): + # Setup & Act + splitter = GatewayMessageSplitter(max_payload_size=-1, max_datapoints=-1) + + # Assert + assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE # Default value + assert splitter.max_datapoints == 0 # Default value + + +def test_max_payload_size_property(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_payload_size = 20000 + + # Assert + assert splitter.max_payload_size == 20000 - DEFAULT_FIELDS_SIZE + + +def test_max_payload_size_property_invalid(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_payload_size = -1 + + # Assert + assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE # Default value + + +def test_max_datapoints_property(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_datapoints = 200 + + # Assert + assert splitter.max_datapoints == 200 + + +def test_max_datapoints_property_invalid(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_datapoints = -1 + + # Assert + assert splitter.max_datapoints == 0 # Default value + + +@pytest.mark.asyncio +async def test_split_timeseries_no_split_needed(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create a message with timeseries that doesn't need splitting + timeseries = TimeseriesEntry("temp", 22.5) + message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 1 + assert result[0] == message + + +@pytest.mark.asyncio +async def test_split_timeseries_by_size(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Create a message with timeseries that needs splitting by size + message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + + # Mock the size property to force splitting + timeseries_entry = TimeseriesEntry("temp", 22.5) + entries = [TimeseriesEntry("temp", 22.5) for _ in range(int(60/timeseries_entry.size + 1))] + entries.extend([TimeseriesEntry("humidity", 45) for _ in range(int(60/timeseries_entry.size + 1))]) + message_builder.add_timeseries(entries) + message = message_builder.build() + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 2 + assert result[0].device_name == "test_device" + assert result[1].device_name == "test_device" + + # Check that each result has only one of the timeseries entries + assert result[0].size - DEFAULT_FIELDS_SIZE < splitter.max_payload_size + assert result[1].size - DEFAULT_FIELDS_SIZE < splitter.max_payload_size + + +@pytest.mark.asyncio +async def test_split_timeseries_by_datapoints(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=1) + + # Create a message with timeseries that needs splitting by datapoints + message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + entries = [ + TimeseriesEntry("temp", 22.5), + TimeseriesEntry("humidity", 45) + ] + message_builder.add_timeseries(entries) + message = message_builder.build() + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 2 + assert result[0].device_name == "test_device" + assert result[1].device_name == "test_device" + + # Check that each result has only one of the timeseries entries + assert len(result[0].timeseries[0]) == 1 + assert len(result[1].timeseries[0]) == 1 + + # The entries should be in separate messages + key0 = result[0].timeseries[0][0].key + key1 = result[1].timeseries[0][0].key + assert key0 != key1 + assert key0 == 'temp' + assert key1 == 'humidity' + + +@pytest.mark.asyncio +async def test_split_timeseries_multiple_messages(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create multiple messages with timeseries + message1 = GatewayUplinkMessageBuilder().set_device_name("device1").add_timeseries( + TimeseriesEntry("temp", 22.5) + ).build() + + message2 = GatewayUplinkMessageBuilder().set_device_name("device2").add_timeseries( + TimeseriesEntry("humidity", 45) + ).build() + + # Act + result = splitter.split_timeseries([message1, message2]) + + # Assert + assert len(result) == 2 + + # Check that the messages are for different devices + devices = {msg.device_name for msg in result} + assert devices == {"device1", "device2"} + + # Check that each result has the correct timeseries + for msg in result: + if msg.device_name == "device1": + assert "temp" == msg.timeseries[0][0].key + elif msg.device_name == "device2": + assert "humidity" == msg.timeseries[0][0].key + + +@pytest.mark.asyncio +async def test_split_timeseries_with_delivery_futures(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Create a message with enough timeseries entries to exceed the max size + message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + sample_entry = TimeseriesEntry("temp", 22.5) + # Create enough "temp" and "humidity" entries to force splitting + entries = [TimeseriesEntry("temp", 22.5) for _ in range(int(60 / sample_entry.size + 1))] + entries.extend([TimeseriesEntry("humidity", 45) for _ in range(int(60 / sample_entry.size + 1))]) + message_builder.add_timeseries(entries) + + # Add a delivery future to the original message + parent_future = asyncio.Future() + message_builder.add_delivery_futures([parent_future]) + + message = message_builder.build() + + # Mock the event loop and future_map + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock), \ + patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 2 + # Ensure futures are linked correctly + assert mock_register.call_count == 2 + mock_register.assert_any_call(parent_future, [future]) + + +@pytest.mark.asyncio +async def test_split_attributes_no_split_needed(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create a message with attributes that doesn't need splitting + message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes(AttributeEntry("key1", "value1")) + .build()) + + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) == 1 + assert result[0] == message + + +@pytest.mark.asyncio +async def test_split_attributes_by_size(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Enough repeated attributes to exceed 100 bytes (simulate large payload) + attrs = [AttributeEntry(f"key{i}", f"value{i}") for i in range(10)] + + message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes(attrs).build() + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) > 1 + # Ensure each split contains fewer attributes than the original + for part in result: + assert len(part.attributes) < len(attrs) + + +@pytest.mark.asyncio +async def test_split_attributes_by_datapoints(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=1) + + # Create a message with attributes that needs splitting by datapoints + message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes([AttributeEntry("key1", "value1"), + AttributeEntry("key2", "value2")]) + .build()) + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) == 2 + assert result[0].device_name == "test_device" + assert result[1].device_name == "test_device" + + # Check that each result has only one of the attribute entries + assert len(result[0].attributes) == 1 + assert len(result[1].attributes) == 1 + + # The entries should be in separate messages + keys0 = {attr.key for attr in result[0].attributes} + keys1 = {attr.key for attr in result[1].attributes} + assert len(keys0.intersection(keys1)) == 0 + assert keys0.union(keys1) == {"key1", "key2"} + + +@pytest.mark.asyncio +async def test_split_attributes_multiple_messages(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create multiple messages with attributes + message1 = (GatewayUplinkMessageBuilder() + .set_device_name("device1") + .add_attributes(AttributeEntry("key1", "value1")) + .build()) + + message2 = (GatewayUplinkMessageBuilder() + .set_device_name("device2") + .add_attributes(AttributeEntry("key2", "value2")) + .build()) + + # Act + result = splitter.split_attributes([message1, message2]) + + # Assert + assert len(result) == 2 + + # Check that the messages are for different devices + devices = {msg.device_name for msg in result} + assert devices == {"device1", "device2"} + + # Check that each result has the correct attributes + for msg in result: + if msg.device_name == "device1": + assert msg.attributes[0].key == "key1" + elif msg.device_name == "device2": + assert msg.attributes[0].key == "key2" + + +@pytest.mark.asyncio +async def test_split_attributes_with_delivery_futures(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Create enough attributes to exceed the max size + attrs = [AttributeEntry(f"key{i}", f"value{i}") for i in range(10)] + + # Delivery future for the original message + parent_future = asyncio.Future() + + message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes(attrs) + .add_delivery_futures([parent_future]) + .build()) + + # Mock the event loop and future_map + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock), \ + patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) > 1 + assert mock_register.call_count == len(result) + mock_register.assert_any_call(parent_future, [future]) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/test_message_service.py b/tests/service/test_message_service.py index 9eddefa..8a2a2e1 100644 --- a/tests/service/test_message_service.py +++ b/tests/service/test_message_service.py @@ -13,7 +13,7 @@ # limitations under the License. import asyncio -from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock, call +from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio @@ -22,12 +22,12 @@ from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, EMPTY_RATE_LIMIT from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter -from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder -from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder from tb_mqtt_client.service.device.message_adapter import MessageAdapter from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter -from tb_mqtt_client.service.mqtt_manager import MQTTManager from tb_mqtt_client.service.message_service import MessageService, MessageQueueWorker +from tb_mqtt_client.service.mqtt_manager import MQTTManager @pytest_asyncio.fixture @@ -82,6 +82,22 @@ async def setup_message_service(): yield service, mqtt_manager, main_stop_event, device_rate_limiter, message_adapter, gateway_message_adapter, gateway_rate_limiter +@pytest_asyncio.fixture +async def setup_retry_loop_service(setup_message_service): + service, mqtt_manager, main_stop_event, *_ = setup_message_service + service._main_stop_event = main_stop_event + service._active.set() + + # Mock dependencies + service._retry_by_qos_queue = AsyncMock() + service._service_queue = AsyncMock() + service._gateway_uplink_messages_queue = AsyncMock() + service._device_uplink_messages_queue = AsyncMock() + + # Cooldown patch to avoid slow tests + service._QUEUE_COOLDOWN = 0 + + return service, mqtt_manager @pytest.mark.asyncio async def test_publish_success(setup_message_service): @@ -149,7 +165,7 @@ async def mock_task(): with patch.object(MessageService, '_dispatch_initial_queue_loop', return_value=mock_task()), \ patch.object(MessageService, '_dispatch_queue_loop', return_value=mock_task()), \ patch.object(MessageService, '_rate_limit_refill_loop', return_value=mock_task()), \ - patch.object(MessageService, 'print_queue_statistics', return_value=mock_task()), \ + patch.object(MessageService, 'print_queues_statistics', return_value=mock_task()), \ patch.object(MessageService, 'clear', new_callable=AsyncMock) as mock_clear: # Create the service @@ -969,5 +985,74 @@ async def test_message_queue_worker_consume_rate_limits_for_message_no_limits(): message_rate_limit.consume.assert_not_awaited() datapoints_rate_limit.consume.assert_not_awaited() + +@pytest.mark.asyncio +async def test_retry_loop_with_bytes_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + message = MagicMock() + message.original_payload = b"binary" + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[message, asyncio.CancelledError()]) + + await service._dispatch_retry_by_qos_queue_loop() + + service._service_queue.reinsert_front.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_retry_loop_with_gateway_uplink_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + message = MagicMock() + message.original_payload = GatewayUplinkMessageBuilder().set_device_name("test").build() + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[message, asyncio.CancelledError()]) + + await service._dispatch_retry_by_qos_queue_loop() + + service._gateway_uplink_messages_queue.reinsert_front.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_retry_loop_with_device_uplink_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + message = MagicMock() + message.original_payload = DeviceUplinkMessageBuilder().set_device_name("test-device").build() + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[message, asyncio.CancelledError()]) + + await service._dispatch_retry_by_qos_queue_loop() + + service._device_uplink_messages_queue.reinsert_front.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_retry_loop_with_cancelled_error(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + service._retry_by_qos_queue.get = AsyncMock(side_effect=asyncio.CancelledError()) + + await service._dispatch_retry_by_qos_queue_loop() + + +@pytest.mark.asyncio +async def test_retry_loop_with_not_connected_and_empty_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + + mqtt_manager.is_connected.side_effect = [False, True] + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[None, asyncio.CancelledError()]) + + service._main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + with patch("asyncio.sleep", new=AsyncMock()): + await service._dispatch_retry_by_qos_queue_loop() + if __name__ == '__main__': pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/tb_device_mqtt_client_tests.py b/tests/tb_device_mqtt_client_tests.py deleted file mode 100644 index 131403d..0000000 --- a/tests/tb_device_mqtt_client_tests.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from time import sleep - -from tb_device_mqtt import TBDeviceMqttClient - - -class TBDeviceMqttClientTests(unittest.TestCase): - """ - Before running tests, do the next steps: - 1. Create device "Example Name" in ThingsBoard - 2. Add shared attribute "attr" with value "hello" to created device - 3. Add client attribute "atr3" with value "value3" to created device - """ - - client = None - - shared_attribute_name = 'attr' - shared_attribute_value = 'hello' - - client_attribute_name = 'atr3' - client_attribute_value = 'value3' - - request_attributes_result = None - subscribe_to_attribute = None - subscribe_to_attribute_all = None - - @classmethod - def setUpClass(cls) -> None: - cls.client = TBDeviceMqttClient('127.0.0.1', 1883, 'TEST_DEVICE_TOKEN') - cls.client.connect(timeout=1) - - @classmethod - def tearDownClass(cls) -> None: - cls.client.disconnect() - - @staticmethod - def on_attributes_change_callback(result, exception=None): - if exception is not None: - TBDeviceMqttClientTests.request_attributes_result = exception - else: - TBDeviceMqttClientTests.request_attributes_result = result - - @staticmethod - def callback_for_specific_attr(result, *args): - TBDeviceMqttClientTests.subscribe_to_attribute = result - - @staticmethod - def callback_for_everything(result, *args): - TBDeviceMqttClientTests.subscribe_to_attribute_all = result - - def test_request_attributes(self): - self.client.request_attributes(shared_keys=[self.shared_attribute_name], - callback=self.on_attributes_change_callback) - sleep(3) - self.assertEqual(self.request_attributes_result, - {'shared': {self.shared_attribute_name: self.shared_attribute_value}}) - - self.client.request_attributes(client_keys=[self.client_attribute_name], - callback=self.on_attributes_change_callback) - sleep(3) - self.assertEqual(self.request_attributes_result, - {'client': {self.client_attribute_name: self.client_attribute_value}}) - - def test_send_telemetry_and_attr(self): - telemetry = {"temperature": 41.9, "humidity": 69, "enabled": False, "currentFirmwareVersion": "v1.2.2"} - self.assertEqual(self.client.send_telemetry(telemetry, 0).get(), 0) - - attributes = {"sensorModel": "DHT-22", self.client_attribute_name: self.client_attribute_value} - self.assertEqual(self.client.send_attributes(attributes, 0).get(), 0) - - def test_subscribe_to_attrs(self): - sub_id_1 = self.client.subscribe_to_attribute(self.shared_attribute_name, self.callback_for_specific_attr) - sub_id_2 = self.client.subscribe_to_all_attributes(self.callback_for_everything) - - sleep(1) - value = input("Updated attribute value: ") - - self.assertEqual(self.subscribe_to_attribute_all, {self.shared_attribute_name: value}) - self.assertEqual(self.subscribe_to_attribute, {self.shared_attribute_name: value}) - - self.client.unsubscribe_from_attribute(sub_id_1) - self.client.unsubscribe_from_attribute(sub_id_2) - - -class TestSplitMessageVariants(unittest.TestCase): - - def test_empty_input(self): - self.assertEqual(TBDeviceMqttClient._split_message([], 10, 100), []) - - def test_rpc_payload(self): - rpc_payload = {'device': 'dev1'} - result = TBDeviceMqttClient._split_message(rpc_payload, 10, 100) - self.assertEqual(len(result), 1) - self.assertEqual(result[0]['data'], rpc_payload) - - def test_single_value_message(self): - msg = [{'ts': 1, 'values': {'a': 1}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result[0]['data'][0]['values'], {'a': 1}) - - def test_timestamp_change_split(self): - msg = [{'ts': 1, 'values': {'a': 1}}, {'ts': 2, 'values': {'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(len(result), 2) - - def test_exceeding_datapoint_limit(self): - msg = [{'ts': 1, 'values': {f'k{i}': i for i in range(10)}}] - result = TBDeviceMqttClient._split_message(msg, 5, 1000) - self.assertGreaterEqual(len(result), 2) - - def test_message_with_metadata(self): - msg = [{'ts': 1, 'values': {'a': 1}, 'metadata': {'unit': 'C'}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertIn('metadata', result[0]['data'][0]) - - def test_large_payload_split(self): - msg = [{'ts': 1, 'values': {f'key{i}': 'v'*50 for i in range(5)}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertGreater(len(result), 1) - - def test_metadata_present_with_ts(self): - msg = [{'ts': 123456789, 'values': {'temperature': 25}, 'metadata': {'unit': 'C'}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertIn('metadata', result[0]['data'][0]) - self.assertEqual(result[0]['data'][0]['metadata'], {'unit': 'C'}) - - def test_metadata_ignored_without_ts(self): - msg = [{'values': {'temperature': 25}, 'metadata': {'unit': 'C'}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertTrue(all('metadata' not in entry for r in result for entry in r['data'])) - - def test_grouping_same_ts_exceeds_datapoint_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 1, 'b': 2, 'c': 3}}, - {'ts': 1, 'values': {'d': 4, 'e': 5}} - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=3, max_payload_size=1000) - self.assertGreater(len(result), 1) - total_keys = set() - for part in result: - for d in part['data']: - total_keys.update(d['values'].keys()) - self.assertEqual(total_keys, {'a', 'b', 'c', 'd', 'e'}) - - def test_grouping_same_ts_exceeds_payload_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 'x'*30}}, - {'ts': 1, 'values': {'b': 'y'*30}}, - {'ts': 1, 'values': {'c': 'z'*30}} - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=64) - self.assertGreater(len(result), 1) - all_keys = set() - for r in result: - for d in r['data']: - all_keys.update(d['values'].keys()) - self.assertEqual(all_keys, {'a', 'b', 'c'}) - - def test_individual_messages_not_grouped_if_payload_limit_exceeded(self): - msg = [ - {'ts': 1, 'values': {'a': 'x' * 100}}, - {'ts': 1, 'values': {'b': 'y' * 100}}, - {'ts': 1, 'values': {'c': 'z' * 100}} - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=110) - self.assertEqual(len(result), 3) - - grouped_keys = [] - for r in result: - for d in r['data']: - grouped_keys.extend(d['values'].keys()) - - self.assertEqual(set(grouped_keys), {'a', 'b', 'c'}) - for r in result: - for d in r['data']: - self.assertLessEqual(sum(len(k) + len(str(v)) for k, v in d['values'].items()), 110) - - def test_partial_grouping_due_to_payload_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 'x' * 10, 'b': 'y' * 10}}, # should be grouped together - {'ts': 1, 'values': {'c': 'z' * 100}}, # should be on its own due to size - {'ts': 1, 'values': {'d': 'w' * 10, 'e': 'q' * 10}} # should be grouped again - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=64) - self.assertEqual(len(result), 3) - - result_keys = [] - for r in result: - keys = [] - for d in r['data']: - keys.extend(d['values'].keys()) - result_keys.append(set(keys)) - - self.assertIn({'a', 'b'}, result_keys) - self.assertIn({'c'}, result_keys) - self.assertIn({'d', 'e'}, result_keys) - - def test_partial_grouping_due_to_datapoint_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 1, 'b': 2}}, # grouped - {'ts': 1, 'values': {'c': 3, 'd': 4}}, # grouped - {'ts': 1, 'values': {'e': 5}}, # forced into next group due to datapoint limit - {'ts': 1, 'values': {'f': 6}}, # grouped with above - ] - # Max datapoints per message is 4 (after subtracting 1 in implementation) - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=5, max_payload_size=1000) - self.assertEqual(len(result), 2) - - all_keys = [] - for r in result: - keys = set() - for d in r['data']: - keys.update(d['values'].keys()) - all_keys.append(keys) - - # First group should contain a, b, c, d (4 datapoints) - self.assertIn({'a', 'b', 'c', 'd'}, all_keys) - # Second group should contain e, f (2 datapoints) - self.assertIn({'e', 'f'}, all_keys) - - def test_values_included_only_when_ts_present(self): - msg = [{'values': {'a': 1, 'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result[0]['data'][0], {'a': 1, 'b': 2}) - - def test_missing_values_field_uses_whole_message(self): - msg = [{'ts': 123, 'a': 1, 'b': 2}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertIn('values', result[0]['data'][0]) - self.assertEqual(result[0]['data'][0]['values'], {'a': 1, 'b': 2, 'ts': 123}) - - def test_metadata_conflict_same_ts_no_grouping(self): - msg = [ - {'ts': 1, 'values': {'a': 1}, 'metadata': {'unit': 'C'}}, - {'ts': 1, 'values': {'b': 2}, 'metadata': {'unit': 'F'}} - ] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - - self.assertEqual(len(result), 2) - - metadata_sets = [d['data'][0].get('metadata') for d in result] - self.assertIn({'unit': 'C'}, metadata_sets) - self.assertIn({'unit': 'F'}, metadata_sets) - - value_keys_sets = [set(d['data'][0]['values'].keys()) for d in result] - self.assertIn({'a'}, value_keys_sets) - self.assertIn({'b'}, value_keys_sets) - - def test_non_dict_message_is_skipped(self): - msg = [{'ts': 1, 'values': {'a': 1}}, 'this_is_not_a_dict', {'ts': 1, 'values': {'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(len(result), 1) - values = result[0]['data'][0]['values'] - self.assertEqual(set(values.keys()), {'a', 'b'}) - - def test_multiple_dicts_without_ts_values_metadata(self): - msg = [{'a': 1}, {'b': 2}, {'c': 3}] - result = TBDeviceMqttClient._split_message(msg, 10, 1000) - self.assertEqual(len(result), 1) # grouped - combined_keys = {} - for d in result[0]['data']: - combined_keys.update(d) - self.assertEqual(set(combined_keys.keys()), {'a', 'b', 'c'}) - - def test_multiple_dicts_without_ts_split_by_payload(self): - msg = [{'a': 'x' * 60}, {'b': 'y' * 60}, {'c': 'z' * 60}] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=64) - self.assertEqual(len(result), 3) # each too large to group - keys = [list(r['data'][0].keys())[0] for r in result] - self.assertEqual(set(keys), {'a', 'b', 'c'}) - - def test_mixed_dicts_with_and_without_ts(self): - msg = [ - {'ts': 1, 'values': {'a': 1}}, - {'b': 2}, - {'ts': 1, 'values': {'c': 3}}, - {'d': 4} - ] - result = TBDeviceMqttClient._split_message(msg, 10, 1000) - # Should split into at least 2 chunks: one for ts=1 and one for ts=None - self.assertGreaterEqual(len(result), 2) - - ts_chunks = [d for r in result for d in r['data'] if 'ts' in d] - raw_chunks = [d for r in result for d in r['data'] if 'ts' not in d] - - ts_keys = set() - for d in ts_chunks: - ts_keys.update(d['values'].keys()) - - raw_keys = set() - for d in raw_chunks: - raw_keys.update(d.keys()) - - self.assertEqual(ts_keys, {'a', 'c'}) - self.assertEqual(raw_keys, {'b', 'd'}) - - def test_complex_mixed_messages(self): - msg = [ - {'ts': 1, 'values': {'a': 1}}, - {'ts': 1, 'values': {'b': 2}, 'metadata': {'unit': 'C'}}, - - {'ts': 2, 'values': {'c1': 1, 'c2': 2, 'c3': 3}}, - {'ts': 2, 'values': {'c4': 4, 'c5': 5}}, - - {'ts': 3, 'values': {'x': 'x' * 60}}, - {'ts': 3, 'values': {'y': 'y' * 60}}, - - {'m1': 1, 'm2': 2}, - - 123, - - {'m3': 'a' * 100}, - {'m4': 'b' * 100}, - - {'k1': 1, 'k2': 2, 'k3': 3}, - {'k4': 4, 'k5': 5}, - - {'ts': 1, 'values': {'z': 99}}, - ] - - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=4, max_payload_size=64) - - all_ts_groups = {} - raw_chunks = [] - for r in result: - for entry in r['data']: - if isinstance(entry, dict) and 'ts' in entry: - ts = entry['ts'] - if ts not in all_ts_groups: - all_ts_groups[ts] = set() - all_ts_groups[ts].update(entry.get('values', {}).keys()) - if 'metadata' in entry: - self.assertIsInstance(entry['metadata'], dict) - else: - raw_chunks.append(entry) - - self.assertIn(1, all_ts_groups) - self.assertEqual(all_ts_groups[1], {'a', 'b', 'z'}) - - self.assertIn(2, all_ts_groups) - self.assertEqual(all_ts_groups[2], {'c1', 'c2', 'c3', 'c4', 'c5'}) - - self.assertIn(3, all_ts_groups) - self.assertEqual(all_ts_groups[3], {'x', 'y'}) - - all_raw_keys = [set(entry.keys()) for entry in raw_chunks] - - expected_raw_key_sets = [ - {'m1', 'm2'}, - {'m3'}, - {'m4'}, - {'k1', 'k2', 'k3'}, - {'k4', 'k5'} - ] - - for expected_keys in expected_raw_key_sets: - self.assertIn(expected_keys, all_raw_keys) - - for r in raw_chunks: - self.assertLessEqual(len(r), 4) # Max datapoints = 4 - - total_size = sum(len(k) + len(str(v)) for k, v in r.items()) - - if len(r) > 1: - self.assertLessEqual(total_size, 64) - if total_size > 64: - self.assertEqual(len(r), 1) - - self.assertGreaterEqual(len(result), 8) - - def test_empty_values_should_skip_or_include_empty(self): - msg = [{'ts': 1, 'values': {}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result, []) - - def test_duplicate_keys_within_same_ts(self): - msg = [{'ts': 1, 'values': {'a': 1}}, {'ts': 1, 'values': {'a': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - values = {} - for d in result[0]['data']: - values.update(d['values']) - self.assertEqual(values['a'], 2) # Last value wins - - def test_partial_metadata_presence(self): - msg = [{'ts': 1, 'values': {'a': 1}, 'metadata': {'unit': 'C'}}, - {'ts': 1, 'values': {'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - for d in result[0]['data']: - if d['values'].keys() == {'a', 'b'} or d['values'].keys() == {'b', 'a'}: - self.assertIn('metadata', d) - - def test_non_dict_metadata_should_be_ignored(self): - msg = [{'ts': 1, 'values': {'a': 1}, 'metadata': ['invalid']}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertTrue(all( - isinstance(d.get('metadata', {}), dict) or 'metadata' not in d - for r in result for d in r['data'] - )) - - def test_non_list_message_pack_single_dict_raw(self): - msg = {'a': 1, 'b': 2} - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result[0]['data'][0], msg) - - def test_nested_value_object_should_count_size_correctly(self): - msg = [{'ts': 1, 'values': {'a': {'nested': 'structure'}}}] - result = TBDeviceMqttClient._split_message(msg, 10, 1000) - total_size = sum(len(k) + len(str(v)) for r in result for d in r['data'] for k, v in d['values'].items()) - self.assertGreater(total_size, 0) - - def test_raw_duplicate_keys_overwrite_behavior(self): - msg = [{'a': 1}, {'a': 2}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - all_data = {} - for r in result: - for d in r['data']: - all_data.update(d) - self.assertEqual(all_data['a'], 2) # Last value wins - - -if __name__ == '__main__': - unittest.main('tb_device_mqtt_client_tests') diff --git a/tests/tb_gateway_mqtt_client_tests.py b/tests/tb_gateway_mqtt_client_tests.py deleted file mode 100644 index 761fca3..0000000 --- a/tests/tb_gateway_mqtt_client_tests.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from time import sleep, time - -from tb_gateway_mqtt import TBGatewayMqttClient - - -class TBGatewayMqttClientTests(unittest.TestCase): - """ - Before running tests, do the next steps: - 1. Create device "Example Name" in ThingsBoard - 2. Add shared attribute "attr" with value "hello" to created device - """ - - client = None - - device_name = 'Example Name' - shared_attr_name = 'attr' - shared_attr_value = 'hello' - - request_attributes_result = None - subscribe_to_attribute = None - subscribe_to_attribute_all = None - subscribe_to_device_attribute_all = None - - @classmethod - def setUpClass(cls) -> None: - cls.client = TBGatewayMqttClient('127.0.0.1', 1883, 'TEST_GATEWAY_TOKEN') - cls.client.connect(timeout=1) - - @classmethod - def tearDownClass(cls) -> None: - cls.client.disconnect() - - @staticmethod - def request_attributes_callback(result, exception=None): - if exception is not None: - TBGatewayMqttClientTests.request_attributes_result = exception - else: - TBGatewayMqttClientTests.request_attributes_result = result - - @staticmethod - def callback(result): - TBGatewayMqttClientTests.subscribe_to_device_attribute_all = result - - @staticmethod - def callback_for_everything(result): - TBGatewayMqttClientTests.subscribe_to_attribute_all = result - - @staticmethod - def callback_for_specific_attr(result): - TBGatewayMqttClientTests.subscribe_to_attribute = result - - def test_connect_disconnect_device(self): - self.assertEqual(self.client.gw_connect_device(self.device_name).rc, 0) - self.assertEqual(self.client.gw_disconnect_device(self.device_name).rc, 0) - - def test_request_attributes(self): - self.client.gw_request_shared_attributes(self.device_name, [self.shared_attr_name], - self.request_attributes_callback) - sleep(3) - self.assertEqual(self.request_attributes_result, - {'id': 1, 'device': self.device_name, 'value': self.shared_attr_value}) - - def test_send_telemetry_and_attributes(self): - attributes = {"atr1": 1, "atr2": True, "atr3": "value3"} - telemetry = {"ts": int(round(time() * 1000)), "values": {"key1": "11"}} - self.assertEqual(self.client.gw_send_attributes(self.device_name, attributes).get(), 0) - self.assertEqual(self.client.gw_send_telemetry(self.device_name, telemetry).get(), 0) - - def test_subscribe_to_attributes(self): - self.client.gw_connect_device(self.device_name) - - self.client.gw_subscribe_to_all_attributes(self.callback_for_everything) - self.client.gw_subscribe_to_attribute(self.device_name, self.shared_attr_name, self.callback_for_specific_attr) - sub_id = self.client.gw_subscribe_to_all_device_attributes(self.device_name, self.callback) - - sleep(1) - value = input("Updated attribute value: ") - - self.assertEqual(self.subscribe_to_attribute, - {'device': self.device_name, 'data': {self.shared_attr_name: value}}) - self.assertEqual(self.subscribe_to_attribute_all, - {'device': self.device_name, 'data': {self.shared_attr_name: value}}) - self.assertEqual(self.subscribe_to_device_attribute_all, - {'device': self.device_name, 'data': {self.shared_attr_name: value}}) - - self.client.gw_unsubscribe(sub_id) - - -if __name__ == '__main__': - unittest.main('tb_gateway_mqtt_client_tests') diff --git a/tests/test_async_utils.py b/tests/test_async_utils.py deleted file mode 100644 index 2ea8010..0000000 --- a/tests/test_async_utils.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from unittest import IsolatedAsyncioTestCase -from unittest.mock import patch - -from tb_mqtt_client.common.async_utils import await_and_resolve_original - - -class TestAwaitAndResolveOriginal(IsolatedAsyncioTestCase): - - async def resolves_parent_with_first_successful_child_result(self): - parent_future = asyncio.Future() - child_future_1 = asyncio.Future() - child_future_2 = asyncio.Future() - - child_future_1.set_result("success") - child_future_2.set_exception(ValueError("error")) - - await await_and_resolve_original([parent_future], [child_future_1, child_future_2]) - - self.assertTrue(parent_future.done()) - self.assertEqual(parent_future.result(), "success") - - async def resolves_parent_with_first_exception_if_no_successful_results(self): - parent_future = asyncio.Future() - child_future_1 = asyncio.Future() - child_future_2 = asyncio.Future() - - child_future_1.set_exception(ValueError("error1")) - child_future_2.set_exception(ValueError("error2")) - - await await_and_resolve_original([parent_future], [child_future_1, child_future_2]) - - self.assertTrue(parent_future.done()) - self.assertIsInstance(parent_future.exception(), ValueError) - self.assertEqual(str(parent_future.exception()), "error1") - - async def handles_empty_child_futures_list(self): - parent_future = asyncio.Future() - - await await_and_resolve_original([parent_future], []) - - self.assertTrue(parent_future.done()) - self.assertIsNone(parent_future.result()) - - async def sets_exception_on_parent_if_unexpected_error_occurs(self): - parent_future = asyncio.Future() - child_future = asyncio.Future() - - with patch("tb_mqtt_client.common.async_utils.future_map.child_resolved", side_effect=Exception("unexpected")): - await await_and_resolve_original([parent_future], [child_future]) - - self.assertTrue(parent_future.done()) - self.assertIsInstance(parent_future.exception(), Exception) - self.assertEqual(str(parent_future.exception()), "unexpected") \ No newline at end of file diff --git a/utils.py b/utils.py deleted file mode 100644 index 1227418..0000000 --- a/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from sys import executable -from subprocess import check_call, CalledProcessError, DEVNULL -from pkg_resources import get_distribution, DistributionNotFound - - -def install_package(package, version="upgrade"): - result = False - - def try_install(args, suppress_stderr=False): - try: - stderr = DEVNULL if suppress_stderr else None - check_call([executable, "-m", "pip", *args], stderr=stderr) - return True - except CalledProcessError: - return False - - if version.lower() == "upgrade": - args = ["install", package, "--upgrade"] - result = try_install(args + ["--user"], suppress_stderr=True) - if not result: - result = try_install(args) - else: - try: - installed_version = get_distribution(package).version - if installed_version == version: - return True - except DistributionNotFound: - pass - install_version = f"{package}=={version}" if ">=" not in version else f"{package}{version}" - args = ["install", install_version] - if not try_install(args + ["--user"], suppress_stderr=True): - result = try_install(args) - - return result From c930d21ed8d1791ae3c7cdb0177c6b69298370fe Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 1 Aug 2025 14:52:09 +0300 Subject: [PATCH 64/74] Updated gateway operational example and attribute response processing --- examples/gateway/operational_example.py | 212 ++++++++++-------- .../service/gateway/message_adapter.py | 7 +- 2 files changed, 117 insertions(+), 102 deletions(-) diff --git a/examples/gateway/operational_example.py b/examples/gateway/operational_example.py index db083f1..927a888 100644 --- a/examples/gateway/operational_example.py +++ b/examples/gateway/operational_example.py @@ -15,148 +15,162 @@ import asyncio import logging import signal -from datetime import datetime, UTC -from random import randint, uniform +from random import uniform, randint +from time import time from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse from tb_mqtt_client.service.gateway.client import GatewayClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.service.gateway.device_session import DeviceSession +# ---- Logging Setup ---- configure_logging() logger = get_logger(__name__) logger.setLevel(logging.DEBUG) logging.getLogger("tb_mqtt_client").setLevel(logging.DEBUG) +# ---- Constants ---- +GATEWAY_HOST = "localhost" +GATEWAY_PORT = 1883 # Default MQTT port, change if needed +GATEWAY_ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your actual access token + +DELAY_BETWEEN_DATA_PUBLISH = 1 # seconds + + +# ---- Callbacks ---- +async def attribute_update_handler(device_session: DeviceSession, update: AttributeUpdate): + logger.info("Received attribute update for %s: %s", + device_session.device_info.device_name, update.entries) + + +async def requested_attributes_handler(device_session: DeviceSession, response: RequestedAttributeResponse): + logger.info("Requested attributes for %s -> client: %r, shared: %r", + device_session.device_info.device_name, + response.client, + response.shared) + -async def device_attribute_update_callback(update: AttributeUpdate): - """ - Callback function to handle device attribute updates. - :param update: The attribute update object. - """ - logger.info("Received attribute update for device %s: %s", update.device, update.attributes) - - -async def device_rpc_request_callback(device_name: str, method: str, params: dict): - """ - Callback function to handle device RPC requests. - :param device_name: Name of the device - :param method: RPC method - :param params: RPC parameters - :return: RPC response - """ - logger.info("Received RPC request for device %s: method=%s, params=%s", device_name, method, params) - - # Example response based on method - if method == "getTemperature": - return {"temperature": round(uniform(20.0, 30.0), 2)} - elif method == "setLedState": - state = params.get("state", False) - return {"success": True, "state": state} - else: - return {"error": f"Unsupported method: {method}"} - - -async def device_disconnect_callback(device_name: str): - """ - Callback function to handle device disconnections. - :param device_name: Name of the disconnected device - """ - logger.info("Device %s disconnected", device_name) +async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: + logger.info("Received RPC request for %s: %r", device_session.device_info.device_name, rpc_request) + response_data = { + "status": "success", + "echo_method": rpc_request.method, + "params": rpc_request.params + } + return GatewayRPCResponse.build(device_session.device_info.device_name, rpc_request.request_id, response_data) +# ---- Main ---- async def main(): stop_event = asyncio.Event() def _shutdown_handler(): stop_event.set() + asyncio.create_task(client.stop()) loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): try: loop.add_signal_handler(sig, _shutdown_handler) except NotImplementedError: - # Windows compatibility fallback signal.signal(sig, lambda *_: _shutdown_handler()) + # ---- Gateway Config ---- config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_GATEWAY_ACCESS_TOKEN" + config.host = GATEWAY_HOST + config.port = GATEWAY_PORT + config.access_token = GATEWAY_ACCESS_TOKEN + global client client = GatewayClient(config) - client.set_device_attribute_update_callback(device_attribute_update_callback) - client.set_device_rpc_request_callback(device_rpc_request_callback) - client.set_device_disconnect_callback(device_disconnect_callback) + await client.connect() + logger.info("Gateway connected to ThingsBoard.") + + # ---- Register Devices ---- + devices = [ + ("Test Device A1", "Test devices"), + ("Test Device B1", "Test devices") + ] + sessions = {} + + for name, profile in devices: + session, _ = await client.connect_device(name, profile, wait_for_publish=True) + if not session: + logger.error("Failed to connect %s", name) + continue + sessions[name] = session + logger.info("Device connected: %s", name) + + # Register callbacks for each device + client.device_manager.set_attribute_update_callback(session.device_info.device_id, attribute_update_handler) + client.device_manager.set_attribute_response_callback(session.device_info.device_id, requested_attributes_handler) + client.device_manager.set_rpc_request_callback(session.device_info.device_id, device_rpc_request_handler) + + # ---- Main loop ---- + while not stop_event.is_set(): + iteration_start = time() - logger.info("Connected to ThingsBoard as gateway.") + for device_name, session in sessions.items(): + logger.info("Publishing data for %s", device_name) - # Connect devices to the gateway - device_names = ["sensor-1", "sensor-2", "actuator-1"] - for device_name in device_names: - await client.gw_connect_device(device_name) - logger.info("Connected device: %s", device_name) + # --- Attributes --- + raw_attrs = {"firmwareVersion": "2.0.0", "location": "office"} + await client.send_device_attributes(session, raw_attrs, wait_for_publish=True) - while not stop_event.is_set(): - # Send device attributes - for device_name in device_names: - # Send device attributes - attributes = { - "firmwareVersion": "1.0.4", - "serialNumber": f"SN-{randint(1000, 9999)}", - "deviceType": "sensor" if "sensor" in device_name else "actuator" - } - await client.gw_send_attributes(device_name, attributes) - logger.info("Sent attributes for device %s: %s", device_name, attributes) - - # Send single attribute - single_attr = AttributeEntry("lastUpdateTime", datetime.now(UTC).isoformat()) - await client.gw_send_attributes(device_name, single_attr) - logger.info("Sent single attribute for device %s: %s", device_name, single_attr) - - # Send device telemetry - if "sensor" in device_name: - # For sensor devices - telemetry = { - "temperature": round(uniform(20.0, 30.0), 2), - "humidity": round(uniform(40.0, 80.0), 2), - "batteryLevel": randint(1, 100) - } - await client.gw_send_telemetry(device_name, telemetry) - logger.info("Sent telemetry for device %s: %s", device_name, telemetry) - - # Send single telemetry entry - single_entry = TimeseriesEntry("signalStrength", randint(-90, -30)) - await client.gw_send_telemetry(device_name, single_entry) - logger.info("Sent single telemetry entry for device %s: %s", device_name, single_entry) - else: - # For actuator devices - telemetry = { - "state": "ON" if randint(0, 1) == 1 else "OFF", - "powerConsumption": round(uniform(0.1, 5.0), 2), - "uptime": randint(1, 1000) - } - await client.gw_send_telemetry(device_name, telemetry) - logger.info("Sent telemetry for device %s: %s", device_name, telemetry) + single_attr = AttributeEntry("mode", "auto") + await client.send_device_attributes(session, single_attr, wait_for_publish=True) + + multi_attrs = [ + AttributeEntry("maxTemp", randint(60, 90)), + AttributeEntry("calibrated", True) + ] + await client.send_device_attributes(session, multi_attrs, wait_for_publish=True) + + # --- Telemetry --- + raw_ts = {"temperature": round(uniform(20, 30), 2), "humidity": randint(40, 80)} + await client.send_device_timeseries(session, raw_ts, wait_for_publish=True) + list_ts = [ + {"ts": int(time() * 1000), "values": {"temperature": 25.5}}, + {"ts": int(time() * 1000) - 1000, "values": {"humidity": 65}} + ] + await client.send_device_timeseries(session, list_ts, wait_for_publish=True) + + ts_entries = [TimeseriesEntry(f"temp_{i}", randint(0, 50)) for i in range(5)] + await client.send_device_timeseries(session, ts_entries, wait_for_publish=True) + + # --- Attribute Request --- + attr_req = await GatewayAttributeRequest.build( + device_session=session, + client_keys=["firmwareVersion", "mode"] + ) + await client.send_device_attributes_request(session, attr_req, wait_for_publish=True) + + # ---- Delay handling ---- try: - await asyncio.wait_for(stop_event.wait(), timeout=5) + timeout = DELAY_BETWEEN_DATA_PUBLISH - (time() - iteration_start) + if timeout > 0: + await asyncio.wait_for(stop_event.wait(), timeout=timeout) except asyncio.TimeoutError: pass - # Disconnect devices before shutting down - for device_name in device_names: - await client.gw_disconnect_device(device_name) - logger.info("Disconnected device: %s", device_name) - + # ---- Disconnect devices ---- + for session in sessions.values(): + await client.disconnect_device(session, wait_for_publish=True) await client.disconnect() - logger.info("Disconnected from ThingsBoard.") + logger.info("Gateway disconnected cleanly.") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: - print("Interrupted by user.") \ No newline at end of file + logger.info("Interrupted by user.") \ No newline at end of file diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 0eb41d8..8e7e23b 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -321,10 +321,11 @@ def parse_gateway_requested_attribute_response(self, and not gateway_attribute_request.client_keys))): # TODO: Skipping case when requested several attributes, but only one is returned, issue on the platform logger.warning("Received gateway attribute response with single key, but multiply keys expected. " - "Request keys: %s, Response keys: %s", - list(*gateway_attribute_request.client_keys, *gateway_attribute_request.shared_keys), + "Request keys: %s, Response value: %r", + gateway_attribute_request.client_keys + gateway_attribute_request.shared_keys, data['value']) - return None + client = [] + shared = [] elif 'value' in data: if not client_keys_empty and len(gateway_attribute_request.client_keys) == 1: client = [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] From d01e25b901a7787d0a422a33a2eb0ba07df00302 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 4 Aug 2025 12:49:35 +0300 Subject: [PATCH 65/74] Added tls example for gateway client --- .../DEPRECATEDclaiming_device_pe_only.py | 58 -------------- examples/gateway/tls_connect.py | 76 +++++++++++++++++++ tb_mqtt_client/common/config_loader.py | 2 + tb_mqtt_client/service/gateway/client.py | 4 - tests/service/gateway/test_gateway_client.py | 17 ----- 5 files changed, 78 insertions(+), 79 deletions(-) delete mode 100644 examples/gateway/DEPRECATEDclaiming_device_pe_only.py create mode 100644 examples/gateway/tls_connect.py diff --git a/examples/gateway/DEPRECATEDclaiming_device_pe_only.py b/examples/gateway/DEPRECATEDclaiming_device_pe_only.py deleted file mode 100644 index 3a6bfb0..0000000 --- a/examples/gateway/DEPRECATEDclaiming_device_pe_only.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.DEBUG) - -THINGSBOARD_HOST = "127.0.0.1" -GATEWAY_ACCESS_TOKEN = "GATEWAY_ACCESS_TOKEN" - -DEVICE_NAME = "DEVICE_NAME" -SECRET_KEY = "DEVICE_SECRET_KEY" # Customer should write this key in device claiming widget -DURATION = 30000 # In milliseconds (30 seconds) - - -def main(): - client = TBGatewayMqttClient(THINGSBOARD_HOST, username=GATEWAY_ACCESS_TOKEN) - client.connect() - - """ - You are able to provide every parameter or pass claiming request like: - request_example = { - "DEVICE A": { - "secretKey": "DEVICE_A_SECRET_KEY", - "durationMs": "30000" - }, - "DEVICE B": { - "secretKey": "DEVICE_B_SECRET_KEY", - "durationMs": "60000" - } - - info = client.gw_claim(claiming_request=request_example).wait_for_publish() - - """ - - client.gw_connect_device(DEVICE_NAME) - - info = client.gw_claim(device_name=DEVICE_NAME, secret_key=SECRET_KEY, duration=DURATION) - - if info.rc() == 0: - print("Claiming request was sent.") - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/tls_connect.py b/examples/gateway/tls_connect.py new file mode 100644 index 0000000..1bfac9e --- /dev/null +++ b/examples/gateway/tls_connect.py @@ -0,0 +1,76 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example demonstrates how to connect to ThingsBoard over SSL using the GatewayClient, +# connect a device, and send telemetry data securely. + +import asyncio +import logging +import random + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + + +PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host +PLATFORM_PORT = 8883 # Default port for MQTT over SSL + + +# Update with your CA certificate, client certificate, and client key paths. There are no default files generated. +# You can generate them using the following guides: +# Certificates for server - https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ +# Certificates for client - https://thingsboard.io/docs/user-guide/certificates/?ubuntuThingsboardX509=X509Leaf +CA_CERT_PATH = "mqttserver.pem" # Update with your CA certificate path (Default - mqttserver.pem in the examples directory) +CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) +CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) + + +async def main(): + config = GatewayConfig() + + config.host = PLATFORM_HOST + config.port = PLATFORM_PORT + + config.ca_cert = CA_CERT_PATH + config.client_cert = CLIENT_CERT_PATH + config.private_key = CLIENT_KEY_PATH + + client = GatewayClient(config) + await client.connect() + + device_name = "Test Device B1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + # Sending telemetry data to the connected device + list_timeseries = [ + TimeseriesEntry(key="temperature", value=random.randint(20, 35)), + TimeseriesEntry(key="humidity", value=random.randint(40, 80)) + ] + logger.info("Sending list of timeseries: %s", list_timeseries) + await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) + logger.info("List of timeseries sent successfully.") + + await client.stop() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index c47b6eb..75d8c8f 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -100,6 +100,8 @@ def __init__(self, config=None): if os.getenv("TB_GW_QOS") is not None: self.qos: int = int(os.getenv("TB_GW_QOS", 1)) + + def __repr__(self): return (f"GatewayConfig(host={self.host}, port={self.port}, " f"auth={'token' if self.access_token else 'user/pass'} " diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 303b292..300fe8e 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -402,10 +402,6 @@ async def _unsubscribe_from_gateway_topics(self): async def _handle_rate_limit_response(self, response: RPCResponse): # noqa device_rate_limits_processing_result = await super()._handle_rate_limit_response(response) try: - if not isinstance(response.result, dict) or 'gatewayRateLimits' not in response.result: - logger.warning("Invalid gateway rate limit response: %r", response) - return None - gateway_rate_limits = response.result.get('gatewayRateLimits', {}) await self._gateway_rate_limiter.message_rate_limit.set_limit(gateway_rate_limits.get('messages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) diff --git a/tests/service/gateway/test_gateway_client.py b/tests/service/gateway/test_gateway_client.py index 5eb017d..d757a00 100644 --- a/tests/service/gateway/test_gateway_client.py +++ b/tests/service/gateway/test_gateway_client.py @@ -484,22 +484,5 @@ async def test_handle_rate_limit_response(): client._mqtt_manager.set_gateway_rate_limits_received.assert_called_once() -@pytest.mark.asyncio -async def test_handle_rate_limit_response_invalid_response(): - # Setup - client = GatewayClient() - - # Create an invalid response - response = RPCResponse.build(1, result="invalid") - - # Mock the parent class method - with patch('tb_mqtt_client.service.device.client.DeviceClient._handle_rate_limit_response', return_value=None): - # Act - result = await client._handle_rate_limit_response(response) - - # Assert - assert result is None - - if __name__ == '__main__': pytest.main([__file__]) From 93957dff0263d6017bfac7f492e355e97e5f71c9 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Tue, 5 Aug 2025 09:14:22 +0300 Subject: [PATCH 66/74] Refactiong, imports optimization, formatting --- examples/device/claim_device.py | 13 +- examples/device/client_provisioning.py | 3 +- examples/device/firmware_update.py | 4 +- examples/device/handle_attribute_updates.py | 4 +- examples/device/handle_rpc_requests.py | 5 +- examples/device/load.py | 13 +- examples/device/operational_example.py | 1 + examples/device/request_attributes.py | 6 +- examples/device/send_attributes.py | 4 +- examples/device/send_client_side_rpc.py | 4 +- examples/device/send_timeseries.py | 4 +- examples/device/tls_connect.py | 8 +- examples/gateway/DEPRECATEDtls_connect.py | 25 --- examples/gateway/claim_device.py | 6 +- examples/gateway/handle_rpc_requests.py | 7 +- examples/gateway/load.py | 19 +- examples/gateway/operational_example.py | 8 +- examples/gateway/request_attributes.py | 6 +- examples/gateway/send_timeseries.py | 4 +- examples/gateway/tls_connect.py | 9 +- tb_mqtt_client/common/__init__.py | 1 - tb_mqtt_client/common/async_utils.py | 16 +- tb_mqtt_client/common/config_loader.py | 6 +- tb_mqtt_client/common/gmqtt_patch.py | 16 +- tb_mqtt_client/common/logging_utils.py | 3 +- tb_mqtt_client/common/mqtt_message.py | 9 +- tb_mqtt_client/common/publish_result.py | 3 +- tb_mqtt_client/common/queue.py | 2 +- tb_mqtt_client/common/rate_limit/__init__.py | 1 - .../rate_limit/backpressure_controller.py | 7 +- .../common/rate_limit/rate_limit.py | 1 - tb_mqtt_client/constants/__init__.py | 1 - tb_mqtt_client/constants/json_typing.py | 3 +- tb_mqtt_client/constants/mqtt_topics.py | 1 + tb_mqtt_client/constants/service_keys.py | 1 - tb_mqtt_client/entities/__init__.py | 1 - tb_mqtt_client/entities/data/__init__.py | 1 - .../entities/data/attribute_request.py | 6 +- tb_mqtt_client/entities/data/claim_request.py | 2 +- .../entities/data/device_uplink_message.py | 14 +- .../entities/data/provisioning_response.py | 3 +- .../data/requested_attribute_response.py | 3 +- tb_mqtt_client/entities/data/rpc_response.py | 7 +- tb_mqtt_client/entities/gateway/__init__.py | 1 - .../entities/gateway/base_gateway_event.py | 1 + .../gateway/device_connect_message.py | 3 +- .../entities/gateway/device_info.py | 4 +- .../entities/gateway/device_session_state.py | 1 + tb_mqtt_client/entities/gateway/event_type.py | 1 + .../gateway/gateway_attribute_request.py | 21 +- .../gateway/gateway_attribute_update.py | 9 +- .../entities/gateway/gateway_claim_request.py | 14 +- .../gateway_requested_attribute_response.py | 4 +- .../entities/gateway/gateway_rpc_request.py | 10 +- .../entities/gateway/gateway_rpc_response.py | 9 +- .../gateway/gateway_uplink_message.py | 14 +- tb_mqtt_client/service/__init__.py | 1 - tb_mqtt_client/service/base_client.py | 24 +-- tb_mqtt_client/service/device/__init__.py | 1 - tb_mqtt_client/service/device/client.py | 37 ++-- .../service/device/firmware_updater.py | 4 +- .../service/device/handlers/__init__.py | 1 - .../requested_attributes_response_handler.py | 20 +- .../device/handlers/rpc_response_handler.py | 3 +- .../service/device/message_adapter.py | 26 ++- .../service/device/message_splitter.py | 11 +- tb_mqtt_client/service/gateway/__init__.py | 1 - tb_mqtt_client/service/gateway/client.py | 98 ++++++--- .../service/gateway/device_session.py | 23 +- .../gateway/direct_event_dispatcher.py | 3 +- .../gateway/gateway_client_interface.py | 35 +++- .../gateway_attribute_updates_handler.py | 4 +- ...y_requested_attributes_response_handler.py | 22 +- .../gateway/handlers/gateway_rpc_handler.py | 12 +- .../service/gateway/message_adapter.py | 46 ++-- .../service/gateway/message_sender.py | 33 ++- .../service/gateway/message_splitter.py | 7 +- tb_mqtt_client/service/message_service.py | 52 +++-- tb_mqtt_client/service/mqtt_manager.py | 61 +++--- tests/__init__.py | 1 - tests/common/test_async_utils.py | 20 ++ tests/common/test_config_loader.py | 3 +- tests/common/test_gmqtt_patch.py | 22 +- tests/common/test_publish_result.py | 6 +- tests/common/test_queue.py | 3 + tests/common/test_rate_limit.py | 1 + tests/constants/__init__.py | 1 - tests/constants/test_mqtt_topics.py | 2 +- tests/entities/gateway/test_device_info.py | 3 +- .../gateway/test_gateway_attribute_request.py | 3 +- .../gateway/test_gateway_attribute_update.py | 3 + tests/service/__init__.py | 1 - tests/service/device/__init__.py | 1 - .../test_attribute_updates_handler.py | 1 - .../handlers/test_rpc_requests_handler.py | 1 + .../handlers/test_rpc_response_handler.py | 1 + tests/service/device/test_device_client.py | 5 +- tests/service/device/test_firmware_updater.py | 15 +- tests/service/gateway/__init__.py | 1 - tests/service/gateway/handlers/__init__.py | 1 - .../test_gateway_attribute_updates_handler.py | 30 +-- ...y_requested_attributes_response_handler.py | 52 +++-- .../handlers/test_gateway_rpc_handler.py | 138 ++++++------ tests/service/gateway/test_device_manager.py | 114 +++++----- tests/service/gateway/test_device_session.py | 94 +++++---- .../gateway/test_direct_event_dispatcher.py | 82 ++++---- tests/service/gateway/test_gateway_client.py | 197 +++++++++--------- tests/service/gateway/test_message_adapter.py | 138 ++++++------ tests/service/gateway/test_message_sender.py | 103 ++++----- .../service/gateway/test_message_splitter.py | 97 +++++---- tests/service/test_json_message_adapter.py | 11 +- tests/service/test_message_service.py | 27 ++- tests/service/test_message_splitter.py | 4 +- tests/service/test_mqtt_manager.py | 18 +- 114 files changed, 1151 insertions(+), 927 deletions(-) delete mode 100644 examples/gateway/DEPRECATEDtls_connect.py diff --git a/examples/device/claim_device.py b/examples/device/claim_device.py index cc34c49..8c2acaf 100644 --- a/examples/device/claim_device.py +++ b/examples/device/claim_device.py @@ -18,18 +18,16 @@ import logging from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) - # Constants for connection PLATFORM_HOST = "localhost" # Replace with your ThingsBoard host DEVICE_ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your device's access token @@ -53,13 +51,18 @@ async def main(): claim_request = ClaimRequest.build(secret_key=CLAIMING_SECRET_KEY, duration=CLAIMING_DURATION) # Send claim request - result: PublishResult = await client.claim_device(claim_request, wait_for_publish=True, timeout=CLAIMING_DURATION + 10) + result: PublishResult = await client.claim_device(claim_request, + wait_for_publish=True, + timeout=CLAIMING_DURATION + 10) if result.is_successful(): - logger.info(f"Claiming request was sent successfully. Please use the secret key '{CLAIMING_SECRET_KEY}' to claim the device from the dashboard.") + logger.info( + f"Claiming request was sent successfully. " + f"Please use the secret key '{CLAIMING_SECRET_KEY}' to claim the device from the dashboard.") else: logger.error(f"Failed to send claiming request. Result: {result}") await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index 240e0f4..7c7b5ae 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -18,11 +18,10 @@ import logging from random import randint +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.provisioning_request import AccessTokenProvisioningCredentials, ProvisioningRequest from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py index 34df722..d9e7ae6 100644 --- a/examples/device/firmware_update.py +++ b/examples/device/firmware_update.py @@ -19,16 +19,14 @@ from time import monotonic from tb_mqtt_client.common.config_loader import DeviceConfig -from tb_mqtt_client.service.device.client import DeviceClient from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - +from tb_mqtt_client.service.device.client import DeviceClient configure_logging() logger = get_logger(__name__) logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) - firmware_received = asyncio.Event() firmware_update_timeout = 30 diff --git a/examples/device/handle_attribute_updates.py b/examples/device/handle_attribute_updates.py index b1ced19..145d23c 100644 --- a/examples/device/handle_attribute_updates.py +++ b/examples/device/handle_attribute_updates.py @@ -18,10 +18,9 @@ import logging from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) @@ -52,5 +51,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/device/handle_rpc_requests.py b/examples/device/handle_rpc_requests.py index 584feb8..f76d1e0 100644 --- a/examples/device/handle_rpc_requests.py +++ b/examples/device/handle_rpc_requests.py @@ -18,11 +18,10 @@ import logging from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) @@ -38,6 +37,7 @@ async def rpc_request_callback(request: RPCRequest) -> RPCResponse: else: return RPCResponse.build(request_id=request.request_id, result={"message": "Unknown method"}) + async def main(): config = DeviceConfig() config.host = "localhost" @@ -56,5 +56,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/device/load.py b/examples/device/load.py index 2c739f1..d846859 100644 --- a/examples/device/load.py +++ b/examples/device/load.py @@ -6,13 +6,13 @@ from random import randint from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.common.publish_result import PublishResult -from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger # --- Logging setup --- configure_logging() @@ -21,6 +21,8 @@ logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) # --- Constants --- +THINGSBOARD_HOST = "localhost" +ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your actual access token BATCH_SIZE = 1000 MAX_PENDING_BATCHES = 100 FUTURE_TIMEOUT = 1.0 @@ -115,8 +117,8 @@ def _shutdown_handler(): signal.signal(sig, lambda *_: _shutdown_handler()) config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" + config.host = THINGSBOARD_HOST + config.access_token = ACCESS_TOKEN client = DeviceClient(config) client.set_attribute_update_callback(attribute_update_callback) @@ -136,7 +138,8 @@ def _shutdown_handler(): try: delivery_start_ts = time.perf_counter() - ts_now = int(datetime.now(UTC).timestamp() * 1000) + # ts_now = int(datetime.now(UTC).timestamp() * 1000) + # You can add ts=ts_now to try with entities with ts, it will group messages by ts and send them together entries = [TimeseriesEntry(f"temperature{i}", randint(20, 40)) for i in range(BATCH_SIZE)] while not stop_event.is_set(): diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py index 18f8d2d..57c1994 100644 --- a/examples/device/operational_example.py +++ b/examples/device/operational_example.py @@ -68,6 +68,7 @@ async def rpc_request_callback(request: RPCRequest) -> RPCResponse: response = RPCResponse.build(request_id=request.request_id, result=response_data) return response + async def rpc_response_callback(response: RPCResponse): """ Callback function to handle RPC responses for client side RPC requests. diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index 32f32ee..794aae3 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -18,11 +18,10 @@ import logging from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) @@ -31,10 +30,12 @@ response_received = asyncio.Event() + async def attribute_request_callback(response: RequestedAttributeResponse): logger.info("Received attribute response: %r", response) response_received.set() + async def main(): config = DeviceConfig() config.host = "localhost" @@ -60,5 +61,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/device/send_attributes.py b/examples/device/send_attributes.py index 76a779f..6144089 100644 --- a/examples/device/send_attributes.py +++ b/examples/device/send_attributes.py @@ -18,10 +18,9 @@ import logging from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) @@ -63,5 +62,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py index f652935..2e2d4fc 100644 --- a/examples/device/send_client_side_rpc.py +++ b/examples/device/send_client_side_rpc.py @@ -18,11 +18,10 @@ import logging from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) @@ -57,6 +56,7 @@ async def main(): await asyncio.sleep(5) await client.stop() + if __name__ == "__main__": try: asyncio.run(main()) diff --git a/examples/device/send_timeseries.py b/examples/device/send_timeseries.py index 4528fb1..5c5c5f9 100644 --- a/examples/device/send_timeseries.py +++ b/examples/device/send_timeseries.py @@ -20,10 +20,9 @@ from time import time from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) @@ -68,5 +67,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index a0370a0..aaa99ca 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -19,26 +19,23 @@ from random import randint from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.client import DeviceClient -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger - configure_logging() logger = get_logger(__name__) logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) - PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host PLATFORM_PORT = 8883 # Default port for MQTT over SSL - # Update with your CA certificate, client certificate, and client key paths. There are no default files generated. # You can generate them using the following guides: # Certificates for server - https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ # Certificates for client - https://thingsboard.io/docs/user-guide/certificates/?ubuntuThingsboardX509=X509Leaf -CA_CERT_PATH = "mqttserver.pem" # Update with your CA certificate path (Default - mqttserver.pem in the examples directory) +CA_CERT_PATH = "ca_cert.pem" # Update with your CA certificate path (Default - ca_cert.pem in the examples directory) CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) @@ -66,5 +63,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/gateway/DEPRECATEDtls_connect.py b/examples/gateway/DEPRECATEDtls_connect.py deleted file mode 100644 index 4315508..0000000 --- a/examples/gateway/DEPRECATEDtls_connect.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025 ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_gateway_mqtt import TBGatewayMqttClient -import socket - -logging.basicConfig(level=logging.DEBUG) -# connecting to localhost -gateway = TBGatewayMqttClient(socket.gethostname()) -gateway.connect(tls=True, - ca_certs="mqttserver.pub.pem", - cert_file="mqttclient.nopass.pem") -gateway.disconnect() diff --git a/examples/gateway/claim_device.py b/examples/gateway/claim_device.py index 8c5ee0c..3632c7c 100644 --- a/examples/gateway/claim_device.py +++ b/examples/gateway/claim_device.py @@ -64,8 +64,10 @@ async def main(): logger.error("Failed to send claim request for device: %s", device_name) return - logger.info("Claim request sent successfully for device: %s, you have %r seconds to claim device using ThingsBoard UI or API.", - device_name, CLAIMING_DURATION_MS / 1000) + logger.info( + "Claim request sent successfully for device: %s, " + "you have %r seconds to claim device using ThingsBoard UI or API.", + device_name, CLAIMING_DURATION_MS / 1000) # Disconnect device logger.info("Disconnecting device: %s", device_name) diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py index c711801..1b6135f 100644 --- a/examples/gateway/handle_rpc_requests.py +++ b/examples/gateway/handle_rpc_requests.py @@ -25,7 +25,8 @@ logger = get_logger(__name__) -async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: +async def device_rpc_request_handler(device_session: DeviceSession, + rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: """ Callback to handle RPC requests from the device. :param device_session: Device session for which the request was made. @@ -44,7 +45,9 @@ async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: } } - rpc_response = GatewayRPCResponse.build(device_session.device_info.device_name, rpc_request.request_id, response_data) + rpc_response = GatewayRPCResponse.build(device_session.device_info.device_name, + rpc_request.request_id, + response_data) logger.info("Sending RPC response for request id %r: %r", rpc_request.request_id, rpc_response) diff --git a/examples/gateway/load.py b/examples/gateway/load.py index 573b5db..e4972d1 100644 --- a/examples/gateway/load.py +++ b/examples/gateway/load.py @@ -21,11 +21,11 @@ from typing import List from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.gateway.client import GatewayClient from tb_mqtt_client.service.gateway.device_session import DeviceSession -from tb_mqtt_client.common.logging_utils import configure_logging, get_logger -from tb_mqtt_client.common.publish_result import PublishResult # --- Logging --- configure_logging() @@ -34,12 +34,16 @@ logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) # --- Constants --- +THINGSBOARD_HOST = "localhost" +ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your actual access token NUM_DEVICES = 10 BATCH_SIZE = 1000 MAX_PENDING = 10 FUTURE_TIMEOUT = 1.0 DEVICE_PREFIX = "perf-test-device" -WAIT_FOR_PUBLISH = False # Set to True if you want to wait for publish confirmation (can slow down the test, because each message will wait for confirmation) +WAIT_FOR_PUBLISH = False # Set to True if you want to wait for publish confirmation +# (can slow down the test, because each message will wait for confirmation) + # --- Test logic --- async def send_batch(client: GatewayClient, session: DeviceSession) -> List[asyncio.Future]: @@ -52,6 +56,7 @@ async def send_batch(client: GatewayClient, session: DeviceSession) -> List[asyn return [result] return [] + async def wait_for_futures(futures: List[asyncio.Future]) -> int: delivered = 0 if isinstance(futures, list) and futures and isinstance(futures[0], asyncio.Future): @@ -87,8 +92,8 @@ def _shutdown(): signal.signal(sig, lambda *_: _shutdown()) config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" + config.host = THINGSBOARD_HOST + config.access_token = ACCESS_TOKEN client = GatewayClient(config) await client.connect() @@ -113,8 +118,8 @@ def _shutdown(): try: while not stop_event.is_set(): for session in sessions: - futs = await send_batch(client, session) - pending_futures.extend(futs) + futures = await send_batch(client, session) + pending_futures.extend(futures) sent_batches += 1 if len(pending_futures) >= MAX_PENDING: diff --git a/examples/gateway/operational_example.py b/examples/gateway/operational_example.py index 927a888..a15c9c7 100644 --- a/examples/gateway/operational_example.py +++ b/examples/gateway/operational_example.py @@ -57,7 +57,8 @@ async def requested_attributes_handler(device_session: DeviceSession, response: response.shared) -async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: +async def device_rpc_request_handler(device_session: DeviceSession, + rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: logger.info("Received RPC request for %s: %r", device_session.device_info.device_name, rpc_request) response_data = { "status": "success", @@ -111,7 +112,8 @@ def _shutdown_handler(): # Register callbacks for each device client.device_manager.set_attribute_update_callback(session.device_info.device_id, attribute_update_handler) - client.device_manager.set_attribute_response_callback(session.device_info.device_id, requested_attributes_handler) + client.device_manager.set_attribute_response_callback(session.device_info.device_id, + requested_attributes_handler) client.device_manager.set_rpc_request_callback(session.device_info.device_id, device_rpc_request_handler) # ---- Main loop ---- @@ -173,4 +175,4 @@ def _shutdown_handler(): try: asyncio.run(main()) except KeyboardInterrupt: - logger.info("Interrupted by user.") \ No newline at end of file + logger.info("Interrupted by user.") diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index eb366bb..2310129 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -58,7 +58,8 @@ async def main(): return # Register callback for requested attributes - client.device_manager.set_attribute_response_callback(device_session.device_info.device_id, requested_attributes_handler) + client.device_manager.set_attribute_response_callback(device_session.device_info.device_id, + requested_attributes_handler) logger.info("Device connected successfully: %s", device_name) @@ -77,7 +78,8 @@ async def main(): # Request attributes for the device logger.info("Requesting attributes for device: %s", device_name) attributes_to_request = ["maintenance", "id", "location"] - attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=attributes_to_request) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, + client_keys=attributes_to_request) await client.send_device_attributes_request(device_session, attribute_request, wait_for_publish=True) diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index dff7719..ebc1b0a 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -90,10 +90,10 @@ def _shutdown_handler(): TimeseriesEntry(key="temperature%i" % i, value=loop_counter, ts=ts) for i in range(20) ] logger.info("Sending TimeseriesEntry objects: %s", timeseries_entries) - await client.send_device_timeseries(device_session=device_session, data=timeseries_entries, wait_for_publish=True) + await client.send_device_timeseries(device_session=device_session, data=timeseries_entries, + wait_for_publish=True) logger.info("TimeseriesEntry objects sent successfully.") - try: logger.info("Waiting before next iteration...") await asyncio.wait_for(stop_event.wait(), timeout=1) diff --git a/examples/gateway/tls_connect.py b/examples/gateway/tls_connect.py index 1bfac9e..dfb3ffe 100644 --- a/examples/gateway/tls_connect.py +++ b/examples/gateway/tls_connect.py @@ -29,16 +29,14 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) - PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host PLATFORM_PORT = 8883 # Default port for MQTT over SSL - # Update with your CA certificate, client certificate, and client key paths. There are no default files generated. # You can generate them using the following guides: # Certificates for server - https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ # Certificates for client - https://thingsboard.io/docs/user-guide/certificates/?ubuntuThingsboardX509=X509Leaf -CA_CERT_PATH = "mqttserver.pem" # Update with your CA certificate path (Default - mqttserver.pem in the examples directory) +CA_CERT_PATH = "ca_cert.pem" # Update with your CA certificate path (Default - ca_cert.pem in the examples directory) CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) @@ -63,8 +61,8 @@ async def main(): # Sending telemetry data to the connected device list_timeseries = [ - TimeseriesEntry(key="temperature", value=random.randint(20, 35)), - TimeseriesEntry(key="humidity", value=random.randint(40, 80)) + TimeseriesEntry(key="temperature", value=random.randint(20, 35)), + TimeseriesEntry(key="humidity", value=random.randint(40, 80)) ] logger.info("Sending list of timeseries: %s", list_timeseries) await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) @@ -72,5 +70,6 @@ async def main(): await client.stop() + if __name__ == "__main__": asyncio.run(main()) diff --git a/tb_mqtt_client/common/__init__.py b/tb_mqtt_client/common/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/common/__init__.py +++ b/tb_mqtt_client/common/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py index 912a9ee..70cf553 100644 --- a/tb_mqtt_client/common/async_utils.py +++ b/tb_mqtt_client/common/async_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import threading from typing import Union, Optional, Any, List, Set, Dict @@ -59,8 +60,10 @@ def child_resolved(self, child: asyncio.Future): parent.set_result(None) self._parent_to_remaining.pop(parent, None) + future_map = FutureMap() + async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, Any], stop_event: asyncio.Event, timeout: Optional[float]) -> Optional[Any]: @@ -97,10 +100,11 @@ async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, if not stop_task.done(): stop_task.cancel() -def run_coroutine_sync(coro_func, timeout: float = 3.0, raise_on_timeout: bool = False): + +def run_coroutine_sync(coroutine, timeout: float = 3.0, raise_on_timeout: bool = False): """ Run async coroutine and return its result from a sync function even if event loop is running. - :param coro_func: async function with no arguments (like: lambda: some_async_fn()) + :param coroutine: async function with no arguments (like: lambda: some_async_fn()) :param timeout: max wait time in seconds :param raise_on_timeout: if True, raise TimeoutError on timeout; otherwise return None """ @@ -109,7 +113,7 @@ def run_coroutine_sync(coro_func, timeout: float = 3.0, raise_on_timeout: bool = async def wrapper(): try: - result = await coro_func() + result = await coroutine() result_container['result'] = result except Exception as e: result_container['error'] = e @@ -122,12 +126,12 @@ async def wrapper(): completed = event.wait(timeout=timeout) if not completed: - logger.warning("Timeout while waiting for coroutine to finish: %s", coro_func) + logger.warning("Timeout while waiting for coroutine to finish: %s", coroutine) if raise_on_timeout: - raise TimeoutError(f"Coroutine {coro_func} did not complete in {timeout} seconds.") + raise TimeoutError(f"Coroutine {coroutine} did not complete in {timeout} seconds.") return None if 'error' in result_container: raise result_container['error'] - return result_container.get('result') \ No newline at end of file + return result_container.get('result') diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index 75d8c8f..530ba5a 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -22,7 +22,8 @@ class DeviceConfig: This class loads configuration options from environment variables, allowing for flexible deployment and easy customization of device connection settings. """ - def __init__(self, config = None): + + def __init__(self, config=None): if config is not None: self.host: str = config.get("host", "localhost") self.port: int = config.get("port", 1883) @@ -71,6 +72,7 @@ class GatewayConfig(DeviceConfig): Configuration class for ThingsBoard gateway clients. This class extends DeviceConfig to include additional options specific to gateways. """ + def __init__(self, config=None): # TODO: REFACTOR, temporary solution for development super().__init__(config) @@ -100,8 +102,6 @@ def __init__(self, config=None): if os.getenv("TB_GW_QOS") is not None: self.qos: int = int(os.getenv("TB_GW_QOS", 1)) - - def __repr__(self): return (f"GatewayConfig(host={self.host}, port={self.port}, " f"auth={'token' if self.access_token else 'user/pass'} " diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index 592f712..513df8c 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import asyncio import heapq import struct @@ -50,7 +49,8 @@ def build_package(cls, message: MqttPublishMessage, protocol, mid: int = None) - if message.payload_size == 0: logger.debug("Sending PUBLISH (q%d), '%s' (NULL payload)", message.qos, message.topic) else: - logger.debug("Sending PUBLISH (q%d), '%s', ... (%d bytes)", message.qos, message.topic, message.payload_size) + logger.debug("Sending PUBLISH (q%d), '%s', ... (%d bytes)", message.qos, message.topic, + message.payload_size) if message.qos > 0: remaining_length += 2 @@ -141,7 +141,6 @@ def parse_mqtt_properties(packet: bytes) -> dict: return dict(properties_dict) - @staticmethod def extract_reason_code(packet): """ @@ -271,13 +270,13 @@ def patch_gmqtt_protocol_connection_lost(patch_utils_instance): and pass the exception to the handler. """ try: - original_base_connection_lost = BaseMQTTProtocol.connection_lost + def patched_base_connection_lost(self, exc): self._connected.clear() super(BaseMQTTProtocol, self).connection_lost(exc) + BaseMQTTProtocol.connection_lost = patched_base_connection_lost - original_mqtt_connection_lost = MQTTProtocol.connection_lost def patched_mqtt_connection_lost(self, exc): super(MQTTProtocol, self).connection_lost(exc) reason_code = 0 @@ -318,6 +317,7 @@ def patched_mqtt_connection_lost(self, exc): MQTTProtocol.connection_lost = patched_mqtt_connection_lost original_call = MqttPackageHandler.__call__ + def patched_call(self, cmd, packet): try: if cmd == MQTTCommands.DISCONNECT and hasattr(self._connection, '_disconnect_exc'): @@ -351,6 +351,7 @@ def patched_call(self, cmd, packet): def patch_puback_handling(self, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): original_handler = MqttPackageHandler._handle_puback_packet + def wrapped_handle_puback(self, cmd, packet): try: mid = struct.unpack("!H", packet[:2])[0] @@ -368,6 +369,7 @@ def wrapped_handle_puback(self, cmd, packet): except Exception as e: logger.exception("Error while handling PUBACK with properties: %s", e) return original_handler(self, cmd, packet) + MqttPackageHandler._handle_puback_packet = wrapped_handle_puback logger.debug("Patched _handle_puback_packet for QoS1 support.") @@ -377,6 +379,7 @@ async def pop_message_with_tm(): patch_utils_instance.client._persistent_storage._check_empty() return tm, mid, raw_package + patch_utils_instance.client._persistent_storage.pop_message = pop_message_with_tm async def _retry_loop(self): @@ -410,7 +413,8 @@ async def _retry_loop(self): logger.error("Resending PUBLISH message with mid=%r, topic=%s", mid, mqtt_msg.topic) try: - await self.client.put_retry_message(mqtt_msg) # noqa This method sets in message service to the client + await self.client.put_retry_message( + mqtt_msg) # noqa This method sets in message service to the client except AttributeError as e: logger.trace("Failed to resend message with mid=%r: %s", mid, e) diff --git a/tb_mqtt_client/common/logging_utils.py b/tb_mqtt_client/common/logging_utils.py index 5ff651a..698bd5e 100644 --- a/tb_mqtt_client/common/logging_utils.py +++ b/tb_mqtt_client/common/logging_utils.py @@ -16,7 +16,6 @@ import sys from typing import Optional - DEFAULT_LOG_FORMAT = "[%(asctime)s.%(msecs)03d] [%(levelname)s] %(name)s - %(lineno)d - %(message)s" DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -28,10 +27,12 @@ class ExtendedLogger(logging.Logger): """ Custom logger class that supports TRACE level logging. """ + def trace(self, message, *args, **kwargs): if self.isEnabledFor(TRACE_LEVEL): self._log(TRACE_LEVEL, message, args, **kwargs) + logging.setLoggerClass(ExtendedLogger) diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py index 631900a..5a40b7a 100644 --- a/tb_mqtt_client/common/mqtt_message.py +++ b/tb_mqtt_client/common/mqtt_message.py @@ -32,15 +32,16 @@ class MqttPublishMessage(Message): A custom Publish MQTT message class that extends the gmqtt Message class. Contains additional information like datapoints, to avoid rate limits exceeding. """ + def __init__(self, topic: str, payload: Union[bytes, GatewayUplinkMessage, DeviceUplinkMessage], qos: int = 1, retain: bool = False, datapoints: int = 0, - delivery_futures = None, + delivery_futures=None, main_ts: Optional[int] = None, - original_payload = None, + original_payload=None, **kwargs): """ Initialize the MqttMessage with topic, payload, QoS, retain flag, and datapoints. @@ -58,7 +59,7 @@ def __init__(self, self.topic = topic self.is_service_message = self.topic not in mqtt_topics.TOPICS_WITH_DATAPOINTS_CHECK self.is_device_message = not isinstance(original_payload, GatewayUplinkMessage) - self.qos = qos + self.qos = qos if qos is not None else 1 if self.qos < 0 or self.qos > 1: logger.warning(f"Invalid QoS {self.qos} for topic {topic}, using default QoS 1") self.qos = 1 @@ -79,7 +80,7 @@ def __init__(self, if not hasattr(future, 'uuid'): future.uuid = uuid4() logger.trace(f"Created MqttMessage with topic: {topic}, payload type: {type(payload).__name__}, " - f"datapoints: {datapoints}, delivery_future id: {self.delivery_futures[0].uuid}") + f"datapoints: {datapoints}, delivery_future id: {self.delivery_futures[0].uuid}") def mark_as_sent(self, message_id: int): """Mark the message as sent.""" diff --git a/tb_mqtt_client/common/publish_result.py b/tb_mqtt_client/common/publish_result.py index 429db0f..cda1589 100644 --- a/tb_mqtt_client/common/publish_result.py +++ b/tb_mqtt_client/common/publish_result.py @@ -19,7 +19,7 @@ class PublishResult: def __init__(self, topic: str, qos: int, message_id: int, payload_size: int, reason_code: int, datapoints_count: int = 0): self.topic = topic - self.qos = qos + self.qos = qos if qos is not None else 1 self.message_id = message_id self.payload_size = payload_size self.reason_code = reason_code @@ -59,7 +59,6 @@ def is_successful(self) -> bool: """ return self.reason_code == 0 - @staticmethod def merge(results: List['PublishResult']) -> 'PublishResult': if not results: diff --git a/tb_mqtt_client/common/queue.py b/tb_mqtt_client/common/queue.py index 3bc951a..edd59bb 100644 --- a/tb_mqtt_client/common/queue.py +++ b/tb_mqtt_client/common/queue.py @@ -94,4 +94,4 @@ def is_empty(self): return not self._deque def size(self): - return len(self._deque) \ No newline at end of file + return len(self._deque) diff --git a/tb_mqtt_client/common/rate_limit/__init__.py b/tb_mqtt_client/common/rate_limit/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/common/rate_limit/__init__.py +++ b/tb_mqtt_client/common/rate_limit/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py index 639329d..76450cd 100644 --- a/tb_mqtt_client/common/rate_limit/backpressure_controller.py +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import asyncio from asyncio import Event from datetime import datetime, timedelta, UTC @@ -33,7 +32,7 @@ def __init__(self, main_stop_event: Event): self._max_backoff_seconds = 3600 # 1 hour self._can_process_messages_events: List[asyncio.Event] = [] logger.debug("BackpressureController initialized with default pause duration of %s seconds", - self._default_pause_duration.total_seconds()) + self._default_pause_duration.total_seconds()) def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): if self.__main_stop_event.is_set(): @@ -57,8 +56,8 @@ def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): # Cap at max backoff delay_seconds = min(delay_seconds, self._max_backoff_seconds) - logger.warning("Applying backpressure for %d seconds (consecutive quota exceeded: %d)", - delay_seconds, self._consecutive_quota_exceeded) + logger.warning("Applying backpressure for %d seconds (consecutive quota exceeded: %d)", + delay_seconds, self._consecutive_quota_exceeded) duration = timedelta(seconds=delay_seconds) self._pause_until = now + duration diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py index cb7f02b..b597eff 100644 --- a/tb_mqtt_client/common/rate_limit/rate_limit.py +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import asyncio import logging import os diff --git a/tb_mqtt_client/constants/__init__.py b/tb_mqtt_client/constants/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/constants/__init__.py +++ b/tb_mqtt_client/constants/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/constants/json_typing.py b/tb_mqtt_client/constants/json_typing.py index 3e98ff7..e66b8fb 100644 --- a/tb_mqtt_client/constants/json_typing.py +++ b/tb_mqtt_client/constants/json_typing.py @@ -17,6 +17,7 @@ JSONPrimitive = Union[str, int, float, bool, None] JSONCompatibleType = Union[JSONPrimitive, List["JSONType"], Dict[str, "JSONType"]] + def validate_json_compatibility(value: object) -> None: """ Validates that the input value is fully JSON-compatible in structure and type. @@ -40,4 +41,4 @@ def validate_json_compatibility(value: object) -> None: raise ValueError(f"Invalid JSON key at {path}: expected str, got {type(k).__name__} ({k!r})") stack.append((v, f"{path}.{k}")) else: - raise ValueError(f"Invalid JSON value at {path}: unsupported - {type(current).__name__} ({current!r})") \ No newline at end of file + raise ValueError(f"Invalid JSON value at {path}: unsupported - {type(current).__name__} ({current!r})") diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index b596e4e..aca0cca 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -66,6 +66,7 @@ GATEWAY_CLAIM_TOPIC }) + # Topic Builders diff --git a/tb_mqtt_client/constants/service_keys.py b/tb_mqtt_client/constants/service_keys.py index 5a547ad..e5351a8 100644 --- a/tb_mqtt_client/constants/service_keys.py +++ b/tb_mqtt_client/constants/service_keys.py @@ -12,6 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. - TELEMETRY_TIMESTAMP_PARAMETER = "ts" TELEMETRY_VALUES_PARAMETER = "values" diff --git a/tb_mqtt_client/entities/__init__.py b/tb_mqtt_client/entities/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/entities/__init__.py +++ b/tb_mqtt_client/entities/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/entities/data/__init__.py b/tb_mqtt_client/entities/data/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/entities/data/__init__.py +++ b/tb_mqtt_client/entities/data/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py index 3b6fb9b..5434312 100644 --- a/tb_mqtt_client/entities/data/attribute_request.py +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -33,13 +33,15 @@ class AttributeRequest(BaseGatewayEvent): event_type: GatewayEventType = GatewayEventType.DEVICE_ATTRIBUTE_REQUEST def __new__(self, *args, **kwargs): - raise TypeError("Direct instantiation of AttributeRequest is not allowed. Use 'await AttributeRequest.build(...)'.") + raise TypeError( + "Direct instantiation of AttributeRequest is not allowed. Use 'await AttributeRequest.build(...)'.") def __repr__(self) -> str: return f"AttributeRequest(id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" @classmethod - async def build(cls, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'AttributeRequest': + async def build(cls, shared_keys: Optional[List[str]] = None, + client_keys: Optional[List[str]] = None) -> 'AttributeRequest': """ Build a new AttributeRequest with a unique request ID, using the global ID generator. """ diff --git a/tb_mqtt_client/entities/data/claim_request.py b/tb_mqtt_client/entities/data/claim_request.py index ec42057..27fbb43 100644 --- a/tb_mqtt_client/entities/data/claim_request.py +++ b/tb_mqtt_client/entities/data/claim_request.py @@ -55,4 +55,4 @@ def to_payload_format(self) -> Dict[str, Any]: payload["secretKey"] = self.secret_key if self.duration is not None: payload["durationMs"] = int(self.duration * 1000) - return payload \ No newline at end of file + return payload diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py index 47daad5..1360d4b 100644 --- a/tb_mqtt_client/entities/data/device_uplink_message.py +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -39,7 +39,9 @@ class DeviceUplinkMessage: main_ts: Optional[int] = None def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of DeviceUplinkMessage is not allowed. Use DeviceUplinkMessageBuilder to construct instances.") + raise TypeError( + "Direct instantiation of DeviceUplinkMessage is not allowed. " + "Use DeviceUplinkMessageBuilder to construct instances.") def __repr__(self): return (f"DeviceUplinkMessage(device_name={self.device_name}, " @@ -127,8 +129,10 @@ def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]] self.__size += attribute.size return self - def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry], OrderedDict[ - int, List[TimeseriesEntry]]]) -> 'DeviceUplinkMessageBuilder': + def add_timeseries(self, + timeseries: Union[TimeseriesEntry, + List[TimeseriesEntry], + OrderedDict[int, List[TimeseriesEntry]]]) -> 'DeviceUplinkMessageBuilder': if isinstance(timeseries, OrderedDict): self._timeseries = timeseries return self @@ -149,8 +153,8 @@ def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry self.__size += timeseries_entry.size return self - def add_delivery_futures(self, futures: Union[ - asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': + def add_delivery_futures(self, futures: Union[asyncio.Future[PublishResult], + List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': if not isinstance(futures, list): futures = [futures] if futures: diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py index 82e7a73..2e79024 100644 --- a/tb_mqtt_client/entities/data/provisioning_response.py +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -27,7 +27,8 @@ class ProvisioningResponse: error: Optional[str] = None def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of ProvisioningResponse is not allowed. Use ProvisioningResponse.build(result, error).") # noqa + raise TypeError( + "Direct instantiation of ProvisioningResponse is not allowed. Use ProvisioningResponse.build(result, error).") # noqa def __repr__(self) -> str: return f"ProvisioningResponse(status={self.status}, result={self.result}, error={self.error})" diff --git a/tb_mqtt_client/entities/data/requested_attribute_response.py b/tb_mqtt_client/entities/data/requested_attribute_response.py index 5ee97b6..9678888 100644 --- a/tb_mqtt_client/entities/data/requested_attribute_response.py +++ b/tb_mqtt_client/entities/data/requested_attribute_response.py @@ -20,7 +20,6 @@ @dataclass(slots=True, frozen=True) class RequestedAttributeResponse: - request_id: int shared: List[AttributeEntry] client: List[AttributeEntry] @@ -89,5 +88,5 @@ def from_dict(cls, data: Dict[str, Any]) -> 'RequestedAttributeResponse': """ shared = [AttributeEntry(k, v) for k, v in data.get('shared', {}).items()] client = [AttributeEntry(k, v) for k, v in data.get('client', {}).items()] - request_id = data.get('request_id', -1) # Default to -1 if not provided + request_id = data.get('request_id', -1) # Default to -1 if not provided return cls(shared=shared, client=client, request_id=request_id) diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py index ba2b5d0..4d9a133 100644 --- a/tb_mqtt_client/entities/data/rpc_response.py +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -32,6 +32,7 @@ class RPCStatus(Enum): def __str__(self): return self.value + @dataclass(slots=True, frozen=True) class RPCResponse: """ @@ -48,13 +49,15 @@ class RPCResponse: error: Optional[Union[str, Dict[str, Any]]] = None def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of RPCResponse is not allowed. Use RPCResponse.build(request_id, result, error).") + raise TypeError( + "Direct instantiation of RPCResponse is not allowed. Use RPCResponse.build(request_id, result, error).") def __repr__(self) -> str: return f"RPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" @classmethod - def build(cls, request_id: Union[int, str], result: Optional[Any] = None, error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'RPCResponse': + def build(cls, request_id: Union[int, str], result: Optional[Any] = None, + error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'RPCResponse': """ Constructs an RPCResponse explicitly. """ diff --git a/tb_mqtt_client/entities/gateway/__init__.py b/tb_mqtt_client/entities/gateway/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/entities/gateway/__init__.py +++ b/tb_mqtt_client/entities/gateway/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/entities/gateway/base_gateway_event.py b/tb_mqtt_client/entities/gateway/base_gateway_event.py index 1dee49e..63a9679 100644 --- a/tb_mqtt_client/entities/gateway/base_gateway_event.py +++ b/tb_mqtt_client/entities/gateway/base_gateway_event.py @@ -14,6 +14,7 @@ from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + class BaseGatewayEvent: def __init__(self, event_type: GatewayEventType): self.__event_type = event_type diff --git a/tb_mqtt_client/entities/gateway/device_connect_message.py b/tb_mqtt_client/entities/gateway/device_connect_message.py index d0a685d..cdec62d 100644 --- a/tb_mqtt_client/entities/gateway/device_connect_message.py +++ b/tb_mqtt_client/entities/gateway/device_connect_message.py @@ -30,7 +30,8 @@ class DeviceConnectMessage(BaseGatewayEvent): event_type: GatewayEventType = GatewayEventType.DEVICE_CONNECT def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of DeviceConnectMessage is not allowed. Use 'await DeviceConnectMessage.build(...)'.") + raise TypeError( + "Direct instantiation of DeviceConnectMessage is not allowed. Use 'await DeviceConnectMessage.build(...)'.") def __repr__(self): return f"DeviceConnectMessage(device_name={self.device_name}, device_profile={self.device_profile})" diff --git a/tb_mqtt_client/entities/gateway/device_info.py b/tb_mqtt_client/entities/gateway/device_info.py index 2e44f5b..f013958 100644 --- a/tb_mqtt_client/entities/gateway/device_info.py +++ b/tb_mqtt_client/entities/gateway/device_info.py @@ -31,7 +31,9 @@ def __post_init__(self): def __setattr__(self, key, value): if not self._initializing: - raise AttributeError(f"Cannot modify attribute '{key}' of frozen DeviceInfo instance. Use rename() method to change device_name.") + raise AttributeError( + f"Cannot modify attribute '{key}' of frozen DeviceInfo instance." + "Use rename() method to change device_name.") else: super().__setattr__(key, value) diff --git a/tb_mqtt_client/entities/gateway/device_session_state.py b/tb_mqtt_client/entities/gateway/device_session_state.py index 5b10891..043eead 100644 --- a/tb_mqtt_client/entities/gateway/device_session_state.py +++ b/tb_mqtt_client/entities/gateway/device_session_state.py @@ -14,6 +14,7 @@ from enum import Enum + class DeviceSessionState(Enum): CONNECTED = "CONNECTED" DISCONNECTED = "DISCONNECTED" diff --git a/tb_mqtt_client/entities/gateway/event_type.py b/tb_mqtt_client/entities/gateway/event_type.py index 41fbd0b..0f97890 100644 --- a/tb_mqtt_client/entities/gateway/event_type.py +++ b/tb_mqtt_client/entities/gateway/event_type.py @@ -14,6 +14,7 @@ from enum import Enum + class GatewayEventType(Enum): """ Enum representing different types of gateway events. diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py index 5958e6a..b3dbc02 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py @@ -31,13 +31,18 @@ class GatewayAttributeRequest(AttributeRequest): device_session: DeviceSession = None # type: ignore[assignment] def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of GatewayAttributeRequest is not allowed. Use 'await GatewayAttributeRequest.build(...)'.") + raise TypeError( + "Direct instantiation of GatewayAttributeRequest is not allowed. " + "Use 'await GatewayAttributeRequest.build(...)'.") def __repr__(self) -> str: - return f"GatewayAttributeRequest(device_session={self.device_session}, id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" + return (f"GatewayAttributeRequest(device_session={self.device_session}, id={self.request_id}, " + f"shared_keys={self.shared_keys}, client_keys={self.client_keys})") @classmethod - async def build(cls, device_session: DeviceSession, shared_keys: Optional[List[str]] = None, client_keys: Optional[List[str]] = None) -> 'GatewayAttributeRequest': # noqa + async def build(cls, device_session: DeviceSession, # noqa + shared_keys: Optional[List[str]] = None, + client_keys: Optional[List[str]] = None) -> 'GatewayAttributeRequest': """ Build a new GatewayAttributeRequest with a unique request ID, using the global ID generator. """ @@ -53,7 +58,8 @@ async def build(cls, device_session: DeviceSession, shared_keys: Optional[List[s return self @classmethod - async def from_attribute_request(cls, device_session: DeviceSession, attribute_request: AttributeRequest) -> 'GatewayAttributeRequest': + async def from_attribute_request(cls, device_session: DeviceSession, + attribute_request: AttributeRequest) -> 'GatewayAttributeRequest': """ Create a GatewayAttributeRequest from an existing AttributeRequest and a DeviceSession. """ @@ -67,19 +73,20 @@ async def from_attribute_request(cls, device_session: DeviceSession, attribute_r object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) return self - def to_payload_format(self) -> Dict[str, Union[str, bool]]: """ Convert the attribute request into the expected MQTT payload format. """ payload = {"device": self.device_session.device_info.device_name, "id": self.request_id} - single_key_request = (self.client_keys is not None and len(self.client_keys) == 1) or (self.shared_keys is not None and len(self.shared_keys) == 1) + single_key_request = (self.client_keys is not None and len(self.client_keys) == 1) or ( + self.shared_keys is not None and len(self.shared_keys) == 1) request_key = 'key' if single_key_request else 'keys' if self.client_keys is not None and self.client_keys: payload['client'] = True payload[request_key] = self.client_keys[0] if single_key_request else self.client_keys elif self.shared_keys is not None and self.shared_keys: - # TODO: In current realisation on server it is not possible to request values for the both scopes simultaneously, recommended to improve the platform API + # TODO: In current realization on server it is not possible to request values for the both scopes + # TODO: at the same time, recommended to improve the platform API payload['client'] = False payload[request_key] = self.shared_keys[0] if single_key_request else self.shared_keys return payload diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py index 51fd359..ae5c0ee 100644 --- a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List, Union from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry @@ -25,18 +26,20 @@ class GatewayAttributeUpdate(AttributeUpdate, BaseGatewayEvent): This event is used to notify about changes in device shared attributes. """ - def __init__(self, device_name: str, attribute_update: Union[AttributeUpdate, List[AttributeEntry], AttributeEntry]): + def __init__(self, device_name: str, + attribute_update: Union[AttributeUpdate, List[AttributeEntry], AttributeEntry]): super().__init__(GatewayEventType.DEVICE_ATTRIBUTE_UPDATE) if isinstance(attribute_update, list) and all(isinstance(entry, AttributeEntry) for entry in attribute_update): attribute_update = AttributeUpdate(entries=attribute_update) elif isinstance(attribute_update, AttributeEntry): attribute_update = AttributeUpdate(entries=[attribute_update]) elif not isinstance(attribute_update, AttributeUpdate): - raise TypeError("attribute_update must be an instance of AttributeUpdate, list of AttributeEntry, or a single AttributeEntry.") + raise TypeError( + "attribute_update must be an instance of AttributeUpdate, " + "list of AttributeEntry, or a single AttributeEntry.") self.device_name = device_name self.entries = attribute_update.entries self.attribute_update = attribute_update - def __str__(self) -> str: return f"GatewayAttributeUpdate(device_name={self.device_name}, attribute_update={self.attribute_update})" diff --git a/tb_mqtt_client/entities/gateway/gateway_claim_request.py b/tb_mqtt_client/entities/gateway/gateway_claim_request.py index dbf57e8..7330a2e 100644 --- a/tb_mqtt_client/entities/gateway/gateway_claim_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_claim_request.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from dataclasses import dataclass from typing import Dict, Any, Union @@ -24,12 +23,12 @@ @dataclass(slots=True, frozen=True) class GatewayClaimRequest(BaseGatewayEvent): - devices_requests: Dict[Union[DeviceSession, str], ClaimRequest] = None event_type: GatewayEventType = GatewayEventType.GATEWAY_CLAIM_REQUEST def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of GatewayClaimRequest is not allowed. Use 'GatewayClaimRequestBuilder.build(...)'.") + raise TypeError( + "Direct instantiation of GatewayClaimRequest is not allowed. Use 'GatewayClaimRequestBuilder.build(...)'.") def __repr__(self) -> str: return f"GatewayClaimRequest(devices_requests={self.devices_requests})" @@ -68,15 +67,18 @@ class GatewayClaimRequestBuilder: Builder class for GatewayClaimRequest. Allows adding multiple device claim requests in a fluent interface style. """ + def __init__(self): self._devices_requests: Dict[Union[DeviceSession, str], ClaimRequest] = {} - def add_device_request(self, device_name_or_session: Union[DeviceSession, str], device_claim_request: ClaimRequest) -> 'GatewayClaimRequestBuilder': + def add_device_request(self, device_name_or_session: Union[DeviceSession, str], + device_claim_request: ClaimRequest) -> 'GatewayClaimRequestBuilder': """ Add a device claim request to the builder. """ if not isinstance(device_name_or_session, (DeviceSession, str)): - raise ValueError("device_session must be an instance of DeviceSession or a string representing the device name") + raise ValueError( + "device_session must be an instance of DeviceSession or a string representing the device name") if not isinstance(device_claim_request, ClaimRequest): raise ValueError("device_claim_request must be an instance of ClaimRequest") self._devices_requests[device_name_or_session] = device_claim_request @@ -89,4 +91,4 @@ def build(self) -> GatewayClaimRequest: gateway_claim_request = GatewayClaimRequest.build() for device_session, claim_request in self._devices_requests.items(): gateway_claim_request.add_device_request(device_session, claim_request) - return gateway_claim_request \ No newline at end of file + return gateway_claim_request diff --git a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py index 612843d..712531d 100644 --- a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py @@ -26,7 +26,6 @@ @dataclass(slots=True, frozen=True) class GatewayRequestedAttributeResponse(RequestedAttributeResponse, BaseGatewayEvent): - device_name: str = "" request_id: int = -1 shared: Optional[List[AttributeEntry]] = None @@ -37,7 +36,8 @@ def __post_init__(self): super(BaseGatewayEvent, self).__setattr__('event_type', GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE) def __repr__(self): - return f"GatewayRequestedAttributeResponse(device_name={self.device_name},request_id={self.request_id}, shared={self.shared}, client={self.client})" + return (f"GatewayRequestedAttributeResponse(device_name={self.device_name},request_id={self.request_id}, " + f"shared={self.shared}, client={self.client})") def __getitem__(self, item): """ diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py index b47c423..874fe1a 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -28,13 +28,17 @@ class GatewayRPCRequest(BaseGatewayEvent): event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_REQUEST def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of GatewayRPCRequest is not allowed. Use 'GatewayRPCRequest._deserialize_from_dict(...)'.") + raise TypeError( + "Direct instantiation of GatewayRPCRequest is not allowed. " + "Use 'GatewayRPCRequest._deserialize_from_dict(...)'.") def __repr__(self): - return f"RPCRequest(id={self.request_id}, device_name={self.device_name}, method={self.method}, params={self.params})" + return (f"RPCRequest(id={self.request_id}, device_name={self.device_name}, " + f"method={self.method}, params={self.params})") def __str__(self): - return f"RPCRequest(id={self.request_id}, device_name={self.device_name}, method={self.method}, params={self.params})" + return (f"RPCRequest(id={self.request_id}, device_name={self.device_name}, " + f"method={self.method}, params={self.params})") @classmethod def _deserialize_from_dict(cls, data: Dict[str, Union[str, Dict[str, Any]]]) -> 'GatewayRPCRequest': diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py index f218798..45055a9 100644 --- a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py @@ -37,13 +37,16 @@ class GatewayRPCResponse(RPCResponse, BaseGatewayEvent): event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_RESPONSE def __new__(cls, *args, **kwargs): - raise TypeError("Direct instantiation of GatewayRPCResponse is not allowed. Use GatewayRPCResponse.build(device_name, request_id, result, error).") + raise TypeError( + "Direct instantiation of GatewayRPCResponse is not allowed. " + "Use GatewayRPCResponse.build(device_name, request_id, result, error).") def __repr__(self) -> str: - return f"GatewayRPCResponse(device_name={self.device_name}, request_id={self.request_id}, result={self.result}, error={self.error})" + return (f"GatewayRPCResponse(device_name={self.device_name}, request_id={self.request_id}, " + f"result={self.result}, error={self.error})") @classmethod - def build(cls, # noqa + def build(cls, # noqa device_name: str, request_id: int, result: Optional[Any] = None, diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py index 5c79368..a80e44b 100644 --- a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -39,7 +39,8 @@ class GatewayUplinkMessage(DeviceUplinkMessage, BaseGatewayEvent): def __new__(cls, *args, **kwargs): raise TypeError( - "Direct instantiation of GatewayUplinkMessage is not allowed. Use GatewayUplinkMessageBuilder to construct instances.") + "Direct instantiation of GatewayUplinkMessage is not allowed. " + "Use GatewayUplinkMessageBuilder to construct instances.") def __repr__(self): return (f"GatewayUplinkMessage(device_name={self.device_name}, " @@ -128,8 +129,10 @@ def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]] self.__size += attribute.size return self - def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry], OrderedDict[ - int, List[TimeseriesEntry]]]) -> 'GatewayUplinkMessageBuilder': + def add_timeseries(self, + timeseries: Union[TimeseriesEntry, + List[TimeseriesEntry], + OrderedDict[int, List[TimeseriesEntry]]]) -> 'GatewayUplinkMessageBuilder': if isinstance(timeseries, OrderedDict): self._timeseries = timeseries return self @@ -149,8 +152,9 @@ def add_timeseries(self, timeseries: Union[TimeseriesEntry, List[TimeseriesEntry self.__size += entry.size return self - def add_delivery_futures(self, futures: Union[ - asyncio.Future[PublishResult], List[asyncio.Future[PublishResult]]]) -> 'GatewayUplinkMessageBuilder': + def add_delivery_futures(self, + futures: Union[asyncio.Future[PublishResult], + List[asyncio.Future[PublishResult]]]) -> 'GatewayUplinkMessageBuilder': if not isinstance(futures, list): futures = [futures] if futures: diff --git a/tb_mqtt_client/service/__init__.py b/tb_mqtt_client/service/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/service/__init__.py +++ b/tb_mqtt_client/service/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py index 60c6b1b..61df408 100644 --- a/tb_mqtt_client/service/base_client.py +++ b/tb_mqtt_client/service/base_client.py @@ -70,16 +70,14 @@ async def send_timeseries(self, Dict[str, Any], List[Dict[str, Any]]], wait_for_publish: bool = True, - timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], - PublishResult, - None, - List[PublishResult], - List[asyncio.Future[PublishResult]]]: + timeout: Optional[float] = None) -> Optional[Union[asyncio.Future[PublishResult], + PublishResult, + List[PublishResult], + List[asyncio.Future[PublishResult]]]]: """ Sends timeseries data to the ThingsBoard server. :param data: Timeseries data to send, can be a single TimeseriesEntry, a list of TimeseriesEntries, a dictionary of key-value pairs, or a list of dictionaries. - :param qos: Quality of Service level for the MQTT message. :param wait_for_publish: If True, waits for the publish result. :param timeout: Timeout for waiting for the publish result. :return: PublishResult or list of PublishResults if wait_for_publish is True, Future or list of Futures if not, @@ -142,9 +140,9 @@ def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awa @staticmethod def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], - TimeseriesEntry, - List[TimeseriesEntry], - List[Dict[str, Any]]], + TimeseriesEntry, + List[TimeseriesEntry], + List[Dict[str, Any]]], device_session: Optional[DeviceSession] = None, ) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: timeseries_entries = [] @@ -192,9 +190,9 @@ def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> L @staticmethod def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], - AttributeEntry, - List[AttributeEntry]], - device_session = None) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: + AttributeEntry, + List[AttributeEntry]], + device_session=None) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: if isinstance(payload, dict): payload = [AttributeEntry(k, v) for k, v in payload.items()] @@ -206,4 +204,4 @@ def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], else: message_builder = DeviceUplinkMessageBuilder() message_builder.add_attributes(payload) - return message_builder.build() \ No newline at end of file + return message_builder.build() diff --git a/tb_mqtt_client/service/device/__init__.py b/tb_mqtt_client/service/device/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/service/device/__init__.py +++ b/tb_mqtt_client/service/device/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index c42f9c5..7424baf 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -173,9 +173,6 @@ async def stop(self): async def disconnect(self): await self._mqtt_manager.disconnect() - # if self._message_queue: - # await self._message_queue.shutdown() - # TODO: Not sure if we need to shutdown the message queue here, as it might be handled by MQTTManager async def send_telemetry(self, *args, **kwargs): """ @@ -187,15 +184,13 @@ async def send_telemetry(self, *args, **kwargs): async def send_timeseries( self, data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], - qos: int = 1, wait_for_publish: bool = True, timeout: Optional[float] = None - ) -> Union[PublishResult, List[PublishResult], None, Future[PublishResult], List[Future[PublishResult]]]: + ) -> Optional[Union[PublishResult, List[PublishResult], Future[PublishResult], List[Future[PublishResult]]]]: """ Sends timeseries data to the ThingsBoard server. :param data: Timeseries data to send, can be a single TimeseriesEntry, a list of TimeseriesEntries, a dictionary of key-value pairs, or a list of dictionaries. - :param qos: Quality of Service level for the MQTT message. :param wait_for_publish: If True, waits for the publish result. :param timeout: Timeout for waiting for the publish result. :return: PublishResult or list of PublishResults if wait_for_publish is True, Future or list of Futures if not, @@ -205,7 +200,7 @@ async def send_timeseries( mqtt_message = MqttPublishMessage( topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, payload=message, - qos=qos or self._config.qos, + qos=self._config.qos, datapoints_count=message.timeseries_datapoint_count() ) delivery_future = mqtt_message.delivery_futures @@ -224,21 +219,20 @@ async def send_timeseries( result = await await_or_stop(delivery_future, timeout=1, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for telemetry publish result") - result = PublishResult(mqtt_message.topic, qos, -1, message.size, -1) + result = PublishResult(mqtt_message.topic, self._config.qos, -1, message.size, -1) return result async def send_attributes( self, attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], - qos: int = None, wait_for_publish: bool = True, timeout: int = BaseClient.DEFAULT_TIMEOUT - ) -> Union[PublishResult, List[PublishResult], None, Future[PublishResult], List[Future[PublishResult]]]: + ) -> Optional[Union[PublishResult, List[PublishResult], Future[PublishResult], List[Future[PublishResult]]]]: message = self._build_uplink_message_for_attributes(attributes) mqtt_message = MqttPublishMessage( topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, payload=message, - qos=qos or self._config.qos, + qos=self._config.qos, datapoints_count=message.attributes_datapoint_count() ) @@ -259,7 +253,7 @@ async def send_attributes( result = await await_or_stop(fut, timeout=timeout, stop_event=self._stop_event) except TimeoutError: logger.warning("Timeout while waiting for attribute publish result") - result = PublishResult(mqtt_message.topic, qos, -1, message.size, -1) + result = PublishResult(mqtt_message.topic, self._config.qos, -1, message.size, -1) results.append(result) return results[0] if len(results) == 1 else results @@ -270,7 +264,7 @@ async def send_rpc_request( callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None, wait_for_publish: bool = True, timeout: Optional[float] = BaseClient.DEFAULT_TIMEOUT - ) -> Union[RPCResponse, Awaitable[RPCResponse], None]: + ) -> Optional[Union[RPCResponse, Awaitable[RPCResponse]]]: request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() message_to_send = self._message_adapter.build_rpc_request(rpc_request) message_to_send.qos = self._config.qos @@ -354,8 +348,10 @@ async def _on_connect(self): return self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, self._handle_rpc_request) # noqa - self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, self._handle_requested_attribute_response) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, + self._handle_rpc_request) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, + self._handle_requested_attribute_response) # noqa # RPC responses are handled by the RPCResponseHandler, which is already registered async def _on_disconnect(self): @@ -386,9 +382,12 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa rate_limits = response.result.get('rateLimits', {}) - await self._rate_limiter.message_rate_limit.set_limit(rate_limits.get("messages", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) - await self._rate_limiter.telemetry_message_rate_limit.set_limit(rate_limits.get("telemetryMessages", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) - await self._rate_limiter.telemetry_datapoints_rate_limit.set_limit(rate_limits.get("telemetryDataPoints", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.message_rate_limit.set_limit(rate_limits.get("messages", "0:0,"), + percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.telemetry_message_rate_limit.set_limit( + rate_limits.get("telemetryMessages", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.telemetry_datapoints_rate_limit.set_limit( + rate_limits.get("telemetryDataPoints", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) server_inflight = int(response.result.get("maxInflightMessages", 100)) limits = [rl.minimal_limit for rl in [ @@ -422,7 +421,7 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa self._message_adapter.splitter.max_payload_size = self.max_payload_size logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) - self._message_adapter.splitter.max_datapoints = self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit + self._message_adapter.splitter.max_datapoints = self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit # noqa if (not self._rate_limiter.message_rate_limit.has_limit() and not self._rate_limiter.telemetry_message_rate_limit.has_limit() diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py index 4a716db..6dd2a0d 100644 --- a/tb_mqtt_client/service/device/firmware_updater.py +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -22,6 +22,7 @@ from tb_mqtt_client.common.install_package_utils import install_package from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.constants import mqtt_topics from tb_mqtt_client.constants.firmware import ( FW_CHECKSUM_ALG_ATTR, @@ -100,7 +101,8 @@ async def _get_next_chunk(self): payload = str(self._chunk_size).encode() topic = mqtt_topics.build_firmware_update_request_topic(self._firmware_request_id, self._current_chunk) - await self._client._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) + mqtt_message = MqttPublishMessage(topic, payload) + await self._client._message_queue.publish(mqtt_message, wait_for_publish=True) async def _verify_downloaded_firmware(self): self._log.info('Verifying downloaded firmware...') diff --git a/tb_mqtt_client/service/device/handlers/__init__.py b/tb_mqtt_client/service/device/handlers/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/service/device/handlers/__init__.py +++ b/tb_mqtt_client/service/device/handlers/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py index 78f7139..bd2f0c7 100644 --- a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py @@ -29,7 +29,8 @@ class RequestedAttributeResponseHandler: def __init__(self): self._message_adapter = None - self._pending_attribute_requests: Dict[int, Tuple[AttributeRequest, Callable[[RequestedAttributeResponse], Awaitable[None]]]] = {} + self._pending_attribute_requests: Dict[ + int, Tuple[AttributeRequest, Callable[[RequestedAttributeResponse], Awaitable[None]]]] = {} def set_message_adapter(self, message_adapter: MessageAdapter): """ @@ -41,7 +42,8 @@ def set_message_adapter(self, message_adapter: MessageAdapter): self._message_adapter = message_adapter logger.debug("Message adapter set for RequestedAttributeResponseHandler.") - async def register_request(self, request: AttributeRequest, callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + async def register_request(self, request: AttributeRequest, + callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): """ Called when a request is sent to the platform and a response is awaited. """ @@ -52,7 +54,7 @@ async def register_request(self, request: AttributeRequest, callback: Callable[[ def unregister_request(self, request_id: int): """ - Unregisters a request by its ID. + Unregister a request by its ID. This is useful if the request is no longer needed or has timed out. """ if request_id in self._pending_attribute_requests: @@ -73,18 +75,22 @@ async def handle(self, topic: str, payload: bytes): return requested_attribute_response = self._message_adapter.parse_requested_attribute_response(topic, payload) - pending_request_details = self._pending_attribute_requests.pop(requested_attribute_response.request_id, None) + pending_request_details = self._pending_attribute_requests.pop(requested_attribute_response.request_id, + None) if not pending_request_details: - logger.warning("No future awaiting request ID %s. Ignoring.", requested_attribute_response.request_id) + logger.warning("No future awaiting request ID %s. Ignoring.", + requested_attribute_response.request_id) return request, callback = pending_request_details if callback: - logger.trace("Invoking callback for requested attribute response with ID %s", requested_attribute_response.request_id) + logger.trace("Invoking callback for requested attribute response with ID %s", + requested_attribute_response.request_id) await callback(requested_attribute_response) else: - logger.error("No callback registered for requested attribute response with ID %s", requested_attribute_response.request_id) + logger.error("No callback registered for requested attribute response with ID %s", + requested_attribute_response.request_id) except Exception as e: logger.exception("Failed to handle requested attribute response: %s", e) diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py index 979ba5e..64955a3 100644 --- a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -47,7 +47,8 @@ def set_message_adapter(self, message_adapter: MessageAdapter): logger.debug("Message adapter set for RPCResponseHandler.") def register_request(self, request_id: Union[str, int], - callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None) -> asyncio.Future[RPCResponse]: + callback: Optional[Callable[[RPCResponse], + Awaitable[None]]] = None) -> asyncio.Future[RPCResponse]: """ Called when a request is sent to the platform and a response is awaited. """ diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index 67f44b0..f88684e 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -45,8 +45,8 @@ def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optio @abstractmethod def build_uplink_messages( - self, - messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + self, + messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: """ Build a list of topic-payload pairs from the given messages. Each pair consists of a topic string, payload bytes, the number of datapoints, @@ -133,7 +133,9 @@ def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RP pass @abstractmethod - def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, payload: bytes) -> ProvisioningResponse: + def parse_provisioning_response(self, + provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: """ Parse the provisioning response from the given payload. This method should be implemented to handle the specific format of the provisioning response. @@ -145,6 +147,7 @@ class JsonMessageAdapter(MessageAdapter): """ A concrete implementation of MessageDispatcher that operates with JSON payloads. """ + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): super().__init__(max_payload_size, max_datapoints) logger.trace("JsonMessageDispatcher created.") @@ -161,7 +164,8 @@ def parse_requested_attribute_response(self, topic: str, payload: bytes) -> Requ data = loads(payload) logger.trace("Parsing attribute request response from payload: %s", data) if not isinstance(data, dict): - logger.error("Invalid requested attribute response format: expected dict, got %s", type(data).__name__) + logger.error("Invalid requested attribute response format: expected dict, got %s", + type(data).__name__) raise ValueError("Invalid requested attribute response format") data["request_id"] = request_id # Add request_id to the data dictionary return RequestedAttributeResponse.from_dict(data) @@ -178,7 +182,7 @@ def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: try: data = loads(payload) logger.trace("Parsing attribute update from payload: %s", data) - return AttributeUpdate._deserialize_from_dict(data) + return AttributeUpdate._deserialize_from_dict(data) # noqa except Exception as e: logger.error("Failed to parse attribute update: %s", str(e)) raise ValueError("Invalid attribute update format") from e @@ -212,13 +216,14 @@ def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RP data = RPCResponse.build(request_id, error=payload) else: parsed = loads(payload) - data = RPCResponse.build(request_id, parsed) # noqa + data = RPCResponse.build(request_id, parsed) return data except Exception as e: logger.error("Failed to parse RPC response: %s", str(e)) raise ValueError("Invalid RPC response format") from e - def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, payload: bytes) -> 'ProvisioningResponse': + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> 'ProvisioningResponse': """ Parse the provisioning response from the given payload. :param provisioning_request: The ProvisioningRequest that initiated the provisioning. @@ -278,7 +283,7 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt if logger.isEnabledFor(TRACE_LEVEL): logger.trace( "Built telemetry payload for '%s' with %d datapoints, futures=%r", - device_name, count, [f.uuid for f in child_futures] + device_name, count, [f.uuid for f in child_futures] # noqa ) if attr_msgs: @@ -302,7 +307,7 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt if logger.isEnabledFor(TRACE_LEVEL): logger.trace( "Built attribute payload for '%s' with %d attributes, futures=%r", - device_name, count, [f.uuid for f in child_futures] + device_name, count, [f.uuid for f in child_futures] # noqa ) # Register child futures to all original parent futures @@ -379,7 +384,8 @@ def build_provision_request(self, provision_request: 'ProvisioningRequest') -> M :param provision_request: The ProvisioningRequest to build the payload for. :return: A tuple of topic and payload bytes. """ - if not provision_request.credentials.provision_device_key or not provision_request.credentials.provision_device_secret: + if (not provision_request.credentials.provision_device_key + or not provision_request.credentials.provision_device_secret): raise ValueError("ProvisioningRequest must have valid device key and secret.") topic = mqtt_topics.PROVISION_REQUEST_TOPIC diff --git a/tb_mqtt_client/service/device/message_splitter.py b/tb_mqtt_client/service/device/message_splitter.py index 063f7fe..0b088af 100644 --- a/tb_mqtt_client/service/device/message_splitter.py +++ b/tb_mqtt_client/service/device/message_splitter.py @@ -31,7 +31,10 @@ class MessageSplitter(BaseMessageSplitter): DEFAULT_MAX_PAYLOAD_SIZE = 55_000 # Default to 55_000 to allow for some overhead def __init__(self, max_payload_size: int = DEFAULT_MAX_PAYLOAD_SIZE, max_datapoints: int = 0): - self._max_payload_size = max_payload_size if max_payload_size is not None and max_payload_size > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + if max_payload_size is not None and max_payload_size > 0: + self._max_payload_size = max_payload_size + else: + self._max_payload_size = self.DEFAULT_MAX_PAYLOAD_SIZE self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 logger.trace("MessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", self._max_payload_size, self._max_datapoints) @@ -157,7 +160,8 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp for parent in parent_futures: future_map.register(parent, [shared_future]) - logger.trace("Flushed attribute batch (count=%d, size=%d)", len(built.attributes), size) + logger.trace("Flushed attribute batch (count=%d, size=%d)", + len(built.attributes), size) builder = DeviceUplinkMessageBuilder() \ .set_device_name(device_name) \ .set_device_profile(device_profile) \ @@ -179,7 +183,8 @@ def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUp for parent in parent_futures: future_map.register(parent, [shared_future]) - logger.trace("Flushed final attribute batch (count=%d, size=%d)", len(built.attributes), size) + logger.trace("Flushed final attribute batch (count=%d, size=%d)", + len(built.attributes), size) logger.trace("Total attribute batches created: %d", len(result)) return result diff --git a/tb_mqtt_client/service/gateway/__init__.py b/tb_mqtt_client/service/gateway/__init__.py index fa669aa..cff354e 100644 --- a/tb_mqtt_client/service/gateway/__init__.py +++ b/tb_mqtt_client/service/gateway/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py index 300fe8e..c4e1f96 100644 --- a/tb_mqtt_client/service/gateway/client.py +++ b/tb_mqtt_client/service/gateway/client.py @@ -68,14 +68,21 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): self._event_dispatcher: DirectEventDispatcher = DirectEventDispatcher() self._uplink_message_sender = GatewayMessageSender() - self._event_dispatcher.register(GatewayEventType.DEVICE_CONNECT, self._uplink_message_sender.send_device_connect) - self._event_dispatcher.register(GatewayEventType.DEVICE_DISCONNECT, self._uplink_message_sender.send_device_disconnect) - self._event_dispatcher.register(GatewayEventType.DEVICE_UPLINK, self._uplink_message_sender.send_uplink_message) - self._event_dispatcher.register(GatewayEventType.DEVICE_ATTRIBUTE_REQUEST, self._uplink_message_sender.send_attributes_request) - self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_RESPONSE, self._uplink_message_sender.send_rpc_response) - self._event_dispatcher.register(GatewayEventType.GATEWAY_CLAIM_REQUEST, self._uplink_message_sender.send_claim_request) - - self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter(1000, 1) # Default max payload size and datapoints count limit, should be changed after connection established + self._event_dispatcher.register(GatewayEventType.DEVICE_CONNECT, + self._uplink_message_sender.send_device_connect) + self._event_dispatcher.register(GatewayEventType.DEVICE_DISCONNECT, + self._uplink_message_sender.send_device_disconnect) + self._event_dispatcher.register(GatewayEventType.DEVICE_UPLINK, + self._uplink_message_sender.send_uplink_message) + self._event_dispatcher.register(GatewayEventType.DEVICE_ATTRIBUTE_REQUEST, + self._uplink_message_sender.send_attributes_request) + self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_RESPONSE, + self._uplink_message_sender.send_rpc_response) + self._event_dispatcher.register(GatewayEventType.GATEWAY_CLAIM_REQUEST, + self._uplink_message_sender.send_claim_request) + + # Default max payload size and datapoints count limit, should be changed after connection established + self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter(1000, 1) self._uplink_message_sender.set_message_adapter(self._gateway_message_adapter) self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed @@ -83,12 +90,14 @@ def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): message_adapter=self._gateway_message_adapter, device_manager=self.device_manager, stop_event=self._stop_event) - self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler(event_dispatcher=self._event_dispatcher, - message_adapter=self._gateway_message_adapter, - device_manager=self.device_manager) - self._gateway_requested_attribute_response_handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=self._event_dispatcher, - message_adapter=self._gateway_message_adapter, - device_manager=self.device_manager) + self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler( + event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self.device_manager) + self._gateway_requested_attribute_response_handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self.device_manager) # Gateway-specific rate limits self._device_messages_rate_limit = RateLimit("0:0,", name="device_messages") @@ -114,11 +123,12 @@ async def connect(self): logger.info("Gateway connected to ThingsBoard.") - async def connect_device(self, device_name_or_device_connect_message: Union[str, DeviceConnectMessage], device_profile: str = 'default', - wait_for_publish=False) -> Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: + wait_for_publish=False) -> Tuple[ + DeviceSession, + List[Union[PublishResult, Future[PublishResult]]]]: """ Connect a device to the gateway. @@ -195,8 +205,14 @@ async def disconnect_device(self, device_session: DeviceSession, wait_for_publis async def send_device_timeseries(self, device_session: DeviceSession, - data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: + data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: """ Send timeseries data to the platform for a specific device. :param device_session: The DeviceSession object for the device @@ -231,18 +247,21 @@ async def send_device_timeseries(self, return results[0] if len(results) == 1 else results - async def send_device_attributes(self, device_session: DeviceSession, data: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: """ Send attributes data to the platform for a specific device. :param device_session: The DeviceSession object for the device :param data: Attributes data to send, can be a single entry or a list of entries :param wait_for_publish: Whether to wait for the publish result """ - logger.trace("Sending attributes data for device %s", device_session.device_info.device_name) + logger.trace("Sending attributes data for device %s", + device_session.device_info.device_name) if not device_session or not data: logger.warning("No device session or data provided for sending attributes") return None @@ -264,23 +283,27 @@ async def send_device_attributes(self, results.append(result) return results[0] if len(results) == 1 else results - async def send_device_attributes_request(self, device_session: DeviceSession, attribute_request: Union[AttributeRequest, GatewayAttributeRequest], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: """ Send a request for device attributes to the platform. :param device_session: The DeviceSession object for the device :param attribute_request: Attributes to request, can be a single AttributeRequest or GatewayAttributeRequest :param wait_for_publish: Whether to wait for the publish result """ - logger.trace("Sending attributes request for device %s", device_session.device_info.device_name) + logger.trace("Sending attributes request for device %s", + device_session.device_info.device_name) if not device_session or not attribute_request: logger.warning("No device session or attributes provided for sending attributes request") return None if isinstance(attribute_request, AttributeRequest): - attribute_request = await GatewayAttributeRequest.from_attribute_request(device_session=device_session, attribute_request=attribute_request) + attribute_request = await GatewayAttributeRequest.from_attribute_request(device_session=device_session, + attribute_request=attribute_request) # noqa await self._gateway_requested_attribute_response_handler.register_request(attribute_request) futures = await self._event_dispatcher.dispatch(attribute_request, qos=self._config.qos) @@ -306,7 +329,10 @@ async def send_device_attributes_request(self, async def send_device_claim_request(self, device_session: DeviceSession, gateway_claim_request: GatewayClaimRequest, - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], PublishResult, Future[PublishResult], None]: + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: """ Send a claim request for a device to the platform. :param device_session: The DeviceSession object for the device @@ -365,9 +391,12 @@ async def _subscribe_to_gateway_topics(self): while not sub_future.done(): await sleep(0.01) - self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, self._gateway_attribute_updates_handler.handle) - self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_RPC_TOPIC, self._gateway_rpc_handler.handle) - self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, self._gateway_requested_attribute_response_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, + self._gateway_attribute_updates_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_RPC_TOPIC, + self._gateway_rpc_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, + self._gateway_requested_attribute_response_handler.handle) async def _unsubscribe_from_gateway_topics(self): """ @@ -404,12 +433,15 @@ async def _handle_rate_limit_response(self, response: RPCResponse): # noqa try: gateway_rate_limits = response.result.get('gatewayRateLimits', {}) - await self._gateway_rate_limiter.message_rate_limit.set_limit(gateway_rate_limits.get('messages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) - await self._gateway_rate_limiter.telemetry_message_rate_limit.set_limit(gateway_rate_limits.get('telemetryMessages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) - await self._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit(gateway_rate_limits.get('telemetryDataPoints', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.message_rate_limit.set_limit(gateway_rate_limits.get('messages', '0:0,'), + percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.telemetry_message_rate_limit.set_limit( + gateway_rate_limits.get('telemetryMessages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit( + gateway_rate_limits.get('telemetryDataPoints', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) self._gateway_message_adapter.splitter.max_payload_size = self.max_payload_size - self._gateway_message_adapter.splitter.max_datapoints = self._gateway_rate_limiter.telemetry_datapoints_rate_limit.minimal_limit + self._gateway_message_adapter.splitter.max_datapoints = self._gateway_rate_limiter.telemetry_datapoints_rate_limit.minimal_limit # noqa self._mqtt_manager.set_gateway_rate_limits_received() return device_rate_limits_processing_result diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py index 281e611..204e940 100644 --- a/tb_mqtt_client/service/gateway/device_session.py +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -15,7 +15,7 @@ import asyncio from dataclasses import dataclass, field from time import time -from typing import Callable, Awaitable, Optional, Union +from typing import Callable, Awaitable, Optional from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse @@ -37,9 +37,12 @@ class DeviceSession: provisioned: bool = False state: DeviceSessionState = DeviceSessionState.CONNECTED - attribute_update_callback: Optional[Callable[['DeviceSession','AttributeUpdate'], Union[Awaitable[None], None]]] = None - attribute_response_callback: Optional[Callable[['DeviceSession','RequestedAttributeResponse'], Union[Awaitable[None], None]]] = None - rpc_request_callback: Optional[Callable[['DeviceSession','GatewayRPCRequest'], Union[Awaitable[Union['GatewayRPCResponse', None]], None]]] = None + attribute_update_callback: Optional[Callable[['DeviceSession', 'AttributeUpdate'], + Optional[Awaitable[None]]]] = None + attribute_response_callback: Optional[Callable[['DeviceSession', 'RequestedAttributeResponse'], + Optional[Awaitable[None]]]] = None + rpc_request_callback: Optional[Callable[['DeviceSession', 'GatewayRPCRequest'], + Optional[Awaitable[Optional['GatewayRPCResponse']]]]] = None def update_state(self, new_state: DeviceSessionState): self.state = new_state @@ -49,16 +52,20 @@ def update_state(self, new_state: DeviceSessionState): def update_last_seen(self): self.last_seen_at = int(time() * 1000) - def set_attribute_update_callback(self, cb: Callable[['DeviceSession','AttributeUpdate'], Union[Awaitable[None], None]]): + def set_attribute_update_callback(self, + cb: Callable[['DeviceSession', 'AttributeUpdate'], Optional[Awaitable[None]]]): self.attribute_update_callback = cb - def set_attribute_response_callback(self, cb: Callable[['DeviceSession','RequestedAttributeResponse'], Union[Awaitable[None], None]]): + def set_attribute_response_callback(self, cb: Callable[ + ['DeviceSession', 'RequestedAttributeResponse'], Optional[Awaitable[None]]]): self.attribute_response_callback = cb - def set_rpc_request_callback(self, cb: Callable[['DeviceSession','GatewayRPCRequest'], Union[Awaitable[Union['GatewayRPCResponse', None]], None]]): + def set_rpc_request_callback(self, cb: Callable[ + ['DeviceSession', 'GatewayRPCRequest'], Optional[Awaitable[Optional['GatewayRPCResponse']]]]): self.rpc_request_callback = cb - async def handle_event_to_device(self, event: BaseGatewayEvent) -> Optional[Awaitable[Union['GatewayRPCResponse', None]]]: + async def handle_event_to_device(self, event: BaseGatewayEvent) -> ( + Optional)[Awaitable[Optional['GatewayRPCResponse']]]: cb = None if GatewayEventType.DEVICE_ATTRIBUTE_UPDATE == event.event_type \ and isinstance(event, AttributeUpdate): diff --git a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py index b2ab01e..efc17c0 100644 --- a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py +++ b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py @@ -30,6 +30,7 @@ class DirectEventDispatcher: """ Direct event dispatcher for handling gateway events. """ + def __init__(self): self._handlers: Dict[GatewayEventType, List[EventCallback]] = defaultdict(list) self._lock = asyncio.Lock() @@ -46,7 +47,7 @@ def unregister(self, event_type: GatewayEventType, callback: EventCallback): if not self._handlers[event_type]: del self._handlers[event_type] - async def dispatch(self, event: GatewayEvent, *args, device_session: DeviceSession=None, **kwargs): + async def dispatch(self, event: GatewayEvent, *args, device_session: DeviceSession = None, **kwargs): if device_session is not None: return await device_session.handle_event_to_device(event) async with self._lock: diff --git a/tb_mqtt_client/service/gateway/gateway_client_interface.py b/tb_mqtt_client/service/gateway/gateway_client_interface.py index ccb8be2..21e70a1 100644 --- a/tb_mqtt_client/service/gateway/gateway_client_interface.py +++ b/tb_mqtt_client/service/gateway/gateway_client_interface.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from asyncio import Future -from typing import Union, List, Tuple, Dict, Any +from typing import Union, List, Tuple, Dict, Any, Optional from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry @@ -29,33 +29,46 @@ class GatewayClientInterface(BaseClient, ABC): @abstractmethod - async def connect_device(self, device_name: str, device_profile: str, wait_for_publish: bool) -> \ - Tuple[DeviceSession, List[Union[PublishResult, Future[PublishResult]]]]: ... + async def connect_device(self, + device_name: str, + device_profile: str, + wait_for_publish: bool) -> Tuple[DeviceSession, + List[Union[PublishResult, Future[PublishResult]]]]: ... @abstractmethod - async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool) -> \ - List[Union[PublishResult, Future[PublishResult]]]: ... + async def disconnect_device(self, + device_session: DeviceSession, + wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... @abstractmethod async def send_device_timeseries(self, device_session: DeviceSession, - data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... + data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool) -> Optional[List[Union[PublishResult, + Future[PublishResult]]]]: ... @abstractmethod async def send_device_attributes(self, device_session: DeviceSession, - attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... + attributes: Union[Dict[str, Any], + AttributeEntry, + List[AttributeEntry]], + wait_for_publish: bool) -> Optional[List[Union[PublishResult, + Future[PublishResult]]]]: ... @abstractmethod async def send_device_attributes_request(self, device_session: DeviceSession, attributes: Union[AttributeRequest, GatewayAttributeRequest], - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... + wait_for_publish: bool) -> Optional[ + List[Union[PublishResult, Future[PublishResult]]]]: ... @abstractmethod async def send_device_claim_request(self, device_session: DeviceSession, gateway_claim_request: GatewayClaimRequest, - wait_for_publish: bool) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: ... + wait_for_publish: bool) -> Union[List[Union[PublishResult, + Future[PublishResult]]], None]: ... diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py index 1f2ef78..717ed55 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py @@ -19,6 +19,7 @@ class GatewayAttributeUpdatesHandler: """Handles shared attribute updates for devices connected to a gateway.""" + def __init__(self, event_dispatcher: DirectEventDispatcher, message_adapter: GatewayMessageAdapter, @@ -36,4 +37,5 @@ async def handle(self, topic: str, payload: bytes): device_session = self.device_manager.get_by_name(gateway_attribute_update.device_name) if device_session: gateway_attribute_update.set_device_session(device_session) - await self.event_dispatcher.dispatch(gateway_attribute_update.attribute_update, device_session=device_session) # noqa + await self.event_dispatcher.dispatch(gateway_attribute_update.attribute_update, # noqa + device_session=device_session) diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py index 863c815..20377a6 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py @@ -24,7 +24,6 @@ logger = get_logger(__name__) - AttributeResponseCallback: TypeAlias = Callable[[GatewayRequestedAttributeResponse], Coroutine[Any, Any, None]] @@ -37,8 +36,7 @@ def __init__(self, event_dispatcher: Any, message_adapter: GatewayMessageAdapter self._event_dispatcher = event_dispatcher self._message_adapter: Union[GatewayMessageAdapter, None] = message_adapter self._device_manager = device_manager - self._pending_attribute_requests: Dict[Tuple[str, int], Tuple[GatewayAttributeRequest, - Union[Task, None]]] = {} + self._pending_attribute_requests: Dict[Tuple[str, int], Tuple[GatewayAttributeRequest, Union[Task, None]]] = {} async def register_request(self, request: GatewayAttributeRequest, @@ -55,7 +53,8 @@ async def register_request(self, if timeout > 0: timeout_task = asyncio.get_event_loop().call_later(timeout, self._on_timeout, device_name, request_id) self._pending_attribute_requests[key] = (request, timeout_task) - logger.debug("Registered attribute request with ID %s for device %s", request_id, device_name) + logger.debug("Registered attribute request with ID %s for device %s", + request_id, device_name) def unregister_request(self, device_name: str, request_id: int): """ @@ -65,9 +64,11 @@ def unregister_request(self, device_name: str, request_id: int): key = (device_name, request_id) if key in self._pending_attribute_requests: self._pending_attribute_requests.pop(key) - logger.debug("Unregistered attribute request with ID %s for device %s", request_id, device_name) + logger.debug("Unregistered attribute request with ID %s for device %s", + request_id, device_name) else: - logger.debug("Attempted to unregister non-existent request ID %s for device %s", request_id, device_name) + logger.debug("Attempted to unregister non-existent request ID %s for device %s", + request_id, device_name) async def handle(self, topic: str, payload: bytes): """ @@ -90,7 +91,8 @@ async def handle(self, topic: str, payload: bytes): attribute_request, timeout_task = attribute_request_with_callback if timeout_task: timeout_task.cancel() - requested_attribute_response = self._message_adapter.parse_gateway_requested_attribute_response(attribute_request, deserialized_data) + requested_attribute_response = self._message_adapter.parse_gateway_requested_attribute_response( + attribute_request, deserialized_data) device_session = self._device_manager.get_by_name(device_name) if not device_session: logger.warning("No device session found for device: %s", device_name) @@ -113,9 +115,11 @@ def _on_timeout(self, device_name: str, request_id: int): key = (device_name, request_id) if key in self._pending_attribute_requests: self._pending_attribute_requests.pop(key) - logger.warning("Request ID %s for device %s has timed out and has been unregistered.", request_id, device_name) + logger.warning("Request ID %s for device %s has timed out and has been unregistered.", + request_id, device_name) else: - logger.debug("Attempted to unregister non-existent request ID %s for device %s on timeout", request_id, device_name) + logger.debug("Attempted to unregister non-existent request ID %s for device %s on timeout", + request_id, device_name) def _handle_callback_exception(self, task: asyncio.Task): try: diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py index 465a270..a8dd5f5 100644 --- a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py +++ b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio from typing import Awaitable, Callable, Optional @@ -71,12 +72,14 @@ async def handle(self, topic: str, payload: bytes) -> None: elif not isinstance(result, GatewayRPCResponse): raise TypeError("RPC callback must return an instance of GatewayRPCResponse, got: %s", type(result)) logger.debug("RPC response for device %r method id: %i - %s with result: %s", - rpc_request.device_name, result.request_id, rpc_request.method, result.result) + rpc_request.device_name, result.request_id, rpc_request.method, result.result) except Exception as e: logger.exception("Failed to process RPC request: %s", e) if rpc_request is None: return None - result = GatewayRPCResponse.build(device_name=rpc_request.device_name, request_id=rpc_request.request_id, error=e) + result = GatewayRPCResponse.build(device_name=rpc_request.device_name, + request_id=rpc_request.request_id, + error=e) if not device_session: logger.warning("No device session found for device: %s, cannot send RPC response", @@ -86,8 +89,9 @@ async def handle(self, topic: str, payload: bytes) -> None: future = await self._event_dispatcher.dispatch(result) # noqa if not future: - logger.warning("No publish futures were returned from message queue for RPC response of device %s, request id %i", - rpc_request.device_name, rpc_request.request_id) + logger.warning( + "No publish futures were returned from message queue for RPC response of device %s, request id %i", + rpc_request.device_name, rpc_request.request_id) return None try: await await_or_stop(future, timeout=1, stop_event=self._stop_event) diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py index 8e7e23b..210b4ee 100644 --- a/tb_mqtt_client/service/gateway/message_adapter.py +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -45,6 +45,7 @@ class GatewayMessageAdapter(ABC): """ Adapter for converting events to uplink messages and received messages to events. """ + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): self._splitter = GatewayMessageSplitter(max_payload_size, max_datapoints) logger.trace("GatewayMessageAdapter initialized with max_payload_size=%s, max_datapoints=%s", @@ -60,8 +61,8 @@ def splitter(self) -> GatewayMessageSplitter: @abstractmethod def build_uplink_messages( - self, - messages: List[MqttPublishMessage] + self, + messages: List[MqttPublishMessage] ) -> List[MqttPublishMessage]: """ Build a list of topic-payload pairs from the given messages. @@ -71,7 +72,9 @@ def build_uplink_messages( pass @abstractmethod - def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage, qos) -> MqttPublishMessage: + def build_device_connect_message_payload(self, + device_connect_message: DeviceConnectMessage, + qos) -> MqttPublishMessage: """ Build the payload for a device connect message. This method should be implemented to handle the specific format of the payload. @@ -79,7 +82,9 @@ def build_device_connect_message_payload(self, device_connect_message: DeviceCon pass @abstractmethod - def build_device_disconnect_message_payload(self, device_disconnect_message: DeviceDisconnectMessage, qos) -> MqttPublishMessage: + def build_device_disconnect_message_payload(self, + device_disconnect_message: DeviceDisconnectMessage, + qos) -> MqttPublishMessage: """ Build the payload for a device disconnect message. This method should be implemented to handle the specific format of the payload. @@ -87,7 +92,9 @@ def build_device_disconnect_message_payload(self, device_disconnect_message: Dev pass @abstractmethod - def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest, qos) -> MqttPublishMessage: + def build_gateway_attribute_request_payload(self, + attribute_request: GatewayAttributeRequest, + qos) -> MqttPublishMessage: """ Build the payload for a gateway attribute request. This method should be implemented to handle the specific format of the payload. @@ -121,7 +128,8 @@ def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate @abstractmethod def parse_gateway_requested_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, - data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: + data: Dict[str, Any]) -> Union[ + GatewayRequestedAttributeResponse, None]: """ Parse the gateway attribute response data into an GatewayAttributeResponse. This method should be implemented to handle the specific format of the payload. @@ -225,33 +233,41 @@ def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[Mqtt logger.trace("Generated %d MqttPublishMessage(s) for gateway uplink.", len(result)) return result - def build_device_connect_message_payload(self, device_connect_message: DeviceConnectMessage, qos) -> MqttPublishMessage: + def build_device_connect_message_payload(self, + device_connect_message: DeviceConnectMessage, + qos) -> MqttPublishMessage: """ Build the payload for a device connect message. This method serializes the DeviceConnectMessage to JSON format. """ try: payload = dumps(device_connect_message.to_payload_format()) - logger.trace("Built device connect message payload for device='%s'", device_connect_message.device_name) + logger.trace("Built device connect message payload for device='%s'", + device_connect_message.device_name) return MqttPublishMessage(GATEWAY_CONNECT_TOPIC, payload, qos=1) except Exception as e: logger.error("Failed to build device connect message payload: %s", str(e)) raise ValueError("Invalid device connect message format") from e - def build_device_disconnect_message_payload(self,device_disconnect_message: DeviceDisconnectMessage, qos) -> MqttPublishMessage: + def build_device_disconnect_message_payload(self, + device_disconnect_message: DeviceDisconnectMessage, + qos) -> MqttPublishMessage: """ Build the payload for a device disconnect message. This method serializes the DeviceDisconnectMessage to JSON format. """ try: payload = dumps(device_disconnect_message.to_payload_format()) - logger.trace("Built device disconnect message payload for device='%s'", device_disconnect_message.device_name) + logger.trace("Built device disconnect message payload for device='%s'", + device_disconnect_message.device_name) return MqttPublishMessage(GATEWAY_DISCONNECT_TOPIC, payload, qos) except Exception as e: logger.error("Failed to build device disconnect message payload: %s", str(e)) raise ValueError("Invalid device disconnect message format") from e - def build_gateway_attribute_request_payload(self, attribute_request: GatewayAttributeRequest, qos) -> MqttPublishMessage: + def build_gateway_attribute_request_payload(self, + attribute_request: GatewayAttributeRequest, + qos) -> MqttPublishMessage: """ Build the payload for a gateway attribute request. This method serializes the GatewayAttributeRequest to JSON format. @@ -286,7 +302,8 @@ def build_claim_request_payload(self, claim_request: GatewayClaimRequest, qos) - """ try: payload = dumps(claim_request.to_payload_format()) - logger.trace("Built claim request payload for devices: %s", list(claim_request.devices_requests.keys())) + logger.trace("Built claim request payload for devices: %s", + list(claim_request.devices_requests.keys())) return MqttPublishMessage(GATEWAY_CLAIM_TOPIC, payload, qos) except Exception as e: logger.error("Failed to build claim request payload: %s", str(e)) @@ -303,7 +320,8 @@ def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate def parse_gateway_requested_attribute_response(self, gateway_attribute_request: GatewayAttributeRequest, - data: Dict[str, Any]) -> Union[GatewayRequestedAttributeResponse, None]: + data: Dict[str, Any]) -> Union[ + GatewayRequestedAttributeResponse, None]: """ Parse the gateway attribute response data into a GatewayRequestedAttributeResponse. This method extracts the device name, shared and client attributes from the payload. @@ -352,7 +370,7 @@ def parse_rpc_request(self, topic: str, data: Dict[str, Any]) -> GatewayRPCReque This method deserializes the payload into a GatewayRPCRequest object. """ try: - return GatewayRPCRequest._deserialize_from_dict(data) # noqa + return GatewayRPCRequest._deserialize_from_dict(data) # noqa except Exception as e: logger.error("Failed to parse RPC request: %s", str(e)) raise ValueError("Invalid RPC request format") from e diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py index 4c05d0c..1c364a3 100644 --- a/tb_mqtt_client/service/gateway/message_sender.py +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -41,7 +41,8 @@ def __init__(self): self._message_queue: Optional[MessageService] = None self._message_adapter: Optional[GatewayMessageAdapter] = None - async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: """ Sends a list of uplink messages to the platform. @@ -74,7 +75,8 @@ async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> Uni futures.extend(mqtt_message.delivery_futures) return futures - async def send_device_connect(self, device_connect_message: DeviceConnectMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_device_connect(self, device_connect_message: DeviceConnectMessage, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: """ Sends a device connect message to the platform. @@ -83,13 +85,16 @@ async def send_device_connect(self, device_connect_message: DeviceConnectMessage :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. """ if self._message_queue is None: - logger.error("Cannot send device connect message. Message queue is not set, do you connected to the platform?") + logger.error( + "Cannot send device connect message. Message queue is not set, do you connected to the platform?") return None - mqtt_message = self._message_adapter.build_device_connect_message_payload(device_connect_message=device_connect_message, qos=qos) + mqtt_message = self._message_adapter.build_device_connect_message_payload( + device_connect_message=device_connect_message, qos=qos) await self._message_queue.publish(mqtt_message) return mqtt_message.delivery_futures - async def send_device_disconnect(self, device_disconnect_message: DeviceDisconnectMessage, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_device_disconnect(self, device_disconnect_message: DeviceDisconnectMessage, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: """ Sends a device disconnect message to the platform. @@ -98,13 +103,16 @@ async def send_device_disconnect(self, device_disconnect_message: DeviceDisconne :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. """ if self._message_queue is None: - logger.error("Cannot send device disconnect message. Message queue is not set, do you connected to the platform?") + logger.error( + "Cannot send device disconnect message. Message queue is not set, do you connected to the platform?") return None - mqtt_message = self._message_adapter.build_device_disconnect_message_payload(device_disconnect_message=device_disconnect_message, qos=qos) + mqtt_message = self._message_adapter.build_device_disconnect_message_payload( + device_disconnect_message=device_disconnect_message, qos=qos) await self._message_queue.publish(mqtt_message) return mqtt_message.delivery_futures - async def send_attributes_request(self, attribute_request: GatewayAttributeRequest, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_attributes_request(self, attribute_request: GatewayAttributeRequest, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: """ Sends an attribute request message to the platform. @@ -115,11 +123,13 @@ async def send_attributes_request(self, attribute_request: GatewayAttributeReque if self._message_queue is None: logger.error("Cannot send attribute request. Message queue is not set, do you connected to the platform?") return None - mqtt_message = self._message_adapter.build_gateway_attribute_request_payload(attribute_request=attribute_request, qos=qos) + mqtt_message = self._message_adapter.build_gateway_attribute_request_payload( + attribute_request=attribute_request, qos=qos) await self._message_queue.publish(mqtt_message) return mqtt_message.delivery_futures - async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: """ Sends an RPC response message to the platform. @@ -134,7 +144,8 @@ async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> Un await self._message_queue.publish(mqtt_message) return mqtt_message.delivery_futures - async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> Union[List[Union[PublishResult, Future[PublishResult]]], None]: + async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: """ Sends a claim request message to the platform. diff --git a/tb_mqtt_client/service/gateway/message_splitter.py b/tb_mqtt_client/service/gateway/message_splitter.py index c9b416f..79a3796 100644 --- a/tb_mqtt_client/service/gateway/message_splitter.py +++ b/tb_mqtt_client/service/gateway/message_splitter.py @@ -32,7 +32,10 @@ class GatewayMessageSplitter(BaseMessageSplitter): DEFAULT_MAX_PAYLOAD_SIZE = 55000 # Default max payload size in bytes def __init__(self, max_payload_size: int = 55000, max_datapoints: int = 0): - self._max_payload_size = max_payload_size if max_payload_size is not None and max_payload_size > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + if max_payload_size is not None and max_payload_size > 0: + self._max_payload_size = max_payload_size + else: + self._max_payload_size = self.DEFAULT_MAX_PAYLOAD_SIZE self._max_payload_size = self._max_payload_size - DEFAULT_FIELDS_SIZE self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 logger.trace("GatewayMessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", @@ -190,4 +193,4 @@ def max_datapoints(self) -> int: def max_datapoints(self, value: int): old = self._max_datapoints self._max_datapoints = value if value > 0 else 0 - logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) \ No newline at end of file + logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) diff --git a/tb_mqtt_client/service/message_service.py b/tb_mqtt_client/service/message_service.py index a7bee96..0e8db72 100644 --- a/tb_mqtt_client/service/message_service.py +++ b/tb_mqtt_client/service/message_service.py @@ -14,7 +14,7 @@ import asyncio from contextlib import suppress -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage @@ -62,7 +62,7 @@ def __init__(self, self._main_stop_event, self._mqtt_manager, self._device_rate_limiter, - self._gateway_rate_limiter,) + self._gateway_rate_limiter) self._device_uplink_messages_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) self._device_uplink_message_worker = MessageQueueWorker("DeviceUplinkMessageWorker", self._device_uplink_messages_queue, @@ -89,9 +89,12 @@ def __init__(self, self._retry_by_qos_task = asyncio.create_task(self._dispatch_retry_by_qos_queue_loop()) self._initial_queue_task = asyncio.create_task(self._dispatch_initial_queue_loop()) - self._service_queue_task = asyncio.create_task(self._dispatch_queue_loop(self._service_queue, self._service_message_worker)) - self._device_uplink_messages_queue_task = asyncio.create_task(self._dispatch_queue_loop(self._device_uplink_messages_queue, self._device_uplink_message_worker)) - self._gateway_uplink_messages_queue_task = asyncio.create_task(self._dispatch_queue_loop(self._gateway_uplink_messages_queue, self._gateway_uplink_message_worker)) + self._service_queue_task = asyncio.create_task( + self._dispatch_queue_loop(self._service_queue, self._service_message_worker)) + self._device_uplink_messages_queue_task = asyncio.create_task( + self._dispatch_queue_loop(self._device_uplink_messages_queue, self._device_uplink_message_worker)) + self._gateway_uplink_messages_queue_task = asyncio.create_task( + self._dispatch_queue_loop(self._gateway_uplink_messages_queue, self._gateway_uplink_message_worker)) self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) self.__print_queue_statistics_task = asyncio.create_task(self.print_queues_statistics()) @@ -104,7 +107,8 @@ async def publish(self, message: MqttPublishMessage) -> Optional[List[asyncio.Fu """ try: if logger.isEnabledFor(TRACE_LEVEL): - logger.trace(f"Pushing message to queue with delivery futures: {[f.uuid for f in message.delivery_futures]}") + logger.trace( + f"Pushing message to queue with delivery futures: {[f.uuid for f in message.delivery_futures]}") await self._initial_queue.put(message) except Exception as e: logger.error("Failed to push message to queue: %s", e) @@ -136,7 +140,8 @@ async def _dispatch_initial_queue_loop(self): # If the message is a DeviceUplinkMessage, process it with the device adapter device_messages.append(message) else: - logger.warning("Unknown message type in initial queue: %s", type(message.original_payload)) + logger.warning("Unknown message type in initial queue: %s", + type(message.original_payload)) if gateway_messages: # Process gateway messages in batches @@ -172,11 +177,11 @@ async def _dispatch_queue_loop(self, queue: AsyncDeque, worker: 'MessageQueueWor await asyncio.sleep(0.01) continue logger.trace("Processing message from queue: %s, message_id: %s, message payload: %s", - message.topic, message.message_id, message.original_payload) + message.topic, message.message_id, message.original_payload) expected_duration, expected_tokens, triggered_rate_limit = await worker.process(message) if triggered_rate_limit: logger.trace("Reinserting message to the front of the queue: %s, message payload: %s", - message.uuid, message.original_payload) + message.uuid, message.original_payload) await queue.reinsert_front(message) triggered_rate_limit.set_required_tokens(expected_duration, expected_tokens) await triggered_rate_limit.required_tokens_ready.wait() @@ -213,7 +218,8 @@ async def _dispatch_retry_by_qos_queue_loop(self): elif isinstance(message.original_payload, DeviceUplinkMessage): await self._device_uplink_messages_queue.reinsert_front(message) else: - logger.warning("Unknown message type in retry queue: %s", type(message.original_payload)) + logger.warning("Unknown message type in retry queue: %s", + type(message.original_payload)) except asyncio.CancelledError: break @@ -245,7 +251,7 @@ async def shutdown(self): await self.clear() logger.debug("MessageQueue shutdown complete, message queue size: %d", - self._initial_queue.size()) + self._initial_queue.size()) @staticmethod async def _cancel_tasks(tasks: set[asyncio.Task]): @@ -261,7 +267,7 @@ def is_empty(self): async def clear(self): logger.debug("Clearing message queue...") for queue in [self._initial_queue, self._service_queue, - self._device_uplink_messages_queue, self._gateway_uplink_messages_queue]: + self._device_uplink_messages_queue, self._gateway_uplink_messages_queue]: while not queue.is_empty(): message: MqttPublishMessage = await queue.get() for future in message.delivery_futures: @@ -270,7 +276,8 @@ async def clear(self): topic=message.topic, qos=message.qos, message_id=-1, - payload_size=message.payload_size if isinstance(message.payload, bytes) else message.payload.size, + payload_size=message.payload_size if isinstance(message.payload, + bytes) else message.payload.size, reason_code=-1 )) logger.debug("Message queue cleared.") @@ -325,6 +332,7 @@ async def print_queues_statistics(self): active) await asyncio.sleep(60) + class MessageQueueWorker: def __init__(self, name, @@ -343,9 +351,10 @@ def __init__(self, async def process(self, message: MqttPublishMessage) -> Tuple[Optional[int], Optional[int], Optional[RateLimit]]: message_rate_limit, datapoints_rate_limit = self._get_rate_limits_for_message(message) if message_rate_limit.has_limit() or datapoints_rate_limit.has_limit(): - triggered_rate_limit_entry, expected_tokens, rate_limit = await self.check_rate_limits_for_message(datapoints_count=message.datapoints, - message_rate_limit=message_rate_limit, - datapoints_rate_limit=datapoints_rate_limit) + triggered_rate_limit_entry, expected_tokens, rate_limit = await self.check_rate_limits_for_message( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit) if triggered_rate_limit_entry is not None: triggered_duration = triggered_rate_limit_entry[1] @@ -359,7 +368,6 @@ async def process(self, message: MqttPublishMessage) -> Tuple[Optional[int], Opt await self._mqtt_manager.publish(message) return None, None, None - def _get_rate_limits_for_message(self, message: MqttPublishMessage) -> Tuple[RateLimit, RateLimit]: message_rate_limit = EMPTY_RATE_LIMIT @@ -383,8 +391,10 @@ def _get_rate_limits_for_message(self, message: MqttPublishMessage) -> Tuple[Rat @staticmethod async def check_rate_limits_for_message(datapoints_count: int, - message_rate_limit: RateLimit, - datapoints_rate_limit: RateLimit) -> Tuple[Union[Tuple[int, int], None], int, Optional[RateLimit]]: + message_rate_limit: RateLimit, + datapoints_rate_limit: RateLimit) -> Tuple[Optional[Tuple[int, int]], + int, + Optional[RateLimit]]: if message_rate_limit and message_rate_limit.has_limit(): triggered_rate_limit_entry = await message_rate_limit.try_consume(1) if triggered_rate_limit_entry: @@ -397,8 +407,8 @@ async def check_rate_limits_for_message(datapoints_count: int, @staticmethod async def _consume_rate_limits_for_message(datapoints_count: int, - message_rate_limit: RateLimit, - datapoints_rate_limit: RateLimit) -> None: + message_rate_limit: RateLimit, + datapoints_rate_limit: RateLimit) -> None: if message_rate_limit and message_rate_limit.has_limit(): await message_rate_limit.consume(1) if datapoints_rate_limit and datapoints_rate_limit.has_limit(): diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index b5030d6..fbc818f 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -44,19 +44,18 @@ class MQTTManager: - _PUBLISH_TIMEOUT = 10.0 # Default timeout for publish operations def __init__( - self, - client_id: str, - main_stop_event: asyncio.Event, - message_adapter: MessageAdapter, - on_connect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, - on_disconnect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, - on_publish_result: Optional[Callable[[PublishResult], Coroutine[Any, Any, None]]] = None, - rate_limits_handler: Optional[Callable[[RPCResponse], Coroutine[Any, Any, None]]] = None, - rpc_response_handler: Optional[RPCResponseHandler] = None, + self, + client_id: str, + main_stop_event: asyncio.Event, + message_adapter: MessageAdapter, + on_connect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, + on_disconnect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, + on_publish_result: Optional[Callable[[PublishResult], Coroutine[Any, Any, None]]] = None, + rate_limits_handler: Optional[Callable[[RPCResponse], Coroutine[Any, Any, None]]] = None, + rpc_response_handler: Optional[RPCResponseHandler] = None, ): self._main_stop_event = main_stop_event self._message_adapter = message_adapter @@ -98,7 +97,7 @@ def __init__( self.__rate_limiter: Optional[RateLimiter] = None self.__gateway_rate_limiter: Optional[RateLimiter] = None self.__is_gateway = False - self.__is_waiting_for_rate_limits_publish = True # Start with True to prevent publishing before rate limits are retrieved + self.__is_waiting_for_rate_limits_publish = True # True to prevent publishing before rate limits are retrieved self._rate_limits_ready_event = asyncio.Event() async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, @@ -153,7 +152,6 @@ async def disconnect(self): async def publish(self, message: MqttPublishMessage, - qos: int = 1, # TODO: probably should be removed, as qos is set in MqttPublishMessage force=False): if not force: @@ -173,14 +171,14 @@ async def publish(self, raise RuntimeError("Publishing temporarily paused due to backpressure.") if not message.dup: - return await self.process_regular_publish(message, qos) + return await self.process_regular_publish(message, message.qos) else: - # If a message is a duplicate, we should not process it as a regular publish message, it should be sent immediately + # If a message is a duplicate, it should be sent immediately if logger.isEnabledFor(TRACE_LEVEL): logger.trace("Processing duplicate message with topic: %s, qos: %d, payload size: %d", - message.topic, qos, len(message.payload)) + message.topic, message.qos, len(message.payload)) - protocol = self._client._connection._protocol # noqa + protocol = self._client._connection._protocol # noqa if protocol: try: @@ -203,17 +201,21 @@ async def process_regular_publish(self, message: MqttPublishMessage, qos: int = mqtt_future = asyncio.get_event_loop().create_future() mqtt_future.uuid = uuid4() if logger.isEnabledFor(TRACE_LEVEL): - logger.trace("Publishing message with topic: %s, qos: %d, payload size: %d, mqtt_future id: %r, delivery futures: %r", - message.topic, qos, len(message.payload), mqtt_future.uuid, [f.uuid for f in message.delivery_futures]) + logger.trace( + "Publishing message with topic: %s, qos: %d, " + "payload size: %d, mqtt_future id: %r, delivery futures: %r", + message.topic, qos, len(message.payload), + mqtt_future.uuid, [f.uuid for f in message.delivery_futures]) if message.delivery_futures is not None: - await self._add_future_chain_processing(mqtt_future, message) + await self._add_future_chain_processing(mqtt_future, message) # noqa mid, package = self._client._connection.publish(message) # noqa message.mark_as_sent(mid) if qos > 0: - logger.trace("Publishing mid=%s, storing publish main future with id: %r", mid, mqtt_future.uuid) + logger.trace("Publishing mid=%s, storing publish main future with id: %r", + mid, mqtt_future.uuid) self._pending_publishes[mid] = (mqtt_future, message, monotonic()) self._client._persistent_storage.push_message_nowait(mid, message) # noqa else: @@ -253,7 +255,8 @@ def _on_connect_internal(self, client, session_present, reason_code, properties) self._connected_event.clear() return logger.info("Connected to the platform.") - logger.debug("Connection session_present: %s, reason code: %s, properties: %s", session_present, reason_code, properties) + logger.debug("Connection session_present: %s, reason code: %s, properties: %s", session_present, reason_code, + properties) if hasattr(client, '_connection'): client._connection._on_disconnect_called = False # noqa self._connected_event.set() @@ -317,7 +320,7 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc if reason_code == 142: logger.error("Session was taken over, looks like another client connected with the same credentials.") self._backpressure.notify_disconnect(delay_seconds=10) - if reason_code in (131, 142, 143, 151): # 131, 142, 151 may be caused by rate limits or issue with the data + if reason_code in (131, 142, 143, 151): # 131, 142, 151 may be caused by rate limits or issue with the data reached_time = 1 for rate_limit in self.__rate_limiter.values(): if isinstance(rate_limit, RateLimit): @@ -325,7 +328,7 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc reached_limit = run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) except TimeoutError: logger.warning("Timeout while checking rate limit reaching.") - reached_time = 10 # Default to 10 seconds if timeout occurs + reached_time = 10 # Default to 10 seconds if timeout occurs break reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) self._backpressure.notify_disconnect(delay_seconds=reached_time) @@ -375,8 +378,10 @@ def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dic logger.warning("PUBACK received with QUOTA_EXCEEDED for mid=%s", mid) self._backpressure.notify_quota_exceeded(delay_seconds=10) elif reason_code == IMPLEMENTATION_SPECIFIC_ERROR: - logger.warning("PUBACK received with IMPLEMENTATION_SPECIFIC_ERROR for mid=%s, treating as rate limit reached", mid) - self._backpressure.notify_quota_exceeded(delay_seconds=15) # Treat implementation specific error as quota exceeded + logger.warning( + "PUBACK received with IMPLEMENTATION_SPECIFIC_ERROR for mid=%s, treating as rate limit reached", mid) + self._backpressure.notify_quota_exceeded( + delay_seconds=15) # Treat implementation specific error as quota exceeded elif reason_code != 0: logger.warning("PUBACK received with error code %s for mid=%s", reason_code, mid) @@ -425,7 +430,7 @@ async def __request_rate_limits(self): response_future = self._rpc_response_handler.register_request(request.request_id, self.__rate_limits_handler) try: - await self.publish(mqtt_message, qos=1, force=True) + await self.publish(mqtt_message, force=True) await await_or_stop(response_future, self._main_stop_event, timeout=10) logger.info("Successfully processed rate limits.") # self.__rate_limits_retrieved = True @@ -463,7 +468,9 @@ async def _monitor_ack_timeouts(self): await asyncio.sleep(0.1) await self.check_pending_publishes(monotonic()) - def patch_client_for_retry_logic(self, put_retry_message_method: Callable[[MqttPublishMessage], Coroutine[Any, Any, None]]): + def patch_client_for_retry_logic(self, + put_retry_message_method: Callable[[MqttPublishMessage], + Coroutine[Any, Any, None]]): self._client.put_retry_message = put_retry_message_method async def check_pending_publishes(self, time_to_check): diff --git a/tests/__init__.py b/tests/__init__.py index fa669aa..cff354e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/common/test_async_utils.py b/tests/common/test_async_utils.py index a588f5b..049351c 100644 --- a/tests/common/test_async_utils.py +++ b/tests/common/test_async_utils.py @@ -34,6 +34,7 @@ async def test_future_map_register_and_get_parents(): assert set(fm.get_parents(child1)) == {parent} assert set(fm.get_parents(child2)) == {parent} + @pytest.mark.asyncio async def test_future_map_child_resolved_merges_results(): fm = FutureMap() @@ -55,6 +56,7 @@ async def test_future_map_child_resolved_merges_results(): assert isinstance(result, PublishResult) assert result.reason_code == 0 + @pytest.mark.asyncio async def test_future_map_child_resolved_no_publish_result(): fm = FutureMap() @@ -67,6 +69,7 @@ async def test_future_map_child_resolved_no_publish_result(): assert parent.done() assert parent.result() is None + @pytest.mark.asyncio async def test_future_map_child_resolved_with_cancelled_child(): fm = FutureMap() @@ -78,35 +81,48 @@ async def test_future_map_child_resolved_with_cancelled_child(): assert parent.done() assert parent.result() is None + @pytest.mark.asyncio async def test_await_or_stop_coroutine_finishes_first(): stop_event = asyncio.Event() + async def coro(): return 123 + result = await await_or_stop(coro(), stop_event, timeout=1) assert result == 123 + @pytest.mark.asyncio async def test_await_or_stop_stop_event_first(): stop_event = asyncio.Event() + async def coro(): await asyncio.sleep(0.5) + asyncio.get_event_loop().call_soon(stop_event.set) result = await await_or_stop(coro(), stop_event, timeout=1) assert result is None + @pytest.mark.asyncio async def test_await_or_stop_timeout(): stop_event = asyncio.Event() + async def coro(): await asyncio.sleep(1) + with pytest.raises(asyncio.TimeoutError): await await_or_stop(coro(), stop_event, timeout=0.01) + @pytest.mark.asyncio async def test_await_or_stop_negative_timeout(): stop_event = asyncio.Event() + async def coro(): return "ok" + result = await await_or_stop(coro(), stop_event, timeout=-1) assert result == "ok" + @pytest.mark.asyncio async def test_await_or_stop_future_done(): stop_event = asyncio.Event() @@ -115,17 +131,21 @@ async def test_await_or_stop_future_done(): result = await await_or_stop(fut, stop_event, timeout=1) assert result == "done" + @pytest.mark.asyncio async def test_await_or_stop_invalid_type(): stop_event = asyncio.Event() with pytest.raises(TypeError): await await_or_stop("not a future", stop_event, timeout=1) + @pytest.mark.asyncio async def test_await_or_stop_cancelled_error(): stop_event = asyncio.Event() + async def coro(): raise asyncio.CancelledError() + result = await await_or_stop(coro(), stop_event, timeout=1) assert result is None diff --git a/tests/common/test_config_loader.py b/tests/common/test_config_loader.py index 6148d6e..fe482d3 100644 --- a/tests/common/test_config_loader.py +++ b/tests/common/test_config_loader.py @@ -70,6 +70,7 @@ def detects_tls_correctly(self): config = DeviceConfig() self.assertTrue(config.use_tls()) + class TestGatewayConfig(unittest.TestCase): def loads_gateway_specific_env_vars(self): @@ -102,4 +103,4 @@ def falls_back_to_device_config_when_gateway_env_vars_missing(self): os.environ["TB_PORT"] = "1884" config = GatewayConfig() self.assertEqual(config.host, "device_host") - self.assertEqual(config.port, 1884) \ No newline at end of file + self.assertEqual(config.port, 1884) diff --git a/tests/common/test_gmqtt_patch.py b/tests/common/test_gmqtt_patch.py index 414d690..64f8474 100644 --- a/tests/common/test_gmqtt_patch.py +++ b/tests/common/test_gmqtt_patch.py @@ -29,12 +29,14 @@ def test_parse_mqtt_properties_valid_and_invalid(): pkt = bytes([1]) + bytes([255]) assert PatchUtils.parse_mqtt_properties(pkt) == {} - # Exception path (invalid varint) + # Exception path (invalid variant) assert PatchUtils.parse_mqtt_properties(b"\xff") == {} def test_extract_reason_code_all_paths(): - class Obj: reason_code = 42 + class Obj: + reason_code = 42 + assert PatchUtils.extract_reason_code(Obj()) == 42 assert PatchUtils.extract_reason_code(b"\x00*") == 42 assert PatchUtils.extract_reason_code(b"") is None @@ -44,6 +46,7 @@ class Obj: reason_code = 42 def test_patch_puback_handling_and_storage(monkeypatch): pu = PatchUtils(None, asyncio.Event()) called = {} + def on_puback(mid, reason, props): called["hit"] = (mid, reason, props) @@ -51,11 +54,9 @@ def on_puback(mid, reason, props): pu.patch_puback_handling(on_puback) # Call wrapped handler - pkt = struct.pack("!HB", 10, 1) + b"\x00" handler = types.SimpleNamespace( _connection=types.SimpleNamespace(persistent_storage=types.SimpleNamespace(remove=lambda m: None)) ) - handler2 = handler MqttHandlerClass = type("H", (), {}) pu_handler = MqttHandlerClass() pu_handler._connection = handler._connection @@ -67,8 +68,8 @@ def on_puback(mid, reason, props): # patch_storage test client = types.SimpleNamespace(_persistent_storage=types.SimpleNamespace( - _queue=[(0,1,"raw")], - _check_empty=lambda : None + _queue=[(0, 1, "raw")], + _check_empty=lambda: None )) pu.client = client pu.patch_storage() @@ -78,18 +79,23 @@ def on_puback(mid, reason, props): @pytest.mark.asyncio async def test_retry_loop_and_task_controls(monkeypatch): storage_queue = [] + class Storage: def _check_empty(self): pass + async def pop_message(self): if not storage_queue: raise IndexError return storage_queue.pop(0) msgs_sent = [] + class FakeClient: is_connected = True + async def put_retry_message(self, msg): msgs_sent.append(msg) + _persistent_storage = Storage() pu = PatchUtils(FakeClient(), asyncio.Event(), retry_interval=0) @@ -112,8 +118,8 @@ async def put_retry_message(self, msg): def test_apply_calls_patch_and_starts_task(monkeypatch): pu = PatchUtils(None, asyncio.Event()) monkeypatch.setattr(pu, "patch_puback_handling", lambda cb: setattr(pu, "_patched", True)) - monkeypatch.setattr(pu, "start_retry_task", lambda : setattr(pu, "_started", True)) - pu.apply(lambda a,b,c: None) + monkeypatch.setattr(pu, "start_retry_task", lambda: setattr(pu, "_started", True)) + pu.apply(lambda a, b, c: None) assert pu._patched assert pu._started diff --git a/tests/common/test_publish_result.py b/tests/common/test_publish_result.py index 738c592..dbe57cb 100644 --- a/tests/common/test_publish_result.py +++ b/tests/common/test_publish_result.py @@ -60,6 +60,7 @@ def test_publish_result_as_dict(default_publish_result): "datapoints_count": 0 } + def test_publish_request_merge(): result1 = PublishResult( topic="v1/devices/me/telemetry", @@ -79,14 +80,16 @@ def test_publish_request_merge(): assert merged_result.topic == "v1/devices/me/telemetry" assert merged_result.qos == 1 - assert merged_result.message_id == -1 # Merged results do not have a specific message_id + assert merged_result.message_id == -1 # Merged results do not have a specific message_id assert merged_result.payload_size == 768 # Combined payload size assert merged_result.reason_code == 0 # All successful + def test_publish_result_merge_with_empty_list(): with pytest.raises(ValueError, match="No publish results to merge."): PublishResult.merge([]) + def test_publish_result_is_successful_true(default_publish_result): assert default_publish_result.is_successful() is True @@ -150,5 +153,6 @@ def test_publish_result_various_failure_codes(reason_code): ) assert result.is_successful() is False + if __name__ == '__main__': pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py index 1e02c30..75b02e5 100644 --- a/tests/common/test_queue.py +++ b/tests/common/test_queue.py @@ -23,8 +23,10 @@ @pytest.fixture def make_item(): """Factory to create unique or duplicate-like test items.""" + def _make(uuid): return SimpleNamespace(uuid=uuid) + return _make @@ -151,5 +153,6 @@ async def delayed_put(): got = await q.get() assert got.uuid == "later" + if __name__ == '__main__': pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/common/test_rate_limit.py b/tests/common/test_rate_limit.py index e46686b..8556f2b 100644 --- a/tests/common/test_rate_limit.py +++ b/tests/common/test_rate_limit.py @@ -121,6 +121,7 @@ async def test_rate_limit_refill_behavior(): await rl.refill() assert (await rl.try_consume()) is None + @pytest.mark.asyncio async def test_set_required_tokens_and_clear_event(): rl = RateLimit("5:1", "token-test") diff --git a/tests/constants/__init__.py b/tests/constants/__init__.py index fa669aa..cff354e 100644 --- a/tests/constants/__init__.py +++ b/tests/constants/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/constants/test_mqtt_topics.py b/tests/constants/test_mqtt_topics.py index 6fbe7b0..f8f19e3 100644 --- a/tests/constants/test_mqtt_topics.py +++ b/tests/constants/test_mqtt_topics.py @@ -18,4 +18,4 @@ def test_device_topic_builders(): assert mqtt_topics.build_device_attributes_request_topic(42) == "v1/devices/me/attributes/request/42" assert mqtt_topics.build_device_rpc_request_topic(99) == "v1/devices/me/rpc/request/99" - assert mqtt_topics.build_device_rpc_response_topic(99) == "v1/devices/me/rpc/response/99" \ No newline at end of file + assert mqtt_topics.build_device_rpc_response_topic(99) == "v1/devices/me/rpc/response/99" diff --git a/tests/entities/gateway/test_device_info.py b/tests/entities/gateway/test_device_info.py index ab9a3af..b71053a 100644 --- a/tests/entities/gateway/test_device_info.py +++ b/tests/entities/gateway/test_device_info.py @@ -107,5 +107,6 @@ def test_hash_works_in_set(): s = {d} assert d in s + if __name__ == '__main__': - pytest.main([__file__, "--tb=short", "-v"]) \ No newline at end of file + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_attribute_request.py b/tests/entities/gateway/test_gateway_attribute_request.py index b1a85eb..a8b498c 100644 --- a/tests/entities/gateway/test_gateway_attribute_request.py +++ b/tests/entities/gateway/test_gateway_attribute_request.py @@ -33,7 +33,8 @@ async def test_build_assigns_values_and_repr(): mock_session = MagicMock() mock_session.device_info.device_name = "TestDevice" - with patch("tb_mqtt_client.common.request_id_generator.AttributeRequestIdProducer.get_next", new=AsyncMock(return_value=123)): + with patch("tb_mqtt_client.common.request_id_generator.AttributeRequestIdProducer.get_next", + new=AsyncMock(return_value=123)): req = await GatewayAttributeRequest.build( device_session=mock_session, shared_keys=["s1", "s2"], diff --git a/tests/entities/gateway/test_gateway_attribute_update.py b/tests/entities/gateway/test_gateway_attribute_update.py index 2842dc2..9d7a04c 100644 --- a/tests/entities/gateway/test_gateway_attribute_update.py +++ b/tests/entities/gateway/test_gateway_attribute_update.py @@ -28,6 +28,7 @@ def test_init_with_list_of_entries(): assert isinstance(obj.attribute_update, AttributeUpdate) assert str(obj) == f"GatewayAttributeUpdate(device_name=deviceA, attribute_update={obj.attribute_update})" + def test_init_with_single_entry(): entry = AttributeEntry("k", "v") obj = GatewayAttributeUpdate("deviceB", entry) @@ -37,6 +38,7 @@ def test_init_with_single_entry(): assert isinstance(obj.attribute_update, AttributeUpdate) assert "deviceB" in str(obj) + def test_init_with_attribute_update(): update = AttributeUpdate([AttributeEntry("ka", "va")]) obj = GatewayAttributeUpdate("deviceC", update) @@ -45,6 +47,7 @@ def test_init_with_attribute_update(): assert isinstance(obj.attribute_update, AttributeUpdate) assert "deviceC" in str(obj) + def test_init_with_invalid_type(): with pytest.raises(TypeError) as exc: GatewayAttributeUpdate("deviceD", {"invalid": "type"}) diff --git a/tests/service/__init__.py b/tests/service/__init__.py index fa669aa..cff354e 100644 --- a/tests/service/__init__.py +++ b/tests/service/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/service/device/__init__.py b/tests/service/device/__init__.py index fa669aa..cff354e 100644 --- a/tests/service/device/__init__.py +++ b/tests/service/device/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/service/device/handlers/test_attribute_updates_handler.py b/tests/service/device/handlers/test_attribute_updates_handler.py index 3533a65..8cb294c 100644 --- a/tests/service/device/handlers/test_attribute_updates_handler.py +++ b/tests/service/device/handlers/test_attribute_updates_handler.py @@ -71,7 +71,6 @@ def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: return self.result - @pytest.mark.asyncio async def test_set_message_adapter_and_callback_called(): handler = AttributeUpdatesHandler() diff --git a/tests/service/device/handlers/test_rpc_requests_handler.py b/tests/service/device/handlers/test_rpc_requests_handler.py index 55b8a96..197bc42 100644 --- a/tests/service/device/handlers/test_rpc_requests_handler.py +++ b/tests/service/device/handlers/test_rpc_requests_handler.py @@ -194,6 +194,7 @@ async def test_rpc_request_build_and_str_and_payload_format(): assert payload["method"] == "myMethod" assert payload["params"] == {"y": 2} + @pytest.mark.asyncio async def test_rpc_request_build_invalid_method_type(): with pytest.raises(ValueError): diff --git a/tests/service/device/handlers/test_rpc_response_handler.py b/tests/service/device/handlers/test_rpc_response_handler.py index ced4323..324b865 100644 --- a/tests/service/device/handlers/test_rpc_response_handler.py +++ b/tests/service/device/handlers/test_rpc_response_handler.py @@ -117,6 +117,7 @@ async def test_handle_no_message_adapter_uses_json_adapter(): assert fut.done() assert fut.result().result == {'result': {"foo": "bar"}} + @pytest.mark.asyncio async def test_handle_no_message_adapter_uses_json_adapter_with_error_rpc(): handler = RPCResponseHandler() diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py index 7e7dc48..73a2ebf 100644 --- a/tests/service/device/test_device_client.py +++ b/tests/service/device/test_device_client.py @@ -90,7 +90,9 @@ async def fake_await_ready(): # Assert assert isinstance(result, PublishResult) - assert result.message_id == -1 # The initial mqtt message doesn't contain message_id, expected behavior because an initial message can be split to separated messages or grouped with other messages + # The initial mqtt message doesn't contain message_id, + # expected behavior because an initial message can be split to separated messages or grouped with other messages + assert result.message_id == -1 assert result.topic == mqtt_msg.topic assert result.payload_size == mqtt_msg.payload_size assert result.datapoints_count == mqtt_msg.datapoints @@ -144,6 +146,7 @@ async def fake_connect(**kwargs): async def fake_await_ready(): return + client._mqtt_manager.await_ready = fake_await_ready await client.connect() diff --git a/tests/service/device/test_firmware_updater.py b/tests/service/device/test_firmware_updater.py index 46874bc..d3598cb 100644 --- a/tests/service/device/test_firmware_updater.py +++ b/tests/service/device/test_firmware_updater.py @@ -19,7 +19,10 @@ import pytest -from tb_mqtt_client.constants.firmware import * +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.firmware import (FirmwareStates, FW_STATE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR, + FW_SIZE_ATTR, FW_CHECKSUM_ALG_ATTR, REQUIRED_SHARED_KEYS, + FW_CHECKSUM_ATTR) from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater @@ -45,8 +48,7 @@ def updater(mock_client): @pytest.mark.asyncio async def test_update_success(updater, mock_client): - with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequest.build", - new_callable=AsyncMock) as mock_build, \ + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequest.build", new_callable=AsyncMock), \ patch.object(updater, "_firmware_info_callback", new=AsyncMock()): await updater.update() mock_client._mqtt_manager.subscribe.assert_called_once() @@ -99,12 +101,7 @@ async def test_get_next_chunk_empty_payload(updater, mock_client): updater._chunk_size = 15 updater._target_firmware_length = 10 await updater._get_next_chunk() - mock_client._message_queue.publish.assert_awaited_with( - topic=ANY, - payload=b'', - datapoints_count=0, - qos=1 - ) + mock_client._message_queue.publish.assert_awaited() @pytest.mark.asyncio diff --git a/tests/service/gateway/__init__.py b/tests/service/gateway/__init__.py index fa669aa..cff354e 100644 --- a/tests/service/gateway/__init__.py +++ b/tests/service/gateway/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/service/gateway/handlers/__init__.py b/tests/service/gateway/handlers/__init__.py index fa669aa..cff354e 100644 --- a/tests/service/gateway/handlers/__init__.py +++ b/tests/service/gateway/handlers/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py b/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py index 32621b0..f6fd6c1 100644 --- a/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py +++ b/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py @@ -33,32 +33,32 @@ async def test_handle_with_existing_device(): event_dispatcher = AsyncMock(spec=DirectEventDispatcher) message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) - + # Create a handler handler = GatewayAttributeUpdatesHandler( event_dispatcher=event_dispatcher, message_adapter=message_adapter, device_manager=device_manager ) - + # Mock the device session device_session = MagicMock(spec=DeviceSession) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter payload = b'{"device": "test_device", "data": {"key": "value"}}' deserialized_data = {"device": "test_device", "data": {"key": "value"}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock attribute update attribute_update = AttributeUpdate([AttributeEntry(key="key", value="value")]) gateway_attribute_update = GatewayAttributeUpdate(device_name="test_device", attribute_update=attribute_update) message_adapter.parse_attribute_update.return_value = gateway_attribute_update gateway_attribute_update.set_device_session = AsyncMock() - + # Act await handler.handle("topic", payload) - + # Assert message_adapter.deserialize_to_dict.assert_called_once_with(payload) message_adapter.parse_attribute_update.assert_called_once_with(deserialized_data) @@ -73,37 +73,39 @@ async def test_handle_with_nonexistent_device(): event_dispatcher = AsyncMock(spec=DirectEventDispatcher) message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) - + # Create a handler handler = GatewayAttributeUpdatesHandler( event_dispatcher=event_dispatcher, message_adapter=message_adapter, device_manager=device_manager ) - + # Mock the device session (nonexistent) device_manager.get_by_name.return_value = None - + # Mock the message adapter payload = b'{"device": "nonexistent_device", "data": {"key": "value"}}' deserialized_data = {"device": "nonexistent_device", "data": {"key": "value"}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock attribute update attribute_update = AttributeUpdate([AttributeEntry("key", "value")]) - gateway_attribute_update = GatewayAttributeUpdate(device_name="nonexistent_device", attribute_update=attribute_update) + gateway_attribute_update = GatewayAttributeUpdate(device_name="nonexistent_device", + attribute_update=attribute_update) message_adapter.parse_attribute_update.return_value = gateway_attribute_update gateway_attribute_update.set_device_session = AsyncMock() - + # Act await handler.handle("topic", payload) - + # Assert message_adapter.deserialize_to_dict.assert_called_once_with(payload) message_adapter.parse_attribute_update.assert_called_once_with(deserialized_data) device_manager.get_by_name.assert_called_once_with("nonexistent_device") # set_device_session should not be called since a device is None - assert not hasattr(gateway_attribute_update, 'set_device_session') or not gateway_attribute_update.set_device_session.called + assert not hasattr(gateway_attribute_update, + 'set_device_session') or not gateway_attribute_update.set_device_session.called event_dispatcher.dispatch.assert_awaited_once_with(attribute_update, device_session=None) diff --git a/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py b/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py index 670ac51..0badff7 100644 --- a/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py +++ b/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py @@ -58,7 +58,8 @@ async def test_register_request(): # Assert assert (request.device_session.device_info.device_name, request.request_id) in handler._pending_attribute_requests - assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][0] == request + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][ + 0] == request @pytest.mark.asyncio @@ -93,10 +94,12 @@ async def test_register_request_with_timeout(): await handler.register_request(request, timeout=10) # Assert - assert (request.device_session.device_info.device_name, request.request_id) in handler._pending_attribute_requests - assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][0] == request - assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][1] == mock_timeout_task - mock_loop.call_later.assert_called_once_with(10, handler._on_timeout, request.device_session.device_info.device_name, request.request_id) + assert (request.device_session.device_info.device_name, + request.request_id) in handler._pending_attribute_requests + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][0] == request # noqa + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][1] == mock_timeout_task # noqa + mock_loop.call_later.assert_called_once_with(10, handler._on_timeout, + request.device_session.device_info.device_name, request.request_id) @pytest.mark.asyncio @@ -150,13 +153,15 @@ async def test_unregister_request(): request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) # Register the request - handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, None) + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, None) # Act handler.unregister_request(request.device_session.device_info.device_name, request.request_id) # Assert - assert (request.device_session.device_info.device_name, request.request_id) not in handler._pending_attribute_requests + assert (request.device_session.device_info.device_name, + request.request_id) not in handler._pending_attribute_requests def test_unregister_nonexistent_request(): @@ -203,15 +208,18 @@ async def test_handle_valid_response(): # Register the request timeout_task = MagicMock() - handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, timeout_task) + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, timeout_task) # Mock the message adapter - payload = b'{"device": "test_device", "id": ' + str(request.request_id).encode() + b', "values": {"key1": "value1"}}' + payload = b'{"device": "test_device", "id": ' + str( + request.request_id).encode() + b', "values": {"key1": "value1"}}' deserialized_data = {"device": "test_device", "id": request.request_id, "values": {"key1": "value1"}} message_adapter.deserialize_to_dict.return_value = deserialized_data # Create a mock response - response = GatewayRequestedAttributeResponse(device_name="test_device", request_id=request.request_id, shared=[AttributeEntry("key1", "value1")]) + response = GatewayRequestedAttributeResponse(device_name="test_device", request_id=request.request_id, + shared=[AttributeEntry("key1", "value1")]) message_adapter.parse_gateway_requested_attribute_response.return_value = response # Mock asyncio.create_task @@ -306,12 +314,14 @@ async def test_on_timeout(): request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) # Register the request - handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, None) + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, None) handler._on_timeout(request.device_session.device_info.device_name, request.request_id) # Assert - assert (request.device_session.device_info.device_name, request.request_id) not in handler._pending_attribute_requests + assert (request.device_session.device_info.device_name, + request.request_id) not in handler._pending_attribute_requests def test_handle_callback_exception(): @@ -360,7 +370,8 @@ async def test_clear(): request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) # Register the request - handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = (request, None) + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, None) # Act handler.clear() @@ -368,35 +379,38 @@ async def test_clear(): # Assert assert len(handler._pending_attribute_requests) == 0 + @pytest.mark.asyncio async def test_handle_no_message_adapter_removes_request(): adapter = MagicMock(spec=GatewayMessageAdapter) handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), - message_adapter=adapter, - device_manager=MagicMock(spec=DeviceManager)) + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) handler._pending_attribute_requests["test_device", 5] = (MagicMock(spec=AttributeRequest), AsyncMock()) topic = "attr/request/5" await handler.handle(topic, b"{}") assert 5 not in handler._pending_attribute_requests + @pytest.mark.asyncio async def test_handle_with_no_callback_registered(): adapter = MagicMock(spec=GatewayMessageAdapter) handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), - message_adapter=adapter, - device_manager=MagicMock(spec=DeviceManager)) + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) resp = MagicMock(spec=RequestedAttributeResponse) resp.request_id = 42 adapter.parse_gateway_requested_attribute_response.return_value = resp handler._pending_attribute_requests['test_device', 42] = (MagicMock(spec=AttributeRequest), None) await handler.handle("topic", b"payload") + @pytest.mark.asyncio async def test_handle_with_parsing_exception(): adapter = MagicMock(spec=GatewayMessageAdapter) handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), - message_adapter=adapter, - device_manager=MagicMock(spec=DeviceManager)) + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) adapter.parse_gateway_requested_attribute_response.side_effect = RuntimeError("bad parse") await handler.handle("topic", b"payload") diff --git a/tests/service/gateway/handlers/test_gateway_rpc_handler.py b/tests/service/gateway/handlers/test_gateway_rpc_handler.py index a75ac5a..fa85e90 100644 --- a/tests/service/gateway/handlers/test_gateway_rpc_handler.py +++ b/tests/service/gateway/handlers/test_gateway_rpc_handler.py @@ -35,7 +35,7 @@ def test_init(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Act handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -43,7 +43,7 @@ def test_init(): device_manager=device_manager, stop_event=stop_event ) - + # Assert assert handler._event_dispatcher == event_dispatcher assert handler._message_adapter == message_adapter @@ -60,7 +60,7 @@ async def test_handle_successful_rpc(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -68,43 +68,44 @@ async def test_handle_successful_rpc(): device_manager=device_manager, stop_event=stop_event ) - + # Create a device session device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter topic = "v1/gateway/rpc" payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' - deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock RPC request rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) message_adapter.parse_rpc_request.return_value = rpc_request - + # Create a mock RPC response rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") - + # Mock the event dispatcher to return the RPC response event_dispatcher.dispatch.side_effect = [rpc_response, asyncio.Future()] - + # Mock await_or_stop with patch('tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler.await_or_stop') as mock_await_or_stop: # Act - result = await handler.handle(topic, payload) - + await handler.handle(topic, payload) + # Assert message_adapter.deserialize_to_dict.assert_called_once_with(payload) message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) device_manager.get_by_name.assert_called_once_with("test_device") - + # Check that dispatch was called twice - once for the request and once for the response assert event_dispatcher.dispatch.call_count == 2 event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) event_dispatcher.dispatch.assert_any_call(rpc_response) - + mock_await_or_stop.assert_called_once() @@ -114,7 +115,7 @@ async def test_handle_no_message_adapter(): event_dispatcher = AsyncMock(spec=DirectEventDispatcher) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler with no message adapter handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -122,10 +123,10 @@ async def test_handle_no_message_adapter(): device_manager=device_manager, stop_event=stop_event ) - + # Act result = await handler.handle("topic", b'{}') - + # Assert assert result is None @@ -137,7 +138,7 @@ async def test_handle_no_device_session(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -145,23 +146,24 @@ async def test_handle_no_device_session(): device_manager=device_manager, stop_event=stop_event ) - + # Mock the device manager to return None (no device session) device_manager.get_by_name.return_value = None - + # Mock the message adapter topic = "v1/gateway/rpc" payload = b'{"device": "nonexistent_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' - deserialized_data = {"device": "nonexistent_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + deserialized_data = {"device": "nonexistent_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock RPC request rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) message_adapter.parse_rpc_request.return_value = rpc_request - + # Act result = await handler.handle(topic, payload) - + # Assert assert result is None message_adapter.deserialize_to_dict.assert_called_once_with(payload) @@ -176,7 +178,7 @@ async def test_handle_no_response_from_callback(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -184,28 +186,29 @@ async def test_handle_no_response_from_callback(): device_manager=device_manager, stop_event=stop_event ) - + # Create a device session device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter topic = "v1/gateway/rpc" payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' - deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock RPC request rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) message_adapter.parse_rpc_request.return_value = rpc_request - + # Mock the event dispatcher to return None (no response from callback) event_dispatcher.dispatch.return_value = None - + # Act result = await handler.handle(topic, payload) - + # Assert assert result is None message_adapter.deserialize_to_dict.assert_called_once_with(payload) @@ -221,7 +224,7 @@ async def test_handle_invalid_response_type(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -229,28 +232,29 @@ async def test_handle_invalid_response_type(): device_manager=device_manager, stop_event=stop_event ) - + # Create a device session device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter topic = "v1/gateway/rpc" payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' - deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock RPC request rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) message_adapter.parse_rpc_request.return_value = rpc_request - + # Mock the event dispatcher to return an invalid response type event_dispatcher.dispatch.side_effect = ["invalid_response", asyncio.Future()] - + # Act - result = await handler.handle(topic, payload) - + await handler.handle(topic, payload) + # Assert message_adapter.deserialize_to_dict.assert_called_once_with(payload) message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) @@ -265,7 +269,7 @@ async def test_handle_exception_in_processing(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -273,20 +277,20 @@ async def test_handle_exception_in_processing(): device_manager=device_manager, stop_event=stop_event ) - + # Create a device session device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter to raise an exception topic = "v1/gateway/rpc" payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' message_adapter.deserialize_to_dict.side_effect = Exception("Test exception") - + # Act result = await handler.handle(topic, payload) - + # Assert assert result is None message_adapter.deserialize_to_dict.assert_called_once_with(payload) @@ -299,7 +303,7 @@ async def test_handle_timeout_in_publish(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -307,35 +311,36 @@ async def test_handle_timeout_in_publish(): device_manager=device_manager, stop_event=stop_event ) - + # Create a device session device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter topic = "v1/gateway/rpc" payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' - deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock RPC request rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) message_adapter.parse_rpc_request.return_value = rpc_request - + # Create a mock RPC response rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") - + # Mock the event dispatcher event_dispatcher.dispatch.side_effect = [rpc_response, asyncio.Future()] - + # Mock await_or_stop to raise TimeoutError with patch('tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler.await_or_stop') as mock_await_or_stop: mock_await_or_stop.side_effect = TimeoutError("Timeout") - + # Act - result = await handler.handle(topic, payload) - + await handler.handle(topic, payload) + # Assert message_adapter.deserialize_to_dict.assert_called_once_with(payload) message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) @@ -352,7 +357,7 @@ async def test_handle_no_publish_futures(): message_adapter = MagicMock(spec=GatewayMessageAdapter) device_manager = MagicMock(spec=DeviceManager) stop_event = asyncio.Event() - + # Create a handler handler = GatewayRPCHandler( event_dispatcher=event_dispatcher, @@ -360,31 +365,32 @@ async def test_handle_no_publish_futures(): device_manager=device_manager, stop_event=stop_event ) - + # Create a device session device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) device_manager.get_by_name.return_value = device_session - + # Mock the message adapter topic = "v1/gateway/rpc" payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' - deserialized_data = {"device": "test_device", "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} message_adapter.deserialize_to_dict.return_value = deserialized_data - + # Create a mock RPC request rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) message_adapter.parse_rpc_request.return_value = rpc_request - + # Create a mock RPC response rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") - + # Mock the event dispatcher to return the RPC response but no publish futures event_dispatcher.dispatch.side_effect = [rpc_response, None] - + # Act result = await handler.handle(topic, payload) - + # Assert assert result is None message_adapter.deserialize_to_dict.assert_called_once_with(payload) diff --git a/tests/service/gateway/test_device_manager.py b/tests/service/gateway/test_device_manager.py index 58715e9..10650e6 100644 --- a/tests/service/gateway/test_device_manager.py +++ b/tests/service/gateway/test_device_manager.py @@ -25,10 +25,10 @@ def test_register_new_device(): # Setup manager = DeviceManager() - + # Act session = manager.register("test_device", "default") - + # Assert assert session is not None assert session.device_info.device_name == "test_device" @@ -42,10 +42,10 @@ def test_register_existing_device(): # Setup manager = DeviceManager() first_session = manager.register("test_device", "default") - + # Act second_session = manager.register("test_device", "custom") - + # Assert assert first_session is second_session assert second_session.device_info.device_profile == "default" # Profile should not change @@ -56,10 +56,10 @@ def test_unregister_device(): manager = DeviceManager() session = manager.register("test_device", "default") device_id = session.device_info.device_id - + # Act manager.unregister(device_id) - + # Assert assert device_id not in manager._sessions_by_id assert "test_device" not in manager._ids_by_device_name @@ -71,7 +71,7 @@ def test_unregister_nonexistent_device(): manager = DeviceManager() from uuid import uuid4 nonexistent_id = uuid4() - + # Act & Assert - should not raise an exception manager.unregister(nonexistent_id) @@ -81,10 +81,10 @@ def test_get_by_id(): manager = DeviceManager() session = manager.register("test_device", "default") device_id = session.device_info.device_id - + # Act retrieved_session = manager.get_by_id(device_id) - + # Assert assert retrieved_session is session @@ -94,10 +94,10 @@ def test_get_by_id_nonexistent(): manager = DeviceManager() from uuid import uuid4 nonexistent_id = uuid4() - + # Act retrieved_session = manager.get_by_id(nonexistent_id) - + # Assert assert retrieved_session is None @@ -106,10 +106,10 @@ def test_get_by_name(): # Setup manager = DeviceManager() session = manager.register("test_device", "default") - + # Act retrieved_session = manager.get_by_name("test_device") - + # Assert assert retrieved_session is session @@ -119,13 +119,13 @@ def test_get_by_original_name(): manager = DeviceManager() session = manager.register("test_device", "default") original_name = session.device_info.original_name - + # Rename the device manager.rename_device("test_device", "new_name") - + # Act retrieved_session = manager.get_by_name(original_name) - + # Assert assert retrieved_session is session @@ -133,10 +133,10 @@ def test_get_by_original_name(): def test_get_by_name_nonexistent(): # Setup manager = DeviceManager() - + # Act retrieved_session = manager.get_by_name("nonexistent_device") - + # Assert assert retrieved_session is None @@ -146,10 +146,10 @@ def test_is_connected(): manager = DeviceManager() session = manager.register("test_device", "default") device_id = session.device_info.device_id - + # Act is_connected = manager.is_connected(device_id) - + # Assert assert is_connected is True @@ -159,10 +159,10 @@ def test_is_connected_nonexistent(): manager = DeviceManager() from uuid import uuid4 nonexistent_id = uuid4() - + # Act is_connected = manager.is_connected(nonexistent_id) - + # Assert assert is_connected is False @@ -172,10 +172,10 @@ def test_all(): manager = DeviceManager() session1 = manager.register("device1", "default") session2 = manager.register("device2", "default") - + # Act all_sessions = list(manager.all()) - + # Assert assert len(all_sessions) == 2 assert session1 in all_sessions @@ -187,10 +187,10 @@ def test_rename_device(): manager = DeviceManager() session = manager.register("old_name", "default") device_id = session.device_info.device_id - + # Act manager.rename_device("old_name", "new_name") - + # Assert assert "old_name" not in manager._ids_by_device_name assert "new_name" in manager._ids_by_device_name @@ -202,10 +202,10 @@ def test_rename_device(): def test_rename_nonexistent_device(): # Setup manager = DeviceManager() - + # Act - should not raise an exception manager.rename_device("nonexistent_device", "new_name") - + # Assert assert "nonexistent_device" not in manager._ids_by_device_name assert "new_name" not in manager._ids_by_device_name @@ -217,13 +217,13 @@ def test_set_attribute_update_callback(): session = manager.register("test_device", "default") device_id = session.device_info.device_id callback = MagicMock() - + # Mock the session's set_attribute_update_callback method session.set_attribute_update_callback = MagicMock() - + # Act manager.set_attribute_update_callback(device_id, callback) - + # Assert session.set_attribute_update_callback.assert_called_once_with(callback) @@ -234,7 +234,7 @@ def test_set_attribute_update_callback_nonexistent(): from uuid import uuid4 nonexistent_id = uuid4() callback = MagicMock() - + # Act - should not raise an exception manager.set_attribute_update_callback(nonexistent_id, callback) @@ -245,13 +245,13 @@ def test_set_attribute_response_callback(): session = manager.register("test_device", "default") device_id = session.device_info.device_id callback = MagicMock() - + # Mock the session's set_attribute_response_callback method session.set_attribute_response_callback = MagicMock() - + # Act manager.set_attribute_response_callback(device_id, callback) - + # Assert session.set_attribute_response_callback.assert_called_once_with(callback) @@ -262,7 +262,7 @@ def test_set_attribute_response_callback_nonexistent(): from uuid import uuid4 nonexistent_id = uuid4() callback = MagicMock() - + # Act - should not raise an exception manager.set_attribute_response_callback(nonexistent_id, callback) @@ -273,13 +273,13 @@ def test_set_rpc_request_callback(): session = manager.register("test_device", "default") device_id = session.device_info.device_id callback = MagicMock() - + # Mock the session's set_rpc_request_callback method session.set_rpc_request_callback = MagicMock() - + # Act manager.set_rpc_request_callback(device_id, callback) - + # Assert session.set_rpc_request_callback.assert_called_once_with(callback) @@ -290,7 +290,7 @@ def test_set_rpc_request_callback_nonexistent(): from uuid import uuid4 nonexistent_id = uuid4() callback = MagicMock() - + # Act - should not raise an exception manager.set_rpc_request_callback(nonexistent_id, callback) @@ -298,18 +298,18 @@ def test_set_rpc_request_callback_nonexistent(): def test_state_change_callback(): # Setup manager = DeviceManager() - + # Create a device session with a mocked state device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info, manager._DeviceManager__state_change_callback) - + # Mock the session's state session.state = MagicMock() session.state.is_connected = MagicMock(return_value=True) - + # Act manager._DeviceManager__state_change_callback(session) - + # Assert assert session in manager.connected_devices @@ -317,21 +317,21 @@ def test_state_change_callback(): def test_state_change_callback_disconnect(): # Setup manager = DeviceManager() - + # Create a device session with a mocked state device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info, manager._DeviceManager__state_change_callback) - + # Add the session to connected devices manager._DeviceManager__connected_devices.add(session) - + # Mock the session's state session.state = MagicMock() session.state.is_connected = MagicMock(return_value=False) - + # Act manager._DeviceManager__state_change_callback(session) - + # Assert assert session not in manager.connected_devices @@ -341,14 +341,14 @@ def test_connected_devices_property(): manager = DeviceManager() session1 = manager.register("device1", "default") session2 = manager.register("device2", "default") - + # Add sessions to connected devices manager._DeviceManager__connected_devices.add(session1) manager._DeviceManager__connected_devices.add(session2) - + # Act connected = manager.connected_devices - + # Assert assert len(connected) == 2 assert session1 in connected @@ -360,10 +360,10 @@ def test_all_devices_property(): manager = DeviceManager() session1 = manager.register("device1", "default") session2 = manager.register("device2", "default") - + # Act all_devices = manager.all_devices - + # Assert assert len(all_devices) == 2 assert session1.device_info.device_id in all_devices @@ -377,10 +377,10 @@ def test_repr(): manager = DeviceManager() manager.register("device1", "default") manager.register("device2", "default") - + # Act repr_string = repr(manager) - + # Assert assert "DeviceManager" in repr_string assert "device1" in repr_string @@ -388,4 +388,4 @@ def test_repr(): if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/service/gateway/test_device_session.py b/tests/service/gateway/test_device_session.py index 93e5b6b..88d4cd9 100644 --- a/tests/service/gateway/test_device_session.py +++ b/tests/service/gateway/test_device_session.py @@ -32,10 +32,10 @@ def test_init(): # Setup device_info = DeviceInfo("test_device", "default") state_change_callback = MagicMock() - + # Act session = DeviceSession(device_info, state_change_callback) - + # Assert assert session.device_info == device_info assert session._state_change_callback == state_change_callback @@ -54,10 +54,10 @@ def test_update_state(): device_info = DeviceInfo("test_device", "default") state_change_callback = MagicMock() session = DeviceSession(device_info, state_change_callback) - + # Act session.update_state(DeviceSessionState.DISCONNECTED) - + # Assert assert session.state == DeviceSessionState.DISCONNECTED state_change_callback.assert_called_once_with(session) @@ -67,10 +67,10 @@ def test_update_state_no_callback(): # Setup device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) # No callback provided - + # Act - should not raise an exception session.update_state(DeviceSessionState.DISCONNECTED) - + # Assert assert session.state == DeviceSessionState.DISCONNECTED @@ -80,14 +80,14 @@ def test_update_last_seen(): device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) old_last_seen = session.last_seen_at - + # Wait a bit to ensure the timestamp changes import time time.sleep(0.001) - + # Act session.update_last_seen() - + # Assert assert session.last_seen_at > old_last_seen @@ -97,10 +97,10 @@ def test_set_attribute_update_callback(): device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) callback = MagicMock() - + # Act session.set_attribute_update_callback(callback) - + # Assert assert session.attribute_update_callback == callback @@ -110,10 +110,10 @@ def test_set_attribute_response_callback(): device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) callback = MagicMock() - + # Act session.set_attribute_response_callback(callback) - + # Assert assert session.attribute_response_callback == callback @@ -123,10 +123,10 @@ def test_set_rpc_request_callback(): device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) callback = MagicMock() - + # Act session.set_rpc_request_callback(callback) - + # Assert assert session.rpc_request_callback == callback @@ -138,16 +138,17 @@ async def test_handle_event_to_device_attribute_update(): session = DeviceSession(device_info) callback = MagicMock() session.set_attribute_update_callback(callback) - + # Create an attribute update event attribute_entry = AttributeEntry(key="key", value="value") attribute_update = AttributeUpdate([attribute_entry]) - gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, attribute_update=attribute_update) - + gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, + attribute_update=attribute_update) + # Act result = await session.handle_event_to_device(gateway_attribute_update.attribute_update) - + # Assert callback.assert_called_once_with(session, attribute_update) assert result == callback.return_value @@ -160,13 +161,17 @@ async def test_handle_event_to_device_attribute_response(): session = DeviceSession(device_info) callback = MagicMock() session.set_attribute_response_callback(callback) - + # Create an attribute response event - gateway_requested_attribute_response = GatewayRequestedAttributeResponse(request_id=1, device_name="test_device", shared=[AttributeEntry(key="shared_key", value="shared_value")], client=[]) - + gateway_requested_attribute_response = GatewayRequestedAttributeResponse(request_id=1, device_name="test_device", + shared=[ + AttributeEntry(key="shared_key", + value="shared_value")], + client=[]) + # Act result = await session.handle_event_to_device(gateway_requested_attribute_response) - + # Assert callback.assert_called_once_with(session, gateway_requested_attribute_response) assert result == callback.return_value @@ -179,7 +184,7 @@ async def test_handle_event_to_device_rpc_request(): session = DeviceSession(device_info) callback = MagicMock() session.set_rpc_request_callback(callback) - + # Create an RPC request event request_dict = { "device": "test_device", @@ -190,10 +195,10 @@ async def test_handle_event_to_device_rpc_request(): } } rpc_request = GatewayRPCRequest._deserialize_from_dict(request_dict) - + # Act result = await session.handle_event_to_device(rpc_request) - + # Assert callback.assert_called_once_with(session, rpc_request) assert result == callback.return_value @@ -204,14 +209,15 @@ async def test_handle_event_to_device_no_callback(): # Setup device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) - + # Create an attribute update event attribute_update = AttributeUpdate([AttributeEntry(key="shared_key", value="shared_value")]) - gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, attribute_update=attribute_update) - + gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, + attribute_update=attribute_update) + # Act result = await session.handle_event_to_device(gateway_attribute_update.attribute_update) - + # Assert assert result is None @@ -230,19 +236,19 @@ async def test_handle_event_to_device_async_callback(): "params": {"param": "value"} } } - + # Create an async callback async def async_callback(session, event: GatewayRPCRequest): return GatewayRPCResponse.build(device_name=event.device_name, request_id=event.request_id, result="success") - + session.set_rpc_request_callback(async_callback) - + # Create an RPC request event rpc_request = GatewayRPCRequest._deserialize_from_dict(rpc_request) - + # Act result = await session.handle_event_to_device(rpc_request) - + # Assert assert isinstance(result, GatewayRPCResponse) assert result.device_name == "test_device" @@ -255,14 +261,14 @@ async def test_handle_event_to_device_unsupported_event(): # Setup device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) - + # Create a mock event with an unsupported event type mock_event = MagicMock() mock_event.event_type = "UNSUPPORTED_EVENT_TYPE" - + # Act result = await session.handle_event_to_device(mock_event) - + # Assert assert result is None @@ -272,16 +278,16 @@ def test_equality(): device_info1 = DeviceInfo("test_device", "default") device_info2 = DeviceInfo("test_device", "default") # Same device name, but different UUID device_info3 = DeviceInfo("other_device", "default") - + session1 = DeviceSession(device_info1) session2 = DeviceSession(device_info2) session3 = DeviceSession(device_info3) - + # Act & Assert assert session1 != session2 # Different UUIDs assert session1 != session3 assert session2 != session3 - + # Test equality with the same device_info session4 = DeviceSession(device_info1) assert session1 == session4 # Same device_info @@ -291,13 +297,13 @@ def test_hash(): # Setup device_info = DeviceInfo("test_device", "default") session = DeviceSession(device_info) - + # Act hash_value = hash(session) - + # Assert assert hash_value == hash(device_info) - + # Test that sessions with the same device_info have the same hash session2 = DeviceSession(device_info) assert hash(session) == hash(session2) diff --git a/tests/service/gateway/test_direct_event_dispatcher.py b/tests/service/gateway/test_direct_event_dispatcher.py index 3de95d1..5d1d2ea 100644 --- a/tests/service/gateway/test_direct_event_dispatcher.py +++ b/tests/service/gateway/test_direct_event_dispatcher.py @@ -18,7 +18,6 @@ import pytest -from tb_mqtt_client.entities.gateway.device_info import DeviceInfo from tb_mqtt_client.entities.gateway.event_type import GatewayEventType from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent from tb_mqtt_client.service.gateway.device_session import DeviceSession @@ -28,7 +27,7 @@ def test_init(): # Setup & Act dispatcher = DirectEventDispatcher() - + # Assert assert isinstance(dispatcher._handlers, dict) assert len(dispatcher._handlers) == 0 @@ -40,10 +39,10 @@ def test_register(): dispatcher = DirectEventDispatcher() callback = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Act dispatcher.register(event_type, callback) - + # Assert assert event_type in dispatcher._handlers assert callback in dispatcher._handlers[event_type] @@ -55,11 +54,11 @@ def test_register_duplicate(): dispatcher = DirectEventDispatcher() callback = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Register the callback twice dispatcher.register(event_type, callback) dispatcher.register(event_type, callback) - + # Assert assert event_type in dispatcher._handlers assert callback in dispatcher._handlers[event_type] @@ -72,11 +71,11 @@ def test_register_multiple(): callback1 = MagicMock() callback2 = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Act dispatcher.register(event_type, callback1) dispatcher.register(event_type, callback2) - + # Assert assert event_type in dispatcher._handlers assert callback1 in dispatcher._handlers[event_type] @@ -89,11 +88,11 @@ def test_unregister(): dispatcher = DirectEventDispatcher() callback = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Register and then unregister dispatcher.register(event_type, callback) dispatcher.unregister(event_type, callback) - + # Assert assert event_type not in dispatcher._handlers # Event type should be removed when last callback is unregistered @@ -103,10 +102,10 @@ def test_unregister_nonexistent(): dispatcher = DirectEventDispatcher() callback = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Act - should not raise an exception dispatcher.unregister(event_type, callback) - + # Assert assert event_type not in dispatcher._handlers @@ -117,12 +116,12 @@ def test_unregister_one_of_many(): callback1 = MagicMock() callback2 = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Register both callbacks and unregister one dispatcher.register(event_type, callback1) dispatcher.register(event_type, callback2) dispatcher.unregister(event_type, callback1) - + # Assert assert event_type in dispatcher._handlers assert callback1 not in dispatcher._handlers[event_type] @@ -134,17 +133,16 @@ def test_unregister_one_of_many(): async def test_dispatch_to_device_session(): # Setup dispatcher = DirectEventDispatcher() - device_info = DeviceInfo("test_device", "default") device_session = MagicMock(spec=DeviceSession) device_session.handle_event_to_device = AsyncMock() - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = GatewayEventType.DEVICE_ATTRIBUTE_UPDATE - + # Act await dispatcher.dispatch(event, device_session=device_session) - + # Assert device_session.handle_event_to_device.assert_awaited_once_with(event) @@ -155,17 +153,17 @@ async def test_dispatch_to_sync_callback(): dispatcher = DirectEventDispatcher() callback = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Register the callback dispatcher.register(event_type, callback) - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = event_type - + # Act result = await dispatcher.dispatch(event) - + # Assert callback.assert_called_once_with(event) assert result == callback.return_value @@ -177,17 +175,17 @@ async def test_dispatch_to_async_callback(): dispatcher = DirectEventDispatcher() callback = AsyncMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Register the callback dispatcher.register(event_type, callback) - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = event_type - + # Act result = await dispatcher.dispatch(event) - + # Assert callback.assert_awaited_once_with(event) assert result == callback.return_value @@ -199,17 +197,17 @@ async def test_dispatch_with_args(): dispatcher = DirectEventDispatcher() callback = MagicMock() event_type = GatewayEventType.DEVICE_CONNECT - + # Register the callback dispatcher.register(event_type, callback) - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = event_type - + # Act await dispatcher.dispatch(event, "arg1", "arg2", kwarg1="value1", kwarg2="value2") - + # Assert callback.assert_called_once_with(event, "arg1", "arg2", kwarg1="value1", kwarg2="value2") @@ -218,14 +216,14 @@ async def test_dispatch_with_args(): async def test_dispatch_no_handlers(): # Setup dispatcher = DirectEventDispatcher() - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = GatewayEventType.DEVICE_CONNECT - + # Act result = await dispatcher.dispatch(event) - + # Assert assert result is None @@ -236,17 +234,17 @@ async def test_dispatch_callback_exception(): dispatcher = DirectEventDispatcher() callback = MagicMock(side_effect=Exception("Test exception")) event_type = GatewayEventType.DEVICE_CONNECT - + # Register the callback dispatcher.register(event_type, callback) - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = event_type - + # Act result = await dispatcher.dispatch(event) - + # Assert callback.assert_called_once_with(event) assert result is None @@ -258,21 +256,21 @@ async def test_dispatch_async_callback_exception(): dispatcher = DirectEventDispatcher() callback = AsyncMock(side_effect=Exception("Test exception")) event_type = GatewayEventType.DEVICE_CONNECT - + # Register the callback dispatcher.register(event_type, callback) - + # Create a mock event event = MagicMock(spec=GatewayEvent) event.event_type = event_type - + # Act result = await dispatcher.dispatch(event) - + # Assert callback.assert_awaited_once_with(event) assert result is None if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/service/gateway/test_gateway_client.py b/tests/service/gateway/test_gateway_client.py index d757a00..61248b3 100644 --- a/tests/service/gateway/test_gateway_client.py +++ b/tests/service/gateway/test_gateway_client.py @@ -33,40 +33,21 @@ from tb_mqtt_client.service.gateway.device_session import DeviceSession -# @pytest.mark.asyncio -# async def test_connect(): -# # Setup -# client = GatewayClient() -# client._mqtt_manager = AsyncMock() -# client._mqtt_manager.is_connected = MagicMock(return_value=True) -# client._mqtt_manager.await_ready = AsyncMock() -# client._mqtt_manager.subscribe = AsyncMock(return_value=asyncio.Future()) -# client._mqtt_manager.register_handler = MagicMock() -# -# # Act -# await client.connect() -# -# # Assert -# client._mqtt_manager.connect.assert_awaited_once() -# assert client._mqtt_manager.subscribe.call_count == 3 # Should subscribe to 3 topics -# assert client._mqtt_manager.register_handler.call_count == 3 # Should register 3 handlers - - @pytest.mark.asyncio async def test_connect_device(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() client._event_dispatcher.dispatch = AsyncMock() - + # Create a future that will be returned by dispatch future = asyncio.Future() future.set_result(PublishResult(topic=GATEWAY_CONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Act device_session, result = await client.connect_device("test_device", wait_for_publish=True) - + # Assert assert device_session is not None assert device_session.device_info.device_name == "test_device" @@ -81,18 +62,18 @@ async def test_connect_device_with_connect_message(): client = GatewayClient() client._event_dispatcher = AsyncMock() client._event_dispatcher.dispatch = AsyncMock() - + # Create a future that will be returned by dispatch future = asyncio.Future() future.set_result(PublishResult(topic=GATEWAY_CONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Create a device connect message connect_message = DeviceConnectMessage.build("test_device", "custom_profile") - + # Act device_session, result = await client.connect_device(connect_message, wait_for_publish=True) - + # Assert assert device_session is not None assert device_session.device_info.device_name == "test_device" @@ -107,14 +88,14 @@ async def test_connect_device_no_wait(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a future that will be returned by dispatch future = asyncio.Future() client._event_dispatcher.dispatch.return_value = [future] - + # Act device_session, futures = await client.connect_device("test_device", wait_for_publish=False) - + # Assert assert device_session is not None assert device_session.device_info.device_name == "test_device" @@ -127,19 +108,20 @@ async def test_disconnect_device(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() - future.set_result(PublishResult(topic=GATEWAY_DISCONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + future.set_result( + PublishResult(topic=GATEWAY_DISCONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Act session, result = await client.disconnect_device(device_session, wait_for_publish=True) - + # Assert assert session is not None assert isinstance(result, PublishResult) @@ -152,18 +134,18 @@ async def test_disconnect_device_no_wait(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() client._event_dispatcher.dispatch.return_value = [future] - + # Act session, futures = await client.disconnect_device(device_session, wait_for_publish=False) - + # Assert assert session is not None assert isinstance(futures[0], asyncio.Future) @@ -175,19 +157,20 @@ async def test_send_device_timeseries(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() - future.set_result(PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + future.set_result( + PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Act result = await client.send_device_timeseries(device_session, {"temperature": 25.5}, wait_for_publish=True) - + # Assert assert isinstance(result, PublishResult) assert result.topic == GATEWAY_TELEMETRY_TOPIC @@ -199,22 +182,23 @@ async def test_send_device_timeseries_with_timeseries_entry(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() - future.set_result(PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + future.set_result( + PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Create a timeseries entry timeseries = TimeseriesEntry("temperature", 25.5) - + # Act result = await client.send_device_timeseries(device_session, timeseries, wait_for_publish=True) - + # Assert assert isinstance(result, PublishResult) assert result.topic == GATEWAY_TELEMETRY_TOPIC @@ -226,19 +210,19 @@ async def test_send_device_timeseries_no_wait(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() client._event_dispatcher.dispatch.return_value = [future] - + # Act future = await client.send_device_timeseries(device_session, {"temperature": 25.5}, wait_for_publish=False) - + # Assert assert isinstance(future, asyncio.Future) client._event_dispatcher.dispatch.assert_awaited_once() @@ -249,20 +233,21 @@ async def test_send_device_attributes(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() - future.set_result(PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + future.set_result( + PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Act result = await client.send_device_attributes(device_session, {"firmware_version": "1.0.0"}, wait_for_publish=True) - + # Assert assert isinstance(result, PublishResult) assert result.topic == GATEWAY_ATTRIBUTES_TOPIC @@ -274,23 +259,24 @@ async def test_send_device_attributes_with_attribute_entry(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() - future.set_result(PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + future.set_result( + PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Create an attribute entry attributes = AttributeEntry("firmware_version", "1.0.0") - + # Act result = await client.send_device_attributes(device_session, attributes, wait_for_publish=True) - + # Assert assert isinstance(result, PublishResult) assert result.topic == GATEWAY_ATTRIBUTES_TOPIC @@ -302,19 +288,19 @@ async def test_send_device_attributes_no_wait(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() client._event_dispatcher.dispatch.return_value = [future] - + # Act future = await client.send_device_attributes(device_session, {"firmware_version": "1.0.0"}, wait_for_publish=False) - + # Assert assert isinstance(future, asyncio.Future) client._event_dispatcher.dispatch.assert_awaited_once() @@ -326,23 +312,24 @@ async def test_send_device_attributes_request(): client = GatewayClient() client._event_dispatcher = AsyncMock() client._gateway_requested_attribute_response_handler = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() - future.set_result(PublishResult(topic=GATEWAY_ATTRIBUTES_REQUEST_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + future.set_result( + PublishResult(topic=GATEWAY_ATTRIBUTES_REQUEST_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Create a gateway attribute request request = await GatewayAttributeRequest.build(device_session, shared_keys=["firmware_version"], client_keys=None) - + # Act result = await client.send_device_attributes_request(device_session, request, wait_for_publish=True) - + # Assert assert isinstance(result, PublishResult) assert result.topic == GATEWAY_ATTRIBUTES_REQUEST_TOPIC @@ -356,22 +343,22 @@ async def test_send_device_attributes_request_no_wait(): client = GatewayClient() client._event_dispatcher = AsyncMock() client._gateway_requested_attribute_response_handler = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that will be returned by dispatch future = asyncio.Future() client._event_dispatcher.dispatch.return_value = [future] - + # Create a gateway attribute request request = await GatewayAttributeRequest.build(device_session, shared_keys=["firmware_version"], client_keys=None) - + # Act future = await client.send_device_attributes_request(device_session, request, wait_for_publish=False) - + # Assert assert isinstance(future, asyncio.Future) client._gateway_requested_attribute_response_handler.register_request.assert_awaited_once() @@ -383,24 +370,25 @@ async def test_send_device_claim_request(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that dispatch will return future = asyncio.Future() future.set_result(PublishResult(topic=GATEWAY_CLAIM_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) client._event_dispatcher.dispatch.return_value = [future] - + # Create a gateway claim request device_claim_request = ClaimRequest.build(secret_key="secret", duration=1000) - gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, device_claim_request).build() - + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, + device_claim_request).build() + # Act result = await client.send_device_claim_request(device_session, gateway_claim_request, wait_for_publish=True) - + # Assert assert isinstance(result, PublishResult) assert result.topic == GATEWAY_CLAIM_TOPIC @@ -412,23 +400,24 @@ async def test_send_device_claim_request_no_wait(): # Setup client = GatewayClient() client._event_dispatcher = AsyncMock() - + # Create a device session - + info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info=info) - + # Create a future that dispatch will return future = asyncio.Future() client._event_dispatcher.dispatch.return_value = [future] - + # Create a gateway claim request device_claim_request = ClaimRequest.build(secret_key="secret", duration=1000) - gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, device_claim_request).build() - + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, + device_claim_request).build() + # Act future = await client.send_device_claim_request(device_session, gateway_claim_request, wait_for_publish=False) - + # Assert assert isinstance(future, asyncio.Future) client._event_dispatcher.dispatch.assert_awaited_once() @@ -440,10 +429,10 @@ async def test_disconnect(): client = GatewayClient() client._mqtt_manager = AsyncMock() client._mqtt_manager.unsubscribe = AsyncMock(return_value=asyncio.Future()) - + # Act await client.disconnect() - + # Assert client._mqtt_manager.unsubscribe.assert_awaited() client._mqtt_manager.disconnect.assert_awaited_once() @@ -460,7 +449,7 @@ async def test_handle_rate_limit_response(): client._gateway_rate_limiter.message_rate_limit = AsyncMock() client._gateway_rate_limiter.telemetry_message_rate_limit = AsyncMock() client._gateway_rate_limiter.telemetry_datapoints_rate_limit = AsyncMock() - + # Create a response with gateway rate limits response = RPCResponse.build(1, result={ 'gatewayRateLimits': { @@ -470,12 +459,12 @@ async def test_handle_rate_limit_response(): }, 'maxPayloadSize': 512 }) - + # Mock the parent class method with patch('tb_mqtt_client.service.device.client.DeviceClient._handle_rate_limit_response', return_value=True): # Act result = await client._handle_rate_limit_response(response) - + # Assert assert result is True client._gateway_rate_limiter.message_rate_limit.set_limit.assert_awaited_once() diff --git a/tests/service/gateway/test_message_adapter.py b/tests/service/gateway/test_message_adapter.py index b714a3c..b5f2372 100644 --- a/tests/service/gateway/test_message_adapter.py +++ b/tests/service/gateway/test_message_adapter.py @@ -44,7 +44,7 @@ def test_init(): # Setup & Act adapter = JsonGatewayMessageAdapter(max_payload_size=1000, max_datapoints=100) - + # Assert assert adapter.splitter.max_payload_size == 1000 - DEFAULT_FIELDS_SIZE assert adapter.splitter.max_datapoints == 100 @@ -54,15 +54,15 @@ def test_build_device_connect_message_payload(): # Setup adapter = JsonGatewayMessageAdapter() device_connect_message = DeviceConnectMessage.build("test_device", "default") - + # Act result = adapter.build_device_connect_message_payload(device_connect_message, qos=1) - + # Assert assert isinstance(result, MqttPublishMessage) assert result.topic == GATEWAY_CONNECT_TOPIC assert result.qos == 1 - + # Verify payload content payload_dict = orjson.loads(result.payload) assert "device" in payload_dict @@ -75,15 +75,15 @@ def test_build_device_disconnect_message_payload(): # Setup adapter = JsonGatewayMessageAdapter() device_disconnect_message = DeviceDisconnectMessage.build("test_device") - + # Act result = adapter.build_device_disconnect_message_payload(device_disconnect_message, qos=1) - + # Assert assert isinstance(result, MqttPublishMessage) assert result.topic == GATEWAY_DISCONNECT_TOPIC assert result.qos == 1 - + # Verify payload content payload_dict = orjson.loads(result.payload) assert "device" in payload_dict @@ -97,15 +97,15 @@ async def test_build_gateway_attribute_request_payload(): device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) - + # Act result = adapter.build_gateway_attribute_request_payload(attribute_request, qos=1) - + # Assert assert isinstance(result, MqttPublishMessage) assert result.topic == GATEWAY_ATTRIBUTES_REQUEST_TOPIC assert result.qos == 1 - + # Verify payload content payload_dict = orjson.loads(result.payload) assert "device" in payload_dict @@ -120,15 +120,15 @@ def test_build_rpc_response_payload(): # Setup adapter = JsonGatewayMessageAdapter() rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") - + # Act result = adapter.build_rpc_response_payload(rpc_response, qos=1) - + # Assert assert isinstance(result, MqttPublishMessage) assert result.topic == GATEWAY_RPC_TOPIC assert result.qos == 1 - + # Verify payload content payload_dict = orjson.loads(result.payload) assert "device" in payload_dict @@ -143,17 +143,19 @@ def test_build_claim_request_payload(): # Setup adapter = JsonGatewayMessageAdapter() device_claim_request = ClaimRequest.build(secret_key="secret", duration=1) - gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_name_or_session="test_device", - device_claim_request=device_claim_request).build() - + gateway_claim_request = (GatewayClaimRequestBuilder() + .add_device_request(device_name_or_session="test_device", + device_claim_request=device_claim_request) + .build()) + # Act result = adapter.build_claim_request_payload(gateway_claim_request, qos=1) - + # Assert assert isinstance(result, MqttPublishMessage) assert result.topic == GATEWAY_CLAIM_TOPIC assert result.qos == 1 - + # Verify payload content payload_dict = orjson.loads(result.payload) assert "test_device" in payload_dict @@ -171,10 +173,10 @@ def test_parse_attribute_update(): "key2": 42 } } - + # Act result = adapter.parse_attribute_update(data) - + # Assert assert isinstance(result, GatewayAttributeUpdate) assert result.device_name == "test_device" @@ -188,7 +190,7 @@ def test_parse_attribute_update_invalid_format(): data = { "invalid_format": True } - + # Act & Assert with pytest.raises(ValueError): adapter.parse_attribute_update(data) @@ -201,7 +203,7 @@ async def test_parse_gateway_requested_attribute_response(): device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=["key1", "key2"]) - + data = { "device": "test_device", "id": attribute_request.request_id, @@ -210,10 +212,10 @@ async def test_parse_gateway_requested_attribute_response(): "key2": 42 } } - + # Act result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) - + # Assert assert isinstance(result, GatewayRequestedAttributeResponse) assert result.device_name == "test_device" @@ -231,16 +233,16 @@ async def test_parse_gateway_requested_attribute_response_single_value(): device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=["key1"]) - + data = { "device": "test_device", "id": attribute_request.request_id, "value": "value1" } - + # Act result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) - + # Assert assert isinstance(result, GatewayRequestedAttributeResponse) assert result.device_name == "test_device" @@ -258,7 +260,7 @@ async def test_parse_gateway_requested_attribute_response_shared_keys(): device_info = DeviceInfo("test_device", "default") device_session = DeviceSession(device_info) attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) - + data = { "device": "test_device", "id": attribute_request.request_id, @@ -267,10 +269,10 @@ async def test_parse_gateway_requested_attribute_response_shared_keys(): "key2": 42 } } - + # Act result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) - + # Assert assert isinstance(result, GatewayRequestedAttributeResponse) assert result.device_name == "test_device" @@ -295,10 +297,10 @@ def test_parse_rpc_request(): } } } - + # Act result = adapter.parse_rpc_request("v1/gateway/rpc", data) - + # Assert assert isinstance(result, GatewayRPCRequest) assert result.device_name == "test_device" @@ -313,7 +315,7 @@ def test_parse_rpc_request_invalid_format(): data = { "invalid_format": True } - + # Act & Assert with pytest.raises(ValueError): adapter.parse_rpc_request("v1/gateway/rpc", data) @@ -323,10 +325,10 @@ def test_deserialize_to_dict(): # Setup adapter = JsonGatewayMessageAdapter() payload = b'{"key": "value", "number": 42}' - + # Act result = adapter.deserialize_to_dict(payload) - + # Assert assert isinstance(result, dict) assert result == {"key": "value", "number": 42} @@ -336,7 +338,7 @@ def test_deserialize_to_dict_invalid_format(): # Setup adapter = JsonGatewayMessageAdapter() payload = b'invalid json' - + # Act & Assert with pytest.raises(ValueError): adapter.deserialize_to_dict(payload) @@ -349,10 +351,10 @@ def test_pack_attributes(): AttributeEntry("key2", 42) ] uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes(attributes).build() - + # Act result = JsonGatewayMessageAdapter.pack_attributes(uplink_message) - + # Assert assert isinstance(result, dict) assert result == {"key1": "value1", "key2": 42} @@ -362,10 +364,10 @@ def test_pack_timeseries_no_timestamp(): # Setup timeseries = [TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)] uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() - + # Act result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) - + # Assert assert isinstance(result, list) assert len(result) == 1 @@ -377,10 +379,10 @@ def test_pack_timeseries_with_timestamp(): ts = int(datetime.now(UTC).timestamp() * 1000) timeseries = [TimeseriesEntry("temp", 22.5, ts), TimeseriesEntry("humidity", 45, ts)] uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() - + # Act result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) - + # Assert assert isinstance(result, list) assert len(result) == 1 @@ -396,18 +398,18 @@ def test_pack_timeseries_with_different_timestamps(): ts2 = ts1 + 1000 # 1 second later timeseries = [TimeseriesEntry("temp", 22.5, ts1), TimeseriesEntry("humidity", 45, ts2)] uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() - + # Act result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) - + # Assert assert isinstance(result, list) assert len(result) == 2 - + # Find entries by timestamp ts1_entry = next((entry for entry in result if entry["ts"] == ts1), None) ts2_entry = next((entry for entry in result if entry["ts"] == ts2), None) - + assert ts1_entry is not None assert ts2_entry is not None assert ts1_entry["values"] == {"temp": 22.5} @@ -417,10 +419,10 @@ def test_pack_timeseries_with_different_timestamps(): def test_build_uplink_messages_empty(): # Setup adapter = JsonGatewayMessageAdapter() - + # Act result = adapter.build_uplink_messages([]) - + # Assert assert result == [] @@ -429,10 +431,10 @@ def test_build_uplink_messages_non_gateway_message(): # Setup adapter = JsonGatewayMessageAdapter() mqtt_msg = MqttPublishMessage(topic="test/topic", payload=b"test_payload", qos=1) - + # Act result = adapter.build_uplink_messages([mqtt_msg]) - + # Assert assert len(result) == 1 assert result[0] == mqtt_msg @@ -441,23 +443,23 @@ def test_build_uplink_messages_non_gateway_message(): def test_build_uplink_messages_with_telemetry(): # Setup adapter = JsonGatewayMessageAdapter() - + # Create a gateway uplink message with telemetry uplink_message = (GatewayUplinkMessageBuilder() .set_device_name("test_device") .add_timeseries([TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)]) .build()) - + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) - + # Act result = adapter.build_uplink_messages([mqtt_msg]) - + # Assert assert len(result) == 1 assert result[0].topic == GATEWAY_TELEMETRY_TOPIC assert result[0].qos == 1 - + # Verify payload content payload_dict = orjson.loads(result[0].payload) assert "test_device" in payload_dict @@ -470,23 +472,23 @@ def test_build_uplink_messages_with_telemetry(): def test_build_uplink_messages_with_attributes(): # Setup adapter = JsonGatewayMessageAdapter() - + # Create a gateway uplink message with attributes uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes([ AttributeEntry("key1", "value1"), AttributeEntry("key2", 42) ]).build() - + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) - + # Act result = adapter.build_uplink_messages([mqtt_msg]) - + # Assert assert len(result) == 1 assert result[0].topic == GATEWAY_ATTRIBUTES_TOPIC assert result[0].qos == 1 - + # Verify payload content payload_dict = orjson.loads(result[0].payload) assert "test_device" in payload_dict @@ -496,28 +498,28 @@ def test_build_uplink_messages_with_attributes(): def test_build_uplink_messages_with_both_telemetry_and_attributes(): # Setup adapter = JsonGatewayMessageAdapter() - + # Create a gateway uplink message with both telemetry and attributes uplink_message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") uplink_message_builder.add_timeseries([TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)]) uplink_message_builder.add_attributes([AttributeEntry("key1", "value1"), AttributeEntry("key2", 42)]) uplink_message = uplink_message_builder.build() - + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) - + # Act result = adapter.build_uplink_messages([mqtt_msg]) - + # Assert assert len(result) == 2 - + # Find telemetry and attribute messages telemetry_msg = next((msg for msg in result if msg.topic == GATEWAY_TELEMETRY_TOPIC), None) attribute_msg = next((msg for msg in result if msg.topic == GATEWAY_ATTRIBUTES_TOPIC), None) - + assert telemetry_msg is not None assert attribute_msg is not None - + # Verify telemetry payload telemetry_dict = orjson.loads(telemetry_msg.payload) assert "test_device" in telemetry_dict @@ -525,7 +527,7 @@ def test_build_uplink_messages_with_both_telemetry_and_attributes(): assert len(telemetry_dict["test_device"]) == 1 assert "values" in telemetry_dict["test_device"][0] assert telemetry_dict["test_device"][0]["values"] == {"temp": 22.5, "humidity": 45} - + # Verify attribute payload attribute_dict = orjson.loads(attribute_msg.payload) assert "test_device" in attribute_dict diff --git a/tests/service/gateway/test_message_sender.py b/tests/service/gateway/test_message_sender.py index f5ec49a..3c83c94 100644 --- a/tests/service/gateway/test_message_sender.py +++ b/tests/service/gateway/test_message_sender.py @@ -38,7 +38,7 @@ def test_init(): # Setup & Act sender = GatewayMessageSender() - + # Assert assert sender._message_queue is None assert sender._message_adapter is None @@ -48,10 +48,10 @@ def test_set_message_queue(): # Setup sender = GatewayMessageSender() message_queue = MagicMock(spec=MessageService) - + # Act sender.set_message_queue(message_queue) - + # Assert assert sender._message_queue == message_queue @@ -60,10 +60,10 @@ def test_set_message_adapter(): # Setup sender = GatewayMessageSender() message_adapter = MagicMock(spec=GatewayMessageAdapter) - + # Act sender.set_message_adapter(message_adapter) - + # Assert assert sender._message_adapter == message_adapter @@ -74,29 +74,29 @@ async def test_send_uplink_message_with_timeseries(): sender = GatewayMessageSender() message_queue = AsyncMock(spec=MessageService) sender.set_message_queue(message_queue) - + # Create an uplink message with timeseries uplink_message = (GatewayUplinkMessageBuilder() .set_device_name("test_device") .add_timeseries([TimeseriesEntry("temp", 22.5)]) .build()) - + # Mock the publish method to set delivery_futures async def mock_publish(mqtt_message): mqtt_message.delivery_futures = [asyncio.Future()] return [mqtt_message] - + message_queue.publish.side_effect = mock_publish - + # Act result = await sender.send_uplink_message(uplink_message) - + # Assert assert result is not None assert len(result) == 1 assert isinstance(result[0], asyncio.Future) message_queue.publish.assert_called_once() - + # Verify the MQTT message mqtt_message = message_queue.publish.call_args[0][0] assert mqtt_message.topic == GATEWAY_TELEMETRY_TOPIC @@ -110,29 +110,29 @@ async def test_send_uplink_message_with_attributes(): sender = GatewayMessageSender() message_queue = AsyncMock(spec=MessageService) sender.set_message_queue(message_queue) - + # Create an uplink message with attributes uplink_message = (GatewayUplinkMessageBuilder() .set_device_name("test_device") .add_attributes([AttributeEntry("temperature", 22.5)]) .build()) - + # Mock the publish method to set delivery_futures async def mock_publish(mqtt_message): mqtt_message.delivery_futures = [asyncio.Future()] return [mqtt_message] - + message_queue.publish.side_effect = mock_publish - + # Act result = await sender.send_uplink_message(uplink_message) - + # Assert assert result is not None assert len(result) == 1 assert isinstance(result[0], asyncio.Future) message_queue.publish.assert_called_once() - + # Verify the MQTT message mqtt_message = message_queue.publish.call_args[0][0] assert mqtt_message.topic == GATEWAY_ATTRIBUTES_TOPIC @@ -146,31 +146,31 @@ async def test_send_uplink_message_with_both(): sender = GatewayMessageSender() message_queue = AsyncMock(spec=MessageService) sender.set_message_queue(message_queue) - + # Create an uplink message with both timeseries and attributes uplink_message = (GatewayUplinkMessageBuilder() .set_device_name("test_device") .add_timeseries([TimeseriesEntry("temp", 22.5)]) .add_attributes(AttributeEntry("key1", "value1")) .build()) - + # Mock the publish method to set delivery_futures async def mock_publish(mqtt_message): mqtt_message.delivery_futures = [asyncio.Future()] return [mqtt_message] - + message_queue.publish.side_effect = mock_publish - + # Act result = await sender.send_uplink_message(uplink_message) - + # Assert assert result is not None assert len(result) == 2 # One for telemetry, one for attributes assert isinstance(result[0], asyncio.Future) assert isinstance(result[1], asyncio.Future) assert message_queue.publish.call_count == 2 - + # Verify the MQTT messages call_args_list = message_queue.publish.call_args_list assert call_args_list[0][0][0].topic == GATEWAY_TELEMETRY_TOPIC @@ -187,23 +187,24 @@ async def test_send_device_connect(): message_adapter = MagicMock(spec=GatewayMessageAdapter) sender.set_message_queue(message_queue) sender.set_message_adapter(message_adapter) - + # Create a device connect message device_connect_message = DeviceConnectMessage.build("test_device", "default") - + # Mock the message adapter mqtt_message = MqttPublishMessage(GATEWAY_CONNECT_TOPIC, b'{"device":"test_device","type":"default"}', qos=1) mqtt_message.delivery_futures = [asyncio.Future()] message_adapter.build_device_connect_message_payload.return_value = mqtt_message - + # Act result = await sender.send_device_connect(device_connect_message) - + # Assert assert result is not None assert len(result) == 1 assert isinstance(result[0], asyncio.Future) - message_adapter.build_device_connect_message_payload.assert_called_once_with(device_connect_message=device_connect_message, qos=1) + message_adapter.build_device_connect_message_payload.assert_called_once_with( + device_connect_message=device_connect_message, qos=1) message_queue.publish.assert_called_once_with(mqtt_message) @@ -215,23 +216,24 @@ async def test_send_device_disconnect(): message_adapter = MagicMock(spec=GatewayMessageAdapter) sender.set_message_queue(message_queue) sender.set_message_adapter(message_adapter) - + # Create a device disconnect message device_disconnect_message = DeviceDisconnectMessage.build("test_device") - + # Mock the message adapter mqtt_message = MqttPublishMessage(GATEWAY_DISCONNECT_TOPIC, b'{"device":"test_device"}', qos=1) mqtt_message.delivery_futures = [asyncio.Future()] message_adapter.build_device_disconnect_message_payload.return_value = mqtt_message - + # Act result = await sender.send_device_disconnect(device_disconnect_message) - + # Assert assert result is not None assert len(result) == 1 assert isinstance(result[0], asyncio.Future) - message_adapter.build_device_disconnect_message_payload.assert_called_once_with(device_disconnect_message=device_disconnect_message, qos=1) + message_adapter.build_device_disconnect_message_payload.assert_called_once_with( + device_disconnect_message=device_disconnect_message, qos=1) message_queue.publish.assert_called_once_with(mqtt_message) @@ -243,23 +245,25 @@ async def test_send_attributes_request(): message_adapter = MagicMock(spec=GatewayMessageAdapter) sender.set_message_queue(message_queue) sender.set_message_adapter(message_adapter) - + # Create an attribute request attribute_request = MagicMock(spec=GatewayAttributeRequest) - + # Mock the message adapter - mqtt_message = MqttPublishMessage(GATEWAY_ATTRIBUTES_REQUEST_TOPIC, b'{"device":"test_device","keys":["key1"]}', qos=1) + mqtt_message = MqttPublishMessage(GATEWAY_ATTRIBUTES_REQUEST_TOPIC, b'{"device":"test_device","keys":["key1"]}', + qos=1) mqtt_message.delivery_futures = [asyncio.Future()] message_adapter.build_gateway_attribute_request_payload.return_value = mqtt_message - + # Act result = await sender.send_attributes_request(attribute_request) - + # Assert assert result is not None assert len(result) == 1 assert isinstance(result[0], asyncio.Future) - message_adapter.build_gateway_attribute_request_payload.assert_called_once_with(attribute_request=attribute_request, qos=1) + message_adapter.build_gateway_attribute_request_payload.assert_called_once_with(attribute_request=attribute_request, + qos=1) message_queue.publish.assert_called_once_with(mqtt_message) @@ -271,18 +275,19 @@ async def test_send_rpc_response(): message_adapter = MagicMock(spec=GatewayMessageAdapter) sender.set_message_queue(message_queue) sender.set_message_adapter(message_adapter) - + # Create an RPC response rpc_response = MagicMock(spec=GatewayRPCResponse) - + # Mock the message adapter - mqtt_message = MqttPublishMessage(GATEWAY_RPC_TOPIC, b'{"device":"test_device","id":1,"data":{"result":"success"}}', qos=1) + mqtt_message = MqttPublishMessage(GATEWAY_RPC_TOPIC, b'{"device":"test_device","id":1,"data":{"result":"success"}}', + qos=1) mqtt_message.delivery_futures = [asyncio.Future()] message_adapter.build_rpc_response_payload.return_value = mqtt_message - + # Act result = await sender.send_rpc_response(rpc_response) - + # Assert assert result is not None assert len(result) == 1 @@ -299,18 +304,18 @@ async def test_send_claim_request(): message_adapter = MagicMock(spec=GatewayMessageAdapter) sender.set_message_queue(message_queue) sender.set_message_adapter(message_adapter) - + # Create a claim request claim_request = MagicMock(spec=GatewayClaimRequest) - + # Mock the message adapter mqtt_message = MqttPublishMessage(GATEWAY_CLAIM_TOPIC, b'{"device":"test_device","secretKey":"secret"}', qos=1) mqtt_message.delivery_futures = [asyncio.Future()] message_adapter.build_claim_request_payload.return_value = mqtt_message - + # Act result = await sender.send_claim_request(claim_request) - + # Assert assert result is not None assert len(result) == 1 @@ -320,4 +325,4 @@ async def test_send_claim_request(): if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/service/gateway/test_message_splitter.py b/tests/service/gateway/test_message_splitter.py index 6d0cad2..8200f64 100644 --- a/tests/service/gateway/test_message_splitter.py +++ b/tests/service/gateway/test_message_splitter.py @@ -31,7 +31,7 @@ def test_init_default(): # Setup & Act splitter = GatewayMessageSplitter() - + # Assert assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE assert splitter.max_datapoints == 0 @@ -40,7 +40,7 @@ def test_init_default(): def test_init_custom(): # Setup & Act splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) - + # Assert assert splitter.max_payload_size == 10000 - DEFAULT_FIELDS_SIZE assert splitter.max_datapoints == 100 @@ -49,7 +49,7 @@ def test_init_custom(): def test_init_invalid_values(): # Setup & Act splitter = GatewayMessageSplitter(max_payload_size=-1, max_datapoints=-1) - + # Assert assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE # Default value assert splitter.max_datapoints == 0 # Default value @@ -58,10 +58,10 @@ def test_init_invalid_values(): def test_max_payload_size_property(): # Setup splitter = GatewayMessageSplitter() - + # Act splitter.max_payload_size = 20000 - + # Assert assert splitter.max_payload_size == 20000 - DEFAULT_FIELDS_SIZE @@ -69,10 +69,10 @@ def test_max_payload_size_property(): def test_max_payload_size_property_invalid(): # Setup splitter = GatewayMessageSplitter() - + # Act splitter.max_payload_size = -1 - + # Assert assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE # Default value @@ -80,10 +80,10 @@ def test_max_payload_size_property_invalid(): def test_max_datapoints_property(): # Setup splitter = GatewayMessageSplitter() - + # Act splitter.max_datapoints = 200 - + # Assert assert splitter.max_datapoints == 200 @@ -91,10 +91,10 @@ def test_max_datapoints_property(): def test_max_datapoints_property_invalid(): # Setup splitter = GatewayMessageSplitter() - + # Act splitter.max_datapoints = -1 - + # Assert assert splitter.max_datapoints == 0 # Default value @@ -103,15 +103,14 @@ def test_max_datapoints_property_invalid(): async def test_split_timeseries_no_split_needed(): # Setup splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) - + # Create a message with timeseries that doesn't need splitting timeseries = TimeseriesEntry("temp", 22.5) message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() - # Act result = splitter.split_timeseries([message]) - + # Assert assert len(result) == 1 assert result[0] == message @@ -121,31 +120,31 @@ async def test_split_timeseries_no_split_needed(): async def test_split_timeseries_by_size(): # Setup splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) - + # Create a message with timeseries that needs splitting by size message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") - + # Mock the size property to force splitting timeseries_entry = TimeseriesEntry("temp", 22.5) - entries = [TimeseriesEntry("temp", 22.5) for _ in range(int(60/timeseries_entry.size + 1))] - entries.extend([TimeseriesEntry("humidity", 45) for _ in range(int(60/timeseries_entry.size + 1))]) + entries = [TimeseriesEntry("temp", 22.5) for _ in range(int(60 / timeseries_entry.size + 1))] + entries.extend([TimeseriesEntry("humidity", 45) for _ in range(int(60 / timeseries_entry.size + 1))]) message_builder.add_timeseries(entries) message = message_builder.build() - + # Mock the event loop loop_mock = MagicMock() future = asyncio.Future() loop_mock.create_future.return_value = future - + with patch('asyncio.get_running_loop', return_value=loop_mock): # Act result = splitter.split_timeseries([message]) - + # Assert assert len(result) == 2 assert result[0].device_name == "test_device" assert result[1].device_name == "test_device" - + # Check that each result has only one of the timeseries entries assert result[0].size - DEFAULT_FIELDS_SIZE < splitter.max_payload_size assert result[1].size - DEFAULT_FIELDS_SIZE < splitter.max_payload_size @@ -155,7 +154,7 @@ async def test_split_timeseries_by_size(): async def test_split_timeseries_by_datapoints(): # Setup splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=1) - + # Create a message with timeseries that needs splitting by datapoints message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") entries = [ @@ -164,25 +163,25 @@ async def test_split_timeseries_by_datapoints(): ] message_builder.add_timeseries(entries) message = message_builder.build() - + # Mock the event loop loop_mock = MagicMock() future = asyncio.Future() loop_mock.create_future.return_value = future - + with patch('asyncio.get_running_loop', return_value=loop_mock): # Act result = splitter.split_timeseries([message]) - + # Assert assert len(result) == 2 assert result[0].device_name == "test_device" assert result[1].device_name == "test_device" - + # Check that each result has only one of the timeseries entries assert len(result[0].timeseries[0]) == 1 assert len(result[1].timeseries[0]) == 1 - + # The entries should be in separate messages key0 = result[0].timeseries[0][0].key key1 = result[1].timeseries[0][0].key @@ -195,7 +194,7 @@ async def test_split_timeseries_by_datapoints(): async def test_split_timeseries_multiple_messages(): # Setup splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) - + # Create multiple messages with timeseries message1 = GatewayUplinkMessageBuilder().set_device_name("device1").add_timeseries( TimeseriesEntry("temp", 22.5) @@ -204,17 +203,17 @@ async def test_split_timeseries_multiple_messages(): message2 = GatewayUplinkMessageBuilder().set_device_name("device2").add_timeseries( TimeseriesEntry("humidity", 45) ).build() - + # Act result = splitter.split_timeseries([message1, message2]) - + # Assert assert len(result) == 2 - + # Check that the messages are for different devices devices = {msg.device_name for msg in result} assert devices == {"device1", "device2"} - + # Check that each result has the correct timeseries for msg in result: if msg.device_name == "device1": @@ -248,7 +247,7 @@ async def test_split_timeseries_with_delivery_futures(): loop_mock.create_future.return_value = future with patch('asyncio.get_running_loop', return_value=loop_mock), \ - patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: + patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: # Act result = splitter.split_timeseries([message]) @@ -263,16 +262,16 @@ async def test_split_timeseries_with_delivery_futures(): async def test_split_attributes_no_split_needed(): # Setup splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) - + # Create a message with attributes that doesn't need splitting message = (GatewayUplinkMessageBuilder() .set_device_name("test_device") .add_attributes(AttributeEntry("key1", "value1")) .build()) - + # Act result = splitter.split_attributes([message]) - + # Assert assert len(result) == 1 assert result[0] == message @@ -308,32 +307,32 @@ async def test_split_attributes_by_size(): async def test_split_attributes_by_datapoints(): # Setup splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=1) - + # Create a message with attributes that needs splitting by datapoints message = (GatewayUplinkMessageBuilder() .set_device_name("test_device") .add_attributes([AttributeEntry("key1", "value1"), AttributeEntry("key2", "value2")]) .build()) - + # Mock the event loop loop_mock = MagicMock() future = asyncio.Future() loop_mock.create_future.return_value = future - + with patch('asyncio.get_running_loop', return_value=loop_mock): # Act result = splitter.split_attributes([message]) - + # Assert assert len(result) == 2 assert result[0].device_name == "test_device" assert result[1].device_name == "test_device" - + # Check that each result has only one of the attribute entries assert len(result[0].attributes) == 1 assert len(result[1].attributes) == 1 - + # The entries should be in separate messages keys0 = {attr.key for attr in result[0].attributes} keys1 = {attr.key for attr in result[1].attributes} @@ -345,7 +344,7 @@ async def test_split_attributes_by_datapoints(): async def test_split_attributes_multiple_messages(): # Setup splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) - + # Create multiple messages with attributes message1 = (GatewayUplinkMessageBuilder() .set_device_name("device1") @@ -356,17 +355,17 @@ async def test_split_attributes_multiple_messages(): .set_device_name("device2") .add_attributes(AttributeEntry("key2", "value2")) .build()) - + # Act result = splitter.split_attributes([message1, message2]) - + # Assert assert len(result) == 2 - + # Check that the messages are for different devices devices = {msg.device_name for msg in result} assert devices == {"device1", "device2"} - + # Check that each result has the correct attributes for msg in result: if msg.device_name == "device1": @@ -398,7 +397,7 @@ async def test_split_attributes_with_delivery_futures(): loop_mock.create_future.return_value = future with patch('asyncio.get_running_loop', return_value=loop_mock), \ - patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: + patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: # Act result = splitter.split_attributes([message]) diff --git a/tests/service/test_json_message_adapter.py b/tests/service/test_json_message_adapter.py index 5ae23f6..a0a0a5c 100644 --- a/tests/service/test_json_message_adapter.py +++ b/tests/service/test_json_message_adapter.py @@ -45,7 +45,6 @@ def dummy_provisioning_request(): return ProvisioningRequest("some-device", credentials, device_name="dev", gateway=False) - def build_msg(device="devX", with_attr=False, with_ts=False): builder = DeviceUplinkMessageBuilder().set_device_name(device) if with_attr: @@ -54,6 +53,7 @@ def build_msg(device="devX", with_attr=False, with_ts=False): builder.add_timeseries(TimeseriesEntry("t", 2, ts=1234567890)) return builder.build() + @pytest.fixture def adapter(): return JsonMessageAdapter() @@ -256,13 +256,14 @@ async def test_build_uplink_payloads_both(adapter: JsonMessageAdapter): initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) with patch.object(adapter._splitter, "split_attributes", return_value=[msg]), \ - patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): + patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): result = adapter.build_uplink_messages([initial_mqtt_message]) assert len(result) == 2 topics = {r.topic for r in result} assert DEVICE_ATTRIBUTES_TOPIC in topics assert DEVICE_TELEMETRY_TOPIC in topics + def test_build_payload_without_device_name(adapter: JsonMessageAdapter): builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 9)) msg = builder.build() @@ -291,7 +292,8 @@ def test_pack_timeseries_no_ts(monkeypatch): def test_build_uplink_payloads_error_handling(adapter: JsonMessageAdapter): - with patch("tb_mqtt_client.service.device.message_adapter.DeviceUplinkMessage.has_attributes", side_effect=Exception("boom")): + with patch("tb_mqtt_client.service.device.message_adapter.DeviceUplinkMessage.has_attributes", + side_effect=Exception("boom")): msg = build_msg(with_attr=True) initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) with pytest.raises(Exception, match="boom"): @@ -320,5 +322,6 @@ def test_parse_provisioning_response_failure(adapter, dummy_provisioning_request assert args[1]["status"] == "FAILURE" assert "errorMsg" in args[1] + if __name__ == '__main__': - pytest.main([__file__, "--tb=short", "-v"]) \ No newline at end of file + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/service/test_message_service.py b/tests/service/test_message_service.py index 8a2a2e1..956cd0f 100644 --- a/tests/service/test_message_service.py +++ b/tests/service/test_message_service.py @@ -80,7 +80,9 @@ async def setup_message_service(): service._rate_limit_refill_task = MagicMock() service.__print_queue_statistics_task = MagicMock() - yield service, mqtt_manager, main_stop_event, device_rate_limiter, message_adapter, gateway_message_adapter, gateway_rate_limiter + yield (service, mqtt_manager, main_stop_event, device_rate_limiter, + message_adapter, gateway_message_adapter, gateway_rate_limiter) + @pytest_asyncio.fixture async def setup_retry_loop_service(setup_message_service): @@ -99,6 +101,7 @@ async def setup_retry_loop_service(setup_message_service): return service, mqtt_manager + @pytest.mark.asyncio async def test_publish_success(setup_message_service): service, mqtt_manager, _, _, _, _, _ = setup_message_service @@ -163,10 +166,10 @@ async def mock_task(): # Patch the methods that would create tasks to return our mock task with patch.object(MessageService, '_dispatch_initial_queue_loop', return_value=mock_task()), \ - patch.object(MessageService, '_dispatch_queue_loop', return_value=mock_task()), \ - patch.object(MessageService, '_rate_limit_refill_loop', return_value=mock_task()), \ - patch.object(MessageService, 'print_queues_statistics', return_value=mock_task()), \ - patch.object(MessageService, 'clear', new_callable=AsyncMock) as mock_clear: + patch.object(MessageService, '_dispatch_queue_loop', return_value=mock_task()), \ + patch.object(MessageService, '_rate_limit_refill_loop', return_value=mock_task()), \ + patch.object(MessageService, 'print_queues_statistics', return_value=mock_task()), \ + patch.object(MessageService, 'clear', new_callable=AsyncMock) as mock_clear: # Create the service service = MessageService( @@ -503,14 +506,14 @@ async def test_print_queue_statistics(setup_message_service): service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service # Create a patched version of the print_queue_statistics method to avoid infinite loop - original_print_queue_statistics = service.print_queues_statistics + original_print_queue_statistics = service.print_queues_statistics # noqa async def patched_print_queue_statistics(): # Just run the body of the loop once - initial_queue_size = service._initial_queue.size() - service_queue_size = service._service_queue.size() - device_uplink_queue_size = service._device_uplink_messages_queue.size() - gateway_uplink_queue_size = service._gateway_uplink_messages_queue.size() + initial_queue_size = service._initial_queue.size() # noqa + service_queue_size = service._service_queue.size() # noqa + device_uplink_queue_size = service._device_uplink_messages_queue.size() # noqa + gateway_uplink_queue_size = service._gateway_uplink_messages_queue.size() # noqa # We don't need to log anything in the test # Replace the method with our patched version @@ -671,7 +674,8 @@ async def test_message_queue_worker_process_with_rate_limits_triggered(): # Set up the rate limit to be triggered triggered_rate_limit_entry = (10, 5) # (tokens, duration) expected_tokens = 5 - worker.check_rate_limits_for_message = AsyncMock(return_value=(triggered_rate_limit_entry, expected_tokens, message_rate_limit)) + worker.check_rate_limits_for_message = AsyncMock( + return_value=(triggered_rate_limit_entry, expected_tokens, message_rate_limit)) # Create a test message message = MqttPublishMessage("test/topic", b"test_payload") @@ -1054,5 +1058,6 @@ async def test_retry_loop_with_not_connected_and_empty_message(setup_retry_loop_ with patch("asyncio.sleep", new=AsyncMock()): await service._dispatch_retry_by_qos_queue_loop() + if __name__ == '__main__': pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py index a82e462..ae94627 100644 --- a/tests/service/test_message_splitter.py +++ b/tests/service/test_message_splitter.py @@ -158,7 +158,6 @@ async def test_split_attributes_grouping(): for i in range(3, 6): builder2.add_attributes(AttributeEntry(f"key_{i}", i)) - messages = [builder1.build(), builder2.build()] result = dispatcher.splitter.split_attributes(messages) @@ -180,7 +179,7 @@ async def test_split_attributes_different_devices_not_grouped(): for i in range(3): builder1.add_attributes(AttributeEntry(f"key_{i}", i)) - builder2.add_attributes(AttributeEntry(f"key_{i+3}", i+3)) + builder2.add_attributes(AttributeEntry(f"key_{i + 3}", i + 3)) result = dispatcher.splitter.split_attributes([builder1.build(), builder2.build()]) @@ -227,5 +226,6 @@ async def test_split_timeseries_registers_futures_and_batches_correctly(mock_reg assert isinstance(shared_future, asyncio.Future) assert hasattr(shared_future, "uuid") + if __name__ == '__main__': pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index b052673..15d855d 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -48,7 +48,8 @@ async def setup_manager(): rate_limits_handler=rate_limits_handler, rpc_response_handler=rpc_response_handler ) - return manager, stop_event, message_adapter, on_connect, on_disconnect, on_publish_result, rate_limits_handler, rpc_response_handler + return (manager, stop_event, message_adapter, on_connect, on_disconnect, + on_publish_result, rate_limits_handler, rpc_response_handler) @pytest.mark.asyncio @@ -124,7 +125,8 @@ async def test_publish_fails_without_rate_limits(setup_manager): manager._MQTTManager__rate_limits_retrieved = False manager._MQTTManager__is_waiting_for_rate_limits_publish = False with pytest.raises(RuntimeError, match="Cannot publish before rate limits are retrieved."): - await manager.publish("topic", b"payload") + mqtt_message = MqttPublishMessage("topic", b"payload") + await manager.publish(mqtt_message, force=False) @pytest.mark.asyncio @@ -139,7 +141,7 @@ async def test_publish_force_bypasses_limits(setup_manager): manager._client._persistent_storage = MagicMock() mqtt_publish_message = MqttPublishMessage("topic", b"payload", qos=1) - await manager.publish(mqtt_publish_message, qos=1, force=True) + await manager.publish(mqtt_publish_message, force=True) assert manager._client._connection.publish.call_count == 1 @@ -172,7 +174,8 @@ async def test_handle_puback_reason_code(setup_manager): manager, *_ = setup_manager fut = asyncio.Future() fut.uuid = "test-future" - manager._pending_publishes[123] = (fut, MqttPublishMessage("topic", b"payload", delivery_futures=[fut]), monotonic()) + manager._pending_publishes[123] = (fut, MqttPublishMessage("topic", b"payload", delivery_futures=[fut]), + monotonic()) manager._handle_puback_reason_code(123, 0, {}) assert fut.done() assert fut.result().message_id == 123 @@ -264,7 +267,7 @@ async def test_publish_qos_zero_sets_result_immediately(setup_manager): manager._client._persistent_storage = MagicMock() future = asyncio.Future() - await manager.publish(MqttPublishMessage("topic", b"payload", delivery_futures=future), qos=0, force=True) + await manager.publish(MqttPublishMessage("topic", b"payload", qos=0, delivery_futures=future), force=True) await asyncio.sleep(0.05) # Allow async tasks to complete assert future.done() assert future.result() == PublishResult("topic", 0, -1, 7, 0) @@ -315,9 +318,8 @@ async def test_connect_loop_retry_and_success(setup_manager): manager._client.connect = AsyncMock(side_effect=[Exception("fail1"), AsyncMock()]) with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock) as mock_connected, \ - patch("asyncio.sleep", new_callable=AsyncMock), \ - patch.object(manager, "_connect_params", new=("host", 1883, None, None, False, 60, None)): - + patch("asyncio.sleep", new_callable=AsyncMock), \ + patch.object(manager, "_connect_params", new=("host", 1883, None, None, False, 60, None)): mock_connected.side_effect = [False, False, True] await asyncio.wait_for(manager._connect_loop(), timeout=1) From 6b18fdd625ca1159f94a16bde46baf0b4cfa9a37 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Tue, 5 Aug 2025 15:51:05 +0300 Subject: [PATCH 67/74] Updated README and dev notes --- README.md | 405 ++++++++++----------------------------------------- dev_notes.md | 325 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 399 insertions(+), 331 deletions(-) create mode 100644 dev_notes.md diff --git a/README.md b/README.md index bdb0878..1551f93 100644 --- a/README.md +++ b/README.md @@ -1,345 +1,88 @@ -# ThingsBoard MQTT and HTTP client Python SDK +# ThingsBoard Python Client SDK 2.0 + [![Join the chat at https://gitter.im/thingsboard/chat](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/thingsboard/chat?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -ThingsBoard is an open-source IoT platform for data collection, processing, visualization, and device management. -This project is a Python library that provides convenient client SDK for both Device and [Gateway](https://thingsboard.io/docs/reference/gateway-mqtt-api/) APIs. - -SDK supports: -- Unencrypted and encrypted (TLS v1.2) connection -- QoS 0 and 1 (MQTT only) -- Automatic reconnect -- All [Device MQTT](https://thingsboard.io/docs/reference/mqtt-api/) APIs provided by ThingsBoard -- All [Gateway MQTT](https://thingsboard.io/docs/reference/gateway-mqtt-api/) APIs provided by ThingsBoard -- Most [Device HTTP](https://thingsboard.io/docs/reference/http-api/) APIs provided by ThingsBoard -- Device Claiming -- Firmware updates - -The [Device MQTT](https://thingsboard.io/docs/reference/mqtt-api/) API and the [Gateway MQTT](https://thingsboard.io/docs/reference/gateway-mqtt-api/) API are base on the Paho MQTT library. The [Device HTTP](https://thingsboard.io/docs/reference/http-api/) API is based on the Requests library. - -## Installation - -To install using pip: - -```bash -pip3 install tb-mqtt-client -``` - -## Getting Started - -Client initialization and telemetry publishing -### MQTT -```python -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo - - -telemetry = {"temperature": 41.9, "enabled": False, "currentFirmwareVersion": "v1.2.2"} - -# Initialize ThingsBoard client -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -# Connect to ThingsBoard -client.connect() -# Sending telemetry without checking the delivery status -client.send_telemetry(telemetry) -# Sending telemetry and checking the delivery status (QoS = 1 by default) -result = client.send_telemetry(telemetry) -# get is a blocking call that awaits delivery status -success = result.get() == TBPublishInfo.TB_ERR_SUCCESS -# Disconnect from ThingsBoard -client.disconnect() - -``` - -### MQTT using TLS - -TLS connection to localhost. See https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ for more information about client and ThingsBoard configuration. - -```python -from tb_device_mqtt import TBDeviceMqttClient -import socket - - -client = TBDeviceMqttClient(socket.gethostname()) -client.connect(tls=True, - ca_certs="mqttserver.pub.pem", - cert_file="mqttclient.nopass.pem") -client.disconnect() - -``` - -### HTTP - -````python -from tb_device_http import TBHTTPDevice - - -client = TBHTTPDevice('https://thingsboard.example.com', 'secret-token') -client.connect() -client.send_telemetry({'temperature': 41.9}) - -```` - -## Using Device APIs - -**TBDeviceMqttClient** provides access to Device MQTT APIs of ThingsBoard platform. It allows to publish telemetry and attribute updates, subscribe to attribute changes, send and receive RPC commands, etc. Use **TBHTTPClient** for the Device HTTP API. -#### Subscription to attributes -You can subscribe to attribute updates from the server. The following example demonstrates how to subscribe to attribute updates from the server. -##### MQTT -```python -import time -from tb_device_mqtt import TBDeviceMqttClient - - -def on_attributes_change(client, result, exception): - if exception is not None: - print("Exception: " + str(exception)) - else: - print(result) - - -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -client.connect() -client.subscribe_to_attribute("uploadFrequency", on_attributes_change) -client.subscribe_to_all_attributes(on_attributes_change) -while True: - time.sleep(1) - -``` - -##### HTTP -Note: The HTTP API only allows a subscription to updates for all attribute. -```python -from tb_device_http import TBHTTPClient - - -client = TBHTTPClient('https://thingsboard.example.com', 'secret-token') - -def callback(data): - print(data) - # ... - -# Subscribe -client.subscribe('attributes', callback) -# Unsubscribe -client.unsubscribe('attributes') - -``` - -#### Telemetry pack sending -You can send multiple telemetry messages at once. The following example demonstrates how to send multiple telemetry messages at once. -##### MQTT -```python - -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -import time - - -telemetry_with_ts = {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}} -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -# we set maximum amount of messages sent to send them at the same time. it may stress memory but increases performance -client.max_inflight_messages_set(100) -client.connect() -results = [] -result = True -for i in range(0, 100): - results.append(client.send_telemetry(telemetry_with_ts)) -for tmp_result in results: - result &= tmp_result.get() == TBPublishInfo.TB_ERR_SUCCESS -print("Result " + str(result)) -client.disconnect() - -``` -##### HTTP -Unsupported, the HTTP API does not allow the packing of values. - -#### Request attributes from server -You can request attributes from the server. The following example demonstrates how to request attributes from the server. +ThingsBoard is an open-source IoT platform for data collection, processing, visualization, and device management. +This is the **official Python client SDK** for connecting **Devices** and **Gateways** to ThingsBoard using **MQTT**. -##### MQTT -```python +--- -import time -from tb_device_mqtt import TBDeviceMqttClient +## ✨ Features +- **MQTT 3.1.1 / 5.0** with QoS 0 and 1 +- **Async I/O** built on [`gmqtt`](https://github.com/wialon/gmqtt) for high performance +- **Encrypted connections** (TLS v1.2+) +- **Automatic reconnect** with session persistence +- **All Device MQTT APIs** ([docs](https://thingsboard.io/docs/reference/mqtt-api/)) +- **All Gateway MQTT APIs** ([docs](https://thingsboard.io/docs/reference/gateway-mqtt-api/)) +- **Batching** and **payload splitting** for large telemetry/attribute messages +- **Rate limit awareness** (ThingsBoard-style multi-window limits) +- **Delivery tracking** — know exactly when your message reaches the server +- **Device claiming** +- **Firmware updates** -def on_attributes_change(client,result, exception): - if exception is not None: - print("Exception: " + str(exception)) - else: - print(result) +--- - -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -client.connect() -client.request_attributes(["configuration","targetFirmwareVersion"], callback=on_attributes_change) -while True: - time.sleep(1) - -``` - -##### HTTP -```python -from tb_device_http import TBHTTPClient - - -client = TBHTTPClient('https://thingsboard.example.com', 'secret-token') - -client_keys = ['attr1', 'attr2'] -shared_keys = ['shared1', 'shared2'] -data = client.request_attributes(client_keys=client_keys, shared_keys=shared_keys) - -``` - -#### Respond to server RPC call -You can respond to RPC calls from the server. The following example demonstrates how to respond to RPC calls from the server. -Please install psutil using 'pip install psutil' command before running the example. - -##### MQTT -```python - -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) -import time -import logging -from tb_device_mqtt import TBDeviceMqttClient - -# dependently of request method we send different data back -def on_server_side_rpc_request(client, request_id, request_body): - print(request_id, request_body) - if request_body["method"] == "getCPULoad": - client.send_rpc_reply(request_id, {"CPU percent": psutil.cpu_percent()}) - elif request_body["method"] == "getMemoryUsage": - client.send_rpc_reply(request_id, {"Memory": psutil.virtual_memory().percent}) - -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -client.set_server_side_rpc_request_handler(on_server_side_rpc_request) -client.connect() -while True: - time.sleep(1) - -``` - -##### HTTP -You can use HTTP API client in case you want to use HTTP API instead of MQTT API. -```python -from tb_device_http import TBHTTPClient - - -client = TBHTTPClient('https://thingsboard.example.com', 'secret-token') - -def callback(data): - rpc_id = data['id'] - # ... do something with data['params'] and data['method']... - response_params = {'result': 1} - client.send_rpc(name='rpc_response', rpc_id=rpc_id, params=response_params) - -# Subscribe -client.subscribe('rpc', callback) -# Unsubscribe -client.unsubscribe('rpc') - -``` - -## Using Gateway APIs - -**TBGatewayMqttClient** extends **TBDeviceMqttClient**, thus has access to all it's APIs as a regular device. -Besides, gateway is able to represent multiple devices connected to it. For example, sending telemetry or attributes on behalf of other, constrained, device. See more info about the gateway here: -#### Telemetry and attributes sending -```python -import time -from tb_gateway_mqtt import TBGatewayMqttClient - - -gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") -gateway.connect() -gateway.gw_connect_device("Test Device A1") - -gateway.gw_send_telemetry("Test Device A1", {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.2}}) -gateway.gw_send_attributes("Test Device A1", {"firmwareVersion": "2.3.1"}) - -gateway.gw_disconnect_device("Test Device A1") -gateway.disconnect() - -``` -#### Request attributes - -You can request attributes from the server. The following example demonstrates how to request attributes from the server. - -```python -import time -from tb_gateway_mqtt import TBGatewayMqttClient - - -def callback(result, exception): - if exception is not None: - print("Exception: " + str(exception)) - else: - print(result) - - -gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") -gateway.connect() -gateway.gw_request_shared_attributes("Test Device A1", ["temperature"], callback) - -while True: - time.sleep(1) - -``` -#### Respond to RPC - -You can respond to RPC calls from the server. The following example demonstrates how to respond to RPC calls from the server. -Please install psutil using 'pip install psutil' command before running the example. - -```python -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) - - -def rpc_request_response(client, request_id, request_body): - # request body contains id, method and other parameters - print(request_body) - method = request_body["data"]["method"] - device = request_body["device"] - req_id = request_body["data"]["id"] - # dependently of request method we send different data back - if method == 'getCPULoad': - gateway.gw_send_rpc_reply(device, req_id, {"CPU load": psutil.cpu_percent()}) - elif method == 'getMemoryLoad': - gateway.gw_send_rpc_reply(device, req_id, {"Memory": psutil.virtual_memory().percent}) - else: - print('Unknown method: ' + method) - - -gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") -gateway.connect() -# now rpc_request_response will process rpc requests from servers -gateway.gw_set_server_side_rpc_request_handler(rpc_request_response) -# without device connection it is impossible to get any messages -gateway.gw_connect_device("Test Device A1") -while True: - time.sleep(1) +## 📦 Installation +```bash +pip install tb-mqtt-client ``` -## Other Examples - -There are more examples for both [device](https://github.com/thingsboard/thingsboard-python-client-sdk/tree/master/examples/device) and [gateway](https://github.com/thingsboard/thingsboard-python-client-sdk/tree/master/examples/gateway) in corresponding [folders](https://github.com/thingsboard/thingsboard-python-client-sdk/tree/master/examples). - -## Support - - - [Community chat](https://gitter.im/thingsboard/chat) - - [Q&A forum](https://groups.google.com/forum/#!forum/thingsboard) - - [Stackoverflow](http://stackoverflow.com/questions/tagged/thingsboard) -## Licenses +## 🚀 Getting Started + +This SDK is **async-based**. +You can run it with `asyncio`. + +We provide **ready-to-run examples** for both device and gateway clients. + +--- + +### **Device Examples** +[`examples/device/`](examples/device): +- [Claim to customer](examples/device/claim_device.py) +- [Provision and connect with provisioned credentials](examples/device/client_provisioning.py) +- [Retrieve firmware](examples/device/firmware_update.py) +- [Handle attribute updates from server](examples/device/handle_attribute_updates.py) +- [Handle RPC requests from server](examples/device/handle_rpc_requests.py) +- [Load testing](examples/device/load.py) +- [Operational example](examples/device/operational_example.py) +- [Request attributes from server](examples/device/request_attributes.py) +- [Send attributes to server](examples/device/send_attributes.py) +- [Send client-side RPC to server and retrieve the response](examples/device/send_client_side_rpc.py) +- [Send timeseries to server](examples/device/send_timeseries.py) +- [Connect to server using MQTT over SSL and send encrypted timeseries](examples/device/tls_connect.py) + +--- + +### **Gateway Examples** +[`examples/gateway/`](examples/gateway): +- [Claim device](examples/gateway/claim_device.py) +- [Connect and disconnect device](examples/gateway/connect_and_disconnect_device.py) +- [Handle attribute updates for connected devices](examples/gateway/handle_attribute_updates.py) +- [Handle RPC requests for connected devices](examples/gateway/handle_rpc_requests.py) +- [Load testing](examples/gateway/load.py) +- [Operational example](examples/gateway/operational_example.py) +- [Request attributes for connected devices](examples/gateway/request_attributes.py) +- [Send attributes for connected devices to server](examples/gateway/send_attributes.py) +- [Send timeseries for connected devices to server](examples/gateway/send_timeseries.py) +- [Connect gateway client to server using MQTT over SSL and send encrypted timeseries](examples/gateway/tls_connect.py) + +--- + +## 📚 Documentation +- [ThingsBoard Device MQTT API](https://thingsboard.io/docs/reference/mqtt-api/) +- [ThingsBoard Gateway MQTT API](https://thingsboard.io/docs/reference/gateway-mqtt-api/) + +--- + +## 💬 Support +- [Community chat](https://gitter.im/thingsboard/chat) +- [Q&A forum](https://groups.google.com/forum/#!forum/thingsboard) +- [Stack Overflow](http://stackoverflow.com/questions/tagged/thingsboard) + +## 📄 Licenses This project is released under [Apache 2.0 License](./LICENSE). diff --git a/dev_notes.md b/dev_notes.md new file mode 100644 index 0000000..f03d85e --- /dev/null +++ b/dev_notes.md @@ -0,0 +1,325 @@ +# Developer Notes for ThingsBoard Python Client SDK + +## Overview +This document provides explanations of tricky moments and major points in the ThingsBoard Python Client SDK implementation. It focuses on key areas that require special attention for developers working with or maintaining the codebase. + +## GMQTT Patches + +The SDK uses the GMQTT library for MQTT communication but implements several patches to enhance its functionality: + +### Key Components: +- `tb_mqtt_client/common/gmqtt_patch.py`: Contains the `PatchUtils` class that applies various patches to the GMQTT library. + +### Major Patches: + +1. **Enhanced Disconnect Handling**: + - The original GMQTT disconnect handling is patched to provide better error reporting and recovery. + - `patch_mqtt_handler_disconnect()` modifies how disconnect packets are processed, adding proper reason code extraction. + +2. **Connection Acknowledgment Handling**: + - `patch_handle_connack()` enhances the connection acknowledgment process to better handle MQTT 5.0 properties. + +3. **Connection Lost Recovery**: + - `patch_gmqtt_protocol_connection_lost()` improves how the client handles unexpected disconnections. + - This is critical for maintaining robust connections in unstable network environments. + +4. **PUBACK Handling with Reason Codes**: + - `patch_puback_handling()` adds support for MQTT 5.0 reason codes in publish acknowledgments. + - This allows the client to understand why a message was rejected by the broker. + +5. **Message Storage for Retries**: + - `patch_storage()` modifies how messages are stored to support the retry mechanism. + +## Message Queues + +The SDK implements a sophisticated message queuing system to handle message processing, rate limiting, and retries: + +### Key Components: +- `tb_mqtt_client/common/queue.py`: Implements `AsyncDeque`, a thread-safe asynchronous queue. +- `tb_mqtt_client/service/message_service.py`: Contains the `MessageService` class that manages multiple queues. + +### Queue Types: +1. **Initial Queue**: + - Entry point for all messages before they're routed to specialized queues. + - Implemented in `MessageService._dispatch_initial_queue_loop()`. + +2. **Service Queue**: + - Handles general service messages. + +3. **Device Uplink Messages Queue**: + - Specifically for device telemetry and attribute updates. + +4. **Gateway Uplink Messages Queue**: + - For gateway-specific messages that may contain data for multiple devices. + +5. **Retry by QoS Queue**: + - Special queue for handling QoS 1 messages that need to be retried. + - Implemented in `MessageService._dispatch_retry_by_qos_queue_loop()`. + +### Queue Processing: +- Each queue has its own asynchronous processing loop. +- Messages are processed according to rate limits and QoS requirements. +- The system uses `asyncio` for non-blocking queue operations. + +## QoS Messages Reprocessing + +The SDK implements a robust mechanism for handling QoS 1 messages, ensuring they are delivered at least once: + +### Key Components: +- `tb_mqtt_client/service/mqtt_manager.py`: Contains the `patch_client_for_retry_logic()` method. +- `tb_mqtt_client/common/gmqtt_patch.py`: Implements the retry mechanism in `_retry_loop()`. + +### Retry Process: +1. **Message Tracking**: + - Messages with QoS 1 are tracked until acknowledgment is received. + - If no acknowledgment is received within the timeout, the message is requeued. + +2. **Retry Loop**: + - `PatchUtils._retry_loop()` periodically checks for unacknowledged messages. + - Messages that haven't been acknowledged are put back into the retry queue. + +3. **Integration with Message Service**: + - `MQTTManager.patch_client_for_retry_logic()` connects the GMQTT client with the message service's retry queue. + - `MessageService.put_retry_message()` handles requeuing of messages that need to be retried. + +4. **Publish Monitoring**: + - `MQTTManager._monitor_ack_timeouts()` and `check_pending_publishes()` track publish operations and handle timeouts. + +## Gateway Message Dispatcher + +The gateway part of the SDK includes a message dispatcher for handling various types of events: + +### Key Components: +- `tb_mqtt_client/service/gateway/direct_event_dispatcher.py`: Implements the `DirectEventDispatcher` class. + +### Dispatcher Features: +1. **Event Registration**: + - Handlers can be registered for specific event types using `register()`. + - Multiple handlers can be registered for the same event type. + +2. **Event Dispatching**: + - `dispatch()` method routes events to the appropriate handlers. + - Supports both synchronous and asynchronous handlers. + +3. **Device-Specific Handling**: + - Can route events directly to a specific device session if provided. + - Otherwise, broadcasts to all registered handlers for the event type. + +4. **Error Handling**: + - Catches and logs exceptions in handlers to prevent one handler failure from affecting others. + +## Message Splitting + +The SDK includes mechanisms to handle large messages by splitting them into smaller chunks: + +### Key Components: +- `tb_mqtt_client/service/base_message_splitter.py`: Defines the base interface. +- `tb_mqtt_client/service/gateway/message_splitter.py`: Implements gateway-specific splitting logic. + +### Splitting Strategy: +1. **Size-Based Splitting**: + - Messages are split if they exceed the maximum payload size (default 55KB). + - Each split message maintains the original message's context. + +2. **Datapoint-Based Splitting**: + - Messages can also be split based on the number of datapoints. + - This helps prevent overwhelming the server with too many datapoints in a single message. + +3. **Message grouping**: + - Uplink data messages may be grouped to efficiently use the available payload size. + - This reduces the number of messages sent and optimizes network usage. + +4. **Future Chaining**: + - When a message is split, the futures from the original message are chained to the split messages. + - This ensures that the completion status is properly propagated back to the caller. + +## Rate Limiting + +The SDK implements rate limiting to prevent overwhelming the ThingsBoard server: + +### Key Components: +- `tb_mqtt_client/common/rate_limit/rate_limit.py`: Defines rate limit structures. +- `tb_mqtt_client/common/rate_limit/rate_limiter.py`: Implements the rate limiting logic. + +### Rate Limiting Features: +1. **Server-Provided Limits**: + - The SDK requests rate limits from the server upon connection. + - These limits are then applied to message publishing. + +2. **Separate Device and Gateway Limits**: + - Different rate limits can be applied to device and gateway messages. + +3. **Backpressure Control**: + - When rate limits are exceeded, the system applies backpressure to slow down message production. + + + +## Futures in Device and Gateway Uplink Messages + +### Overview of Futures in the SDK + +Futures are a critical component in the ThingsBoard Python Client SDK for handling asynchronous operations, particularly for message publishing. They provide a way to track the status of message delivery and propagate results back to the caller. + +### Futures in Uplink Messages + +#### Device and Gateway Message Structure + +Both `DeviceUplinkMessage` and `GatewayUplinkMessage` classes contain a `delivery_futures` field, which is a list of `asyncio.Future` objects. These futures are resolved when the message is successfully published to the ThingsBoard server or when an error occurs. + +```python +# From DeviceUplinkMessage +delivery_futures: List[Optional[asyncio.Future[PublishResult]]] +``` + +When a client application sends data to ThingsBoard, it can attach a future to the message to be notified when the message is delivered: + +```python +# Example usage +future = asyncio.get_running_loop().create_future() +message_builder.add_delivery_futures(future) +message = message_builder.build() +# Later, the application can await future to know when the message is delivered +result = await future # PublishResult object +``` + +### Message Splitting and Future Chaining + +One of the trickiest aspects of the SDK is how it handles futures when messages need to be split due to size limitations or datapoint count restrictions. + +#### The Splitting Process + +When a message exceeds the maximum payload size (default 55KB) or the maximum number of datapoints, it needs to be split into smaller chunks. This presents a challenge: how to maintain the relationship between the original message's future and the futures of the split messages? + +#### Future Chaining Mechanism + +The SDK uses a sophisticated future chaining mechanism implemented in the `FutureMap` class: + +```python +# From async_utils.py +class FutureMap: + def __init__(self): + self._child_to_parents: Dict[asyncio.Future, Set[asyncio.Future]] = {} + self._parent_to_remaining: Dict[asyncio.Future, Set[asyncio.Future]] = {} +``` + +This class maintains two mappings: +1. `_child_to_parents`: Maps each child future (from split messages) to its parent futures +2. `_parent_to_remaining`: Maps each parent future to the set of child futures that still need to be resolved + +##### Registration Process + +When a message is split, new futures are created for each split message, and these are registered with the original message's future: + +```python +# In message_splitter.py +shared_future = asyncio.get_running_loop().create_future() +shared_future.uuid = uuid4() +builder.add_delivery_futures(shared_future) + +built = builder.build() +result.append(built) +for parent in parent_futures: + future_map.register(parent, [shared_future]) +``` + +##### Resolution Process + +When a child future is resolved (when a split message is delivered), the `child_resolved` method is called: + +```python +# In FutureMap.child_resolved +def child_resolved(self, child: asyncio.Future): + parents = self._child_to_parents.pop(child, set()) + for parent in parents: + remaining = self._parent_to_remaining.get(parent) + if remaining is not None: + remaining.discard(child) + if not remaining and not parent.done(): + # All children resolved, resolve the parent + all_children = list(remaining) + [child] + results = [] + for f in all_children: + if f.done() and not f.cancelled(): + result = f.result() + if isinstance(result, PublishResult): + results.append(result) + + if results: + parent.set_result(PublishResult.merge(results)) + else: + parent.set_result(None) +``` + +This ensures that the parent future is only resolved when all of its child futures are resolved, and the results are merged. + +#### Differences Between Device and Gateway Splitters + +While the core mechanism is the same, there are some differences in how device and gateway message splitters handle futures: + +1. **Device Message Splitter**: + - Focuses on splitting individual device messages + - Handles timeseries and attributes separately + - Creates a shared future for each batch of data + +2. **Gateway Message Splitter**: + - Handles messages from multiple devices + - Groups data by device name and profile + - Maintains the device context across split messages + +### MQTT Manager's Role + +The `MQTTManager` class plays a crucial role in the future resolution process: + +```python +# In mqtt_manager.py +@staticmethod +async def _add_future_chain_processing(mqtt_future, message: MqttPublishMessage): + def resolve_attached(publish_future: asyncio.Future): + try: + try: + publish_result = publish_future.result() + except asyncio.CancelledError: + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + except Exception as exc: + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_result(publish_result) + future_map.child_resolved(f) + except Exception as e: + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_exception(e) +``` + +This method: +1. Attaches a callback to the MQTT publish future +2. When the publish operation completes, it resolves all the delivery futures attached to the message +3. Calls `future_map.child_resolved()` to propagate the resolution up the chain + +### Practical Implications + +This future chaining mechanism has several important implications: + +1. **Transparent Splitting**: Applications don't need to be aware that their messages are being split. They attach a future to the original message and receive a notification when all parts are delivered. + +2. **Error Handling**: If any part of a split message fails to deliver, the error is propagated to the parent future. + +3. **Performance Optimization**: The SDK can optimize message delivery by splitting large messages without breaking the promise to notify the application when delivery is complete. + +4. **Resource Management**: Futures are properly managed and resolved, preventing memory leaks even when messages are split into many parts. + +## Conclusion + +The ThingsBoard Python Client SDK implements several sophisticated mechanisms to ensure reliable message delivery, efficient processing, and robust error handling. +The future mechanism in the ThingsBoard Python Client SDK provides a robust way to handle asynchronous message delivery with transparent message splitting. +Understanding this mechanism is crucial for developers working with the SDK, especially when dealing with large messages or high-throughput scenarios. + +Key areas to be aware of when making changes: +1. The GMQTT patches that enhance the underlying MQTT library +2. The message queue system that manages message flow +3. The QoS message reprocessing mechanism that ensures reliable delivery +4. The gateway message dispatcher that routes events to appropriate handlers +5. The message splitting logic that handles large payloads +6. The rate limiting system that prevents overwhelming the server \ No newline at end of file From 2f34578867257f44bdbd6bf660a14b8921b7a3de Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 6 Aug 2025 09:40:12 +0300 Subject: [PATCH 68/74] Updated license --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 6cbfe64..373c076 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2024. ThingsBoard + Copyright 2025 ThingsBoard Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 4e372592405c23eb7945e05fb84e65a853573769 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Wed, 6 Aug 2025 09:41:21 +0300 Subject: [PATCH 69/74] Fixed reconnection on session taken over and missing device at init --- examples/device/load.py | 1 - tb_mqtt_client/common/exceptions.py | 9 ++++ tb_mqtt_client/common/gmqtt_patch.py | 1 - tb_mqtt_client/service/message_service.py | 9 +++- tb_mqtt_client/service/mqtt_manager.py | 57 ++++++++++------------- tests/service/test_mqtt_manager.py | 22 --------- 6 files changed, 41 insertions(+), 58 deletions(-) diff --git a/examples/device/load.py b/examples/device/load.py index d846859..7e5c20a 100644 --- a/examples/device/load.py +++ b/examples/device/load.py @@ -2,7 +2,6 @@ import logging import signal import time -from datetime import UTC, datetime from random import randint from tb_mqtt_client.common.config_loader import DeviceConfig diff --git a/tb_mqtt_client/common/exceptions.py b/tb_mqtt_client/common/exceptions.py index 4fc2237..d91193a 100644 --- a/tb_mqtt_client/common/exceptions.py +++ b/tb_mqtt_client/common/exceptions.py @@ -65,4 +65,13 @@ def _asyncio_handler(loop_, context: dict): loop.set_exception_handler(_asyncio_handler) +class BackpressureException(Exception): + """ + Exception raised when the client is under backpressure. + This should be used to signal that the client cannot process more messages at the moment. + """ + def __init__(self, message: str = "Client is under backpressure. Please retry later."): + super().__init__(message) + + exception_handler = ExceptionHandler() diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index 513df8c..539df5a 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -198,7 +198,6 @@ def patched_handle_disconnect_packet(self, cmd, packet): # Set a flag on the connection object to indicate that on_disconnect has been called self._connection._on_disconnect_called = True - original_handle_disconnect(self, cmd, packet) # Apply the patch MqttPackageHandler._handle_disconnect_packet = patched_handle_disconnect_packet diff --git a/tb_mqtt_client/service/message_service.py b/tb_mqtt_client/service/message_service.py index 0e8db72..20bb41b 100644 --- a/tb_mqtt_client/service/message_service.py +++ b/tb_mqtt_client/service/message_service.py @@ -16,6 +16,7 @@ from contextlib import suppress from typing import List, Optional, Tuple +from tb_mqtt_client.common.exceptions import BackpressureException from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult @@ -167,7 +168,7 @@ async def _dispatch_initial_queue_loop(self): async def _dispatch_queue_loop(self, queue: AsyncDeque, worker: 'MessageQueueWorker'): """Loop to process messages from the service queue.""" while not self._main_stop_event.is_set() and self._active.is_set(): - message = None + message: Optional[MqttPublishMessage] = None try: if not self._mqtt_manager.is_connected(): await asyncio.sleep(self._QUEUE_COOLDOWN) @@ -191,6 +192,12 @@ async def _dispatch_queue_loop(self, queue: AsyncDeque, worker: 'MessageQueueWor except asyncio.CancelledError: break + except BackpressureException: + logger.warning("Backpressure exception occurred, re-inserting message to the front of the queue: %s", + message.uuid if message else "Unknown") + if message: + await queue.reinsert_front(message) + await asyncio.sleep(1) except Exception as e: logger.exception("Service queue loop error: %s", e) if message: diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py index fbc818f..5dfe4d2 100644 --- a/tb_mqtt_client/service/mqtt_manager.py +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -23,6 +23,7 @@ from gmqtt import Client as GMQTTClient, Subscription from tb_mqtt_client.common.async_utils import await_or_stop, future_map, run_coroutine_sync +from tb_mqtt_client.common.exceptions import BackpressureException from tb_mqtt_client.common.gmqtt_patch import PatchUtils, PublishPacket from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL from tb_mqtt_client.common.mqtt_message import MqttPublishMessage @@ -63,7 +64,7 @@ def __init__( self._patch_utils.patch_gmqtt_protocol_connection_lost() self._patch_utils.patch_mqtt_handler_disconnect() - self._client = GMQTTClient(client_id) + self._client: GMQTTClient = GMQTTClient(client_id) self._patch_utils.client = self._client self._patch_utils.patch_handle_connack() self._patch_utils.apply(self._handle_puback_reason_code) @@ -79,7 +80,6 @@ def __init__( self._on_publish_result_callback = on_publish_result self._connected_event = asyncio.Event() - self._connect_params = None # Will be set in connect method self._handlers: Dict[str, Callable[[str, bytes], Coroutine[Any, Any, None]]] = {} self._pending_publishes: Dict[int, Tuple[asyncio.Future[PublishResult], MqttPublishMessage, float]] = {} @@ -103,33 +103,22 @@ def __init__( async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, password: Optional[str] = None, tls: bool = False, keepalive: int = 60, ssl_context: Optional[ssl.SSLContext] = None): - self._connect_params = (host, port, username, password, tls, keepalive, ssl_context) - asyncio.create_task(self._connect_loop()) + if username: + self._client.set_auth_credentials(username, password) - async def _connect_loop(self): - host, port, username, password, tls, keepalive, ssl_context = self._connect_params - retry_delay = 3 + if tls: + if ssl_context is None: + ssl_context = ssl.create_default_context() + await self._client.connect(host, port, ssl=ssl_context, keepalive=keepalive, raise_exc=False) + else: + await self._client.connect(host, port, keepalive=keepalive, raise_exc=False) while not self._client.is_connected and not self._main_stop_event.is_set(): try: - if username: - self._client.set_auth_credentials(username, password) - - if tls: - if ssl_context is None: - ssl_context = ssl.create_default_context() - await self._client.connect(host, port, ssl=ssl_context, keepalive=keepalive) - else: - await self._client.connect(host, port, keepalive=keepalive) - - logger.info("MQTT connection initiated, waiting for on_connect...") await self._connected_event.wait() - logger.info("MQTT connected.") break - - except Exception as e: - logger.warning("Initial MQTT connection failed: %s. Retrying in %s seconds...", str(e), retry_delay) - await asyncio.sleep(retry_delay) + except Exception as exc: + logger.warning("MQTT connection failed, waiting for connection: %s", str(exc)) def is_connected(self) -> bool: return (self._client.is_connected @@ -168,7 +157,7 @@ async def publish(self, if not force and self._backpressure.should_pause(): logger.trace("Backpressure active. Publishing suppressed.") - raise RuntimeError("Publishing temporarily paused due to backpressure.") + raise BackpressureException("Publishing temporarily paused due to backpressure.") if not message.dup: return await self.process_regular_publish(message, message.qos) @@ -322,20 +311,22 @@ def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc self._backpressure.notify_disconnect(delay_seconds=10) if reason_code in (131, 142, 143, 151): # 131, 142, 151 may be caused by rate limits or issue with the data reached_time = 1 - for rate_limit in self.__rate_limiter.values(): - if isinstance(rate_limit, RateLimit): - try: - reached_limit = run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) - except TimeoutError: - logger.warning("Timeout while checking rate limit reaching.") - reached_time = 10 # Default to 10 seconds if timeout occurs - break - reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) + if self.__rate_limiter: + for rate_limit in self.__rate_limiter.values(): + if isinstance(rate_limit, RateLimit): + try: + reached_limit = run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) + except TimeoutError: + logger.warning("Timeout while checking rate limit reaching.") + reached_time = 10 # Default to 10 seconds if timeout occurs + break + reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) self._backpressure.notify_disconnect(delay_seconds=reached_time) elif reason_code != 0: # Default disconnect handling self._backpressure.notify_disconnect(delay_seconds=15) + self._rpc_response_handler.clear() if self._on_disconnect_callback: asyncio.create_task(self._on_disconnect_callback()) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 15d855d..9857434 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -52,13 +52,6 @@ async def setup_manager(): on_publish_result, rate_limits_handler, rpc_response_handler) -@pytest.mark.asyncio -async def test_connect_sets_connect_params(setup_manager): - manager, *_ = setup_manager - await manager.connect("localhost", 1883, "user", "pass", tls=False) - assert manager._connect_params[:4] == ("localhost", 1883, "user", "pass") - - @pytest.mark.asyncio async def test_is_connected_returns_false_if_not_ready(setup_manager): manager, *_ = setup_manager @@ -310,21 +303,6 @@ async def test_handle_puback_reason_code_errors(setup_manager): manager._handle_puback_reason_code(9999, 1, {}) # Should log warning, not crash -@pytest.mark.asyncio -async def test_connect_loop_retry_and_success(setup_manager): - manager, stop_event, *_ = setup_manager - manager._connected_event.set() - - manager._client.connect = AsyncMock(side_effect=[Exception("fail1"), AsyncMock()]) - - with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock) as mock_connected, \ - patch("asyncio.sleep", new_callable=AsyncMock), \ - patch.object(manager, "_connect_params", new=("host", 1883, None, None, False, 60, None)): - mock_connected.side_effect = [False, False, True] - - await asyncio.wait_for(manager._connect_loop(), timeout=1) - - @pytest.mark.asyncio async def test_request_rate_limits_timeout(setup_manager): manager, stop_event, _, _, _, _, rate_handler, _ = setup_manager From 10126291c6e4d14d8ded62b6bd3caf7c9e907b1a Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 7 Aug 2025 09:38:49 +0300 Subject: [PATCH 70/74] Added blackbox tests --- examples/__init__.py | 15 + examples/device/__init__.py | 15 + examples/device/client_provisioning.py | 20 +- examples/device/handle_attribute_updates.py | 6 +- examples/device/handle_rpc_requests.py | 6 +- examples/device/request_attributes.py | 7 +- examples/device/send_attributes.py | 8 +- examples/device/send_timeseries.py | 8 +- examples/gateway/__init__.py | 15 + .../gateway/connect_and_disconnect_device.py | 12 +- examples/gateway/handle_attribute_updates.py | 13 +- examples/gateway/handle_rpc_requests.py | 13 +- examples/gateway/request_attributes.py | 13 +- examples/gateway/send_attributes.py | 13 +- examples/gateway/send_timeseries.py | 11 +- tb_mqtt_client/common/mqtt_message.py | 11 + tb_mqtt_client/common/provisioning_client.py | 17 +- .../entities/data/provisioning_response.py | 10 +- tb_mqtt_client/service/device/client.py | 9 +- .../service/device/message_adapter.py | 2 +- tests/blackbox/__init__.py | 13 + tests/blackbox/conftest.py | 228 +++++++++++ tests/blackbox/rest_helpers.py | 147 +++++++ tests/blackbox/test_basic_device_examples.py | 207 ++++++++++ tests/blackbox/test_basic_gateway_examples.py | 359 ++++++++++++++++++ tests/blackbox/test_client_provisioning.py | 79 ++++ tests/common/test_provisioning_client.py | 29 +- tests/entities/data/test_provisioning_data.py | 15 +- 28 files changed, 1211 insertions(+), 90 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/device/__init__.py create mode 100644 examples/gateway/__init__.py create mode 100644 tests/blackbox/__init__.py create mode 100644 tests/blackbox/conftest.py create mode 100644 tests/blackbox/rest_helpers.py create mode 100644 tests/blackbox/test_basic_device_examples.py create mode 100644 tests/blackbox/test_basic_gateway_examples.py create mode 100644 tests/blackbox/test_client_provisioning.py diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..978ca26 --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is required to make the examples usable in blackbox tests. diff --git a/examples/device/__init__.py b/examples/device/__init__.py new file mode 100644 index 0000000..978ca26 --- /dev/null +++ b/examples/device/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is required to make the examples usable in blackbox tests. diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index 7c7b5ae..29bef2d 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -28,20 +28,28 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +device_name = "ProvisionedDevice" + +provisioning_credentials = AccessTokenProvisioningCredentials( + provision_device_key='YOUR_PROVISION_DEVICE_KEY', + provision_device_secret='YOUR_PROVISION_DEVICE_SECRET', +) +provisioning_request = ProvisioningRequest('localhost', + credentials=provisioning_credentials, + device_name=device_name) async def main(): - provisioning_credentials = AccessTokenProvisioningCredentials( - provision_device_key='YOUR_PROVISION_DEVICE_KEY', - provision_device_secret='YOUR_PROVISION_DEVICE_SECRET', - ) - provisioning_request = ProvisioningRequest('localhost', credentials=provisioning_credentials) provisioning_response = await DeviceClient.provision(provisioning_request) + if not provisioning_response: + logger.error(f"Provisioning failed, no response received.") + return + if provisioning_response.error is not None: logger.error(f"Provisioning failed: {provisioning_response.error}") return - logger.info('Provisioned device configuration: ', provisioning_response) + logger.info(f'Provisioned device configuration: {provisioning_response}') # Create a DeviceClient instance with the provisioned device configuration client = DeviceClient(provisioning_response.result) diff --git a/examples/device/handle_attribute_updates.py b/examples/device/handle_attribute_updates.py index 145d23c..5da55b9 100644 --- a/examples/device/handle_attribute_updates.py +++ b/examples/device/handle_attribute_updates.py @@ -27,15 +27,15 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" async def attribute_update_callback(update: AttributeUpdate): logger.info("Received attribute update: %r", update) async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = DeviceClient(config) client.set_attribute_update_callback(attribute_update_callback) diff --git a/examples/device/handle_rpc_requests.py b/examples/device/handle_rpc_requests.py index f76d1e0..24da791 100644 --- a/examples/device/handle_rpc_requests.py +++ b/examples/device/handle_rpc_requests.py @@ -28,6 +28,9 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" async def rpc_request_callback(request: RPCRequest) -> RPCResponse: logger.info("Received RPC: %r", request) @@ -39,9 +42,6 @@ async def rpc_request_callback(request: RPCRequest) -> RPCResponse: async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = DeviceClient(config) client.set_rpc_request_callback(rpc_request_callback) diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index 794aae3..ea60339 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -30,6 +30,10 @@ response_received = asyncio.Event() +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + async def attribute_request_callback(response: RequestedAttributeResponse): logger.info("Received attribute response: %r", response) @@ -37,9 +41,6 @@ async def attribute_request_callback(response: RequestedAttributeResponse): async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = DeviceClient(config) await client.connect() diff --git a/examples/device/send_attributes.py b/examples/device/send_attributes.py index 6144089..04aa74d 100644 --- a/examples/device/send_attributes.py +++ b/examples/device/send_attributes.py @@ -27,13 +27,15 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = DeviceClient(config) + await client.connect() # Send attribute as raw dictionary diff --git a/examples/device/send_timeseries.py b/examples/device/send_timeseries.py index 5c5c5f9..2afb2b1 100644 --- a/examples/device/send_timeseries.py +++ b/examples/device/send_timeseries.py @@ -30,10 +30,12 @@ logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + + async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = DeviceClient(config) await client.connect() diff --git a/examples/gateway/__init__.py b/examples/gateway/__init__.py new file mode 100644 index 0000000..978ca26 --- /dev/null +++ b/examples/gateway/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is required to make the examples usable in blackbox tests. diff --git a/examples/gateway/connect_and_disconnect_device.py b/examples/gateway/connect_and_disconnect_device.py index 5ef9d13..7ed51be 100644 --- a/examples/gateway/connect_and_disconnect_device.py +++ b/examples/gateway/connect_and_disconnect_device.py @@ -21,19 +21,19 @@ configure_logging() logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" async def main(): - config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = GatewayClient(config) await client.connect() # Connecting device - - device_name = "Test Device B1" - device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/handle_attribute_updates.py b/examples/gateway/handle_attribute_updates.py index b821397..a007358 100644 --- a/examples/gateway/handle_attribute_updates.py +++ b/examples/gateway/handle_attribute_updates.py @@ -23,6 +23,13 @@ configure_logging() logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + async def attribute_update_handler(device_session: DeviceSession, attribute_update: AttributeUpdate): """ @@ -35,17 +42,11 @@ async def attribute_update_handler(device_session: DeviceSession, attribute_upda async def main(): - config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = GatewayClient(config) await client.connect() # Connecting device - - device_name = "Test Device B1" - device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py index 1b6135f..6ffa377 100644 --- a/examples/gateway/handle_rpc_requests.py +++ b/examples/gateway/handle_rpc_requests.py @@ -21,9 +21,17 @@ from tb_mqtt_client.service.gateway.client import GatewayClient from tb_mqtt_client.service.gateway.device_session import DeviceSession + configure_logging() logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: @@ -55,17 +63,12 @@ async def device_rpc_request_handler(device_session: DeviceSession, async def main(): - config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = GatewayClient(config) await client.connect() # Connecting device - device_name = "Test Device B1" - device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index 2310129..d8cf8fe 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -22,9 +22,17 @@ from tb_mqtt_client.service.gateway.client import GatewayClient from tb_mqtt_client.service.gateway.device_session import DeviceSession + configure_logging() logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + async def requested_attributes_handler(device_session: DeviceSession, response: RequestedAttributeResponse): """ @@ -39,17 +47,12 @@ async def requested_attributes_handler(device_session: DeviceSession, response: async def main(): - config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = GatewayClient(config) await client.connect() # Connecting device - device_name = "Test Device B1" - device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/send_attributes.py b/examples/gateway/send_attributes.py index 28d39d6..34443bd 100644 --- a/examples/gateway/send_attributes.py +++ b/examples/gateway/send_attributes.py @@ -19,22 +19,25 @@ from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry from tb_mqtt_client.service.gateway.client import GatewayClient + configure_logging() logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + async def main(): - config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = GatewayClient(config) await client.connect() # Connecting device - device_name = "Test Device B1" - device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py index ebc1b0a..f20304e 100644 --- a/examples/gateway/send_timeseries.py +++ b/examples/gateway/send_timeseries.py @@ -24,19 +24,20 @@ configure_logging() logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" async def main(): - config = GatewayConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = GatewayClient(config) await client.connect() # Connecting device - device_name = "Test Device B1" - device_profile = "Test devices" logger.info("Connecting device: %s", device_name) device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py index 5a40b7a..922267b 100644 --- a/tb_mqtt_client/common/mqtt_message.py +++ b/tb_mqtt_client/common/mqtt_message.py @@ -86,3 +86,14 @@ def mark_as_sent(self, message_id: int): """Mark the message as sent.""" self.message_id = message_id self._is_sent = True + + def __str__(self): + return f"MqttPublishMessage(topic={self.topic}, payload_size={self.payload_size}, " \ + f"datapoints={self.datapoints}, qos={self.qos}, retain={self.retain}, " \ + f"main_ts={self.main_ts}, is_device_message={self.is_device_message})" + + def __repr__(self): + return f"MqttPublishMessage(topic={self.topic}, payload={self.payload}, payload_size={self.payload_size}, " \ + f"datapoints={self.datapoints}, qos={self.qos}, retain={self.retain}, " \ + f"main_ts={self.main_ts}, is_device_message={self.is_device_message}, " \ + f"delivery_futures={self.delivery_futures})" diff --git a/tb_mqtt_client/common/provisioning_client.py b/tb_mqtt_client/common/provisioning_client.py index 5ca90f7..12affc9 100644 --- a/tb_mqtt_client/common/provisioning_client.py +++ b/tb_mqtt_client/common/provisioning_client.py @@ -39,26 +39,25 @@ def __init__(self, host: str, port: int, provision_request: ProvisioningRequest) self._client.on_connect = self._on_connect self._client.on_message = self._on_message self._provisioned = Event() - self._device_config: Optional[Union[DeviceConfig, ProvisioningResponse]] = None + self._provisioning_response: Optional[ProvisioningResponse] = None self.__message_adapter = JsonMessageAdapter() def _on_connect(self, client, _, rc, __): if rc == 0: self._log.debug("[Provisioning client] Connected to ThingsBoard") client.subscribe(PROVISION_RESPONSE_TOPIC) - topic, payload = self.__message_adapter.build_provision_request(self._provision_request) - self._log.debug("[Provisioning client] Sending provisioning request %s" % payload) - client.publish(topic, payload) + provision_message = self.__message_adapter.build_provision_request(self._provision_request) + self._log.debug("[Provisioning client] Sending provisioning request %s", provision_message) + client.publish(provision_message.topic, provision_message.payload) else: - self._device_config = ProvisioningResponse.build(self._provision_request, - {'status': 'FAILURE', + self._provisioning_response = ProvisioningResponse.build(self._provision_request, + {'status': 'FAILURE', 'errorMsg': 'Cannot connect to ThingsBoard!'}) self._provisioned.set() self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) async def _on_message(self, _, __, payload, ___, ____): - provisioning_response = self.__message_adapter.parse_provisioning_response(self._provision_request, payload) - self._device_config = provisioning_response.result + self._provisioning_response = self.__message_adapter.parse_provisioning_response(self._provision_request, payload) await self._client.disconnect() self._provisioned.set() @@ -67,4 +66,4 @@ async def provision(self): await self._client.connect(self._host, self._port) await self._provisioned.wait() - return self._device_config + return self._provisioning_response diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py index 2e79024..fc82f4e 100644 --- a/tb_mqtt_client/entities/data/provisioning_response.py +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -40,16 +40,16 @@ def build(cls, provision_request: 'ProvisioningRequest', payload: dict) -> 'Prov """ self = object.__new__(cls) - if payload.get('status') == ProvisioningResponseStatus.ERROR.value: - object.__setattr__(self, 'error', payload.get('errorMsg')) - object.__setattr__(self, 'status', ProvisioningResponseStatus.ERROR) - object.__setattr__(self, 'result', None) - else: + if payload.get('status') == ProvisioningResponseStatus.SUCCESS.value: device_config = ProvisioningResponse._build_device_config(provision_request, payload) object.__setattr__(self, 'result', device_config) object.__setattr__(self, 'status', ProvisioningResponseStatus.SUCCESS) object.__setattr__(self, 'error', None) + else: + object.__setattr__(self, 'error', payload.get('errorMsg')) + object.__setattr__(self, 'status', ProvisioningResponseStatus.ERROR) + object.__setattr__(self, 'result', None) return self diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 7424baf..7b05202 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -33,6 +33,7 @@ from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.claim_request import ClaimRequest from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -453,17 +454,17 @@ async def __on_publish_result(self, publish_result: PublishResult): logger.error("Publish failed: %r", publish_result) @staticmethod - async def provision(provision_request: 'ProvisioningRequest', timeout=BaseClient.DEFAULT_TIMEOUT): + async def provision(provision_request: 'ProvisioningRequest', timeout=BaseClient.DEFAULT_TIMEOUT) -> Optional[ProvisioningResponse]: provision_client = ProvisioningClient( host=provision_request.host, port=provision_request.port, provision_request=provision_request ) - device_credentials = None + provisioning_response = None try: - device_credentials = await wait_for(provision_client.provision(), timeout=timeout) + provisioning_response = await wait_for(provision_client.provision(), timeout=timeout) except TimeoutError: logger.error("Provisioning timed out") - return device_credentials + return provisioning_response diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py index f88684e..76472f2 100644 --- a/tb_mqtt_client/service/device/message_adapter.py +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -370,7 +370,7 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: :param rpc_response: The RPC response to build the payload for. :return: A tuple of topic and payload bytes. """ - if not rpc_response.request_id: + if rpc_response.request_id is None or rpc_response.request_id < 0: raise ValueError("RPCResponse must have a valid request ID.") payload = dumps(rpc_response.to_payload_format()) diff --git a/tests/blackbox/__init__.py b/tests/blackbox/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/blackbox/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/blackbox/conftest.py b/tests/blackbox/conftest.py new file mode 100644 index 0000000..67b5de9 --- /dev/null +++ b/tests/blackbox/conftest.py @@ -0,0 +1,228 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import time +import pytest +import requests +import logging + +from requests import HTTPError + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.config_loader import DeviceConfig +from tests.blackbox.rest_helpers import find_related_entity_id, get_device_info_by_id + +TB_HOST = os.getenv("SDK_BLACKBOX_TB_HOST", "localhost") +TB_HTTP_PORT = int(os.getenv("SDK_BLACKBOX_TB_PORT", 8080)) +TB_MQTT_PORT = int(os.getenv("SDK_BLACKBOX_TB_MQTT_PORT", 1883)) +TENANT_USER = os.getenv("SDK_BLACKBOX_TENANT_USER", "tenant@thingsboard.org") +TENANT_PASS = os.getenv("SDK_BLACKBOX_TENANT_PASS", "tenant") +TB_CONTAINER_NAME = os.getenv("SDK_BLACKBOX_TB_CONTAINER", "tb-ce-sdk-tests") +RUN_WITH_LOCAL_TB = os.getenv("SDK_RUN_BLACKBOX_TESTS_LOCAL", "false").lower() == "true" + +TB_HTTP_PROTOCOL = os.getenv("SDK_BLACKBOX_TB_HTTP_PROTOCOL", "http") +TB_URL = os.getenv("SDK_BLACKBOX_TB_URL", f"{TB_HTTP_PROTOCOL}://{TB_HOST}:{TB_HTTP_PORT}") + +logger = logging.getLogger("blackbox") +logger.setLevel(logging.INFO) + + +def pytest_collection_modifyitems(config, items): + """Skip blackbox tests unless explicitly enabled.""" + if os.getenv("SDK_RUN_BLACKBOX_TESTS", "false").lower() != "true": + skip_marker = pytest.mark.skip(reason="Blackbox tests disabled. Set SDK_RUN_BLACKBOX_TESTS=true to run.") + for item in items: + if "blackbox" in str(item.fspath): + item.add_marker(skip_marker) + + +@pytest.fixture(scope="session", autouse=True) +def start_thingsboard(): + """Start ThingsBoard CE in Docker if blackbox tests are enabled.""" + if os.getenv("SDK_RUN_BLACKBOX_TESTS", "false").lower() != "true": + return + + if not RUN_WITH_LOCAL_TB: + try: + status = subprocess.check_output( + ["docker", "inspect", "-f", "{{.State.Running}}", TB_CONTAINER_NAME], + stderr=subprocess.DEVNULL + ).decode().strip() + if status == "true": + logger.info(f"Using existing ThingsBoard CE container: {TB_CONTAINER_NAME}") + wait_for_tb_ready() + return + except subprocess.CalledProcessError: + pass + + logger.info("Pulling latest ThingsBoard CE image...") + subprocess.run(["docker", "pull", "thingsboard/tb-postgres"], check=True) + + logger.info("Starting ThingsBoard CE container...") + subprocess.run([ + "docker", "run", "-d", "--rm", + "--name", TB_CONTAINER_NAME, + "-p", "8080:8080", + "-p", f"{TB_MQTT_PORT}:1883", + "thingsboard/tb-postgres" + ], check=True) + + wait_for_tb_ready() + + +def wait_for_tb_ready(): + logger.info("Waiting for ThingsBoard CE to become ready...") + for _ in range(60): + try: + r = requests.post(f"{TB_URL}/api/auth/login", json={ + "username": TENANT_USER, + "password": TENANT_PASS + }) + if r.ok: + logger.info("ThingsBoard CE is ready.") + return + except requests.ConnectionError: + pass + time.sleep(2) + pytest.fail("ThingsBoard CE did not become ready in time.") + + +@pytest.fixture(scope="session") +def test_config(): + """Return configuration for blackbox tests.""" + return { + "tb_host": TB_HOST, + "tb_http_port": TB_HTTP_PORT, + "tb_mqtt_port": TB_MQTT_PORT, + "tenant_user": TENANT_USER, + "tenant_pass": TENANT_PASS, + "tb_url": TB_URL + } + + +@pytest.fixture(scope="session") +def tb_admin_token(start_thingsboard): + """Login as tenant admin and return JWT.""" + r = requests.post(f"{TB_URL}/api/auth/login", json={ + "username": TENANT_USER, + "password": TENANT_PASS + }) + r.raise_for_status() + return r.json()["token"] + + +@pytest.fixture(scope="session") +def tb_admin_headers(tb_admin_token): + """Return headers for ThingsBoard REST API requests.""" + return { + "X-Authorization": f"Bearer {tb_admin_token}", + "Content-Type": "application/json" + } + + +@pytest.fixture() +def device_info(tb_admin_headers): + try: + r = requests.post(f"{TB_URL}/api/device", headers=tb_admin_headers, json={ + "name": "pytest-device", + "type": "default" + }) + r.raise_for_status() + device = r.json() + except HTTPError as e: + logger.error(f"Failed to create device: {e}") + r = requests.get(f"{TB_URL}/api/tenant/devices?deviceName=pytest-device", headers=tb_admin_headers) + r.raise_for_status() + device = r.json() + if not device: + pytest.fail("Failed to create or find test device.") + + yield device + + +@pytest.fixture() +def device_config(device_info, tb_admin_headers): + + device_id = device_info["id"]["id"] + + # Get credentials + r = requests.get(f"{TB_URL}/api/device/{device_id}/credentials", headers=tb_admin_headers) + r.raise_for_status() + token = r.json()["credentialsId"] + + device_config = DeviceConfig() + device_config.host = TB_HOST + device_config.port = TB_MQTT_PORT + device_config.access_token = token + yield device_config + + # Cleanup + requests.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers) + + +@pytest.fixture() +def gateway_info(tb_admin_headers): + try: + r = requests.post(f"{TB_URL}/api/device", headers=tb_admin_headers, json={ + "name": "pytest-gw-device", + "type": "default", + "additionalInfo": { + "gateway": True + } + }) + r.raise_for_status() + device = r.json() + except HTTPError as e: + logger.error(f"Failed to create device: {e}") + r = requests.get(f"{TB_URL}/api/tenant/devices?deviceName=pytest-gw-device", headers=tb_admin_headers) + r.raise_for_status() + device = r.json() + if not device: + pytest.fail("Failed to create or find test device.") + + yield device + +@pytest.fixture +def gateway_config(gateway_info, tb_admin_headers): + + device_id = gateway_info["id"]["id"] + + # Get credentials + r = requests.get(f"{TB_URL}/api/device/{device_id}/credentials", headers=tb_admin_headers) + r.raise_for_status() + token = r.json()["credentialsId"] + + gateway_config = GatewayConfig() + gateway_config.host = TB_HOST + gateway_config.port = TB_MQTT_PORT + gateway_config.access_token = token + yield gateway_config + + subdevice_id = find_related_entity_id(device_id, TB_URL, tb_admin_headers) + + # Cleanup + requests.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers) + + sub_device_info = get_device_info_by_id( + subdevice_id, + TB_URL, + tb_admin_headers + ) + + requests.delete(f"{TB_URL}/api/device/{subdevice_id}", headers=tb_admin_headers) + + requests.delete(f"{TB_URL}/api/deviceProfile/{sub_device_info['deviceProfileId']['id']}", + headers=tb_admin_headers) diff --git a/tests/blackbox/rest_helpers.py b/tests/blackbox/rest_helpers.py new file mode 100644 index 0000000..eba63a2 --- /dev/null +++ b/tests/blackbox/rest_helpers.py @@ -0,0 +1,147 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from time import time +import requests + + +def get_device_attributes(device_id: str, base_url: str, headers: dict, scope: str) -> list: + + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/attributes?scope={scope}" + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + + +def get_device_timeseries(device_id: str, base_url: str, headers: dict, keys:list =None) -> dict: + current_ts = int(time() * 1000) + + if keys is None: + keys = [] + + url = (f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/timeseries?startTs={current_ts - 60000}" + f"&endTs={current_ts}" + f"&useStrictDataTypes=true") + + if keys: + url += f"&keys={','.join(keys)}" + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +def update_shared_attributes(device_id: str, base_url: str, headers: dict, attributes: dict) -> None: + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/attributes/SHARED_SCOPE" + + response = requests.post(url, json=attributes, headers=headers) + + if response.status_code != 200: + response.raise_for_status() + +def get_device_info_by_name(device_name: str, base_url: str, headers: dict) -> dict: + url = f"{base_url}/api/tenant/devices?deviceName={device_name}" + response = requests.get(url, headers=headers) + if response.status_code != 200: + response.raise_for_status() + device = response.json() + if not device: + raise ValueError(f"Device with name '{device_name}' not found.") + return device + +def find_related_entity_id(device_id: str, base_url: str, headers: dict) -> str: + url = f"{base_url}/api/relations?fromId={device_id}&fromType=DEVICE" + + response = requests.get(url, headers=headers) + if response.status_code == 200: + relations = response.json() + if relations: + return relations[0]['to']['id'] # Return the first related entity + else: + raise ValueError(f"No related entity found for device ID '{device_id}'.") + else: + response.raise_for_status() + +def get_device_info_by_id(device_id: str, base_url: str, headers: dict) -> dict: + url = f"{base_url}/api/device/{device_id}" + response = requests.get(url, headers=headers) + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +def get_default_device_profile(base_url: str, headers: dict) -> dict: + url = f"{base_url}/api/deviceProfileInfo/default" + response = requests.get(url, headers=headers) + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +def create_device_profile_with_provisioning(device_profile_name: str, + base_url: str, + headers: dict, + provisioning_device_key: str, + provisioning_device_secret: str) -> dict: + device_profile = get_default_device_profile(base_url, headers) + url = f"{base_url}/api/deviceProfile" + device_profile['id'] = None + device_profile['name'] = device_profile_name + device_profile['createdTime'] = None + device_profile['provisionType'] = 'ALLOW_CREATE_NEW_DEVICES' + if 'profileData' not in device_profile: + device_profile['profileData'] = { + "configuration": { + "type": "DEFAULT" + }, + "transportConfiguration": { + "type": "DEFAULT" + }, + "provisionConfiguration": { + "type": "ALLOW_CREATE_NEW_DEVICES", + "provisionDeviceSecret": provisioning_device_secret + }, + "alarms": None + } + else: + device_profile['profileData']['provisionConfiguration'] = { + 'type': 'ALLOW_CREATE_NEW_DEVICES', + 'provisionDeviceSecret': provisioning_device_secret + } + device_profile['provisionDeviceKey'] = provisioning_device_key + + response = requests.post(url, json=device_profile, headers=headers) + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + +async def send_rpc_request(device_id: str, base_url: str, headers: dict, request: dict): + loop = asyncio.get_running_loop() + + def _send(): + url = f"{base_url}/api/rpc/twoway/{device_id}" + r = requests.post(url, json=request, headers=headers, timeout=60) + r.raise_for_status() + return r.json() + + result = await loop.run_in_executor(None, _send) + return result diff --git a/tests/blackbox/test_basic_device_examples.py b/tests/blackbox/test_basic_device_examples.py new file mode 100644 index 0000000..4cc4b09 --- /dev/null +++ b/tests/blackbox/test_basic_device_examples.py @@ -0,0 +1,207 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest +from examples.device import send_attributes +from examples.device import send_timeseries +from examples.device import request_attributes +from examples.device import handle_attribute_updates +from examples.device import handle_rpc_requests +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tests.blackbox.rest_helpers import get_device_attributes, get_device_timeseries, update_shared_attributes, \ + send_rpc_request + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_send_attributes(device_config, device_info, tb_admin_headers, test_config): + expected_attributes = { + "firmwareVersion": "1.0.3", + "hardwareModel": "TB-SDK-Device", + "mode": "normal", + "location": "Building A", + "status": "active" + } + + send_attributes.config = device_config + + await send_attributes.main() + + attrs = get_device_attributes(device_info['id']['id'], test_config['tb_url'], tb_admin_headers, 'CLIENT_SCOPE') + received_keys = {a.get("key") for a in attrs} + for key, value in expected_attributes.items(): + assert any(a.get("key") == key and a.get("value") == value for a in attrs) + assert key in received_keys + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_send_timeseries(device_config, device_info, tb_admin_headers, test_config): + expected_timeseries = { + "temperature": "any", # Accept any value for temperature + "humidity": "any", # Accept any value for humidity + "pressure": "any", # Accept any value for pressure + "vibration": 0.05, + "speed": 123 + } + additional_timeseries = { + "vibration": 0.01, + "speed": 120 + } + + send_timeseries.config = device_config + await send_timeseries.main() + + timeseries = get_device_timeseries(device_info['id']['id'], test_config['tb_url'], tb_admin_headers) + for key, value in expected_timeseries.items(): + assert key in timeseries, f"Expected timeseries key '{key}' not found." + assert any((ts.get("value") == value or ts.get("value", True) == additional_timeseries.get(key, False) or value == "any") for ts in timeseries[key]), f"Expected value '{value}' for key '{key}' not found." + + all_timeseries_keys = set(timeseries.keys()) + all_timeseries = get_device_timeseries(device_info['id']['id'], test_config['tb_url'], + tb_admin_headers, keys=list(all_timeseries_keys)) + for key in all_timeseries_keys: + for ts_entry in all_timeseries[key]: + assert (expected_timeseries[key] == 'any' or ts_entry.get("value") == expected_timeseries[key] or ts_entry.get("value") == additional_timeseries.get(key)), f"Expected value '{expected_timeseries[key]}' for key '{key}' not found in all timeseries." + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_request_attributes(device_config, device_info, tb_admin_headers, test_config): + expected_attributes = { + "currentTemperature": 22.5 + } + + attributes_received = asyncio.Event() + + handler_future = asyncio.Future() + + async def attribute_request_callback(response: RequestedAttributeResponse): + attributes_received.set() + try: + assert response is not None, "Response should not be None" + assert response.request_id != -1, "Request ID should not be -1" + assert response.client, "Client attributes should not be empty" + for expected_attribute in expected_attributes: + assert any(attr.key == expected_attribute and attr.value == expected_attributes[expected_attribute] for attr in response.client), f"Expected attribute '{expected_attribute}' not found in response" + handler_future.set_result(True) + except AssertionError as e: + handler_future.set_exception(e) + + request_attributes.config = device_config + request_attributes.attribute_request_callback = attribute_request_callback + request_attributes.response_received = attributes_received + + await request_attributes.main() + + await asyncio.wait_for(handler_future, timeout=10) + + attrs = get_device_attributes(device_info['id']['id'], test_config['tb_url'], tb_admin_headers, 'CLIENT_SCOPE') + received_keys = {a.get("key") for a in attrs} + for key, value in expected_attributes.items(): + assert any(a.get("key") == key and a.get("value") == value for a in attrs) + assert key in received_keys + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_handle_attribute_updates(device_config, device_info, tb_admin_headers, test_config): + + test_shared_attribute = {"pytest_shared_attribute": "test_value"} + + handler_future = asyncio.Future() + + async def attribute_update_callback(update: AttributeUpdate): + try: + assert update is not None, "Update should not be None" + assert update.keys()[0] == list(test_shared_attribute.keys())[0], "Update key does not match expected key" + assert update.values()[0] == list(test_shared_attribute.values())[0], "Update value does not match expected value" + handler_future.set_result(True) + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_attribute_updates.attribute_update_callback = attribute_update_callback + handle_attribute_updates.config = device_config + + task = asyncio.create_task(handle_attribute_updates.main()) + + await asyncio.sleep(3) + try: + update_shared_attributes( + device_info['id']['id'], + test_config['tb_url'], + tb_admin_headers, + test_shared_attribute + ) + + await asyncio.wait_for(handler_future, timeout=10) + except asyncio.TimeoutError: + pytest.fail("Attribute update was not received within the timeout period") + finally: + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_handle_rpc_requests(device_config, device_info, tb_admin_headers, test_config): + + handler_future = asyncio.Future() + + rpc_request = { + "method": "getDeviceInfo", + "params": {"deviceId": device_info['id']['id']} + } + + async def rpc_request_callback(request: RPCRequest): + try: + assert request is not None, "RPC request should not be None" + assert request.method == "getDeviceInfo", "Unexpected RPC method" + response = RPCResponse.build(request_id=request.request_id, + result={"id": device_info['id'], "name": device_info['name']}) + handler_future.set_result(response) + return response + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_rpc_requests.rpc_request_callback = rpc_request_callback + handle_rpc_requests.config = device_config + + task = asyncio.create_task(handle_rpc_requests.main()) + await asyncio.sleep(1) + + try: + + response_task = asyncio.create_task(send_rpc_request(device_info['id']['id'], test_config['tb_url'], tb_admin_headers, rpc_request)) + + await asyncio.wait_for(handler_future, timeout=10) + + response = await asyncio.wait_for(response_task, timeout=10) + + assert response is not None, "RPC response should not be None" + + except asyncio.TimeoutError: + pytest.fail("RPC request was not received within the timeout period") + except Exception as e: + pytest.fail(f"An unexpected exception occurred: {e}") + finally: + task.cancel() \ No newline at end of file diff --git a/tests/blackbox/test_basic_gateway_examples.py b/tests/blackbox/test_basic_gateway_examples.py new file mode 100644 index 0000000..906dc65 --- /dev/null +++ b/tests/blackbox/test_basic_gateway_examples.py @@ -0,0 +1,359 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest + +from examples.gateway import connect_and_disconnect_device +from examples.gateway import send_timeseries +from examples.gateway import send_attributes +from examples.gateway import request_attributes +from examples.gateway import handle_attribute_updates +from examples.gateway import handle_rpc_requests +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tests.blackbox.rest_helpers import get_device_info_by_name, get_device_attributes, get_device_timeseries, \ + update_shared_attributes, send_rpc_request + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_connect_and_disconnect_device(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice" + device_profile = "pytest-gw-subdevice-profile" + connect_and_disconnect_device.config = gateway_config + connect_and_disconnect_device.device_name = device_name + connect_and_disconnect_device.device_profile = device_profile + + task = asyncio.create_task(connect_and_disconnect_device.main()) + + try: + await asyncio.sleep(3) + device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + device_service_attributes = get_device_attributes(device_info["id"]["id"], test_config["tb_url"], tb_admin_headers, "SERVER_SCOPE") + server_attributes = {attr["key"]: attr["value"] for attr in device_service_attributes} + assert device_info['type'] == device_profile + assert server_attributes['active'] + assert abs(server_attributes['lastActivityTime'] - server_attributes["lastConnectTime"]) < 5 + assert server_attributes['lastDisconnectTime'] >= server_attributes["lastConnectTime"] + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_send_timeseries(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-ts" + device_profile = "pytest-gw-subdevice-ts-profile" + + send_timeseries.config = gateway_config + send_timeseries.device_name = device_name + send_timeseries.device_profile = device_profile + + task = asyncio.create_task(send_timeseries.main()) + + try: + await asyncio.sleep(5) + + device_info = get_device_info_by_name( + device_name, test_config["tb_url"], tb_admin_headers + ) + assert device_info is not None, "Subdevice was not created/connected." + + timeseries = get_device_timeseries( + device_info["id"]["id"], test_config["tb_url"], tb_admin_headers + ) + + expected_keys = {"temperature", "humidity"} + for key in expected_keys: + assert key in timeseries, f"Timeseries key '{key}' not found in device data." + + for key in expected_keys: + assert any(isinstance(entry.get("value"), (int, float)) for entry in timeseries[key]), \ + f"No numeric values found for '{key}'." + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_send_attributes(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-attrs" + device_profile = "pytest-gw-subdevice-attrs-profile" + + send_attributes.config = gateway_config + send_attributes.device_name = device_name + send_attributes.device_profile = device_profile + + task = asyncio.create_task(send_attributes.main()) + + try: + await asyncio.sleep(5) + + device_info = get_device_info_by_name( + device_name, test_config["tb_url"], tb_admin_headers + ) + assert device_info is not None, "Subdevice was not created/connected." + + attrs = get_device_attributes( + device_info["id"]["id"], + test_config["tb_url"], + tb_admin_headers, + scope="CLIENT_SCOPE" + ) + attr_dict = {a["key"]: a["value"] for a in attrs} + + expected_attrs = { + "maintenance": "scheduled", + "id": 341, + "location": "office", + "status": "active", + "version": "1.0.0" + } + + for key, value in expected_attrs.items(): + assert key in attr_dict, f"Attribute '{key}' not found." + assert attr_dict[key] == value, ( + f"Attribute '{key}' expected '{value}', got '{attr_dict[key]}'" + ) + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_request_attributes(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-attrs-request" + device_profile = "pytest-gw-subdevice-attrs-request-profile" + + expected_attrs = { + "maintenance": "scheduled", + "id": 341, + "location": "office", + } + + handler_future = asyncio.Future() + + async def requested_attributes_handler(device_session, response: GatewayRequestedAttributeResponse): + try: + assert device_session is not None, "Device session should not be None." + assert device_session.device_info is not None, "Device info should not be None." + assert device_session.device_info.device_name == device_name, ( + f"Device name mismatch: expected '{device_name}', got '{device_session.device_info.device_name}'" + ) + assert response is not None, "Response should not be None." + assert isinstance(response.client, list), "Client attributes should be a list." + assert isinstance(response.shared, list), "Shared attributes should be a list." + if response.request_id == 1: # Assuming request_id 1 is for the client attributes + assert len(response.client) > 0, "Client attributes should not be empty." + assert len(response.shared) >= 0, "Shared attributes can be empty." + for attr in response.client: + for key, value in expected_attrs.items(): + if attr.key == key: + assert attr.value == value, ( + f"Attribute '{key}' expected '{value}', got '{attr.value}'" + ) + else: + assert response.client == [], "Client attributes should be empty for shared attribute requests." + assert response.shared == [], "Shared attributes should be empty, because they were not set." + handler_future.set_result(True) + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + request_attributes.config = gateway_config + request_attributes.device_name = device_name + request_attributes.device_profile = device_profile + request_attributes.requested_attributes_handler = requested_attributes_handler + + task = asyncio.create_task(request_attributes.main()) + + try: + await asyncio.wait_for(handler_future, timeout=10) + + device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + assert device_info is not None, "Subdevice was not created/connected." + + attrs = get_device_attributes( + device_info["id"]["id"], + test_config["tb_url"], + tb_admin_headers, + scope="CLIENT_SCOPE" + ) + attr_dict = {a["key"]: a["value"] for a in attrs} + + for key, value in expected_attrs.items(): + assert key in attr_dict, f"Attribute '{key}' not found." + assert attr_dict[key] == value, ( + f"Attribute '{key}' expected '{value}', got '{attr_dict[key]}'" + ) + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_handle_attribute_updates(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-attrs-update" + device_profile = "pytest-gw-subdevice-attrs-update-profile" + + test_shared_attribute = {"pytest_shared_attribute": "test_value"} + + handler_future = asyncio.Future() + + async def attribute_update_handler(device_session: DeviceSession, update: AttributeUpdate): + try: + assert device_session is not None, "Device session should not be None." + assert device_session.device_info is not None, "Device info should not be None." + assert device_session.device_info.device_name == device_name, ( + f"Device name mismatch: expected '{device_name}', got '{device_session.device_info.device_name}'" + ) + assert update is not None, "Update should not be None." + assert update.keys()[0] == list(test_shared_attribute.keys())[0], ( + "Update key does not match expected key." + ) + assert update.values()[0] == list(test_shared_attribute.values())[0], ( + "Update value does not match expected value." + ) + handler_future.set_result(True) + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_attribute_updates.config = gateway_config + handle_attribute_updates.device_name = device_name + handle_attribute_updates.device_profile = device_profile + handle_attribute_updates.attribute_update_handler = attribute_update_handler + + task = asyncio.create_task(handle_attribute_updates.main()) + + try: + await asyncio.sleep(3) + + sub_device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + + update_shared_attributes( + sub_device_info["id"]["id"], + test_config["tb_url"], + tb_admin_headers, + test_shared_attribute + ) + + await asyncio.wait_for(handler_future, timeout=10) + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_handle_rpc_requests(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-rpc" + device_profile = "pytest-gw-subdevice-rpc-profile" + + handler_future = asyncio.Future() + + async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest): + try: + assert device_session is not None, "Device session should not be None." + assert device_session.device_info is not None, "Device info should not be None." + assert device_session.device_info.device_name == device_name, ( + f"Device name mismatch: expected '{device_name}', got '{device_session.device_info.device_name}'" + ) + assert rpc_request is not None, "RPC request should not be None." + assert rpc_request.method == "testMethod", "RPC method does not match expected method." + response_data = { + "data": { + "device_name": device_session.device_info.device_name, + "request_id": rpc_request.request_id, + "method": rpc_request.method, + "params": rpc_request.params + } + } + rpc_response = GatewayRPCResponse.build( + device_session.device_info.device_name, + rpc_request.request_id, + response_data + ) + handler_future.set_result(True) + return rpc_response + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_rpc_requests.config = gateway_config + handle_rpc_requests.device_name = device_name + handle_rpc_requests.device_profile = device_profile + handle_rpc_requests.device_rpc_request_handler = device_rpc_request_handler + + task = asyncio.create_task(handle_rpc_requests.main()) + await asyncio.sleep(3) + + try: + + sub_device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + rpc_request = { + "method": "testMethod", + "params": {"param1": "value1"} + } + + response_task = asyncio.create_task(send_rpc_request(sub_device_info['id']['id'], test_config['tb_url'], tb_admin_headers, rpc_request)) + + await asyncio.wait_for(handler_future, timeout=10) + + response_data = await asyncio.wait_for(response_task, timeout=10) + + assert 'result' in response_data, "RPC response should contain 'result'." + assert 'data' in response_data['result'], "RPC response should contain 'data'." + assert 'device_name' in response_data['result']['data'], "RPC response data should contain 'device_name'." + assert response_data['result']['data']['device_name'] == device_name, "RPC response device_name does not match expected device_name." + assert 'method' in response_data['result']['data'], "RPC response data should contain 'method'." + assert response_data['result']['data']['method'] == "testMethod", "RPC response method does not match expected method." + assert 'request_id' in response_data['result']['data'], "RPC response data should contain 'request_id'." + assert response_data['result']['data']['request_id'] == 0, "RPC response request_id does not match expected request_id." + assert 'params' in response_data['result']['data'], "RPC response data should contain 'params'." + assert response_data['result']['data']['params'] == {"param1": "value1"}, "RPC response params do not match expected params." + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/tests/blackbox/test_client_provisioning.py b/tests/blackbox/test_client_provisioning.py new file mode 100644 index 0000000..5450007 --- /dev/null +++ b/tests/blackbox/test_client_provisioning.py @@ -0,0 +1,79 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio + +import pytest +import requests + +from examples.device import client_provisioning +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials +from tests.blackbox.rest_helpers import create_device_profile_with_provisioning, \ + get_device_timeseries, get_device_info_by_name + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_client_provisioning(tb_admin_headers, test_config): + device_name = 'pytest-provisioned-device' + + provisioning_device_key = 'pytest-provisioning-device-key' + provisioning_device_secret = 'pytest-provisioning-device-secret' + + provisioning_credentials = AccessTokenProvisioningCredentials( + provision_device_key=provisioning_device_key, + provision_device_secret=provisioning_device_secret, + ) + provisioning_request = ProvisioningRequest(test_config['tb_host'], + credentials=provisioning_credentials, + device_name=device_name) + + device_profile = create_device_profile_with_provisioning('pytest-provisioning-profile', + test_config['tb_url'], + tb_admin_headers, + provisioning_device_key, + provisioning_device_secret) + provisioned_device_info = None + try: + client_provisioning.provisioning_request = provisioning_request + + try: + await client_provisioning.main() + except Exception as e: + pytest.fail(f"Provisioning failed with exception: {e}") + + provisioned_device_info = get_device_info_by_name(device_name, + test_config['tb_url'], + tb_admin_headers) + + assert provisioned_device_info is not None, "Provisioned device info should not be None" + assert provisioned_device_info['name'] == device_name, "Provisioned device name should match" + assert provisioned_device_info['deviceProfileId']['id'] == device_profile['id']['id'], \ + "Provisioned device profile ID should match" + + timeseries = get_device_timeseries(provisioned_device_info['id']['id'], test_config['tb_url'], tb_admin_headers, + ['batteryLevel']) + assert len(timeseries) == 1, "There should be one timeseries entry" + assert 'batteryLevel' in timeseries, "Timeseries key should be 'batteryLevel'" + assert 'value' in timeseries['batteryLevel'][0], "Timeseries entry should have a value" + finally: + # Cleanup: delete the provisioned device and profile + if provisioned_device_info and 'id' in provisioned_device_info: + device_id = provisioned_device_info['id']['id'] + delete_url = f"{test_config['tb_url']}/api/device/{device_id}" + requests.delete(delete_url, headers=tb_admin_headers) + + if 'id' in device_profile: + profile_id = device_profile['id']['id'] + delete_profile_url = f"{test_config['tb_url']}/api/deviceProfile/{profile_id}" + requests.delete(delete_profile_url, headers=tb_admin_headers) diff --git a/tests/common/test_provisioning_client.py b/tests/common/test_provisioning_client.py index 0d00244..51383ab 100644 --- a/tests/common/test_provisioning_client.py +++ b/tests/common/test_provisioning_client.py @@ -16,10 +16,12 @@ import pytest +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.provisioning_client import ProvisioningClient from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse @pytest.fixture @@ -39,25 +41,26 @@ async def test_successful_provisioning_flow(mock_dispatcher_cls, mock_gmqtt_cls, mock_client = AsyncMock() mock_gmqtt_cls.return_value = mock_client - mock_dispatcher = MagicMock() + mock_adapter = MagicMock() topic = "provision/topic" payload = b'{"provision": "data"}' - mock_dispatcher.build_provision_request.return_value = (topic, payload) + mock_message = MqttPublishMessage(topic=topic, payload=payload) + mock_adapter.build_provision_request.return_value = mock_message - mock_device_config = MagicMock() - mock_dispatcher.parse_provisioning_response.return_value.result = mock_device_config - mock_dispatcher_cls.return_value = mock_dispatcher + mock_provisioning_response = ProvisioningResponse.build(real_request, {"credentialsValue": "config_value", "status": "SUCCESS"}) + mock_adapter.parse_provisioning_response.return_value = mock_provisioning_response + mock_dispatcher_cls.return_value = mock_adapter client = ProvisioningClient("test_host", 1883, real_request) client._on_connect(mock_client, None, 0, None) mock_client.subscribe.assert_called_once_with(PROVISION_RESPONSE_TOPIC) - mock_client.publish.assert_called_once_with(topic, payload) + mock_client.publish.assert_called_once_with(mock_message.topic, mock_message.payload) await client._on_message(None, None, b"payload-data", None, None) - assert client._device_config == mock_device_config + assert client._provisioning_response == mock_provisioning_response assert client._provisioned.is_set() mock_client.disconnect.assert_awaited_once() @@ -74,8 +77,8 @@ async def test_failed_connection(mock_dispatcher_cls, mock_gmqtt_cls, real_reque with caplog.at_level("ERROR"): client._on_connect(mock_client, None, 1, None) - assert client._device_config is not None - assert client._device_config.status == ProvisioningResponseStatus.ERROR + assert client._provisioning_response is not None + assert client._provisioning_response.status == ProvisioningResponseStatus.ERROR assert client._provisioned.is_set() assert "Cannot connect to ThingsBoard!" in caplog.text @@ -89,7 +92,7 @@ async def test_provision_method_awaits_provisioned(mock_dispatcher_cls, mock_gmq client = ProvisioningClient("localhost", 1883, real_request) expected_config = MagicMock() - client._device_config = expected_config + client._provisioning_response = expected_config client._provisioned.set() result = await client.provision() @@ -106,4 +109,8 @@ def test_initial_state(real_request): assert client._provision_request == real_request assert client._client_id == "provision" assert not client._provisioned.is_set() - assert client._device_config is None + assert client._provisioning_response is None + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/entities/data/test_provisioning_data.py b/tests/entities/data/test_provisioning_data.py index d2eb4af..d1165c1 100644 --- a/tests/entities/data/test_provisioning_data.py +++ b/tests/entities/data/test_provisioning_data.py @@ -86,7 +86,7 @@ def test_error_response(): def test_success_access_token(access_token_request): - payload = {"credentialsValue": "ACCESS-TOKEN-123"} + payload = {"credentialsValue": "ACCESS-TOKEN-123", "status": "SUCCESS"} response = ProvisioningResponse.build(access_token_request, payload) @@ -104,7 +104,8 @@ def test_success_mqtt_basic(mqtt_basic_request): "clientId": "my-client-id", "userName": "user1", "password": "pass123" - } + }, + "status": "SUCCESS" } response = ProvisioningResponse.build(mqtt_basic_request, payload) @@ -121,7 +122,7 @@ def test_success_mqtt_basic(mqtt_basic_request): def test_success_x509(x509_request): - payload = {"credentialsValue": None} # Should be ignored for X509 + payload = {"credentialsValue": None, "status": "SUCCESS"} # Should be ignored for X509 response = ProvisioningResponse.build(x509_request, payload) @@ -136,7 +137,7 @@ def test_success_x509(x509_request): def test_repr_output(access_token_request): - payload = {"credentialsValue": "access-token"} + payload = {"credentialsValue": "access-token", "status": "SUCCESS"} response = ProvisioningResponse.build(access_token_request, payload) r = repr(response) @@ -148,21 +149,21 @@ def test_repr_output(access_token_request): def test_missing_credentials_value_for_access_token(access_token_request): - payload = {} # Missing 'credentialsValue' + payload = {"status": "SUCCESS"} # Missing 'credentialsValue' with pytest.raises(KeyError): ProvisioningResponse.build(access_token_request, payload) def test_mqtt_basic_missing_fields(mqtt_basic_request): - payload = {"credentialsValue": {}} # All fields missing + payload = {"credentialsValue": {}, "status": "SUCCESS"} # All fields missing with pytest.raises(KeyError): ProvisioningResponse.build(mqtt_basic_request, payload) def test_access_token_none_is_accepted(access_token_request): - payload = {"credentialsValue": None} + payload = {"credentialsValue": None, "status": "SUCCESS"} response = ProvisioningResponse.build(access_token_request, payload) assert response.status == ProvisioningResponseStatus.SUCCESS From 8b18e2549fd5c45d3815721ea06359742beeecf2 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Thu, 7 Aug 2025 13:06:13 +0300 Subject: [PATCH 71/74] Added fix for checksum verification and blackbox tests for firmware update --- examples/device/firmware_update.py | 18 ++--- tb_mqtt_client/service/device/client.py | 2 +- .../service/device/firmware_updater.py | 32 ++++---- tests/blackbox/conftest.py | 43 ++++++++++- tests/blackbox/rest_helpers.py | 76 ++++++++++++++++++- ....py => test_device_client_provisioning.py} | 0 tests/blackbox/test_device_firmware_update.py | 64 ++++++++++++++++ tests/service/device/test_firmware_updater.py | 6 -- 8 files changed, 203 insertions(+), 38 deletions(-) rename tests/blackbox/{test_client_provisioning.py => test_device_client_provisioning.py} (100%) create mode 100644 tests/blackbox/test_device_firmware_update.py diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py index d9e7ae6..7063ef7 100644 --- a/examples/device/firmware_update.py +++ b/examples/device/firmware_update.py @@ -16,7 +16,6 @@ import asyncio import logging -from time import monotonic from tb_mqtt_client.common.config_loader import DeviceConfig from tb_mqtt_client.common.logging_utils import configure_logging, get_logger @@ -30,25 +29,24 @@ firmware_received = asyncio.Event() firmware_update_timeout = 30 +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" -async def firmware_update_callback(_, payload): - logger.info(f"Firmware update payload received: {payload}") + +async def firmware_update_callback(firmware_data, firmware_info): + logger.info(f"Firmware update payload received: {firmware_info}") firmware_received.set() async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" client = DeviceClient(config) await client.connect() - await client.update_firmware(on_received_callback=firmware_update_callback) + await client.update_firmware(on_received_callback=firmware_update_callback, save_firmware=False) # Set save_firmware to True if you want to save the firmware data - update_started = monotonic() - while not firmware_received.is_set() and monotonic() - update_started < firmware_update_timeout: - await asyncio.sleep(1) + await asyncio.wait_for(firmware_received.wait(), timeout=firmware_update_timeout) await client.stop() diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index 7b05202..13068e1 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -106,7 +106,7 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._firmware_updater = FirmwareUpdater(self) - async def update_firmware(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, + async def update_firmware(self, on_received_callback: Optional[Callable[[bytes, dict], Awaitable[None]]] = None, save_firmware: bool = True, firmware_save_path: Optional[str] = None): await self._firmware_updater.update(on_received_callback, save_firmware, firmware_save_path) diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py index 6dd2a0d..6ac123e 100644 --- a/tb_mqtt_client/service/device/firmware_updater.py +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -102,7 +102,7 @@ async def _get_next_chunk(self): topic = mqtt_topics.build_firmware_update_request_topic(self._firmware_request_id, self._current_chunk) mqtt_message = MqttPublishMessage(topic, payload) - await self._client._message_queue.publish(mqtt_message, wait_for_publish=True) + await self._client._message_queue.publish(mqtt_message) async def _verify_downloaded_firmware(self): self._log.info('Verifying downloaded firmware...') @@ -111,8 +111,8 @@ async def _verify_downloaded_firmware(self): await self._send_current_firmware_info() verified = self.verify_checksum(self._firmware_data, - self._target_checksum, - self._target_checksum_alg) + self._target_checksum_alg, + self._target_checksum) if verified: self._log.debug('Checksum verified.') @@ -201,8 +201,8 @@ async def _firmware_info_callback(self, response, *args, **kwargs): self._firmware_request_id += 1 self._target_firmware_length = fetched_firmware_info[FW_SIZE_ATTR] - self._target_checksum = fetched_firmware_info[FW_CHECKSUM_ALG_ATTR] - self._target_checksum_alg = fetched_firmware_info[FW_CHECKSUM_ATTR] + self._target_checksum = fetched_firmware_info[FW_CHECKSUM_ATTR] + self._target_checksum_alg = fetched_firmware_info[FW_CHECKSUM_ALG_ATTR] self._target_title = fetched_firmware_info[FW_TITLE_ATTR] self._target_version = fetched_firmware_info[FW_VERSION_ATTR] @@ -239,27 +239,28 @@ def verify_checksum(self, firmware_data, checksum_alg, checksum): checksum_of_received_firmware = None self._log.debug('Checksum algorithm is: %s' % checksum_alg) - if checksum_alg.lower() == "sha256": + lower_checksum_alg = checksum_alg.lower() + if lower_checksum_alg == "sha256": checksum_of_received_firmware = sha256(firmware_data).digest().hex() - elif checksum_alg.lower() == "sha384": + elif lower_checksum_alg == "sha384": checksum_of_received_firmware = sha384(firmware_data).digest().hex() - elif checksum_alg.lower() == "sha512": + elif lower_checksum_alg == "sha512": checksum_of_received_firmware = sha512(firmware_data).digest().hex() - elif checksum_alg.lower() == "md5": + elif lower_checksum_alg == "md5": checksum_of_received_firmware = md5(firmware_data).digest().hex() - elif checksum_alg.lower() == "murmur3_32": + elif lower_checksum_alg == "murmur3_32": reversed_checksum = f'{hash(firmware_data, signed=False):0>2X}' if len(reversed_checksum) % 2 != 0: reversed_checksum = '0' + reversed_checksum checksum_of_received_firmware = "".join( reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - elif checksum_alg.lower() == "murmur3_128": + elif lower_checksum_alg == "murmur3_128": reversed_checksum = f'{hash128(firmware_data, signed=False):0>2X}' if len(reversed_checksum) % 2 != 0: reversed_checksum = '0' + reversed_checksum checksum_of_received_firmware = "".join( reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - elif checksum_alg.lower() == "crc32": + elif lower_checksum_alg == "crc32": reversed_checksum = f'{crc32(firmware_data) & 0xffffffff:0>2X}' if len(reversed_checksum) % 2 != 0: reversed_checksum = '0' + reversed_checksum @@ -268,11 +269,4 @@ def verify_checksum(self, firmware_data, checksum_alg, checksum): else: self._log.error('Client error. Unsupported checksum algorithm.') - self._log.debug(checksum_of_received_firmware) - - random_value = randint(0, 5) - if random_value > 3: - self._log.debug('Dummy fail! Do not panic, just restart and try again the chance of this fail is ~20%') - return False - return checksum_of_received_firmware == checksum diff --git a/tests/blackbox/conftest.py b/tests/blackbox/conftest.py index 67b5de9..025a575 100644 --- a/tests/blackbox/conftest.py +++ b/tests/blackbox/conftest.py @@ -18,12 +18,14 @@ import pytest import requests import logging +import uuid from requests import HTTPError from tb_mqtt_client.common.config_loader import GatewayConfig from tb_mqtt_client.common.config_loader import DeviceConfig -from tests.blackbox.rest_helpers import find_related_entity_id, get_device_info_by_id +from tests.blackbox.rest_helpers import find_related_entity_id, get_device_info_by_id, \ + create_device_profile_and_firmware TB_HOST = os.getenv("SDK_BLACKBOX_TB_HOST", "localhost") TB_HTTP_PORT = int(os.getenv("SDK_BLACKBOX_TB_PORT", 8080)) @@ -226,3 +228,42 @@ def gateway_config(gateway_info, tb_admin_headers): requests.delete(f"{TB_URL}/api/deviceProfile/{sub_device_info['deviceProfileId']['id']}", headers=tb_admin_headers) + + +@pytest.fixture +def firmware_profile_and_package(tb_admin_headers, test_config): + firmware_name = 'pytest-firmware' + firmware_version = '1.0.0' + boundary = uuid.uuid4().hex + firmware_bytes = b'Firmware binary data for pytest' + + multipart_body = ( + f"--{boundary}\r\n" + f"Content-Disposition: form-data; name=\"file\"; filename=\"firmware.bin\"\r\n" + f"Content-Type: application/octet-stream\r\n\r\n" + ).encode() + firmware_bytes + f"\r\n--{boundary}--\r\n".encode() + + firmware_info = { + "name": firmware_name, + "version": firmware_version, + "data": firmware_bytes + } + + firmware_headers = { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Content-Length": str(len(multipart_body)) + } + firmware_headers["X-Authorization"] = tb_admin_headers["X-Authorization"] + + device_profile, firmware = create_device_profile_and_firmware( + firmware_name, firmware_version, multipart_body, test_config['tb_url'], tb_admin_headers, firmware_headers + ) + + assert device_profile is not None, "Device profile should be created successfully" + assert firmware is not None, "Firmware should be created successfully" + + yield device_profile, firmware, firmware_info + + # Cleanup + requests.delete(f"{test_config['tb_url']}/api/deviceProfile/{device_profile['id']['id']}", headers=tb_admin_headers) + requests.delete(f"{test_config['tb_url']}/api/firmware/{firmware['id']['id']}", headers=tb_admin_headers) diff --git a/tests/blackbox/rest_helpers.py b/tests/blackbox/rest_helpers.py index eba63a2..0829105 100644 --- a/tests/blackbox/rest_helpers.py +++ b/tests/blackbox/rest_helpers.py @@ -14,6 +14,8 @@ import asyncio from time import time +from typing import Tuple + import requests @@ -134,6 +136,78 @@ def create_device_profile_with_provisioning(device_profile_name: str, else: response.raise_for_status() +def create_device_profile_and_firmware(firmware_name: str, + firmware_version: str, + firmware_data: bytes, + base_url: str, + headers: dict, + firmware_data_headers: dict) -> Tuple[dict, dict]: + device_profile = get_default_device_profile(base_url, headers) + device_profile['id'] = None + device_profile['createdTime'] = None + device_profile['name'] = 'pytest-firmware-profile' + if 'profileData' not in device_profile: + device_profile['profileData'] = { + "configuration": { + "type": "DEFAULT" + }, + "transportConfiguration": { + "type": "DEFAULT" + } + } + response = requests.post(f"{base_url}/api/deviceProfile", json=device_profile, headers=headers) + if response.status_code == 200: + device_profile = response.json() + else: + response.raise_for_status() + + # Create OTA package + init_ota_package ={ + "id": None, + "createdTime": None, + "deviceProfileId": { + "entityType": "DEVICE_PROFILE", + "id": device_profile['id']['id'] + }, + "type": "FIRMWARE", + "title": firmware_name, + "version": firmware_version, + "tag": firmware_name + " " + firmware_version, + "url": None, + "hasData": False, + "fileName": None, + "contentType": None, + "checksumAlgorithm": None, + "checksum": None, + "dataSize": None, + "externalId": None, + "name": firmware_name, + "additionalInfo": { + "description": "" + } + } + created_ota_package = requests.post(f"{base_url}/api/otaPackage", json=init_ota_package, headers=headers) + initial_ota_package = None + if created_ota_package.status_code == 200: + initial_ota_package = created_ota_package.json() + else: + created_ota_package.raise_for_status() + # Upload firmware data + firmware_data_url = f"{base_url}/api/otaPackage/{initial_ota_package['id']['id']}?checksumAlgorithm=SHA256" + response = requests.post(firmware_data_url, data=firmware_data, headers=firmware_data_headers) + if response.status_code == 200: + return device_profile, response.json() + else: + response.raise_for_status() + +def save_device(device: dict, base_url: str, headers: dict) -> dict: + url = f"{base_url}/api/device" + response = requests.post(url, json=device, headers=headers) + if response.status_code == 200: + return response.json() + else: + response.raise_for_status() + async def send_rpc_request(device_id: str, base_url: str, headers: dict, request: dict): loop = asyncio.get_running_loop() @@ -144,4 +218,4 @@ def _send(): return r.json() result = await loop.run_in_executor(None, _send) - return result + return result \ No newline at end of file diff --git a/tests/blackbox/test_client_provisioning.py b/tests/blackbox/test_device_client_provisioning.py similarity index 100% rename from tests/blackbox/test_client_provisioning.py rename to tests/blackbox/test_device_client_provisioning.py diff --git a/tests/blackbox/test_device_firmware_update.py b/tests/blackbox/test_device_firmware_update.py new file mode 100644 index 0000000..f1c324b --- /dev/null +++ b/tests/blackbox/test_device_firmware_update.py @@ -0,0 +1,64 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from copy import deepcopy + +import pytest + +from examples.device import firmware_update +from tests.blackbox.rest_helpers import save_device + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_device_firmware_update(firmware_profile_and_package, device_info, device_config, tb_admin_headers, test_config): + + handler_future = asyncio.Future() + + async def firmware_update_callback(firmware_data, current_firmware_info): + handler_future.set_result((firmware_data, current_firmware_info)) + + firmware_update.config = device_config + firmware_update.firmware_update_callback = firmware_update_callback + + device_profile, firmware_package, firmware_info = firmware_profile_and_package + + device = deepcopy(device_info) + + device['deviceProfileId'] = device_profile['id'] + device['firmwareId'] = firmware_package['id'] + + save_device(device, test_config['tb_url'], tb_admin_headers) + + await asyncio.sleep(1) # Ensure the device is saved before starting the firmware update + + task = asyncio.create_task(firmware_update.main()) + result = None + try: + result = await asyncio.wait_for(handler_future, timeout=30) + except Exception as e: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert result is not None, "Firmware update callback should be called" + assert isinstance(result, tuple), "Result should be a tuple containing firmware data and info" + received_firmware_data, received_firmware_info = result + assert received_firmware_info['current_fw_title'] == firmware_info['name'] + assert received_firmware_info['current_fw_version'] == firmware_info['version'], \ + "Received firmware version should match the package version" + assert received_firmware_info['fw_state'] == 'UPDATED' diff --git a/tests/service/device/test_firmware_updater.py b/tests/service/device/test_firmware_updater.py index d3598cb..c0eb751 100644 --- a/tests/service/device/test_firmware_updater.py +++ b/tests/service/device/test_firmware_updater.py @@ -259,9 +259,3 @@ def test_verify_checksum_known_algorithms(updater, alg): name, checksum = alg with patch("tb_mqtt_client.service.device.firmware_updater.randint", return_value=0): assert updater.verify_checksum(b"data", name, checksum) is True - - -def test_verify_checksum_random_failure(updater): - with patch("tb_mqtt_client.service.device.firmware_updater.randint", return_value=5): - result = updater.verify_checksum(b"data", "md5", md5(b"data").digest().hex()) - assert not result From 404b6a3c6c9dc43b9cf1e8681fb4252ce76719fb Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 8 Aug 2025 08:38:50 +0300 Subject: [PATCH 72/74] Optimized tests and added blackbox test for client side RPC --- examples/device/send_client_side_rpc.py | 23 +- tests/blackbox/conftest.py | 572 +++++++++++++----- tests/blackbox/constants.py | 81 +++ tests/blackbox/rest_helpers.py | 371 +++++++----- tests/blackbox/test_device_client_side_rpc.py | 62 ++ 5 files changed, 777 insertions(+), 332 deletions(-) create mode 100644 tests/blackbox/constants.py create mode 100644 tests/blackbox/test_device_client_side_rpc.py diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py index 2e2d4fc..0020a25 100644 --- a/examples/device/send_client_side_rpc.py +++ b/examples/device/send_client_side_rpc.py @@ -28,15 +28,20 @@ logger.setLevel(logging.INFO) logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +response_received = asyncio.Event() async def rpc_response_callback(response: RPCResponse): - logger.info("Received RPC response: %r", response) + logger.info("Received RPC response in callback: %r", response) + response_received.set() +rpc_response = None async def main(): - config = DeviceConfig() - config.host = "localhost" - config.access_token = "YOUR_ACCESS_TOKEN" + global rpc_response client = DeviceClient(config) await client.connect() @@ -44,16 +49,16 @@ async def main(): # Send client-side RPC and wait for response rpc_request = await RPCRequest.build("getTime", {}) try: - response = await client.send_rpc_request(rpc_request) - logger.info("Received response: %r", response) + rpc_response = await client.send_rpc_request(rpc_request) + logger.info("Received RPC response: %r", rpc_response) except TimeoutError: logger.info("RPC request timed out") - # Send client-side RPC with callback - rpc_request_2 = await RPCRequest.build("getStatus", {}) + # Send client-side RPC with response callback + rpc_request_2 = await RPCRequest.build("getStatus", {"param1": "value1", "param2": "value2"}) await client.send_rpc_request(rpc_request_2, rpc_response_callback, wait_for_publish=False) - await asyncio.sleep(5) + await asyncio.wait_for(response_received.wait(), timeout=30) await client.stop() diff --git a/tests/blackbox/conftest.py b/tests/blackbox/conftest.py index 025a575..93b36ae 100644 --- a/tests/blackbox/conftest.py +++ b/tests/blackbox/conftest.py @@ -12,98 +12,193 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import logging import os import subprocess import time -import pytest -import requests -import logging import uuid +from copy import deepcopy +from typing import Dict, Generator, Optional +import pytest +import requests from requests import HTTPError +from requests.adapters import HTTPAdapter +from urllib3.util import Retry + +from tb_mqtt_client.common.config_loader import DeviceConfig, GatewayConfig +from tests.blackbox.constants import RPC_RESPONSE_RULE_CHAIN +from tests.blackbox.rest_helpers import ( + find_related_entity_id, + get_device_info_by_id, + create_device_profile_and_firmware, + get_default_device_profile, +) + +# ---------------------------- +# Environment and configuration +# ---------------------------- + +ENV = os.environ.get +TB_HOST: str = ENV("SDK_BLACKBOX_TB_HOST", "localhost") +TB_HTTP_PORT: int = int(ENV("SDK_BLACKBOX_TB_PORT", "8080")) +TB_MQTT_PORT: int = int(ENV("SDK_BLACKBOX_TB_MQTT_PORT", "1883")) +TENANT_USER: str = ENV("SDK_BLACKBOX_TENANT_USER", "tenant@thingsboard.org") +TENANT_PASS: str = ENV("SDK_BLACKBOX_TENANT_PASS", "tenant") +TB_CONTAINER_NAME: str = ENV("SDK_BLACKBOX_TB_CONTAINER", "tb-ce-sdk-tests") +RUN_WITH_LOCAL_TB: bool = ENV("SDK_RUN_BLACKBOX_TESTS_LOCAL", "false").lower() == "true" +RUN_BLACKBOX: bool = ENV("SDK_RUN_BLACKBOX_TESTS", "false").lower() == "true" + +TB_HTTP_PROTOCOL: str = ENV("SDK_BLACKBOX_TB_HTTP_PROTOCOL", "http") +TB_URL: str = ENV("SDK_BLACKBOX_TB_URL", f"{TB_HTTP_PROTOCOL}://{TB_HOST}:{TB_HTTP_PORT}") + +REQUEST_TIMEOUT: float = float(ENV("SDK_BLACKBOX_HTTP_TIMEOUT", "30")) +DOCKER_START_TIMEOUT_S: int = int(ENV("SDK_BLACKBOX_TB_START_TIMEOUT", "180")) -from tb_mqtt_client.common.config_loader import GatewayConfig -from tb_mqtt_client.common.config_loader import DeviceConfig -from tests.blackbox.rest_helpers import find_related_entity_id, get_device_info_by_id, \ - create_device_profile_and_firmware +logger = logging.getLogger("blackbox") +logger.setLevel(logging.INFO) -TB_HOST = os.getenv("SDK_BLACKBOX_TB_HOST", "localhost") -TB_HTTP_PORT = int(os.getenv("SDK_BLACKBOX_TB_PORT", 8080)) -TB_MQTT_PORT = int(os.getenv("SDK_BLACKBOX_TB_MQTT_PORT", 1883)) -TENANT_USER = os.getenv("SDK_BLACKBOX_TENANT_USER", "tenant@thingsboard.org") -TENANT_PASS = os.getenv("SDK_BLACKBOX_TENANT_PASS", "tenant") -TB_CONTAINER_NAME = os.getenv("SDK_BLACKBOX_TB_CONTAINER", "tb-ce-sdk-tests") -RUN_WITH_LOCAL_TB = os.getenv("SDK_RUN_BLACKBOX_TESTS_LOCAL", "false").lower() == "true" -TB_HTTP_PROTOCOL = os.getenv("SDK_BLACKBOX_TB_HTTP_PROTOCOL", "http") -TB_URL = os.getenv("SDK_BLACKBOX_TB_URL", f"{TB_HTTP_PROTOCOL}://{TB_HOST}:{TB_HTTP_PORT}") +def _build_retrying_session() -> requests.Session: + """A single session with retry/backoff for better stability in tests.""" + session = requests.Session() + retries = Retry( + total=5, + connect=5, + read=5, + backoff_factor=0.5, # exponential: 0.5, 1, 2, ... + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=("GET", "POST", "PUT", "DELETE", "PATCH"), + raise_on_status=False, + ) + adapter = HTTPAdapter(max_retries=retries, pool_connections=20, pool_maxsize=50) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session -logger = logging.getLogger("blackbox") -logger.setLevel(logging.INFO) + +@pytest.fixture(scope="session") +def http() -> Generator[requests.Session, None, None]: + """Shared HTTP session with retries and connection pooling.""" + session = _build_retrying_session() + try: + yield session + finally: + session.close() def pytest_collection_modifyitems(config, items): """Skip blackbox tests unless explicitly enabled.""" - if os.getenv("SDK_RUN_BLACKBOX_TESTS", "false").lower() != "true": - skip_marker = pytest.mark.skip(reason="Blackbox tests disabled. Set SDK_RUN_BLACKBOX_TESTS=true to run.") + if not RUN_BLACKBOX: + skip_marker = pytest.mark.skip( + reason=( + "Blackbox tests disabled. Set SDK_RUN_BLACKBOX_TESTS=true to run. " + "Optionally set SDK_RUN_BLACKBOX_TESTS_LOCAL=true to run against a local ThingsBoard instance." + ) + ) for item in items: if "blackbox" in str(item.fspath): item.add_marker(skip_marker) -@pytest.fixture(scope="session", autouse=True) -def start_thingsboard(): - """Start ThingsBoard CE in Docker if blackbox tests are enabled.""" - if os.getenv("SDK_RUN_BLACKBOX_TESTS", "false").lower() != "true": - return +# ---------------------------- +# ThingsBoard lifecycle helpers +# ---------------------------- - if not RUN_WITH_LOCAL_TB: +def _docker_is_running(container_name: str) -> bool: + try: + status = ( + subprocess.check_output( + ["docker", "inspect", "-f", "{{.State.Running}}", container_name], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + return status == "true" + except subprocess.CalledProcessError: + return False + + +def _wait_for_tb_ready(http: requests.Session) -> None: + """Wait for TB to accept logins; uses exponential backoff.""" + logger.info("Waiting for ThingsBoard CE to become ready...") + start = time.time() + delay = 1.0 + # Try login to ensure full readiness (DB + REST). + while time.time() - start < DOCKER_START_TIMEOUT_S: try: - status = subprocess.check_output( - ["docker", "inspect", "-f", "{{.State.Running}}", TB_CONTAINER_NAME], - stderr=subprocess.DEVNULL - ).decode().strip() - if status == "true": - logger.info(f"Using existing ThingsBoard CE container: {TB_CONTAINER_NAME}") - wait_for_tb_ready() + r = http.post( + f"{TB_URL}/api/auth/login", + json={"username": TENANT_USER, "password": TENANT_PASS}, + timeout=REQUEST_TIMEOUT, + ) + if r.ok: + logger.info("ThingsBoard CE is ready.") return - except subprocess.CalledProcessError: + except requests.RequestException: pass + time.sleep(delay) + delay = min(delay * 1.5, 5.0) + pytest.fail("ThingsBoard CE did not become ready in time.") - logger.info("Pulling latest ThingsBoard CE image...") - subprocess.run(["docker", "pull", "thingsboard/tb-postgres"], check=True) - logger.info("Starting ThingsBoard CE container...") - subprocess.run([ - "docker", "run", "-d", "--rm", - "--name", TB_CONTAINER_NAME, - "-p", "8080:8080", - "-p", f"{TB_MQTT_PORT}:1883", - "thingsboard/tb-postgres" - ], check=True) +@pytest.fixture(scope="session", autouse=True) +def start_thingsboard(http: requests.Session) -> Generator[None, None, None]: + """ + Start ThingsBoard CE in Docker if blackbox tests are enabled and not running locally. + Ensures instance is ready. If we start the container, we will stop it at session end. + """ + if not RUN_BLACKBOX: + yield + return - wait_for_tb_ready() + started_by_us = False + if not RUN_WITH_LOCAL_TB: + if _docker_is_running(TB_CONTAINER_NAME): + logger.info("Using existing ThingsBoard CE container: %s", TB_CONTAINER_NAME) + else: + logger.info("Pulling ThingsBoard CE image (thingsboard/tb-postgres)...") + subprocess.run(["docker", "pull", "thingsboard/tb-postgres"], check=True) + + logger.info("Starting ThingsBoard CE container...") + subprocess.run( + [ + "docker", + "run", + "-d", + "--rm", + "--name", + TB_CONTAINER_NAME, + "-p", + f"{TB_HTTP_PORT}:8080", + "-p", + f"{TB_MQTT_PORT}:1883", + "thingsboard/tb-postgres", + ], + check=True, + ) + started_by_us = True + + _wait_for_tb_ready(http) -def wait_for_tb_ready(): - logger.info("Waiting for ThingsBoard CE to become ready...") - for _ in range(60): - try: - r = requests.post(f"{TB_URL}/api/auth/login", json={ - "username": TENANT_USER, - "password": TENANT_PASS - }) - if r.ok: - logger.info("ThingsBoard CE is ready.") - return - except requests.ConnectionError: - pass - time.sleep(2) - pytest.fail("ThingsBoard CE did not become ready in time.") + try: + yield + finally: + if started_by_us: + logger.info("Stopping ThingsBoard CE container: %s", TB_CONTAINER_NAME) + subprocess.run(["docker", "stop", TB_CONTAINER_NAME], check=False) +# ---------------------------- +# Auth and headers +# ---------------------------- + @pytest.fixture(scope="session") -def test_config(): +def test_config() -> Dict[str, object]: """Return configuration for blackbox tests.""" return { "tb_host": TB_HOST, @@ -111,159 +206,310 @@ def test_config(): "tb_mqtt_port": TB_MQTT_PORT, "tenant_user": TENANT_USER, "tenant_pass": TENANT_PASS, - "tb_url": TB_URL + "tb_url": TB_URL, + "timeout": REQUEST_TIMEOUT, } @pytest.fixture(scope="session") -def tb_admin_token(start_thingsboard): +def tb_admin_token(start_thingsboard, http: requests.Session) -> str: """Login as tenant admin and return JWT.""" - r = requests.post(f"{TB_URL}/api/auth/login", json={ - "username": TENANT_USER, - "password": TENANT_PASS - }) + r = http.post( + f"{TB_URL}/api/auth/login", + json={"username": TENANT_USER, "password": TENANT_PASS}, + timeout=REQUEST_TIMEOUT, + ) r.raise_for_status() - return r.json()["token"] + data = r.json() + # Some TB builds return {"token": "..."}; others include "refreshToken". We only need the token. + return data["token"] @pytest.fixture(scope="session") -def tb_admin_headers(tb_admin_token): - """Return headers for ThingsBoard REST API requests.""" +def tb_admin_headers(tb_admin_token: str) -> Dict[str, str]: + """Headers for ThingsBoard REST API requests.""" return { "X-Authorization": f"Bearer {tb_admin_token}", - "Content-Type": "application/json" + "Content-Type": "application/json", } +# ---------------------------- +# Device fixtures +# ---------------------------- + +def _create_device(http: requests.Session, headers: Dict[str, str], name: str, dev_type: str = "default") -> dict: + r = http.post( + f"{TB_URL}/api/device", + headers=headers, + json={"name": name, "type": dev_type}, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + return r.json() + + +def _get_device_credentials_token(http: requests.Session, headers: Dict[str, str], device_id: str) -> str: + r = http.get(f"{TB_URL}/api/device/{device_id}/credentials", headers=headers, timeout=REQUEST_TIMEOUT) + r.raise_for_status() + return r.json()["credentialsId"] + + @pytest.fixture() -def device_info(tb_admin_headers): +def device_info(tb_admin_headers: Dict[str, str], http: requests.Session) -> Generator[dict, None, None]: + """Create a unique device for a test; always cleaned up.""" + name = f"pytest-device-{uuid.uuid4().hex[:8]}" try: - r = requests.post(f"{TB_URL}/api/device", headers=tb_admin_headers, json={ - "name": "pytest-device", - "type": "default" - }) - r.raise_for_status() - device = r.json() + device = _create_device(http, tb_admin_headers, name) except HTTPError as e: - logger.error(f"Failed to create device: {e}") - r = requests.get(f"{TB_URL}/api/tenant/devices?deviceName=pytest-device", headers=tb_admin_headers) - r.raise_for_status() - device = r.json() - if not device: - pytest.fail("Failed to create or find test device.") + logger.error("Failed to create device %s: %s", name, e) + pytest.fail("Failed to create test device.") - yield device + try: + yield device + finally: + device_id = device["id"]["id"] + http.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) @pytest.fixture() -def device_config(device_info, tb_admin_headers): - +def device_config( + device_info: dict, tb_admin_headers: Dict[str, str], http: requests.Session +) -> Generator[DeviceConfig, None, None]: + """Return a ready-to-use DeviceConfig; cleans up the device after use via device_info fixture.""" device_id = device_info["id"]["id"] + token = _get_device_credentials_token(http, tb_admin_headers, device_id) - # Get credentials - r = requests.get(f"{TB_URL}/api/device/{device_id}/credentials", headers=tb_admin_headers) - r.raise_for_status() - token = r.json()["credentialsId"] + cfg = DeviceConfig() + cfg.host = TB_HOST + cfg.port = TB_MQTT_PORT + cfg.access_token = token + yield cfg - device_config = DeviceConfig() - device_config.host = TB_HOST - device_config.port = TB_MQTT_PORT - device_config.access_token = token - yield device_config - # Cleanup - requests.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers) +# ---------------------------- +# Gateway fixtures +# ---------------------------- + +def _create_gateway_device(http: requests.Session, headers: Dict[str, str], name: str) -> dict: + r = http.post( + f"{TB_URL}/api/device", + headers=headers, + json={"name": name, "type": "default", "additionalInfo": {"gateway": True}}, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + return r.json() @pytest.fixture() -def gateway_info(tb_admin_headers): +def gateway_info(tb_admin_headers: Dict[str, str], http: requests.Session) -> Generator[dict, None, None]: + """Create a unique gateway device for a test; always cleaned up.""" + name = f"pytest-gw-device-{uuid.uuid4().hex[:8]}" try: - r = requests.post(f"{TB_URL}/api/device", headers=tb_admin_headers, json={ - "name": "pytest-gw-device", - "type": "default", - "additionalInfo": { - "gateway": True - } - }) - r.raise_for_status() - device = r.json() + gw = _create_gateway_device(http, tb_admin_headers, name) except HTTPError as e: - logger.error(f"Failed to create device: {e}") - r = requests.get(f"{TB_URL}/api/tenant/devices?deviceName=pytest-gw-device", headers=tb_admin_headers) - r.raise_for_status() - device = r.json() - if not device: - pytest.fail("Failed to create or find test device.") - - yield device + logger.error("Failed to create gateway device %s: %s", name, e) + pytest.fail("Failed to create or find test gateway device.") + try: + yield gw + finally: + gw_id = gw["id"]["id"] + # Attempt to delete gateway at teardown (sub-devices handled by gateway_config fixture if created). + http.delete(f"{TB_URL}/api/device/{gw_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) -@pytest.fixture -def gateway_config(gateway_info, tb_admin_headers): +@pytest.fixture() +def gateway_config( + gateway_info: dict, tb_admin_headers: Dict[str, str], http: requests.Session +) -> Generator[GatewayConfig, None, None]: + """Return a GatewayConfig and clean up gateway + any auto-created sub-device/profile.""" device_id = gateway_info["id"]["id"] + token = _get_device_credentials_token(http, tb_admin_headers, device_id) - # Get credentials - r = requests.get(f"{TB_URL}/api/device/{device_id}/credentials", headers=tb_admin_headers) - r.raise_for_status() - token = r.json()["credentialsId"] - - gateway_config = GatewayConfig() - gateway_config.host = TB_HOST - gateway_config.port = TB_MQTT_PORT - gateway_config.access_token = token - yield gateway_config + cfg = GatewayConfig() + cfg.host = TB_HOST + cfg.port = TB_MQTT_PORT + cfg.access_token = token - subdevice_id = find_related_entity_id(device_id, TB_URL, tb_admin_headers) + # Provide config to the test + yield cfg - # Cleanup - requests.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers) - - sub_device_info = get_device_info_by_id( - subdevice_id, - TB_URL, - tb_admin_headers - ) + # Cleanup: delete gateway and any auto-created sub-device/profile if present. + try: + subdevice_id: Optional[str] = None + try: + subdevice_id = find_related_entity_id(device_id, TB_URL, tb_admin_headers) + except Exception: + subdevice_id = None - requests.delete(f"{TB_URL}/api/device/{subdevice_id}", headers=tb_admin_headers) + http.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) - requests.delete(f"{TB_URL}/api/deviceProfile/{sub_device_info['deviceProfileId']['id']}", - headers=tb_admin_headers) + if subdevice_id: + sub_dev = get_device_info_by_id(subdevice_id, TB_URL, tb_admin_headers) + http.delete(f"{TB_URL}/api/device/{subdevice_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) + if "deviceProfileId" in sub_dev and "id" in sub_dev["deviceProfileId"]: + http.delete( + f"{TB_URL}/api/deviceProfile/{sub_dev['deviceProfileId']['id']}", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + except requests.RequestException as e: + logger.warning("Gateway cleanup encountered an error: %s", e) -@pytest.fixture -def firmware_profile_and_package(tb_admin_headers, test_config): - firmware_name = 'pytest-firmware' - firmware_version = '1.0.0' - boundary = uuid.uuid4().hex - firmware_bytes = b'Firmware binary data for pytest' +# ---------------------------- +# Firmware/profile fixtures +# ---------------------------- - multipart_body = ( - f"--{boundary}\r\n" - f"Content-Disposition: form-data; name=\"file\"; filename=\"firmware.bin\"\r\n" - f"Content-Type: application/octet-stream\r\n\r\n" - ).encode() + firmware_bytes + f"\r\n--{boundary}--\r\n".encode() +@pytest.fixture() +def firmware_profile_and_package( + tb_admin_headers: Dict[str, str], + test_config: Dict[str, object], + http: requests.Session, +) -> Generator[tuple[dict, dict, dict], None, None]: + """ + Creates a device profile and uploads a firmware package. + Returns (device_profile, firmware, firmware_info). + """ + firmware_name = "pytest-firmware" + firmware_version = "1.0.0" + firmware_bytes = b"Firmware binary data for pytest" firmware_info = { "name": firmware_name, "version": firmware_version, - "data": firmware_bytes - } - - firmware_headers = { - "Content-Type": f"multipart/form-data; boundary={boundary}", - "Content-Length": str(len(multipart_body)) + "data": firmware_bytes, } - firmware_headers["X-Authorization"] = tb_admin_headers["X-Authorization"] device_profile, firmware = create_device_profile_and_firmware( - firmware_name, firmware_version, multipart_body, test_config['tb_url'], tb_admin_headers, firmware_headers + firmware_name=firmware_name, + firmware_version=firmware_version, + firmware_bytes=firmware_bytes, + base_url=test_config["tb_url"], # type: ignore[arg-type] + headers=tb_admin_headers, + http=http, + timeout=REQUEST_TIMEOUT, ) assert device_profile is not None, "Device profile should be created successfully" assert firmware is not None, "Firmware should be created successfully" - yield device_profile, firmware, firmware_info + try: + yield device_profile, firmware, firmware_info + finally: + # Cleanup + http.delete( + f"{test_config['tb_url']}/api/deviceProfile/{device_profile['id']['id']}", # type: ignore[index] + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + http.delete( + f"{test_config['tb_url']}/api/firmware/{firmware['id']['id']}", # type: ignore[index] + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + + +# ---------------------------- +# Rule chain fixtures +# ---------------------------- + +@pytest.fixture() +def rpc_rule_chain( + tb_admin_headers: Dict[str, str], test_config: Dict[str, object], http: requests.Session +) -> Generator[dict, None, None]: + """ + Creates a rule chain based on RPC_RESPONSE_RULE_CHAIN and uploads its metadata. + Cleans up the rule chain afterwards. + """ + rule_chain_info = deepcopy(RPC_RESPONSE_RULE_CHAIN["ruleChain"]) + r = http.post( + f"{test_config['tb_url']}/api/ruleChain", + headers=tb_admin_headers, + json=rule_chain_info, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + rule_chain = r.json() + rule_chain_id = rule_chain["id"]["id"] + + # Get current metadata, then replace nodes/connections from our template + meta_resp = http.get( + f"{test_config['tb_url']}/api/ruleChain/{rule_chain_id}/metadata", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + meta_resp.raise_for_status() + + received_metadata = meta_resp.json() + tpl_meta = deepcopy(RPC_RESPONSE_RULE_CHAIN["metadata"]) + + received_metadata["nodes"] = tpl_meta["nodes"] + now_ms = int(time.time() * 1000) + for node in received_metadata["nodes"]: + node["createdTime"] = now_ms + + received_metadata["connections"] = tpl_meta["connections"] + received_metadata["firstNodeIndex"] = tpl_meta["firstNodeIndex"] + received_metadata["ruleChainId"] = {"id": rule_chain_id, "entityType": "RULE_CHAIN"} + + apply_resp = http.post( + f"{test_config['tb_url']}/api/ruleChain/metadata", + headers=tb_admin_headers, + json=received_metadata, + timeout=REQUEST_TIMEOUT, + ) + apply_resp.raise_for_status() + + try: + yield rule_chain + finally: + http.delete( + f"{test_config['tb_url']}/api/ruleChain/{rule_chain_id}", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + + +@pytest.fixture() +def device_profile_with_rpc_rule_chain( + tb_admin_headers: Dict[str, str], + test_config: Dict[str, object], + rpc_rule_chain: dict, + http: requests.Session, +) -> Generator[dict, None, None]: + """ + Creates a device profile that uses the RPC rule chain as default; cleans up afterwards. + """ + rule_chain = rpc_rule_chain + device_profile = get_default_device_profile(test_config["tb_url"], tb_admin_headers) + device_profile["id"] = None + device_profile["createdTime"] = None + device_profile["name"] = "pytest-client-side-rpc-profile" + device_profile["defaultRuleChainId"] = rule_chain["id"] + device_profile.setdefault( + "profileData", + { + "configuration": {"type": "DEFAULT"}, + "transportConfiguration": {"type": "DEFAULT"}, + }, + ) - # Cleanup - requests.delete(f"{test_config['tb_url']}/api/deviceProfile/{device_profile['id']['id']}", headers=tb_admin_headers) - requests.delete(f"{test_config['tb_url']}/api/firmware/{firmware['id']['id']}", headers=tb_admin_headers) + r = http.post( + f"{test_config['tb_url']}/api/deviceProfile", + headers=tb_admin_headers, + json=device_profile, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + created = r.json() + try: + yield created + finally: + device_profile_id = created["id"]["id"] + http.delete( + f"{test_config['tb_url']}/api/deviceProfile/{device_profile_id}", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) diff --git a/tests/blackbox/constants.py b/tests/blackbox/constants.py new file mode 100644 index 0000000..4acdbaa --- /dev/null +++ b/tests/blackbox/constants.py @@ -0,0 +1,81 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +RPC_RESPONSE_RULE_CHAIN = { + "ruleChain": { + "name": "test-client-side-rpc", + "type": "CORE", + "firstRuleNodeId": None, + "root": False, + "debugMode": False, + "configuration": None, + "additionalInfo": { + "description": "" + } + }, + "metadata": { + "ruleChainId": { + }, + "version": 3, + "firstNodeIndex": 0, + "nodes": [ + { + "createdTime": 1754561486165, + "type": "org.thingsboard.rule.engine.filter.TbMsgTypeSwitchNode", + "name": "msg_type_switch", + "debugSettings": None, + "singletonMode": False, + "queueName": None, + "configurationVersion": 0, + "configuration": { + "version": 0 + }, + "externalId": None, + "additionalInfo": { + "description": "", + "layoutX": 277, + "layoutY": 150 + } + }, + { + "createdTime": 1754561486166, + "type": "org.thingsboard.rule.engine.rpc.TbSendRPCReplyNode", + "name": "rpc_reply", + "debugSettings": None, + "singletonMode": False, + "queueName": None, + "configurationVersion": 0, + "configuration": { + "serviceIdMetaDataAttribute": "serviceId", + "sessionIdMetaDataAttribute": "sessionId", + "requestIdMetaDataAttribute": "requestId" + }, + "externalId": None, + "additionalInfo": { + "description": "", + "layoutX": 542, + "layoutY": 153 + } + } + ], + "connections": [ + { + "fromIndex": 0, + "toIndex": 1, + "type": "RPC Request from Device" + } + ], + "ruleChainConnections": None +} +} \ No newline at end of file diff --git a/tests/blackbox/rest_helpers.py b/tests/blackbox/rest_helpers.py index 0829105..f515c5f 100644 --- a/tests/blackbox/rest_helpers.py +++ b/tests/blackbox/rest_helpers.py @@ -12,167 +12,201 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import asyncio -from time import time -from typing import Tuple +from time import time as now +from typing import Dict, List, Optional, Tuple import requests -def get_device_attributes(device_id: str, base_url: str, headers: dict, scope: str) -> list: - - url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/attributes?scope={scope}" - - response = requests.get(url, headers=headers) - - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() - +def _check_ok(resp: requests.Response) -> None: + if not resp.ok: + resp.raise_for_status() -def get_device_timeseries(device_id: str, base_url: str, headers: dict, keys:list =None) -> dict: - current_ts = int(time() * 1000) - if keys is None: - keys = [] +def get_device_attributes( + device_id: str, base_url: str, headers: Dict[str, str], scope: str, *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> List[dict]: + """Fetch device attributes for a scope.""" + sess = http or requests + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/attributes" + resp = sess.get(url, headers=headers, params={"scope": scope}, timeout=timeout) + _check_ok(resp) + return resp.json() - url = (f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/timeseries?startTs={current_ts - 60000}" - f"&endTs={current_ts}" - f"&useStrictDataTypes=true") +def get_device_timeseries( + device_id: str, + base_url: str, + headers: Dict[str, str], + keys: Optional[List[str]] = None, + *, + http: Optional[requests.Session] = None, + timeout: float = 30.0, +) -> Dict[str, list]: + """Fetch device timeseries for the last 60 seconds, optionally filtered by keys.""" + sess = http or requests + current_ts = int(now() * 1000) + params = { + "startTs": current_ts - 60_000, + "endTs": current_ts, + "useStrictDataTypes": "true", + } if keys: - url += f"&keys={','.join(keys)}" - - response = requests.get(url, headers=headers) + params["keys"] = ",".join(keys) + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/timeseries" + resp = sess.get(url, headers=headers, params=params, timeout=timeout) + _check_ok(resp) + return resp.json() - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() -def update_shared_attributes(device_id: str, base_url: str, headers: dict, attributes: dict) -> None: +def update_shared_attributes( + device_id: str, base_url: str, headers: Dict[str, str], attributes: Dict[str, object], *, + http: Optional[requests.Session] = None, timeout: float = 30.0 +) -> None: + """Update shared scope attributes for a device.""" + sess = http or requests url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/attributes/SHARED_SCOPE" + resp = sess.post(url, json=attributes, headers=headers, timeout=timeout) + _check_ok(resp) - response = requests.post(url, json=attributes, headers=headers) - if response.status_code != 200: - response.raise_for_status() - -def get_device_info_by_name(device_name: str, base_url: str, headers: dict) -> dict: - url = f"{base_url}/api/tenant/devices?deviceName={device_name}" - response = requests.get(url, headers=headers) - if response.status_code != 200: - response.raise_for_status() - device = response.json() +def get_device_info_by_name( + device_name: str, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> dict: + """Look up a device by its exact name.""" + sess = http or requests + url = f"{base_url}/api/tenant/devices" + resp = sess.get(url, headers=headers, params={"deviceName": device_name}, timeout=timeout) + _check_ok(resp) + device = resp.json() if not device: raise ValueError(f"Device with name '{device_name}' not found.") return device -def find_related_entity_id(device_id: str, base_url: str, headers: dict) -> str: - url = f"{base_url}/api/relations?fromId={device_id}&fromType=DEVICE" - response = requests.get(url, headers=headers) - if response.status_code == 200: - relations = response.json() - if relations: - return relations[0]['to']['id'] # Return the first related entity - else: - raise ValueError(f"No related entity found for device ID '{device_id}'.") - else: - response.raise_for_status() +def find_related_entity_id( + device_id: str, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> str: + """Return the first related entity ID for a device (if any).""" + sess = http or requests + url = f"{base_url}/api/relations" + resp = sess.get(url, headers=headers, params={"fromId": device_id, "fromType": "DEVICE"}, timeout=timeout) + _check_ok(resp) + relations = resp.json() + if relations: + return relations[0]["to"]["id"] + raise ValueError(f"No related entity found for device ID '{device_id}'.") + -def get_device_info_by_id(device_id: str, base_url: str, headers: dict) -> dict: +def get_device_info_by_id( + device_id: str, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> dict: + """Fetch device info by ThingsBoard device UUID.""" + sess = http or requests url = f"{base_url}/api/device/{device_id}" - response = requests.get(url, headers=headers) - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() + resp = sess.get(url, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() -def get_default_device_profile(base_url: str, headers: dict) -> dict: + +def get_default_device_profile( + base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, timeout: float = 30.0 +) -> dict: + """Get the tenant default device profile (info).""" + sess = http or requests url = f"{base_url}/api/deviceProfileInfo/default" - response = requests.get(url, headers=headers) - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() - -def create_device_profile_with_provisioning(device_profile_name: str, - base_url: str, - headers: dict, - provisioning_device_key: str, - provisioning_device_secret: str) -> dict: - device_profile = get_default_device_profile(base_url, headers) - url = f"{base_url}/api/deviceProfile" - device_profile['id'] = None - device_profile['name'] = device_profile_name - device_profile['createdTime'] = None - device_profile['provisionType'] = 'ALLOW_CREATE_NEW_DEVICES' - if 'profileData' not in device_profile: - device_profile['profileData'] = { - "configuration": { - "type": "DEFAULT" - }, - "transportConfiguration": { - "type": "DEFAULT" - }, - "provisionConfiguration": { - "type": "ALLOW_CREATE_NEW_DEVICES", - "provisionDeviceSecret": provisioning_device_secret - }, - "alarms": None - } - else: - device_profile['profileData']['provisionConfiguration'] = { - 'type': 'ALLOW_CREATE_NEW_DEVICES', - 'provisionDeviceSecret': provisioning_device_secret - } - device_profile['provisionDeviceKey'] = provisioning_device_key - - response = requests.post(url, json=device_profile, headers=headers) - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() - -def create_device_profile_and_firmware(firmware_name: str, - firmware_version: str, - firmware_data: bytes, - base_url: str, - headers: dict, - firmware_data_headers: dict) -> Tuple[dict, dict]: - device_profile = get_default_device_profile(base_url, headers) - device_profile['id'] = None - device_profile['createdTime'] = None - device_profile['name'] = 'pytest-firmware-profile' - if 'profileData' not in device_profile: - device_profile['profileData'] = { - "configuration": { - "type": "DEFAULT" - }, - "transportConfiguration": { - "type": "DEFAULT" - } + resp = sess.get(url, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def create_device_profile_with_provisioning( + device_profile_name: str, + base_url: str, + headers: Dict[str, str], + provisioning_device_key: str, + provisioning_device_secret: str, + *, + http: Optional[requests.Session] = None, + timeout: float = 30.0, +) -> dict: + """Create a device profile with ALLOW_CREATE_NEW_DEVICES provisioning.""" + sess = http or requests + device_profile = get_default_device_profile(base_url, headers, http=sess, timeout=timeout) + device_profile.update( + { + "id": None, + "name": device_profile_name, + "createdTime": None, + "provisionType": "ALLOW_CREATE_NEW_DEVICES", } - response = requests.post(f"{base_url}/api/deviceProfile", json=device_profile, headers=headers) - if response.status_code == 200: - device_profile = response.json() - else: - response.raise_for_status() - - # Create OTA package - init_ota_package ={ + ) + profile_data = device_profile.setdefault( + "profileData", + { + "configuration": {"type": "DEFAULT"}, + "transportConfiguration": {"type": "DEFAULT"}, + "alarms": None, + }, + ) + profile_data["provisionConfiguration"] = { + "type": "ALLOW_CREATE_NEW_DEVICES", + "provisionDeviceSecret": provisioning_device_secret, + } + device_profile["provisionDeviceKey"] = provisioning_device_key + + url = f"{base_url}/api/deviceProfile" + resp = sess.post(url, json=device_profile, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def create_device_profile_and_firmware( + firmware_name: str, + firmware_version: str, + firmware_bytes: bytes, + base_url: str, + headers: Dict[str, str], + *, + http: Optional[requests.Session] = None, + timeout: float = 30.0, +) -> Tuple[dict, dict]: + """ + Create a device profile and upload a firmware OTA package. + + Returns: + (device_profile, firmware_ota_package) + """ + sess = http or requests + + # Create device profile + device_profile = get_default_device_profile(base_url, headers, http=sess, timeout=timeout) + device_profile.update({"id": None, "createdTime": None, "name": "pytest-firmware-profile"}) + device_profile.setdefault("profileData", { + "configuration": {"type": "DEFAULT"}, + "transportConfiguration": {"type": "DEFAULT"}, + }) + + resp = sess.post(f"{base_url}/api/deviceProfile", json=device_profile, headers=headers, timeout=timeout) + _check_ok(resp) + device_profile = resp.json() + + # Create OTA package (metadata) + init_ota_package = { "id": None, "createdTime": None, - "deviceProfileId": { - "entityType": "DEVICE_PROFILE", - "id": device_profile['id']['id'] - }, + "deviceProfileId": {"entityType": "DEVICE_PROFILE", "id": device_profile["id"]["id"]}, "type": "FIRMWARE", "title": firmware_name, "version": firmware_version, - "tag": firmware_name + " " + firmware_version, + "tag": f"{firmware_name} {firmware_version}", "url": None, "hasData": False, "fileName": None, @@ -182,40 +216,57 @@ def create_device_profile_and_firmware(firmware_name: str, "dataSize": None, "externalId": None, "name": firmware_name, - "additionalInfo": { - "description": "" - } + "additionalInfo": {"description": ""}, } - created_ota_package = requests.post(f"{base_url}/api/otaPackage", json=init_ota_package, headers=headers) - initial_ota_package = None - if created_ota_package.status_code == 200: - initial_ota_package = created_ota_package.json() - else: - created_ota_package.raise_for_status() - # Upload firmware data - firmware_data_url = f"{base_url}/api/otaPackage/{initial_ota_package['id']['id']}?checksumAlgorithm=SHA256" - response = requests.post(firmware_data_url, data=firmware_data, headers=firmware_data_headers) - if response.status_code == 200: - return device_profile, response.json() - else: - response.raise_for_status() - -def save_device(device: dict, base_url: str, headers: dict) -> dict: + created_ota_package = sess.post(f"{base_url}/api/otaPackage", json=init_ota_package, headers=headers, + timeout=timeout) + _check_ok(created_ota_package) + initial = created_ota_package.json() + + # Upload firmware data using proper multipart form-data + upload_url = f"{base_url}/api/otaPackage/{initial['id']['id']}?checksumAlgorithm=SHA256" + files = { + "file": ("firmware.bin", firmware_bytes, "application/octet-stream"), + } + # Only auth header required; requests builds Content-Type with boundary. + upload_headers = {k: v for k, v in headers.items() if k.lower() == "x-authorization"} + upload_resp = sess.post(upload_url, headers=upload_headers, files=files, timeout=timeout) + _check_ok(upload_resp) + + return device_profile, upload_resp.json() + + +def save_device( + device: dict, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> dict: + """Create or update a device via REST.""" + sess = http or requests url = f"{base_url}/api/device" - response = requests.post(url, json=device, headers=headers) - if response.status_code == 200: - return response.json() - else: - response.raise_for_status() + resp = sess.post(url, json=device, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() -async def send_rpc_request(device_id: str, base_url: str, headers: dict, request: dict): - loop = asyncio.get_running_loop() - def _send(): +async def send_rpc_request( + device_id: str, + base_url: str, + headers: Dict[str, str], + request: dict, + *, + http: Optional[requests.Session] = None, + timeout: float = 60.0, +) -> dict: + """ + Send a two-way RPC request to a device. Performs the blocking HTTP call in a thread pool. + """ + sess = http or requests + + def _send() -> dict: url = f"{base_url}/api/rpc/twoway/{device_id}" - r = requests.post(url, json=request, headers=headers, timeout=60) + r = sess.post(url, json=request, headers=headers, timeout=timeout) r.raise_for_status() return r.json() - result = await loop.run_in_executor(None, _send) - return result \ No newline at end of file + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, _send) diff --git a/tests/blackbox/test_device_client_side_rpc.py b/tests/blackbox/test_device_client_side_rpc.py new file mode 100644 index 0000000..c853939 --- /dev/null +++ b/tests/blackbox/test_device_client_side_rpc.py @@ -0,0 +1,62 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest + +from examples.device import send_client_side_rpc +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tests.blackbox.rest_helpers import save_device + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_client_side_rpc(device_profile_with_rpc_rule_chain, device_info, device_config, tb_admin_headers, test_config): + handler_future = asyncio.Future() + async def rpc_response_handler(rpc_request): + handler_future.set_result(rpc_request) + + send_client_side_rpc.config = device_config + send_client_side_rpc.rpc_response_callback = rpc_response_handler + + device_profile_id = device_profile_with_rpc_rule_chain['id'] + device_info['deviceProfileId'] = device_profile_id + device = save_device(device_info, test_config['tb_url'], tb_admin_headers) + assert device is not None, "Device should be saved successfully" + + await asyncio.sleep(1) # Ensure the device is saved before starting the RPC request + + task = asyncio.create_task(send_client_side_rpc.main()) + result = None + try: + result = await asyncio.wait_for(handler_future, timeout=30) + except Exception as e: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + first_rpc_response = send_client_side_rpc.rpc_response + second_rpc_response = result + + assert first_rpc_response is not None, "First RPC response should not be None" + assert isinstance(first_rpc_response, RPCResponse), "First RPC response should be an instance of RPCResponse" + assert first_rpc_response.request_id < second_rpc_response.request_id, "First RPC response ID should be less than second RPC response ID" + assert first_rpc_response.result['method'] == 'getTime', "First RPC response method should be 'getTime'" + + assert second_rpc_response is not None, "Second RPC response should not be None" + assert isinstance(second_rpc_response, RPCResponse), "Second RPC response should be an instance of RPCResponse" + assert second_rpc_response.result['method'] == 'getStatus', "Second RPC response method should be 'getStatus'" + assert second_rpc_response.result['params'] == {'param1': 'value1', 'param2': 'value2'}, "Second RPC response params should match expected values" From 36e498b6a4d566776041980c88611e2581301c3e Mon Sep 17 00:00:00 2001 From: imbeacon Date: Fri, 8 Aug 2025 15:13:16 +0300 Subject: [PATCH 73/74] Added additional tests --- tb_mqtt_client/common/config_loader.py | 31 +- tb_mqtt_client/common/gmqtt_patch.py | 4 +- .../rate_limit/backpressure_controller.py | 16 - tests/common/test_async_utils.py | 56 ++- tests/common/test_config_loader.py | 41 ++- tests/common/test_exceptions.py | 156 ++++++++ tests/common/test_gmqtt_patch.py | 344 ++++++++++++++++-- .../gateway/test_gateway_rpc_request.py | 154 ++++++++ .../gateway/test_gateway_rpc_response.py | 177 +++++++++ tests/service/test_mqtt_manager.py | 300 ++++++++++++++- 10 files changed, 1210 insertions(+), 69 deletions(-) create mode 100644 tests/common/test_exceptions.py create mode 100644 tests/entities/gateway/test_gateway_rpc_request.py create mode 100644 tests/entities/gateway/test_gateway_rpc_response.py diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py index 530ba5a..4b56d35 100644 --- a/tb_mqtt_client/common/config_loader.py +++ b/tb_mqtt_client/common/config_loader.py @@ -24,6 +24,17 @@ class DeviceConfig: """ def __init__(self, config=None): + self.host = None + self.port = 1883 + self.access_token: Optional[str] = None + self.username: Optional[str] = None + self.password: Optional[str] = None + self.client_id: Optional[str] = None + self.ca_cert: Optional[str] = None + self.client_cert: Optional[str] = None + self.private_key: Optional[str] = None + self.qos: int = 1 + if config is not None: self.host: str = config.get("host", "localhost") self.port: int = config.get("port", 1883) @@ -35,24 +46,24 @@ def __init__(self, config=None): self.client_cert: Optional[str] = config.get("client_cert") self.private_key: Optional[str] = config.get("private_key") - self.host: str = os.getenv("TB_HOST") - self.port: int = int(os.getenv("TB_PORT", 1883)) + self.host: str = os.getenv("TB_HOST", self.host) + self.port: int = int(os.getenv("TB_PORT", self.port)) # Authentication options - self.access_token: Optional[str] = os.getenv("TB_ACCESS_TOKEN") - self.username: Optional[str] = os.getenv("TB_USERNAME") - self.password: Optional[str] = os.getenv("TB_PASSWORD") + self.access_token: Optional[str] = os.getenv("TB_ACCESS_TOKEN", self.access_token) + self.username: Optional[str] = os.getenv("TB_USERNAME", self.username) + self.password: Optional[str] = os.getenv("TB_PASSWORD", self.password) # Optional - self.client_id: Optional[str] = os.getenv("TB_CLIENT_ID") + self.client_id: Optional[str] = os.getenv("TB_CLIENT_ID", self.client_id) # TLS options - self.ca_cert: Optional[str] = os.getenv("TB_CA_CERT") - self.client_cert: Optional[str] = os.getenv("TB_CLIENT_CERT") - self.private_key: Optional[str] = os.getenv("TB_PRIVATE_KEY") + self.ca_cert: Optional[str] = os.getenv("TB_CA_CERT", self.ca_cert) + self.client_cert: Optional[str] = os.getenv("TB_CLIENT_CERT", self.client_cert) + self.private_key: Optional[str] = os.getenv("TB_PRIVATE_KEY", self.private_key) # Default values - self.qos: int = int(os.getenv("TB_QOS", 1)) + self.qos: int = int(os.getenv("TB_QOS", self.qos)) def use_tls_auth(self) -> bool: return all([self.ca_cert, self.client_cert, self.private_key]) diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py index 539df5a..735afd0 100644 --- a/tb_mqtt_client/common/gmqtt_patch.py +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -122,8 +122,8 @@ def parse_mqtt_properties(packet: bytes) -> dict: properties_dict = defaultdict(list) try: - properties_len, _ = unpack_variable_byte_integer(packet) - props = packet[:properties_len] + properties_len, rest = unpack_variable_byte_integer(packet) + props = rest[:properties_len] # slice out exactly the properties section while props: property_identifier = props[0] diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py index 76450cd..6dd46a1 100644 --- a/tb_mqtt_client/common/rate_limit/backpressure_controller.py +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -30,7 +30,6 @@ def __init__(self, main_stop_event: Event): self._consecutive_quota_exceeded = 0 self._last_quota_exceeded = datetime.now(UTC) self._max_backoff_seconds = 3600 # 1 hour - self._can_process_messages_events: List[asyncio.Event] = [] logger.debug("BackpressureController initialized with default pause duration of %s seconds", self._default_pause_duration.total_seconds()) @@ -90,10 +89,6 @@ def should_pause(self) -> bool: # Reset pause state self._pause_until = None logger.info("Backpressure released, resuming publishing") - for event in self._can_process_messages_events: - if not event.is_set(): - event.set() - logger.debug("Set can-process event %s", event) return False def clear(self): @@ -101,14 +96,3 @@ def clear(self): logger.info("Clearing backpressure pause") self._pause_until = None self._consecutive_quota_exceeded = 0 - - def register_can_process_event(self, event: Event): - """ - Register an event that will be set when the controller can process messages again. - This is useful for other components to wait until the backpressure is lifted. - """ - if not isinstance(event, Event): - raise ValueError("Expected an asyncio.Event instance") - self._can_process_messages_events.append(event) - logger.debug("Registered a new can-process event, total events: %d", - len(self._can_process_messages_events)) diff --git a/tests/common/test_async_utils.py b/tests/common/test_async_utils.py index 049351c..ff88b6d 100644 --- a/tests/common/test_async_utils.py +++ b/tests/common/test_async_utils.py @@ -13,13 +13,36 @@ # limitations under the License. import asyncio +import threading +import time import pytest -from tb_mqtt_client.common.async_utils import FutureMap, await_or_stop +import tb_mqtt_client.common.async_utils as async_utils_mod +from tb_mqtt_client.common.async_utils import FutureMap, await_or_stop, run_coroutine_sync from tb_mqtt_client.common.publish_result import PublishResult +@pytest.fixture +def fake_loop(monkeypatch): + class _FakeTask: + def __init__(self, thread: threading.Thread): + self._thread = thread + + def done(self): # optional helpers if ever needed + return not self._thread.is_alive() + + class _FakeLoop: + def create_task(self, coro): + t = threading.Thread(target=lambda: asyncio.run(coro), daemon=True) + t.start() + return _FakeTask(t) + + monkeypatch.setattr(async_utils_mod.asyncio, "get_running_loop", lambda: _FakeLoop()) + + return _FakeLoop() + + @pytest.mark.asyncio async def test_future_map_register_and_get_parents(): fm = FutureMap() @@ -150,5 +173,36 @@ async def coro(): assert result is None +def test_returns_result_when_coroutine_completes(fake_loop): + async def ok_coro(): + await asyncio.sleep(0.01) + return "done" + + result = run_coroutine_sync(lambda: ok_coro(), timeout=0.5) + assert result == "done" + + +def test_raises_original_exception_from_coroutine(fake_loop): + class CustomError(RuntimeError): + pass + + async def bad_coro(): + await asyncio.sleep(0.01) + raise CustomError("boom") + + with pytest.raises(CustomError, match="boom"): + run_coroutine_sync(lambda: bad_coro(), timeout=0.5) + + +def test_timeout_raises_timeout_error(fake_loop): + async def slow_coro(): + await asyncio.sleep(0.2) + return "too late" + + with pytest.raises(TimeoutError, match=r"did not complete in 0\.05 seconds"): + run_coroutine_sync(lambda: slow_coro(), timeout=0.05, raise_on_timeout=True) + time.sleep(0.25) + + if __name__ == '__main__': pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/common/test_config_loader.py b/tests/common/test_config_loader.py index fe482d3..9963e15 100644 --- a/tests/common/test_config_loader.py +++ b/tests/common/test_config_loader.py @@ -20,7 +20,32 @@ class TestDeviceConfig(unittest.TestCase): - def loads_default_values_when_env_vars_missing(self): + def test_config_creation_from_dict(self): + config_dict = { + "host": "localhost", + "port": 1883, + "access_token": "test_token", + "username": "test_user", + "password": "test_pass", + "client_id": "test_client", + "ca_cert": "test_ca", + "client_cert": "test_cert", + "private_key": "test_key", + "qos": 1 + } + config = DeviceConfig(config_dict) + self.assertEqual(config.host, "localhost") + self.assertEqual(config.port, 1883) + self.assertEqual(config.access_token, "test_token") + self.assertEqual(config.username, "test_user") + self.assertEqual(config.password, "test_pass") + self.assertEqual(config.client_id, "test_client") + self.assertEqual(config.ca_cert, "test_ca") + self.assertEqual(config.client_cert, "test_cert") + self.assertEqual(config.private_key, "test_key") + self.assertEqual(config.qos, 1) + + def test_loads_default_values_when_env_vars_missing(self): os.environ.clear() config = DeviceConfig() self.assertEqual(config.host, None) @@ -34,7 +59,7 @@ def loads_default_values_when_env_vars_missing(self): self.assertEqual(config.private_key, None) self.assertEqual(config.qos, 1) - def loads_values_from_env_vars(self): + def test_loads_values_from_env_vars(self): os.environ["TB_HOST"] = "test_host" os.environ["TB_PORT"] = "8883" os.environ["TB_ACCESS_TOKEN"] = "test_token" @@ -58,14 +83,14 @@ def loads_values_from_env_vars(self): self.assertEqual(config.private_key, "test_key") self.assertEqual(config.qos, 2) - def detects_tls_auth_correctly(self): + def test_detects_tls_auth_correctly(self): os.environ["TB_CA_CERT"] = "test_ca" os.environ["TB_CLIENT_CERT"] = "test_cert" os.environ["TB_PRIVATE_KEY"] = "test_key" config = DeviceConfig() self.assertTrue(config.use_tls_auth()) - def detects_tls_correctly(self): + def test_detects_tls_correctly(self): os.environ["TB_CA_CERT"] = "test_ca" config = DeviceConfig() self.assertTrue(config.use_tls()) @@ -73,7 +98,7 @@ def detects_tls_correctly(self): class TestGatewayConfig(unittest.TestCase): - def loads_gateway_specific_env_vars(self): + def test_loads_gateway_specific_env_vars(self): os.environ["TB_GW_HOST"] = "gw_host" os.environ["TB_GW_PORT"] = "8884" os.environ["TB_GW_ACCESS_TOKEN"] = "gw_token" @@ -97,10 +122,14 @@ def loads_gateway_specific_env_vars(self): self.assertEqual(config.private_key, "gw_key") self.assertEqual(config.qos, 0) - def falls_back_to_device_config_when_gateway_env_vars_missing(self): + def test_falls_back_to_device_config_when_gateway_env_vars_missing(self): os.environ.clear() os.environ["TB_HOST"] = "device_host" os.environ["TB_PORT"] = "1884" config = GatewayConfig() self.assertEqual(config.host, "device_host") self.assertEqual(config.port, 1884) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/common/test_exceptions.py b/tests/common/test_exceptions.py new file mode 100644 index 0000000..03368a9 --- /dev/null +++ b/tests/common/test_exceptions.py @@ -0,0 +1,156 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import List, Tuple, Optional + +import pytest + +from tb_mqtt_client.common.exceptions import ExceptionHandler, BackpressureException + + +def make_recorder(storage: List[Tuple[str, Optional[dict]]]): + def _recorder(exc: BaseException, context: Optional[dict]): + storage.append((exc.__class__.__name__, context)) + + return _recorder + + +def test_register_and_handle_specific_exception_calls_callback(): + handler = ExceptionHandler() + events: List[Tuple[str, Optional[dict]]] = [] + cb = make_recorder(events) + + handler.register(ValueError, cb) + + err = ValueError("boom") + ctx = {"info": "ctx"} + handler.handle(err, ctx) + + assert events == [("ValueError", ctx)] + + +def test_default_callback_called_when_no_specific_registered(): + handler = ExceptionHandler() + events: List[Tuple[str, Optional[dict]]] = [] + default_cb = make_recorder(events) + + handler.register(None, default_cb) + + err = KeyError("missing") + ctx = {"k": "v"} + handler.handle(err, ctx) + + assert events == [("KeyError", ctx)] + + +def test_specific_takes_precedence_default_not_called(): + handler = ExceptionHandler() + events: List[Tuple[str, Optional[dict]]] = [] + specific_events: List[Tuple[str, Optional[dict]]] = [] + + default_cb = make_recorder(events) + specific_cb = make_recorder(specific_events) + + handler.register(None, default_cb) + handler.register(RuntimeError, specific_cb) + + err = RuntimeError("nope") + ctx = {"tag": 1} + handler.handle(err, ctx) + + assert specific_events == [("RuntimeError", ctx)] + assert events == [] + + +def test_multiple_callbacks_for_same_exception_type_all_called(): + handler = ExceptionHandler() + calls: List[Tuple[str, Optional[dict]]] = [] + + cb1 = make_recorder(calls) + cb2 = make_recorder(calls) + handler.register(IndexError, cb1) + handler.register(IndexError, cb2) + + err = IndexError("oops") + ctx = {"n": 42} + handler.handle(err, ctx) + + assert calls == [("IndexError", ctx), ("IndexError", ctx)] + + +def test_subclass_matching_triggers_all_matching_callbacks_and_skips_default(): + handler = ExceptionHandler() + + calls_exception: List[Tuple[str, Optional[dict]]] = [] + calls_valueerror: List[Tuple[str, Optional[dict]]] = [] + default_calls: List[Tuple[str, Optional[dict]]] = [] + + cb_exception = make_recorder(calls_exception) + cb_valueerror = make_recorder(calls_valueerror) + cb_default = make_recorder(default_calls) + + handler.register(Exception, cb_exception) + handler.register(ValueError, cb_valueerror) + handler.register(None, cb_default) + + err = ValueError("bad value") + ctx = {"a": 1} + handler.handle(err, ctx) + + assert len(calls_exception) == 1 + assert len(calls_valueerror) == 1 + assert calls_exception[0][0] == "ValueError" + assert calls_valueerror[0][0] == "ValueError" + + assert default_calls == [] + + +@pytest.mark.asyncio +async def test_install_asyncio_handler_dispatches_on_loop_exception(event_loop: asyncio.AbstractEventLoop): + """ + Instead of relying on an un-awaited task (flaky under uvloop), + directly call the loop's exception handler with a context containing an exception. + This validates that install_asyncio_handler wires _asyncio_handler correctly, + and that it dispatches to ExceptionHandler.handle(...). + """ + handler = ExceptionHandler() + + calls: List[Tuple[str, Optional[dict]]] = [] + cb = make_recorder(calls) + handler.register(ValueError, cb) + + handler.install_asyncio_handler(event_loop) + + err = ValueError("async oops") + context = {"exception": err, "message": "unit-test"} + event_loop.call_exception_handler(context) + + await asyncio.sleep(0) + + assert calls == [("ValueError", context)] + + +def test_backpressure_exception_default_message(): + exc = BackpressureException() + assert str(exc) == "Client is under backpressure. Please retry later." + + +def test_backpressure_exception_custom_message(): + exc = BackpressureException("Hold up!") + assert str(exc) == "Hold up!" + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_gmqtt_patch.py b/tests/common/test_gmqtt_patch.py index 64f8474..172860b 100644 --- a/tests/common/test_gmqtt_patch.py +++ b/tests/common/test_gmqtt_patch.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. - import asyncio import heapq import struct import types import pytest +from gmqtt.mqtt.constants import MQTTv50, MQTTv311, MQTTCommands +from gmqtt.mqtt.handler import MqttPackageHandler +from gmqtt.mqtt.protocol import MQTTProtocol from tb_mqtt_client.common.gmqtt_patch import PatchUtils, PublishPacket from tb_mqtt_client.common.mqtt_message import MqttPublishMessage def test_parse_mqtt_properties_valid_and_invalid(): - # Unknown property id triggers warning branch pkt = bytes([1]) + bytes([255]) assert PatchUtils.parse_mqtt_properties(pkt) == {} - # Exception path (invalid variant) assert PatchUtils.parse_mqtt_properties(b"\xff") == {} @@ -53,27 +53,28 @@ def on_puback(mid, reason, props): monkeypatch.setattr("gmqtt.mqtt.handler.MqttPackageHandler._handle_puback_packet", lambda *a, **k: None) pu.patch_puback_handling(on_puback) - # Call wrapped handler handler = types.SimpleNamespace( - _connection=types.SimpleNamespace(persistent_storage=types.SimpleNamespace(remove=lambda m: None)) + _connection=types.SimpleNamespace( + persistent_storage=types.SimpleNamespace(remove=lambda m: None) + ) + ) + packet = struct.pack("!HB", 0xABCD, 0x10) + MqttPackageHandler._handle_puback_packet(handler, MQTTCommands.PUBACK, packet) + assert called["hit"][0] == 0xABCD + assert called["hit"][1] == 0x10 + assert isinstance(called["hit"][2], dict) + + client = types.SimpleNamespace( + _persistent_storage=types.SimpleNamespace( + _queue=[(0.0, 1, "raw")], + _check_empty=lambda: None + ) ) - MqttHandlerClass = type("H", (), {}) - pu_handler = MqttHandlerClass() - pu_handler._connection = handler._connection - pu_handler._handle_puback_packet = lambda *a, **k: None - pu_handler = types.SimpleNamespace(**{"_connection": handler._connection}) - # Ensure parsing works - PatchUtils.parse_mqtt_properties(b"\x00") - called.clear() - - # patch_storage test - client = types.SimpleNamespace(_persistent_storage=types.SimpleNamespace( - _queue=[(0, 1, "raw")], - _check_empty=lambda: None - )) pu.client = client pu.patch_storage() - assert asyncio.get_event_loop().run_until_complete(client._persistent_storage.pop_message()) + + tm, mid, raw = asyncio.get_event_loop().run_until_complete(client._persistent_storage.pop_message()) + assert mid == 1 and raw == "raw" @pytest.mark.asyncio @@ -81,7 +82,8 @@ async def test_retry_loop_and_task_controls(monkeypatch): storage_queue = [] class Storage: - def _check_empty(self): pass + def _check_empty(self): + pass async def pop_message(self): if not storage_queue: @@ -99,20 +101,19 @@ async def put_retry_message(self, msg): _persistent_storage = Storage() pu = PatchUtils(FakeClient(), asyncio.Event(), retry_interval=0) - heapq.heappush(storage_queue, (0, 1, types.SimpleNamespace(topic="t", dup=False))) - pu._stop_event.set() # immediate exit + heapq.heappush(storage_queue, (0.0, 1, types.SimpleNamespace(topic="t", dup=False))) + pu._stop_event.set() await pu._retry_loop() - - # start_retry_task + stop_retry_task normal + assert not msgs_sent pu._stop_event.clear() pu.start_retry_task() - assert pu._retry_task + assert pu._retry_task is not None await pu.stop_retry_task() assert pu._retry_task is None - # Timeout branch in stop_retry_task pu._retry_task = asyncio.create_task(asyncio.sleep(1)) await pu.stop_retry_task() + assert pu._retry_task is None def test_apply_calls_patch_and_starts_task(monkeypatch): @@ -130,15 +131,12 @@ def test_build_package_qos1_with_provided_mid(): protocol = types.SimpleNamespace(proto_ver=5) mid, packet = PublishPacket.build_package(msg, protocol, mid=77) assert mid == 77 - # Verify mid is encoded at correct spot (after topic length and topic) assert struct.pack("!H", 77) in packet - # DUP flag set assert packet[0] & 0x08 def test_build_package_qos1_with_generated_mid(monkeypatch): msg = MqttPublishMessage(topic="gen", payload=b"PAY", qos=1) - # Force known id from id_generator monkeypatch.setattr(PublishPacket, "id_generator", types.SimpleNamespace(next_id=lambda: 1234)) protocol = types.SimpleNamespace(proto_ver=5) mid, packet = PublishPacket.build_package(msg, protocol) @@ -146,5 +144,291 @@ def test_build_package_qos1_with_generated_mid(monkeypatch): assert struct.pack("!H", 1234) in packet +# ------------------------------ +# New tests (extended coverage) +# ------------------------------ + +def test_build_package_qos0_no_mid_empty_payload(): + msg = MqttPublishMessage(topic="t/0", payload=b"", qos=0, retain=False) + protocol = types.SimpleNamespace(proto_ver=5) + mid, packet = PublishPacket.build_package(msg, protocol) + assert mid is None + assert b"t/0" in packet + assert (packet[0] >> 1) & 0b11 == 0 + + +@pytest.mark.asyncio +async def test_patch_mqtt_handler_disconnect_invokes_on_disconnect_and_reconnect(monkeypatch): + assert PatchUtils.patch_mqtt_handler_disconnect(None) is True + + reconnect_called = asyncio.Event() + on_disconnect_args = {} + + async def fake_reconnect(delay=True): + reconnect_called.set() + + def fake_on_disconnect(self_like, reason_code, properties, exc): + on_disconnect_args["rc"] = reason_code + on_disconnect_args["props"] = properties + on_disconnect_args["exc"] = exc + + fake_connection = types.SimpleNamespace(_on_disconnect_called=False) + self_like = types.SimpleNamespace( + _clear_topics_aliases=lambda: None, + reconnect=fake_reconnect, + _handle_exception_in_future=lambda f: None, + on_disconnect=fake_on_disconnect, + _connection=fake_connection, + ) + + packet = bytes([151]) + bytes([0]) + + MqttPackageHandler._handle_disconnect_packet(self_like, MQTTCommands.DISCONNECT, packet) + + await asyncio.wait_for(reconnect_called.wait(), timeout=1.0) + assert on_disconnect_args["rc"] == 151 + assert isinstance(on_disconnect_args["props"], dict) + assert on_disconnect_args["exc"] is None + assert self_like._connection._on_disconnect_called is True + + +@pytest.mark.asyncio +async def test_patch_handle_connack_success_and_error_paths(monkeypatch): + assert PatchUtils.patch_handle_connack(None) is True + + connected_set = {"hit": False} + on_connect_called = {} + + def _connected_set(): + connected_set["hit"] = True + + success_self = types.SimpleNamespace( + _connected=types.SimpleNamespace(set=_connected_set), + _logger=types.SimpleNamespace(warning=lambda *a, **k: None, + info=lambda *a, **k: None, + debug=lambda *a, **k: None), + failed_connections=0, + protocol_version=MQTTv50, + _parse_properties=lambda payload: ({'x': 'y'}, b""), + _update_keepalive_if_needed=lambda: None, + properties={}, + on_connect=lambda *args: on_connect_called.setdefault("ok", args), + disconnect=lambda: asyncio.Future(), + reconnect=lambda delay=True: asyncio.Future(), + _handle_exception_in_future=lambda f: None, + ) + + packet_ok = struct.pack("!BB", 1, 0) + b"\x00" + MqttPackageHandler._handle_connack_packet(success_self, MQTTCommands.CONNACK, packet_ok) + assert connected_set["hit"] is True + assert "ok" in on_connect_called + + downgraded = {} + + fut = asyncio.get_event_loop().create_future() + fut.set_result(None) + + def fake_reconnect(delay=True): + downgraded["called"] = True + f = asyncio.get_event_loop().create_future() + f.set_result(None) + return f + + error_self = types.SimpleNamespace( + _connected=types.SimpleNamespace(set=lambda: None), + _logger=types.SimpleNamespace(warning=lambda *a, **k: None, + info=lambda *a, **k: None, + debug=lambda *a, **k: None), + failed_connections=0, + protocol_version=MQTTv50, + _parse_properties=lambda payload: ({}, b""), + _update_keepalive_if_needed=lambda: None, + properties={}, + on_connect=lambda *a, **k: None, + reconnect=fake_reconnect, + _handle_exception_in_future=lambda f: None, + ) + + packet_err = struct.pack("!BB", 0, 1) + MqttPackageHandler._handle_connack_packet(error_self, MQTTCommands.CONNACK, packet_err) + assert downgraded.get("called") is True + assert MQTTProtocol.proto_ver == MQTTv311 + + +@pytest.mark.asyncio +async def test_patch_gmqtt_protocol_connection_lost_and_handler_call(monkeypatch): + assert PatchUtils.patch_gmqtt_protocol_connection_lost(None) is True + + disconnect_pkgs = [] + + def put_package(pkg): + disconnect_pkgs.append(pkg) + + loop = asyncio.get_event_loop() + read_future = loop.create_future() + + class FakeMQTTProtocol(MQTTProtocol): + def __init__(self): + pass + + proto_self = FakeMQTTProtocol() + proto_self._connected = types.SimpleNamespace(clear=lambda: None) + proto_self._connection = types.SimpleNamespace( + _disconnect_exc=None, + _disconnect_properties=None, + put_package=put_package + ) + proto_self._read_loop_future = read_future + proto_self._queue = None + + proto_self._stream_reader_wr = lambda: types.SimpleNamespace( + feed_eof=lambda: None, + set_exception=lambda exc: None + ) + proto_self._closed = loop.create_future() + proto_self._paused = False + + proto_self.connection_lost(TimeoutError("simulated")) + + assert disconnect_pkgs and disconnect_pkgs[0][0] == MQTTCommands.DISCONNECT + payload = disconnect_pkgs[0][1] + assert isinstance(payload, (bytes, bytearray)) and len(payload) == 1 + + reconnect_flag = asyncio.Event() + on_disc_args = {} + + async def fake_reconnect(delay=True): + reconnect_flag.set() + + handler_self = types.SimpleNamespace( + _connection=proto_self._connection, + _clear_topics_aliases=lambda: None, + reconnect=fake_reconnect, + _handle_exception_in_future=lambda f: None, + on_disconnect=lambda *args: on_disc_args.setdefault("args", args), + ) + + MqttPackageHandler.__call__(handler_self, MQTTCommands.DISCONNECT, b"\x01") + await asyncio.wait_for(reconnect_flag.wait(), timeout=1.0) + assert "args" in on_disc_args + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "exc, expected_reason", + [ + (ConnectionRefusedError("boom"), 135), # Keep Alive timeout + (TimeoutError("boom"), 135), # Keep Alive timeout + (ConnectionResetError("boom"), 139), # Receive Maximum exceeded + (ConnectionAbortedError("boom"), 136), # Session taken over + (PermissionError("boom"), 132), # Not authorized + (OSError("boom"), 130), # Protocol Error + (ValueError("boom"), 131), # Implementation specific error (fallback) + ], +) +async def test_connection_lost_reason_code_mapping(monkeypatch, exc, expected_reason): + assert PatchUtils.patch_gmqtt_protocol_connection_lost(None) is True + + def build_proto(): + disconnect_pkgs = [] + + def put_package(pkg): + disconnect_pkgs.append(pkg) + + loop = asyncio.get_event_loop() + read_future = loop.create_future() + + class FakeMQTTProtocol(MQTTProtocol): + def __init__(self): + pass + + proto = FakeMQTTProtocol() + proto._connected = types.SimpleNamespace(clear=lambda: None) + proto._connection = types.SimpleNamespace( + _disconnect_exc=None, + _disconnect_properties=None, + put_package=put_package + ) + proto._read_loop_future = read_future + proto._queue = None + proto._stream_reader_wr = lambda: types.SimpleNamespace( + feed_eof=lambda: None, + set_exception=lambda e: None + ) + proto._closed = loop.create_future() + proto._paused = False + return proto, disconnect_pkgs + + proto_self, disconnect_pkgs = build_proto() + + proto_self.connection_lost(exc) + + assert disconnect_pkgs and disconnect_pkgs[0][0] == MQTTCommands.DISCONNECT + payload = disconnect_pkgs[0][1] + assert isinstance(payload, (bytes, bytearray)) and len(payload) == 1 + assert payload[0] == expected_reason + + props = getattr(proto_self._connection, "_disconnect_properties", {}) + assert isinstance(props, dict) + assert "reason_string" in props + assert isinstance(props["reason_string"], list) and props["reason_string"][0] == str(exc.args[0]) + + +def test_patch_puback_handling_with_properties(monkeypatch): + pu = PatchUtils(None, asyncio.Event()) + got = {} + + def on_puback(mid, reason, props): + got["mid"] = mid + got["reason"] = reason + got["props"] = props + + monkeypatch.setattr("gmqtt.mqtt.handler.MqttPackageHandler._handle_puback_packet", lambda *a, **k: None) + pu.patch_puback_handling(on_puback) + + from gmqtt.mqtt.property import Property + from gmqtt.mqtt.utils import pack_variable_byte_integer + + reason_prop = Property.factory(name="reason_string").dumps("OK") + props_len_prefix = pack_variable_byte_integer(len(reason_prop)) + props_payload = bytes(props_len_prefix) + bytes(reason_prop) + + packet = struct.pack("!H", 0x1234) + bytes([0x00]) + props_payload + + handler_self = types.SimpleNamespace( + _connection=types.SimpleNamespace( + persistent_storage=types.SimpleNamespace(remove=lambda m: None) + ) + ) + + # Act + MqttPackageHandler._handle_puback_packet(handler_self, MQTTCommands.PUBACK, packet) + + assert got["mid"] == 0x1234 + assert got["reason"] == 0x00 + assert isinstance(got["props"], dict) + assert got["props"].get("reason_string") == ["OK"] + + +def test_parse_properties_reason_string_and_user_property_real(): + from gmqtt.mqtt.property import Property + from gmqtt.mqtt.utils import pack_variable_byte_integer + + rs = Property.factory(name="reason_string").dumps("OK") + up = Property.factory(name="user_property").dumps(("k", "v")) + props = bytes(rs + up) + + packet = bytes(pack_variable_byte_integer(len(props))) + props + parsed = PatchUtils.parse_mqtt_properties(packet) + + assert parsed["reason_string"] == ["OK"] + assert parsed["user_property"] == [("k", "v")] + + +def test_parse_properties_unknown_property_returns_empty(): + packet = b"\x01" + b"\xff" + assert PatchUtils.parse_mqtt_properties(packet) == {} + + if __name__ == '__main__': pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/entities/gateway/test_gateway_rpc_request.py b/tests/entities/gateway/test_gateway_rpc_request.py new file mode 100644 index 0000000..5589b4d --- /dev/null +++ b/tests/entities/gateway/test_gateway_rpc_request.py @@ -0,0 +1,154 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest + + +def test_direct_instantiation_not_allowed(): + """ + Ensures dataclass can't be instantiated directly and guides to deserializer. + """ + with pytest.raises(TypeError) as exc: + GatewayRPCRequest() # type: ignore[call-arg] + assert "Direct instantiation" in str(exc.value) + + +def test_deserialize_with_int_id_and_params(): + """ + Valid deserialization: numeric request_id and explicit params. + """ + payload = { + "device": "pump-1", + "data": { + "id": 42, + "method": "reboot", + "params": {"delay": 5} + } + } + req = GatewayRPCRequest._deserialize_from_dict(payload) + + assert isinstance(req, GatewayRPCRequest) + assert req.request_id == 42 + assert req.device_name == "pump-1" + assert req.method == "reboot" + assert req.params == {"delay": 5} + assert req.event_type == GatewayEventType.DEVICE_RPC_REQUEST + + # __repr__ and __str__ should be identical and informative + expected_repr = "RPCRequest(id=42, device_name=pump-1, method=reboot, params={'delay': 5})" + assert repr(req) == expected_repr + assert str(req) == expected_repr + + +def test_deserialize_with_str_id_and_no_params(): + """ + Valid deserialization: string request_id and missing params (defaults to None). + """ + payload = { + "device": "sensor-A", + "data": { + "id": "rpc-001", + "method": "ping" + # no "params" + } + } + req = GatewayRPCRequest._deserialize_from_dict(payload) + + assert req.request_id == "rpc-001" + assert req.device_name == "sensor-A" + assert req.method == "ping" + assert req.params is None + assert req.event_type == GatewayEventType.DEVICE_RPC_REQUEST + + expected_repr = "RPCRequest(id=rpc-001, device_name=sensor-A, method=ping, params=None)" + assert repr(req) == expected_repr + assert str(req) == expected_repr + + +def test_missing_device_raises_value_error(): + """ + Missing top-level 'device' should raise a clear ValueError. + """ + payload = { + # "device": "x", + "data": {"id": 1, "method": "ping"} + } + with pytest.raises(ValueError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + assert "Missing device name" in str(exc.value) + + +def test_missing_data_key_raises_key_error(): + """ + The implementation accesses data['data'] directly; if absent, a KeyError is expected. + """ + payload = { + "device": "dev-x" + # no "data" + } + with pytest.raises(KeyError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + assert "'data'" in str(exc.value) # KeyError message includes missing key + + +@pytest.mark.parametrize("bad_id", [None, 3.14, {"n": 1}, ["1"]]) +def test_invalid_request_id_type_raises_value_error(bad_id): + """ + request_id must be int or str; anything else should raise ValueError. + """ + payload = { + "device": "dev-y", + "data": {"id": bad_id, "method": "ping"} + } + with pytest.raises(ValueError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + # The current implementation uses a generic message; assert on the key phrase. + assert "request id" in str(exc.value).lower() + + +def test_missing_method_raises_value_error(): + """ + Missing 'method' inside 'data' should raise ValueError. + """ + payload = { + "device": "dev-z", + "data": {"id": 7} # no "method" + } + with pytest.raises(ValueError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + assert "Missing 'method'" in str(exc.value) + + +def test_event_type_is_device_rpc_request(): + """ + Sanity check to ensure event_type is set correctly by the deserializer. + """ + payload = { + "device": "gw-1", + "data": { + "id": 101, + "method": "set_threshold", + "params": {"value": 12.3} + } + } + req = GatewayRPCRequest._deserialize_from_dict(payload) + assert req.event_type == GatewayEventType.DEVICE_RPC_REQUEST + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_rpc_response.py b/tests/entities/gateway/test_gateway_rpc_response.py new file mode 100644 index 0000000..0ead467 --- /dev/null +++ b/tests/entities/gateway/test_gateway_rpc_response.py @@ -0,0 +1,177 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import pytest + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.data.rpc_response import RPCStatus +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse + + +def test_direct_instantiation_not_allowed(): + with pytest.raises(TypeError) as exc: + GatewayRPCResponse() # type: ignore[call-arg] + assert "Direct instantiation of GatewayRPCResponse is not allowed" in str(exc.value) + + +def test_build_success_with_result_only(): + req = GatewayRPCResponse.build( + device_name="pump-1", + request_id=101, + result={"ok": True} + ) + assert isinstance(req, GatewayRPCResponse) + assert req.device_name == "pump-1" + assert req.request_id == 101 + assert req.status == RPCStatus.SUCCESS + assert req.error is None + assert req.result == {"ok": True} + assert req.event_type == GatewayEventType.DEVICE_RPC_RESPONSE + + expected_repr = ( + "GatewayRPCResponse(device_name=pump-1, request_id=101, result={'ok': True}, error=None)" + ) + assert repr(req) == expected_repr + + payload = req.to_payload_format() + assert payload == {"device": "pump-1", "id": 101, "data": {"result": {"ok": True}}} + + +def test_build_with_error_string(): + req = GatewayRPCResponse.build( + device_name="sensor-A", + request_id=7, + error="boom" + ) + assert req.status == RPCStatus.ERROR + assert req.error == "boom" + assert req.result is None + + payload = req.to_payload_format() + assert payload == {"device": "sensor-A", "id": 7, "data": {"error": "boom"}} + + +def test_build_with_error_dict(): + err = {"message": "bad-things", "code": 500} + req = GatewayRPCResponse.build( + device_name="sensor-B", + request_id=8, + error=err + ) + assert req.status == RPCStatus.ERROR + assert req.error == err + assert req.result is None + + payload = req.to_payload_format() + assert payload == {"device": "sensor-B", "id": 8, "data": {"error": err}} + + +def test_build_with_exception_converted_to_struct(): + class CustomError(RuntimeError): + pass + + req = GatewayRPCResponse.build( + device_name="dev-ex", + request_id=999, + error=CustomError("kaput") + ) + + assert req.status == RPCStatus.ERROR + assert isinstance(req.error, dict) + assert req.error.get("message") == "kaput" + assert req.error.get("type") == "CustomError" + assert isinstance(req.error.get("details"), str) + assert "CustomError" in req.error["details"] + + payload = req.to_payload_format() + assert payload["device"] == "dev-ex" + assert payload["id"] == 999 + assert "error" in payload["data"] + assert payload["data"]["error"]["type"] == "CustomError" + + +def test_build_with_result_and_error_includes_both_in_payload(): + req = GatewayRPCResponse.build( + device_name="dev-mix", + request_id=321, + result={"partial": True}, + error="some error" + ) + assert req.status == RPCStatus.ERROR + assert req.result == {"partial": True} + assert req.error == "some error" + + payload = req.to_payload_format() + assert payload == { + "device": "dev-mix", + "id": 321, + "data": {"result": {"partial": True}, "error": "some error"}, + } + + +@pytest.mark.parametrize("bad_device", ["", 123, None, object()]) +def test_build_invalid_device_name_raises_value_error(bad_device): + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build(device_name=bad_device, request_id=1) # type: ignore[arg-type] + assert "Device name must be a non-empty string" in str(exc.value) + + +@pytest.mark.parametrize("bad_error", [123, 3.14, ["oops"], set(), object()]) +def test_build_invalid_error_type_raises_value_error(bad_error): + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build(device_name="dev-x", request_id=1, error=bad_error) # type: ignore[arg-type] + assert "Error must be a string, dictionary, or an exception instance" in str(exc.value) + + +def test_build_invalid_result_fails_json_validation(): + class NotJson: + pass + + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build(device_name="dev-j", request_id=2, result=NotJson()) + assert re.search(r"(json|compatible|type)", str(exc.value), re.IGNORECASE) + + +def test_build_invalid_error_dict_fails_json_validation(): + class NotJson: + pass + + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build( + device_name="dev-je", + request_id=3, + error={"bad": NotJson()} + ) + assert re.search(r"(json|compatible|type)", str(exc.value), re.IGNORECASE) + + +def test_event_type_is_device_rpc_response(): + req = GatewayRPCResponse.build(device_name="gw-1", request_id=55, result=True) + assert req.event_type == GatewayEventType.DEVICE_RPC_RESPONSE + + +def test_repr_includes_all_fields(): + req = GatewayRPCResponse.build(device_name="repr-dev", request_id=42, result={"x": 1}) + assert repr(req) == "GatewayRPCResponse(device_name=repr-dev, request_id=42, result={'x': 1}, error=None)" + + +def test_immutability_enforced_by_frozen_dataclass(): + req = GatewayRPCResponse.build(device_name="immu", request_id=1, result=None) + with pytest.raises((AttributeError, TypeError)): + setattr(req, "status", RPCStatus.ERROR) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py index 9857434..4ee9cf0 100644 --- a/tests/service/test_mqtt_manager.py +++ b/tests/service/test_mqtt_manager.py @@ -13,12 +13,14 @@ # limitations under the License. import asyncio +import ssl from time import monotonic -from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock, call +from unittest.mock import AsyncMock, MagicMock, patch, call, PropertyMock import pytest import pytest_asyncio +from tb_mqtt_client.common.exceptions import BackpressureException from tb_mqtt_client.common.mqtt_message import MqttPublishMessage from tb_mqtt_client.common.publish_result import PublishResult from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit @@ -92,7 +94,6 @@ async def test_on_disconnect_internal_abnormal_disconnect(setup_manager): @pytest.mark.asyncio async def test_handle_puback_reason_code_unknown_id(setup_manager): manager, *_ = setup_manager - # Should not raise or fail if ID not tracked manager._handle_puback_reason_code(999, 0, {}) @@ -138,6 +139,19 @@ async def test_publish_force_bypasses_limits(setup_manager): assert manager._client._connection.publish.call_count == 1 +@pytest.mark.asyncio +async def test_publish_backpressure_blocks(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._rate_limits_ready_event.set() + manager._backpressure = MagicMock() + manager._backpressure.should_pause.return_value = True + + with pytest.raises(BackpressureException): + await manager.publish(MqttPublishMessage("t", b"x"), force=False) + + @pytest.mark.asyncio async def test_on_disconnect_internal_clears_futures(setup_manager): manager, *_ = setup_manager @@ -188,6 +202,16 @@ async def test_set_rate_limits_allows_ready(setup_manager): assert manager._rate_limits_ready_event.is_set() +@pytest.mark.asyncio +async def test_set_rate_limits_gateway_requires_gateway_limits(setup_manager): + manager, *_ = setup_manager + manager.enable_gateway_mode() + manager.set_rate_limits_received() + assert not manager._rate_limits_ready_event.is_set() + manager.set_gateway_rate_limits_received() + assert manager._rate_limits_ready_event.is_set() + + @pytest.mark.asyncio async def test_match_topic_logic(): assert MQTTManager._match_topic("foo/+", "foo/bar") @@ -205,6 +229,17 @@ async def test_check_pending_publishes_timeout(setup_manager): assert fut.result().reason_code == 408 +@pytest.mark.asyncio +async def test_check_pending_publishes_cancelled_when_stopping(setup_manager): + manager, stop_event, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[5] = (fut, MqttPublishMessage("t", b"p"), monotonic()) + stop_event.set() + await manager.check_pending_publishes(monotonic()) + assert fut.cancelled() + stop_event.clear() + + @pytest.mark.asyncio async def test_disconnect_swallows_reset_error(setup_manager): manager, *_ = setup_manager @@ -261,7 +296,7 @@ async def test_publish_qos_zero_sets_result_immediately(setup_manager): future = asyncio.Future() await manager.publish(MqttPublishMessage("topic", b"payload", qos=0, delivery_futures=future), force=True) - await asyncio.sleep(0.05) # Allow async tasks to complete + await asyncio.sleep(0.05) assert future.done() assert future.result() == PublishResult("topic", 0, -1, 7, 0) @@ -300,7 +335,7 @@ async def test_handle_puback_reason_code_errors(setup_manager): manager._handle_puback_reason_code(2, QUOTA_EXCEEDED, {}) assert f2.result().reason_code == QUOTA_EXCEEDED - manager._handle_puback_reason_code(9999, 1, {}) # Should log warning, not crash + manager._handle_puback_reason_code(9999, 1, {}) @pytest.mark.asyncio @@ -328,11 +363,32 @@ async def test_request_rate_limits_timeout(setup_manager): assert manager._rate_limits_ready_event.is_set() +@pytest.mark.asyncio +async def test_request_rate_limits_real_timeout_branch(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + + manager._message_adapter.build_rpc_request.return_value = MqttPublishMessage("t", b"p") + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (1, b"p") + manager._client._persistent_storage = MagicMock() + + never = asyncio.Future() + manager._rpc_response_handler.register_request.return_value = never + + with patch("tb_mqtt_client.service.mqtt_manager.await_or_stop", side_effect=asyncio.TimeoutError): + await manager._MQTTManager__request_rate_limits() + + assert manager._MQTTManager__is_waiting_for_rate_limits_publish is True + assert not manager._rate_limits_ready_event.is_set() + + @pytest.mark.asyncio async def test_monitor_ack_timeouts_stops_gracefully(setup_manager): manager, stop_event, *_ = setup_manager stop_event.set() await manager._monitor_ack_timeouts() + stop_event.clear() @pytest.mark.asyncio @@ -369,5 +425,241 @@ async def test_disconnect_reason_code_142_triggers_special_flow(mock_run_sync, s manager._on_disconnect_callback.assert_awaited_once() +@pytest.mark.asyncio +async def test_duplicate_publish_path_retransmits_and_persists(setup_manager): + manager, *_ = setup_manager + conn = MagicMock() + protocol = object() + conn._protocol = protocol + manager._client._connection = conn + manager._client._persistent_storage = MagicMock() + + msg = MqttPublishMessage("topic/dup", b"payload", qos=1) + msg.dup = True + msg.message_id = 321 + + with patch("tb_mqtt_client.service.mqtt_manager.PublishPacket.build_package", + return_value=(321, b"rebuilt")) as mock_build: + await manager.publish(msg, force=True) + + mock_build.assert_called_once() + conn.send_package.assert_called_once_with(b"rebuilt") + manager._client._persistent_storage.push_message_nowait.assert_called_once_with(321, msg) + + +@pytest.mark.asyncio +async def test_add_future_chain_processing_success_and_immediate(setup_manager): + manager, *_ = setup_manager + main = asyncio.get_event_loop().create_future() + main.uuid = "main" + + f1 = asyncio.get_event_loop().create_future() + f1.uuid = "f1" + f2 = asyncio.get_event_loop().create_future() + f2.uuid = "f2" + + msg = MqttPublishMessage("t", b"p", qos=1, delivery_futures=[f1, f2]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved") as child_resolved: + await MQTTManager._add_future_chain_processing(main, msg) + res = PublishResult("t", 1, -1, 1, 0) + main.set_result(res) + await asyncio.sleep(0.05) + assert f1.done() and f2.done() + assert f1.result() == res and f2.result() == res + assert child_resolved.call_count == 2 + + main2 = asyncio.get_event_loop().create_future() + main2.uuid = "main2" + main2.set_result(res) + g1 = asyncio.get_event_loop().create_future(); g1.uuid = "g1" + g2 = asyncio.get_event_loop().create_future(); g2.uuid = "g2" + msg2 = MqttPublishMessage("t2", b"q", qos=1, delivery_futures=[g1, g2]) + await MQTTManager._add_future_chain_processing(main2, msg2) + await asyncio.sleep(0.05) + assert g1.done() and g2.done() + + + +@pytest.mark.asyncio +async def test_add_future_chain_processing_cancel_and_exception(setup_manager): + manager, *_ = setup_manager + + cancelled_main = asyncio.get_event_loop().create_future() + cancelled_main.uuid = "cmain" + cancelled_main.cancel() + c1 = asyncio.get_event_loop().create_future(); c1.uuid = "c1" + c2 = asyncio.get_event_loop().create_future(); c2.uuid = "c2" + msg = MqttPublishMessage("t", b"p", qos=1, delivery_futures=[c1, c2]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved"): + await MQTTManager._add_future_chain_processing(cancelled_main, msg) + await asyncio.sleep(0.05) + assert c1.done() and c2.done() + assert isinstance(c1.result(), PublishResult) and isinstance(c2.result(), PublishResult) + + exc_main = asyncio.get_event_loop().create_future() + exc_main.uuid = "emain" + exc_main.set_exception(RuntimeError("boom")) + e1 = asyncio.get_event_loop().create_future(); e1.uuid = "e1" + e2 = asyncio.get_event_loop().create_future(); e2.uuid = "e2" + msg2 = MqttPublishMessage("t2", b"p2", qos=1, delivery_futures=[e1, e2]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved"): + await MQTTManager._add_future_chain_processing(exc_main, msg2) + await asyncio.sleep(0.05) + assert e1.done() and e2.done() + assert isinstance(e1.result(), PublishResult) and isinstance(e2.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_process_regular_publish_qos1_resolves_delivery_on_puback(setup_manager): + manager, *_ = setup_manager + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (555, b"p") + manager._client._persistent_storage = MagicMock() + + d1 = asyncio.get_event_loop().create_future(); + d1.uuid = "d1" + d2 = asyncio.get_event_loop().create_future(); + d2.uuid = "d2" + msg = MqttPublishMessage("t", b"payload", qos=1, delivery_futures=[d1, d2]) + await manager.publish(msg, force=True) + + manager._handle_puback_reason_code(555, 0, {}) + await asyncio.sleep(0.05) + assert d1.done() and d2.done() + assert isinstance(d1.result(), PublishResult) and isinstance(d2.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_on_connect_internal_failure_does_not_set_ready(setup_manager): + manager, *_ = setup_manager + manager._connected_event.set() + manager._on_connect_internal(manager._client, session_present=False, reason_code=128, properties=None) + assert not manager._connected_event.is_set() + + + +@pytest.mark.asyncio +async def test_on_connect_internal_success_triggers_limits_flow(setup_manager): + manager, *_ = setup_manager + called = asyncio.Event() + + async def fake_limits_flow(): + called.set() + + manager._client._connection = type("Conn", (), {})() + + setattr(manager, "_MQTTManager__handle_connect_and_limits", fake_limits_flow) + + manager._on_connect_internal(manager._client, session_present=True, reason_code=0, properties={}) + await asyncio.wait_for(called.wait(), timeout=1.0) + assert manager._connected_event.is_set() + + +@pytest.mark.asyncio +async def test_patch_client_for_retry_logic_assigns_method(setup_manager): + manager, *_ = setup_manager + + async def put_retry(msg: MqttPublishMessage): + return None + + manager.patch_client_for_retry_logic(put_retry) + assert manager._client.put_retry_message is put_retry + + +@pytest.mark.asyncio +async def test_stop_cancels_tasks_and_calls_patch_utils(setup_manager): + manager, *_ = setup_manager + manager._patch_utils.stop_retry_task = AsyncMock() + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock, return_value=False): + await manager.stop() + + manager._patch_utils.stop_retry_task.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_tls_creates_default_ssl_context(setup_manager): + manager, *_ = setup_manager + + manager._client.connect = AsyncMock() + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock, return_value=True), \ + patch.object(ssl, "create_default_context", autospec=True) as create_ctx: + await manager.connect("host", 8883, tls=True, ssl_context=None) + + create_ctx.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_loop_handles_wait_exception_and_exits_when_connected(setup_manager): + manager, stop_event, *_ = setup_manager + + manager._client.connect = AsyncMock() + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock, + side_effect=[False, True]), \ + patch.object(manager._connected_event, "wait", + AsyncMock(side_effect=[Exception("boom"), None])): + await manager.connect("host", 1883, tls=False) + + +@pytest.mark.asyncio +async def test_publish_wait_rate_limits_timeout_requests_then_raises(setup_manager): + manager, *_ = setup_manager + + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._backpressure.should_pause = MagicMock(return_value=False) + + with patch("tb_mqtt_client.service.mqtt_manager.await_or_stop", side_effect=asyncio.TimeoutError), \ + patch.object(manager, "_MQTTManager__request_rate_limits", new=AsyncMock()) as req_limits: + with pytest.raises(RuntimeError, match="Timeout waiting for rate limits."): + await manager.publish(MqttPublishMessage("t", b"p"), force=False) + + req_limits.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_rate_limit_timeout_sets_default_backpressure_delay(setup_manager): + manager, *_ = setup_manager + + rate_limit = MagicMock(spec=RateLimit) + setattr(manager, "_MQTTManager__rate_limiter", {"messages": rate_limit}) + manager._backpressure = MagicMock() + + with patch("tb_mqtt_client.service.mqtt_manager.run_coroutine_sync", side_effect=TimeoutError): + manager._on_disconnect_internal(manager._client, reason_code=131) + + manager._backpressure.notify_disconnect.assert_called_with(delay_seconds=10) + + +@pytest.mark.asyncio +async def test_add_future_chain_processing_sets_exception_when_child_resolved_raises(setup_manager): + manager, *_ = setup_manager + + main = asyncio.get_event_loop().create_future() + main.uuid = "main" + + f_ok = asyncio.get_event_loop().create_future(); f_ok.uuid = "ok" + f_err = asyncio.get_event_loop().create_future(); f_err.uuid = "err" + msg = MqttPublishMessage("topic", b"payload", qos=1, delivery_futures=[f_ok, f_err]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved", + side_effect=RuntimeError("boom")): + await MQTTManager._add_future_chain_processing(main, msg) + + main.set_result(PublishResult("topic", 1, -1, len(b"payload"), 0)) + await asyncio.sleep(0.05) + + assert f_ok.done() + assert isinstance(f_ok.result(), PublishResult) + + assert f_err.done() + assert isinstance(f_err.exception(), Exception) + + if __name__ == '__main__': pytest.main([__file__, "-v", "--tb=short"]) From 665fc95dcb00a842485f23cb915128a65cecf408 Mon Sep 17 00:00:00 2001 From: imbeacon Date: Mon, 11 Aug 2025 09:39:32 +0300 Subject: [PATCH 74/74] Added check for ThingsBoard availability before running blackbox tests for case when platform is running locally --- tests/blackbox/conftest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/blackbox/conftest.py b/tests/blackbox/conftest.py index 93b36ae..707f7a0 100644 --- a/tests/blackbox/conftest.py +++ b/tests/blackbox/conftest.py @@ -55,7 +55,7 @@ TB_URL: str = ENV("SDK_BLACKBOX_TB_URL", f"{TB_HTTP_PROTOCOL}://{TB_HOST}:{TB_HTTP_PORT}") REQUEST_TIMEOUT: float = float(ENV("SDK_BLACKBOX_HTTP_TIMEOUT", "30")) -DOCKER_START_TIMEOUT_S: int = int(ENV("SDK_BLACKBOX_TB_START_TIMEOUT", "180")) +TB_START_TIMEOUT: int = int(ENV("SDK_BLACKBOX_TB_START_TIMEOUT", "180")) logger = logging.getLogger("blackbox") logger.setLevel(logging.INFO) @@ -128,7 +128,7 @@ def _wait_for_tb_ready(http: requests.Session) -> None: start = time.time() delay = 1.0 # Try login to ensure full readiness (DB + REST). - while time.time() - start < DOCKER_START_TIMEOUT_S: + while time.time() - start < TB_START_TIMEOUT: try: r = http.post( f"{TB_URL}/api/auth/login", @@ -152,6 +152,7 @@ def start_thingsboard(http: requests.Session) -> Generator[None, None, None]: Ensures instance is ready. If we start the container, we will stop it at session end. """ if not RUN_BLACKBOX: + _wait_for_tb_ready(http) yield return @@ -480,7 +481,7 @@ def device_profile_with_rpc_rule_chain( http: requests.Session, ) -> Generator[dict, None, None]: """ - Creates a device profile that uses the RPC rule chain as default; cleans up afterwards. + Creates a device profile that uses the RPC rule chain as default; cleans up afterward. """ rule_chain = rpc_rule_chain device_profile = get_default_device_profile(test_config["tb_url"], tb_admin_headers)