diff --git a/LICENSE b/LICENSE index 6cbfe64..373c076 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2024. ThingsBoard + Copyright 2025 ThingsBoard Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index bdb0878..1551f93 100644 --- a/README.md +++ b/README.md @@ -1,345 +1,88 @@ -# ThingsBoard MQTT and HTTP client Python SDK +# ThingsBoard Python Client SDK 2.0 + [![Join the chat at https://gitter.im/thingsboard/chat](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/thingsboard/chat?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -ThingsBoard is an open-source IoT platform for data collection, processing, visualization, and device management. -This project is a Python library that provides convenient client SDK for both Device and [Gateway](https://thingsboard.io/docs/reference/gateway-mqtt-api/) APIs. - -SDK supports: -- Unencrypted and encrypted (TLS v1.2) connection -- QoS 0 and 1 (MQTT only) -- Automatic reconnect -- All [Device MQTT](https://thingsboard.io/docs/reference/mqtt-api/) APIs provided by ThingsBoard -- All [Gateway MQTT](https://thingsboard.io/docs/reference/gateway-mqtt-api/) APIs provided by ThingsBoard -- Most [Device HTTP](https://thingsboard.io/docs/reference/http-api/) APIs provided by ThingsBoard -- Device Claiming -- Firmware updates - -The [Device MQTT](https://thingsboard.io/docs/reference/mqtt-api/) API and the [Gateway MQTT](https://thingsboard.io/docs/reference/gateway-mqtt-api/) API are base on the Paho MQTT library. The [Device HTTP](https://thingsboard.io/docs/reference/http-api/) API is based on the Requests library. - -## Installation - -To install using pip: - -```bash -pip3 install tb-mqtt-client -``` - -## Getting Started - -Client initialization and telemetry publishing -### MQTT -```python -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo - - -telemetry = {"temperature": 41.9, "enabled": False, "currentFirmwareVersion": "v1.2.2"} - -# Initialize ThingsBoard client -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -# Connect to ThingsBoard -client.connect() -# Sending telemetry without checking the delivery status -client.send_telemetry(telemetry) -# Sending telemetry and checking the delivery status (QoS = 1 by default) -result = client.send_telemetry(telemetry) -# get is a blocking call that awaits delivery status -success = result.get() == TBPublishInfo.TB_ERR_SUCCESS -# Disconnect from ThingsBoard -client.disconnect() - -``` - -### MQTT using TLS - -TLS connection to localhost. See https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ for more information about client and ThingsBoard configuration. - -```python -from tb_device_mqtt import TBDeviceMqttClient -import socket - - -client = TBDeviceMqttClient(socket.gethostname()) -client.connect(tls=True, - ca_certs="mqttserver.pub.pem", - cert_file="mqttclient.nopass.pem") -client.disconnect() - -``` - -### HTTP - -````python -from tb_device_http import TBHTTPDevice - - -client = TBHTTPDevice('https://thingsboard.example.com', 'secret-token') -client.connect() -client.send_telemetry({'temperature': 41.9}) - -```` - -## Using Device APIs - -**TBDeviceMqttClient** provides access to Device MQTT APIs of ThingsBoard platform. It allows to publish telemetry and attribute updates, subscribe to attribute changes, send and receive RPC commands, etc. Use **TBHTTPClient** for the Device HTTP API. -#### Subscription to attributes -You can subscribe to attribute updates from the server. The following example demonstrates how to subscribe to attribute updates from the server. -##### MQTT -```python -import time -from tb_device_mqtt import TBDeviceMqttClient - - -def on_attributes_change(client, result, exception): - if exception is not None: - print("Exception: " + str(exception)) - else: - print(result) - - -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -client.connect() -client.subscribe_to_attribute("uploadFrequency", on_attributes_change) -client.subscribe_to_all_attributes(on_attributes_change) -while True: - time.sleep(1) - -``` - -##### HTTP -Note: The HTTP API only allows a subscription to updates for all attribute. -```python -from tb_device_http import TBHTTPClient - - -client = TBHTTPClient('https://thingsboard.example.com', 'secret-token') - -def callback(data): - print(data) - # ... - -# Subscribe -client.subscribe('attributes', callback) -# Unsubscribe -client.unsubscribe('attributes') - -``` - -#### Telemetry pack sending -You can send multiple telemetry messages at once. The following example demonstrates how to send multiple telemetry messages at once. -##### MQTT -```python - -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -import time - - -telemetry_with_ts = {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}} -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -# we set maximum amount of messages sent to send them at the same time. it may stress memory but increases performance -client.max_inflight_messages_set(100) -client.connect() -results = [] -result = True -for i in range(0, 100): - results.append(client.send_telemetry(telemetry_with_ts)) -for tmp_result in results: - result &= tmp_result.get() == TBPublishInfo.TB_ERR_SUCCESS -print("Result " + str(result)) -client.disconnect() - -``` -##### HTTP -Unsupported, the HTTP API does not allow the packing of values. - -#### Request attributes from server -You can request attributes from the server. The following example demonstrates how to request attributes from the server. +ThingsBoard is an open-source IoT platform for data collection, processing, visualization, and device management. +This is the **official Python client SDK** for connecting **Devices** and **Gateways** to ThingsBoard using **MQTT**. -##### MQTT -```python +--- -import time -from tb_device_mqtt import TBDeviceMqttClient +## ✨ Features +- **MQTT 3.1.1 / 5.0** with QoS 0 and 1 +- **Async I/O** built on [`gmqtt`](https://github.com/wialon/gmqtt) for high performance +- **Encrypted connections** (TLS v1.2+) +- **Automatic reconnect** with session persistence +- **All Device MQTT APIs** ([docs](https://thingsboard.io/docs/reference/mqtt-api/)) +- **All Gateway MQTT APIs** ([docs](https://thingsboard.io/docs/reference/gateway-mqtt-api/)) +- **Batching** and **payload splitting** for large telemetry/attribute messages +- **Rate limit awareness** (ThingsBoard-style multi-window limits) +- **Delivery tracking** — know exactly when your message reaches the server +- **Device claiming** +- **Firmware updates** -def on_attributes_change(client,result, exception): - if exception is not None: - print("Exception: " + str(exception)) - else: - print(result) +--- - -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -client.connect() -client.request_attributes(["configuration","targetFirmwareVersion"], callback=on_attributes_change) -while True: - time.sleep(1) - -``` - -##### HTTP -```python -from tb_device_http import TBHTTPClient - - -client = TBHTTPClient('https://thingsboard.example.com', 'secret-token') - -client_keys = ['attr1', 'attr2'] -shared_keys = ['shared1', 'shared2'] -data = client.request_attributes(client_keys=client_keys, shared_keys=shared_keys) - -``` - -#### Respond to server RPC call -You can respond to RPC calls from the server. The following example demonstrates how to respond to RPC calls from the server. -Please install psutil using 'pip install psutil' command before running the example. - -##### MQTT -```python - -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) -import time -import logging -from tb_device_mqtt import TBDeviceMqttClient - -# dependently of request method we send different data back -def on_server_side_rpc_request(client, request_id, request_body): - print(request_id, request_body) - if request_body["method"] == "getCPULoad": - client.send_rpc_reply(request_id, {"CPU percent": psutil.cpu_percent()}) - elif request_body["method"] == "getMemoryUsage": - client.send_rpc_reply(request_id, {"Memory": psutil.virtual_memory().percent}) - -client = TBDeviceMqttClient("127.0.0.1", username="A1_TEST_TOKEN") -client.set_server_side_rpc_request_handler(on_server_side_rpc_request) -client.connect() -while True: - time.sleep(1) - -``` - -##### HTTP -You can use HTTP API client in case you want to use HTTP API instead of MQTT API. -```python -from tb_device_http import TBHTTPClient - - -client = TBHTTPClient('https://thingsboard.example.com', 'secret-token') - -def callback(data): - rpc_id = data['id'] - # ... do something with data['params'] and data['method']... - response_params = {'result': 1} - client.send_rpc(name='rpc_response', rpc_id=rpc_id, params=response_params) - -# Subscribe -client.subscribe('rpc', callback) -# Unsubscribe -client.unsubscribe('rpc') - -``` - -## Using Gateway APIs - -**TBGatewayMqttClient** extends **TBDeviceMqttClient**, thus has access to all it's APIs as a regular device. -Besides, gateway is able to represent multiple devices connected to it. For example, sending telemetry or attributes on behalf of other, constrained, device. See more info about the gateway here: -#### Telemetry and attributes sending -```python -import time -from tb_gateway_mqtt import TBGatewayMqttClient - - -gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") -gateway.connect() -gateway.gw_connect_device("Test Device A1") - -gateway.gw_send_telemetry("Test Device A1", {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.2}}) -gateway.gw_send_attributes("Test Device A1", {"firmwareVersion": "2.3.1"}) - -gateway.gw_disconnect_device("Test Device A1") -gateway.disconnect() - -``` -#### Request attributes - -You can request attributes from the server. The following example demonstrates how to request attributes from the server. - -```python -import time -from tb_gateway_mqtt import TBGatewayMqttClient - - -def callback(result, exception): - if exception is not None: - print("Exception: " + str(exception)) - else: - print(result) - - -gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") -gateway.connect() -gateway.gw_request_shared_attributes("Test Device A1", ["temperature"], callback) - -while True: - time.sleep(1) - -``` -#### Respond to RPC - -You can respond to RPC calls from the server. The following example demonstrates how to respond to RPC calls from the server. -Please install psutil using 'pip install psutil' command before running the example. - -```python -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) - - -def rpc_request_response(client, request_id, request_body): - # request body contains id, method and other parameters - print(request_body) - method = request_body["data"]["method"] - device = request_body["device"] - req_id = request_body["data"]["id"] - # dependently of request method we send different data back - if method == 'getCPULoad': - gateway.gw_send_rpc_reply(device, req_id, {"CPU load": psutil.cpu_percent()}) - elif method == 'getMemoryLoad': - gateway.gw_send_rpc_reply(device, req_id, {"Memory": psutil.virtual_memory().percent}) - else: - print('Unknown method: ' + method) - - -gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") -gateway.connect() -# now rpc_request_response will process rpc requests from servers -gateway.gw_set_server_side_rpc_request_handler(rpc_request_response) -# without device connection it is impossible to get any messages -gateway.gw_connect_device("Test Device A1") -while True: - time.sleep(1) +## 📦 Installation +```bash +pip install tb-mqtt-client ``` -## Other Examples - -There are more examples for both [device](https://github.com/thingsboard/thingsboard-python-client-sdk/tree/master/examples/device) and [gateway](https://github.com/thingsboard/thingsboard-python-client-sdk/tree/master/examples/gateway) in corresponding [folders](https://github.com/thingsboard/thingsboard-python-client-sdk/tree/master/examples). - -## Support - - - [Community chat](https://gitter.im/thingsboard/chat) - - [Q&A forum](https://groups.google.com/forum/#!forum/thingsboard) - - [Stackoverflow](http://stackoverflow.com/questions/tagged/thingsboard) -## Licenses +## 🚀 Getting Started + +This SDK is **async-based**. +You can run it with `asyncio`. + +We provide **ready-to-run examples** for both device and gateway clients. + +--- + +### **Device Examples** +[`examples/device/`](examples/device): +- [Claim to customer](examples/device/claim_device.py) +- [Provision and connect with provisioned credentials](examples/device/client_provisioning.py) +- [Retrieve firmware](examples/device/firmware_update.py) +- [Handle attribute updates from server](examples/device/handle_attribute_updates.py) +- [Handle RPC requests from server](examples/device/handle_rpc_requests.py) +- [Load testing](examples/device/load.py) +- [Operational example](examples/device/operational_example.py) +- [Request attributes from server](examples/device/request_attributes.py) +- [Send attributes to server](examples/device/send_attributes.py) +- [Send client-side RPC to server and retrieve the response](examples/device/send_client_side_rpc.py) +- [Send timeseries to server](examples/device/send_timeseries.py) +- [Connect to server using MQTT over SSL and send encrypted timeseries](examples/device/tls_connect.py) + +--- + +### **Gateway Examples** +[`examples/gateway/`](examples/gateway): +- [Claim device](examples/gateway/claim_device.py) +- [Connect and disconnect device](examples/gateway/connect_and_disconnect_device.py) +- [Handle attribute updates for connected devices](examples/gateway/handle_attribute_updates.py) +- [Handle RPC requests for connected devices](examples/gateway/handle_rpc_requests.py) +- [Load testing](examples/gateway/load.py) +- [Operational example](examples/gateway/operational_example.py) +- [Request attributes for connected devices](examples/gateway/request_attributes.py) +- [Send attributes for connected devices to server](examples/gateway/send_attributes.py) +- [Send timeseries for connected devices to server](examples/gateway/send_timeseries.py) +- [Connect gateway client to server using MQTT over SSL and send encrypted timeseries](examples/gateway/tls_connect.py) + +--- + +## 📚 Documentation +- [ThingsBoard Device MQTT API](https://thingsboard.io/docs/reference/mqtt-api/) +- [ThingsBoard Gateway MQTT API](https://thingsboard.io/docs/reference/gateway-mqtt-api/) + +--- + +## 💬 Support +- [Community chat](https://gitter.im/thingsboard/chat) +- [Q&A forum](https://groups.google.com/forum/#!forum/thingsboard) +- [Stack Overflow](http://stackoverflow.com/questions/tagged/thingsboard) + +## 📄 Licenses This project is released under [Apache 2.0 License](./LICENSE). diff --git a/__init__.py b/__init__.py index 234aa97..6d6dc0d 100644 --- a/__init__.py +++ b/__init__.py @@ -12,6 +12,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - -__name__ = "tb_mqtt_client" diff --git a/dev_notes.md b/dev_notes.md new file mode 100644 index 0000000..f03d85e --- /dev/null +++ b/dev_notes.md @@ -0,0 +1,325 @@ +# Developer Notes for ThingsBoard Python Client SDK + +## Overview +This document provides explanations of tricky moments and major points in the ThingsBoard Python Client SDK implementation. It focuses on key areas that require special attention for developers working with or maintaining the codebase. + +## GMQTT Patches + +The SDK uses the GMQTT library for MQTT communication but implements several patches to enhance its functionality: + +### Key Components: +- `tb_mqtt_client/common/gmqtt_patch.py`: Contains the `PatchUtils` class that applies various patches to the GMQTT library. + +### Major Patches: + +1. **Enhanced Disconnect Handling**: + - The original GMQTT disconnect handling is patched to provide better error reporting and recovery. + - `patch_mqtt_handler_disconnect()` modifies how disconnect packets are processed, adding proper reason code extraction. + +2. **Connection Acknowledgment Handling**: + - `patch_handle_connack()` enhances the connection acknowledgment process to better handle MQTT 5.0 properties. + +3. **Connection Lost Recovery**: + - `patch_gmqtt_protocol_connection_lost()` improves how the client handles unexpected disconnections. + - This is critical for maintaining robust connections in unstable network environments. + +4. **PUBACK Handling with Reason Codes**: + - `patch_puback_handling()` adds support for MQTT 5.0 reason codes in publish acknowledgments. + - This allows the client to understand why a message was rejected by the broker. + +5. **Message Storage for Retries**: + - `patch_storage()` modifies how messages are stored to support the retry mechanism. + +## Message Queues + +The SDK implements a sophisticated message queuing system to handle message processing, rate limiting, and retries: + +### Key Components: +- `tb_mqtt_client/common/queue.py`: Implements `AsyncDeque`, a thread-safe asynchronous queue. +- `tb_mqtt_client/service/message_service.py`: Contains the `MessageService` class that manages multiple queues. + +### Queue Types: +1. **Initial Queue**: + - Entry point for all messages before they're routed to specialized queues. + - Implemented in `MessageService._dispatch_initial_queue_loop()`. + +2. **Service Queue**: + - Handles general service messages. + +3. **Device Uplink Messages Queue**: + - Specifically for device telemetry and attribute updates. + +4. **Gateway Uplink Messages Queue**: + - For gateway-specific messages that may contain data for multiple devices. + +5. **Retry by QoS Queue**: + - Special queue for handling QoS 1 messages that need to be retried. + - Implemented in `MessageService._dispatch_retry_by_qos_queue_loop()`. + +### Queue Processing: +- Each queue has its own asynchronous processing loop. +- Messages are processed according to rate limits and QoS requirements. +- The system uses `asyncio` for non-blocking queue operations. + +## QoS Messages Reprocessing + +The SDK implements a robust mechanism for handling QoS 1 messages, ensuring they are delivered at least once: + +### Key Components: +- `tb_mqtt_client/service/mqtt_manager.py`: Contains the `patch_client_for_retry_logic()` method. +- `tb_mqtt_client/common/gmqtt_patch.py`: Implements the retry mechanism in `_retry_loop()`. + +### Retry Process: +1. **Message Tracking**: + - Messages with QoS 1 are tracked until acknowledgment is received. + - If no acknowledgment is received within the timeout, the message is requeued. + +2. **Retry Loop**: + - `PatchUtils._retry_loop()` periodically checks for unacknowledged messages. + - Messages that haven't been acknowledged are put back into the retry queue. + +3. **Integration with Message Service**: + - `MQTTManager.patch_client_for_retry_logic()` connects the GMQTT client with the message service's retry queue. + - `MessageService.put_retry_message()` handles requeuing of messages that need to be retried. + +4. **Publish Monitoring**: + - `MQTTManager._monitor_ack_timeouts()` and `check_pending_publishes()` track publish operations and handle timeouts. + +## Gateway Message Dispatcher + +The gateway part of the SDK includes a message dispatcher for handling various types of events: + +### Key Components: +- `tb_mqtt_client/service/gateway/direct_event_dispatcher.py`: Implements the `DirectEventDispatcher` class. + +### Dispatcher Features: +1. **Event Registration**: + - Handlers can be registered for specific event types using `register()`. + - Multiple handlers can be registered for the same event type. + +2. **Event Dispatching**: + - `dispatch()` method routes events to the appropriate handlers. + - Supports both synchronous and asynchronous handlers. + +3. **Device-Specific Handling**: + - Can route events directly to a specific device session if provided. + - Otherwise, broadcasts to all registered handlers for the event type. + +4. **Error Handling**: + - Catches and logs exceptions in handlers to prevent one handler failure from affecting others. + +## Message Splitting + +The SDK includes mechanisms to handle large messages by splitting them into smaller chunks: + +### Key Components: +- `tb_mqtt_client/service/base_message_splitter.py`: Defines the base interface. +- `tb_mqtt_client/service/gateway/message_splitter.py`: Implements gateway-specific splitting logic. + +### Splitting Strategy: +1. **Size-Based Splitting**: + - Messages are split if they exceed the maximum payload size (default 55KB). + - Each split message maintains the original message's context. + +2. **Datapoint-Based Splitting**: + - Messages can also be split based on the number of datapoints. + - This helps prevent overwhelming the server with too many datapoints in a single message. + +3. **Message grouping**: + - Uplink data messages may be grouped to efficiently use the available payload size. + - This reduces the number of messages sent and optimizes network usage. + +4. **Future Chaining**: + - When a message is split, the futures from the original message are chained to the split messages. + - This ensures that the completion status is properly propagated back to the caller. + +## Rate Limiting + +The SDK implements rate limiting to prevent overwhelming the ThingsBoard server: + +### Key Components: +- `tb_mqtt_client/common/rate_limit/rate_limit.py`: Defines rate limit structures. +- `tb_mqtt_client/common/rate_limit/rate_limiter.py`: Implements the rate limiting logic. + +### Rate Limiting Features: +1. **Server-Provided Limits**: + - The SDK requests rate limits from the server upon connection. + - These limits are then applied to message publishing. + +2. **Separate Device and Gateway Limits**: + - Different rate limits can be applied to device and gateway messages. + +3. **Backpressure Control**: + - When rate limits are exceeded, the system applies backpressure to slow down message production. + + + +## Futures in Device and Gateway Uplink Messages + +### Overview of Futures in the SDK + +Futures are a critical component in the ThingsBoard Python Client SDK for handling asynchronous operations, particularly for message publishing. They provide a way to track the status of message delivery and propagate results back to the caller. + +### Futures in Uplink Messages + +#### Device and Gateway Message Structure + +Both `DeviceUplinkMessage` and `GatewayUplinkMessage` classes contain a `delivery_futures` field, which is a list of `asyncio.Future` objects. These futures are resolved when the message is successfully published to the ThingsBoard server or when an error occurs. + +```python +# From DeviceUplinkMessage +delivery_futures: List[Optional[asyncio.Future[PublishResult]]] +``` + +When a client application sends data to ThingsBoard, it can attach a future to the message to be notified when the message is delivered: + +```python +# Example usage +future = asyncio.get_running_loop().create_future() +message_builder.add_delivery_futures(future) +message = message_builder.build() +# Later, the application can await future to know when the message is delivered +result = await future # PublishResult object +``` + +### Message Splitting and Future Chaining + +One of the trickiest aspects of the SDK is how it handles futures when messages need to be split due to size limitations or datapoint count restrictions. + +#### The Splitting Process + +When a message exceeds the maximum payload size (default 55KB) or the maximum number of datapoints, it needs to be split into smaller chunks. This presents a challenge: how to maintain the relationship between the original message's future and the futures of the split messages? + +#### Future Chaining Mechanism + +The SDK uses a sophisticated future chaining mechanism implemented in the `FutureMap` class: + +```python +# From async_utils.py +class FutureMap: + def __init__(self): + self._child_to_parents: Dict[asyncio.Future, Set[asyncio.Future]] = {} + self._parent_to_remaining: Dict[asyncio.Future, Set[asyncio.Future]] = {} +``` + +This class maintains two mappings: +1. `_child_to_parents`: Maps each child future (from split messages) to its parent futures +2. `_parent_to_remaining`: Maps each parent future to the set of child futures that still need to be resolved + +##### Registration Process + +When a message is split, new futures are created for each split message, and these are registered with the original message's future: + +```python +# In message_splitter.py +shared_future = asyncio.get_running_loop().create_future() +shared_future.uuid = uuid4() +builder.add_delivery_futures(shared_future) + +built = builder.build() +result.append(built) +for parent in parent_futures: + future_map.register(parent, [shared_future]) +``` + +##### Resolution Process + +When a child future is resolved (when a split message is delivered), the `child_resolved` method is called: + +```python +# In FutureMap.child_resolved +def child_resolved(self, child: asyncio.Future): + parents = self._child_to_parents.pop(child, set()) + for parent in parents: + remaining = self._parent_to_remaining.get(parent) + if remaining is not None: + remaining.discard(child) + if not remaining and not parent.done(): + # All children resolved, resolve the parent + all_children = list(remaining) + [child] + results = [] + for f in all_children: + if f.done() and not f.cancelled(): + result = f.result() + if isinstance(result, PublishResult): + results.append(result) + + if results: + parent.set_result(PublishResult.merge(results)) + else: + parent.set_result(None) +``` + +This ensures that the parent future is only resolved when all of its child futures are resolved, and the results are merged. + +#### Differences Between Device and Gateway Splitters + +While the core mechanism is the same, there are some differences in how device and gateway message splitters handle futures: + +1. **Device Message Splitter**: + - Focuses on splitting individual device messages + - Handles timeseries and attributes separately + - Creates a shared future for each batch of data + +2. **Gateway Message Splitter**: + - Handles messages from multiple devices + - Groups data by device name and profile + - Maintains the device context across split messages + +### MQTT Manager's Role + +The `MQTTManager` class plays a crucial role in the future resolution process: + +```python +# In mqtt_manager.py +@staticmethod +async def _add_future_chain_processing(mqtt_future, message: MqttPublishMessage): + def resolve_attached(publish_future: asyncio.Future): + try: + try: + publish_result = publish_future.result() + except asyncio.CancelledError: + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + except Exception as exc: + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_result(publish_result) + future_map.child_resolved(f) + except Exception as e: + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_exception(e) +``` + +This method: +1. Attaches a callback to the MQTT publish future +2. When the publish operation completes, it resolves all the delivery futures attached to the message +3. Calls `future_map.child_resolved()` to propagate the resolution up the chain + +### Practical Implications + +This future chaining mechanism has several important implications: + +1. **Transparent Splitting**: Applications don't need to be aware that their messages are being split. They attach a future to the original message and receive a notification when all parts are delivered. + +2. **Error Handling**: If any part of a split message fails to deliver, the error is propagated to the parent future. + +3. **Performance Optimization**: The SDK can optimize message delivery by splitting large messages without breaking the promise to notify the application when delivery is complete. + +4. **Resource Management**: Futures are properly managed and resolved, preventing memory leaks even when messages are split into many parts. + +## Conclusion + +The ThingsBoard Python Client SDK implements several sophisticated mechanisms to ensure reliable message delivery, efficient processing, and robust error handling. +The future mechanism in the ThingsBoard Python Client SDK provides a robust way to handle asynchronous message delivery with transparent message splitting. +Understanding this mechanism is crucial for developers working with the SDK, especially when dealing with large messages or high-throughput scenarios. + +Key areas to be aware of when making changes: +1. The GMQTT patches that enhance the underlying MQTT library +2. The message queue system that manages message flow +3. The QoS message reprocessing mechanism that ensures reliable delivery +4. The gateway message dispatcher that routes events to appropriate handlers +5. The message splitting logic that handles large payloads +6. The rate limiting system that prevents overwhelming the server \ No newline at end of file diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..978ca26 --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is required to make the examples usable in blackbox tests. diff --git a/examples/device/__init__.py b/examples/device/__init__.py new file mode 100644 index 0000000..978ca26 --- /dev/null +++ b/examples/device/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is required to make the examples usable in blackbox tests. diff --git a/examples/device/claim_device.py b/examples/device/claim_device.py new file mode 100644 index 0000000..8c2acaf --- /dev/null +++ b/examples/device/claim_device.py @@ -0,0 +1,68 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to claim a device using ThingsBoard DeviceClient + +import asyncio +import logging + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +# Constants for connection +PLATFORM_HOST = "localhost" # Replace with your ThingsBoard host +DEVICE_ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your device's access token + +# Constants for claiming +CLAIMING_DURATION = 120 # Default claiming duration in seconds +CLAIMING_SECRET_KEY = "YOUR_SECRET_KEY" # Replace with your actual secret key + + +async def main(): + # Create device config + config = DeviceConfig() + config.host = PLATFORM_HOST + config.access_token = DEVICE_ACCESS_TOKEN + + # Create device client + client = DeviceClient(config) + await client.connect() + + # Build claim request with secret key and optional duration (in seconds) + claim_request = ClaimRequest.build(secret_key=CLAIMING_SECRET_KEY, duration=CLAIMING_DURATION) + + # Send claim request + result: PublishResult = await client.claim_device(claim_request, + wait_for_publish=True, + timeout=CLAIMING_DURATION + 10) + if result.is_successful(): + logger.info( + f"Claiming request was sent successfully. " + f"Please use the secret key '{CLAIMING_SECRET_KEY}' to claim the device from the dashboard.") + else: + logger.error(f"Failed to send claiming request. Result: {result}") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/claiming_device_pe_only.py b/examples/device/claiming_device_pe_only.py deleted file mode 100644 index f7bf8e6..0000000 --- a/examples/device/claiming_device_pe_only.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.DEBUG) - -THINGSBOARD_HOST = "127.0.0.1" -DEVICE_ACCESS_TOKEN = "DEVICE_ACCESS_TOKEN" - -SECRET_KEY = "DEVICE_SECRET_KEY" # Customer should write this key in device claiming widget -DURATION = 30000 # In milliseconds (30 seconds) - - -def main(): - client = TBDeviceMqttClient(THINGSBOARD_HOST, username=DEVICE_ACCESS_TOKEN) - client.connect() - info = client.claim(secret_key=SECRET_KEY, duration=DURATION) - if info.rc() == 0: - print("Claiming request was sent.") - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/device/client_provisioning.py b/examples/device/client_provisioning.py index 20cc3ab..29bef2d 100644 --- a/examples/device/client_provisioning.py +++ b/examples/device/client_provisioning.py @@ -1,70 +1,65 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Example script to device provisioning using the DeviceClient +import asyncio import logging -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -logging.basicConfig(level=logging.DEBUG) +from random import randint +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.provisioning_request import AccessTokenProvisioningCredentials, ProvisioningRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient -def main(): - """ - We can provide the following parameters to provisioning function: - host - required - Host of ThingsBoard - provision_device_key - required - device provision key from device profile - provision_device_secret - required - device provision secret from device profile - port=1883 - not required - MQTT port of ThingsBoard instance - device_name=None - may be generated on ThingsBoard - You may pass here name for device, if this parameter is not assigned, the name will be generated +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) - ### Credentials type = ACCESS_TOKEN +device_name = "ProvisionedDevice" - access_token=None - may be generated on ThingsBoard - You may pass here some access token and it will be saved as accessToken for device on ThingsBoard. +provisioning_credentials = AccessTokenProvisioningCredentials( + provision_device_key='YOUR_PROVISION_DEVICE_KEY', + provision_device_secret='YOUR_PROVISION_DEVICE_SECRET', +) +provisioning_request = ProvisioningRequest('localhost', + credentials=provisioning_credentials, + device_name=device_name) - ### Credentials type = MQTT_BASIC +async def main(): + provisioning_response = await DeviceClient.provision(provisioning_request) - client_id=None - not required (if username is not None) - You may pass here client Id for your device and use it later for connecting - username=None - not required (if client id is not None) - You may pass here username for your client and use it later for connecting - password=None - not required - You may pass here password and use it later for connecting + if not provisioning_response: + logger.error(f"Provisioning failed, no response received.") + return - ### Credentials type = X509_CERTIFICATE - hash=None - required (If you want to use this credentials type) - You should pass here public key of the device, generated from mqttserver.jks + if provisioning_response.error is not None: + logger.error(f"Provisioning failed: {provisioning_response.error}") + return - """ + logger.info(f'Provisioned device configuration: {provisioning_response}') - # Call device provisioning, to do this we don't need an instance of the TBDeviceMqttClient to provision device + # Create a DeviceClient instance with the provisioned device configuration + client = DeviceClient(provisioning_response.result) + await client.connect() - THINGSBOARD_HOST = "mqtt.thingsboard.cloud" + # Send single telemetry entry to the provisioned device + await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100))) - credentials = TBDeviceMqttClient.provision(THINGSBOARD_HOST, "PROVISION_DEVICE_KEY", "PROVISION_DEVICE_SECRET") + await client.stop() - if credentials is not None and credentials.get("status") == "SUCCESS": - username = None - password = None - client_id = None - if credentials["credentialsType"] == "ACCESS_TOKEN": - username = credentials["credentialsValue"] - elif credentials["credentialsType"] == "MQTT_BASIC": - username = credentials["credentialsValue"]["userName"] - password = credentials["credentialsValue"]["password"] - client_id = credentials["credentialsValue"]["clientId"] - client = TBDeviceMqttClient(THINGSBOARD_HOST, username=username, password=password, client_id=client_id) - client.connect() - # Other code - - client.stop() - - -if __name__ == '__main__': - main() +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/client_rpc_request.py b/examples/device/client_rpc_request.py deleted file mode 100644 index fb7f204..0000000 --- a/examples/device/client_rpc_request.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.INFO) - - -def callback(request_id, resp_body, exception=None): - if exception is not None: - logging.error("Exception: " + str(exception)) - else: - logging.info("request id: {request_id}, response body: {resp_body}".format(request_id=request_id, - resp_body=resp_body)) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - - client.connect() - # call "getTime" on server and receive result, then process it with callback - client.send_rpc_call("getTime", {}, callback) - try: - while not client.stopped: - time.sleep(1) - except KeyboardInterrupt: - client.disconnect() - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/device/firmware_update.py b/examples/device/firmware_update.py index 0c5b445..7063ef7 100644 --- a/examples/device/firmware_update.py +++ b/examples/device/firmware_update.py @@ -1,37 +1,55 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import time +# Example script to update firmware using the DeviceClient + +import asyncio import logging -from tb_device_mqtt import TBDeviceMqttClient, FW_STATE_ATTR -logging.basicConfig(level=logging.INFO) +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +firmware_received = asyncio.Event() +firmware_update_timeout = 30 + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + + +async def firmware_update_callback(firmware_data, firmware_info): + logger.info(f"Firmware update payload received: {firmware_info}") + firmware_received.set() + +async def main(): -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() + client = DeviceClient(config) + await client.connect() - client.get_firmware_update() + await client.update_firmware(on_received_callback=firmware_update_callback, save_firmware=False) # Set save_firmware to True if you want to save the firmware data - # Waiting for firmware to be delivered - while not client.current_firmware_info[FW_STATE_ATTR] == 'UPDATED': - time.sleep(1) + await asyncio.wait_for(firmware_received.wait(), timeout=firmware_update_timeout) - client.disconnect() - client.stop() + await client.stop() -if __name__ == '__main__': - main() +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/handle_attribute_updates.py b/examples/device/handle_attribute_updates.py new file mode 100644 index 0000000..5da55b9 --- /dev/null +++ b/examples/device/handle_attribute_updates.py @@ -0,0 +1,56 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to handle attribute updates from ThingsBoard using the DeviceClient + +import asyncio +import logging + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +async def attribute_update_callback(update: AttributeUpdate): + logger.info("Received attribute update: %r", update) + + +async def main(): + + client = DeviceClient(config) + client.set_attribute_update_callback(attribute_update_callback) + + await client.connect() + logger.info("Waiting for attribute updates... Press Ctrl+C to stop.") + + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down...") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/handle_rpc_requests.py b/examples/device/handle_rpc_requests.py new file mode 100644 index 0000000..24da791 --- /dev/null +++ b/examples/device/handle_rpc_requests.py @@ -0,0 +1,61 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to handle RPC requests from ThingsBoard using the DeviceClient + +import asyncio +import logging + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +async def rpc_request_callback(request: RPCRequest) -> RPCResponse: + logger.info("Received RPC: %r", request) + + if request.method == "ping": + return RPCResponse.build(request_id=request.request_id, result={"pong": True}) + else: + return RPCResponse.build(request_id=request.request_id, result={"message": "Unknown method"}) + + +async def main(): + + client = DeviceClient(config) + client.set_rpc_request_callback(rpc_request_callback) + + await client.connect() + logger.info("Waiting for RPCs... Press Ctrl+C to stop.") + try: + while True: + await asyncio.sleep(1) + except (KeyboardInterrupt, asyncio.CancelledError): + logger.info("Shutting down...") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/hardware_specs_sender.py b/examples/device/hardware_specs_sender.py deleted file mode 100644 index 48f0291..0000000 --- a/examples/device/hardware_specs_sender.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) -from tb_device_mqtt import TBDeviceMqttClient -# this example illustrates situation, where client send cpu and memory usage every 5 seconds. -# If client receives an update of uploadFrequency attribute, it changes frequency of other attributes publishing. -# Also client is listening to rpc and responds immediately to corresponding rpc methods from server -# ('getCPULoad' or 'getMemoryUsage') - -logging.basicConfig(level=logging.DEBUG) -uploadFrequency = 5 - - -# this callback changes global variable defining how often telemetry is sent -def on_upload_frequency_change(value, error): - global uploadFrequency - if "uploadFrequency" in value: - uploadFrequency = int(value["uploadFrequency"]) - elif "shared" in value and "uploadFrequency" in value["shared"]: - uploadFrequency = int(value["shared"]["uploadFrequency"]) - - -# dependently of request method we send different data back -def on_server_side_rpc_request(request_id, request_body): - print(client, request_id, request_body) - if request_body["method"] == "getCPULoad": - client.send_rpc_reply(request_id, {"CPU percent": psutil.cpu_percent(interval=0.1)}) - elif request_body["method"] == "getMemoryUsage": - client.send_rpc_reply(request_id, {"Memory": psutil.virtual_memory().percent}) - else: - print("Unknown method: " + request_body["method"]) - client.send_rpc_reply(request_id, "Unknown method: " + request_body["method"]) - - -client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") -client.set_server_side_rpc_request_handler(on_server_side_rpc_request) -client.connect() -# to fetch the latest setting for upload frequency configured on the server -client.request_attributes(shared_keys=["uploadFrequency"], callback=on_upload_frequency_change) -# to subscribe to future changes of upload frequency -client.subscribe_to_attribute(key="uploadFrequency", callback=on_upload_frequency_change) - - -def main(): - while True: - client.send_telemetry({"cpu": psutil.cpu_percent(), "memory": psutil.virtual_memory().percent}) - print("Sleeping for " + str(uploadFrequency)) - time.sleep(uploadFrequency) - - -if __name__ == '__main__': - main() diff --git a/examples/device/load.py b/examples/device/load.py new file mode 100644 index 0000000..7e5c20a --- /dev/null +++ b/examples/device/load.py @@ -0,0 +1,194 @@ +import asyncio +import logging +import signal +import time +from random import randint + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +# --- Logging setup --- +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +# --- Constants --- +THINGSBOARD_HOST = "localhost" +ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your actual access token +BATCH_SIZE = 1000 +MAX_PENDING_BATCHES = 100 +FUTURE_TIMEOUT = 1.0 + + +async def attribute_update_callback(update: AttributeUpdate): + logger.info("Received attribute update: %r", update) + + +async def rpc_request_callback(request: RPCRequest): + logger.info("Received RPC request: %r", request) + return RPCResponse(request_id=request.request_id, result={"status": "ok"}) + + +async def collect_futures_with_counts(future_or_result): + """ + Accepts either a Future or a direct PublishResult. + Returns [(Future or result, datapoint_count)] if successful, else []. + """ + if future_or_result is None: + return [] + + if isinstance(future_or_result, PublishResult): + if future_or_result.is_successful(): + return [(future_or_result, future_or_result.datapoints_count)] + return [] + + if isinstance(future_or_result, list): + results = [] + for item in future_or_result: + if isinstance(item, PublishResult): + if item.is_successful(): + results.append((item, item.datapoints_count)) + elif asyncio.isfuture(item) or asyncio.iscoroutine(item): + results.extend(await collect_futures_with_counts(item)) + else: + logger.warning("Unexpected item in list: %r", item) + return results + + if asyncio.isfuture(future_or_result) or asyncio.iscoroutine(future_or_result): + try: + result = await asyncio.wait_for(asyncio.shield(future_or_result), timeout=FUTURE_TIMEOUT) + if isinstance(result, PublishResult): + return [(future_or_result, result.datapoints_count)] + else: + logger.warning("Unexpected publish result: %r", result) + except asyncio.TimeoutError: + logger.warning("Timeout waiting for delivery future.") + except Exception as e: + logger.warning("Failed to get result from future: %s", e) + return [] + + logger.warning("collect_futures_with_counts() got unexpected type: %r", type(future_or_result)) + return [] + + +async def process_pending_futures(pending_futures, delivered_batches, delivered_datapoints): + try: + all_futures = [fut for fut, _ in pending_futures] + done, _ = await asyncio.wait(all_futures, timeout=FUTURE_TIMEOUT, return_when=asyncio.ALL_COMPLETED) + except Exception as e: + logger.warning("Wait exception: %s", e) + return delivered_batches, delivered_datapoints + + for fut, count in pending_futures: + if fut in done and fut.done() and not fut.cancelled(): + try: + result = fut.result() + if isinstance(result, PublishResult) and result.is_successful(): + delivered_batches += 1 + delivered_datapoints += result.datapoints_count + except Exception as e: + logger.warning("Future error: %s", e) + elif not fut.done(): + fut.cancel() + logger.warning("Cancelled future %r due to timeout.", getattr(fut, "uuid", None)) + + return delivered_batches, delivered_datapoints + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) + except NotImplementedError: + signal.signal(sig, lambda *_: _shutdown_handler()) + + config = DeviceConfig() + config.host = THINGSBOARD_HOST + config.access_token = ACCESS_TOKEN + + client = DeviceClient(config) + client.set_attribute_update_callback(attribute_update_callback) + client.set_rpc_request_callback(rpc_request_callback) + + await client.connect() + client._message_adapter.datapoints_max_count = 0 + logger.info("Connected to ThingsBoard.") + + sent_batches = 0 + delivered_batches = 0 + delivered_datapoints = 0 + pending_futures = [] + + delivery_start_ts = None + delivery_end_ts = None + + try: + delivery_start_ts = time.perf_counter() + # ts_now = int(datetime.now(UTC).timestamp() * 1000) + # You can add ts=ts_now to try with entities with ts, it will group messages by ts and send them together + entries = [TimeseriesEntry(f"temperature{i}", randint(20, 40)) for i in range(BATCH_SIZE)] + while not stop_event.is_set(): + + try: + fut = await client.send_timeseries(entries, wait_for_publish=False) + future_pairs = await collect_futures_with_counts(fut) + if future_pairs: + pending_futures.extend(future_pairs) + else: + logger.warning("No child delivery futures detected (batch dropped?).") + + sent_batches += 1 + + except Exception as e: + logger.warning("Failed to publish telemetry batch: %s", e) + + if len(pending_futures) >= MAX_PENDING_BATCHES: + delivered_batches, delivered_datapoints = await process_pending_futures( + pending_futures, delivered_batches, delivered_datapoints + ) + pending_futures.clear() + + if sent_batches % 10 == 0: + logger.info("Sent %d batches so far...", sent_batches) + + await asyncio.sleep(0) # yield control efficiently + + finally: + logger.info("Waiting for remaining telemetry batches to be acknowledged...") + + if pending_futures: + delivered_batches, delivered_datapoints = await process_pending_futures( + pending_futures, delivered_batches, delivered_datapoints + ) + delivery_end_ts = time.perf_counter() + + await client.disconnect() + logger.info("Disconnected cleanly.") + + if delivery_start_ts and delivery_end_ts: + delivery_duration = delivery_end_ts - delivery_start_ts + logger.info("Delivered %d batches / %d datapoints in %.6f seconds (%.0f datapoints/sec)", + delivered_batches, delivered_datapoints, delivery_duration, + delivered_datapoints / delivery_duration if delivery_duration > 0 else 0) + else: + logger.warning("No successful delivery occurred.") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except (KeyboardInterrupt, asyncio.CancelledError): + print("Interrupted by user.") diff --git a/examples/device/operational_example.py b/examples/device/operational_example.py new file mode 100644 index 0000000..57c1994 --- /dev/null +++ b/examples/device/operational_example.py @@ -0,0 +1,219 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import signal +from random import uniform, randint + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.DEBUG) +logging.getLogger("tb_mqtt_client").setLevel(logging.DEBUG) + +DELAY_BETWEEN_DATA_PUBLISH = 1 # seconds + + +async def attribute_update_callback(update: AttributeUpdate): + """ + Callback function to handle attribute updates. + :param update: The attribute update object. + """ + logger.info("Received attribute update: %r", update) + + +async def rpc_request_callback(request: RPCRequest) -> RPCResponse: + """ + Callback function to handle RPC requests. + :param request: The RPC request object. + :return: An RPCResponse object. + """ + logger.info("Received RPC request: %r", request) + + if request.method == "getError": + # Simulate an error response for demonstration purposes + logger.error("Simulated error for method: %s", request.method) + try: + # Simulate some processing that raises an error + raise RuntimeError("Simulated processing error") + except RuntimeError as e: + return RPCResponse.build(request_id=request.request_id, error=e) + else: + response_data = { + "message": f"Response for method {request.method}", + "params": request.params or {} + } + response = RPCResponse.build(request_id=request.request_id, result=response_data) + return response + + +async def rpc_response_callback(response: RPCResponse): + """ + Callback function to handle RPC responses for client side RPC requests. + :param response: The RPC response object. + """ + logger.info("Received RPC response in callback: %r", response) + + +async def attribute_request_callback(requested_attributes_response: RequestedAttributeResponse): + """ + Callback function to handle requested attributes. + :param requested_attributes_response: The requested attribute response object. + """ + logger.info("Received requested attributes response: %r", requested_attributes_response) + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + asyncio.get_event_loop().run_until_complete(client.stop()) + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) # noqa + except NotImplementedError: + # Windows compatibility fallback + signal.signal(sig, lambda *_: _shutdown_handler()) # noqa + + config = DeviceConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + + client = DeviceClient(config) + client.set_attribute_update_callback(attribute_update_callback) + client.set_rpc_request_callback(rpc_request_callback) + await client.connect() + + logger.info("Connected to ThingsBoard.") + + while not stop_event.is_set(): + # --- Attributes --- + iteration_start = asyncio.get_event_loop().time() + + # 1. Raw dict + raw_dict = { + "firmwareVersion": "1.0.4", + "hardwareModel": "TB-SDK-Device" + } + logger.info("Sending attributes...") + raw_publish_result = await client.send_attributes(raw_dict) + logger.info(f"Raw attributes sent: {raw_dict} with result: {raw_publish_result}") + + # 2. Single AttributeEntry + single_entry = AttributeEntry("mode", "normal") + logger.info("Sending single attribute: %s", single_entry) + single_attribute_publish_result = await client.send_attributes(single_entry) + logger.info(f"Single attribute sent: {single_entry} with result: {single_attribute_publish_result}") + + # 3. List of AttributeEntry + attr_entries = [ + AttributeEntry("maxTemperature", 85), + AttributeEntry("calibrated", True) + ] + logger.info("Sending list of attributes: %s", attr_entries) + attributes_list_publish_result = await client.send_attributes(attr_entries) + logger.info("List of attributes sent: %s with result: %s", attr_entries, attributes_list_publish_result) + + # --- Telemetry --- + + # 1. Raw dict + raw_dict = { + "temperature": round(uniform(20.0, 30.0), 2), + "humidity": 60 + } + logger.info("Sending raw telemetry...") + raw_telemetry_publish_result = await client.send_timeseries(raw_dict) + logger.info(f"Raw telemetry sent: {raw_dict} with result: {raw_telemetry_publish_result}") + + # 2. Single TimeseriesEntry (with ts) + single_entry = TimeseriesEntry("batteryLevel", randint(0, 100)) + logger.info("Sending single telemetry: %s", single_entry) + delivery_future = await client.send_timeseries(single_entry) + logger.info(f"Single telemetry sent: {single_entry} with delivery future: {delivery_future}") + + # 3. List of TimeseriesEntry with mixed timestamps + + telemetry_entries = [] + for i in range(100): + telemetry_entries.append(TimeseriesEntry("temperature%i" % i, i)) + logger.info("Sending list of telemetry entries with mixed timestamps...") + telemetry_list_publish_result = await client.send_timeseries(telemetry_entries) + logger.info("List of telemetry entries sent: %s with result: %s", + len(telemetry_entries) if len(telemetry_entries) > 10 else telemetry_entries, + telemetry_list_publish_result) + + # --- Attribute Request --- + + logger.info("Requesting attributes...") + + attribute_request = await AttributeRequest.build(["uno"], ["client"]) + + logger.info("Sending attribute request: %r", attribute_request) + + await client.send_attribute_request(attribute_request, attribute_request_callback) + + # --- Client-side RPC --- + + logger.info("Sending client side RPC request...") + + rpc_request = await RPCRequest.build("getSomeInformation", {"key1": "value1"}) + + logger.info("Sending RPC request: %r", rpc_request) + + try: + rpc_response = await client.send_rpc_request(rpc_request) + except TimeoutError as e: + logger.error("Timeout while sending RPC request: %s", e) + rpc_response = None + + if rpc_response: + logger.info("RPC response received: %s", rpc_response) + + rpc_request_2 = await RPCRequest.build("getAnotherInformation", {"param": "value"}) + + logger.info("Sending another RPC request: %r", rpc_request_2) + + await client.send_rpc_request(rpc_request_2, rpc_response_callback, wait_for_publish=False) + + try: + logger.info("Waiting before next iteration...") + timeout = DELAY_BETWEEN_DATA_PUBLISH - (asyncio.get_event_loop().time() - iteration_start) + await asyncio.wait_for(stop_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + logger.info("Going to next iteration...") + + logger.info("Disconnected cleanly.") + + +if __name__ == "__main__": + try: + loop = asyncio.get_event_loop() + loop.set_debug(False) # Enable debug mode for asyncio + loop.run_until_complete(main()) + except KeyboardInterrupt: + print("Interrupted by user.") diff --git a/examples/device/request_attributes.py b/examples/device/request_attributes.py index e4529d5..ea60339 100644 --- a/examples/device/request_attributes.py +++ b/examples/device/request_attributes.py @@ -1,47 +1,67 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Example script to request attributes from ThingsBoard using the DeviceClient + +import asyncio import logging -import time -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.INFO) +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +response_received = asyncio.Event() + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +async def attribute_request_callback(response: RequestedAttributeResponse): + logger.info("Received attribute response: %r", response) + response_received.set() -def on_attributes_change(result, exception=None): - # This is a callback function that will be called when we receive the response from the server - if exception is not None: - logging.error("Exception: " + str(exception)) - else: - logging.info(result) +async def main(): -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - # Sending data to retrieve it later - client.send_attributes({"atr1": "value1", "atr2": "value2"}) - # Requesting attributes - client.request_attributes(["atr1", "atr2"], callback=on_attributes_change) + client = DeviceClient(config) + await client.connect() + + # Send client attribute to have it available for request + await client.send_attributes({"currentTemperature": 22.5}) + + await asyncio.sleep(.1) + + # Request specific attributes + request = await AttributeRequest.build(["targetTemperature"], ["currentTemperature"]) + await client.send_attribute_request(request, attribute_request_callback) + + logger.info("Attribute request sent. Waiting for response...") try: - # Waiting for the callback - while not client.stopped: - time.sleep(1) - except KeyboardInterrupt: - client.disconnect() - client.stop() + await asyncio.wait_for(response_received.wait(), timeout=10) + except (asyncio.CancelledError, TimeoutError): + logger.info("Attribute request cancelled.") + + await client.stop() -if __name__ == '__main__': - main() +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/send_attributes.py b/examples/device/send_attributes.py new file mode 100644 index 0000000..04aa74d --- /dev/null +++ b/examples/device/send_attributes.py @@ -0,0 +1,69 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to send attributes to ThingsBoard using the DeviceClient + +import asyncio +import logging + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + + +async def main(): + + client = DeviceClient(config) + + await client.connect() + + # Send attribute as raw dictionary + raw_attributes = { + "firmwareVersion": "1.0.3", + "hardwareModel": "TB-SDK-Device" + } + logger.info("Sending raw attributes: %s", raw_attributes) + await client.send_attributes(raw_attributes) + logger.info("Raw attributes sent successfully.") + + # Send single attribute entry + single_attribute = AttributeEntry("mode", "normal") + logger.info("Sending single attribute: %s", single_attribute) + await client.send_attributes(single_attribute) + logger.info("Single attribute sent successfully.") + + # Send list of attributes + attribute_list = [ + AttributeEntry("location", "Building A"), + AttributeEntry("status", "active") + ] + logger.info("Sending list of attributes: %s", attribute_list) + await client.send_attributes(attribute_list) + logger.info("List of attributes sent successfully.") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/send_client_side_rpc.py b/examples/device/send_client_side_rpc.py new file mode 100644 index 0000000..0020a25 --- /dev/null +++ b/examples/device/send_client_side_rpc.py @@ -0,0 +1,69 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to send a client-side RPC request to ThingsBoard using the DeviceClient + +import asyncio +import logging + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +response_received = asyncio.Event() + +async def rpc_response_callback(response: RPCResponse): + logger.info("Received RPC response in callback: %r", response) + response_received.set() + +rpc_response = None + +async def main(): + global rpc_response + + client = DeviceClient(config) + await client.connect() + + # Send client-side RPC and wait for response + rpc_request = await RPCRequest.build("getTime", {}) + try: + rpc_response = await client.send_rpc_request(rpc_request) + logger.info("Received RPC response: %r", rpc_response) + except TimeoutError: + logger.info("RPC request timed out") + + # Send client-side RPC with response callback + rpc_request_2 = await RPCRequest.build("getStatus", {"param1": "value1", "param2": "value2"}) + await client.send_rpc_request(rpc_request_2, rpc_response_callback, wait_for_publish=False) + + await asyncio.wait_for(response_received.wait(), timeout=30) + await client.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Shutting down...") diff --git a/examples/device/send_telemetry_and_attr.py b/examples/device/send_telemetry_and_attr.py deleted file mode 100644 index 1f734db..0000000 --- a/examples/device/send_telemetry_and_attr.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -from time import sleep, time -logging.basicConfig(level=logging.DEBUG) - -telemetry = {"temperature": 41.9, "humidity": 69, "enabled": False, "currentFirmwareVersion": "v1.2.2"} -telemetry_as_array = [{"temperature": 42.0}, {"humidity": 70}, {"enabled": True}, {"currentFirmwareVersion": "v1.2.3"}] -telemetry_with_ts = {"ts": int(round(time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}} -telemetry_with_ts_as_array = [{"ts": 1451649600000, "values": {"temperature": 42.2, "humidity": 71}}, - {"ts": 1451649601000, "values": {"temperature": 42.3, "humidity": 72}}] -attributes = {"sensorModel": "DHT-22", "attribute_2": "value"} - -log = logging.getLogger(__name__) - - -def on_connect(client, userdata, flags, result_code, *extra_params, tb_client): - if result_code == 0: - log.info("Connected to ThingsBoard!") - else: - log.error("Failed to connect to ThingsBoard with result code: %d", result_code) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect(callback=on_connect) - - while not client.is_connected(): - sleep(1) - - # Sending data in async way - client.send_attributes(attributes) - client.send_telemetry(telemetry) - client.send_telemetry(telemetry_as_array, quality_of_service=1) - client.send_telemetry(telemetry_with_ts) - client.send_telemetry(telemetry_with_ts_as_array) - - # Waiting for data to be delivered - result = client.send_attributes(attributes) - log.info("Attribute update sent: " + str(result.rc() == TBPublishInfo.TB_ERR_SUCCESS)) - result = client.send_attributes(attributes) - log.info("Telemetry update sent: " + str(result.rc() == TBPublishInfo.TB_ERR_SUCCESS)) - - client.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/device/send_telemetry_pack.py b/examples/device/send_telemetry_pack.py deleted file mode 100644 index 0a426ee..0000000 --- a/examples/device/send_telemetry_pack.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_device_mqtt import TBDeviceMqttClient, TBPublishInfo -import time - -logging.basicConfig(level=logging.DEBUG) - -telemetry_with_ts = {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}} - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - - results = [] - result = True - - for i in range(0, 100): - results.append(client.send_telemetry( - {"ts": int(round(time.time() * 1000)), "values": {"temperature": 42.1, "humidity": 70}})) - - for tmp_result in results: - result &= tmp_result.get() == TBPublishInfo.TB_ERR_SUCCESS - - print("Result " + str(result)) - - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/device/send_timeseries.py b/examples/device/send_timeseries.py new file mode 100644 index 0000000..2afb2b1 --- /dev/null +++ b/examples/device/send_timeseries.py @@ -0,0 +1,74 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This example demonstrates how to send time series data from a device to ThingsBoard using the DeviceClient. + +import asyncio +import logging +from random import uniform, randint +from time import time + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + + +config = DeviceConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + + +async def main(): + + client = DeviceClient(config) + await client.connect() + + # Send time series as raw dictionary + raw_timeseries = { + "temperature": round(uniform(20.0, 30.0), 2), + "humidity": randint(30, 70) + } + logger.info("Sending raw timeseries: %s", raw_timeseries) + await client.send_timeseries(raw_timeseries) + logger.info("Raw timeseries sent successfully.") + + # Send single time series entry + single_entry = TimeseriesEntry("pressure", round(uniform(950.0, 1050.0), 2)) + logger.info("Sending single timeseries entry: %s", single_entry) + await client.send_timeseries(single_entry) + logger.info("Single timeseries entry sent successfully.") + + # Send a list of time series entries + ts = int(time() * 1000) + entries = [ + TimeseriesEntry("vibration", 0.05, ts), + TimeseriesEntry("speed", 123, ts), + TimeseriesEntry("vibration", 0.01, ts - 1000), + TimeseriesEntry("speed", 120, ts - 1000), + ] + logger.info("Sending list of timeseries entries: %s", entries) + await client.send_timeseries(entries) + logger.info("List of timeseries entries sent successfully.") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/device/subscription_to_attrs.py b/examples/device/subscription_to_attrs.py deleted file mode 100644 index 5e5c93f..0000000 --- a/examples/device/subscription_to_attrs.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -from tb_device_mqtt import TBDeviceMqttClient -logging.basicConfig(level=logging.DEBUG) - - -def callback(result, *args): - logging.info("Received data: %r", result) - - -def main(): - client = TBDeviceMqttClient("127.0.0.1", username="A2_TEST_TOKEN") - client.connect() - sub_id_1 = client.subscribe_to_attribute("frequency", callback) - sub_id_2 = client.subscribe_to_all_attributes(callback) - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - client.unsubscribe_from_attribute(sub_id_1) - client.unsubscribe_from_attribute(sub_id_2) - client.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/device/tls_connect.py b/examples/device/tls_connect.py index 606f9aa..aaa99ca 100644 --- a/examples/device/tls_connect.py +++ b/examples/device/tls_connect.py @@ -1,25 +1,68 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This example demonstrates how to connect to ThingsBoard over SSL using the DeviceClient and send time series. + +import asyncio import logging -from tb_device_mqtt import TBDeviceMqttClient -import socket - -logging.basicConfig(level=logging.DEBUG) -# connecting to localhost -client = TBDeviceMqttClient(socket.gethostname(), username="A2_TEST_TOKEN") -client.connect(tls=True, - ca_certs="mqttserver.pub.pem", - cert_file="mqttclient.nopass.pem") -client.disconnect() +from random import randint + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.client import DeviceClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host +PLATFORM_PORT = 8883 # Default port for MQTT over SSL + +# Update with your CA certificate, client certificate, and client key paths. There are no default files generated. +# You can generate them using the following guides: +# Certificates for server - https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ +# Certificates for client - https://thingsboard.io/docs/user-guide/certificates/?ubuntuThingsboardX509=X509Leaf +CA_CERT_PATH = "ca_cert.pem" # Update with your CA certificate path (Default - ca_cert.pem in the examples directory) +CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) +CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) + + +async def main(): + config = DeviceConfig() + + config.host = PLATFORM_HOST + config.port = PLATFORM_PORT + + config.ca_cert = CA_CERT_PATH + config.client_cert = CLIENT_CERT_PATH + config.private_key = CLIENT_KEY_PATH + + client = DeviceClient(config) + await client.connect() + + # Send telemetry entry + result = await client.send_timeseries(TimeseriesEntry("batteryLevel", randint(0, 100)), + wait_for_publish=True) + if result is not None and result.is_successful(): + logger.info("Telemetry sent successfully") + else: + logger.error(f"Failed to send telemetry: {result}") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/gateway/__init__.py b/examples/gateway/__init__.py new file mode 100644 index 0000000..978ca26 --- /dev/null +++ b/examples/gateway/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is required to make the examples usable in blackbox tests. diff --git a/examples/gateway/claim_device.py b/examples/gateway/claim_device.py new file mode 100644 index 0000000..3632c7c --- /dev/null +++ b/examples/gateway/claim_device.py @@ -0,0 +1,85 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequestBuilder +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) + +CLAIMING_SECRET_KEY = "YOUR_CLAIMING_SECRET" # Replace it with your actual claiming secret key +CLAIMING_DURATION_MS = 30000 # Duration in milliseconds for which the claim is valid + + +async def main(): + config = GatewayConfig() + config.host = "localhost" + config.access_token = "YOUR_ACCESS_TOKEN" + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + device_name = "Device for claiming" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + logger.info("Device connected successfully: %s", device_name) + + # Claiming request creation + logger.info("Creating claim request for device: %s", device_name) + + device_claim_request = ClaimRequest.build(CLAIMING_SECRET_KEY, CLAIMING_DURATION_MS) + + # Note: You can claim several devices at once by adding them to the gateway claim request + gateway_claim_request_builder = GatewayClaimRequestBuilder() + gateway_claim_request_builder.add_device_request(device_session, device_claim_request) + + # Send claim request to the platform + publish_results = await client.send_device_claim_request(device_session, + gateway_claim_request_builder.build(), + wait_for_publish=True) + if not publish_results: + logger.error("Failed to send claim request for device: %s", device_name) + return + + logger.info( + "Claim request sent successfully for device: %s, " + "you have %r seconds to claim device using ThingsBoard UI or API.", + device_name, CLAIMING_DURATION_MS / 1000) + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/claiming_device_pe_only.py b/examples/gateway/claiming_device_pe_only.py deleted file mode 100644 index 4fdd8a5..0000000 --- a/examples/gateway/claiming_device_pe_only.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.DEBUG) - -THINGSBOARD_HOST = "127.0.0.1" -GATEWAY_ACCESS_TOKEN = "GATEWAY_ACCESS_TOKEN" - -DEVICE_NAME = "DEVICE_NAME" -SECRET_KEY = "DEVICE_SECRET_KEY" # Customer should write this key in device claiming widget -DURATION = 30000 # In milliseconds (30 seconds) - - -def main(): - client = TBGatewayMqttClient(THINGSBOARD_HOST, username=GATEWAY_ACCESS_TOKEN) - client.connect() - - """ - You are able to provide every parameter or pass claiming request like: - request_example = { - "DEVICE A": { - "secretKey": "DEVICE_A_SECRET_KEY", - "durationMs": "30000" - }, - "DEVICE B": { - "secretKey": "DEVICE_B_SECRET_KEY", - "durationMs": "60000" - } - - info = client.gw_claim(claiming_request=request_example).wait_for_publish() - - """ - - client.gw_connect_device(DEVICE_NAME) - - info = client.gw_claim(device_name=DEVICE_NAME, secret_key=SECRET_KEY, duration=DURATION) - - if info.rc() == 0: - print("Claiming request was sent.") - client.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/connect_and_disconnect_device.py b/examples/gateway/connect_and_disconnect_device.py new file mode 100644 index 0000000..7ed51be --- /dev/null +++ b/examples/gateway/connect_and_disconnect_device.py @@ -0,0 +1,61 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) + +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + +async def main(): + client = GatewayClient(config) + + await client.connect() + + # Connecting device + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + logger.info("Device connected successfully: %s", device_name) + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + await asyncio.sleep(.1) # Wait for the disconnect to complete + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/connect_disconnect_device.py b/examples/gateway/connect_disconnect_device.py deleted file mode 100644 index c6a45c9..0000000 --- a/examples/gateway/connect_disconnect_device.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.INFO) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - gateway.gw_connect_device("Example Name") - # device disconnecting will not delete device, gateway just stops receiving messages - gateway.gw_disconnect_device("Example Name") - gateway.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/handle_attribute_updates.py b/examples/gateway/handle_attribute_updates.py new file mode 100644 index 0000000..a007358 --- /dev/null +++ b/examples/gateway/handle_attribute_updates.py @@ -0,0 +1,85 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +configure_logging() +logger = get_logger(__name__) + +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + + +async def attribute_update_handler(device_session: DeviceSession, attribute_update: AttributeUpdate): + """ + Callback to handle attribute updates. + :param device_session: Device session for which attributes were requested. + :param attribute_update: Updated attributes object that contains attribute entries. + """ + logger.info("Received attribute update for device %s: %s", + device_session.device_info.device_name, attribute_update.entries) + + +async def main(): + client = GatewayClient(config) + + await client.connect() + + # Connecting device + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + # Register callback for requested attributes + client.device_manager.set_attribute_update_callback(device_session.device_info.device_id, attribute_update_handler) + + logger.info("Device connected successfully: %s", device_name) + + # Loop to keep the client running and processing Attribute updates + try: + logger.info("Client loop started, waiting for Attribute updates...") + while True: + await asyncio.sleep(1) # Keep the loop running + except (asyncio.CancelledError, KeyboardInterrupt): + logger.info("Client loop stopped, shutting down.") + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + await asyncio.sleep(.1) # Wait for disconnection to complete + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/handle_rpc_requests.py b/examples/gateway/handle_rpc_requests.py new file mode 100644 index 0000000..6ffa377 --- /dev/null +++ b/examples/gateway/handle_rpc_requests.py @@ -0,0 +1,105 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +configure_logging() +logger = get_logger(__name__) + +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + + +async def device_rpc_request_handler(device_session: DeviceSession, + rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: + """ + Callback to handle RPC requests from the device. + :param device_session: Device session for which the request was made. + :param rpc_request: RPC request from the platform. + """ + logger.info("Received RPC request for device %s: %s", device_session.device_info.device_name, rpc_request) + + response_data = { + "status": "success", + "message": f"RPC request '{rpc_request.method}' processed successfully.", + "data": { + "device_name": device_session.device_info.device_name, + "request_id": rpc_request.request_id, + "method": rpc_request.method, + "params": rpc_request.params + } + } + + rpc_response = GatewayRPCResponse.build(device_session.device_info.device_name, + rpc_request.request_id, + response_data) + + logger.info("Sending RPC response for request id %r: %r", rpc_request.request_id, rpc_response) + + return rpc_response + + +async def main(): + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + # Register callback for requested attributes + client.device_manager.set_rpc_request_callback(device_session.device_info.device_id, device_rpc_request_handler) + + logger.info("Device connected successfully: %s", device_name) + + # Loop to keep the client running and processing RPC requests + try: + logger.info("Client loop started, waiting for RPC requests...") + while True: + await asyncio.sleep(1) # Keep the loop running + except (asyncio.CancelledError, KeyboardInterrupt): + logger.info("Client loop stopped, shutting down.") + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/load.py b/examples/gateway/load.py new file mode 100644 index 0000000..e4972d1 --- /dev/null +++ b/examples/gateway/load.py @@ -0,0 +1,150 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import signal +import time +from datetime import UTC, datetime +from random import randint +from typing import List + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +# --- Logging --- +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +# --- Constants --- +THINGSBOARD_HOST = "localhost" +ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your actual access token +NUM_DEVICES = 10 +BATCH_SIZE = 1000 +MAX_PENDING = 10 +FUTURE_TIMEOUT = 1.0 +DEVICE_PREFIX = "perf-test-device" +WAIT_FOR_PUBLISH = False # Set to True if you want to wait for publish confirmation +# (can slow down the test, because each message will wait for confirmation) + + +# --- Test logic --- +async def send_batch(client: GatewayClient, session: DeviceSession) -> List[asyncio.Future]: + ts = int(datetime.now(UTC).timestamp() * 1000) + entries = [TimeseriesEntry(f"temp{i}", randint(20, 35), ts=ts) for i in range(BATCH_SIZE)] + result = await client.send_device_timeseries(session, entries, wait_for_publish=WAIT_FOR_PUBLISH) + if isinstance(result, list): + return result + elif result is not None: + return [result] + return [] + + +async def wait_for_futures(futures: List[asyncio.Future]) -> int: + delivered = 0 + if isinstance(futures, list) and futures and isinstance(futures[0], asyncio.Future): + done, _ = await asyncio.wait(futures, timeout=FUTURE_TIMEOUT, return_when=asyncio.ALL_COMPLETED) + else: + done = futures + + for fut in done: + try: + if isinstance(fut, asyncio.Future): + res = fut.result() + else: + res = fut + if isinstance(res, PublishResult) and res.is_successful(): + delivered += res.datapoints_count + except Exception as e: + logger.warning("Future error: %s", e) + return delivered + + +async def main(): + stop_event = asyncio.Event() + + def _shutdown(): + logger.info("Shutting down by signal...") + stop_event.set() + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown) + except NotImplementedError: + signal.signal(sig, lambda *_: _shutdown()) + + config = GatewayConfig() + config.host = THINGSBOARD_HOST + config.access_token = ACCESS_TOKEN + + client = GatewayClient(config) + await client.connect() + + logger.info("Connected to ThingsBoard as gateway") + + # Register test devices + sessions: List[DeviceSession] = [] + for i in range(NUM_DEVICES): + device_name = f"{DEVICE_PREFIX}-{i}" + session, _ = await client.connect_device(device_name, wait_for_publish=False) + sessions.append(session) + + logger.info("Registered %d devices", len(sessions)) + + sent_batches = 0 + delivered_dp = 0 + pending_futures: List[asyncio.Future] = [] + + start = time.perf_counter() + + try: + while not stop_event.is_set(): + for session in sessions: + futures = await send_batch(client, session) + pending_futures.extend(futures) + sent_batches += 1 + + if len(pending_futures) >= MAX_PENDING: + delivered = await wait_for_futures(pending_futures) + delivered_dp += delivered + pending_futures.clear() + logger.info("Delivered datapoints so far: %d", delivered_dp) + + await asyncio.sleep(0) # yield to event loop + + finally: + logger.info("Flushing remaining futures...") + if pending_futures: + delivered_dp += await wait_for_futures(pending_futures) + end = time.perf_counter() + + duration = end - start + logger.info("Sent %d batches across %d devices", sent_batches, NUM_DEVICES) + logger.info("Delivered %d datapoints in %.2f seconds (%.0f datapoints/sec)", + delivered_dp, duration, delivered_dp / duration if duration > 0 else 0) + await client.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except (KeyboardInterrupt, asyncio.CancelledError): + print("Interrupted by user.") diff --git a/examples/gateway/operational_example.py b/examples/gateway/operational_example.py new file mode 100644 index 0000000..a15c9c7 --- /dev/null +++ b/examples/gateway/operational_example.py @@ -0,0 +1,178 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import signal +from random import uniform, randint +from time import time + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +# ---- Logging Setup ---- +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.DEBUG) +logging.getLogger("tb_mqtt_client").setLevel(logging.DEBUG) + +# ---- Constants ---- +GATEWAY_HOST = "localhost" +GATEWAY_PORT = 1883 # Default MQTT port, change if needed +GATEWAY_ACCESS_TOKEN = "YOUR_ACCESS_TOKEN" # Replace with your actual access token + +DELAY_BETWEEN_DATA_PUBLISH = 1 # seconds + + +# ---- Callbacks ---- +async def attribute_update_handler(device_session: DeviceSession, update: AttributeUpdate): + logger.info("Received attribute update for %s: %s", + device_session.device_info.device_name, update.entries) + + +async def requested_attributes_handler(device_session: DeviceSession, response: RequestedAttributeResponse): + logger.info("Requested attributes for %s -> client: %r, shared: %r", + device_session.device_info.device_name, + response.client, + response.shared) + + +async def device_rpc_request_handler(device_session: DeviceSession, + rpc_request: GatewayRPCRequest) -> GatewayRPCResponse: + logger.info("Received RPC request for %s: %r", device_session.device_info.device_name, rpc_request) + response_data = { + "status": "success", + "echo_method": rpc_request.method, + "params": rpc_request.params + } + return GatewayRPCResponse.build(device_session.device_info.device_name, rpc_request.request_id, response_data) + + +# ---- Main ---- +async def main(): + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + asyncio.create_task(client.stop()) + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) + except NotImplementedError: + signal.signal(sig, lambda *_: _shutdown_handler()) + + # ---- Gateway Config ---- + config = GatewayConfig() + config.host = GATEWAY_HOST + config.port = GATEWAY_PORT + config.access_token = GATEWAY_ACCESS_TOKEN + + global client + client = GatewayClient(config) + + await client.connect() + logger.info("Gateway connected to ThingsBoard.") + + # ---- Register Devices ---- + devices = [ + ("Test Device A1", "Test devices"), + ("Test Device B1", "Test devices") + ] + sessions = {} + + for name, profile in devices: + session, _ = await client.connect_device(name, profile, wait_for_publish=True) + if not session: + logger.error("Failed to connect %s", name) + continue + sessions[name] = session + logger.info("Device connected: %s", name) + + # Register callbacks for each device + client.device_manager.set_attribute_update_callback(session.device_info.device_id, attribute_update_handler) + client.device_manager.set_attribute_response_callback(session.device_info.device_id, + requested_attributes_handler) + client.device_manager.set_rpc_request_callback(session.device_info.device_id, device_rpc_request_handler) + + # ---- Main loop ---- + while not stop_event.is_set(): + iteration_start = time() + + for device_name, session in sessions.items(): + logger.info("Publishing data for %s", device_name) + + # --- Attributes --- + raw_attrs = {"firmwareVersion": "2.0.0", "location": "office"} + await client.send_device_attributes(session, raw_attrs, wait_for_publish=True) + + single_attr = AttributeEntry("mode", "auto") + await client.send_device_attributes(session, single_attr, wait_for_publish=True) + + multi_attrs = [ + AttributeEntry("maxTemp", randint(60, 90)), + AttributeEntry("calibrated", True) + ] + await client.send_device_attributes(session, multi_attrs, wait_for_publish=True) + + # --- Telemetry --- + raw_ts = {"temperature": round(uniform(20, 30), 2), "humidity": randint(40, 80)} + await client.send_device_timeseries(session, raw_ts, wait_for_publish=True) + + list_ts = [ + {"ts": int(time() * 1000), "values": {"temperature": 25.5}}, + {"ts": int(time() * 1000) - 1000, "values": {"humidity": 65}} + ] + await client.send_device_timeseries(session, list_ts, wait_for_publish=True) + + ts_entries = [TimeseriesEntry(f"temp_{i}", randint(0, 50)) for i in range(5)] + await client.send_device_timeseries(session, ts_entries, wait_for_publish=True) + + # --- Attribute Request --- + attr_req = await GatewayAttributeRequest.build( + device_session=session, + client_keys=["firmwareVersion", "mode"] + ) + await client.send_device_attributes_request(session, attr_req, wait_for_publish=True) + + # ---- Delay handling ---- + try: + timeout = DELAY_BETWEEN_DATA_PUBLISH - (time() - iteration_start) + if timeout > 0: + await asyncio.wait_for(stop_event.wait(), timeout=timeout) + except asyncio.TimeoutError: + pass + + # ---- Disconnect devices ---- + for session in sessions.values(): + await client.disconnect_device(session, wait_for_publish=True) + await client.disconnect() + logger.info("Gateway disconnected cleanly.") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Interrupted by user.") diff --git a/examples/gateway/request_attributes.py b/examples/gateway/request_attributes.py index 765fa54..d8cf8fe 100644 --- a/examples/gateway/request_attributes.py +++ b/examples/gateway/request_attributes.py @@ -1,45 +1,110 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import logging -import time +import asyncio -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.INFO) +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession -def callback(result, exception=None): - if exception is not None: - logging.error("Exception: " + str(exception)) - else: - logging.info(result) +configure_logging() +logger = get_logger(__name__) +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - # Requesting attributes - gateway.gw_request_shared_attributes("Example Name", ["temperature"], callback) +device_name = "Test Device B1" +device_profile = "Test devices" - try: - # Waiting for the callback - while not gateway.stopped: - time.sleep(1) - except KeyboardInterrupt: - gateway.disconnect() - gateway.stop() +async def requested_attributes_handler(device_session: DeviceSession, response: RequestedAttributeResponse): + """ + Callback to handle requested attributes. + :param device_session: Device session for which attributes were requested. + :param response: Response containing requested attributes. + """ + logger.info("Received attributes for device %s, client attributes: %r, shared attributes: %r", + device_session.device_info.device_name, + response.client, + response.shared) + + +async def main(): + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + # Register callback for requested attributes + client.device_manager.set_attribute_response_callback(device_session.device_info.device_id, + requested_attributes_handler) + + logger.info("Device connected successfully: %s", device_name) + + # Send attributes to request them later + attributes = [ + AttributeEntry(key="maintenance", value="scheduled"), + AttributeEntry(key="id", value=341), + AttributeEntry(key="location", value="office") + ] + logger.info("Sending attributes: %s", attributes) + await client.send_device_attributes(device_session=device_session, data=attributes, wait_for_publish=True) + logger.info("Attributes sent successfully.") -if __name__ == '__main__': - main() + await asyncio.sleep(1) # Wait for attributes to be processed + + # Request attributes for the device + logger.info("Requesting attributes for device: %s", device_name) + attributes_to_request = ["maintenance", "id", "location"] + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, + client_keys=attributes_to_request) + + await client.send_device_attributes_request(device_session, attribute_request, wait_for_publish=True) + + # Trying to request shared attribute "state" (Response will be empty if not set) + + logger.info("Requesting shared attribute 'state' for device: %s", device_name) + shared_attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["state"]) + await client.send_device_attributes_request(device_session, shared_attribute_request, wait_for_publish=True) + + await asyncio.sleep(1) # Wait for the response to be processed + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/respond_to_rpc.py b/examples/gateway/respond_to_rpc.py deleted file mode 100644 index 2d15096..0000000 --- a/examples/gateway/respond_to_rpc.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging.handlers -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -try: - import psutil -except ImportError: - print("Please install psutil using 'pip install psutil' command") - exit(1) -logging.basicConfig(level=logging.INFO) - - -def rpc_request_response(gateway, request_body): - # request body contains id, method and other parameters - logging.info(request_body) - method = request_body["data"]["method"] - device = request_body["device"] - req_id = request_body["data"]["id"] - # dependently of request method we send different data back - if method == 'getCPULoad': - gateway.gw_send_rpc_reply(device, req_id, psutil.cpu_percent(interval=0.1)) - elif method == 'getMemoryLoad': - gateway.gw_send_rpc_reply(device, req_id, psutil.virtual_memory().percent) - else: - print('Unknown method: ' + method) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - # now rpc_request_response will process rpc requests from servers - gateway.gw_set_server_side_rpc_request_handler(rpc_request_response) - # without device connection it is impossible to get any messages - gateway.gw_connect_device("Test Device A2", "default") - try: - # Waiting for the callback - while not gateway.stopped: - time.sleep(1) - except KeyboardInterrupt: - gateway.disconnect() - gateway.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/send_attributes.py b/examples/gateway/send_attributes.py new file mode 100644 index 0000000..34443bd --- /dev/null +++ b/examples/gateway/send_attributes.py @@ -0,0 +1,87 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.service.gateway.client import GatewayClient + + +configure_logging() +logger = get_logger(__name__) + +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + + +async def main(): + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + logger.info("Device connected successfully: %s", device_name) + + # Send attributes as raw dictionary + raw_attributes = { + "maintenance": "scheduled", + "id": 341, + } + logger.info("Sending raw attributes: %s", raw_attributes) + await client.send_device_attributes(device_session=device_session, data=raw_attributes, wait_for_publish=True) + logger.info("Raw timeseries sent successfully.") + + # Send single attribute entry + single_attribute = AttributeEntry(key="location", value="office") + logger.info("Sending single attribute: %s", single_attribute) + await client.send_device_attributes(device_session=device_session, data=single_attribute, wait_for_publish=True) + logger.info("Single attribute sent successfully.") + + # Send multiple attribute entries + multiple_attributes = [ + AttributeEntry(key="status", value="active"), + AttributeEntry(key="version", value="1.0.0") + ] + logger.info("Sending multiple attributes: %s", multiple_attributes) + await client.send_device_attributes(device_session=device_session, data=multiple_attributes, wait_for_publish=True) + logger.info("Multiple attributes sent successfully.") + + # Disconnect device + logger.info("Disconnecting device: %s", device_name) + await client.disconnect_device(device_session, wait_for_publish=True) + logger.info("Device disconnected successfully: %s", device_name) + + # Disconnect client + await client.disconnect() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/send_telemetry_and_attributes.py b/examples/gateway/send_telemetry_and_attributes.py deleted file mode 100644 index 279255c..0000000 --- a/examples/gateway/send_telemetry_and_attributes.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -from tb_gateway_mqtt import TBGatewayMqttClient - -logging.basicConfig(level=logging.INFO) - - -attributes = {"atr1": 1, "atr2": True, "atr3": "value3"} -telemetry_simple = {"ts": int(round(time.time() * 1000)), "values": {"key1": "11"}} -telemetry_array = [ - {"ts": int(round(time.time() * 1000)), "values": {"key1": "11"}}, - {"ts": int(round(time.time() * 1000)), "values": {"key2": "22"}} -] - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - # without device connection it is impossible to get any messages - gateway.connect() - - gateway.gw_send_telemetry("Test Device A2", telemetry_simple) - gateway.gw_send_telemetry("Test Device A2", telemetry_array) - gateway.gw_send_attributes("Test Device A2", attributes) - gateway.disconnect() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/send_timeseries.py b/examples/gateway/send_timeseries.py new file mode 100644 index 0000000..f20304e --- /dev/null +++ b/examples/gateway/send_timeseries.py @@ -0,0 +1,111 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import signal +from time import time + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) + +config = GatewayConfig() +config.host = "localhost" +config.access_token = "YOUR_ACCESS_TOKEN" + +device_name = "Test Device B1" +device_profile = "Test devices" + +async def main(): + client = GatewayClient(config) + + await client.connect() + + # Connecting device + + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + if not device_session: + logger.error("Failed to register device: %s", device_name) + return + + logger.info("Device connected successfully: %s", device_name) + stop_event = asyncio.Event() + + def _shutdown_handler(): + stop_event.set() + asyncio.gather(client.stop(), return_exceptions=True) + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _shutdown_handler) # noqa + except NotImplementedError: + # Windows compatibility fallback + signal.signal(sig, lambda *_: _shutdown_handler()) # noqa + + loop_counter = 0 + + while not stop_event.is_set(): + loop_counter += 1 + logger.info("Sending timeseries data, iteration: %d", loop_counter) + + # Send time series as raw dictionary + raw_timeseries = { + "temperature": 25.5, + "humidity": 60 + } + logger.info("Sending raw timeseries: %s", raw_timeseries) + await client.send_device_timeseries(device_session=device_session, data=raw_timeseries, wait_for_publish=True) + logger.info("Raw timeseries sent successfully.") + + # Send time series as list of dictionaries + ts = int(time() * 1000) + list_timeseries = [ + {"ts": ts, "values": {"temperature": 26.0, "humidity": 65}}, + {"ts": ts - 1000, "values": {"temperature": 26.5, "humidity": 70}} + ] + logger.info("Sending list of timeseries: %s", list_timeseries) + await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) + logger.info("List of timeseries sent successfully.") + + # Send time series as TimeseriesEntry objects + ts = int(time() * 1000) + timeseries_entries = [ + TimeseriesEntry(key="temperature%i" % i, value=loop_counter, ts=ts) for i in range(20) + ] + logger.info("Sending TimeseriesEntry objects: %s", timeseries_entries) + await client.send_device_timeseries(device_session=device_session, data=timeseries_entries, + wait_for_publish=True) + logger.info("TimeseriesEntry objects sent successfully.") + + try: + logger.info("Waiting before next iteration...") + await asyncio.wait_for(stop_event.wait(), timeout=1) + except asyncio.TimeoutError: + logger.info("Going to next iteration...") + + await client.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received, shutting down.") diff --git a/examples/gateway/subscribe_to_attributes.py b/examples/gateway/subscribe_to_attributes.py deleted file mode 100644 index 463a361..0000000 --- a/examples/gateway/subscribe_to_attributes.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging.handlers -import time - -from tb_gateway_mqtt import TBGatewayMqttClient -logging.basicConfig(level=logging.INFO) - - -def callback(result): - logging.info("Callback for attributes, %r", result) - - -def callback_for_everything(result): - logging.info("Everything goes here, %r", result) - - -def callback_for_specific_attr(result): - logging.info("Specific attribute callback, %r", result) - - -def main(): - gateway = TBGatewayMqttClient("127.0.0.1", username="TEST_GATEWAY_TOKEN") - gateway.connect() - # without device connection it is impossible to get any messages - gateway.gw_connect_device("ImageTest") - - gateway.gw_subscribe_to_all_attributes(callback_for_everything) - - gateway.gw_subscribe_to_attribute("ImageTest", "image", callback_for_specific_attr) - - sub_id = gateway.gw_subscribe_to_all_device_attributes("ImageTest", callback) - gateway.gw_unsubscribe(sub_id) - - try: - # Waiting for the callback - while not gateway.stopped: - time.sleep(1) - except KeyboardInterrupt: - gateway.disconnect() - gateway.stop() - - -if __name__ == '__main__': - main() diff --git a/examples/gateway/tls_connect.py b/examples/gateway/tls_connect.py index 617472f..dfb3ffe 100644 --- a/examples/gateway/tls_connect.py +++ b/examples/gateway/tls_connect.py @@ -1,25 +1,75 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This example demonstrates how to connect to ThingsBoard over SSL using the GatewayClient, +# connect a device, and send telemetry data securely. + +import asyncio import logging -from tb_gateway_mqtt import TBGatewayMqttClient -import socket - -logging.basicConfig(level=logging.DEBUG) -# connecting to localhost -gateway = TBGatewayMqttClient(socket.gethostname()) -gateway.connect(tls=True, - ca_certs="mqttserver.pub.pem", - cert_file="mqttclient.nopass.pem") -gateway.disconnect() +import random + +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import configure_logging, get_logger +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.gateway.client import GatewayClient + +configure_logging() +logger = get_logger(__name__) +logger.setLevel(logging.INFO) +logging.getLogger("tb_mqtt_client").setLevel(logging.INFO) + +PLATFORM_HOST = 'localhost' # Update with your ThingsBoard host +PLATFORM_PORT = 8883 # Default port for MQTT over SSL + +# Update with your CA certificate, client certificate, and client key paths. There are no default files generated. +# You can generate them using the following guides: +# Certificates for server - https://thingsboard.io/docs/user-guide/mqtt-over-ssl/ +# Certificates for client - https://thingsboard.io/docs/user-guide/certificates/?ubuntuThingsboardX509=X509Leaf +CA_CERT_PATH = "ca_cert.pem" # Update with your CA certificate path (Default - ca_cert.pem in the examples directory) +CLIENT_CERT_PATH = "cert.pem" # Update with your client certificate path (Default - cert.pem in the examples directory) +CLIENT_KEY_PATH = "key.pem" # Update with your client key path (Default - key.pem in the examples directory) + + +async def main(): + config = GatewayConfig() + + config.host = PLATFORM_HOST + config.port = PLATFORM_PORT + + config.ca_cert = CA_CERT_PATH + config.client_cert = CLIENT_CERT_PATH + config.private_key = CLIENT_KEY_PATH + + client = GatewayClient(config) + await client.connect() + + device_name = "Test Device B1" + device_profile = "Test devices" + logger.info("Connecting device: %s", device_name) + device_session, publish_results = await client.connect_device(device_name, device_profile, wait_for_publish=True) + + # Sending telemetry data to the connected device + list_timeseries = [ + TimeseriesEntry(key="temperature", value=random.randint(20, 35)), + TimeseriesEntry(key="humidity", value=random.randint(40, 80)) + ] + logger.info("Sending list of timeseries: %s", list_timeseries) + await client.send_device_timeseries(device_session=device_session, data=list_timeseries, wait_for_publish=True) + logger.info("List of timeseries sent successfully.") + + await client.stop() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b5c44ee --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "thingsboard-python-client-sdk" +version = "2.0.0" +description = "Python MQTT client SDK for ThingsBoard platform" +authors = [ + { name = "ThingsBoard Team", email = "info@thingsboard.io" }, +] +license = "Apache-2.0" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "uvloop>=0.21.0", + "gmqtt>=0.6.10", + "orjson>=0.2.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-asyncio>=0.21", + "mypy>=1.8", + "black>=24.0", + "isort>=5.12", + "ruff>=0.1.0", +] + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "-ra -q" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*", "*Tests"] +python_functions = ["test_*"] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" + +[tool.black] +line-length = 100 +target-version = ["py38"] +exclude = ''' +/( + \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 + +[tool.mypy] +python_version = "3.12" +ignore_missing_imports = true +disallow_untyped_defs = true +strict_optional = true +warn_unused_ignores = true +warn_return_any = true +no_implicit_optional = true + +[tool.ruff] +line-length = 100 +target-version = "py312" +select = ["E", "F", "I", "B", "W"] \ No newline at end of file diff --git a/sdk_utils.py b/sdk_utils.py deleted file mode 100644 index 07e9a16..0000000 --- a/sdk_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from random import randint -from zlib import crc32 -from hashlib import sha256, sha384, sha512, md5 -import logging -from subprocess import CalledProcessError - -from utils import install_package - - -try: - from mmh3 import hash, hash128 -except ImportError: - try: - from pymmh3 import hash, hash128 - except ImportError: - try: - install_package('mmh3') - except CalledProcessError: - install_package('pymmh3') - -try: - from mmh3 import hash, hash128 # noqa -except ImportError: - from pymmh3 import hash, hash128 - -log = logging.getLogger(__name__) - - -def verify_checksum(firmware_data, checksum_alg, checksum): - if firmware_data is None: - log.debug('Firmware wasn\'t received!') - return False - if checksum is None: - log.debug('Checksum was\'t provided!') - return False - checksum_of_received_firmware = None - log.debug('Checksum algorithm is: %s' % checksum_alg) - if checksum_alg.lower() == "sha256": - checksum_of_received_firmware = sha256(firmware_data).digest().hex() - elif checksum_alg.lower() == "sha384": - checksum_of_received_firmware = sha384(firmware_data).digest().hex() - elif checksum_alg.lower() == "sha512": - checksum_of_received_firmware = sha512(firmware_data).digest().hex() - elif checksum_alg.lower() == "md5": - checksum_of_received_firmware = md5(firmware_data).digest().hex() - elif checksum_alg.lower() == "murmur3_32": - reversed_checksum = f'{hash(firmware_data, signed=False):0>2X}' - if len(reversed_checksum) % 2 != 0: - reversed_checksum = '0' + reversed_checksum - checksum_of_received_firmware = "".join( - reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - elif checksum_alg.lower() == "murmur3_128": - reversed_checksum = f'{hash128(firmware_data, signed=False):0>2X}' - if len(reversed_checksum) % 2 != 0: - reversed_checksum = '0' + reversed_checksum - checksum_of_received_firmware = "".join( - reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - elif checksum_alg.lower() == "crc32": - reversed_checksum = f'{crc32(firmware_data) & 0xffffffff:0>2X}' - if len(reversed_checksum) % 2 != 0: - reversed_checksum = '0' + reversed_checksum - checksum_of_received_firmware = "".join( - reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() - else: - log.error('Client error. Unsupported checksum algorithm.') - log.debug(checksum_of_received_firmware) - random_value = randint(0, 5) - if random_value > 3: - log.debug('Dummy fail! Do not panic, just restart and try again the chance of this fail is ~20%') - return False - return checksum_of_received_firmware == checksum diff --git a/setup.py b/setup.py index 2413ce8..1118092 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,16 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from os import path from setuptools import setup @@ -21,7 +20,7 @@ with open(path.join(this_directory, 'README.md')) as f: long_description = f.read() -VERSION = "1.13.9" +VERSION = "2.0" setup( version=VERSION, @@ -33,6 +32,6 @@ url="https://github.com/thingsboard/thingsboard-python-client-sdk", long_description=long_description, long_description_content_type="text/markdown", - python_requires=">=3.9", + python_requires=">=3.12", packages=["."], - install_requires=['tb-paho-mqtt-client>=2.1.2', 'requests>=2.31.0', 'orjson']) + install_requires=['gmqtt', 'orjson']) diff --git a/tb_device_http.py b/tb_device_http.py deleted file mode 100644 index c4a1e5d..0000000 --- a/tb_device_http.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ThingsBoard HTTP API device module.""" -import threading -import logging -import queue -import time -import typing -from datetime import datetime, timezone -from sdk_utils import verify_checksum - -import requests -from math import ceil - -FW_CHECKSUM_ATTR = "fw_checksum" -FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" -FW_SIZE_ATTR = "fw_size" -FW_TITLE_ATTR = "fw_title" -FW_VERSION_ATTR = "fw_version" - -FW_STATE_ATTR = "fw_state" - -REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] - - -class TBHTTPAPIException(Exception): - """ThingsBoard HTTP Device API Exception class.""" - - -class TBProvisionFailure(TBHTTPAPIException): - """Exception raised if device provisioning failed.""" - - -class TBHTTPDevice: - """ThingsBoard HTTP Device API class. - - :param host: The ThingsBoard hostname. - :param token: The device token. - :param name: A name for this device. The name is only set locally. - """ - - def __init__(self, host: str, token: str, name: str = None, chunk_size: int = 0): - self.__session = requests.Session() - self.__session.headers.update({'Content-Type': 'application/json'}) - self.__config = { - 'host': host, 'token': token, 'name': name, 'timeout': 30 - } - self.__worker = { - 'publish': { - 'queue': queue.Queue(), - 'thread': threading.Thread(target=self.__publish_worker, daemon=True), - 'stop_event': threading.Event() - }, - 'attributes': { - 'thread': threading.Thread(target=self.__subscription_worker, - daemon=True, - kwargs={'endpoint': 'attributes'}), - 'stop_event': threading.Event(), - }, - 'rpc': { - 'thread': threading.Thread(target=self.__subscription_worker, - daemon=True, - kwargs={'endpoint': 'rpc'}), - 'stop_event': threading.Event(), - } - } - self.current_firmware_info = { - "current_fw_title": None, - "current_fw_version": None - } - self.chunk_size = chunk_size - - def __repr__(self): - return f'' - - @property - def host(self) -> str: - """Get the ThingsBoard hostname.""" - return self.__config['host'] - - @property - def name(self) -> str: - """Get the device name.""" - return self.__config['name'] - - @property - def timeout(self) -> int: - """Get the connection timeout.""" - return self.__config['timeout'] - - @property - def api_base_url(self) -> str: - """Get the ThingsBoard API base URL.""" - return f'{self.host}/api/v1/{self.token}' - - @property - def token(self) -> str: - """Get the device token.""" - return self.__config['token'] - - @property - def logger(self) -> logging.Logger: - """Get the logger instance.""" - return logging.getLogger('TBHTTPDevice') - - @property - def log_level(self) -> str: - """Get the log level.""" - levels = {0: 'NOTSET', 10: 'DEBUG', 20: 'INFO', 30: 'WARNING', 40: 'ERROR', 50: 'CRITICAL'} - return levels.get(self.logger.level) - - @log_level.setter - def log_level(self, value: typing.Union[int, str]): - self.logger.setLevel(value) - self.logger.critical('Log level set to %s', self.log_level) - - def __get_firmware_info(self): - response = self.__session.get( - f"{self.__config['host']}/api/v1/{self.__config['token']}/attributes", - params={"sharedKeys": REQUIRED_SHARED_KEYS}).json() - return response.get("shared", {}) - - def __get_firmware(self, fw_info): - chunk_count = ceil(fw_info.get(FW_SIZE_ATTR, 0) / self.chunk_size) if self.chunk_size > 0 else 0 - firmware_data = b'' - for chunk_number in range(chunk_count + 1): - params = {"title": fw_info.get(FW_TITLE_ATTR), - "version": fw_info.get(FW_VERSION_ATTR), - "size": self.chunk_size if self.chunk_size < fw_info.get(FW_SIZE_ATTR, - 0) else fw_info.get( - FW_SIZE_ATTR, 0), - "chunk": chunk_number - } - self.logger.debug(params) - self.logger.debug( - 'Getting chunk with number: %s. Chunk size is : %r byte(s).' % (chunk_number + 1, self.chunk_size)) - response = self.__session.get( - f"{self.__config['host']}/api/v1/{self.__config['token']}/firmware", - params=params) - if response.status_code != 200: - self.logger.error('Received error:') - response.raise_for_status() - return - firmware_data = firmware_data + response.content - return firmware_data - - def __on_firmware_received(self, firmware_info, firmware_data): - with open(firmware_info.get(FW_TITLE_ATTR), "wb") as firmware_file: - firmware_file.write(firmware_data) - - self.logger.info('Firmware is updated!\n Current firmware version is: %s' % firmware_info.get(FW_VERSION_ATTR)) - - def get_firmware_update(self): - self.send_telemetry(self.current_firmware_info) - self.logger.info('Getting firmware info from %s' % self.__config['host']) - - firmware_info = self.__get_firmware_info() - if (firmware_info.get(FW_VERSION_ATTR) is not None and firmware_info.get( - FW_VERSION_ATTR) != self.current_firmware_info.get("current_" + FW_VERSION_ATTR)) \ - or (firmware_info.get(FW_TITLE_ATTR) is not None and firmware_info.get( - FW_TITLE_ATTR) != self.current_firmware_info.get("current_" + FW_TITLE_ATTR)): - self.logger.info('New firmware available!') - - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADING" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - - firmware_data = self.__get_firmware(firmware_info) - - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADED" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - - verification_result = verify_checksum(firmware_data, firmware_info.get(FW_CHECKSUM_ALG_ATTR), - firmware_info.get(FW_CHECKSUM_ATTR)) - - if verification_result: - self.logger.debug('Checksum verified!') - self.current_firmware_info[FW_STATE_ATTR] = "VERIFIED" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - else: - self.logger.debug('Checksum verification failed!') - self.current_firmware_info[FW_STATE_ATTR] = "FAILED" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - firmware_data = self.__get_firmware(firmware_info) - return - - self.current_firmware_info[FW_STATE_ATTR] = "UPDATING" - time.sleep(1) - self.send_telemetry(self.current_firmware_info) - - self.__on_firmware_received(firmware_info, firmware_data) - - current_firmware_info = { - "current_" + FW_TITLE_ATTR: firmware_info.get(FW_TITLE_ATTR), - "current_" + FW_VERSION_ATTR: firmware_info.get(FW_VERSION_ATTR), - FW_STATE_ATTR: "UPDATED" - } - time.sleep(1) - self.send_telemetry(current_firmware_info) - - def start_publish_worker(self): - """Start the publish worker thread.""" - self.__worker['publish']['stop_event'].clear() - self.__worker['publish']['thread'].start() - - def stop_publish_worker(self): - """Stop the publish worker thread.""" - self.__worker['publish']['stop_event'].set() - - def __publish_worker(self): - """Publish telemetry data from the queue.""" - logger = self.logger.getChild('worker.publish') - logger.info('Start publisher thread') - logger.debug('Perform connection test before entering worker loop') - if not self.test_connection(): - logger.error('Connection test failed, exit publisher thread') - return - logger.debug('Connection test successful') - while True: - if not self.__worker['publish']['queue'].empty(): - try: - task = self.__worker['publish']['queue'].get(timeout=1) - except queue.Empty: - if self.__worker['publish']['stop_event'].is_set(): - break - continue - - endpoint = task.pop('endpoint') - - try: - self._publish_data(task, endpoint) - except Exception as error: - # ToDo: More precise exception catching - logger.error(error) - task.update({'endpoint': endpoint}) - self.__worker['publish']['queue'].put(task) - time.sleep(1) - else: - logger.debug('Published %s to %s', task, endpoint) - self.__worker['publish']['queue'].task_done() - - time.sleep(.2) - - logger.info('Stop publisher thread.') - - def test_connection(self) -> bool: - """Test connection to the API. - - :return: True if no errors occurred, False otherwise. - """ - self.logger.debug('Start connection test') - success = False - try: - self._publish_data(data={}, endpoint='telemetry') - except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as error: - self.logger.debug(error) - except requests.exceptions.HTTPError as error: - self.logger.debug(error) - status_code = error.response.status_code - if status_code == 401: - self.logger.error('Error 401: Unauthorized. Check if token is correct.') - else: - self.logger.error('Error %s', status_code) - else: - self.logger.debug('Connection test successful') - success = True - finally: - self.logger.debug('End connection test') - return success - - def connect(self) -> bool: - """Publish an empty telemetry data to ThingsBoard to test the connection. - - :return: True if connected, false otherwise. - """ - if self.test_connection(): - self.logger.info('Connected to ThingsBoard') - self.start_publish_worker() - return True - return False - - def _publish_data(self, data: dict, endpoint: str, timeout: int = None) -> dict: - """Send POST data to ThingsBoard. - - :param data: The data dictionary to send. - :param endpoint: The receiving API endpoint. - :param timeout: Override the instance timeout for this request. - """ - response = self.__session.post( - url=f'{self.api_base_url}/{endpoint}', - json=data, - timeout=timeout or self.timeout) - response.raise_for_status() - return response.json() if response.content else {} - - def _get_data(self, params: dict, endpoint: str, timeout: int = None) -> dict: - """Retrieve data with GET from ThingsBoard. - - :param params: A dictionary with the parameters for the request. - :param endpoint: The receiving API endpoint. - :param timeout: Override the instance timeout for this request. - :return: A dictionary with the response from the ThingsBoard instance. - """ - response = self.__session.get( - url=f'{self.api_base_url}/{endpoint}', - params=params, - timeout=timeout or self.timeout) - response.raise_for_status() - return response.json() - - def send_telemetry(self, telemetry: dict, timestamp: datetime = None, queued: bool = True): - """Publish telemetry to ThingsBoard. - - :param telemetry: A dictionary with the telemetry data to send. - :param timestamp: Timestamp to set for the values. If not set the ThingsBoard server uses - the time of reception as timestamp. - :param queued: Add the telemetry to the queue. If False, the data is send immediately. - """ - timestamp = datetime.now() if timestamp is None else timestamp - payload = { - 'ts': int(timestamp.replace(tzinfo=timezone.utc).timestamp() * 1000), - 'values': telemetry, - } - if queued: - payload.update({'endpoint': 'telemetry'}) - self.__worker['publish']['queue'].put(payload) - else: - self._publish_data(payload, 'telemetry') - - def send_attributes(self, attributes: dict): - """Send attributes to ThingsBoard. - - :param attributes: Attributes to send. - """ - self._publish_data(attributes, 'attributes') - - def send_rpc(self, name: str, params: dict = None, rpc_id: int = None) -> dict: - """Send RPC to ThingsBoard and return response. - - :param name: Name of the RPC method. - :param params: Parameter for the RPC. - :param rpc_id: Specify an Id for this RPC. - :return: A dictionary with the response. - """ - endpoint = f'rpc/{rpc_id}' if rpc_id else 'rpc' - return self._publish_data({'method': name, 'params': params or {}}, endpoint) - - def request_attributes(self, client_keys: list = None, shared_keys: list = None) -> dict: - """Request attributes from ThingsBoard. - - :param client_keys: A list of keys for client attributes. - :param shared_keys: A list of keys for shared attributes. - :return: A dictionary with the request attributes. - """ - params = {'client_keys': client_keys, 'shared_keys': shared_keys} - return self._get_data(params=params, endpoint='attributes') - - def __subscription_worker(self, endpoint: str, timeout: int = None): - """Worker thread for subscription to HTTP API endpoints. - - :param endpoint: The endpoint name. - :param timeout: Timeout value in seconds. - """ - logger = self.logger.getChild(f'worker.subscription.{endpoint}') - stop_event = self.__worker[endpoint]['stop_event'] - logger.info('Start subscription to %s updates', endpoint) - - if not self.__worker[endpoint].get('callback'): - logger.warning('No callback set for %s subscription', endpoint) - stop_event.set() - callback = self.__worker[endpoint].get('callback', lambda data: None) - params = { - 'timeout': (timeout or self.timeout) * 1000 - } - url = { - 'attributes': f'{self.api_base_url}/attributes/updates', - 'rpc': f'{self.api_base_url}/rpc' - } - logger.debug('Timeout set to %ss', params['timeout'] / 1000) - while not stop_event.is_set(): - response = self.__session.get(url=url[endpoint], - params=params, - timeout=params['timeout']) - if stop_event.is_set(): - break - if response.status_code == 408: # Request timeout - continue - if response.status_code == 504: # Gateway Timeout - continue # Reconnect - response.raise_for_status() - callback(response.json()) - time.sleep(.1) - - stop_event.clear() - logger.info('Stop subscription to %s updates', endpoint) - - def subscribe(self, endpoint: str, callback: typing.Callable[[dict], None] = None): - """Subscribe to updates from a given endpoint. - - :param endpoint: The endpoint to subscribe. - :param callback: Callback to execute on an update. Takes a dict as only argument. - """ - if endpoint not in ['attributes', 'rpc']: - raise ValueError - if callback: - if not callable(callback): - raise TypeError - self.__worker[endpoint]['callback'] = callback - self.__worker[endpoint]['stop_event'].clear() - self.__worker[endpoint]['thread'].start() - - def unsubscribe(self, endpoint: str): - """Unsubscribe from a given endpoint. - - :param endpoint: The endpoint to unsubscribe. - """ - if endpoint not in ['attributes', 'rpc']: - raise ValueError - self.logger.debug('Set stop event for %s subscription', endpoint) - self.__worker[endpoint]['stop_event'].set() - - @classmethod - def provision(cls, host: str, device_name: str, device_key: str, device_secret: str): - """Initiate device provisioning and return a device instance. - - :param host: The root URL to the ThingsBoard instance. - :param device_name: Name of the device to provision. - :param device_key: Provisioning device key from ThingsBoard. - :param device_secret: Provisioning secret from ThingsBoard. - :return: Instance of :class:`TBHTTPClient` - """ - data = { - 'deviceName': device_name, - 'provisionDeviceKey': device_key, - 'provisionDeviceSecret': device_secret - } - response = requests.post(f'{host}/api/v1/provision', json=data) - response.raise_for_status() - device = response.json() - if device['status'] == 'SUCCESS' and device['credentialsType'] == 'ACCESS_TOKEN': - return cls(host=host, token=device['credentialsValue'], name=device_name) - raise TBProvisionFailure(device) - - -class TBHTTPClient(TBHTTPDevice): - """Legacy class name.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.logger.critical('TBHTTPClient class is deprecated, please use TBHTTPDevice') diff --git a/tb_device_mqtt.py b/tb_device_mqtt.py deleted file mode 100644 index 2bea0af..0000000 --- a/tb_device_mqtt.py +++ /dev/null @@ -1,1642 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from copy import deepcopy -from inspect import signature -from time import sleep -from importlib import metadata - -from orjson.orjson import OPT_NON_STR_KEYS - -from utils import install_package -from os import environ - -def check_tb_paho_mqtt_installed(): - try: - dists = metadata.distributions() - for dist in dists: - if dist.metadata["Name"].lower() == "tb-paho-mqtt-client": - files = list(dist.files) - for file in files: - if str(file).startswith("paho/mqtt"): - return True - return False - except Exception: - return False - -if not check_tb_paho_mqtt_installed(): - try: - install_package('tb-paho-mqtt-client', version='>=2.1.2') - except Exception as e: - raise ImportError("tb-paho-mqtt-client is not installed, please install it manually.") from e - -import paho.mqtt.client as paho -from paho.mqtt.enums import CallbackAPIVersion -from math import ceil - -try: - from time import monotonic, time as timestamp -except ImportError: - from time import time as timestamp -import ssl -from threading import RLock, Thread -from enum import Enum - -from paho.mqtt.reasoncodes import ReasonCodes -from paho.mqtt.client import MQTT_ERR_QUEUE_SIZE - -from orjson import dumps, loads, JSONDecodeError - -from sdk_utils import verify_checksum - -FW_TITLE_ATTR = "fw_title" -FW_VERSION_ATTR = "fw_version" -FW_CHECKSUM_ATTR = "fw_checksum" -FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" -FW_SIZE_ATTR = "fw_size" -FW_STATE_ATTR = "fw_state" - -REQUIRED_SHARED_KEYS = f"{FW_CHECKSUM_ATTR},{FW_CHECKSUM_ALG_ATTR},{FW_SIZE_ATTR},{FW_TITLE_ATTR},{FW_VERSION_ATTR}" - -RPC_RESPONSE_TOPIC = 'v1/devices/me/rpc/response/' -RPC_REQUEST_TOPIC = 'v1/devices/me/rpc/request/' -ATTRIBUTES_TOPIC = 'v1/devices/me/attributes' -ATTRIBUTES_TOPIC_REQUEST = 'v1/devices/me/attributes/request/' -ATTRIBUTES_TOPIC_RESPONSE = 'v1/devices/me/attributes/response/' -TELEMETRY_TOPIC = 'v1/devices/me/telemetry' -CLAIMING_TOPIC = 'v1/devices/me/claim' -PROVISION_TOPIC_REQUEST = '/provision/request' -PROVISION_TOPIC_RESPONSE = '/provision/response' -log = logging.getLogger('tb_connection') - -RESULT_CODES = { - 1: "incorrect protocol version", - 2: "invalid client identifier", - 3: "server unavailable", - 4: "bad username or password", - 5: "not authorized", -} - - -class TBTimeoutException(Exception): - pass - - -class TBQoSException(Exception): - pass - - -DEFAULT_TIMEOUT = 5 - - -class TBSendMethod(Enum): - SUBSCRIBE = 0 - PUBLISH = 1 - UNSUBSCRIBE = 2 - - -class TBPublishInfo: - TB_ERR_AGAIN = -1 - TB_ERR_SUCCESS = 0 - TB_ERR_NOMEM = 1 - TB_ERR_PROTOCOL = 2 - TB_ERR_INVAL = 3 - TB_ERR_NO_CONN = 4 - TB_ERR_CONN_REFUSED = 5 - TB_ERR_NOT_FOUND = 6 - TB_ERR_CONN_LOST = 7 - TB_ERR_TLS = 8 - TB_ERR_PAYLOAD_SIZE = 9 - TB_ERR_NOT_SUPPORTED = 10 - TB_ERR_AUTH = 11 - TB_ERR_ACL_DENIED = 12 - TB_ERR_UNKNOWN = 13 - TB_ERR_ERRNO = 14 - TB_ERR_QUEUE_SIZE = 15 - - ERRORS_DESCRIPTION = { - -1: 'Previous error repeated.', - 0: 'The operation completed successfully.', - 1: 'Out of memory.', - 2: 'A network protocol error occurred when communicating with the broker.', - 3: 'Invalid function arguments provided.', - 4: 'The client is not currently connected.', - 5: 'The connection was refused.', - 6: 'Entity not found (for example, trying to unsubscribe from a topic not currently subscribed to).', - 7: 'The connection was lost.', - 8: 'A TLS error occurred.', - 9: 'Payload size is too large.', - 10: 'This feature is not supported.', - 11: 'Authorization failed.', - 12: 'Access denied to the specified ACL.', - 13: 'Unknown error.', - 14: 'A system call returned an error.', - 15: 'The queue size was exceeded.', - 16: 'The keepalive time has been exceeded.' - } - - def __init__(self, message_info): - self.message_info = message_info - - # pylint: disable=invalid-name - def rc(self): - if isinstance(self.message_info, list): - for info in self.message_info: - if isinstance(info.rc, ReasonCodes): - if info.rc.value == 0: - continue - return info.rc - else: - if info.rc != 0: - return info.rc - return self.TB_ERR_SUCCESS - else: - if isinstance(self.message_info.rc, ReasonCodes): - return self.message_info.rc.value - return self.message_info.rc - - def mid(self): - if isinstance(self.message_info, list): - return [info.mid for info in self.message_info] - else: - return self.message_info.mid - - def get(self): - if isinstance(self.message_info, list): - try: - for info in self.message_info: - info.wait_for_publish(timeout=1) - except Exception as e: - global log - log = logging.getLogger('tb_connection') - log.error("Error while waiting for publish: %s", e) - else: - self.message_info.wait_for_publish(timeout=1) - return self.rc() - - -class GreedyTokenBucket: - def __init__(self, capacity, duration_sec): - self.capacity = float(capacity) - self.duration = float(duration_sec) - self.tokens = float(capacity) - self.last_updated = monotonic() - - def refill(self): - now = monotonic() - elapsed = now - self.last_updated - refill_rate = self.capacity / self.duration - refill_amount = elapsed * refill_rate - self.tokens = min(self.capacity, self.tokens + refill_amount) - self.last_updated = now - - def can_consume(self, amount=1): - self.refill() - return round(self.tokens, 6) >= round(amount, 6) - - def consume(self, amount=1): - self.refill() - if self.tokens >= amount: - self.tokens -= amount - return True - return False - - def get_remaining_tokens(self): - self.refill() - return self.tokens - - -DEFAULT_RATE_LIMIT_PERCENTAGE = environ.get('TB_DEFAULT_RATE_LIMIT_PERCENTAGE') -if DEFAULT_RATE_LIMIT_PERCENTAGE is None: - DEFAULT_RATE_LIMIT_PERCENTAGE = 80 -else: - try: - DEFAULT_RATE_LIMIT_PERCENTAGE = int(DEFAULT_RATE_LIMIT_PERCENTAGE) - except ValueError: - log.warning("Invalid value for TB_DEFAULT_RATE_LIMIT_PERCENTAGE, using default value of 80%%") - DEFAULT_RATE_LIMIT_PERCENTAGE = 80 - -class RateLimit: - def __init__(self, rate_limit, name=None, percentage=DEFAULT_RATE_LIMIT_PERCENTAGE): - self.__reached_limit_index = 0 - self.__reached_limit_index_time = 0 - self._no_limit = False - self._rate_buckets = {} - self.__lock = RLock() - self._minimal_timeout = DEFAULT_TIMEOUT - self._minimal_limit = float("inf") - - from_dict = isinstance(rate_limit, dict) - self.name = name - self.percentage = percentage - - if from_dict: - self._no_limit = rate_limit.get('no_limit', False) - self.percentage = rate_limit.get('percentage', percentage) - self.name = rate_limit.get('name', name) - - rate_limits = rate_limit.get('rateLimits', {}) - for duration_str, bucket_info in rate_limits.items(): - try: - duration = int(duration_str) - capacity = bucket_info.get("capacity") - tokens = bucket_info.get("tokens") - last_updated = bucket_info.get("last_updated") - - if capacity is None or tokens is None: - continue - - bucket = GreedyTokenBucket(capacity, duration) - bucket.tokens = min(capacity, float(tokens)) - bucket.last_updated = float(last_updated) if last_updated is not None else monotonic() - - self._rate_buckets[duration] = bucket - self._minimal_limit = min(self._minimal_limit, capacity) - self._minimal_timeout = min(self._minimal_timeout, duration + 1) - except Exception as e: - log.warning("Invalid bucket format for duration %s: %s", duration_str, e) - - else: - clean = ''.join(c for c in rate_limit if c not in [' ', ',', ';']) - if clean in ("", "0:0"): - self._no_limit = True - return - - rate_configs = rate_limit.replace(";", ",").split(",") - for rate in rate_configs: - if not rate.strip(): - continue - try: - limit_str, duration_str = rate.strip().split(":") - limit = int(int(limit_str) * self.percentage / 100) - duration = int(duration_str) - bucket = GreedyTokenBucket(limit, duration) - self._rate_buckets[duration] = bucket - self._minimal_limit = min(self._minimal_limit, limit) - self._minimal_timeout = min(self._minimal_timeout, duration + 1) - except Exception as e: - log.warning("Invalid rate limit format '%s': %s", rate, e) - - log.debug("Rate limit %s set to values:", self.name) - for duration, bucket in self._rate_buckets.items(): - log.debug("Window: %ss, Limit: %s", duration, bucket.capacity) - - def increase_rate_limit_counter(self, amount=1): - if self._no_limit: - return - with self.__lock: - for bucket in self._rate_buckets.values(): - bucket.refill() - bucket.tokens = max(0.0, bucket.tokens - amount) - - def check_limit_reached(self, amount=1): - if self._no_limit: - return False - with self.__lock: - for duration, bucket in self._rate_buckets.items(): - if not bucket.can_consume(amount): - return bucket.capacity, duration - - for duration, bucket in self._rate_buckets.items(): - log.debug("%s left tokens: %.2f per %r seconds", - self.name, - bucket.get_remaining_tokens(), - duration) - - return False - - - def get_minimal_limit(self): - return self._minimal_limit if self.has_limit() else 0 - - def get_minimal_timeout(self): - return self._minimal_timeout if self.has_limit() else 0 - - def has_limit(self): - return not self._no_limit - - def set_limit(self, rate_limit, percentage=DEFAULT_RATE_LIMIT_PERCENTAGE): - with self.__lock: - self._minimal_timeout = DEFAULT_TIMEOUT - self._minimal_limit = float("inf") - - old_buckets = deepcopy(self._rate_buckets) - self._rate_buckets = {} - self.percentage = percentage if percentage > 0 else self.percentage - - clean = ''.join(c for c in rate_limit if c not in [' ', ',', ';']) - if clean in ("", "0:0"): - self._no_limit = True - return - - rate_configs = rate_limit.replace(";", ",").split(",") - - for rate in rate_configs: - if not rate.strip(): - continue - try: - limit_str, duration_str = rate.strip().split(":") - duration = int(duration_str) - new_capacity = int(int(limit_str) * self.percentage / 100) - - previous_bucket = old_buckets.get(duration) - new_bucket = GreedyTokenBucket(new_capacity, duration) - - if previous_bucket: - previous_bucket.refill() - used = max(0.0, previous_bucket.capacity - previous_bucket.tokens) - new_tokens = new_capacity - used - new_bucket.tokens = min(new_capacity, max(0.0, new_tokens)) - new_bucket.last_updated = monotonic() - else: - new_bucket.tokens = new_capacity - new_bucket.last_updated = monotonic() - - self._rate_buckets[duration] = new_bucket - self._minimal_limit = min(self._minimal_limit, new_bucket.capacity) - self._minimal_timeout = min(self._minimal_timeout, duration + 1) - - except Exception as e: - log.warning("Invalid rate limit format '%s': %s", rate, e) - - self._no_limit = not bool(self._rate_buckets) - log.debug("Rate limit set to values:") - for duration, bucket in self._rate_buckets.items(): - log.debug("Duration: %ss, Limit: %s", duration, bucket.capacity) - - def reach_limit(self): - if self._no_limit or not self._rate_buckets: - return - - with self.__lock: - durations = sorted(self._rate_buckets.keys()) - current_monotonic = monotonic() - if self.__reached_limit_index_time >= current_monotonic - self._rate_buckets[durations[-1]].duration: - self.__reached_limit_index = 0 - self.__reached_limit_index_time = current_monotonic - if self.__reached_limit_index >= len(durations): - self.__reached_limit_index = 0 - self.__reached_limit_index_time = current_monotonic - - target_duration = durations[self.__reached_limit_index] - bucket = self._rate_buckets[target_duration] - bucket.refill() - bucket.tokens = 0.0 - - self.__reached_limit_index += 1 - log.info("Received disconnection due to rate limit for \"%s\" rate limit, waiting for tokens in bucket for %s seconds", - self.name, - target_duration) - return self.__reached_limit_index, self.__reached_limit_index_time - - @property - def __dict__(self): - rate_limits_dict = {} - for duration, bucket in self._rate_buckets.items(): - rate_limits_dict[str(duration)] = { - "capacity": bucket.capacity, - "tokens": bucket.get_remaining_tokens(), - "last_updated": bucket.last_updated - } - return { - "rateLimits": rate_limits_dict, - "name": self.name, - "percentage": self.percentage, - "no_limit": self._no_limit - } - - @staticmethod - def get_rate_limits_by_host(host, rate_limit, dp_rate_limit): - rate_limit = RateLimit.get_rate_limit_by_host(host, rate_limit) - dp_rate_limit = RateLimit.get_dp_rate_limit_by_host(host, dp_rate_limit) - - return rate_limit, dp_rate_limit - - @staticmethod - def get_rate_limit_by_host(host, rate_limit): - if rate_limit == "DEFAULT_TELEMETRY_RATE_LIMIT": - if "thingsboard.cloud" in host: - rate_limit = "10:1,60:60," - elif "tb" in host and "cloud" in host: - rate_limit = "10:1,60:60," - elif "demo.thingsboard.io" in host: - rate_limit = "10:1,60:60," - else: - rate_limit = "0:0," - elif rate_limit == "DEFAULT_MESSAGES_RATE_LIMIT": - if "thingsboard.cloud" in host: - rate_limit = "10:1,60:60," - elif "tb" in host and "cloud" in host: - rate_limit = "10:1,60:60," - elif "demo.thingsboard.io" in host: - rate_limit = "10:1,60:60," - else: - rate_limit = "0:0," - else: - rate_limit = rate_limit - - return rate_limit - - @staticmethod - def get_dp_rate_limit_by_host(host, dp_rate_limit): - if dp_rate_limit == "DEFAULT_TELEMETRY_DP_RATE_LIMIT": - if "thingsboard.cloud" in host: - dp_rate_limit = "10:1,300:60," - elif "tb" in host and "cloud" in host: - dp_rate_limit = "10:1,300:60," - elif "demo.thingsboard.io" in host: - dp_rate_limit = "10:1,300:60," - else: - dp_rate_limit = "0:0," - else: - dp_rate_limit = dp_rate_limit - - return dp_rate_limit - - -class TBDeviceMqttClient: - """ThingsBoard MQTT client. This class provides interface to send data to ThingsBoard and receive data from""" - - EMPTY_RATE_LIMIT = RateLimit('0:0,', "EMPTY_RATE_LIMIT") - - def __init__(self, host, port=1883, username=None, password=None, quality_of_service=None, client_id="", - chunk_size=0, messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", - telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", - telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT", max_payload_size=8196, **kwargs): - # Added for compatibility with old versions - if kwargs.get('rate_limit') is not None or kwargs.get('dp_rate_limit') is not None: - messages_rate_limit = messages_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', messages_rate_limit) # noqa - telemetry_rate_limit = telemetry_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', telemetry_rate_limit) # noqa - telemetry_dp_rate_limit = telemetry_dp_rate_limit if kwargs.get('dp_rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('dp_rate_limit', telemetry_dp_rate_limit) # noqa - self._client = paho.Client(protocol=5, client_id=client_id, callback_api_version=CallbackAPIVersion.VERSION2) - self.quality_of_service = quality_of_service if quality_of_service is not None else 1 - self.__host = host - self.__port = port - if username == "": - log.warning("Token is not set, connection without TLS won't be established!") - else: - self._client.username_pw_set(username, password=password) - self._lock = RLock() - - self._attr_request_dict = {} - self.stopped = False - self.__is_connected = False - self.__device_on_server_side_rpc_response = None - self.__connect_callback = None - self.__device_max_sub_id = 0 - self.__device_client_rpc_number = 0 - self.__device_sub_dict = {} - self.__device_client_rpc_dict = {} - self.__attr_request_number = 0 - self.__error_logged = 0 - self.max_payload_size = max_payload_size - self.service_configuration_callback = self.on_service_configuration - telemetry_rate_limit, telemetry_dp_rate_limit = RateLimit.get_rate_limits_by_host(self.__host, - telemetry_rate_limit, - telemetry_dp_rate_limit) - messages_rate_limit = RateLimit.get_rate_limit_by_host(self.__host, messages_rate_limit) - - self._messages_rate_limit = RateLimit(messages_rate_limit, "Rate limit for messages") - self._telemetry_rate_limit = RateLimit(telemetry_rate_limit, "Rate limit for telemetry messages") - self._telemetry_dp_rate_limit = RateLimit(telemetry_dp_rate_limit, "Rate limit for telemetry data points") - self.max_inflight_messages_set(self._telemetry_rate_limit.get_minimal_limit()) - self.__attrs_request_timeout = {} - self.__timeout_thread = Thread(target=self.__timeout_check, name="Timeout check thread") - self.__timeout_thread.daemon = True - self.__timeout_thread.start() - self._client.on_connect = self._on_connect - self._client.on_publish = self._on_publish - self._client.on_message = self._on_message - self._client.on_disconnect = self._on_disconnect - self.current_firmware_info = { - "current_" + FW_TITLE_ATTR: "Initial", - "current_" + FW_VERSION_ATTR: "v0", - FW_STATE_ATTR: "IDLE" - } - self.__request_id = 0 - self.__firmware_request_id = 0 - self.__chunk_size = chunk_size - self.firmware_received = False - self.rate_limits_received = False - self.__request_service_configuration_required = False - self.__service_loop = Thread(target=self.__service_loop, name="Service loop", daemon=True) - self.__service_loop.start() - self.__messages_limit_reached_set_time = (0,0) - self.__datapoints_limit_reached_set_time = (0,0) - - def __service_loop(self): - while not self.stopped: - if self.__request_service_configuration_required: - self.request_service_configuration(self.service_configuration_callback) - self.__request_service_configuration_required = False - elif self.firmware_received: - self.current_firmware_info[FW_STATE_ATTR] = "UPDATING" - self.send_telemetry(self.current_firmware_info) - sleep(1) - - self.__on_firmware_received(self.firmware_info.get(FW_VERSION_ATTR)) - - self.current_firmware_info = { - "current_" + FW_TITLE_ATTR: self.firmware_info.get(FW_TITLE_ATTR), - "current_" + FW_VERSION_ATTR: self.firmware_info.get(FW_VERSION_ATTR), - FW_STATE_ATTR: "UPDATED" - } - self.send_telemetry(self.current_firmware_info) - self.firmware_received = False - sleep(0.05) - - def _on_publish(self, client, userdata, mid, rc=None, properties=None): - if isinstance(rc, ReasonCodes) and rc.value != 0: - log.debug("Publish failed with result code %s (%s) ", str(rc.value), rc.getName()) - if rc.value in [151, 131]: - if self.__messages_limit_reached_set_time[1] - monotonic() > self.__messages_limit_reached_set_time[0]: - self.__messages_limit_reached_set_time = self._messages_rate_limit.reach_limit() - if self.__datapoints_limit_reached_set_time[1] - monotonic() > self.__datapoints_limit_reached_set_time[0]: - self._telemetry_dp_rate_limit.reach_limit() - if rc.value == 0: - if self.__messages_limit_reached_set_time[0] > 0 and self.__messages_limit_reached_set_time[1] > 0: - self.__messages_limit_reached_set_time = (0, 0) - if self.__datapoints_limit_reached_set_time[0] > 0 and self.__datapoints_limit_reached_set_time[1] > 0: - self.__datapoints_limit_reached_set_time = (0, 0) - - def _on_disconnect(self, client: paho.Client, userdata, disconnect_flags, reason=None, properties=None): - self.__is_connected = False - with self._client._out_message_mutex: - client._out_packet.clear() - client._out_messages.clear() - client._in_messages.clear() - self.__attr_request_number = 0 - self.__device_max_sub_id = 0 - self.__device_client_rpc_number = 0 - self.__device_sub_dict = {} - self.__device_client_rpc_dict = {} - self.__attrs_request_timeout = {} - result_code = reason.value - if disconnect_flags.is_disconnect_packet_from_server: - log.warning("MQTT client was disconnected by server with reason code %s (%s) ", - str(result_code), reason.getName()) - else: - log.info("MQTT client was disconnected by client with reason code %s (%s) ", - str(result_code), reason.getName()) - log.debug("Client: %s, user data: %s, result code: %s. Description: %s", - str(client), str(userdata), - str(result_code), reason.getName()) - - def _on_connect(self, client, userdata, connect_flags, result_code, properties, *extra_params): - if result_code == 0: - self.__is_connected = True - log.info("MQTT client %r - Connected!", client) - if properties: - log.debug("MQTT client %r - CONACK Properties: %r", client, properties) - config = {} - if hasattr(properties, 'MaximumPacketSize'): - config['maxPayloadSize'] = int(properties.MaximumPacketSize * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - if hasattr(properties, 'ReceiveMaximum'): - config['maxInflightMessages'] = properties.ReceiveMaximum - if config: - self.on_service_configuration(None, config) - self._subscribe_to_topic(ATTRIBUTES_TOPIC, qos=self.quality_of_service) - self._subscribe_to_topic(ATTRIBUTES_TOPIC + "/response/+", qos=self.quality_of_service) - self._subscribe_to_topic(RPC_REQUEST_TOPIC + '+', qos=self.quality_of_service) - self._subscribe_to_topic(RPC_RESPONSE_TOPIC + '+', qos=self.quality_of_service) - self.__request_service_configuration_required = True - else: - log.error("Connection failed with result code %s (%s) ", - str(result_code.value), result_code.getName()) - - if callable(self.__connect_callback): - sleep(.2) - if "tb_client" in signature(self.__connect_callback).parameters: - self.__connect_callback(client, userdata, connect_flags, result_code, properties, *extra_params, tb_client=self) - else: - self.__connect_callback(client, userdata, connect_flags, result_code, *extra_params) - - if result_code.value in [159, 151]: - log.debug("Connection rate limit reached, waiting before reconnecting...") - sleep(1) # Wait for 1 second before reconnecting, if connection rate limit is reached - log.debug("Reconnecting allowed...") - - def get_firmware_update(self): - self._client.subscribe("v2/fw/response/+") - self.send_telemetry(self.current_firmware_info) - self.__request_firmware_info() - - def __request_firmware_info(self): - self.__request_id = self.__request_id + 1 - self._publish_data({"sharedKeys": REQUIRED_SHARED_KEYS}, - f"v1/devices/me/attributes/request/{self.__request_id}", - 1) - - def is_connected(self): - return self.__is_connected - - def connect(self, callback=None, min_reconnect_delay=1, timeout=120, tls=False, ca_certs=None, cert_file=None, - key_file=None, keepalive=120): - """Connect to ThingsBoard. The callback will be called when the connection is established.""" - if tls: - try: - self._client.tls_set(ca_certs=ca_certs, - certfile=cert_file, - keyfile=key_file, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=ssl.PROTOCOL_TLSv1_2, - ciphers=None) - self._client.tls_insecure_set(False) - except ValueError: - pass - self.reconnect_delay_set(min_reconnect_delay, timeout) - self._client.connect(self.__host, self.__port, keepalive=keepalive) - self._client.loop_start() - self.__connect_callback = callback - - def disconnect(self): - """Disconnect from ThingsBoard.""" - result = self._client.disconnect() - log.debug(self._client) - log.debug("Disconnecting from ThingsBoard") - self.__is_connected = False - self._client.loop_stop() - return result - - def stop(self): - self.stopped = True - - def _on_message(self, client, userdata, message): - update_response_pattern = "v2/fw/response/" + str(self.__firmware_request_id) + "/chunk/" - if message.topic.startswith(update_response_pattern): - firmware_data = message.payload - - self.firmware_data = self.firmware_data + firmware_data - self.__current_chunk = self.__current_chunk + 1 - - log.debug('Getting chunk with number: %s. Chunk size is : %r byte(s).' % ( - self.__current_chunk, self.__chunk_size)) - - if len(self.firmware_data) == self.__target_firmware_length: - self.__process_firmware() - else: - self.__get_firmware() - else: - content = self._decode(message) - self._on_decoded_message(content, message) - - def _on_decoded_message(self, content, message): - if message.topic.startswith(RPC_REQUEST_TOPIC): - self._messages_rate_limit.increase_rate_limit_counter() - request_id = message.topic[len(RPC_REQUEST_TOPIC):len(message.topic)] - if self.__device_on_server_side_rpc_response: - self.__device_on_server_side_rpc_response(request_id, content) - elif message.topic.startswith(RPC_RESPONSE_TOPIC): - self._messages_rate_limit.increase_rate_limit_counter() - with self._lock: - request_id = int(message.topic[len(RPC_RESPONSE_TOPIC):len(message.topic)]) - if self.__device_client_rpc_dict.get(request_id): - callback = self.__device_client_rpc_dict.pop(request_id) - else: - callback = None - if callback is not None: - callback(request_id, content, None) - elif message.topic == ATTRIBUTES_TOPIC: - self._messages_rate_limit.increase_rate_limit_counter() - dict_results = [] - with self._lock: - # callbacks for everything - if self.__device_sub_dict.get("*"): - for subscription_id in self.__device_sub_dict["*"]: - dict_results.append(self.__device_sub_dict["*"][subscription_id]) - # specific callback - keys = content.keys() - keys_list = [] - for key in keys: - keys_list.append(key) - # iterate through message - for key in keys_list: - # find key in our dict - if self.__device_sub_dict.get(key): - for subscription in self.__device_sub_dict[key]: - dict_results.append(self.__device_sub_dict[key][subscription]) - for res in dict_results: - res(content, None) - elif message.topic.startswith(ATTRIBUTES_TOPIC_RESPONSE): - self._messages_rate_limit.increase_rate_limit_counter() - with self._lock: - req_id = int(message.topic[len(ATTRIBUTES_TOPIC + "/response/"):]) - # pop callback and use it - if self._attr_request_dict.get(req_id): - callback = self._attr_request_dict.pop(req_id) - else: - callback = None - if isinstance(callback, tuple): - callback[0](content, None, callback[1]) - elif callback is not None: - callback(content, None) - else: - log.debug("Message received with topic: %s", message.topic) - - if message.topic.startswith("v1/devices/me/attributes"): - self._messages_rate_limit.increase_rate_limit_counter() - self.firmware_info = loads(message.payload) - if "/response/" in message.topic: - self.firmware_info = self.firmware_info.get("shared", {}) if isinstance(self.firmware_info, dict) else {} # noqa - if ((self.firmware_info.get(FW_VERSION_ATTR) is not None - and self.firmware_info.get(FW_VERSION_ATTR) != self.current_firmware_info.get("current_" + FW_VERSION_ATTR)) # noqa - or (self.firmware_info.get(FW_TITLE_ATTR) is not None - and self.firmware_info.get(FW_TITLE_ATTR) != self.current_firmware_info.get("current_" + FW_TITLE_ATTR))): # noqa - log.debug('Firmware is not the same') - self.firmware_data = b'' - self.__current_chunk = 0 - - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADING" - self.send_telemetry(self.current_firmware_info) - sleep(1) - - self.__firmware_request_id = self.__firmware_request_id + 1 - self.__target_firmware_length = self.firmware_info[FW_SIZE_ATTR] - self.__chunk_count = 0 if not self.__chunk_size else ceil( - self.firmware_info[FW_SIZE_ATTR] / self.__chunk_size) - self.__get_firmware() - - def __process_firmware(self): - self.current_firmware_info[FW_STATE_ATTR] = "DOWNLOADED" - self.send_telemetry(self.current_firmware_info) - sleep(1) - - verification_result = verify_checksum(self.firmware_data, self.firmware_info.get(FW_CHECKSUM_ALG_ATTR), - self.firmware_info.get(FW_CHECKSUM_ATTR)) - - if verification_result: - log.debug('Checksum verified!') - self.current_firmware_info[FW_STATE_ATTR] = "VERIFIED" - self.send_telemetry(self.current_firmware_info) - sleep(1) - else: - log.debug('Checksum verification failed!') - self.current_firmware_info[FW_STATE_ATTR] = "FAILED" - self.send_telemetry(self.current_firmware_info) - self.__request_firmware_info() - return - self.firmware_received = True - - def __get_firmware(self): - payload = '' if not self.__chunk_size or self.__chunk_size > self.firmware_info.get(FW_SIZE_ATTR, 0) \ - else str(self.__chunk_size).encode() - self._client.publish( - f"v2/fw/request/{self.__firmware_request_id}/chunk/{self.__current_chunk}", - payload=payload, qos=1) - - def __on_firmware_received(self, version_to): - with open(self.firmware_info.get(FW_TITLE_ATTR), "wb") as firmware_file: - firmware_file.write(self.firmware_data) - log.info('Firmware is updated!\n Current firmware version is: %s' % version_to) - - @staticmethod - def _decode(message): - try: - if isinstance(message.payload, bytes): - content = loads(message.payload.decode("utf-8", "ignore")) - else: - content = loads(message.payload) - except JSONDecodeError: - try: - content = message.payload.decode("utf-8", "ignore") - except JSONDecodeError: - content = message.payload - return content - - def max_inflight_messages_set(self, inflight): - """Set the maximum number of messages with QoS>0 that can be a part way through their network flow at once. - Defaults to minimal rate limit. Increasing this value will consume more memory but can increase throughput.""" - if inflight < 0: - log.error("Inflight messages number must be equal or greater than 0") - return - self._client._max_inflight_messages = inflight - - def max_queued_messages_set(self, queue_size): - """Set the maximum number of outgoing messages with QoS>0 that can be pending in the outgoing message queue. - Defaults to 0. 0 means unlimited. When the queue is full, any further outgoing messages would be dropped.""" - if queue_size < 0: - raise ValueError("Invalid queue size.") - - self._client._max_queued_messages = queue_size - - def reconnect_delay_set(self, min_delay=1, max_delay=120): - """The client will automatically retry connection. Between each attempt it will wait a number of seconds - between min_delay and max_delay. When the connection is lost, initially the reconnection attempt is delayed - of min_delay seconds. It’s doubled between subsequent attempt up to max_delay. The delay is reset to min_delay - when the connection complete (e.g. the CONNACK is received, not just the TCP connection is established).""" - self._client.reconnect_delay_set(min_delay, max_delay) - - def send_rpc_reply(self, req_id, resp, quality_of_service=None, wait_for_publish=False): - """Send RPC reply to ThingsBoard. The response will be sent to the RPC_RESPONSE_TOPIC with the request id.""" - quality_of_service = quality_of_service if quality_of_service is not None else self.quality_of_service - if quality_of_service not in (0, 1): - log.error("Quality of service (qos) value must be 0 or 1") - return None - info = self._publish_data(resp, RPC_RESPONSE_TOPIC + req_id, quality_of_service) - if wait_for_publish: - info.get() - - def send_rpc_call(self, method, params, callback): - """Send RPC call to ThingsBoard. The callback will be called when the response is received.""" - with self._lock: - self.__device_client_rpc_number += 1 - self.__device_client_rpc_dict.update({self.__device_client_rpc_number: callback}) - rpc_request_id = self.__device_client_rpc_number - payload = {"method": method, "params": params} - self._publish_data(payload, RPC_REQUEST_TOPIC + str(rpc_request_id), self.quality_of_service) - - def request_service_configuration(self, callback): - self.send_rpc_call("getSessionLimits", {"timeout": 5000}, callback) - - def on_service_configuration(self, _, response, *args, **kwargs): - global log - log = logging.getLogger('tb_connection') - if "error" in response: - log.warning("Timeout while waiting for service configuration!, session will use default configuration.") - self.rate_limits_received = True - return - service_config = response - if not isinstance(service_config, dict) or 'rateLimits' not in service_config: - log.warning("Cannot retrieve service configuration, session will use default configuration.") - log.debug("Received the following response: %r", service_config) - return - if service_config.get("rateLimits"): - rate_limits_config = service_config.get("rateLimits") - - if rate_limits_config.get('messages'): - self._messages_rate_limit.set_limit(rate_limits_config.get('messages')) - else: - self._messages_rate_limit.set_limit('0:0,') - - if rate_limits_config.get('telemetryMessages'): - self._telemetry_rate_limit.set_limit(rate_limits_config.get('telemetryMessages')) - else: - self._telemetry_rate_limit.set_limit('0:0,') - - if rate_limits_config.get('telemetryDataPoints'): - self._telemetry_dp_rate_limit.set_limit(rate_limits_config.get('telemetryDataPoints')) - else: - self._telemetry_dp_rate_limit.set_limit('0:0,') - - if service_config.get('maxInflightMessages'): - use_messages_rate_limit_factor = self._messages_rate_limit.has_limit() - use_telemetry_rate_limit_factor = self._telemetry_rate_limit.has_limit() - service_config_inflight_messages = int(service_config.get('maxInflightMessages', 100)) - if use_messages_rate_limit_factor and use_telemetry_rate_limit_factor: - max_inflight_messages = int(min(self._messages_rate_limit.get_minimal_limit(), - self._telemetry_rate_limit.get_minimal_limit(), - service_config_inflight_messages) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - elif use_messages_rate_limit_factor: - max_inflight_messages = int(min(self._messages_rate_limit.get_minimal_limit(), - service_config_inflight_messages) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - elif use_telemetry_rate_limit_factor: - max_inflight_messages = int(min(self._telemetry_rate_limit.get_minimal_limit(), - service_config_inflight_messages) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - else: - max_inflight_messages = int(service_config.get('maxInflightMessages', 100) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - if max_inflight_messages == 0: - max_inflight_messages = 10_000 # No limitation on device queue on transport level - if max_inflight_messages < 1: - max_inflight_messages = 1 - self.max_inflight_messages_set(max_inflight_messages) - - if (not self._messages_rate_limit.has_limit() and - not self._telemetry_rate_limit.has_limit() and - not self._telemetry_dp_rate_limit.has_limit() and - not kwargs.get("gateway_limits_present", False)): - log.debug("No rate limits for device, setting max_queued_messages to 50000") - self.max_queued_messages_set(50000) - else: - log.debug("Rate limits for device, setting max_queued_messages to %r", max_inflight_messages) - self.max_queued_messages_set(max_inflight_messages) - - if service_config.get('maxPayloadSize'): - self.max_payload_size = int(int(service_config.get('maxPayloadSize')) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) - log.info("Service configuration was successfully retrieved and applied.") - log.info("Current device limits: %r", service_config) - self.rate_limits_received = True - - def set_server_side_rpc_request_handler(self, handler): - """Set the callback that will be called when a server-side RPC is received.""" - self.__device_on_server_side_rpc_response = handler - - def _wait_for_rate_limit_released(self, timeout, message_rate_limit, dp_rate_limit=None, amount=1): - if not message_rate_limit.has_limit() and not (dp_rate_limit is None or dp_rate_limit.has_limit()): - return - start_time = int(monotonic()) - dp_rate_limit_timeout = dp_rate_limit.get_minimal_timeout() if dp_rate_limit is not None else 0 - timeout = max(message_rate_limit.get_minimal_timeout(), dp_rate_limit_timeout, timeout) + 10 - timeout_updated = False - disconnected = False - limit_reached_check = True - log_posted = False - waited = False - while limit_reached_check: - - message_rate_limit_check = message_rate_limit.check_limit_reached() - datapoints_rate_limit_check = dp_rate_limit.check_limit_reached(amount=amount) if dp_rate_limit is not None else False - limit_reached_check = (message_rate_limit_check - or datapoints_rate_limit_check - or not self.is_connected()) - if isinstance(limit_reached_check, tuple) and timeout < limit_reached_check[1]: - timeout = limit_reached_check[1] - if not timeout_updated and limit_reached_check: - timeout += 10 - timeout_updated = True - if self.stopped: - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if not disconnected and not self.is_connected(): - log.warning("Waiting for connection to be established before sending data to ThingsBoard!") - disconnected = True - timeout = max(timeout, 180) + 10 - if int(monotonic()) >= timeout + start_time: - if message_rate_limit_check: - log.warning("Timeout while waiting for rate limit for messages to be released! Rate limit: %r:%r", - message_rate_limit_check, - message_rate_limit_check) - elif datapoints_rate_limit_check: - log.warning("Timeout while waiting for rate limit for data points to be released! Rate limit: %r:%r", - datapoints_rate_limit_check, - datapoints_rate_limit_check) - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if not log_posted and limit_reached_check: - if log.isEnabledFor(logging.DEBUG): - if isinstance(message_rate_limit_check, tuple): - log.debug("Rate limit for messages (%r messages per %r second(s)) - almost reached, waiting for rate limit to be released...", - *message_rate_limit_check) - if isinstance(datapoints_rate_limit_check, tuple): - log.debug("Rate limit for data points (%r data points per %r second(s)) - almost reached, waiting for rate limit to be released...", - *datapoints_rate_limit_check) - waited = True - log_posted = True - if limit_reached_check: - sleep(.005) - if waited: - log.debug("Rate limit released, sending data to ThingsBoard...") - - def _wait_until_current_queued_messages_processed(self): - logger = None - - max_wait_time = 300 - log_interval = 5 - stuck_threshold = 15 - polling_interval = 0.05 - max_inflight = self._client._max_inflight_messages - - if len(self._client._out_messages) < max_inflight or max_inflight == 0: - return - - waiting_start = monotonic() - last_log_time = waiting_start - last_queue_size = len(self._client._out_messages) - last_queue_change_time = waiting_start - - while not self.stopped: - now = monotonic() - elapsed = now - waiting_start - current_queue_size = len(self._client._out_messages) - - if current_queue_size < max_inflight: - return - - if current_queue_size != last_queue_size: - last_queue_size = current_queue_size - last_queue_change_time = now - - if (now - last_queue_change_time > stuck_threshold - and not self._client.is_connected()): - if logger is None: - logger = logging.getLogger('tb_connection') - logger.warning( - "MQTT out_messages queue is stuck (%d messages) and client is disconnected. " - "Clearing queue after %.2f seconds.", - current_queue_size, now - last_queue_change_time - ) - with self._client._out_message_mutex: - self._client._out_packet.clear() - return - - if now - last_log_time >= log_interval: - if logger is None: - logger = logging.getLogger('tb_connection') - logger.debug( - "Waiting for MQTT queue to drain: %d messages (max inflight %d). " - "Elapsed: %.2f s", - current_queue_size, max_inflight, elapsed - ) - last_log_time = now - - if elapsed > max_wait_time: - if logger is None: - logger = logging.getLogger('tb_connection') - logger.warning( - "MQTT wait timeout reached (%.2f s). Queue still has %d messages.", - elapsed, current_queue_size - ) - return - - sleep(polling_interval) - - def _send_request(self, _type, kwargs, timeout=DEFAULT_TIMEOUT, device=None, - msg_rate_limit=None, dp_rate_limit=None): - topic = kwargs['topic'] - if msg_rate_limit is None: - if topic == TELEMETRY_TOPIC or topic ==ATTRIBUTES_TOPIC: - msg_rate_limit = self._telemetry_rate_limit - else: - msg_rate_limit = self._messages_rate_limit - if dp_rate_limit is None: - if topic == TELEMETRY_TOPIC or topic ==ATTRIBUTES_TOPIC: - dp_rate_limit = self._telemetry_dp_rate_limit - else: - dp_rate_limit = self.EMPTY_RATE_LIMIT - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - msg_rate_limit.increase_rate_limit_counter() - is_reached = self._wait_for_rate_limit_released(timeout, msg_rate_limit, dp_rate_limit) - if is_reached: - return is_reached - - if _type == TBSendMethod.PUBLISH: - self.__add_metadata_to_data_dict_from_device(kwargs["payload"]) - return self.__send_publish_with_limitations(kwargs, timeout, device, msg_rate_limit, dp_rate_limit) - elif _type == TBSendMethod.SUBSCRIBE: - return self._client.subscribe(**kwargs) - elif _type == TBSendMethod.UNSUBSCRIBE: - return self._client.unsubscribe(**kwargs) - - def __add_metadata_to_data_dict_from_device(self, data): - if isinstance(data, dict) and ("metadata" in data and isinstance(data["metadata"], dict)): - data["metadata"]["publishedTs"] = int(timestamp() * 1000) - elif isinstance(data, list): - current_time = int(timestamp() * 1000) - for data_item in data: - if isinstance(data_item, dict): - if 'ts' in data_item and ('metadata' in data_item and isinstance(data_item["metadata"], dict)): - data_item["metadata"]["publishedTs"] = current_time - elif isinstance(data, dict): - for key, value in data.items(): - self.__add_metadata_to_data_dict_from_device(value) - - def __get_rate_limits_by_topic(self, topic, device=None, msg_rate_limit=None, dp_rate_limit=None): - if device is not None: - return msg_rate_limit, dp_rate_limit - else: - if topic == TELEMETRY_TOPIC: - return self._telemetry_rate_limit, self._telemetry_dp_rate_limit - else: - return self._messages_rate_limit, None - - def __send_publish_with_limitations(self, kwargs, timeout, device=None, msg_rate_limit: RateLimit = None, - dp_rate_limit: RateLimit = None): - data = kwargs.get("payload") - if isinstance(data, str): - data = loads(data) - topic = kwargs["topic"] - attributes_format = topic.endswith('attributes') - if topic.endswith('telemetry') or attributes_format: - if device is None or data.get(device) is None: - device_split_messages = self._split_message(data, int(dp_rate_limit.get_minimal_limit()), self.max_payload_size) # noqa - if attributes_format: - split_messages = [{'message': msg_data, 'datapoints': len(msg_data)} for split_message in device_split_messages for msg_data in split_message['data']] # noqa - else: - split_messages = [{'message': split_message['data'], 'datapoints': split_message['datapoints']} for split_message in device_split_messages] # noqa - else: - device_data = data.get(device) - device_split_messages = self._split_message(device_data, int(dp_rate_limit.get_minimal_limit()), self.max_payload_size) # noqa - if attributes_format: - split_messages = [{'message': {device: msg_data}, 'datapoints': len(msg_data)} for split_message in device_split_messages for msg_data in split_message['data']] # noqa - else: - split_messages = [{'message': {device: split_message['data']}, 'datapoints': split_message['datapoints']} for split_message in device_split_messages] # noqa - else: - split_messages = [{'message': data, 'datapoints': self._count_datapoints_in_message(data, device)}] - - results = [] - for part in split_messages: - if not part: - continue - self.__send_split_message(results, part, kwargs, timeout, device, msg_rate_limit, dp_rate_limit, topic) - return TBPublishInfo(results) - - def __send_split_message(self, results, part, kwargs, timeout, device, msg_rate_limit, dp_rate_limit, - topic): - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - dp_rate_limit.increase_rate_limit_counter(part['datapoints']) - rate_limited = self._wait_for_rate_limit_released(timeout, - message_rate_limit=msg_rate_limit, - dp_rate_limit=dp_rate_limit, - amount=part['datapoints']) - if rate_limited: - return rate_limited - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - msg_rate_limit.increase_rate_limit_counter() - kwargs["payload"] = dumps(part['message'], option=OPT_NON_STR_KEYS) - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - self._wait_until_current_queued_messages_processed() - if not self.stopped: - if device is not None: - log.debug("Device: %s, Sending message to topic: %s ", device, topic) - if msg_rate_limit.has_limit() or dp_rate_limit.has_limit(): - if part['datapoints'] > 0: - log.debug("Sending message with %i datapoints", part['datapoints']) - if log.isEnabledFor(5) and hasattr(log, 'trace'): - log.trace("Message payload: %r", kwargs["payload"]) - log.debug("Rate limits after sending message: %r", msg_rate_limit.__dict__) - log.debug("Data points rate limits after sending message: %r", dp_rate_limit.__dict__) - else: - if log.isEnabledFor(5) and hasattr(log, 'trace'): - log.trace("Sending message with %r", kwargs["payload"]) - log.debug("Rate limits after sending message: %r", msg_rate_limit.__dict__) - log.debug("Data points rate limits after sending message: %r", dp_rate_limit.__dict__) - result = self._client.publish(**kwargs) - if result.rc == MQTT_ERR_QUEUE_SIZE: - error_appear_counter = 1 - sleep_time = 0.1 # 100 ms, in case of change - change max tries in while loop - while not self.stopped and result.rc == MQTT_ERR_QUEUE_SIZE: - error_appear_counter += 1 - if error_appear_counter > 78: # 78 tries ~ totally 300 seconds for sleep 0.1 - log.warning("Cannot send message to platform in %i seconds, queue size exceeded, current max inflight messages: %r, max queued messages: %r.", # noqa - int(error_appear_counter * sleep_time), - self._client._max_inflight_messages, - self._client._max_queued_messages) - if int(monotonic()) - self.__error_logged > 10: - log.debug("Queue size exceeded, waiting for messages to be processed by paho client.") - self.__error_logged = int(monotonic()) - sleep(sleep_time) # Give some time for paho to process messages - result = self._client.publish(**kwargs) - results.append(result) - - def _subscribe_to_topic(self, topic, qos=None, timeout=DEFAULT_TIMEOUT): - if qos is None: - qos = self.quality_of_service - - waiting_for_connection_message_time = 0 - while not self.is_connected() and not self.stopped: - if self.stopped: - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if monotonic() - waiting_for_connection_message_time > 10.0: - log.warning("Waiting for connection to be established before subscribing for data on ThingsBoard!") - waiting_for_connection_message_time = monotonic() - sleep(0.01) - - return self._send_request(TBSendMethod.SUBSCRIBE, - {"topic": topic, "qos": qos}, - timeout, - msg_rate_limit=self._messages_rate_limit) - - def _publish_data(self, data, topic, qos, timeout=DEFAULT_TIMEOUT, device=None, - msg_rate_limit=None, dp_rate_limit=None): - if qos is None: - qos = self.quality_of_service - if qos not in (0, 1): - log.exception("Quality of service (qos) value must be 0 or 1") - raise TBQoSException("Quality of service (qos) value must be 0 or 1") - - waiting_for_connection_message_time = 0.0 - while not self.is_connected(): - if self.stopped: - return TBPublishInfo(paho.MQTTMessageInfo(None)) - if monotonic() - waiting_for_connection_message_time > 10.0: - log.warning("Waiting for connection to be established before sending data to ThingsBoard!") - waiting_for_connection_message_time = monotonic() - sleep(0.01) - - return self._send_request(TBSendMethod.PUBLISH, {"topic": topic, "payload": data, "qos": qos}, timeout, - device=device, msg_rate_limit=msg_rate_limit, dp_rate_limit=dp_rate_limit) - - def send_telemetry(self, telemetry, quality_of_service=None, wait_for_publish=True): - """Send telemetry to ThingsBoard. The telemetry can be a single dictionary or a list of dictionaries.""" - quality_of_service = quality_of_service if quality_of_service is not None else self.quality_of_service - if not isinstance(telemetry, list) and not (isinstance(telemetry, dict) and telemetry.get("ts") is not None): - telemetry = [telemetry] - return self._publish_data(telemetry, TELEMETRY_TOPIC, quality_of_service, wait_for_publish) - - def send_attributes(self, attributes, quality_of_service=None, wait_for_publish=True): - """Send attributes to ThingsBoard. The attributes can be a single dictionary or a list of dictionaries.""" - quality_of_service = quality_of_service if quality_of_service is not None else self.quality_of_service - return self._publish_data(attributes, ATTRIBUTES_TOPIC, quality_of_service, wait_for_publish) - - def unsubscribe_from_attribute(self, subscription_id): - """Unsubscribe from attribute updates for subscription_id.""" - with self._lock: - for attribute in self.__device_sub_dict: - if self.__device_sub_dict[attribute].get(subscription_id): - del self.__device_sub_dict[attribute][subscription_id] - log.debug("Unsubscribed from %s, subscription id %i", attribute, subscription_id) - if subscription_id == '*': - self.__device_sub_dict = {} - self.__device_sub_dict = dict((k, v) for k, v in self.__device_sub_dict.items() if v) - - def clean_device_sub_dict(self): - self.__device_sub_dict = {} - - def subscribe_to_all_attributes(self, callback): - """Subscribe to all attribute updates. The callback will be called when an attribute update is received.""" - return self.subscribe_to_attribute("*", callback) - - def subscribe_to_attribute(self, key, callback): - """Subscribe to attribute updates for attribute with key. - The callback will be called when an attribute update is received.""" - with self._lock: - self.__device_max_sub_id += 1 - if key not in self.__device_sub_dict: - self.__device_sub_dict.update({key: {self.__device_max_sub_id: callback}}) - else: - self.__device_sub_dict[key].update({self.__device_max_sub_id: callback}) - log.debug("Subscribed to %s with id %i", key, self.__device_max_sub_id) - return self.__device_max_sub_id - - def request_attributes(self, client_keys=None, shared_keys=None, callback=None): - """Request attributes from ThingsBoard. The callback will be called when the response is received.""" - msg = {} - if client_keys: - tmp = "" - for key in client_keys: - tmp += key + "," - tmp = tmp[:len(tmp) - 1] - msg.update({"clientKeys": tmp}) - if shared_keys: - tmp = "" - for key in shared_keys: - tmp += key + "," - tmp = tmp[:len(tmp) - 1] - msg.update({"sharedKeys": tmp}) - - start_processing_attribute_request = int(monotonic()) - - attr_request_number = self._add_attr_request_callback(callback) - - info = self._publish_data(msg, ATTRIBUTES_TOPIC_REQUEST + str(attr_request_number), self.quality_of_service) - - self.__attrs_request_timeout[attr_request_number] = start_processing_attribute_request + 20 - return info - - def _add_attr_request_callback(self, callback): - with self._lock: - self.__attr_request_number += 1 - self._attr_request_dict.update({self.__attr_request_number: callback}) - attr_request_number = self.__attr_request_number - return attr_request_number - - def __timeout_check(self): - while not self.stopped: - current_time = int(monotonic()) - for (attr_request_number, ts) in tuple(self.__attrs_request_timeout.items()): - if current_time < ts: - continue - - with self._lock: - callback = None - if self._attr_request_dict.get(attr_request_number): - callback = self._attr_request_dict.pop(attr_request_number) - - if callback is not None: - if isinstance(callback, tuple): - callback[0](None, TBTimeoutException("Timeout while waiting for a reply for attribute request from ThingsBoard!"), # noqa - callback[1]) - else: - callback(None, TBTimeoutException("Timeout while waiting for a reply for attribute request from ThingsBoard!")) # noqa - - self.__attrs_request_timeout.pop(attr_request_number) - - sleep(0.1) - - def claim(self, secret_key, duration=30000): - """Claim the device in Thingsboard. The duration is in milliseconds.""" - claiming_request = { - "secretKey": secret_key, - "durationMs": duration - } - info = self._publish_data(claiming_request, CLAIMING_TOPIC, self.quality_of_service) - return info - - @staticmethod - def _count_datapoints_in_message(data, device=None): - datapoints = 0 - if device is not None: - if isinstance(data.get(device), list): - for data_object in data[device]: - datapoints += TBDeviceMqttClient._count_datapoints_in_message(data_object) # noqa - elif isinstance(data.get(device), dict): - datapoints += TBDeviceMqttClient._count_datapoints_in_message(data.get(device, data.get('device'))) - else: - datapoints += 1 - else: - if isinstance(data, dict): - datapoints += TBDeviceMqttClient._get_data_points_from_message(data) - elif isinstance(data, list): - for item in data: - datapoints += TBDeviceMqttClient._get_data_points_from_message(item) - else: - datapoints += 1 - return datapoints - - @staticmethod - def _get_data_points_from_message(data): - if isinstance(data, dict): - if data.get("ts") is not None and data.get("values") is not None: - datapoints_in_message_amount = len(data['values']) + len(str(data['values'])) / 1000 - else: - datapoints_in_message_amount = len(data.keys()) + len(str(data)) / 1000 - else: - datapoints_in_message_amount = len(data) + len(str(data)) / 1000 - return int(datapoints_in_message_amount) - - @staticmethod - def provision(host, - provision_device_key, - provision_device_secret, - port=1883, - device_name=None, - access_token=None, - client_id=None, - username=None, - password=None, - hash=None, - gateway=None): - """Provision the device in ThingsBoard. Returns the credentials for the device.""" - provision_request = { - "provisionDeviceKey": provision_device_key, - "provisionDeviceSecret": provision_device_secret - } - - if access_token is not None: - provision_request["token"] = access_token - provision_request["credentialsType"] = "ACCESS_TOKEN" - elif username is not None or password is not None or client_id is not None: - provision_request["username"] = username - provision_request["password"] = password - provision_request["clientId"] = client_id - provision_request["credentialsType"] = "MQTT_BASIC" - elif hash is not None: - provision_request["hash"] = hash - provision_request["credentialsType"] = "X509_CERTIFICATE" - - if device_name is not None: - provision_request["deviceName"] = device_name - - if gateway is not None: - provision_request["gateway"] = gateway - - provisioning_client = ProvisionClient(host=host, port=port, provision_request=provision_request) - provisioning_client.provision() - return provisioning_client.get_credentials() - - @staticmethod - def _split_message(message_pack, datapoints_max_count, max_payload_size): - if not message_pack: - return [] - - split_messages = [] - - if isinstance(message_pack, dict) and message_pack.get('device') and len(message_pack) in [1, 2]: - return [{ - 'data': message_pack, - 'datapoints': TBDeviceMqttClient._count_datapoints_in_message(message_pack), - 'message': message_pack - }] - - if not isinstance(message_pack, list): - message_pack = [message_pack] - - def _get_metadata_repr(metadata): - return tuple(sorted(metadata.items())) if isinstance(metadata, dict) else None - - def estimate_chunk_size(chunk): - if isinstance(chunk, dict) and "values" in chunk: - size = sum(len(str(k)) + len(str(v)) for k, v in chunk["values"].items()) - size += len(str(chunk.get("ts", ""))) - if "metadata" in chunk: - size += sum(len(str(k)) + len(str(v)) for k, v in chunk["metadata"].items()) - return size + 40 - elif isinstance(chunk, dict): - return sum(len(str(k)) + len(str(v)) for k, v in chunk.items()) + 20 - else: - return len(str(chunk)) + 20 - - ts_group_cache = {} - current_message = {"data": [], "datapoints": 0} - current_size = 0 - current_datapoints = 0 - - def flush_current_message(): - nonlocal current_message, current_size, current_datapoints - if current_message["data"]: - split_messages.append(current_message) - current_message = {"data": [], "datapoints": 0} - current_size = 0 - current_datapoints = 0 - - def split_and_add_chunk(chunk, chunk_datapoints): - nonlocal current_message, current_size, current_datapoints - - chunk_size = estimate_chunk_size(chunk) - - if (0 < datapoints_max_count <= current_datapoints + chunk_datapoints) or \ - (current_size + chunk_size > max_payload_size): - flush_current_message() - - if chunk_datapoints > datapoints_max_count > 0 or chunk_size > max_payload_size: - keys = list(chunk.get("values", {}).keys()) if isinstance(chunk, dict) else list(chunk.keys()) - if len(keys) == 1: - flush_current_message() - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - flush_current_message() - return - - max_step = max(1, datapoints_max_count if datapoints_max_count > 0 else len(keys)) - for i in range(0, len(keys), max_step): - sub_values = ( - {k: chunk["values"][k] for k in keys[i:i + max_step]} if "values" in chunk - else {k: chunk[k] for k in keys[i:i + max_step]} - ) - sub_chunk = {} - if "ts" in chunk: - sub_chunk = {"ts": chunk["ts"], "values": sub_values} - if "metadata" in chunk: - sub_chunk["metadata"] = chunk["metadata"] - else: - sub_chunk = sub_values.copy() - - sub_datapoints = len(sub_values) - sub_size = estimate_chunk_size(sub_chunk) - - if sub_size > max_payload_size or (0 < datapoints_max_count <= sub_datapoints): - flush_current_message() - current_message["data"].append(sub_chunk) - current_message["datapoints"] += sub_datapoints - current_size += sub_size - current_datapoints += sub_datapoints - flush_current_message() - else: - split_and_add_chunk(sub_chunk, sub_datapoints) - return - - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - - if 0 < datapoints_max_count == current_datapoints: - flush_current_message() - - def add_chunk_to_current_message(chunk, chunk_datapoints): - nonlocal current_message, current_size, current_datapoints - - chunk_size = estimate_chunk_size(chunk) - - if (0 < datapoints_max_count <= chunk_datapoints) or chunk_size > max_payload_size: - split_and_add_chunk(chunk, chunk_datapoints) - return - - if (0 < datapoints_max_count <= current_datapoints + chunk_datapoints) or \ - (current_size + chunk_size > max_payload_size): - flush_current_message() - - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - - if 0 < datapoints_max_count == current_datapoints: - flush_current_message() - - def flush_ts_group(ts_key, ts, metadata_repr): - nonlocal current_message, current_size, current_datapoints - if ts_key not in ts_group_cache: - return - - values, _, metadata = ts_group_cache.pop(ts_key) - keys = list(values.keys()) - - step = max(1, datapoints_max_count if datapoints_max_count > 0 else len(keys)) - for i in range(0, len(keys), step): - chunk_values = {k: values[k] for k in keys[i:i + step]} - if ts is not None: - chunk = {"ts": ts, "values": chunk_values} - if metadata: - chunk["metadata"] = metadata - else: - chunk = chunk_values.copy() - - chunk_datapoints = len(chunk_values) - chunk_size = estimate_chunk_size(chunk) - - if chunk_size > max_payload_size or (0 < datapoints_max_count <= chunk_datapoints): - flush_current_message() - current_message["data"].append(chunk) - current_message["datapoints"] += chunk_datapoints - current_size += chunk_size - current_datapoints += chunk_datapoints - flush_current_message() - else: - add_chunk_to_current_message(chunk, chunk_datapoints) - - for message in message_pack: - if not isinstance(message, dict): - continue - - ts = message.get("ts") - metadata = message.get("metadata") if isinstance(message.get("metadata"), dict) else None - values = message.get("values") if isinstance(message.get("values"), dict) else \ - message if isinstance(message, dict) else {} - - metadata_repr = _get_metadata_repr(metadata) - ts_key = (ts, metadata_repr) - - for key, value in values.items(): - pair_size = len(str(key)) + len(str(value)) + 4 - if ts_key not in ts_group_cache: - ts_group_cache[ts_key] = ({}, 0, metadata) - - group_values, group_size, group_metadata = ts_group_cache[ts_key] - - can_add = ( - (datapoints_max_count == 0 or len(group_values) < datapoints_max_count) and - (group_size + pair_size <= max_payload_size) - ) - - if can_add: - group_values[key] = value - ts_group_cache[ts_key] = (group_values, group_size + pair_size, group_metadata) - else: - flush_ts_group(ts_key, ts, metadata_repr) - ts_group_cache[ts_key] = ({key: value}, pair_size, metadata) - - for ts_key in list(ts_group_cache.keys()): - ts, metadata_repr = ts_key - flush_ts_group(ts_key, ts, metadata_repr) - - flush_current_message() - return split_messages - - @staticmethod - def _datapoints_limit_reached(datapoints_max_count, current_datapoints_size, current_size): - return 0 < datapoints_max_count <= current_datapoints_size + current_size // 1024 - - @staticmethod - def _payload_size_limit_reached(max_payload_size, current_size, additional_size): - return current_size + additional_size >= max_payload_size - - def add_attrs_request_timeout(self, attr_request_number, timeout): - self.__attrs_request_timeout[attr_request_number] = timeout - - -class ProvisionClient(paho.Client): - PROVISION_REQUEST_TOPIC = "/provision/request" - PROVISION_RESPONSE_TOPIC = "/provision/response" - - def __init__(self, host, port, provision_request): - super().__init__() - self._host = host - self._port = port - self._username = "provision" - self.__credentials = None - self.on_connect = self.__on_connect - self.on_message = self.__on_message - self.__provision_request = provision_request - - def __on_connect(self, client, _, __, rc): # Callback for connect - if rc == 0: - log.info("[Provisioning client] Connected to ThingsBoard ") - client.subscribe(self.PROVISION_RESPONSE_TOPIC) # Subscribe to provisioning response topic - provision_request = dumps(self.__provision_request, option=OPT_NON_STR_KEYS) - log.info("[Provisioning client] Sending provisioning request %s" % provision_request) - client.publish(self.PROVISION_REQUEST_TOPIC, provision_request) # Publishing provisioning request topic - else: - log.info("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % RESULT_CODES[rc]) - - def __on_message(self, _, __, msg): - decoded_payload = msg.payload.decode("UTF-8") - log.info("[Provisioning client] Received data from ThingsBoard: %s" % decoded_payload) - decoded_message = loads(decoded_payload) - provision_device_status = decoded_message.get("status") - if provision_device_status == "SUCCESS": - self.__credentials = decoded_message - else: - log.error("[Provisioning client] Provisioning was unsuccessful with status %s and message: %s" % ( - provision_device_status, decoded_message["errorMsg"])) - self.disconnect() - - def provision(self): - log.info("[Provisioning client] Connecting to ThingsBoard") - self.__credentials = None - self.connect(self._host, self._port, 60) - self.loop_forever() - - def get_credentials(self): - return self.__credentials diff --git a/tb_gateway_mqtt.py b/tb_gateway_mqtt.py deleted file mode 100644 index 08f7fbd..0000000 --- a/tb_gateway_mqtt.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import logging - -try: - from time import monotonic as time -except ImportError: - from time import time - -from tb_device_mqtt import TBDeviceMqttClient, RateLimit, TBSendMethod - -GATEWAY_ATTRIBUTES_TOPIC = "v1/gateway/attributes" -GATEWAY_TELEMETRY_TOPIC = "v1/gateway/telemetry" -GATEWAY_DISCONNECT_TOPIC = "v1/gateway/disconnect" -GATEWAY_ATTRIBUTES_REQUEST_TOPIC = "v1/gateway/attributes/request" -GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = "v1/gateway/attributes/response" -GATEWAY_MAIN_TOPIC = "v1/gateway/" -GATEWAY_RPC_TOPIC = "v1/gateway/rpc" -GATEWAY_RPC_RESPONSE_TOPIC = "v1/gateway/rpc/response" -GATEWAY_CLAIMING_TOPIC = "v1/gateway/claim" - -log = logging.getLogger("tb_connection") - - -class TBGatewayAPI: - pass - - -class TBGatewayMqttClient(TBDeviceMqttClient): - def __init__(self, host, port=1883, username=None, password=None, gateway=None, quality_of_service=1, client_id="", - messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", - telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", - telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT", - device_messages_rate_limit="DEFAULT_MESSAGES_RATE_LIMIT", - device_telemetry_rate_limit="DEFAULT_TELEMETRY_RATE_LIMIT", - device_telemetry_dp_rate_limit="DEFAULT_TELEMETRY_DP_RATE_LIMIT", **kwargs): - # Added for compatibility with the old versions - if kwargs.get('rate_limit') or kwargs.get('dp_rate_limit'): - messages_rate_limit = messages_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', messages_rate_limit) # noqa - telemetry_rate_limit = telemetry_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', telemetry_rate_limit) # noqa - device_messages_rate_limit = device_messages_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', device_messages_rate_limit) # noqa - device_telemetry_rate_limit = device_telemetry_rate_limit if kwargs.get('rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('rate_limit', device_telemetry_rate_limit) # noqa - telemetry_dp_rate_limit = telemetry_dp_rate_limit if kwargs.get('dp_rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('dp_rate_limit', telemetry_dp_rate_limit) # noqa - device_telemetry_dp_rate_limit = device_telemetry_dp_rate_limit if kwargs.get('dp_rate_limit') == "DEFAULT_RATE_LIMIT" else kwargs.get('dp_rate_limit', device_telemetry_dp_rate_limit) # noqa - - super().__init__(host, port, username, password, quality_of_service, client_id, - messages_rate_limit=messages_rate_limit, telemetry_rate_limit=telemetry_rate_limit, - telemetry_dp_rate_limit=telemetry_dp_rate_limit) - - self.__device_telemetry_rate_limit, self.__device_telemetry_dp_rate_limit = RateLimit.get_rate_limits_by_host( - host, device_telemetry_rate_limit, device_telemetry_dp_rate_limit) - self.__device_messages_rate_limit = RateLimit.get_rate_limit_by_host(host, device_messages_rate_limit) - - self._devices_connected_through_gateway_telemetry_messages_rate_limit = RateLimit(self.__device_telemetry_rate_limit, "Rate limit for devices connected through gateway telemetry messages") # noqa - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit = RateLimit(self.__device_telemetry_dp_rate_limit, "Rate limit for devices connected through gateway telemetry data points") # noqa - self._devices_connected_through_gateway_messages_rate_limit = RateLimit(self.__device_messages_rate_limit, "Rate limit for devices connected through gateway messages") # noqa - - self.service_configuration_callback = self.__on_service_configuration - self.quality_of_service = quality_of_service - self.__max_sub_id = 0 - self.__sub_dict = {} - self.__connected_devices = set("*") - self.devices_server_side_rpc_request_handler = None - self.device_disconnect_callback = None - self._client.on_connect = self._on_connect - self._client.on_message = self._on_message - self._client.on_subscribe = self._on_subscribe - self._client._on_unsubscribe = self._on_unsubscribe - self._gw_subscriptions = {} - self.gateway = gateway - - def _on_connect(self, client, userdata, flags, result_code, properties, *extra_params): - super()._on_connect(client, userdata, flags, result_code, properties, *extra_params) - if result_code == 0: - gateway_attributes_topic_sub_id = int(self._subscribe_to_topic(GATEWAY_ATTRIBUTES_TOPIC, qos=1)[1]) - self._add_or_delete_subscription(GATEWAY_ATTRIBUTES_TOPIC, gateway_attributes_topic_sub_id) - - gateway_attributes_resp_sub_id = int(self._subscribe_to_topic(GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, qos=1)[1]) - self._add_or_delete_subscription(GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, gateway_attributes_resp_sub_id) - - gateway_rpc_topic_sub_id = int(self._subscribe_to_topic(GATEWAY_RPC_TOPIC, qos=1)[1]) - self._add_or_delete_subscription(GATEWAY_RPC_TOPIC, gateway_rpc_topic_sub_id) - - def _on_subscribe(self, client, userdata, mid, reasoncodes, properties=None, *extra_params): - subscription = self._gw_subscriptions.get(mid) - if subscription is not None: - if mid == 128: - self._delete_subscription(subscription, mid) - else: - log.debug("Service subscription to topic %s - successfully completed.", subscription) - del self._gw_subscriptions[mid] - - def _delete_subscription(self, topic, subscription_id): - log.error("Service subscription to topic %s - failed.", topic) - if subscription_id in self._gw_subscriptions: - del self._gw_subscriptions[subscription_id] - - def _add_or_delete_subscription(self, topic, subscription_id): - if subscription_id == 128: - self._delete_subscription(topic, subscription_id) - else: - self._gw_subscriptions[subscription_id] = topic - - @staticmethod - def _on_unsubscribe(*args): - log.debug("Unsubscribe callback called with args: %r", args) - - def get_subscriptions_in_progress(self): - return True if self._gw_subscriptions else False - - def _on_message(self, client, userdata, message): - content = self._decode(message) - super()._on_decoded_message(content, message) - self._on_decoded_message(content, message) - - def _on_decoded_message(self, content, message, **kwargs): - if message.topic.startswith(GATEWAY_ATTRIBUTES_RESPONSE_TOPIC): - with self._lock: - req_id = content["id"] - self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1) - # pop callback and use it - if self._attr_request_dict.get(req_id): - callback = self._attr_request_dict.pop(req_id) - if isinstance(callback, tuple): - callback[0](content, None, callback[1]) - else: - callback(content, None) - else: - log.error("Unable to find callback to process attributes response from platform.") - elif message.topic == GATEWAY_ATTRIBUTES_TOPIC: - with self._lock: - # callbacks for everything - if self.__sub_dict.get("*|*"): - for device in self.__sub_dict["*|*"]: - self.__sub_dict["*|*"][device](content) - # callbacks for device. in this case callback executes for all attributes in message - if content.get("device") is None: - return - self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1) - target = content["device"] + "|*" - if self.__sub_dict.get(target): - for device in self.__sub_dict[target]: - self.__sub_dict[target][device](content) - # callback for atr. in this case callback executes for all attributes in message - targets = [content["device"] + "|" + attribute for attribute in content["data"]] - for target in targets: - if self.__sub_dict.get(target): - for device in self.__sub_dict[target]: - self.__sub_dict[target][device](content) - elif message.topic == GATEWAY_RPC_TOPIC: - self._devices_connected_through_gateway_messages_rate_limit.increase_rate_limit_counter(1) - if self.devices_server_side_rpc_request_handler: - self.devices_server_side_rpc_request_handler(self, content) - elif message.topic == GATEWAY_DISCONNECT_TOPIC: - if content.get("reason"): - reason = content["reason"] - log.info("Device \"%s\" disconnected with reason %s", content["device"], content["reason"]) - if reason == 150: # 150 - Rate limit reached - self._devices_connected_through_gateway_messages_rate_limit.reach_limit() - self._devices_connected_through_gateway_telemetry_messages_rate_limit.reach_limit() - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.reach_limit() - if self.device_disconnect_callback is not None: - self.device_disconnect_callback(self, content) - else: - log.debug("Unexpected message from topic %r, content: %r", message.topic, content) - - def __request_attributes(self, device, keys, callback, type_is_client=False): - - attr_request_number = self._add_attr_request_callback(callback) - msg = {"keys": keys, - "device": device, - "client": type_is_client, - "id": attr_request_number} - info = self._send_device_request(TBSendMethod.PUBLISH, device, topic=GATEWAY_ATTRIBUTES_REQUEST_TOPIC, data=msg, - qos=1) - self.add_attrs_request_timeout(attr_request_number, int(time()) + 20) - return info - - def _send_device_request(self, _type, device_name, **kwargs): - if _type == TBSendMethod.PUBLISH: - topic = kwargs['topic'] - device_msg_rate_limit = self._devices_connected_through_gateway_messages_rate_limit - device_dp_rate_limit = self.EMPTY_RATE_LIMIT - if topic == GATEWAY_TELEMETRY_TOPIC or topic == GATEWAY_ATTRIBUTES_TOPIC: - device_msg_rate_limit = self._devices_connected_through_gateway_telemetry_messages_rate_limit - device_dp_rate_limit = self._devices_connected_through_gateway_telemetry_datapoints_rate_limit - info = self._publish_data(**kwargs, device=device_name, - msg_rate_limit=device_msg_rate_limit, - dp_rate_limit=device_dp_rate_limit) - return info - - def gw_request_shared_attributes(self, device_name, keys, callback): - return self.__request_attributes(device_name, keys, callback, False) - - def gw_request_client_attributes(self, device_name, keys, callback): - return self.__request_attributes(device_name, keys, callback, True) - - def gw_send_attributes(self, device, attributes, quality_of_service=1): - return self._send_device_request(TBSendMethod.PUBLISH, - device, - topic=GATEWAY_ATTRIBUTES_TOPIC, - data={device: attributes}, - qos=quality_of_service) - - def gw_send_telemetry(self, device, telemetry, quality_of_service=1): - if not isinstance(telemetry, list): - telemetry = [telemetry] - - return self._send_device_request(TBSendMethod.PUBLISH, - device, - topic=GATEWAY_TELEMETRY_TOPIC, - data={device: telemetry}, - qos=quality_of_service) - - def gw_connect_device(self, device_name, device_type="default"): - info = self._send_device_request(TBSendMethod.PUBLISH, device_name, topic=GATEWAY_MAIN_TOPIC + "connect", - data={"device": device_name, "type": device_type}, - qos=self.quality_of_service) - - self.__connected_devices.add(device_name) - - log.debug("Connected device %s", device_name) - return info - - def gw_disconnect_device(self, device_name): - info = self._send_device_request(TBSendMethod.PUBLISH, device_name, topic=GATEWAY_MAIN_TOPIC + "disconnect", - data={"device": device_name}, qos=self.quality_of_service) - - if device_name in self.__connected_devices: - self.__connected_devices.remove(device_name) - - log.debug("Disconnected device %s", device_name) - return info - - def gw_subscribe_to_all_attributes(self, callback): - return self.gw_subscribe_to_attribute("*", "*", callback) - - def gw_subscribe_to_all_device_attributes(self, device, callback): - return self.gw_subscribe_to_attribute(device, "*", callback) - - def gw_subscribe_to_attribute(self, device, attribute, callback): - if device not in self.__connected_devices: - log.error("Device %s is not connected", device) - return False - - with self._lock: - self.__max_sub_id += 1 - key = device + "|" + attribute - if key not in self.__sub_dict: - self.__sub_dict.update({key: {device: callback}}) - else: - self.__sub_dict[key].update({device: callback}) - - log.info("Subscribed to %s with id %i for device %s", key, self.__max_sub_id, device) - return self.__max_sub_id - - def gw_unsubscribe(self, subscription_id): - with self._lock: - for attribute in self.__sub_dict: - if self.__sub_dict[attribute].get(subscription_id): - del self.__sub_dict[attribute][subscription_id] - log.info("Unsubscribed from %s, subscription id %r", attribute, subscription_id) - if subscription_id == '*': - self.__sub_dict = {} - - def gw_set_server_side_rpc_request_handler(self, handler): - self.devices_server_side_rpc_request_handler = handler - - def gw_send_rpc_reply(self, device, req_id, resp, quality_of_service=None): - if quality_of_service is None: - quality_of_service = self.quality_of_service - if quality_of_service not in (0, 1): - log.error("Quality of service (qos) value must be 0 or 1") - return None - - info = self._send_device_request(TBSendMethod.PUBLISH, device, topic=GATEWAY_RPC_TOPIC, - data={"device": device, "id": req_id, "data": resp}, - qos=quality_of_service) - return info - - def gw_claim(self, device_name, secret_key, duration, claiming_request=None): - if claiming_request is None: - claiming_request = { - device_name: { - "secretKey": secret_key, - "durationMs": duration - } - } - - info = self._send_device_request(TBSendMethod.PUBLISH, device_name, topic=GATEWAY_CLAIMING_TOPIC, - data=claiming_request, qos=self.quality_of_service) - return info - - def __on_service_configuration(self, _, response, *args, **kwargs): - if "error" in response: - log.warning("Timeout while waiting for service configuration!, session will use default configuration.") - self.rate_limits_received = True - return - service_config = response - gateway_devices_rate_limit_config = service_config.pop('gatewayRateLimits', {}) - gateway_device_itself_rate_limit_config = service_config.pop('rateLimits', {}) - - if gateway_devices_rate_limit_config.get("messages"): - self._devices_connected_through_gateway_messages_rate_limit.set_limit( - gateway_devices_rate_limit_config.get("messages")) - else: - self._devices_connected_through_gateway_messages_rate_limit.set_limit('0:0,') - - if gateway_devices_rate_limit_config.get('telemetryMessages'): - self._devices_connected_through_gateway_telemetry_messages_rate_limit.set_limit( - gateway_devices_rate_limit_config.get('telemetryMessages')) - else: - self._devices_connected_through_gateway_telemetry_messages_rate_limit.set_limit('0:0,') - - if gateway_devices_rate_limit_config.get('telemetryDataPoints'): - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.set_limit( - gateway_devices_rate_limit_config.get('telemetryDataPoints')) - else: - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.set_limit('0:0,') - - gateway_limits_present = any( - [self._devices_connected_through_gateway_messages_rate_limit.has_limit(), - self._devices_connected_through_gateway_telemetry_messages_rate_limit.has_limit(), - self._devices_connected_through_gateway_telemetry_datapoints_rate_limit.has_limit()] - ) - - super().on_service_configuration(_, - {'rateLimits': gateway_device_itself_rate_limit_config, **service_config}, - *args, - **kwargs, - gateway_limits_present=gateway_limits_present) - log.info("Current limits for devices connected through the gateway: %r", gateway_devices_rate_limit_config) diff --git a/tb_mqtt_client/__init__.py b/tb_mqtt_client/__init__.py new file mode 100644 index 0000000..fa669aa --- /dev/null +++ b/tb_mqtt_client/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tb_mqtt_client/common/__init__.py b/tb_mqtt_client/common/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/common/async_utils.py b/tb_mqtt_client/common/async_utils.py new file mode 100644 index 0000000..70cf553 --- /dev/null +++ b/tb_mqtt_client/common/async_utils.py @@ -0,0 +1,137 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +from typing import Union, Optional, Any, List, Set, Dict + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult + +logger = get_logger(__name__) + + +class FutureMap: + def __init__(self): + self._child_to_parents: Dict[asyncio.Future, Set[asyncio.Future]] = {} + self._parent_to_remaining: Dict[asyncio.Future, Set[asyncio.Future]] = {} + + def register(self, parent: asyncio.Future, children: List[asyncio.Future]): + if parent in self._parent_to_remaining: + self._parent_to_remaining[parent].update(children) + else: + self._parent_to_remaining[parent] = set(children) + + for child in children: + self._child_to_parents.setdefault(child, set()).add(parent) + + def get_parents(self, child: asyncio.Future) -> List[asyncio.Future]: + return list(self._child_to_parents.get(child, [])) + + def child_resolved(self, child: asyncio.Future): + parents = self._child_to_parents.pop(child, set()) + for parent in parents: + remaining = self._parent_to_remaining.get(parent) + if remaining is not None: + remaining.discard(child) + if not remaining and not parent.done(): + all_children = list(remaining) + [child] + results = [] + for f in all_children: + if f.done() and not f.cancelled(): + result = f.result() + if isinstance(result, PublishResult): + results.append(result) + + if results: + parent.set_result(PublishResult.merge(results)) + else: + parent.set_result(None) + self._parent_to_remaining.pop(parent, None) + + +future_map = FutureMap() + + +async def await_or_stop(future_or_coroutine: Union[asyncio.Future, asyncio.Task, Any], + stop_event: asyncio.Event, + timeout: Optional[float]) -> Optional[Any]: + if asyncio.iscoroutine(future_or_coroutine): + main_task = asyncio.create_task(future_or_coroutine) + elif asyncio.isfuture(future_or_coroutine): + if future_or_coroutine.done(): + return future_or_coroutine.result() + main_task = future_or_coroutine + else: + raise TypeError("Expected coroutine or Future/Task") + + stop_task = asyncio.create_task(stop_event.wait()) + + if timeout is not None and timeout < 0: + timeout = None + + try: + done, _ = await asyncio.wait( + [main_task, stop_task], + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED + ) + if main_task in done: + return await main_task + if stop_task in done: + if stop_event.is_set(): + return None + if timeout is not None and not done: + raise asyncio.TimeoutError("Operation timed out") + except asyncio.CancelledError: + return None + finally: + if not stop_task.done(): + stop_task.cancel() + + +def run_coroutine_sync(coroutine, timeout: float = 3.0, raise_on_timeout: bool = False): + """ + Run async coroutine and return its result from a sync function even if event loop is running. + :param coroutine: async function with no arguments (like: lambda: some_async_fn()) + :param timeout: max wait time in seconds + :param raise_on_timeout: if True, raise TimeoutError on timeout; otherwise return None + """ + result_container = {} + event = threading.Event() + + async def wrapper(): + try: + result = await coroutine() + result_container['result'] = result + except Exception as e: + result_container['error'] = e + finally: + event.set() + + loop = asyncio.get_running_loop() + loop.create_task(wrapper()) + + completed = event.wait(timeout=timeout) + + if not completed: + logger.warning("Timeout while waiting for coroutine to finish: %s", coroutine) + if raise_on_timeout: + raise TimeoutError(f"Coroutine {coroutine} did not complete in {timeout} seconds.") + return None + + if 'error' in result_container: + raise result_container['error'] + + return result_container.get('result') diff --git a/tb_mqtt_client/common/config_loader.py b/tb_mqtt_client/common/config_loader.py new file mode 100644 index 0000000..4b56d35 --- /dev/null +++ b/tb_mqtt_client/common/config_loader.py @@ -0,0 +1,120 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional + + +class DeviceConfig: + """ + Configuration class for ThingsBoard device clients. + This class loads configuration options from environment variables, allowing for flexible deployment + and easy customization of device connection settings. + """ + + def __init__(self, config=None): + self.host = None + self.port = 1883 + self.access_token: Optional[str] = None + self.username: Optional[str] = None + self.password: Optional[str] = None + self.client_id: Optional[str] = None + self.ca_cert: Optional[str] = None + self.client_cert: Optional[str] = None + self.private_key: Optional[str] = None + self.qos: int = 1 + + if config is not None: + self.host: str = config.get("host", "localhost") + self.port: int = config.get("port", 1883) + self.access_token: Optional[str] = config.get("access_token") + self.username: Optional[str] = config.get("username") + self.password: Optional[str] = config.get("password") + self.client_id: Optional[str] = config.get("client_id") + self.ca_cert: Optional[str] = config.get("ca_cert") + self.client_cert: Optional[str] = config.get("client_cert") + self.private_key: Optional[str] = config.get("private_key") + + self.host: str = os.getenv("TB_HOST", self.host) + self.port: int = int(os.getenv("TB_PORT", self.port)) + + # Authentication options + self.access_token: Optional[str] = os.getenv("TB_ACCESS_TOKEN", self.access_token) + self.username: Optional[str] = os.getenv("TB_USERNAME", self.username) + self.password: Optional[str] = os.getenv("TB_PASSWORD", self.password) + + # Optional + self.client_id: Optional[str] = os.getenv("TB_CLIENT_ID", self.client_id) + + # TLS options + self.ca_cert: Optional[str] = os.getenv("TB_CA_CERT", self.ca_cert) + self.client_cert: Optional[str] = os.getenv("TB_CLIENT_CERT", self.client_cert) + self.private_key: Optional[str] = os.getenv("TB_PRIVATE_KEY", self.private_key) + + # Default values + self.qos: int = int(os.getenv("TB_QOS", self.qos)) + + def use_tls_auth(self) -> bool: + return all([self.ca_cert, self.client_cert, self.private_key]) + + def use_tls(self) -> bool: + return self.ca_cert is not None + + def __repr__(self): + return (f"DeviceConfig(host={self.host}, port={self.port}, " + f"auth={'token' if self.access_token else 'user/pass'} " + f"client_id={self.client_id} " + f"tls={self.use_tls()})") + + +class GatewayConfig(DeviceConfig): + """ + Configuration class for ThingsBoard gateway clients. + This class extends DeviceConfig to include additional options specific to gateways. + """ + + def __init__(self, config=None): + # TODO: REFACTOR, temporary solution for development + super().__init__(config) + + if os.getenv("TB_GW_HOST") is not None: + self.host: str = os.getenv("TB_GW_HOST") + if os.getenv("TB_GW_PORT") is not None: + self.port: int = int(os.getenv("TB_GW_PORT", 1883)) + + if os.getenv("TB_GW_ACCESS_TOKEN") is not None: + self.access_token: Optional[str] = os.getenv("TB_GW_ACCESS_TOKEN") + if os.getenv("TB_GW_USERNAME") is not None: + self.username: Optional[str] = os.getenv("TB_GW_USERNAME") + if os.getenv("TB_GW_PASSWORD") is not None: + self.password: Optional[str] = os.getenv("TB_GW_PASSWORD") + + if os.getenv("TB_GW_CLIENT_ID") is not None: + self.client_id: Optional[str] = os.getenv("TB_GW_CLIENT_ID") + + if os.getenv("TB_GW_CA_CERT") is not None: + self.ca_cert: Optional[str] = os.getenv("TB_GW_CA_CERT") + if os.getenv("TB_GW_CLIENT_CERT") is not None: + self.client_cert: Optional[str] = os.getenv("TB_GW_CLIENT_CERT") + if os.getenv("TB_GW_PRIVATE_KEY") is not None: + self.private_key: Optional[str] = os.getenv("TB_GW_PRIVATE_KEY") + + if os.getenv("TB_GW_QOS") is not None: + self.qos: int = int(os.getenv("TB_GW_QOS", 1)) + + def __repr__(self): + return (f"GatewayConfig(host={self.host}, port={self.port}, " + f"auth={'token' if self.access_token else 'user/pass'} " + f"client_id={self.client_id} " + f"tls={self.use_tls()})") diff --git a/tb_mqtt_client/common/exceptions.py b/tb_mqtt_client/common/exceptions.py new file mode 100644 index 0000000..d91193a --- /dev/null +++ b/tb_mqtt_client/common/exceptions.py @@ -0,0 +1,77 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from typing import Callable, Dict, List, Optional, Type + +logger = logging.getLogger(__name__) + +ExceptionCallback = Callable[[BaseException, Optional[dict]], None] + + +class ExceptionHandler: + def __init__(self): + self._callbacks: Dict[Type[BaseException], List[ExceptionCallback]] = {} + self._default_callbacks: List[ExceptionCallback] = [] + + def register(self, exc_type: Optional[Type[BaseException]], callback: ExceptionCallback): + """ + Register a callback for a specific exception type or all (if exc_type is None). + """ + if exc_type is None: + self._default_callbacks.append(callback) + else: + self._callbacks.setdefault(exc_type, []).append(callback) + + def handle(self, exc: BaseException, context: Optional[dict] = None): + """ + Dispatch the exception to the appropriate registered callbacks. + """ + handled = False + for exc_type, callbacks in self._callbacks.items(): + if isinstance(exc, exc_type): + for cb in callbacks: + cb(exc, context) + handled = True + if not handled: + for cb in self._default_callbacks: + cb(exc, context) + + def install_asyncio_handler(self, loop: Optional[asyncio.AbstractEventLoop] = None): + """ + Hook into asyncio event loop to catch global task errors. + """ + loop = loop or asyncio.get_event_loop() + + def _asyncio_handler(loop_, context: dict): + exception = context.get("exception") + if exception: + self.handle(exception, context) + else: + logger.error("Unhandled asyncio context: %s", context) + + loop.set_exception_handler(_asyncio_handler) + + +class BackpressureException(Exception): + """ + Exception raised when the client is under backpressure. + This should be used to signal that the client cannot process more messages at the moment. + """ + def __init__(self, message: str = "Client is under backpressure. Please retry later."): + super().__init__(message) + + +exception_handler = ExceptionHandler() diff --git a/tb_mqtt_client/common/gmqtt_patch.py b/tb_mqtt_client/common/gmqtt_patch.py new file mode 100644 index 0000000..735afd0 --- /dev/null +++ b/tb_mqtt_client/common/gmqtt_patch.py @@ -0,0 +1,450 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import heapq +import struct +from collections import defaultdict +from typing import Callable, Tuple, Optional + +from gmqtt import Client +from gmqtt.mqtt.constants import MQTTCommands, MQTTv50, MQTTv311 +from gmqtt.mqtt.handler import MqttPackageHandler, MQTTConnectError +from gmqtt.mqtt.package import PackageFactory +from gmqtt.mqtt.property import Property +from gmqtt.mqtt.protocol import BaseMQTTProtocol, MQTTProtocol +from gmqtt.mqtt.utils import unpack_variable_byte_integer, pack_variable_byte_integer + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage + +logger = get_logger(__name__) + + +class PublishPacket(PackageFactory): + @classmethod + def build_package(cls, message: MqttPublishMessage, protocol, mid: int = None) -> Tuple[int, bytes]: + dup_flag = 1 if message.dup else 0 + + command = MQTTCommands.PUBLISH | (dup_flag << 3) | (message.qos << 1) | (message.retain & 0x1) + + packet = bytearray() + packet.append(command) + + remaining_length = 2 + len(message.topic) + message.payload_size + prop_bytes = cls._build_properties_data(message.properties, protocol_version=protocol.proto_ver) + remaining_length += len(prop_bytes) + + if message.payload_size == 0: + logger.debug("Sending PUBLISH (q%d), '%s' (NULL payload)", message.qos, message.topic) + else: + logger.debug("Sending PUBLISH (q%d), '%s', ... (%d bytes)", message.qos, message.topic, + message.payload_size) + + if message.qos > 0: + remaining_length += 2 + + packet.extend(pack_variable_byte_integer(remaining_length)) + cls._pack_str16(packet, message.topic) + + if message.qos > 0: + if mid is None: + mid = cls.id_generator.next_id() + packet.extend(struct.pack("!H", mid)) + + packet.extend(prop_bytes) + packet.extend(message.payload) + + return mid, packet + + +class PatchUtils: + DISCONNECT_REASON_CODES = { + 0: "Normal disconnection", + 4: "Disconnect with Will Message", + 128: "Unspecified error", + 129: "Malformed Packet", + 130: "Protocol Error", + 131: "Implementation specific error", + 132: "Not authorized", + 133: "Server busy", + 134: "Bad credentials", + 135: "Keep Alive timeout", + 136: "Session taken over", + 137: "Topic Filter invalid", + 138: "Topic Name invalid", + 139: "Receive Maximum exceeded", + 140: "Topic Alias invalid", + 141: "Packet too large", + 142: "Session taken over", + 143: "Quota exceeded", + 144: "Administrative action", + 145: "Payload format invalid", + 146: "Retain not supported", + 147: "QoS not supported", + 148: "Use another server", + 149: "Server moved", + 150: "Shared Subscriptions not supported", + 151: "Connection rate exceeded", + 152: "Administrative action", + 153: "Subscription Identifiers not supported", + 154: "Wildcard Subscriptions not supported" + } + + def __init__(self, client: Optional[Client], stop_event: asyncio.Event, retry_interval: int = 15): + """ + Initialize PatchUtils with a client and retry interval. + + :param client: The MQTT client instance to patch. + :param retry_interval: Interval in seconds to retry connection. + """ + self.client = client + self.retry_interval = retry_interval + self._stop_event = stop_event + self._retry_task = None + + @staticmethod + def parse_mqtt_properties(packet: bytes) -> dict: + """ + Parse MQTT 5.0 properties from a packet. + """ + properties_dict = defaultdict(list) + + try: + properties_len, rest = unpack_variable_byte_integer(packet) + props = rest[:properties_len] # slice out exactly the properties section + + while props: + property_identifier = props[0] + property_obj = Property.factory(id_=property_identifier) + if property_obj is None: + logger.warning(f"Unknown property id={property_identifier}") + break + + result, props = property_obj.loads(props[1:]) + for k, v in result.items(): + properties_dict[k].append(v) + + except Exception as e: + logger.warning("Failed to parse properties: %s", e) + + return dict(properties_dict) + + @staticmethod + def extract_reason_code(packet): + """ + Extract the reason code from a disconnect packet. + + :param packet: The disconnect packet, which can be an object with a reason_code attribute or raw bytes + :return: The reason code if found, None otherwise + """ + reason_code = None + if packet: + if hasattr(packet, 'reason_code'): + reason_code = packet.reason_code + elif isinstance(packet, bytes) and len(packet) >= 2: + reason_code = packet[1] + + return reason_code + + def patch_mqtt_handler_disconnect(patch_utils_instance): + """ + Monkey-patch gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet to properly + handle server-initiated disconnect messages. + """ + try: + # Store the original method + original_handle_disconnect = MqttPackageHandler._handle_disconnect_packet # noqa + + # Define the patched method + def patched_handle_disconnect_packet(self, cmd, packet): + # Extract reason code + reason_code = 0 + if packet and len(packet) >= 1: + reason_code = packet[0] + + # Parse properties if available + properties = {} + if packet and len(packet) > 1: + try: + properties = PatchUtils.parse_mqtt_properties(packet[1:]) + except Exception as exc: + logger.warning("Failed to parse properties from disconnect packet: %s", exc) + + reason_desc = PatchUtils.DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + logger.trace("Server initiated disconnect with reason code: %s (%s)", reason_code, reason_desc) + + # Call the original method to handle reconnection + # But don't call the on_disconnect callback, as we'll do that ourselves + # with the extracted reason_code and properties + self._clear_topics_aliases() + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + + # Call the on_disconnect callback with the client, reason_code, properties, and None for exc + # since this is a server-initiated disconnect + self.on_disconnect(self, reason_code, properties, None) + + # Set a flag on the connection object to indicate that on_disconnect has been called + self._connection._on_disconnect_called = True + + # Apply the patch + MqttPackageHandler._handle_disconnect_packet = patched_handle_disconnect_packet + logger.debug("Successfully patched gmqtt.mqtt.handler.MqttPackageHandler._handle_disconnect_packet") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt handler: %s", e) + return False + + def patch_handle_connack(patch_utils_instance): + """ + Fully replaces gmqtt's _handle_connack_packet implementation, skipping internal QoS1 resend behavior + and invoking a custom callback instead of calling the original method. + """ + try: + def new_handle_connack_packet(self, cmd, packet): + try: + self._connected.set() + + (session_present, result_code) = struct.unpack("!BB", packet[:2]) + + if result_code != 0: + self._logger.warning('[CONNACK] %s', hex(result_code)) + self.failed_connections += 1 + if result_code == 1 and self.protocol_version == MQTTv50: + self._logger.info('[CONNACK] Downgrading to MQTT 3.1 protocol version') + MQTTProtocol.proto_ver = MQTTv311 + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + return + else: + self._error = MQTTConnectError(result_code) + asyncio.ensure_future(self.reconnect(delay=True)) + return + else: + self.failed_connections = 0 + + if len(packet) > 2: + properties, _ = self._parse_properties(packet[2:]) + if properties is None: + self._error = MQTTConnectError(10) + asyncio.ensure_future(self.disconnect()) + self._connack_properties = properties + self._update_keepalive_if_needed() + else: + properties = {} + + self._logger.debug('[CONNACK] session_present: %s, result: %s', hex(session_present), + hex(result_code)) + + self.on_connect(self, session_present, result_code, self.properties) + + except Exception as e: + logger.exception("Error while handling CONNACK packet: %s", e) + + MqttPackageHandler._handle_connack_packet = new_handle_connack_packet + logger.debug("Successfully patched gmqtt.mqtt.handler._handle_connack_packet (full replacement)") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt handler: %s", e) + return False + + def patch_gmqtt_protocol_connection_lost(patch_utils_instance): + """ + Monkey-patch gmqtt.mqtt.protocol.BaseMQTTProtocol.connection_lost to suppress the + default "[CONN CLOSE NORMALLY]" log message, as we handle disconnect logging in our code. + + Also, patch MQTTProtocol.connection_lost to include the reason code in the DISCONNECT package + and pass the exception to the handler. + """ + try: + + def patched_base_connection_lost(self, exc): + self._connected.clear() + super(BaseMQTTProtocol, self).connection_lost(exc) + + BaseMQTTProtocol.connection_lost = patched_base_connection_lost + + def patched_mqtt_connection_lost(self, exc): + super(MQTTProtocol, self).connection_lost(exc) + reason_code = 0 + properties = {} + + if exc: + if isinstance(exc, ConnectionRefusedError): + reason_code = 135 # Keep Alive timeout + elif isinstance(exc, TimeoutError): + reason_code = 135 # Keep Alive timeout + elif isinstance(exc, ConnectionResetError): + reason_code = 139 # Receive Maximum exceeded + elif isinstance(exc, ConnectionAbortedError): + reason_code = 136 # Session taken over + elif isinstance(exc, PermissionError): + reason_code = 132 # Not authorized + elif isinstance(exc, OSError): + reason_code = 130 # Protocol Error + else: + reason_code = 131 # Implementation specific error + + if hasattr(exc, 'args') and exc.args: + properties['reason_string'] = [str(exc.args[0])] + + payload = struct.pack('!B', reason_code) + + self._connection._disconnect_exc = exc + self._connection._disconnect_properties = properties + + self._connection.put_package((MQTTCommands.DISCONNECT, payload)) + + if self._read_loop_future is not None: + self._read_loop_future.cancel() + self._read_loop_future = None + + self._queue = asyncio.Queue() + + MQTTProtocol.connection_lost = patched_mqtt_connection_lost + + original_call = MqttPackageHandler.__call__ + + def patched_call(self, cmd, packet): + try: + if cmd == MQTTCommands.DISCONNECT and hasattr(self._connection, '_disconnect_exc'): + reason_code = 0 + if packet and len(packet) >= 1: + reason_code = packet[0] + + properties = getattr(self._connection, '_disconnect_properties', {}) + exc = getattr(self._connection, '_disconnect_exc', None) + + if (not hasattr(self._connection, '_on_disconnect_called') + or not self._connection._on_disconnect_called): # noqa + self._clear_topics_aliases() + future = asyncio.ensure_future(self.reconnect(delay=True)) + future.add_done_callback(self._handle_exception_in_future) + self.on_disconnect(self, reason_code, properties, exc) + return None + + return original_call(self, cmd, packet) + except Exception as exception: + logger.error('[ERROR HANDLE PKG]', exc_info=exception) + return None + + MqttPackageHandler.__call__ = patched_call + + logger.debug("Successfully patched gmqtt.mqtt.protocol connection_lost methods") + return True + except (ImportError, AttributeError) as e: + logger.warning("Failed to patch gmqtt protocol: %s", e) + return False + + def patch_puback_handling(self, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): + original_handler = MqttPackageHandler._handle_puback_packet + + def wrapped_handle_puback(self, cmd, packet): + try: + mid = struct.unpack("!H", packet[:2])[0] + reason_code = 0 + properties = {} + if len(packet) > 2: + reason_code = packet[2] + if len(packet) > 3: + props_payload = packet[3:] + properties = PatchUtils.parse_mqtt_properties(props_payload) + logger.trace("PUBACK received for mid=%r, reason=%r", mid, reason_code) + if hasattr(self._connection, 'persistent_storage'): + self._connection.persistent_storage.remove(mid) + on_puback_with_reason_and_properties(mid, reason_code, properties) + except Exception as e: + logger.exception("Error while handling PUBACK with properties: %s", e) + return original_handler(self, cmd, packet) + + MqttPackageHandler._handle_puback_packet = wrapped_handle_puback + logger.debug("Patched _handle_puback_packet for QoS1 support.") + + def patch_storage(patch_utils_instance): + async def pop_message_with_tm(): + (tm, mid, raw_package) = heapq.heappop(patch_utils_instance.client._persistent_storage._queue) + + patch_utils_instance.client._persistent_storage._check_empty() + return tm, mid, raw_package + + patch_utils_instance.client._persistent_storage.pop_message = pop_message_with_tm + + async def _retry_loop(self): + logger.debug("QoS1 retry loop started.") + self.patch_storage() + try: + while not self._stop_event.is_set(): + current_tm = asyncio.get_event_loop().time() + + for _ in range(100): + if self._stop_event.is_set(): + break + + try: + msg = await asyncio.wait_for(self.client._persistent_storage.pop_message(), timeout=0.1) + except asyncio.TimeoutError: + break + except IndexError: + break + except Exception as e: + logger.warning("Error popping message: %s", e) + break + + if msg is None: + break + + tm, mid, mqtt_msg = msg + + if current_tm - tm > self.retry_interval and self.client.is_connected: + mqtt_msg.dup = True + logger.error("Resending PUBLISH message with mid=%r, topic=%s", mid, mqtt_msg.topic) + + try: + await self.client.put_retry_message( + mqtt_msg) # noqa This method sets in message service to the client + except AttributeError as e: + logger.trace("Failed to resend message with mid=%r: %s", mid, e) + + await asyncio.sleep(self.retry_interval) + except asyncio.CancelledError: + logger.debug("Retry loop cancelled.") + except Exception as e: + logger.exception("Unexpected error in retry loop: %s", e) + finally: + logger.debug("QoS1 retry loop stopped.") + + def start_retry_task(self): + if not self._stop_event.is_set() and not self._retry_task: + self._retry_task = asyncio.create_task(self._retry_loop()) + logger.debug("Retry task started.") + + async def stop_retry_task(self): + if self._retry_task: + self._stop_event.set() + try: + await asyncio.wait_for(self._retry_task, timeout=2) + except asyncio.TimeoutError: + logger.debug("Retry task did not finish in time, cancelling...") + self._retry_task.cancel() + try: + await self._retry_task + except asyncio.CancelledError: + logger.debug("Retry task cancelled.") + self._retry_task = None + self._stop_event.clear() + + def apply(self, on_puback_with_reason_and_properties: Callable[[int, int, dict], None]): + self.patch_puback_handling(on_puback_with_reason_and_properties) + self.start_retry_task() diff --git a/utils.py b/tb_mqtt_client/common/install_package_utils.py similarity index 67% rename from utils.py rename to tb_mqtt_client/common/install_package_utils.py index 8d6b1ec..e2f0438 100644 --- a/utils.py +++ b/tb_mqtt_client/common/install_package_utils.py @@ -1,19 +1,20 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from sys import executable from subprocess import check_call, CalledProcessError, DEVNULL +from sys import executable + from pkg_resources import get_distribution, DistributionNotFound diff --git a/tb_mqtt_client/common/logging_utils.py b/tb_mqtt_client/common/logging_utils.py new file mode 100644 index 0000000..698bd5e --- /dev/null +++ b/tb_mqtt_client/common/logging_utils.py @@ -0,0 +1,61 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import sys +from typing import Optional + +DEFAULT_LOG_FORMAT = "[%(asctime)s.%(msecs)03d] [%(levelname)s] %(name)s - %(lineno)d - %(message)s" +DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +TRACE_LEVEL = 5 +logging.addLevelName(TRACE_LEVEL, "TRACE") + + +class ExtendedLogger(logging.Logger): + """ + Custom logger class that supports TRACE level logging. + """ + + def trace(self, message, *args, **kwargs): + if self.isEnabledFor(TRACE_LEVEL): + self._log(TRACE_LEVEL, message, args, **kwargs) + + +logging.setLoggerClass(ExtendedLogger) + + +def configure_logging(level: int = logging.INFO, + log_format: str = DEFAULT_LOG_FORMAT, + date_format: str = DEFAULT_DATE_FORMAT, + stream=sys.stdout): + """ + Configures the root logger with a stream handler and standardized format. + Should be called once during app startup. + """ + logging.basicConfig( + level=level, + format=log_format, + datefmt=date_format, + handlers=[ + logging.StreamHandler(stream) + ] + ) + + +def get_logger(name: Optional[str] = None) -> ExtendedLogger: + """ + Returns a logger instance with the given name. + """ + return logging.getLogger(name or __name__) # noqa diff --git a/tb_mqtt_client/common/mqtt_message.py b/tb_mqtt_client/common/mqtt_message.py new file mode 100644 index 0000000..922267b --- /dev/null +++ b/tb_mqtt_client/common/mqtt_message.py @@ -0,0 +1,99 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Future +from time import time +from typing import Union, Optional +from uuid import uuid4, UUID + +from gmqtt import Message + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage + +logger = get_logger(__name__) + + +class MqttPublishMessage(Message): + """ + A custom Publish MQTT message class that extends the gmqtt Message class. + Contains additional information like datapoints, to avoid rate limits exceeding. + """ + + def __init__(self, + topic: str, + payload: Union[bytes, GatewayUplinkMessage, DeviceUplinkMessage], + qos: int = 1, + retain: bool = False, + datapoints: int = 0, + delivery_futures=None, + main_ts: Optional[int] = None, + original_payload=None, + **kwargs): + """ + Initialize the MqttMessage with topic, payload, QoS, retain flag, and datapoints. + """ + self.uuid: UUID = uuid4() + self.payload_size = 0 + self.prepared = False + self.original_payload = original_payload if original_payload is not None else payload + self.main_ts = main_ts if main_ts is not None else int(time() * 1000) + if isinstance(payload, bytes): + super().__init__(topic, payload, qos, retain) + else: + self.payload = self.original_payload + payload.set_main_ts(self.main_ts) + self.topic = topic + self.is_service_message = self.topic not in mqtt_topics.TOPICS_WITH_DATAPOINTS_CHECK + self.is_device_message = not isinstance(original_payload, GatewayUplinkMessage) + self.qos = qos if qos is not None else 1 + if self.qos < 0 or self.qos > 1: + logger.warning(f"Invalid QoS {self.qos} for topic {topic}, using default QoS 1") + self.qos = 1 + self.dup = False + self.retain = retain + self.message_id = None + self.datapoints = datapoints + self.properties = kwargs + self._is_sent = False + self.delivery_futures = delivery_futures + if not delivery_futures: + delivery_future = Future() + delivery_future.uuid = uuid4() + self.delivery_futures = [delivery_future] + if not isinstance(self.delivery_futures, (list, tuple)): + self.delivery_futures = [self.delivery_futures] + for future in self.delivery_futures: + if not hasattr(future, 'uuid'): + future.uuid = uuid4() + logger.trace(f"Created MqttMessage with topic: {topic}, payload type: {type(payload).__name__}, " + f"datapoints: {datapoints}, delivery_future id: {self.delivery_futures[0].uuid}") + + def mark_as_sent(self, message_id: int): + """Mark the message as sent.""" + self.message_id = message_id + self._is_sent = True + + def __str__(self): + return f"MqttPublishMessage(topic={self.topic}, payload_size={self.payload_size}, " \ + f"datapoints={self.datapoints}, qos={self.qos}, retain={self.retain}, " \ + f"main_ts={self.main_ts}, is_device_message={self.is_device_message})" + + def __repr__(self): + return f"MqttPublishMessage(topic={self.topic}, payload={self.payload}, payload_size={self.payload_size}, " \ + f"datapoints={self.datapoints}, qos={self.qos}, retain={self.retain}, " \ + f"main_ts={self.main_ts}, is_device_message={self.is_device_message}, " \ + f"delivery_futures={self.delivery_futures})" diff --git a/tb_mqtt_client/common/provisioning_client.py b/tb_mqtt_client/common/provisioning_client.py new file mode 100644 index 0000000..12affc9 --- /dev/null +++ b/tb_mqtt_client/common/provisioning_client.py @@ -0,0 +1,69 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Event +from typing import Union, Optional + +from gmqtt import Client as GMQTTClient + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter + +logger = get_logger(__name__) + + +class ProvisioningClient: + def __init__(self, host: str, port: int, provision_request: ProvisioningRequest): + self._log = logger + self._stop_event = Event() + self._host = host + self._port = port + self._provision_request = provision_request + self._client_id = "provision" + self._client = GMQTTClient(self._client_id) + self._client.on_connect = self._on_connect + self._client.on_message = self._on_message + self._provisioned = Event() + self._provisioning_response: Optional[ProvisioningResponse] = None + self.__message_adapter = JsonMessageAdapter() + + def _on_connect(self, client, _, rc, __): + if rc == 0: + self._log.debug("[Provisioning client] Connected to ThingsBoard") + client.subscribe(PROVISION_RESPONSE_TOPIC) + provision_message = self.__message_adapter.build_provision_request(self._provision_request) + self._log.debug("[Provisioning client] Sending provisioning request %s", provision_message) + client.publish(provision_message.topic, provision_message.payload) + else: + self._provisioning_response = ProvisioningResponse.build(self._provision_request, + {'status': 'FAILURE', + 'errorMsg': 'Cannot connect to ThingsBoard!'}) + self._provisioned.set() + self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) + + async def _on_message(self, _, __, payload, ___, ____): + self._provisioning_response = self.__message_adapter.parse_provisioning_response(self._provision_request, payload) + + await self._client.disconnect() + self._provisioned.set() + + async def provision(self): + await self._client.connect(self._host, self._port) + await self._provisioned.wait() + + return self._provisioning_response diff --git a/tb_mqtt_client/common/publish_result.py b/tb_mqtt_client/common/publish_result.py new file mode 100644 index 0000000..cda1589 --- /dev/null +++ b/tb_mqtt_client/common/publish_result.py @@ -0,0 +1,81 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + + +class PublishResult: + def __init__(self, topic: str, qos: int, message_id: int, payload_size: int, reason_code: int, + datapoints_count: int = 0): + self.topic = topic + self.qos = qos if qos is not None else 1 + self.message_id = message_id + self.payload_size = payload_size + self.reason_code = reason_code + self.datapoints_count = datapoints_count + + def __repr__(self): + return (f"PublishResult(topic={self.topic}, " + f"qos={self.qos}, " + f"message_id={self.message_id}, " + f"payload_size={self.payload_size}, " + f"reason_code={self.reason_code}, " + f"datapoints_count={self.datapoints_count})") + + def __eq__(self, other): + if not isinstance(other, PublishResult): + return False + return (self.topic == other.topic and + self.qos == other.qos and + self.message_id == other.message_id and + self.payload_size == other.payload_size and + self.reason_code == other.reason_code and + self.datapoints_count == other.datapoints_count) + + def as_dict(self) -> dict: + return { + "topic": self.topic, + "qos": self.qos, + "message_id": self.message_id, + "payload_size": self.payload_size, + "reason_code": self.reason_code, + "datapoints_count": self.datapoints_count + } + + def is_successful(self) -> bool: + """ + Check if the publish operation was successful based on the reason code. + """ + return self.reason_code == 0 + + @staticmethod + def merge(results: List['PublishResult']) -> 'PublishResult': + if not results: + raise ValueError("No publish results to merge.") + + topic = results[0].topic + qos = results[0].qos + message_id = -1 + reason_code = 0 if all(r.reason_code == 0 for r in results) else -1 + payload_size = sum(r.payload_size for r in results) + datapoints_count = sum(r.datapoints_count for r in results) + + return PublishResult( + topic=topic, + qos=qos, + message_id=message_id, + payload_size=payload_size, + reason_code=reason_code, + datapoints_count=datapoints_count + ) diff --git a/tb_mqtt_client/common/queue.py b/tb_mqtt_client/common/queue.py new file mode 100644 index 0000000..edd59bb --- /dev/null +++ b/tb_mqtt_client/common/queue.py @@ -0,0 +1,97 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import deque + + +class AsyncDeque: + def __init__(self, maxlen): + self._deque = deque(maxlen=maxlen) + self._cond = asyncio.Condition() + self._in_queue = set() + + async def put(self, item): + if item.uuid in self._in_queue: + return + async with self._cond: + self._deque.append(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def extend(self, items): + async with self._cond: + for item in items: + if item.uuid not in self._in_queue: + self._deque.append(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def put_left(self, item): + if item.uuid in self._in_queue: + return + async with self._cond: + self._deque.appendleft(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def extend_left(self, items): + async with self._cond: + for item in reversed(items): + if item.uuid not in self._in_queue: + self._deque.appendleft(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + async def get(self): + async with self._cond: + while not self._deque: + await self._cond.wait() + item = self._deque.popleft() + self._in_queue.discard(item.uuid) + return item + + async def peek(self): + async with self._cond: + while not self._deque: + await self._cond.wait() + return self._deque[0] + + async def peek_batch(self, max_count: int): + async with self._cond: + while not self._deque: + await self._cond.wait() + return list(self._deque)[:max_count] + + async def pop_n(self, n: int): + async with self._cond: + items = [] + for _ in range(min(n, len(self._deque))): + item = self._deque.popleft() + self._in_queue.discard(item.uuid) + items.append(item) + return items + + async def reinsert_front(self, item): + async with self._cond: + if item.uuid not in self._in_queue: + self._deque.appendleft(item) + self._in_queue.add(item.uuid) + self._cond.notify() + + def is_empty(self): + return not self._deque + + def size(self): + return len(self._deque) diff --git a/tb_mqtt_client/common/rate_limit/__init__.py b/tb_mqtt_client/common/rate_limit/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/common/rate_limit/backpressure_controller.py b/tb_mqtt_client/common/rate_limit/backpressure_controller.py new file mode 100644 index 0000000..6dd46a1 --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/backpressure_controller.py @@ -0,0 +1,98 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from asyncio import Event +from datetime import datetime, timedelta, UTC +from typing import Optional, List + +from tb_mqtt_client.common.logging_utils import get_logger + +logger = get_logger(__name__) + + +class BackpressureController: + def __init__(self, main_stop_event: Event): + self.__main_stop_event = main_stop_event + self._pause_until: Optional[datetime] = None + self._default_pause_duration = timedelta(seconds=10) + self._consecutive_quota_exceeded = 0 + self._last_quota_exceeded = datetime.now(UTC) + self._max_backoff_seconds = 3600 # 1 hour + logger.debug("BackpressureController initialized with default pause duration of %s seconds", + self._default_pause_duration.total_seconds()) + + def notify_quota_exceeded(self, delay_seconds: Optional[int] = None): + if self.__main_stop_event.is_set(): + logger.trace("Main stop event is set, not applying backpressure") + return + now = datetime.now(UTC) + # If we've had a quota-exceeded event in the last 60 seconds, increment the counter + if (now - self._last_quota_exceeded).total_seconds() < 60: + self._consecutive_quota_exceeded += 1 + else: + # Reset counter if it's been more than 60 seconds since the last quota exceeded the event + self._consecutive_quota_exceeded = 1 + + self._last_quota_exceeded = now + + # Apply exponential backoff based on consecutive quota-exceeded events + if delay_seconds is None: + # Start with the default duration and apply exponential backoff + backoff_factor = min(2 ** (self._consecutive_quota_exceeded - 1), 10) + delay_seconds = int(self._default_pause_duration.total_seconds() * backoff_factor) + # Cap at max backoff + delay_seconds = min(delay_seconds, self._max_backoff_seconds) + + logger.warning("Applying backpressure for %d seconds (consecutive quota exceeded: %d)", + delay_seconds, self._consecutive_quota_exceeded) + + duration = timedelta(seconds=delay_seconds) + self._pause_until = now + duration + + def notify_disconnect(self, delay_seconds: Optional[int] = None): + if self.__main_stop_event.is_set(): + logger.trace("Main stop event is set, not pausing publishing") + return + if delay_seconds is None: + delay_seconds = int(self._default_pause_duration.total_seconds()) + + duration = timedelta(seconds=delay_seconds) + self._pause_until = datetime.now(UTC) + duration + logger.debug("Pausing publishing for %d seconds due to disconnect", delay_seconds) + + def should_pause(self) -> bool: + if self.__main_stop_event.is_set(): + logger.trace("Main stop event is set, not checking pause state") + return False + if self._pause_until is None: + return False + + now = datetime.now(UTC) + if now < self._pause_until: + remaining = (self._pause_until - now).total_seconds() + if remaining > 10: # Only log if more than 10 seconds remaining + logger.debug("Backpressure active: pausing publishing for %.1f more seconds", remaining) + return True + + # Reset pause state + self._pause_until = None + logger.info("Backpressure released, resuming publishing") + return False + + def clear(self): + if self._pause_until is not None: + logger.info("Clearing backpressure pause") + self._pause_until = None + self._consecutive_quota_exceeded = 0 diff --git a/tb_mqtt_client/common/rate_limit/rate_limit.py b/tb_mqtt_client/common/rate_limit/rate_limit.py new file mode 100644 index 0000000..b597eff --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/rate_limit.py @@ -0,0 +1,227 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +from asyncio import Lock +from time import monotonic + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 5 + +try: + DEFAULT_RATE_LIMIT_PERCENTAGE = int(os.getenv('TB_DEFAULT_RATE_LIMIT_PERCENTAGE', 80)) +except ValueError: + logger.warning("Invalid TB_DEFAULT_RATE_LIMIT_PERCENTAGE, falling back to 80%%") + DEFAULT_RATE_LIMIT_PERCENTAGE = 80 + + +class GreedyTokenBucket: + def __init__(self, capacity, duration_sec): + self.capacity = float(capacity) + self.duration = float(duration_sec) + self.tokens = float(capacity) + self.last_updated = monotonic() + + def refill(self): + now = monotonic() + elapsed = now - self.last_updated + rate = self.capacity / self.duration + self.tokens = min(self.capacity, self.tokens + elapsed * rate) + self.last_updated = now + + def can_consume(self, amount=1): + self.refill() + return round(self.tokens, 6) >= round(amount, 6) + + def consume(self, amount=1): + self.refill() + if self.tokens >= amount: + self.tokens -= amount + return True + return False + + def get_remaining_tokens(self): + self.refill() + return self.tokens + + +class RateLimit: + def __init__(self, rate_limit: str, name: str = None, percentage: int = DEFAULT_RATE_LIMIT_PERCENTAGE): + self.name = name + self.percentage = percentage + self._no_limit = False + self._rate_buckets = {} + self._lock = Lock() + self._minimal_timeout = DEFAULT_TIMEOUT + self._minimal_limit = float('inf') + self.__reached_index = 0 + self.__reached_index_time = 0 + self.__required_tokens = None + self._required_tokens_duration = None + self.required_tokens_ready = asyncio.Event() + + self._parse_string(rate_limit) + + def _parse_string(self, rate_limit: str): + if not rate_limit or rate_limit.strip().removesuffix(',') in ("0:0", ""): + self._no_limit = True + return + + entries = rate_limit.replace(";", ",").split(",") + for entry in entries: + if not entry.strip(): + continue + try: + limit_str, dur_str = entry.strip().split(":") + limit = int(int(limit_str) * self.percentage / 100) + duration = int(dur_str) + + bucket = GreedyTokenBucket(limit, duration) + self._rate_buckets[duration] = bucket + self._minimal_limit = min(self._minimal_limit, limit) + self._minimal_timeout = min(self._minimal_timeout, duration + 1) + except Exception as e: + logger.warning("Invalid rate limit format '%s': %s", entry, e) + + self._no_limit = not bool(self._rate_buckets) + + async def check_limit_reached(self, amount=1): + if self._no_limit: + return False + + async with self._lock: + result = False + for dur, bucket in self._rate_buckets.items(): + bucket.refill() + if not result and bucket.tokens < amount: + result = (bucket.capacity, dur) + return result + + async def refill(self): + """Force refill of all token buckets without consuming any tokens.""" + if self._no_limit: + return + async with self._lock: + for duration, bucket in self._rate_buckets.items(): + bucket.refill() + if duration == self._required_tokens_duration: + if bucket.tokens >= self.__required_tokens: + self.required_tokens_ready.set() + self._required_tokens_duration = None + self.__required_tokens = None + + async def try_consume(self, amount=1): + """ + Try to consume tokens from all buckets. + Returns True if all buckets had enough tokens and they were consumed. + Returns False if any bucket didn't have enough tokens. + """ + if self._no_limit: + return None + + async with self._lock: + for bucket in self._rate_buckets.values(): + bucket.refill() + if bucket.tokens < amount: + return bucket.capacity, bucket.duration + + for bucket in self._rate_buckets.values(): + bucket.tokens -= amount + + return None + + async def consume(self, amount=1): + if self._no_limit: + return + async with self._lock: + for bucket in self._rate_buckets.values(): + bucket.consume(amount) + + @property + def minimal_limit(self): + return self._minimal_limit if self.has_limit() else 0 + + @property + def minimal_timeout(self): + return self._minimal_timeout if self.has_limit() else 0 + + def has_limit(self): + return not self._no_limit + + async def reach_limit(self): + if self._no_limit: + return None + + async with self._lock: + durations = sorted(self._rate_buckets.keys()) + now = monotonic() + + if self.__reached_index_time >= now - self._rate_buckets[durations[-1]].duration: + self.__reached_index = 0 + self.__reached_index_time = now + + if self.__reached_index >= len(durations): + self.__reached_index = 0 + self.__reached_index_time = now + + dur = durations[self.__reached_index] + self._rate_buckets[dur].tokens = 0.0 + self.__reached_index += 1 + + logger.info("Rate limit reached for \"%s\". Cooldown for %s seconds", self.name, dur) + return self.__reached_index, self.__reached_index_time, dur + + def to_dict(self): + return { + "name": self.name, + "percentage": self.percentage, + "no_limit": self._no_limit, + "rateLimits": { + str(dur): { + "capacity": b.capacity, + "tokens": b.get_remaining_tokens(), + "last_updated": b.last_updated + } for dur, b in self._rate_buckets.items() + } + } + + async def set_limit(self, rate_limit: str, percentage: int = DEFAULT_RATE_LIMIT_PERCENTAGE): + async with self._lock: + self._rate_buckets.clear() + self._minimal_timeout = DEFAULT_TIMEOUT + self._minimal_limit = float('inf') + self.percentage = percentage + self._parse_string(rate_limit) + + def set_required_tokens(self, duration, expected_tokens): + """ + Set the required tokens for next message processing. + After this call, `required_tokens_ready` event will be set when + the required tokens are available in the rate limit buckets. + """ + self.__required_tokens = expected_tokens + self._required_tokens_duration = duration + + def clear_required_tokens_event(self): + """ + Clear the required tokens ready event. + This should be called when message processing is done + """ + self.required_tokens_ready.clear() + + +EMPTY_RATE_LIMIT = RateLimit("0:0,", name="Empty Rate Limit") diff --git a/tb_mqtt_client/common/rate_limit/rate_limiter.py b/tb_mqtt_client/common/rate_limit/rate_limiter.py new file mode 100644 index 0000000..7fe75e0 --- /dev/null +++ b/tb_mqtt_client/common/rate_limit/rate_limiter.py @@ -0,0 +1,32 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit + + +@dataclass() +class RateLimiter: + message_rate_limit: RateLimit + telemetry_message_rate_limit: RateLimit + telemetry_datapoints_rate_limit: RateLimit + + def values(self): + return [self.message_rate_limit, self.telemetry_message_rate_limit, self.telemetry_datapoints_rate_limit] + + def __repr__(self): + return f"RateLimiter(message_rate_limit={self.message_rate_limit}, " \ + f"telemetry_message_rate_limit={self.telemetry_message_rate_limit}, " \ + f"telemetry_datapoints_rate_limit={self.telemetry_datapoints_rate_limit})" diff --git a/tb_mqtt_client/common/request_id_generator.py b/tb_mqtt_client/common/request_id_generator.py new file mode 100644 index 0000000..a5df480 --- /dev/null +++ b/tb_mqtt_client/common/request_id_generator.py @@ -0,0 +1,69 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Lock + + +class RPCRequestIdProducer: + """ + Singleton-style producer of unique RPC request IDs, + safe for async environments and shared across all SDK services. + """ + + _counter: int = 1 + _lock: Lock = Lock() + + @classmethod + async def get_next(cls) -> int: + """ + Atomically increment and return the next request ID. + """ + async with cls._lock: + current = cls._counter + cls._counter += 1 + return current + + @classmethod + def reset(cls): + """ + Reset the global request ID counter (usually on disconnect). + """ + cls._counter = 1 + + +class AttributeRequestIdProducer: + """ + Singleton-style producer of unique attribute request IDs, + safe for async environments and shared across all SDK services. + """ + + _counter: int = 1 + _lock: Lock = Lock() + + @classmethod + async def get_next(cls) -> int: + """ + Atomically increment and return the next attribute request ID. + """ + async with cls._lock: + current = cls._counter + cls._counter += 1 + return current + + @classmethod + def reset(cls): + """ + Reset the global attribute request ID counter (usually on disconnect). + """ + cls._counter = 1 diff --git a/tb_mqtt_client/constants/__init__.py b/tb_mqtt_client/constants/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/constants/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/constants/firmware.py b/tb_mqtt_client/constants/firmware.py new file mode 100644 index 0000000..cdc03b4 --- /dev/null +++ b/tb_mqtt_client/constants/firmware.py @@ -0,0 +1,36 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class FirmwareStates(Enum): + IDLE = 'IDLE' + DOWNLOADING = 'DOWNLOADING' + DOWNLOADED = 'DOWNLOADED' + VERIFIED = 'VERIFIED' + FAILED = 'FAILED' + UPDATING = 'UPDATING' + UPDATED = 'UPDATED' + + +FW_TITLE_ATTR = "fw_title" +FW_VERSION_ATTR = "fw_version" +FW_CHECKSUM_ATTR = "fw_checksum" +FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" +FW_SIZE_ATTR = "fw_size" +FW_STATE_ATTR = "fw_state" + +REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, + FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] diff --git a/tb_mqtt_client/constants/json_typing.py b/tb_mqtt_client/constants/json_typing.py new file mode 100644 index 0000000..e66b8fb --- /dev/null +++ b/tb_mqtt_client/constants/json_typing.py @@ -0,0 +1,44 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List, Dict + +JSONPrimitive = Union[str, int, float, bool, None] +JSONCompatibleType = Union[JSONPrimitive, List["JSONType"], Dict[str, "JSONType"]] + + +def validate_json_compatibility(value: object) -> None: + """ + Validates that the input value is fully JSON-compatible in structure and type. + Raises a ValueError on the first incompatible type encountered. + """ + stack = [(value, "$")] + basic_types = (str, int, float, bool, type(None)) + + while stack: + current, path = stack.pop() + t = type(current) + + if t in basic_types: + continue + if t is list: + for idx, item in enumerate(current): # type: ignore[arg-type] + stack.append((item, f"{path}[{idx}]")) + elif t is dict: + for k, v in current.items(): # type: ignore[union-attr] + if type(k) is not str: + raise ValueError(f"Invalid JSON key at {path}: expected str, got {type(k).__name__} ({k!r})") + stack.append((v, f"{path}.{k}")) + else: + raise ValueError(f"Invalid JSON value at {path}: unsupported - {type(current).__name__} ({current!r})") diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py new file mode 100644 index 0000000..aca0cca --- /dev/null +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -0,0 +1,86 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +WILDCARD = "+" +REQUEST_TOPIC_SUFFIX = "/request" +RESPONSE_TOPIC_SUFFIX = "/response" + +# V1 Topics for Device API +DEVICE_TELEMETRY_TOPIC = "v1/devices/me/telemetry" +DEVICE_ATTRIBUTES_TOPIC = "v1/devices/me/attributes" +DEVICE_ATTRIBUTES_REQUEST_TOPIC = DEVICE_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + "{request_id}" +DEVICE_ATTRIBUTES_RESPONSE_TOPIC = DEVICE_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD +# Device RPC topics +DEVICE_RPC_TOPIC = "v1/devices/me/rpc" +DEVICE_RPC_REQUEST_TOPIC = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" +DEVICE_RPC_RESPONSE_TOPIC = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" +# Device RPC topics for subscription +DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + REQUEST_TOPIC_SUFFIX + "/" + WILDCARD +DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD +# Device Claim topic +DEVICE_CLAIM_TOPIC = "v1/devices/me/claim" +# Device Provisioning topics +PROVISION_REQUEST_TOPIC = "/provision/request" +PROVISION_RESPONSE_TOPIC = "/provision/response" +# Device Firmware Update topics +DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC = "v2/fw/response/+/chunk/+" +DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC = "v2/fw/request/{request_id}/chunk/{current_chunk}" + +# V1 Topics for Gateway API +BASE_GATEWAY_TOPIC = "v1/gateway" +GATEWAY_CONNECT_TOPIC = BASE_GATEWAY_TOPIC + "/connect" +GATEWAY_DISCONNECT_TOPIC = BASE_GATEWAY_TOPIC + "/disconnect" +GATEWAY_TELEMETRY_TOPIC = BASE_GATEWAY_TOPIC + "/telemetry" +GATEWAY_ATTRIBUTES_TOPIC = BASE_GATEWAY_TOPIC + "/attributes" +GATEWAY_ATTRIBUTES_REQUEST_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + REQUEST_TOPIC_SUFFIX +GATEWAY_ATTRIBUTES_RESPONSE_TOPIC = GATEWAY_ATTRIBUTES_TOPIC + RESPONSE_TOPIC_SUFFIX +GATEWAY_RPC_TOPIC = BASE_GATEWAY_TOPIC + "/rpc" +GATEWAY_CLAIM_TOPIC = BASE_GATEWAY_TOPIC + "/claim" + +TOPICS_WITH_DATAPOINTS_CHECK = frozenset({ + DEVICE_TELEMETRY_TOPIC, + DEVICE_ATTRIBUTES_TOPIC, + GATEWAY_TELEMETRY_TOPIC, + GATEWAY_ATTRIBUTES_TOPIC +}) + +GATEWAY_TOPICS = frozenset({ + GATEWAY_CONNECT_TOPIC, + GATEWAY_DISCONNECT_TOPIC, + GATEWAY_TELEMETRY_TOPIC, + GATEWAY_ATTRIBUTES_TOPIC, + GATEWAY_ATTRIBUTES_REQUEST_TOPIC, + GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, + GATEWAY_RPC_TOPIC, + GATEWAY_CLAIM_TOPIC +}) + + +# Topic Builders + + +def build_device_attributes_request_topic(request_id: int) -> str: + return DEVICE_ATTRIBUTES_REQUEST_TOPIC.format(request_id=request_id) + + +def build_device_rpc_request_topic(request_id: int) -> str: + return DEVICE_RPC_REQUEST_TOPIC + str(request_id) + + +def build_device_rpc_response_topic(request_id: int) -> str: + return DEVICE_RPC_RESPONSE_TOPIC + str(request_id) + + +def build_firmware_update_request_topic(request_id: int, current_chunk: int) -> str: + return DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC.format(request_id=request_id, current_chunk=current_chunk) diff --git a/tb_mqtt_client/constants/provisioning.py b/tb_mqtt_client/constants/provisioning.py new file mode 100644 index 0000000..a8cf8c8 --- /dev/null +++ b/tb_mqtt_client/constants/provisioning.py @@ -0,0 +1,29 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ProvisioningResponseStatus(Enum): + SUCCESS = "SUCCESS" + ERROR = "FAILURE" + + def __str__(self): + return self.value + + +class ProvisioningCredentialsType(Enum): + ACCESS_TOKEN = "ACCESS_TOKEN" + MQTT_BASIC = "MQTT_BASIC" + X509_CERTIFICATE = "X509_CERTIFICATE" diff --git a/tb_mqtt_client/constants/service_keys.py b/tb_mqtt_client/constants/service_keys.py new file mode 100644 index 0000000..e5351a8 --- /dev/null +++ b/tb_mqtt_client/constants/service_keys.py @@ -0,0 +1,16 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +TELEMETRY_TIMESTAMP_PARAMETER = "ts" +TELEMETRY_VALUES_PARAMETER = "values" diff --git a/tb_mqtt_client/entities/__init__.py b/tb_mqtt_client/entities/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/entities/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/entities/data/__init__.py b/tb_mqtt_client/entities/data/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/entities/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/entities/data/attribute_entry.py b/tb_mqtt_client/entities/data/attribute_entry.py new file mode 100644 index 0000000..214d83a --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_entry.py @@ -0,0 +1,35 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.constants.json_typing import JSONCompatibleType +from tb_mqtt_client.entities.data.data_entry import DataEntry + + +class AttributeEntry(DataEntry): + def __init__(self, key: str, value: JSONCompatibleType): + super().__init__(key, value) + + def __repr__(self): + return f"AttributeEntry(key={self.key}, value={self.value})" + + def as_dict(self) -> dict: + return { + "key": self.key, + "value": self.value + } + + def __eq__(self, other): + if not isinstance(other, AttributeEntry): + return False + return self.key == other.key and self.value == other.value diff --git a/tb_mqtt_client/entities/data/attribute_request.py b/tb_mqtt_client/entities/data/attribute_request.py new file mode 100644 index 0000000..5434312 --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_request.py @@ -0,0 +1,67 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, List, Dict + +from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer +from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +@dataclass(slots=True, frozen=True) +class AttributeRequest(BaseGatewayEvent): + """ + Represents a request for device attributes, with optional client and shared attribute keys. + Automatically assigns a unique request ID via the build() method. + """ + request_id: int + shared_keys: Optional[List[str]] = None + client_keys: Optional[List[str]] = None + event_type: GatewayEventType = GatewayEventType.DEVICE_ATTRIBUTE_REQUEST + + def __new__(self, *args, **kwargs): + raise TypeError( + "Direct instantiation of AttributeRequest is not allowed. Use 'await AttributeRequest.build(...)'.") + + def __repr__(self) -> str: + return f"AttributeRequest(id={self.request_id}, shared_keys={self.shared_keys}, client_keys={self.client_keys})" + + @classmethod + async def build(cls, shared_keys: Optional[List[str]] = None, + client_keys: Optional[List[str]] = None) -> 'AttributeRequest': + """ + Build a new AttributeRequest with a unique request ID, using the global ID generator. + """ + validate_json_compatibility(shared_keys) + validate_json_compatibility(client_keys) + request_id = await AttributeRequestIdProducer.get_next() + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'shared_keys', shared_keys) + object.__setattr__(self, 'client_keys', client_keys) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) + return self + + def to_payload_format(self) -> Dict[str, str]: + """ + Convert the attribute request into the expected MQTT payload format. + """ + payload = {} + if self.shared_keys is not None: + payload["sharedKeys"] = ','.join(self.shared_keys) + if self.client_keys is not None: + payload["clientKeys"] = ','.join(self.client_keys) + return payload diff --git a/tb_mqtt_client/entities/data/attribute_update.py b/tb_mqtt_client/entities/data/attribute_update.py new file mode 100644 index 0000000..336be4a --- /dev/null +++ b/tb_mqtt_client/entities/data/attribute_update.py @@ -0,0 +1,57 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Any, List + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +@dataclass(slots=True) +class AttributeUpdate(BaseGatewayEvent): + entries: List[AttributeEntry] + event_type: GatewayEventType = GatewayEventType.DEVICE_ATTRIBUTE_UPDATE + + def __repr__(self): + return f"AttributeUpdate(entries={self.entries})" + + def get(self, key: str, default=None): + for entry in self.entries: + if entry.key == key: + return entry.value + return default + + def keys(self): + return [entry.key for entry in self.entries] + + def values(self): + return [entry.value for entry in self.entries] + + def items(self): + return [(entry.key, entry.value) for entry in self.entries] + + def as_dict(self) -> Dict[str, Any]: + return {entry.key: entry.value for entry in self.entries} + + @classmethod + def _deserialize_from_dict(cls, data: Dict[str, Any]) -> 'AttributeUpdate': + """ + Deserialize dictionary into AttributeUpdate object. + :param data: Dictionary of attribute key-value pairs. + :return: AttributeUpdate instance. + """ + entries = [AttributeEntry(k, v) for k, v in data.items()] + return cls(entries=entries) diff --git a/tb_mqtt_client/entities/data/claim_request.py b/tb_mqtt_client/entities/data/claim_request.py new file mode 100644 index 0000000..27fbb43 --- /dev/null +++ b/tb_mqtt_client/entities/data/claim_request.py @@ -0,0 +1,58 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Dict, Any + + +@dataclass(slots=True, frozen=True) +class ClaimRequest: + """ + Represents a device claim request, as per ThingsBoard MQTT API. + Optionally includes a secret key and a duration (in seconds) for the claim. + Does not include request_id, as it's not required for this operation. + """ + secret_key: Optional[str] = None + duration: Optional[int] = None # seconds + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of ClaimRequest is not allowed. Use 'ClaimRequest.build(...)'.") + + def __repr__(self) -> str: + return f"ClaimRequest(secret_key={self.secret_key}, duration={self.duration})" + + @classmethod + def build(cls, secret_key: str, duration: int = 60000) -> 'ClaimRequest': + """ + Safely construct a new ClaimRequest instance. + """ + if not isinstance(secret_key, str): + raise ValueError("Secret key must be a string") + if not isinstance(duration, int) or duration < 0: + raise ValueError("Duration must be a non-negative integer representing seconds") + self = object.__new__(cls) + object.__setattr__(self, 'secret_key', secret_key) + object.__setattr__(self, 'duration', duration) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """ + Convert the claim request to the expected MQTT JSON payload format. + """ + payload = {} + if self.secret_key is not None: + payload["secretKey"] = self.secret_key + if self.duration is not None: + payload["durationMs"] = int(self.duration * 1000) + return payload diff --git a/tb_mqtt_client/entities/data/data_entry.py b/tb_mqtt_client/entities/data/data_entry.py new file mode 100644 index 0000000..73df046 --- /dev/null +++ b/tb_mqtt_client/entities/data/data_entry.py @@ -0,0 +1,69 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +from orjson import dumps + +from tb_mqtt_client.constants.json_typing import JSONCompatibleType, validate_json_compatibility + + +class DataEntry: + __slots__ = ("__key", "__value", "__ts", "__size") + + def __init__(self, key: str, value: JSONCompatibleType, ts: Optional[int] = None): + validate_json_compatibility(value) + self.__key = key + self.__value = value + self.__ts = ts + self.__size = self.__estimate_size() + + def __repr__(self): + return f"DataEntry(key={self.key}, value={self.value}, ts={self.ts})" + + def __estimate_size(self) -> int: + if self.__ts is not None: + return len(dumps({"ts": self.__ts, "values": {self.__key: self.__value}})) + return len(dumps({self.__key: self.__value})) + + @property + def size(self) -> int: + return self.__size + + @property + def key(self) -> str: + return self.__key + + @key.setter + def key(self, value: str): + self.__key = value + self.__size = self.__estimate_size() + + @property + def value(self) -> Any: + return self.__value + + @value.setter + def value(self, value: Any): + self.__value = value + self.__size = self.__estimate_size() + + @property + def ts(self) -> Optional[int]: + return self.__ts + + @ts.setter + def ts(self, value: Optional[int]): + self.__ts = value + self.__size = self.__estimate_size() diff --git a/tb_mqtt_client/entities/data/device_uplink_message.py b/tb_mqtt_client/entities/data/device_uplink_message.py new file mode 100644 index 0000000..1360d4b --- /dev/null +++ b/tb_mqtt_client/entities/data/device_uplink_message.py @@ -0,0 +1,184 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass +from types import MappingProxyType +from typing import List, Optional, Union, OrderedDict, Tuple, Mapping +from uuid import uuid4 + +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + +logger = get_logger(__name__) + +DEFAULT_FIELDS_SIZE = len('{"device_name":"","device_profile":"","attributes":"","timeseries":""}'.encode('utf-8')) + + +@dataclass(slots=True, frozen=True) +class DeviceUplinkMessage: + device_name: Optional[str] + device_profile: Optional[str] + attributes: Tuple[AttributeEntry] + timeseries: Mapping[int, Tuple[TimeseriesEntry]] + delivery_futures: List[Optional[asyncio.Future[PublishResult]]] + _size: int + main_ts: Optional[int] = None + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of DeviceUplinkMessage is not allowed. " + "Use DeviceUplinkMessageBuilder to construct instances.") + + def __repr__(self): + return (f"DeviceUplinkMessage(device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"attributes={self.attributes}, " + f"timeseries={self.timeseries}, " + f"delivery_futures={self.delivery_futures})") + + @classmethod + def build(cls, + device_name: Optional[str], + device_profile: Optional[str], + attributes: List[AttributeEntry], + timeseries: Mapping[int, List[TimeseriesEntry]], + delivery_futures: List[Optional[asyncio.Future]], + size: int, + main_ts: Optional[int]) -> 'DeviceUplinkMessage': + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) + object.__setattr__(self, 'attributes', tuple(attributes)) + object.__setattr__(self, 'timeseries', + MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) + object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) + object.__setattr__(self, '_size', size) + object.__setattr__(self, 'main_ts', main_ts) + return self + + @property + def size(self) -> int: + return self._size + + def timeseries_datapoint_count(self) -> int: + return sum(len(entries) for entries in self.timeseries.values()) + + def attributes_datapoint_count(self) -> int: + return len(self.attributes) + + def has_attributes(self) -> bool: + return bool(self.attributes) + + def has_timeseries(self) -> bool: + return bool(self.timeseries) + + def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: + return self.delivery_futures + + def set_main_ts(self, main_ts: int) -> 'DeviceUplinkMessage': + """ + Set the main timestamp for the message. + :param main_ts: The main timestamp to set. + :return: The updated DeviceUplinkMessage instance. + """ + object.__setattr__(self, 'main_ts', main_ts) + return self + + +class DeviceUplinkMessageBuilder: + def __init__(self): + self._device_name: Optional[str] = None + self._device_profile: Optional[str] = None + self._attributes: List[AttributeEntry] = [] + self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() + self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] + self.__size = DEFAULT_FIELDS_SIZE + self._main_ts: Optional[int] = None + + def set_device_name(self, device_name: str) -> 'DeviceUplinkMessageBuilder': + self._device_name = device_name + if device_name is not None: + self.__size += len(device_name) + return self + + def set_device_profile(self, profile: str) -> 'DeviceUplinkMessageBuilder': + self._device_profile = profile + if profile is not None: + self.__size += len(profile) + return self + + def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(attributes, list): + attributes = [attributes] + self._attributes.extend(attributes) + for attribute in attributes: + self.__size += attribute.size + return self + + def add_timeseries(self, + timeseries: Union[TimeseriesEntry, + List[TimeseriesEntry], + OrderedDict[int, List[TimeseriesEntry]]]) -> 'DeviceUplinkMessageBuilder': + if isinstance(timeseries, OrderedDict): + self._timeseries = timeseries + return self + if not isinstance(timeseries, list): + timeseries = [timeseries] + for entry in timeseries: + if entry.ts is not None: + if entry.ts in self._timeseries: + self._timeseries[entry.ts].append(entry) + else: + self._timeseries[entry.ts] = [entry] + else: + if 0 in self._timeseries: + self._timeseries[0].append(entry) + else: + self._timeseries[0] = [entry] + for timeseries_entry in timeseries: + self.__size += timeseries_entry.size + return self + + def add_delivery_futures(self, futures: Union[asyncio.Future[PublishResult], + List[asyncio.Future[PublishResult]]]) -> 'DeviceUplinkMessageBuilder': + if not isinstance(futures, list): + futures = [futures] + if futures: + if logger.isEnabledFor(TRACE_LEVEL): + logger.debug("Created delivery futures: %r", [future.uuid for future in futures]) + self._delivery_futures.extend(futures) + return self + + def set_main_ts(self, main_ts: int) -> 'DeviceUplinkMessageBuilder': + self._main_ts = main_ts + return self + + def build(self) -> DeviceUplinkMessage: + if not self._delivery_futures: + delivery_future = asyncio.get_event_loop().create_future() + delivery_future.uuid = uuid4() + logger.trace("No delivery futures provided, creating a default future: %r", delivery_future.uuid) + self._delivery_futures = [delivery_future] + return DeviceUplinkMessage.build( + device_name=self._device_name, + device_profile=self._device_profile, + attributes=self._attributes, + timeseries=self._timeseries, + delivery_futures=self._delivery_futures, + size=self.__size, + main_ts=self._main_ts + ) diff --git a/tb_mqtt_client/entities/data/provisioning_request.py b/tb_mqtt_client/entities/data/provisioning_request.py new file mode 100644 index 0000000..7493e69 --- /dev/null +++ b/tb_mqtt_client/entities/data/provisioning_request.py @@ -0,0 +1,78 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional + +from tb_mqtt_client.constants.provisioning import ProvisioningCredentialsType + + +class ProvisioningCredentials(ABC): + @abstractmethod + def __init__(self, provision_device_key: str, provision_device_secret: str): + self.provision_device_key = provision_device_key + self.provision_device_secret = provision_device_secret + self.credentials_type: ProvisioningCredentialsType + + +class ProvisioningRequest: + def __init__(self, host, credentials: ProvisioningCredentials, port: int = 1883, + device_name: Optional[str] = None, gateway: Optional[bool] = False): + self.host = host + self.port = port + self.credentials = credentials + self.device_name = device_name + self.gateway = gateway + + +class AccessTokenProvisioningCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: Optional[str] = None): + super().__init__(provision_device_key, provision_device_secret) + self.access_token = access_token + self.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN + + +class BasicProvisioningCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key, provision_device_secret, + client_id: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None): + super().__init__(provision_device_key, provision_device_secret) + self.client_id = client_id + self.username = username + self.password = password + self.credentials_type = ProvisioningCredentialsType.MQTT_BASIC + + +class X509ProvisioningCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key, provision_device_secret, + private_key_path: str, public_cert_path: str, ca_cert_path: str): + super().__init__(provision_device_key, provision_device_secret) + self.private_key_path = private_key_path + self.ca_cert_path = ca_cert_path + self.public_cert_path = public_cert_path + self.public_cert = self._load_public_cert_path(public_cert_path) + self.credentials_type = ProvisioningCredentialsType.X509_CERTIFICATE + + @staticmethod + def _load_public_cert_path(public_cert_path): + content = '' + + try: + with open(public_cert_path, 'r') as file: + content = file.read() + except FileNotFoundError: + raise FileNotFoundError(f"Public certificate file not found: {public_cert_path}") + except IOError as e: + raise IOError(f"Error reading public certificate file {public_cert_path}: {e}") + + return content.strip() if content else None diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py new file mode 100644 index 0000000..fc82f4e --- /dev/null +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -0,0 +1,74 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType + + +@dataclass(frozen=True) +class ProvisioningResponse: + status: ProvisioningResponseStatus + result: Optional[DeviceConfig] = None + error: Optional[str] = None + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of ProvisioningResponse is not allowed. Use ProvisioningResponse.build(result, error).") # noqa + + def __repr__(self) -> str: + return f"ProvisioningResponse(status={self.status}, result={self.result}, error={self.error})" + + @classmethod + def build(cls, provision_request: 'ProvisioningRequest', payload: dict) -> 'ProvisioningResponse': + """ + Constructs a ProvisioningResponse explicitly. + """ + self = object.__new__(cls) + + if payload.get('status') == ProvisioningResponseStatus.SUCCESS.value: + device_config = ProvisioningResponse._build_device_config(provision_request, payload) + + object.__setattr__(self, 'result', device_config) + object.__setattr__(self, 'status', ProvisioningResponseStatus.SUCCESS) + object.__setattr__(self, 'error', None) + else: + object.__setattr__(self, 'error', payload.get('errorMsg')) + object.__setattr__(self, 'status', ProvisioningResponseStatus.ERROR) + object.__setattr__(self, 'result', None) + + return self + + @staticmethod + def _build_device_config(provision_request: 'ProvisioningRequest', payload: dict): + device_config = DeviceConfig() + device_config.host = provision_request.host + device_config.port = provision_request.port + + if provision_request.credentials.credentials_type is None or \ + provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: + device_config.access_token = payload['credentialsValue'] + elif provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: + device_config.client_id = payload['credentialsValue']['clientId'] + device_config.username = payload['credentialsValue']['userName'] + device_config.password = payload['credentialsValue']['password'] + elif provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: + device_config.ca_cert = provision_request.credentials.ca_cert_path + device_config.client_cert = provision_request.credentials.public_cert_path + device_config.private_key = provision_request.credentials.private_key_path + + return device_config diff --git a/tb_mqtt_client/entities/data/requested_attribute_response.py b/tb_mqtt_client/entities/data/requested_attribute_response.py new file mode 100644 index 0000000..9678888 --- /dev/null +++ b/tb_mqtt_client/entities/data/requested_attribute_response.py @@ -0,0 +1,92 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Any, List + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@dataclass(slots=True, frozen=True) +class RequestedAttributeResponse: + request_id: int + shared: List[AttributeEntry] + client: List[AttributeEntry] + + def __repr__(self): + return f"RequestedAttributeResponse(request_id={self.request_id}, shared={self.shared}, client={self.client})" + + def __getitem__(self, item): + """ + Allows access to values using dictionary-like syntax. + """ + for entry in self.shared: + if entry.key == item: + return entry.value + for entry in self.client: + if entry.key == item: + return entry.value + raise KeyError(f"Key '{item}' not found in shared or client attributes.") + + def shared_keys(self): + return [entry.key for entry in self.shared] + + def client_keys(self): + return [entry.key for entry in self.client] + + def get_shared(self, key: str, default=None): + """ + Get the value of a shared attribute by key. + :param key: The key of the shared attribute. + :param default: Default value if the key is not found. + :return: Value of the shared attribute or default. + """ + for entry in self.shared: + if entry.key == key: + return entry.value + return default + + def get_client(self, key: str, default=None): + """ + Get the value of a client attribute by key. + :param key: The key of the client attribute. + :param default: Default value if the key is not found. + :return: Value of the client attribute or default. + """ + for entry in self.client: + if entry.key == key: + return entry.value + return default + + def as_dict(self) -> Dict[str, Any]: + """ + Convert the AttributeResponse to a dictionary format. + :return: Dictionary representation of the response. + """ + return { + 'shared': [entry.as_dict() for entry in self.shared], + 'client': [entry.as_dict() for entry in self.client] + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'RequestedAttributeResponse': + """ + Deserialize dictionary into AttributeResponse object. + :param data: Dictionary containing 'shared' and 'client' attributes. + :return: AttributeResponse instance. + """ + shared = [AttributeEntry(k, v) for k, v in data.get('shared', {}).items()] + client = [AttributeEntry(k, v) for k, v in data.get('client', {}).items()] + request_id = data.get('request_id', -1) # Default to -1 if not provided + return cls(shared=shared, client=client, request_id=request_id) diff --git a/tb_mqtt_client/entities/data/rpc_request.py b/tb_mqtt_client/entities/data/rpc_request.py new file mode 100644 index 0000000..9300399 --- /dev/null +++ b/tb_mqtt_client/entities/data/rpc_request.py @@ -0,0 +1,80 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union, Optional, Dict, Any + +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType + + +@dataclass(slots=True, frozen=True) +class RPCRequest: + request_id: Union[int, str] + method: str + params: Optional[Any] = None + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of RPCRequest is not allowed. Use 'await RPCRequest.build(...)'.") + + def __repr__(self): + return f"RPCRequest(id={self.request_id}, method={self.method}, params={self.params})" + + @classmethod + def _deserialize_from_dict(cls, request_id: int, data: Dict[str, Any]) -> 'RPCRequest': + """ + Constructs an RPCRequest, should be used only for deserialization request from the platform. + """ + if not isinstance(request_id, (int, str)): + raise ValueError("Missing request id in RPC request") + if "method" not in data: + raise ValueError("Missing 'method' in RPC request") + + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'method', data["method"]) + object.__setattr__(self, 'params', data.get("params")) + return self + + @classmethod + async def build(cls, method: str, params: Optional[JSONCompatibleType] = None) -> 'RPCRequest': + """ + Constructs an RPCRequest with a unique request ID, + using the RPCRequestIdProducer to ensure thread-safe ID generation. + """ + if not isinstance(method, str): + raise ValueError("Method must be a string") + validate_json_compatibility(params) + request_id = await RPCRequestIdProducer.get_next() + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'method', method) + object.__setattr__(self, 'params', params) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """ + Serializes the RPC request for publishing. + Converts the request to a dictionary format suitable for MQTT payload. + """ + data = { + "id": self.request_id, + "method": self.method + } + if self.params is not None: + data["params"] = self.params + return data + + def __str__(self): + return f"RPCRequest(id={self.request_id}, method={self.method}, params={self.params})" diff --git a/tb_mqtt_client/entities/data/rpc_response.py b/tb_mqtt_client/entities/data/rpc_response.py new file mode 100644 index 0000000..4d9a133 --- /dev/null +++ b/tb_mqtt_client/entities/data/rpc_response.py @@ -0,0 +1,101 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum +from traceback import format_exception +from typing import Union, Optional, Dict, Any + +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType + + +class RPCStatus(Enum): + """ + Enum representing the status of an RPC call. + """ + SUCCESS = "SUCCESS" + ERROR = "ERROR" + TIMEOUT = "TIMEOUT" + NOT_FOUND = "NOT_FOUND" + + def __str__(self): + return self.value + + +@dataclass(slots=True, frozen=True) +class RPCResponse: + """ + Represents a response to the RPC call. + + Attributes: + request_id: Unique identifier of the request being responded to. + result: Optional response payload (Any type allowed). + error: Optional error information if the RPC failed. + """ + request_id: Union[int, str] + status: RPCStatus = None + result: Optional[Any] = None + error: Optional[Union[str, Dict[str, Any]]] = None + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of RPCResponse is not allowed. Use RPCResponse.build(request_id, result, error).") + + def __repr__(self) -> str: + return f"RPCResponse(request_id={self.request_id}, result={self.result}, error={self.error})" + + @classmethod + def build(cls, request_id: Union[int, str], result: Optional[Any] = None, + error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'RPCResponse': + """ + Constructs an RPCResponse explicitly. + """ + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + + if error is not None: + if not isinstance(error, (str, dict, BaseException)): + raise ValueError("Error must be a string, dictionary, or an exception instance") + + object.__setattr__(self, 'status', RPCStatus.ERROR) + + if isinstance(error, BaseException): + try: + raise error + except BaseException as e: + error = { + "message": str(e), + "type": type(e).__name__, + "details": ''.join(format_exception(type(e), e, e.__traceback__)) + } + + validate_json_compatibility(error) + object.__setattr__(self, 'error', error) + + else: + object.__setattr__(self, 'status', RPCStatus.SUCCESS) + object.__setattr__(self, 'error', None) + validate_json_compatibility(result) + + object.__setattr__(self, 'result', result) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """Serializes the RPC response for publishing.""" + data = {} + if self.result is not None: + data["result"] = self.result + if self.error is not None: + data["error"] = self.error + return data diff --git a/tb_mqtt_client/entities/data/timeseries_entry.py b/tb_mqtt_client/entities/data/timeseries_entry.py new file mode 100644 index 0000000..d63f7bf --- /dev/null +++ b/tb_mqtt_client/entities/data/timeseries_entry.py @@ -0,0 +1,40 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from tb_mqtt_client.constants.json_typing import JSONCompatibleType +from tb_mqtt_client.entities.data.data_entry import DataEntry + + +class TimeseriesEntry(DataEntry): + def __init__(self, key: str, value: JSONCompatibleType, ts: Optional[int] = None): + super().__init__(key, value, ts) + + def __repr__(self): + return f"TimeseriesEntry(key={self.key}, value={self.value}, ts={self.ts})" + + def as_dict(self) -> dict: + result = { + "key": self.key, + "value": self.value + } + if self.ts is not None: + result["ts"] = self.ts + return result + + def __eq__(self, other): + if not isinstance(other, TimeseriesEntry): + return False + return self.key == other.key and self.value == other.value and self.ts == other.ts diff --git a/tb_mqtt_client/entities/gateway/__init__.py b/tb_mqtt_client/entities/gateway/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/entities/gateway/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/entities/gateway/base_gateway_event.py b/tb_mqtt_client/entities/gateway/base_gateway_event.py new file mode 100644 index 0000000..63a9679 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/base_gateway_event.py @@ -0,0 +1,35 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +class BaseGatewayEvent: + def __init__(self, event_type: GatewayEventType): + self.__event_type = event_type + self.__device_session = None + + @property + def event_type(self) -> GatewayEventType: + return self.__event_type + + def set_device_session(self, device_session): + self.__device_session = device_session + + @property + def device_session(self): + return self.__device_session + + def __str__(self) -> str: + return self.__repr__() diff --git a/tb_mqtt_client/entities/gateway/device_connect_message.py b/tb_mqtt_client/entities/gateway/device_connect_message.py new file mode 100644 index 0000000..cdec62d --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_connect_message.py @@ -0,0 +1,59 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict + +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +@dataclass(slots=True, frozen=True) +class DeviceConnectMessage(BaseGatewayEvent): + """ + Represents a device connection message in the ThingsBoard Gateway MQTT client. + This class is used to encapsulate the details of a device connection message. + """ + device_name: str + device_profile: str = 'default' + event_type: GatewayEventType = GatewayEventType.DEVICE_CONNECT + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of DeviceConnectMessage is not allowed. Use 'await DeviceConnectMessage.build(...)'.") + + def __repr__(self): + return f"DeviceConnectMessage(device_name={self.device_name}, device_profile={self.device_profile})" + + @classmethod + def build(cls, device_name: str, device_profile: str = 'default') -> 'DeviceConnectMessage': + """ + Build a new DeviceConnectMessage with the specified device name and profile. + """ + if not device_name: + raise ValueError("Device name must not be empty.") + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_CONNECT) + return self + + def to_payload_format(self) -> Dict[str, str]: + """ + Convert the device connection message into the expected MQTT payload format. + """ + return { + "device": self.device_name, + "type": self.device_profile + } diff --git a/tb_mqtt_client/entities/gateway/device_disconnect_message.py b/tb_mqtt_client/entities/gateway/device_disconnect_message.py new file mode 100644 index 0000000..2e5bae1 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_disconnect_message.py @@ -0,0 +1,57 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Dict + +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +@dataclass(slots=True, frozen=True) +class DeviceDisconnectMessage(BaseGatewayEvent): + """ + Represents a device disconnection message in the ThingsBoard Gateway MQTT client. + This class is used to encapsulate the details of a device connection message. + """ + device_name: str + event_type: GatewayEventType = GatewayEventType.DEVICE_DISCONNECT + + def __new__(self, *args, **kwargs): + raise TypeError( + "Direct instantiation of DeviceDisconnectMessage is not allowed. Use 'DeviceDisconnectMessage.build(...)'.") + + def __repr__(self): + return f"DeviceDisconnectMessage(device_name={self.device_name})" + + @classmethod + def build(cls, device_name: str) -> 'DeviceDisconnectMessage': + """ + Build a new DeviceDisconnectMessage with the specified device name and profile. + """ + if not device_name: + raise ValueError("Device name must not be empty.") + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_DISCONNECT) + return self + + def to_payload_format(self) -> Dict[str, str]: + """ + Convert the device connection message into the expected MQTT payload format. + """ + return { + "device": self.device_name + } diff --git a/tb_mqtt_client/entities/gateway/device_info.py b/tb_mqtt_client/entities/gateway/device_info.py new file mode 100644 index 0000000..f013958 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_info.py @@ -0,0 +1,90 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from dataclasses import dataclass, field + + +@dataclass() +class DeviceInfo(object): + device_name: str + device_profile: str + original_name: str = field(init=False) + device_id: uuid.UUID = field(default_factory=uuid.uuid4, init=False) + + _initializing: bool = field(default=True, init=False, repr=False) + + def __post_init__(self): + self.__setattr__("original_name", self.device_name) + self._initializing = False + + def __setattr__(self, key, value): + if not self._initializing: + raise AttributeError( + f"Cannot modify attribute '{key}' of frozen DeviceInfo instance." + "Use rename() method to change device_name.") + else: + super().__setattr__(key, value) + + def rename(self, new_name: str): + if new_name != self.device_name: + super().__setattr__('device_name', new_name) + + @classmethod + def from_dict(cls, data: dict) -> 'DeviceInfo': + original_post_init = cls.__post_init__ + cls.__post_init__ = lambda self: None + instance = cls( + device_name=data['device_name'], + device_profile=data.get('device_profile', 'default') + ) + instance.__setattr__("device_id", uuid.UUID(data['device_id'])) + if 'original_name' in data: + instance.__setattr__("original_name", data['original_name']) + else: + instance.__setattr__("original_name", instance.device_name) + instance._initializing = False + cls.__post_init__ = original_post_init + return instance + + def to_dict(self) -> dict: + return { + "device_name": self.device_name, + "device_profile": self.device_profile, + "device_id": str(self.device_id), + "original_name": self.original_name + } + + def __str__(self) -> str: + return (f"DeviceInfo(device_id={self.device_id}, " + f"device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"original_name={self.original_name})") + + def __repr__(self) -> str: + return (f"DeviceInfo(device_id={self.device_id!r}, " + f"device_name={self.device_name!r}, " + f"device_profile={self.device_profile!r}, " + f"original_name={self.original_name!r})") + + def __eq__(self, other): + if not isinstance(other, DeviceInfo): + return NotImplemented + return (self.device_id == other.device_id and + self.device_name == other.device_name and + self.device_profile == other.device_profile and + self.original_name == other.original_name) + + def __hash__(self): + return hash((self.device_id, self.device_name, self.device_profile, self.original_name)) diff --git a/tb_mqtt_client/entities/gateway/device_session_state.py b/tb_mqtt_client/entities/gateway/device_session_state.py new file mode 100644 index 0000000..043eead --- /dev/null +++ b/tb_mqtt_client/entities/gateway/device_session_state.py @@ -0,0 +1,32 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class DeviceSessionState(Enum): + CONNECTED = "CONNECTED" + DISCONNECTED = "DISCONNECTED" + RECONNECTING = "RECONNECTING" + CONNECTION_LOST = "CONNECTION_LOST" + CONNECTION_FAILED = "CONNECTION_FAILED" + + def is_connected(self) -> bool: + """ + Check if the session state indicates a connected state. + + Returns: + bool: True if the session is connected, False otherwise. + """ + return self == DeviceSessionState.CONNECTED diff --git a/tb_mqtt_client/entities/gateway/event_type.py b/tb_mqtt_client/entities/gateway/event_type.py new file mode 100644 index 0000000..0f97890 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/event_type.py @@ -0,0 +1,45 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class GatewayEventType(Enum): + """ + Enum representing different types of gateway events. + Each event type corresponds to a specific action or state change in the gateway. + """ + GATEWAY_CONNECT = "gateway.connect" + GATEWAY_DISCONNECT = "gateway.disconnect" + + DEVICE_CONNECT = "gateway.device.connect" + DEVICE_DISCONNECT = "gateway.device.disconnect" + DEVICE_UPDATE = "gateway.device.update" + + DEVICE_UPLINK = "gateway.device.uplink" + DEVICE_ATTRIBUTE_REQUEST = "gateway.device.attribute.request" + DEVICE_ATTRIBUTE_UPDATE = "gateway.device.attribute.update" + DEVICE_REQUESTED_ATTRIBUTE_RESPONSE = "gateway.device.requested.attribute.response" + + DEVICE_SESSION_STATE_CHANGE = "gateway.device.session.state.change" + DEVICE_RPC_REQUEST = "gateway.device.rpc.request" + DEVICE_RPC_RESPONSE = "gateway.device.rpc.response" + + GATEWAY_CLAIM_REQUEST = "gateway.gateway.claim" + + RPC_REQUEST_RECEIVE = "device.rpc.request" + RPC_RESPONSE_SEND = "device.rpc.response" + + def __str__(self) -> str: + return self.value diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_request.py b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py new file mode 100644 index 0000000..b3dbc02 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_request.py @@ -0,0 +1,92 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, List, Dict, Union + +from tb_mqtt_client.common.request_id_generator import AttributeRequestIdProducer +from tb_mqtt_client.constants.json_typing import validate_json_compatibility +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +@dataclass(slots=True, frozen=True) +class GatewayAttributeRequest(AttributeRequest): + """ + Represents a request for device attributes, with optional client and shared attribute keys. + Automatically assigns a unique request ID via the build() method. + """ + device_session: DeviceSession = None # type: ignore[assignment] + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of GatewayAttributeRequest is not allowed. " + "Use 'await GatewayAttributeRequest.build(...)'.") + + def __repr__(self) -> str: + return (f"GatewayAttributeRequest(device_session={self.device_session}, id={self.request_id}, " + f"shared_keys={self.shared_keys}, client_keys={self.client_keys})") + + @classmethod + async def build(cls, device_session: DeviceSession, # noqa + shared_keys: Optional[List[str]] = None, + client_keys: Optional[List[str]] = None) -> 'GatewayAttributeRequest': + """ + Build a new GatewayAttributeRequest with a unique request ID, using the global ID generator. + """ + validate_json_compatibility(shared_keys) + validate_json_compatibility(client_keys) + request_id = await AttributeRequestIdProducer.get_next() + self = object.__new__(cls) + object.__setattr__(self, 'device_session', device_session) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'shared_keys', shared_keys) + object.__setattr__(self, 'client_keys', client_keys) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) + return self + + @classmethod + async def from_attribute_request(cls, device_session: DeviceSession, + attribute_request: AttributeRequest) -> 'GatewayAttributeRequest': + """ + Create a GatewayAttributeRequest from an existing AttributeRequest and a DeviceSession. + """ + if not isinstance(attribute_request, AttributeRequest): + raise TypeError("attribute_request must be an instance of AttributeRequest") + self = object.__new__(cls) + object.__setattr__(self, 'device_session', device_session) + object.__setattr__(self, 'request_id', attribute_request.request_id) + object.__setattr__(self, 'shared_keys', attribute_request.shared_keys) + object.__setattr__(self, 'client_keys', attribute_request.client_keys) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) + return self + + def to_payload_format(self) -> Dict[str, Union[str, bool]]: + """ + Convert the attribute request into the expected MQTT payload format. + """ + payload = {"device": self.device_session.device_info.device_name, "id": self.request_id} + single_key_request = (self.client_keys is not None and len(self.client_keys) == 1) or ( + self.shared_keys is not None and len(self.shared_keys) == 1) + request_key = 'key' if single_key_request else 'keys' + if self.client_keys is not None and self.client_keys: + payload['client'] = True + payload[request_key] = self.client_keys[0] if single_key_request else self.client_keys + elif self.shared_keys is not None and self.shared_keys: + # TODO: In current realization on server it is not possible to request values for the both scopes + # TODO: at the same time, recommended to improve the platform API + payload['client'] = False + payload[request_key] = self.shared_keys[0] if single_key_request else self.shared_keys + return payload diff --git a/tb_mqtt_client/entities/gateway/gateway_attribute_update.py b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py new file mode 100644 index 0000000..ae5c0ee --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_attribute_update.py @@ -0,0 +1,45 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +class GatewayAttributeUpdate(AttributeUpdate, BaseGatewayEvent): + """ + Represents an attribute update event for a device connected to a gateway. + This event is used to notify about changes in device shared attributes. + """ + + def __init__(self, device_name: str, + attribute_update: Union[AttributeUpdate, List[AttributeEntry], AttributeEntry]): + super().__init__(GatewayEventType.DEVICE_ATTRIBUTE_UPDATE) + if isinstance(attribute_update, list) and all(isinstance(entry, AttributeEntry) for entry in attribute_update): + attribute_update = AttributeUpdate(entries=attribute_update) + elif isinstance(attribute_update, AttributeEntry): + attribute_update = AttributeUpdate(entries=[attribute_update]) + elif not isinstance(attribute_update, AttributeUpdate): + raise TypeError( + "attribute_update must be an instance of AttributeUpdate, " + "list of AttributeEntry, or a single AttributeEntry.") + self.device_name = device_name + self.entries = attribute_update.entries + self.attribute_update = attribute_update + + def __str__(self) -> str: + return f"GatewayAttributeUpdate(device_name={self.device_name}, attribute_update={self.attribute_update})" diff --git a/tb_mqtt_client/entities/gateway/gateway_claim_request.py b/tb_mqtt_client/entities/gateway/gateway_claim_request.py new file mode 100644 index 0000000..7330a2e --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_claim_request.py @@ -0,0 +1,94 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Any, Union + +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +@dataclass(slots=True, frozen=True) +class GatewayClaimRequest(BaseGatewayEvent): + devices_requests: Dict[Union[DeviceSession, str], ClaimRequest] = None + event_type: GatewayEventType = GatewayEventType.GATEWAY_CLAIM_REQUEST + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of GatewayClaimRequest is not allowed. Use 'GatewayClaimRequestBuilder.build(...)'.") + + def __repr__(self) -> str: + return f"GatewayClaimRequest(devices_requests={self.devices_requests})" + + def add_device_request(self, device_name_or_session: Union[DeviceSession, str], claim_request: ClaimRequest): + """ + Add a device claim request to the GatewayClaimRequest. + """ + self.devices_requests[device_name_or_session] = claim_request + + def to_payload_format(self) -> Dict[str, Any]: + """ + Convert the claim request to the expected MQTT JSON payload format. + """ + payload = {} + for device_session, claim_request in self.devices_requests.items(): + device_name = device_session + if isinstance(device_session, DeviceSession): + device_name = device_session.device_info.device_name + payload[device_name] = claim_request.to_payload_format() + return payload + + @classmethod + def build(cls) -> 'GatewayClaimRequest': + """ + Build a new GatewayClaimRequest instance. + """ + self = object.__new__(cls) + object.__setattr__(self, 'devices_requests', {}) + object.__setattr__(self, 'event_type', GatewayEventType.GATEWAY_CLAIM_REQUEST) + return self + + +class GatewayClaimRequestBuilder: + """ + Builder class for GatewayClaimRequest. + Allows adding multiple device claim requests in a fluent interface style. + """ + + def __init__(self): + self._devices_requests: Dict[Union[DeviceSession, str], ClaimRequest] = {} + + def add_device_request(self, device_name_or_session: Union[DeviceSession, str], + device_claim_request: ClaimRequest) -> 'GatewayClaimRequestBuilder': + """ + Add a device claim request to the builder. + """ + if not isinstance(device_name_or_session, (DeviceSession, str)): + raise ValueError( + "device_session must be an instance of DeviceSession or a string representing the device name") + if not isinstance(device_claim_request, ClaimRequest): + raise ValueError("device_claim_request must be an instance of ClaimRequest") + self._devices_requests[device_name_or_session] = device_claim_request + return self + + def build(self) -> GatewayClaimRequest: + """ + Build the GatewayClaimRequest with all added device requests. + """ + gateway_claim_request = GatewayClaimRequest.build() + for device_session, claim_request in self._devices_requests.items(): + gateway_claim_request.add_device_request(device_session, claim_request) + return gateway_claim_request diff --git a/tb_mqtt_client/entities/gateway/gateway_event.py b/tb_mqtt_client/entities/gateway/gateway_event.py new file mode 100644 index 0000000..e67bfb1 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_event.py @@ -0,0 +1,42 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +class GatewayEvent(BaseGatewayEvent): + """ + Base class for all events in the gateway client. + This class can be extended to create specific event types. + """ + + def __init__(self, event_type: GatewayEventType): + super().__init__(event_type) + self.__device_session: Union[DeviceSession, None] = None + + def set_device_session(self, device_session: DeviceSession): + self.__device_session = device_session + + def get_device_session(self) -> Union[DeviceSession, None]: + return self.__device_session + + def __str__(self) -> str: + return f"GatewayEvent(type={self.event_type}, device_session={self.__device_session})" + + def to_dict(self) -> dict: + return {"event_type": self.event_type, "device_session": self.__device_session} diff --git a/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py new file mode 100644 index 0000000..712531d --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_requested_attribute_response.py @@ -0,0 +1,96 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Any, List, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + +logger = get_logger(__name__) + + +@dataclass(slots=True, frozen=True) +class GatewayRequestedAttributeResponse(RequestedAttributeResponse, BaseGatewayEvent): + device_name: str = "" + request_id: int = -1 + shared: Optional[List[AttributeEntry]] = None + client: Optional[List[AttributeEntry]] = None + event_type: GatewayEventType = GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE + + def __post_init__(self): + super(BaseGatewayEvent, self).__setattr__('event_type', GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE) + + def __repr__(self): + return (f"GatewayRequestedAttributeResponse(device_name={self.device_name},request_id={self.request_id}, " + f"shared={self.shared}, client={self.client})") + + def __getitem__(self, item): + """ + Allows access to values using dictionary-like syntax. + """ + if self.shared is not None: + for entry in self.shared: + if entry.key == item: + return entry.value + if self.client is not None: + for entry in self.client: + if entry.key == item: + return entry.value + raise KeyError(f"Key '{item}' not found in shared or client attributes.") + + def shared_keys(self): + return [entry.key for entry in self.shared] + + def client_keys(self): + return [entry.key for entry in self.client] + + def get_shared(self, key: str, default=None): + """ + Get the value of a shared attribute by key. + :param key: The key of the shared attribute. + :param default: Default value if the key is not found. + :return: Value of the shared attribute or default. + """ + if self.shared is not None: + for entry in self.shared: + if entry.key == key: + return entry.value + return default + + def get_client(self, key: str, default=None): + """ + Get the value of a client attribute by key. + :param key: The key of the client attribute. + :param default: Default value if the key is not found. + :return: Value of the client attribute or default. + """ + if self.client is not None: + for entry in self.client: + if entry.key == key: + return entry.value + return default + + def as_dict(self) -> Dict[str, Any]: + """ + Convert the GatewayRequestedAttributeResponse to a dictionary format. + :return: Dictionary representation of the response. + """ + return { + 'shared': [entry.as_dict() for entry in self.shared if self.shared is not None], + 'client': [entry.as_dict() for entry in self.client if self.client is not None], + } diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_request.py b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py new file mode 100644 index 0000000..874fe1a --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_request.py @@ -0,0 +1,64 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union, Optional, Dict, Any + +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +@dataclass(slots=True, frozen=True) +class GatewayRPCRequest(BaseGatewayEvent): + request_id: int + device_name: str + method: str + params: Optional[Any] = None + event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_REQUEST + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of GatewayRPCRequest is not allowed. " + "Use 'GatewayRPCRequest._deserialize_from_dict(...)'.") + + def __repr__(self): + return (f"RPCRequest(id={self.request_id}, device_name={self.device_name}, " + f"method={self.method}, params={self.params})") + + def __str__(self): + return (f"RPCRequest(id={self.request_id}, device_name={self.device_name}, " + f"method={self.method}, params={self.params})") + + @classmethod + def _deserialize_from_dict(cls, data: Dict[str, Union[str, Dict[str, Any]]]) -> 'GatewayRPCRequest': + """ + Constructs an GatewayRPCRequest, should be used only for deserialization request from the platform. + """ + if "device" not in data: + raise ValueError("Missing device name in RPC request") + device_name = data["device"] + data = data["data"] + request_id = data["id"] + if not isinstance(request_id, (int, str)): + raise ValueError("Missing request id in RPC request") + if "method" not in data: + raise ValueError("Missing 'method' in RPC request") + + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'method', data["method"]) + object.__setattr__(self, 'params', data.get("params")) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_RPC_REQUEST) + return self diff --git a/tb_mqtt_client/entities/gateway/gateway_rpc_response.py b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py new file mode 100644 index 0000000..45055a9 --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_rpc_response.py @@ -0,0 +1,98 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from traceback import format_exception +from typing import Union, Optional, Dict, Any + +from tb_mqtt_client.constants.json_typing import validate_json_compatibility, JSONCompatibleType +from tb_mqtt_client.entities.data.rpc_response import RPCStatus, RPCResponse +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +@dataclass(slots=True, frozen=True) +class GatewayRPCResponse(RPCResponse, BaseGatewayEvent): + """ + Represents a response to the RPC call. + + Attributes: + device_name: Name of the device for which the RPC response is intended. + request_id: Unique identifier of the request being responded to. + result: Optional response payload (Any type allowed). + error: Optional error information if the RPC failed. + """ + device_name: str = None + event_type: GatewayEventType = GatewayEventType.DEVICE_RPC_RESPONSE + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of GatewayRPCResponse is not allowed. " + "Use GatewayRPCResponse.build(device_name, request_id, result, error).") + + def __repr__(self) -> str: + return (f"GatewayRPCResponse(device_name={self.device_name}, request_id={self.request_id}, " + f"result={self.result}, error={self.error})") + + @classmethod + def build(cls, # noqa + device_name: str, + request_id: int, + result: Optional[Any] = None, + error: Optional[Union[str, Dict[str, JSONCompatibleType], BaseException]] = None) -> 'GatewayRPCResponse': + """ + Constructs an GatewayRPCResponse explicitly. + """ + if not isinstance(device_name, str) or not device_name: + raise ValueError("Device name must be a non-empty string") + self = object.__new__(cls) + object.__setattr__(self, 'request_id', request_id) + object.__setattr__(self, 'device_name', device_name) + + if error is not None: + if not isinstance(error, (str, dict, BaseException)): + raise ValueError("Error must be a string, dictionary, or an exception instance") + + object.__setattr__(self, 'status', RPCStatus.ERROR) + + if isinstance(error, BaseException): + try: + raise error + except BaseException as e: + error = { + "message": str(e), + "type": type(e).__name__, + "details": ''.join(format_exception(type(e), e, e.__traceback__)) + } + + validate_json_compatibility(error) + object.__setattr__(self, 'error', error) + + else: + object.__setattr__(self, 'status', RPCStatus.SUCCESS) + object.__setattr__(self, 'error', None) + validate_json_compatibility(result) + + object.__setattr__(self, 'result', result) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_RPC_RESPONSE) + return self + + def to_payload_format(self) -> Dict[str, Any]: + """Serializes the RPC response for publishing.""" + data = {"device": self.device_name, "id": self.request_id, "data": {}} + if self.result is not None: + data["data"]["result"] = self.result + if self.error is not None: + data["data"]["error"] = self.error + return data diff --git a/tb_mqtt_client/entities/gateway/gateway_uplink_message.py b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py new file mode 100644 index 0000000..a80e44b --- /dev/null +++ b/tb_mqtt_client/entities/gateway/gateway_uplink_message.py @@ -0,0 +1,188 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass +from types import MappingProxyType +from typing import List, Optional, Union, OrderedDict, Mapping +from uuid import uuid4 + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + +logger = get_logger(__name__) + +DEFAULT_FIELDS_SIZE = len('{"device_name":"","attributes":"","timeseries":""}'.encode('utf-8')) + + +@dataclass(slots=True, frozen=True) +class GatewayUplinkMessage(DeviceUplinkMessage, BaseGatewayEvent): + device_name: str + device_profile: str + event_type: GatewayEventType = GatewayEventType.DEVICE_UPLINK + + def __new__(cls, *args, **kwargs): + raise TypeError( + "Direct instantiation of GatewayUplinkMessage is not allowed. " + "Use GatewayUplinkMessageBuilder to construct instances.") + + def __repr__(self): + return (f"GatewayUplinkMessage(device_name={self.device_name}, " + f"device_profile={self.device_profile}, " + f"attributes={self.attributes}, " + f"timeseries={self.timeseries}, " + f"delivery_futures={self.delivery_futures})") + + @classmethod + def build(cls, # noqa + device_name: Optional[str], + device_profile: Optional[str], + attributes: List[AttributeEntry], + timeseries: Mapping[int, List[TimeseriesEntry]], + delivery_futures: List[Optional[asyncio.Future]], + size: int, + main_ts: Optional[int]) -> 'GatewayUplinkMessage': + self = object.__new__(cls) + object.__setattr__(self, 'device_name', device_name) + object.__setattr__(self, 'device_profile', device_profile) + object.__setattr__(self, 'attributes', tuple(attributes)) + object.__setattr__(self, 'timeseries', + MappingProxyType({ts: tuple(entries) for ts, entries in timeseries.items()})) + object.__setattr__(self, 'delivery_futures', tuple(delivery_futures)) + object.__setattr__(self, '_size', size) + object.__setattr__(self, 'event_type', GatewayEventType.DEVICE_UPLINK) + object.__setattr__(self, 'main_ts', main_ts) + return self + + @property + def size(self) -> int: + return self._size + + def timeseries_datapoint_count(self) -> int: + return sum(len(entries) for entries in self.timeseries.values()) + + def attributes_datapoint_count(self) -> int: + return len(self.attributes) + + def has_attributes(self) -> bool: + return bool(self.attributes) + + def has_timeseries(self) -> bool: + return bool(self.timeseries) + + def get_delivery_futures(self) -> List[Optional[asyncio.Future[PublishResult]]]: + return self.delivery_futures + + def set_main_ts(self, main_ts: int) -> 'GatewayUplinkMessage': + """ + Set the main timestamp for the message. + :param main_ts: The main timestamp to set. + :return: The updated GatewayUplinkMessage instance. + """ + object.__setattr__(self, 'main_ts', main_ts) + return self + + +class GatewayUplinkMessageBuilder: + def __init__(self): + self._device_name: Optional[str] = None + self._device_profile: Optional[str] = None + self._attributes: List[AttributeEntry] = [] + self._timeseries: OrderedDict[int, List[TimeseriesEntry]] = OrderedDict() + self._delivery_futures: List[Optional[asyncio.Future[PublishResult]]] = [] + self.__size = DEFAULT_FIELDS_SIZE + self._main_ts: Optional[int] = None + + def set_device_name(self, device_name: str) -> 'GatewayUplinkMessageBuilder': + self._device_name = device_name + if device_name is not None: + self.__size += len(device_name) + return self + + def set_device_profile(self, profile: str) -> 'GatewayUplinkMessageBuilder': + self._device_profile = profile + if profile is not None: + self.__size += len(profile) + return self + + def add_attributes(self, attributes: Union[AttributeEntry, List[AttributeEntry]]) -> 'GatewayUplinkMessageBuilder': + if not isinstance(attributes, list): + attributes = [attributes] + self._attributes.extend(attributes) + for attribute in attributes: + self.__size += attribute.size + return self + + def add_timeseries(self, + timeseries: Union[TimeseriesEntry, + List[TimeseriesEntry], + OrderedDict[int, List[TimeseriesEntry]]]) -> 'GatewayUplinkMessageBuilder': + if isinstance(timeseries, OrderedDict): + self._timeseries = timeseries + return self + if not isinstance(timeseries, list): + timeseries = [timeseries] + for entry in timeseries: + if entry.ts is not None: + if entry.ts in self._timeseries: + self._timeseries[entry.ts].append(entry) + else: + self._timeseries[entry.ts] = [entry] + else: + if 0 in self._timeseries: + self._timeseries[0].append(entry) + else: + self._timeseries[0] = [entry] + self.__size += entry.size + return self + + def add_delivery_futures(self, + futures: Union[asyncio.Future[PublishResult], + List[asyncio.Future[PublishResult]]]) -> 'GatewayUplinkMessageBuilder': + if not isinstance(futures, list): + futures = [futures] + if futures: + logger.debug("Created delivery futures: %r", [id(future) for future in futures]) + self._delivery_futures.extend(futures) + return self + + def set_main_ts(self, main_ts: int) -> 'GatewayUplinkMessageBuilder': + self._main_ts = main_ts + return self + + def build(self) -> GatewayUplinkMessage: + if not self._delivery_futures: + delivery_future = asyncio.get_event_loop().create_future() + delivery_future.uuid = uuid4() + self._delivery_futures = [delivery_future] + if not self._device_profile: + self._device_profile = 'default' + self.__size += len(self._device_profile) + return GatewayUplinkMessage.build( # noqa + device_name=self._device_name, + device_profile=self._device_profile, + attributes=self._attributes, + timeseries=self._timeseries, + delivery_futures=self._delivery_futures, + size=self.__size, + main_ts=self._main_ts + ) + + def __len__(self) -> int: + return self.__size diff --git a/tb_mqtt_client/service/__init__.py b/tb_mqtt_client/service/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/service/base_client.py b/tb_mqtt_client/service/base_client.py new file mode 100644 index 0000000..61df408 --- /dev/null +++ b/tb_mqtt_client/service/base_client.py @@ -0,0 +1,207 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from abc import ABC, abstractmethod +from time import time +from typing import Callable, Awaitable, Dict, Any, Union, List, Optional + +import uvloop + +from tb_mqtt_client.common.exceptions import exception_handler +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants.json_typing import JSONCompatibleType +from tb_mqtt_client.constants.service_keys import TELEMETRY_TIMESTAMP_PARAMETER, TELEMETRY_VALUES_PARAMETER +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, GatewayUplinkMessage +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +exception_handler.install_asyncio_handler() + + +class BaseClient(ABC): + """ + Abstract base class for clients. + """ + + DEFAULT_TIMEOUT = 3.0 + + def __init__(self, host: str, port: int, client_id: str): + self._host = host + self._port = port + self._client_id = client_id + self._connected = asyncio.Event() + + @abstractmethod + async def connect(self): + """ + Connect to the platform over MQTT. + """ + pass + + @abstractmethod + async def disconnect(self): + """ + Disconnect from the platform. + """ + pass + + @abstractmethod + async def send_timeseries(self, + data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool = True, + timeout: Optional[float] = None) -> Optional[Union[asyncio.Future[PublishResult], + PublishResult, + List[PublishResult], + List[asyncio.Future[PublishResult]]]]: + """ + Sends timeseries data to the ThingsBoard server. + :param data: Timeseries data to send, can be a single TimeseriesEntry, a list of TimeseriesEntries, + a dictionary of key-value pairs, or a list of dictionaries. + :param wait_for_publish: If True, waits for the publish result. + :param timeout: Timeout for waiting for the publish result. + :return: PublishResult or list of PublishResults if wait_for_publish is True, Future or list of Futures if not, + None if no data is sent. + """ + pass + + @abstractmethod + async def send_attributes(self, + attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + wait_for_publish: bool = True, + timeout: Optional[float] = None) -> Union[asyncio.Future[PublishResult], PublishResult]: + """ + Send client attributes. + + :param attributes: Dictionary of attributes or a single AttributeEntry or a list of AttributeEntries. + :param wait_for_publish: If True, wait for the publishing result. Default is True. + :param timeout: Timeout for the publish operation if `wait_for_publish` is True. + In seconds. If less than 0 or None, wait indefinitely. + :return: Future or PublishResult depending on `wait_for_publish`. + """ + pass + + @abstractmethod + async def claim_device(self, claim_request: ClaimRequest) -> Union[asyncio.Future[PublishResult], PublishResult]: + """ + Claim a device using the provided ClaimRequest. + + :param claim_request: The ClaimRequest instance contains secret key and duration. + :return: Future or PublishResult depending on the implementation. + """ + pass + + @abstractmethod + async def send_rpc_response(self, response: RPCResponse): + """ + Send a response to a server-initiated RPC request. + + :param RPCResponse response: The RPC response for sending to the platform. + """ + pass + + @abstractmethod + def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + """ + Set callback to be triggered when a shared attribute update is received. + + :param callback: Coroutine accepting an AttributeUpdate instance. + """ + pass + + @abstractmethod + def set_rpc_request_callback(self, callback: Callable[[str, Dict[str, Any]], Awaitable[Dict[str, Any]]]): + """ + Set callback to be triggered when an RPC request is received. + + :param callback: Coroutine accepting (method, params) and returning result. + """ + pass + + @staticmethod + def _build_uplink_message_for_telemetry(payload: Union[Dict[str, Any], + TimeseriesEntry, + List[TimeseriesEntry], + List[Dict[str, Any]]], + device_session: Optional[DeviceSession] = None, + ) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: + timeseries_entries = [] + if isinstance(payload, TimeseriesEntry): + timeseries_entries.append(payload) + elif isinstance(payload, dict): + timeseries_entries.extend(BaseClient.__build_timeseries_entry_from_dict(payload)) + elif isinstance(payload, list) and len(payload) > 0: + for item in payload: + if isinstance(item, dict): + timeseries_entries.extend(BaseClient.__build_timeseries_entry_from_dict(item)) + elif isinstance(item, TimeseriesEntry): + timeseries_entries.append(item) + else: + raise ValueError(f"Unsupported item type in telemetry list: {type(item).__name__}") + else: + raise ValueError(f"Unsupported payload type for telemetry: {type(payload).__name__}") + + if device_session: + message_builder = GatewayUplinkMessageBuilder() + message_builder.set_device_name(device_session.device_info.device_name) + message_builder.set_device_profile(device_session.device_info.device_profile) + else: + message_builder = DeviceUplinkMessageBuilder() + message_builder.add_timeseries(timeseries_entries) + return message_builder.build() + + @staticmethod + def __build_timeseries_entry_from_dict(data: Dict[str, JSONCompatibleType]) -> List[TimeseriesEntry]: + result = [] + if TELEMETRY_TIMESTAMP_PARAMETER in data: + ts = data.pop(TELEMETRY_TIMESTAMP_PARAMETER) + values = data.pop(TELEMETRY_VALUES_PARAMETER, {}) + else: + ts = int(time() * 1000) + values = data + + if not isinstance(values, dict): + raise ValueError(f"Expected {TELEMETRY_VALUES_PARAMETER} to be a dict, got {type(values).__name__}") + + for key, value in values.items(): + result.append(TimeseriesEntry(key, value, ts=ts)) + + return result + + @staticmethod + def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], + AttributeEntry, + List[AttributeEntry]], + device_session=None) -> Union[DeviceUplinkMessage, GatewayUplinkMessage]: + + if isinstance(payload, dict): + payload = [AttributeEntry(k, v) for k, v in payload.items()] + + if device_session: + message_builder = GatewayUplinkMessageBuilder() + message_builder.set_device_name(device_session.device_info.device_name) + message_builder.set_device_profile(device_session.device_info.device_profile) + else: + message_builder = DeviceUplinkMessageBuilder() + message_builder.add_attributes(payload) + return message_builder.build() diff --git a/tb_mqtt_client/service/base_message_splitter.py b/tb_mqtt_client/service/base_message_splitter.py new file mode 100644 index 0000000..6c50dc9 --- /dev/null +++ b/tb_mqtt_client/service/base_message_splitter.py @@ -0,0 +1,64 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage + + +class BaseMessageSplitter(ABC): + """ + Base class for message splitters in the ThingsBoard MQTT client. + """ + + @property + @abstractmethod + def max_payload_size(self) -> int: + """ + Returns the maximum payload size for messages. + """ + + @max_payload_size.setter + @abstractmethod + def max_payload_size(self, value: int) -> None: + """ + Sets the maximum payload size for messages. + """ + + @property + @abstractmethod + def max_datapoints(self) -> int: + """ + Returns the maximum number of datapoints allowed in a message. + """ + + @max_datapoints.setter + @abstractmethod + def max_datapoints(self, value: int) -> None: + """ + Sets the maximum number of datapoints allowed in a message. + """ + + @abstractmethod + def split_timeseries(self, *args, **kwargs) -> List[MqttPublishMessage]: + """ + Splits timeseries data + """ + + @abstractmethod + def split_attributes(self, *args, **kwargs) -> List[MqttPublishMessage]: + """ + Splits attributes data + """ diff --git a/tb_mqtt_client/service/device/__init__.py b/tb_mqtt_client/service/device/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/service/device/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py new file mode 100644 index 0000000..13068e1 --- /dev/null +++ b/tb_mqtt_client/service/device/client.py @@ -0,0 +1,470 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ssl +from asyncio import sleep, wait_for, TimeoutError, Event, Future +from random import choices +from string import ascii_uppercase, digits +from typing import Callable, Awaitable, Optional, Dict, Any, Union, List + +from tb_mqtt_client.common.async_utils import await_or_stop +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.provisioning_client import ProvisioningClient +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater +from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler +from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ + RequestedAttributeResponseHandler +from tb_mqtt_client.service.device.handlers.rpc_requests_handler import RPCRequestsHandler +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter, MessageAdapter +from tb_mqtt_client.service.message_service import MessageService +from tb_mqtt_client.service.mqtt_manager import MQTTManager + +logger = get_logger(__name__) + + +class DeviceClient(BaseClient): + def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): + self._stop_event = Event() + self._config = None + if isinstance(config, DeviceConfig): + self._config = config + else: + self._config = DeviceConfig() + if isinstance(config, dict): + for key, value in config.items(): + if hasattr(self._config, key) and value is not None: + setattr(self._config, key, value) + + client_id = self._config.client_id or "tb-client-" + ''.join(choices(ascii_uppercase + digits, k=6)) + + super().__init__(self._config.host, self._config.port, client_id) + + self._message_queue: Optional[MessageService] = None + self._message_adapter: MessageAdapter = JsonMessageAdapter(1000, + 1) # Will be updated after connection established + + self._rate_limiter = RateLimiter( + message_rate_limit=RateLimit("0:0,", name="messages"), + telemetry_message_rate_limit=RateLimit("0:0,", name="telemetryMessages"), + telemetry_datapoints_rate_limit=RateLimit("0:0,", name="telemetryDataPoints") + ) + self._gateway_rate_limiter = RateLimiter( + message_rate_limit=RateLimit("0:0,", name="messages"), + telemetry_message_rate_limit=RateLimit("0:0,", name="telemetryMessages"), + telemetry_datapoints_rate_limit=RateLimit("0:0,", name="telemetryDataPoints") + ) + self._ssl_context = None + self.max_payload_size = None + self._max_inflight_messages = 100 + self._max_uplink_message_queue_size = 10000 + self._max_queued_messages = 50000 + + self._rpc_response_handler = RPCResponseHandler() + + self._mqtt_manager = MQTTManager(client_id=self._client_id, + main_stop_event=self._stop_event, + message_adapter=self._message_adapter, + on_connect=self._on_connect, + on_disconnect=self._on_disconnect, + on_publish_result=self.__on_publish_result, + rate_limits_handler=self._handle_rate_limit_response, + rpc_response_handler=self._rpc_response_handler) + + self._requested_attribute_response_handler = RequestedAttributeResponseHandler() + self._attribute_updates_handler = AttributeUpdatesHandler() + self._rpc_requests_handler = RPCRequestsHandler() + + self._firmware_updater = FirmwareUpdater(self) + + async def update_firmware(self, on_received_callback: Optional[Callable[[bytes, dict], Awaitable[None]]] = None, + save_firmware: bool = True, firmware_save_path: Optional[str] = None): + await self._firmware_updater.update(on_received_callback, save_firmware, firmware_save_path) + + async def connect(self): + logger.info("Connecting to platform at %s:%s", self._host, self._port) + + tls = self._config.use_tls() + if tls: + self._ssl_context = ssl.create_default_context() + self._ssl_context.load_verify_locations(self._config.ca_cert) + self._ssl_context.load_cert_chain(certfile=self._config.client_cert, keyfile=self._config.private_key) + + await self._mqtt_manager.connect( + host=self._host, + port=self._port, + username=self._config.access_token or self._config.username, + password=None if self._config.access_token else self._config.password, + tls=tls, + ssl_context=self._ssl_context + ) + + while not self._mqtt_manager.is_connected(): + await self._mqtt_manager.await_ready() + if self._stop_event.is_set(): + return + + # Initialize with default max_payload_size if not set + if self.max_payload_size is None: + self.max_payload_size = 65535 + logger.debug("Using default max_payload_size: %d", self.max_payload_size) + + self._message_adapter = JsonMessageAdapter(self.max_payload_size, + self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit) + self._message_queue = MessageService( + mqtt_manager=self._mqtt_manager, + main_stop_event=self._stop_event, + device_rate_limiter=self._rate_limiter, + gateway_rate_limiter=self._gateway_rate_limiter, + message_adapter=self._message_adapter, + max_queue_size=self._max_uplink_message_queue_size, + ) + + self._requested_attribute_response_handler.set_message_adapter(self._message_adapter) + self._attribute_updates_handler.set_message_adapter(self._message_adapter) + self._rpc_requests_handler.set_message_adapter(self._message_adapter) + self._rpc_response_handler.set_message_adapter(self._message_adapter) + + async def stop(self): + """ + Stops the client and disconnects from the MQTT broker. + """ + logger.info("Stopping DeviceClient...") + self._stop_event.set() + + for fut, _ in self._rpc_response_handler._pending_rpc_requests.values(): + if not fut.done(): + fut.cancel() + + if self._message_queue: + await self._message_queue.shutdown() + + await self._mqtt_manager.stop() + + logger.info("DeviceClient stopped.") + + async def disconnect(self): + await self._mqtt_manager.disconnect() + + async def send_telemetry(self, *args, **kwargs): + """ + Note: This method is deprecated. Use `send_timeseries` instead. + """ + logger.warning("send_telemetry is deprecated. Use send_timeseries instead.") + return await self.send_timeseries(*args, **kwargs) + + async def send_timeseries( + self, + data: Union[TimeseriesEntry, List[TimeseriesEntry], Dict[str, Any], List[Dict[str, Any]]], + wait_for_publish: bool = True, + timeout: Optional[float] = None + ) -> Optional[Union[PublishResult, List[PublishResult], Future[PublishResult], List[Future[PublishResult]]]]: + """ + Sends timeseries data to the ThingsBoard server. + :param data: Timeseries data to send, can be a single TimeseriesEntry, a list of TimeseriesEntries, + a dictionary of key-value pairs, or a list of dictionaries. + :param wait_for_publish: If True, waits for the publish result. + :param timeout: Timeout for waiting for the publish result. + :return: PublishResult or list of PublishResults if wait_for_publish is True, Future or list of Futures if not, + None if no data is sent. + """ + message = self._build_uplink_message_for_telemetry(data) + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.DEVICE_TELEMETRY_TOPIC, + payload=message, + qos=self._config.qos, + datapoints_count=message.timeseries_datapoint_count() + ) + delivery_future = mqtt_message.delivery_futures + + await self._message_queue.publish(mqtt_message) + + if not wait_for_publish: + return delivery_future + + if isinstance(delivery_future, list): + delivery_future = delivery_future[0] + + logger.info("Delivery future id in device client.send_timeseries: %r", delivery_future.uuid) + + try: + result = await await_or_stop(delivery_future, timeout=1, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(mqtt_message.topic, self._config.qos, -1, message.size, -1) + return result + + async def send_attributes( + self, + attributes: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + wait_for_publish: bool = True, + timeout: int = BaseClient.DEFAULT_TIMEOUT + ) -> Optional[Union[PublishResult, List[PublishResult], Future[PublishResult], List[Future[PublishResult]]]]: + message = self._build_uplink_message_for_attributes(attributes) + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, + payload=message, + qos=self._config.qos, + datapoints_count=message.attributes_datapoint_count() + ) + + await self._message_queue.publish(mqtt_message) + + futures = mqtt_message.delivery_futures + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=timeout, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for attribute publish result") + result = PublishResult(mqtt_message.topic, self._config.qos, -1, message.size, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + async def send_rpc_request( + self, + rpc_request: RPCRequest, + callback: Optional[Callable[[RPCResponse], Awaitable[None]]] = None, + wait_for_publish: bool = True, + timeout: Optional[float] = BaseClient.DEFAULT_TIMEOUT + ) -> Optional[Union[RPCResponse, Awaitable[RPCResponse]]]: + request_id = rpc_request.request_id or await RPCRequestIdProducer.get_next() + message_to_send = self._message_adapter.build_rpc_request(rpc_request) + message_to_send.qos = self._config.qos + response_future = self._rpc_response_handler.register_request(request_id, callback) + + await self._message_queue.publish(message_to_send) + + if not wait_for_publish: + return response_future + + try: + return await await_or_stop(response_future, timeout=timeout, stop_event=self._stop_event) + except TimeoutError as e: + if not callback: + raise TimeoutError(f"Timed out waiting for RPC response (requestId={request_id})") + else: + logger.warning("Timed out waiting for RPC response, but callback is set. " + "Callback will be called with None response.") + await self._rpc_response_handler.handle( + mqtt_topics.build_device_rpc_response_topic(rpc_request.request_id), e) + + async def send_rpc_response(self, response: RPCResponse): + mqtt_message = self._message_adapter.build_rpc_response(response) + mqtt_message.qos = self._config.qos + await self._message_queue.publish(mqtt_message) + + delivery_future = mqtt_message.delivery_futures + + if isinstance(delivery_future, list): + delivery_future = delivery_future[0] + try: + return await await_or_stop(delivery_future, timeout=BaseClient.DEFAULT_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for RPC response publish result") + return PublishResult(mqtt_message.topic, mqtt_message.qos, -1, len(mqtt_message.payload), -1) + + async def send_attribute_request(self, + attribute_request: AttributeRequest, + callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + await self._requested_attribute_response_handler.register_request(attribute_request, callback) + + mqtt_message = self._message_adapter.build_attribute_request(attribute_request) + mqtt_message.qos = self._config.qos + + await self._message_queue.publish(mqtt_message) + + async def claim_device(self, + claim_request: ClaimRequest, + wait_for_publish: bool = True, + timeout: float = BaseClient.DEFAULT_TIMEOUT) -> Union[Future[PublishResult], PublishResult]: + mqtt_message = self._message_adapter.build_claim_request(claim_request) + mqtt_message.qos = self._config.qos + + delivery_future = mqtt_message.delivery_futures + + await self._message_queue.publish(mqtt_message) + if isinstance(delivery_future, list): + delivery_future = delivery_future[0] + if wait_for_publish: + try: + return await await_or_stop(delivery_future, timeout=timeout, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for claiming publish result") + return PublishResult(mqtt_message.topic, 1, -1, mqtt_message.payload_size, -1) + else: + return delivery_future + + def set_attribute_update_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + self._attribute_updates_handler.set_callback(callback) + + def set_rpc_request_callback(self, callback: Callable[[RPCRequest], Awaitable[RPCResponse]]): + self._rpc_requests_handler.set_callback(callback) + + async def _on_connect(self): + logger.info("Subscribing to attribute and RPC topics") + + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + if self._stop_event.is_set(): + return + + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_TOPIC, self._handle_attribute_update) + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, + self._handle_rpc_request) # noqa + self._mqtt_manager.register_handler(mqtt_topics.DEVICE_ATTRIBUTES_RESPONSE_TOPIC, + self._handle_requested_attribute_response) # noqa + # RPC responses are handled by the RPCResponseHandler, which is already registered + + async def _on_disconnect(self): + logger.info("Device client disconnected.") + self._requested_attribute_response_handler.clear() + + async def _handle_attribute_update(self, topic: str, payload: bytes): + await self._attribute_updates_handler.handle(topic, payload) + + async def _handle_rpc_request(self, topic: str, payload: bytes): + response: RPCResponse = await self._rpc_requests_handler.handle(topic, payload) + if response: + await self.send_rpc_response(response) + + async def _handle_rpc_response(self, topic: str, payload: bytes): + await self._rpc_response_handler.handle(topic, payload) + + async def _handle_requested_attribute_response(self, topic: str, payload: bytes): + await self._requested_attribute_response_handler.handle(topic, payload) + + async def _handle_rate_limit_response(self, response: RPCResponse): # noqa + try: + logger.debug("Received rate limit response payload: %s", response) + + if not isinstance(response.result, dict) or 'rateLimits' not in response.result: + logger.warning("Invalid rate limit response: %r", response) + return None + + rate_limits = response.result.get('rateLimits', {}) + + await self._rate_limiter.message_rate_limit.set_limit(rate_limits.get("messages", "0:0,"), + percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.telemetry_message_rate_limit.set_limit( + rate_limits.get("telemetryMessages", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._rate_limiter.telemetry_datapoints_rate_limit.set_limit( + rate_limits.get("telemetryDataPoints", "0:0,"), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + + server_inflight = int(response.result.get("maxInflightMessages", 100)) + limits = [rl.minimal_limit for rl in [ + self._rate_limiter.message_rate_limit, + self._rate_limiter.telemetry_message_rate_limit, + ] if rl.has_limit()] + + if limits: + self._max_inflight_messages = int( + min(min(limits), server_inflight) * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + else: + self._max_inflight_messages = int(server_inflight * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + if self._max_inflight_messages == 0: + self._max_inflight_messages = 10000 + + if "maxPayloadSize" in response.result: + self.max_payload_size = int(response.result["maxPayloadSize"] * DEFAULT_RATE_LIMIT_PERCENTAGE / 100) + # Update the adapter's splitter with the new max_payload_size + if self._message_adapter is not None: + self._message_adapter.splitter.max_payload_size = self.max_payload_size + logger.debug("Updated adapter's max_payload_size to %d", self.max_payload_size) + else: + # If maxPayloadSize is not provided, keep the default value + logger.debug("No maxPayloadSize in service config, using default: %d", self.max_payload_size) + # Initialize with default max_payload_size if not set + if self.max_payload_size is None: + self.max_payload_size = 65535 + logger.debug("Using default max_payload_size: %d", self.max_payload_size) + # Update the dispatcher's max_payload_size if it's already initialized + if self._message_adapter is not None and hasattr(self._message_adapter, 'splitter'): + self._message_adapter.splitter.max_payload_size = self.max_payload_size + logger.debug("Updated dispatcher's max_payload_size to %d", self.max_payload_size) + + self._message_adapter.splitter.max_datapoints = self._rate_limiter.telemetry_datapoints_rate_limit.minimal_limit # noqa + + if (not self._rate_limiter.message_rate_limit.has_limit() + and not self._rate_limiter.telemetry_message_rate_limit.has_limit() + and not self._rate_limiter.telemetry_datapoints_rate_limit.has_limit()): + self._max_queued_messages = 50000 + logger.debug("No rate limits, setting max_queued_messages to 50000") + else: + self._max_queued_messages = self._max_inflight_messages + logger.debug("With rate limits, setting max_queued_messages to %r", self._max_queued_messages) + + logger.info("Service configuration retrieved and applied.") + logger.info("Parsed device limits: %r", response) + + self._mqtt_manager.set_rate_limits_received() + return True + + except Exception as e: + logger.exception("Failed to parse rate limits from server response: %s", e) + return False + + async def __on_publish_result(self, publish_result: PublishResult): + """ + Callback for handling publish results. + This can be used to handle the result of a publish operation, such as logging or updating state. + """ + if publish_result.is_successful(): + logger.trace("Publish successful: %r", publish_result) + else: + logger.error("Publish failed: %r", publish_result) + + @staticmethod + async def provision(provision_request: 'ProvisioningRequest', timeout=BaseClient.DEFAULT_TIMEOUT) -> Optional[ProvisioningResponse]: + provision_client = ProvisioningClient( + host=provision_request.host, + port=provision_request.port, + provision_request=provision_request + ) + + provisioning_response = None + try: + provisioning_response = await wait_for(provision_client.provision(), timeout=timeout) + except TimeoutError: + logger.error("Provisioning timed out") + + return provisioning_response diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py new file mode 100644 index 0000000..6ac123e --- /dev/null +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -0,0 +1,272 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import sleep +from hashlib import sha256, sha384, sha512, md5 +from os.path import sep +from random import randint +from subprocess import CalledProcessError +from typing import Awaitable, Callable, Optional +from zlib import crc32 + +from tb_mqtt_client.common.install_package_utils import install_package +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.firmware import ( + FW_CHECKSUM_ALG_ATTR, + FW_CHECKSUM_ATTR, + FW_SIZE_ATTR, + FW_STATE_ATTR, + FW_TITLE_ATTR, + FW_VERSION_ATTR, + REQUIRED_SHARED_KEYS, + FirmwareStates +) +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + +try: + from mmh3 import hash, hash128 +except ImportError: + try: + from pymmh3 import hash, hash128 + except ImportError: + try: + install_package('mmh3') + except CalledProcessError: + install_package('pymmh3') + +try: + from mmh3 import hash, hash128 # noqa +except ImportError: + from pymmh3 import hash, hash128 + +logger = get_logger(__name__) + + +class FirmwareUpdater: + def __init__(self, client): + self._log = logger + self._client = client + self._client._mqtt_manager.register_handler(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, + self._handle_firmware_update) + self._on_received_callback = None + self._save_firmware = True + self._save_path = './' + self._firmware_request_id = 0 + self._chunk_size = 0 + self._current_chunk = 0 + self._firmware_data = b'' + self._target_firmware_length = 0 + self._target_checksum = 0 + self._target_checksum_alg = None + self._target_version = None + self._target_title = None + self.current_firmware_info = { + 'current_' + FW_TITLE_ATTR: 'Initial', + 'current_' + FW_VERSION_ATTR: 'v0', + FW_STATE_ATTR: FirmwareStates.IDLE.value + } + + async def _handle_firmware_update(self, _, payload: bytes): + self._firmware_data = self._firmware_data + payload + self._current_chunk = self._current_chunk + 1 + + self._log.debug('Getting chunk with number: %s. Chunk size is : %r byte(s).' % ( + self._current_chunk, self._chunk_size)) + + if len(self._firmware_data) == self._target_firmware_length: + self._log.info('Firmware download completed. ' + 'Total firmware size: %s byte(s).' % self._target_firmware_length) + await self._verify_downloaded_firmware() + else: + await self._get_next_chunk() + + async def _get_next_chunk(self): + if not self._chunk_size or self._chunk_size > self._target_firmware_length: + payload = b'' + else: + payload = str(self._chunk_size).encode() + + topic = mqtt_topics.build_firmware_update_request_topic(self._firmware_request_id, self._current_chunk) + mqtt_message = MqttPublishMessage(topic, payload) + await self._client._message_queue.publish(mqtt_message) + + async def _verify_downloaded_firmware(self): + self._log.info('Verifying downloaded firmware...') + + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADED.value + await self._send_current_firmware_info() + + verified = self.verify_checksum(self._firmware_data, + self._target_checksum_alg, + self._target_checksum) + + if verified: + self._log.debug('Checksum verified.') + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.VERIFIED.value + else: + self._log.error('Checksum verification failed.') + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + + await self._send_current_firmware_info() + + if self.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.VERIFIED.value: + await self._apply_downloaded_firmware() + + async def _apply_downloaded_firmware(self): + self._log.info('Applying downloaded firmware...') + + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.UPDATING.value + await self._send_current_firmware_info() + + try: + if self._save_firmware: + self._save() + except Exception as e: + self._log.error('Failed to save firmware: %s', e) + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._send_current_firmware_info() + return + + self.current_firmware_info = { + "current_" + FW_TITLE_ATTR: self._target_title, + "current_" + FW_VERSION_ATTR: self._target_version, + FW_STATE_ATTR: FirmwareStates.UPDATED.value + } + + await self._send_current_firmware_info() + + if self._on_received_callback: + await self._on_received_callback(self._firmware_data, self.current_firmware_info) + await self._client._mqtt_manager.unsubscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC) + + self._log.info('Firmware is updated.') + self._log.info('Current firmware version is: %s' % self._target_version) + + def _save(self): + firmware_path = self._save_path + sep + self._target_title + with open(firmware_path, "wb") as firmware_file: + firmware_file.write(self._firmware_data) + + async def update(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, + save_firmware: bool = True, firmware_save_path: Optional[str] = None): + if not self._client._mqtt_manager.is_connected(): + self._log.error("Client is not connected. Cannot start firmware update.") + return + + self._log.info("Starting firmware update process...") + + self._on_received_callback = on_received_callback + self._save_firmware = save_firmware + if firmware_save_path: + self._save_path = firmware_save_path + self._log.info("Firmware will be saved to: %s", self._save_path) + + sub_future = await self._client._mqtt_manager.subscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, + qos=1) + while not sub_future.done(): + await sleep(0.01) + + await self._send_current_firmware_info() + + attribute_request = await AttributeRequest.build(REQUIRED_SHARED_KEYS) + await self._client.send_attribute_request(attribute_request, callback=self._firmware_info_callback) + + async def _firmware_info_callback(self, response, *args, **kwargs): + if len(response.shared_keys()) == len(REQUIRED_SHARED_KEYS): + fetched_firmware_info = response.as_dict()['shared'] + fetched_firmware_info = {item['key']: item['value'] + for item in fetched_firmware_info} + + if self._is_different_firmware_versions(fetched_firmware_info): + self._log.info("Firmware update available: %s. Downloading...", + fetched_firmware_info) + + self._firmware_data = b'' + self._current_chunk = 0 + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value + + self._firmware_request_id += 1 + self._target_firmware_length = fetched_firmware_info[FW_SIZE_ATTR] + self._target_checksum = fetched_firmware_info[FW_CHECKSUM_ATTR] + self._target_checksum_alg = fetched_firmware_info[FW_CHECKSUM_ALG_ATTR] + self._target_title = fetched_firmware_info[FW_TITLE_ATTR] + self._target_version = fetched_firmware_info[FW_VERSION_ATTR] + + await self._get_next_chunk() + else: + self._log.info("Firmware is up to date.") + else: + self._log.error("Failed to fetch firmware info. " + "Received firmware info does not match required keys. " + "Expected: %s, Received: %s", + REQUIRED_SHARED_KEYS, + response.shared_keys()) + + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._send_current_firmware_info() + + def _is_different_firmware_versions(self, new_firmware_info): + return (self.current_firmware_info['current_' + FW_TITLE_ATTR] != new_firmware_info[FW_TITLE_ATTR] or # noqa + self.current_firmware_info['current_' + FW_VERSION_ATTR] != new_firmware_info[FW_VERSION_ATTR]) # noqa + + async def _send_current_firmware_info(self): + current_info = [TimeseriesEntry(key, value) for key, value in self.current_firmware_info.items()] + await self._client.send_timeseries(current_info, wait_for_publish=True) + + def verify_checksum(self, firmware_data, checksum_alg, checksum): + if firmware_data is None: + self._log.debug('Firmware wasn\'t received!') + return False + + if checksum is None: + self._log.debug('Checksum was\'t provided!') + return False + + checksum_of_received_firmware = None + + self._log.debug('Checksum algorithm is: %s' % checksum_alg) + lower_checksum_alg = checksum_alg.lower() + if lower_checksum_alg == "sha256": + checksum_of_received_firmware = sha256(firmware_data).digest().hex() + elif lower_checksum_alg == "sha384": + checksum_of_received_firmware = sha384(firmware_data).digest().hex() + elif lower_checksum_alg == "sha512": + checksum_of_received_firmware = sha512(firmware_data).digest().hex() + elif lower_checksum_alg == "md5": + checksum_of_received_firmware = md5(firmware_data).digest().hex() + elif lower_checksum_alg == "murmur3_32": + reversed_checksum = f'{hash(firmware_data, signed=False):0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + elif lower_checksum_alg == "murmur3_128": + reversed_checksum = f'{hash128(firmware_data, signed=False):0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + elif lower_checksum_alg == "crc32": + reversed_checksum = f'{crc32(firmware_data) & 0xffffffff:0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + else: + self._log.error('Client error. Unsupported checksum algorithm.') + + return checksum_of_received_firmware == checksum diff --git a/tb_mqtt_client/service/device/handlers/__init__.py b/tb_mqtt_client/service/device/handlers/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py b/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py new file mode 100644 index 0000000..e740d5e --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/attribute_updates_handler.py @@ -0,0 +1,63 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + +logger = get_logger(__name__) + + +class AttributeUpdatesHandler: + """ + Handles shared attribute update messages from the platform. + """ + + def __init__(self): + self._message_adapter = None + self._callback: Optional[Callable[[AttributeUpdate], Awaitable[None]]] = None + + def set_message_adapter(self, message_adapter: MessageAdapter): + """ + Sets the message adapter for handling incoming messages. + This should be called before any callbacks are set. + + :param message_adapter: An instance of MessageAdapter. + """ + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for AttributeUpdatesHandler.") + + def set_callback(self, callback: Callable[[AttributeUpdate], Awaitable[None]]): + """ + Sets the async callback that will be triggered on shared attribute update. + + :param callback: A coroutine that takes an AttributeUpdate object. + """ + self._callback = callback + + async def handle(self, topic: str, payload: bytes): # noqa + if not self._callback: + logger.debug("No attribute update callback set. Skipping payload.") + return + + try: + data = self._message_adapter.parse_attribute_update(payload) + logger.debug("Handling attribute update: %r", data) + await self._callback(data) + except Exception as e: + logger.exception("Failed to handle attribute update: %s", e) diff --git a/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py new file mode 100644 index 0000000..bd2f0c7 --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/requested_attributes_response_handler.py @@ -0,0 +1,102 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple, Awaitable, Callable + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + +logger = get_logger(__name__) + + +class RequestedAttributeResponseHandler: + """ + Handles responses to attribute requests sent to the platform. + """ + + def __init__(self): + self._message_adapter = None + self._pending_attribute_requests: Dict[ + int, Tuple[AttributeRequest, Callable[[RequestedAttributeResponse], Awaitable[None]]]] = {} + + def set_message_adapter(self, message_adapter: MessageAdapter): + """ + Sets the message adapter for handling incoming messages. + This should be called before any requests are registered. + """ + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for RequestedAttributeResponseHandler.") + + async def register_request(self, request: AttributeRequest, + callback: Callable[[RequestedAttributeResponse], Awaitable[None]]): + """ + Called when a request is sent to the platform and a response is awaited. + """ + request_id = request.request_id + if request_id in self._pending_attribute_requests: + raise RuntimeError(f"Request ID {request.request_id} is already registered.") + self._pending_attribute_requests[request.request_id] = (request, callback) + + def unregister_request(self, request_id: int): + """ + Unregister a request by its ID. + This is useful if the request is no longer needed or has timed out. + """ + if request_id in self._pending_attribute_requests: + self._pending_attribute_requests.pop(request_id) + logger.debug("Unregistered attribute request with ID %s", request_id) + else: + logger.debug("Attempted to unregister non-existent request ID %s", request_id) + + async def handle(self, topic: str, payload: bytes): + """ + Handles the incoming attribute request response. + """ + try: + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle attribute response.") + request_id = topic.split('/')[-1] # Assuming request ID is in the topic + self._pending_attribute_requests.pop(int(request_id), None) + return + + requested_attribute_response = self._message_adapter.parse_requested_attribute_response(topic, payload) + pending_request_details = self._pending_attribute_requests.pop(requested_attribute_response.request_id, + None) + if not pending_request_details: + logger.warning("No future awaiting request ID %s. Ignoring.", + requested_attribute_response.request_id) + return + + request, callback = pending_request_details + + if callback: + logger.trace("Invoking callback for requested attribute response with ID %s", + requested_attribute_response.request_id) + await callback(requested_attribute_response) + else: + logger.error("No callback registered for requested attribute response with ID %s", + requested_attribute_response.request_id) + + except Exception as e: + logger.exception("Failed to handle requested attribute response: %s", e) + + def clear(self): + """ + Clears all pending requests (e.g., on disconnect). + """ + self._pending_attribute_requests.clear() diff --git a/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py b/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py new file mode 100644 index 0000000..e20cc9d --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/rpc_requests_handler.py @@ -0,0 +1,80 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + +logger = get_logger(__name__) + + +class RPCRequestsHandler: + """ + Handles incoming RPC request messages for a device. + """ + + def __init__(self): + self._message_adapter = None + self._callback: Optional[Callable[[RPCRequest], Awaitable[RPCResponse]]] = None + + def set_message_adapter(self, message_adapter: MessageAdapter): + """ + Sets the message adapter for handling incoming messages. + This should be called before any callbacks are set. + :param message_adapter: An instance of MessageAdapter. + """ + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for RPCRequestsHandler.") + + def set_callback(self, callback: Callable[[RPCRequest], Awaitable[RPCResponse]]): + """ + Set the async callback to handle incoming RPC requests. + :param callback: A coroutine that takes an RPCRequest and returns an RPCResponse. + """ + self._callback = callback + + async def handle(self, topic: str, payload: bytes) -> Optional[RPCResponse]: + """ + Process the RPC request and return the response payload and request ID (if possible). + :returns: (request_id, response_dict) or None if failed + """ + if not self._callback: + logger.debug("No RPC request callback set. Skipping RPC handling. " + "You can add set callback using client.set_rpc_request_callback(your_method)") + return None + + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle RPC request.") + return None + + try: + rpc_request = self._message_adapter.parse_rpc_request(topic, payload) + logger.debug("Handling RPC method id: %i - %s with params: %s", + rpc_request.request_id, rpc_request.method, rpc_request.params) + result = await self._callback(rpc_request) + if not isinstance(result, RPCResponse): + logger.error("RPC callback must return an instance of RPCResponse, got: %s", type(result)) + return None + logger.debug("RPC response for method id: %i - %s with result: %s", + result.request_id, rpc_request.method, result.result) + return result + + except Exception as e: + logger.exception("Failed to process RPC request: %s", e) + return None diff --git a/tb_mqtt_client/service/device/handlers/rpc_response_handler.py b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py new file mode 100644 index 0000000..64955a3 --- /dev/null +++ b/tb_mqtt_client/service/device/handlers/rpc_response_handler.py @@ -0,0 +1,106 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Dict, Union, Awaitable, Callable, Optional, Tuple +from uuid import uuid4 + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.message_adapter import MessageAdapter, JsonMessageAdapter + +logger = get_logger(__name__) + + +class RPCResponseHandler: + """ + Handles RPC responses coming from the platform to the client (client-side RPCs responses). + Maintains an internal map of request_id -> asyncio.Future for awaiting RPC results. + """ + + def __init__(self): + self._message_adapter: Optional[MessageAdapter] = None + self._pending_rpc_requests: Dict[Union[str, int], + Tuple[asyncio.Future[RPCResponse], + Optional[Callable[[RPCResponse], Awaitable[None]]]]] = {} + + def set_message_adapter(self, message_adapter: MessageAdapter): + """ + Sets the message adapter for handling incoming messages. + This should be called before any requests are registered. + :param message_adapter: An instance of MessageAdapter. + """ + if not isinstance(message_adapter, MessageAdapter): + raise ValueError("message_adapter must be an instance of MessageAdapter.") + self._message_adapter = message_adapter + logger.debug("Message adapter set for RPCResponseHandler.") + + def register_request(self, request_id: Union[str, int], + callback: Optional[Callable[[RPCResponse], + Awaitable[None]]] = None) -> asyncio.Future[RPCResponse]: + """ + Called when a request is sent to the platform and a response is awaited. + """ + if request_id in self._pending_rpc_requests: + raise RuntimeError(f"Request ID {request_id} is already registered.") + future = asyncio.get_event_loop().create_future() + future.uuid = uuid4() + self._pending_rpc_requests[request_id] = future, callback + return future + + async def handle(self, topic: str, payload: Union[bytes, TimeoutError]): + """ + Handles the incoming RPC response from the platform and fulfills the corresponding future. + The topic is expected to be: v1/devices/me/rpc/response/{request_id} + """ + try: + if not self._message_adapter: + dummy_adapter = JsonMessageAdapter() + rpc_response = dummy_adapter.parse_rpc_response(topic, payload) + else: + rpc_response = self._message_adapter.parse_rpc_response(topic, payload) + + request_details = self._pending_rpc_requests.pop(rpc_response.request_id, None) + if not request_details: + logger.warning("No pending request found for request ID %s. Ignoring response.", + rpc_response.request_id) + return + future, callback = request_details + if not future: + logger.warning("No future awaiting request ID %s. Ignoring.", rpc_response.request_id) + return + if callback: + try: + await callback(rpc_response) + except Exception as e: + logger.exception("Error in callback for request ID %s: %s", rpc_response.request_id, e) + future.set_exception(e) + return + + if rpc_response.error: + future.set_exception(Exception(rpc_response.error)) + else: + future.set_result(rpc_response) + + except Exception as e: + logger.exception("Failed to handle RPC response: %s", e) + + def clear(self): + """ + Clears all pending futures (e.g., on disconnect). + """ + for fut, _ in self._pending_rpc_requests.values(): + if not fut.done(): + fut.cancel() + self._pending_rpc_requests.clear() diff --git a/tb_mqtt_client/service/device/message_adapter.py b/tb_mqtt_client/service/device/message_adapter.py new file mode 100644 index 0000000..76472f2 --- /dev/null +++ b/tb_mqtt_client/service/device/message_adapter.py @@ -0,0 +1,475 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import defaultdict +from datetime import UTC, datetime +from typing import Any, Dict, List, Optional, Union + +from orjson import dumps, loads + +from tb_mqtt_client.common.async_utils import future_map +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.message_splitter import MessageSplitter + +logger = get_logger(__name__) + + +class MessageAdapter(ABC): + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): + self._splitter = MessageSplitter(max_payload_size, max_datapoints) + logger.trace("MessageAdapter initialized with max_payload_size=%s, max_datapoints=%s", + max_payload_size, max_datapoints) + + @abstractmethod + def build_uplink_messages( + self, + messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + """ + Build a list of topic-payload pairs from the given messages. + Each pair consists of a topic string, payload bytes, the number of datapoints, + and a list of futures for delivery confirmation. + """ + pass + + @abstractmethod + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + """ + Build the payload for an attribute request response. + This method should return a tuple of topic and payload bytes. + """ + pass + + @abstractmethod + def build_claim_request(self, claim_request) -> MqttPublishMessage: + """ + Build the payload for a claim request. + This method should return a tuple of topic and payload bytes. + """ + + @abstractmethod + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + """ + Build the payload for an RPC request. + This method should return a tuple of topic and payload bytes. + """ + pass + + @abstractmethod + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + """ + Build the payload for an RPC response. + This method should return a tuple of topic and payload bytes. + """ + pass + + @abstractmethod + def build_provision_request(self, provision_request) -> MqttPublishMessage: + """ + Build the payload for a device provisioning request. + This method should return a tuple of topic and payload bytes. + """ + pass + + @property + def splitter(self) -> MessageSplitter: + """ + Get the message splitter instance used by this adapter. + """ + return self._splitter + + @abstractmethod + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + """ + Parse the attribute request response payload into an AttributeRequestResponse. + This method should be implemented to handle the specific format of the topic and payload. + """ + pass + + @abstractmethod + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + """ + Parse the attribute update payload into an AttributeUpdate. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + """ + Parse the RPC request from the given topic and payload. + This method should be implemented to handle the specific format of the RPC request. + """ + pass + + @abstractmethod + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + """ + Parse the RPC response from the given topic and payload. + This method should be implemented to handle the specific format of the RPC response. + """ + pass + + @abstractmethod + def parse_provisioning_response(self, + provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + """ + Parse the provisioning response from the given payload. + This method should be implemented to handle the specific format of the provisioning response. + """ + pass + + +class JsonMessageAdapter(MessageAdapter): + """ + A concrete implementation of MessageDispatcher that operates with JSON payloads. + """ + + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): + super().__init__(max_payload_size, max_datapoints) + logger.trace("JsonMessageDispatcher created.") + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + """ + Parse the attribute request response payload into a RequestedAttributeResponse. + :param topic: The MQTT topic of the requested attribute response. + :param payload: The raw bytes of the payload. + :return: An instance of RequestedAttributeResponse. + """ + try: + request_id = int(topic.split("/")[-1]) + data = loads(payload) + logger.trace("Parsing attribute request response from payload: %s", data) + if not isinstance(data, dict): + logger.error("Invalid requested attribute response format: expected dict, got %s", + type(data).__name__) + raise ValueError("Invalid requested attribute response format") + data["request_id"] = request_id # Add request_id to the data dictionary + return RequestedAttributeResponse.from_dict(data) + except Exception as e: + logger.error("Failed to parse attribute request response: %s", str(e)) + raise ValueError("Invalid attribute request response format") from e + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + """ + Parse the attribute update payload into an AttributeUpdate. + :param payload: The raw bytes of the payload. + :return: An instance of AttributeUpdate. + """ + try: + data = loads(payload) + logger.trace("Parsing attribute update from payload: %s", data) + return AttributeUpdate._deserialize_from_dict(data) # noqa + except Exception as e: + logger.error("Failed to parse attribute update: %s", str(e)) + raise ValueError("Invalid attribute update format") from e + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + """ + Parse the RPC request from the given topic and payload. + :param topic: The MQTT topic of the RPC request. + :param payload: The raw bytes of the payload. + :return: An instance of RPCRequest. + """ + try: + request_id = int(topic.split("/")[-1]) + parsed = loads(payload) + data = RPCRequest._deserialize_from_dict(request_id, parsed) # noqa + return data + except Exception as e: + logger.error("Failed to parse RPC request: %s", str(e)) + raise ValueError("Invalid RPC request format") from e + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + """ + Parse the RPC response from the given topic and payload. + :param topic: The MQTT topic of the RPC response. + :param payload: The raw bytes of the payload. + :return: An instance of RPCResponse. + """ + try: + request_id = int(topic.split("/")[-1]) + if isinstance(payload, Exception): + data = RPCResponse.build(request_id, error=payload) + else: + parsed = loads(payload) + data = RPCResponse.build(request_id, parsed) + return data + except Exception as e: + logger.error("Failed to parse RPC response: %s", str(e)) + raise ValueError("Invalid RPC response format") from e + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> 'ProvisioningResponse': + """ + Parse the provisioning response from the given payload. + :param provisioning_request: The ProvisioningRequest that initiated the provisioning. + :param payload: The raw bytes of the payload. + :return: An instance of ProvisioningResponse. + """ + try: + data = loads(payload) + logger.trace("Parsing provisioning response from payload: %s", data) + return ProvisioningResponse.build(provisioning_request, data) + except Exception as e: + logger.error("Failed to parse provisioning response: %s", str(e)) + return ProvisioningResponse.build(provisioning_request, {"status": "FAILURE", "errorMsg": str(e)}) + + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + if not messages: + logger.trace("No messages to process in build_uplink_messages.") + return [] + + result: List[MqttPublishMessage] = [] + device_groups = defaultdict(list) + qos = messages[0].qos + + for mqtt_msg in messages: + payload = mqtt_msg.payload + if isinstance(payload, DeviceUplinkMessage): + device_groups[payload.device_name].append(mqtt_msg) + logger.trace("Queued DeviceUplinkMessage for device='%s'", payload.device_name) + else: + result.append(mqtt_msg) + logger.debug("Non-DeviceUplinkMessage found, sending as is: %s", type(payload).__name__) + + for device_name, group_msgs in device_groups.items(): + telemetry_msgs = [m for m in group_msgs if m.payload.has_timeseries()] + attr_msgs = [m for m in group_msgs if m.payload.has_attributes()] + + built_child_messages: List[MqttPublishMessage] = [] + + if telemetry_msgs: + ts_messages = [m.payload for m in telemetry_msgs] + for ts_batch in self._splitter.split_timeseries(ts_messages): + payload_bytes = JsonMessageAdapter.build_payload(ts_batch, True) + count = ts_batch.timeseries_datapoint_count() + child_futures = ts_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=DEVICE_TELEMETRY_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=child_futures, + main_ts=ts_batch.main_ts, + original_payload=ts_batch + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace( + "Built telemetry payload for '%s' with %d datapoints, futures=%r", + device_name, count, [f.uuid for f in child_futures] # noqa + ) + + if attr_msgs: + attr_messages = [m.payload for m in attr_msgs] + for attr_batch in self._splitter.split_attributes(attr_messages): + payload_bytes = JsonMessageAdapter.build_payload(attr_batch, False) + count = len(attr_batch.attributes) + child_futures = attr_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=DEVICE_ATTRIBUTES_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=child_futures, + main_ts=attr_batch.main_ts, + original_payload=attr_batch + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace( + "Built attribute payload for '%s' with %d attributes, futures=%r", + device_name, count, [f.uuid for f in child_futures] # noqa + ) + + # Register child futures to all original parent futures + parent_futures = [f for m in group_msgs for f in (m.delivery_futures or [])] + for parent in parent_futures: + for child_msg in built_child_messages: + for child in child_msg.delivery_futures or []: + future_map.register(parent, [child]) + + logger.trace("Generated %d topic-payload entries.", len(result)) + return result + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + """ + Build the payload for an attribute request response. + :param request: The AttributeRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not request.request_id: + raise ValueError("AttributeRequest must have a valid ID.") + + topic = mqtt_topics.build_device_attributes_request_topic(request.request_id) + payload = dumps(request.to_payload_format()) + logger.trace("Built attribute request payload for request: %r", request) + return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + """ + Build the payload for a claim request. + :param claim_request: The ClaimRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not claim_request.secret_key: + raise ValueError("ClaimRequest must have a valid secret key.") + + topic = mqtt_topics.DEVICE_CLAIM_TOPIC + payload = dumps(claim_request.to_payload_format()) + logger.trace("Built claim request payload: %r", claim_request) + return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + """ + Build the payload for an RPC request. + :param rpc_request: The RPC request to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not rpc_request.request_id: + raise ValueError("RPCRequest must have a valid ID.") + + payload = dumps(rpc_request.to_payload_format()) + topic = mqtt_topics.DEVICE_RPC_REQUEST_TOPIC + str(rpc_request.request_id) + logger.trace("Built RPC request payload for request ID=%d with payload: %r", + rpc_request.request_id, payload) + message_to_send = MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) + return message_to_send + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + """ + Build the payload for an RPC response. + :param rpc_response: The RPC response to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if rpc_response.request_id is None or rpc_response.request_id < 0: + raise ValueError("RPCResponse must have a valid request ID.") + + payload = dumps(rpc_response.to_payload_format()) + topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) + logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) + return MqttPublishMessage(topic=topic, payload=payload, qos=1, datapoints=1) + + def build_provision_request(self, provision_request: 'ProvisioningRequest') -> MqttPublishMessage: + """ + Build the payload for a device provisioning request. + :param provision_request: The ProvisioningRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if (not provision_request.credentials.provision_device_key + or not provision_request.credentials.provision_device_secret): + raise ValueError("ProvisioningRequest must have valid device key and secret.") + + topic = mqtt_topics.PROVISION_REQUEST_TOPIC + request = {} + request["provisionDeviceKey"] = provision_request.credentials.provision_device_key + request["provisionDeviceSecret"] = provision_request.credentials.provision_device_secret + + if provision_request.device_name: + request["deviceName"] = provision_request.device_name + + if provision_request.gateway: + request["gateway"] = provision_request.gateway + + if provision_request.credentials.credentials_type and \ + provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: + if provision_request.credentials.access_token is not None: + request["token"] = provision_request.credentials.access_token + request["credentialsType"] = provision_request.credentials.credentials_type.value + + if provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: + if provision_request.credentials.username is not None: + request["username"] = provision_request.credentials.username + + if provision_request.credentials.password is not None: + request["password"] = provision_request.credentials.password + + if provision_request.credentials.client_id is not None: + request["clientId"] = provision_request.credentials.client_id + + request["credentialsType"] = provision_request.credentials.credentials_type.value + + if provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: + request["hash"] = provision_request.credentials.public_cert + request["credentialsType"] = provision_request.credentials.credentials_type.value + + payload = dumps(request) + result_msg = MqttPublishMessage( + topic=topic, + payload=payload, + qos=1, + datapoints=1 + ) + logger.trace("Built provision request payload: %r", provision_request) + return result_msg + + @staticmethod + def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: + result: Union[Dict[str, Any], List[Dict[str, Any]]] = {} + if build_timeseries_payload: + logger.trace("Packing timeseries") + result = JsonMessageAdapter.pack_timeseries(msg) + else: + logger.trace("Packing attributes") + result = JsonMessageAdapter.pack_attributes(msg) + + payload = dumps(result) + logger.trace("Built payload size: %d bytes", len(payload)) + return payload + + @staticmethod + def pack_attributes(msg: DeviceUplinkMessage) -> Dict[str, Any]: + logger.trace("Packing %d attribute(s)", len(msg.attributes)) + return {attr.key: attr.value for attr in msg.attributes} + + @staticmethod + def pack_timeseries(msg: 'DeviceUplinkMessage') -> Union[Dict[str, Any], List[Dict[str, Any]]]: + entries = [e for entries in msg.timeseries.values() for e in entries] + if not entries: + return {} + + all_ts_none = True + for e in entries: + if e.ts is not None: + all_ts_none = False + break + + if all_ts_none: + result = {e.key: e.value for e in entries} + return [{"ts": msg.main_ts, "values": result}] if msg.main_ts is not None else result + + now_ts = msg.main_ts if msg.main_ts is not None else int(datetime.now(UTC).timestamp() * 1000) + grouped = defaultdict(dict) + for e in entries: + ts = e.ts if e.ts is not None else now_ts + grouped[ts][e.key] = e.value + + return [{"ts": ts, "values": values} for ts, values in grouped.items()] diff --git a/tb_mqtt_client/service/device/message_splitter.py b/tb_mqtt_client/service/device/message_splitter.py new file mode 100644 index 0000000..0b088af --- /dev/null +++ b/tb_mqtt_client/service/device/message_splitter.py @@ -0,0 +1,210 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import defaultdict +from typing import List, Optional, Dict, Tuple +from uuid import uuid4 + +from tb_mqtt_client.common.async_utils import future_map +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage, DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.base_message_splitter import BaseMessageSplitter + +logger = get_logger(__name__) + + +class MessageSplitter(BaseMessageSplitter): + DEFAULT_MAX_PAYLOAD_SIZE = 55_000 # Default to 55_000 to allow for some overhead + + def __init__(self, max_payload_size: int = DEFAULT_MAX_PAYLOAD_SIZE, max_datapoints: int = 0): + if max_payload_size is not None and max_payload_size > 0: + self._max_payload_size = max_payload_size + else: + self._max_payload_size = self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 + logger.trace("MessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", + self._max_payload_size, self._max_datapoints) + + def split_timeseries(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: + logger.trace("Splitting timeseries for %d messages", len(messages)) + + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) + and messages[0].size <= self._max_payload_size): + return messages + + result: List[DeviceUplinkMessage] = [] + + grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + logger.trace("Processing group: device='%s', profile='%s', messages=%d", + device_name, device_profile, len(group_msgs)) + + all_ts_entries: List[TimeseriesEntry] = [] + parent_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_timeseries(): + for ts_group in msg.timeseries.values(): + all_ts_entries.extend(ts_group) + parent_futures.extend(msg.get_delivery_futures() or []) + + builder: Optional[DeviceUplinkMessageBuilder] = None + size = 0 + point_count = 0 + + for ts_kv in all_ts_entries: + exceeds_size = builder and size + ts_kv.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder: + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + + logger.trace("Flushed batch with %d datapoints (size=%d)", + built.timeseries_datapoint_count(), size) + builder = DeviceUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) \ + .set_main_ts(group_msgs[0].main_ts if group_msgs else None) + size = 0 + point_count = 0 + + builder.add_timeseries(ts_kv) + size += ts_kv.size + point_count += 1 + + if builder and builder._timeseries: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + + logger.trace("Flushed final batch with %d datapoints (size=%d)", + built.timeseries_datapoint_count(), size) + + logger.trace("Total timeseries batches created: %d", len(result)) + return result + + def split_attributes(self, messages: List[DeviceUplinkMessage]) -> List[DeviceUplinkMessage]: + logger.trace("Splitting attributes for %d messages", len(messages)) + result: List[DeviceUplinkMessage] = [] + + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) + and messages[0].size <= self._max_payload_size): + return messages + + grouped: Dict[Tuple[str, Optional[str]], List[DeviceUplinkMessage]] = defaultdict(list) + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + logger.trace("Processing attribute group: device='%s', profile='%s', messages=%d", + device_name, device_profile, len(group_msgs)) + + all_attrs: List[AttributeEntry] = [] + parent_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_attributes(): + all_attrs.extend(msg.attributes) + parent_futures.extend(msg.get_delivery_futures()) + + builder: Optional[DeviceUplinkMessageBuilder] = None + size = 0 + point_count = 0 + + for attr in all_attrs: + exceeds_size = builder and size + attr.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder and builder._attributes: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + + logger.trace("Flushed attribute batch (count=%d, size=%d)", + len(built.attributes), size) + builder = DeviceUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) \ + .set_main_ts(group_msgs[0].main_ts if group_msgs else None) + size = 0 + point_count = 0 + + builder.add_attributes(attr) + size += attr.size + point_count += 1 + + if builder and builder._attributes: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + + logger.trace("Flushed final attribute batch (count=%d, size=%d)", + len(built.attributes), size) + + logger.trace("Total attribute batches created: %d", len(result)) + return result + + @property + def max_payload_size(self) -> int: + return self._max_payload_size + + @max_payload_size.setter + def max_payload_size(self, value: int): + old = self._max_payload_size + self._max_payload_size = value if value is not None and value > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + logger.debug("Updated max_payload_size: %d -> %d", old, self._max_payload_size) + + @property + def max_datapoints(self) -> int: + return self._max_datapoints + + @max_datapoints.setter + def max_datapoints(self, value: int): + old = self._max_datapoints + self._max_datapoints = value if value is not None and value > 0 else 0 + logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) diff --git a/tb_mqtt_client/service/gateway/__init__.py b/tb_mqtt_client/service/gateway/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/service/gateway/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/service/gateway/client.py b/tb_mqtt_client/service/gateway/client.py new file mode 100644 index 0000000..c4e1f96 --- /dev/null +++ b/tb_mqtt_client/service/gateway/client.py @@ -0,0 +1,451 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import sleep, Future +from time import monotonic +from typing import Optional, Dict, Union, Tuple, List, Any + +from tb_mqtt_client.common.async_utils import await_or_stop +from tb_mqtt_client.common.config_loader import GatewayConfig +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, DEFAULT_RATE_LIMIT_PERCENTAGE +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.service.device.client import DeviceClient +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.gateway_client_interface import GatewayClientInterface +from tb_mqtt_client.service.gateway.handlers.gateway_attribute_updates_handler import GatewayAttributeUpdatesHandler +from tb_mqtt_client.service.gateway.handlers.gateway_requested_attributes_response_handler import \ + GatewayRequestedAttributeResponseHandler +from tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler import GatewayRPCHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter, JsonGatewayMessageAdapter +from tb_mqtt_client.service.gateway.message_sender import GatewayMessageSender + +logger = get_logger(__name__) + + +class GatewayClient(DeviceClient, GatewayClientInterface): + """ + ThingsBoard Gateway MQTT client implementation. + This class extends DeviceClient and adds gateway-specific functionality. + """ + SUBSCRIPTIONS_TIMEOUT = 1.0 # Timeout for subscribe/unsubscribe operations + OPERATIONAL_TIMEOUT = 5.0 # Timeout for connection events + + def __init__(self, config: Optional[Union[GatewayConfig, Dict]] = None): + """ + Initialize a new GatewayClient instance. + + :param config: Gateway configuration object or dictionary + """ + self._config = config if isinstance(config, GatewayConfig) else GatewayConfig(config) + super().__init__(self._config) + self._mqtt_manager.enable_gateway_mode() + + self.device_manager = DeviceManager() + + self._event_dispatcher: DirectEventDispatcher = DirectEventDispatcher() + self._uplink_message_sender = GatewayMessageSender() + self._event_dispatcher.register(GatewayEventType.DEVICE_CONNECT, + self._uplink_message_sender.send_device_connect) + self._event_dispatcher.register(GatewayEventType.DEVICE_DISCONNECT, + self._uplink_message_sender.send_device_disconnect) + self._event_dispatcher.register(GatewayEventType.DEVICE_UPLINK, + self._uplink_message_sender.send_uplink_message) + self._event_dispatcher.register(GatewayEventType.DEVICE_ATTRIBUTE_REQUEST, + self._uplink_message_sender.send_attributes_request) + self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_RESPONSE, + self._uplink_message_sender.send_rpc_response) + self._event_dispatcher.register(GatewayEventType.GATEWAY_CLAIM_REQUEST, + self._uplink_message_sender.send_claim_request) + + # Default max payload size and datapoints count limit, should be changed after connection established + self._gateway_message_adapter: GatewayMessageAdapter = JsonGatewayMessageAdapter(1000, 1) + self._uplink_message_sender.set_message_adapter(self._gateway_message_adapter) + + self._multiplex_dispatcher = None # Placeholder for multiplex dispatcher, if needed + self._gateway_rpc_handler = GatewayRPCHandler(event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self.device_manager, + stop_event=self._stop_event) + self._gateway_attribute_updates_handler = GatewayAttributeUpdatesHandler( + event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self.device_manager) + self._gateway_requested_attribute_response_handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=self._event_dispatcher, + message_adapter=self._gateway_message_adapter, + device_manager=self.device_manager) + + # Gateway-specific rate limits + self._device_messages_rate_limit = RateLimit("0:0,", name="device_messages") + self._device_telemetry_rate_limit = RateLimit("0:0,", name="device_telemetry") + self._device_telemetry_dp_rate_limit = RateLimit("0:0,", name="device_telemetry_datapoints") + + # Callbacks + self._device_attribute_update_callback = None + self._device_rpc_request_callback = None + self._device_disconnect_callback = None + + async def connect(self): + """ + Connect to the platform. + """ + logger.info("Connecting gateway to platform at %s:%s", self._host, self._port) + await super().connect() + self._message_queue.set_gateway_message_adapter(self._gateway_message_adapter) + self._uplink_message_sender.set_message_queue(self._message_queue) + + # Subscribe to gateway-specific topics + await self._subscribe_to_gateway_topics() + + logger.info("Gateway connected to ThingsBoard.") + + async def connect_device(self, + device_name_or_device_connect_message: Union[str, DeviceConnectMessage], + device_profile: str = 'default', + wait_for_publish=False) -> Tuple[ + DeviceSession, + List[Union[PublishResult, Future[PublishResult]]]]: + """ + Connect a device to the gateway. + + :param device_name_or_device_connect_message: Name of the device or a DeviceConnectMessage object + :param device_profile: Profile of the device + :param wait_for_publish: Whether to wait for the publish result + :return: Tuple containing the DeviceSession and an optional PublishResult or list of PublishResults + """ + if not isinstance(device_name_or_device_connect_message, DeviceConnectMessage): + device_name = device_name_or_device_connect_message + device_connect_message = DeviceConnectMessage.build(device_name, device_profile) + else: + device_connect_message = device_name_or_device_connect_message + + logger.info("Connecting device %s with profile %s", + device_connect_message.device_name, + device_connect_message.device_profile) + device_session = self.device_manager.register(device_connect_message.device_name, + device_connect_message.device_profile) + futures = await self._event_dispatcher.dispatch(device_connect_message, qos=self._config.qos) # noqa + + if not futures: + logger.warning("No publish futures were returned from message queue") + return device_session, [] + + if not wait_for_publish: + return device_session, futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(mqtt_topics.GATEWAY_CONNECT_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + return device_session, results[0] if len(results) == 1 else results + + async def disconnect_device(self, device_session: DeviceSession, wait_for_publish: bool): + """ + Disconnect a device from the gateway. + + :param device_session: The DeviceSession object for the device + :param wait_for_publish: Whether to wait for the publish result + :return: PublishResult or Future[PublishResult] if successful, None if failed + """ + logger.info("Disconnecting device %s", device_session.device_info.device_name) + if not device_session: + logger.warning("No device session provided for disconnecting") + return None + device_disconnect_message = DeviceDisconnectMessage.build(device_session.device_info.device_name) + + futures = await self._event_dispatcher.dispatch(device_disconnect_message, qos=self._config.qos) # noqa + + if not futures: + logger.warning("No publish futures were returned from message queue") + return device_session, [] + + if not wait_for_publish: + return device_session, futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(mqtt_topics.GATEWAY_CONNECT_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + self.device_manager.unregister(device_session.device_info.device_id) + return device_session, results[0] if len(results) == 1 else results + + async def send_device_timeseries(self, + device_session: DeviceSession, + data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: + """ + Send timeseries data to the platform for a specific device. + :param device_session: The DeviceSession object for the device + :param data: Timeseries data to send, can be a single entry or a list of entries + :param wait_for_publish: Whether to wait for the publish result + :return: List of PublishResults or Future objects, or None if no data was sent + """ + logger.trace("Sending timeseries data for device %s", device_session.device_info.device_name) + + if not device_session or not data: + logger.warning("No device session or data provided for sending timeseries") + return None + + message = self._build_uplink_message_for_telemetry(data, device_session) + futures = await self._event_dispatcher.dispatch(message, qos=self._config.qos) # noqa + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for telemetry publish result") + result = PublishResult(mqtt_topics.GATEWAY_TELEMETRY_TOPIC, self._config.qos, -1, message.size, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + async def send_device_attributes(self, + device_session: DeviceSession, + data: Union[Dict[str, Any], AttributeEntry, list[AttributeEntry]], + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: + """ + Send attributes data to the platform for a specific device. + :param device_session: The DeviceSession object for the device + :param data: Attributes data to send, can be a single entry or a list of entries + :param wait_for_publish: Whether to wait for the publish result + """ + logger.trace("Sending attributes data for device %s", + device_session.device_info.device_name) + if not device_session or not data: + logger.warning("No device session or data provided for sending attributes") + return None + + message = self._build_uplink_message_for_attributes(data, device_session) + futures = await self._event_dispatcher.dispatch(message, qos=self._config.qos) # noqa + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for attributes publish result") + result = PublishResult(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, self._config.qos, -1, message.size, -1) + results.append(result) + return results[0] if len(results) == 1 else results + + async def send_device_attributes_request(self, + device_session: DeviceSession, + attribute_request: Union[AttributeRequest, GatewayAttributeRequest], + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: + """ + Send a request for device attributes to the platform. + :param device_session: The DeviceSession object for the device + :param attribute_request: Attributes to request, can be a single AttributeRequest or GatewayAttributeRequest + :param wait_for_publish: Whether to wait for the publish result + """ + logger.trace("Sending attributes request for device %s", + device_session.device_info.device_name) + if not device_session or not attribute_request: + logger.warning("No device session or attributes provided for sending attributes request") + return None + if isinstance(attribute_request, AttributeRequest): + attribute_request = await GatewayAttributeRequest.from_attribute_request(device_session=device_session, + attribute_request=attribute_request) # noqa + + await self._gateway_requested_attribute_response_handler.register_request(attribute_request) + futures = await self._event_dispatcher.dispatch(attribute_request, qos=self._config.qos) + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for attributes request publish result") + result = PublishResult(mqtt_topics.GATEWAY_ATTRIBUTES_REQUEST_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + async def send_device_claim_request(self, + device_session: DeviceSession, + gateway_claim_request: GatewayClaimRequest, + wait_for_publish: bool) -> Optional[Union[ + List[Union[PublishResult, Future[PublishResult]]], + PublishResult, + Future[PublishResult]]]: + """ + Send a claim request for a device to the platform. + :param device_session: The DeviceSession object for the device + :param gateway_claim_request: Claim request data + :param wait_for_publish: Whether to wait for the publish result + """ + logger.trace("Sending claim request for device %s", device_session.device_info.device_name) + if not device_session or not gateway_claim_request: + logger.warning("No device session or claim request provided for sending claim request") + return None + + futures = await self._event_dispatcher.dispatch(gateway_claim_request, qos=self._config.qos) # noqa + + if not futures: + logger.warning("No publish futures were returned from message queue") + return None + + if not wait_for_publish: + return futures[0] if len(futures) == 1 else futures + + results = [] + for fut in futures: + try: + result = await await_or_stop(fut, timeout=self.OPERATIONAL_TIMEOUT, stop_event=self._stop_event) + except TimeoutError: + logger.warning("Timeout while waiting for claim request publish result") + result = PublishResult(mqtt_topics.GATEWAY_CLAIM_TOPIC, self._config.qos, -1, -1, -1) + results.append(result) + + return results[0] if len(results) == 1 else results + + async def disconnect(self): + """ + Disconnect from the platform. + """ + logger.info("Disconnecting gateway from platform at %s:%s", self._host, self._port) + await self._unsubscribe_from_gateway_topics() + await super().disconnect() + logger.info("Gateway disconnected from ThingsBoard.") + + async def _subscribe_to_gateway_topics(self): + """ + Subscribe to gateway-specific MQTT topics. + """ + logger.info("Subscribing to gateway topics") + + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + sub_future = await self._mqtt_manager.subscribe(mqtt_topics.GATEWAY_RPC_TOPIC, qos=1) + while not sub_future.done(): + await sleep(0.01) + + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, + self._gateway_attribute_updates_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_RPC_TOPIC, + self._gateway_rpc_handler.handle) + self._mqtt_manager.register_handler(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC, + self._gateway_requested_attribute_response_handler.handle) + + async def _unsubscribe_from_gateway_topics(self): + """ + Unsubscribe from gateway-specific MQTT topics. + """ + logger.info("Unsubscribing from gateway topics") + + unsub_future = await self._mqtt_manager.unsubscribe(mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC) + unsubscribe_start_time = monotonic() + while not unsub_future.done(): + if monotonic() - unsubscribe_start_time > self.SUBSCRIPTIONS_TIMEOUT: + logger.warning("Unsubscribe from gateway attributes topic timed out") + break + await sleep(0.01) + + unsub_future = await self._mqtt_manager.unsubscribe(mqtt_topics.GATEWAY_ATTRIBUTES_RESPONSE_TOPIC) + unsubscribe_start_time = monotonic() + while not unsub_future.done(): + if monotonic() - unsubscribe_start_time > self.SUBSCRIPTIONS_TIMEOUT: + logger.warning("Unsubscribe from gateway attribute responses topic timed out") + break + await sleep(0.01) + + unsub_future = await self._mqtt_manager.unsubscribe(mqtt_topics.GATEWAY_RPC_TOPIC) + unsubscribe_start_time = monotonic() + while not unsub_future.done(): + if monotonic() - unsubscribe_start_time > self.SUBSCRIPTIONS_TIMEOUT: + logger.warning("Unsubscribe from gateway rpc topic timed out") + break + await sleep(0.01) + + async def _handle_rate_limit_response(self, response: RPCResponse): # noqa + device_rate_limits_processing_result = await super()._handle_rate_limit_response(response) + try: + gateway_rate_limits = response.result.get('gatewayRateLimits', {}) + + await self._gateway_rate_limiter.message_rate_limit.set_limit(gateway_rate_limits.get('messages', '0:0,'), + percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.telemetry_message_rate_limit.set_limit( + gateway_rate_limits.get('telemetryMessages', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + await self._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit( + gateway_rate_limits.get('telemetryDataPoints', '0:0,'), percentage=DEFAULT_RATE_LIMIT_PERCENTAGE) + + self._gateway_message_adapter.splitter.max_payload_size = self.max_payload_size + self._gateway_message_adapter.splitter.max_datapoints = self._gateway_rate_limiter.telemetry_datapoints_rate_limit.minimal_limit # noqa + + self._mqtt_manager.set_gateway_rate_limits_received() + return device_rate_limits_processing_result + + except Exception as e: + logger.exception("Failed to parse rate limits from server response: %s", e) + return False diff --git a/tb_mqtt_client/service/gateway/device_manager.py b/tb_mqtt_client/service/gateway/device_manager.py new file mode 100644 index 0000000..e7d786f --- /dev/null +++ b/tb_mqtt_client/service/gateway/device_manager.py @@ -0,0 +1,114 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Dict, Iterable, Callable, Set +from uuid import UUID + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +logger = get_logger(__name__) + + +class DeviceManager: + def __init__(self): + self._sessions_by_id: Dict[UUID, DeviceSession] = {} + self._ids_by_device_name: Dict[str, UUID] = {} + self._ids_by_original_name: Dict[str, UUID] = {} + self.__connected_devices: Set[DeviceSession] = set() + + def register(self, device_name: str, device_profile: str = "default") -> DeviceSession: + session = self.get_by_name(device_name) + if session: + return session + + device_info = DeviceInfo( + device_name=device_name, + device_profile=device_profile + ) + session = DeviceSession(device_info, self.__state_change_callback) + self._sessions_by_id[device_info.device_id] = session + self._ids_by_device_name[device_name] = device_info.device_id + self._ids_by_original_name[device_info.original_name] = device_info.device_id + session.update_last_seen() + return session + + def unregister(self, device_id: UUID): + session = self._sessions_by_id.pop(device_id, None) + if session: + self._ids_by_device_name.pop(session.device_info.device_name, None) + self._ids_by_original_name.pop(session.device_info.original_name, None) + + def get_by_id(self, device_id: UUID) -> Optional[DeviceSession]: + return self._sessions_by_id.get(device_id) + + def get_by_name(self, device_name: str) -> Optional[DeviceSession]: + device_id = self._ids_by_device_name.get(device_name) + if device_id: + return self._sessions_by_id.get(device_id) + renamed_id = self._ids_by_original_name.get(device_name) + if renamed_id: + return self._sessions_by_id.get(renamed_id) + return None + + def is_connected(self, device_id: UUID) -> bool: + return device_id in self._sessions_by_id + + def all(self) -> Iterable[DeviceSession]: + return self._sessions_by_id.values() + + def rename_device(self, old_name: str, new_name: str): + device_session = self.get_by_name(old_name) + if not device_session: + logger.warning(f"Device with name '{old_name}' not found for renaming to '{new_name}'.") + return + device_session.device_info.rename(new_name) + self._ids_by_device_name.pop(old_name, None) + device_id = device_session.device_info.device_id + self._ids_by_device_name[new_name] = device_id + self._ids_by_original_name[device_session.device_info.original_name] = device_id + + def set_attribute_update_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_attribute_update_callback(cb) + + def set_attribute_response_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_attribute_response_callback(cb) + + def set_rpc_request_callback(self, device_id: UUID, cb: Callable): + session = self._sessions_by_id.get(device_id) + if session: + session.set_rpc_request_callback(cb) + + def __state_change_callback(self, device_session: DeviceSession) -> None: + if device_session.state.is_connected() and device_session.device_info.device_id not in self.__connected_devices: + self.__connected_devices.add(device_session) + else: + self.__connected_devices.remove(device_session) + logger.debug(f"Device {device_session.device_info.device_name} state changed to {device_session.state}") + + @property + def connected_devices(self) -> Set[DeviceSession]: + return self.__connected_devices + + @property + def all_devices(self) -> Dict[UUID, DeviceSession]: + return self._sessions_by_id + + def __repr__(self): + return f"DeviceManager({str(list(self._ids_by_device_name.keys()))})" diff --git a/tb_mqtt_client/service/gateway/device_session.py b/tb_mqtt_client/service/gateway/device_session.py new file mode 100644 index 0000000..204e940 --- /dev/null +++ b/tb_mqtt_client/service/gateway/device_session.py @@ -0,0 +1,97 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass, field +from time import time +from typing import Callable, Awaitable, Optional + +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.device_session_state import DeviceSessionState +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse + + +@dataclass +class DeviceSession: + device_info: DeviceInfo + _state_change_callback: Optional[Callable[['DeviceSession'], None]] = None + connected_at: float = field(default_factory=lambda: int(time() * 1000)) + last_seen_at: float = field(default_factory=lambda: int(time() * 1000)) + claimed: bool = False + provisioned: bool = False + state: DeviceSessionState = DeviceSessionState.CONNECTED + + attribute_update_callback: Optional[Callable[['DeviceSession', 'AttributeUpdate'], + Optional[Awaitable[None]]]] = None + attribute_response_callback: Optional[Callable[['DeviceSession', 'RequestedAttributeResponse'], + Optional[Awaitable[None]]]] = None + rpc_request_callback: Optional[Callable[['DeviceSession', 'GatewayRPCRequest'], + Optional[Awaitable[Optional['GatewayRPCResponse']]]]] = None + + def update_state(self, new_state: DeviceSessionState): + self.state = new_state + if self._state_change_callback: + self._state_change_callback(self) + + def update_last_seen(self): + self.last_seen_at = int(time() * 1000) + + def set_attribute_update_callback(self, + cb: Callable[['DeviceSession', 'AttributeUpdate'], Optional[Awaitable[None]]]): + self.attribute_update_callback = cb + + def set_attribute_response_callback(self, cb: Callable[ + ['DeviceSession', 'RequestedAttributeResponse'], Optional[Awaitable[None]]]): + self.attribute_response_callback = cb + + def set_rpc_request_callback(self, cb: Callable[ + ['DeviceSession', 'GatewayRPCRequest'], Optional[Awaitable[Optional['GatewayRPCResponse']]]]): + self.rpc_request_callback = cb + + async def handle_event_to_device(self, event: BaseGatewayEvent) -> ( + Optional)[Awaitable[Optional['GatewayRPCResponse']]]: + cb = None + if GatewayEventType.DEVICE_ATTRIBUTE_UPDATE == event.event_type \ + and isinstance(event, AttributeUpdate): + if self.attribute_update_callback: + cb = self.attribute_update_callback + elif GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE == event.event_type \ + and isinstance(event, RequestedAttributeResponse): + if self.attribute_response_callback: + cb = self.attribute_response_callback + elif GatewayEventType.DEVICE_RPC_REQUEST == event.event_type \ + and isinstance(event, GatewayRPCRequest): + if self.rpc_request_callback: + cb = self.rpc_request_callback + + if cb is None: + return None + + if asyncio.iscoroutinefunction(cb): + return await cb(self, event) + else: + return cb(self, event) + + def __eq__(self, other): + if not isinstance(other, DeviceSession): + return NotImplemented + return self.device_info.device_id == other.device_info.device_id + + def __hash__(self): + return hash(self.device_info) diff --git a/tb_mqtt_client/service/gateway/direct_event_dispatcher.py b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py new file mode 100644 index 0000000..efc17c0 --- /dev/null +++ b/tb_mqtt_client/service/gateway/direct_event_dispatcher.py @@ -0,0 +1,63 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import defaultdict +from typing import Callable, Awaitable, Dict, List, Union + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession + +EventCallback = Union[Callable[..., Awaitable[None]], Callable[..., None]] + +logger = get_logger(__name__) + + +class DirectEventDispatcher: + """ + Direct event dispatcher for handling gateway events. + """ + + def __init__(self): + self._handlers: Dict[GatewayEventType, List[EventCallback]] = defaultdict(list) + self._lock = asyncio.Lock() + + def register(self, event_type: GatewayEventType, callback: EventCallback): + if callback not in self._handlers[event_type]: + self._handlers[event_type].append(callback) + + def unregister(self, event_type: GatewayEventType, callback: EventCallback): + if event_type not in self._handlers: + return + if callback in self._handlers[event_type]: + self._handlers[event_type].remove(callback) + if not self._handlers[event_type]: + del self._handlers[event_type] + + async def dispatch(self, event: GatewayEvent, *args, device_session: DeviceSession = None, **kwargs): + if device_session is not None: + return await device_session.handle_event_to_device(event) + async with self._lock: + callbacks = list(self._handlers.get(event.event_type, [])) + for callback in callbacks: + try: + if asyncio.iscoroutinefunction(callback): + return await callback(event, *args, **kwargs) + else: + return callback(event, *args, **kwargs) + except Exception as e: + logger.error(f"[EventDispatcher] Exception in handler for '{event.event_type}': {e}") + return None diff --git a/tb_mqtt_client/service/gateway/gateway_client_interface.py b/tb_mqtt_client/service/gateway/gateway_client_interface.py new file mode 100644 index 0000000..21e70a1 --- /dev/null +++ b/tb_mqtt_client/service/gateway/gateway_client_interface.py @@ -0,0 +1,74 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from asyncio import Future +from typing import Union, List, Tuple, Dict, Any, Optional + +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +class GatewayClientInterface(BaseClient, ABC): + + @abstractmethod + async def connect_device(self, + device_name: str, + device_profile: str, + wait_for_publish: bool) -> Tuple[DeviceSession, + List[Union[PublishResult, Future[PublishResult]]]]: ... + + @abstractmethod + async def disconnect_device(self, + device_session: DeviceSession, + wait_for_publish: bool) -> List[Union[PublishResult, Future[PublishResult]]]: ... + + @abstractmethod + async def send_device_timeseries(self, + device_session: DeviceSession, + data: Union[TimeseriesEntry, + List[TimeseriesEntry], + Dict[str, Any], + List[Dict[str, Any]]], + wait_for_publish: bool) -> Optional[List[Union[PublishResult, + Future[PublishResult]]]]: ... + + @abstractmethod + async def send_device_attributes(self, + device_session: DeviceSession, + attributes: Union[Dict[str, Any], + AttributeEntry, + List[AttributeEntry]], + wait_for_publish: bool) -> Optional[List[Union[PublishResult, + Future[PublishResult]]]]: ... + + @abstractmethod + async def send_device_attributes_request(self, + device_session: DeviceSession, + attributes: Union[AttributeRequest, GatewayAttributeRequest], + wait_for_publish: bool) -> Optional[ + List[Union[PublishResult, Future[PublishResult]]]]: ... + + @abstractmethod + async def send_device_claim_request(self, + device_session: DeviceSession, + gateway_claim_request: GatewayClaimRequest, + wait_for_publish: bool) -> Union[List[Union[PublishResult, + Future[PublishResult]]], None]: ... diff --git a/tb_mqtt_client/service/gateway/handlers/__init__.py b/tb_mqtt_client/service/gateway/handlers/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tb_mqtt_client/service/gateway/handlers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py new file mode 100644 index 0000000..717ed55 --- /dev/null +++ b/tb_mqtt_client/service/gateway/handlers/gateway_attribute_updates_handler.py @@ -0,0 +1,41 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +class GatewayAttributeUpdatesHandler: + """Handles shared attribute updates for devices connected to a gateway.""" + + def __init__(self, + event_dispatcher: DirectEventDispatcher, + message_adapter: GatewayMessageAdapter, + device_manager: DeviceManager): + self.event_dispatcher = event_dispatcher + self.message_adapter = message_adapter + self.device_manager = device_manager + + async def handle(self, topic: str, payload: bytes): + """ + Handles the gateway attribute update event by dispatching the attribute update + """ + data = self.message_adapter.deserialize_to_dict(payload) + gateway_attribute_update = self.message_adapter.parse_attribute_update(data) + device_session = self.device_manager.get_by_name(gateway_attribute_update.device_name) + if device_session: + gateway_attribute_update.set_device_session(device_session) + await self.event_dispatcher.dispatch(gateway_attribute_update.attribute_update, # noqa + device_session=device_session) diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py new file mode 100644 index 0000000..20377a6 --- /dev/null +++ b/tb_mqtt_client/service/gateway/handlers/gateway_requested_attributes_response_handler.py @@ -0,0 +1,134 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from asyncio import Task +from typing import Dict, Tuple, Coroutine, Callable, Union, TypeAlias, Any + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + +logger = get_logger(__name__) + +AttributeResponseCallback: TypeAlias = Callable[[GatewayRequestedAttributeResponse], Coroutine[Any, Any, None]] + + +class GatewayRequestedAttributeResponseHandler: + """ + Handles responses to attribute requests sent to the platform. + """ + + def __init__(self, event_dispatcher: Any, message_adapter: GatewayMessageAdapter, device_manager: DeviceManager): + self._event_dispatcher = event_dispatcher + self._message_adapter: Union[GatewayMessageAdapter, None] = message_adapter + self._device_manager = device_manager + self._pending_attribute_requests: Dict[Tuple[str, int], Tuple[GatewayAttributeRequest, Union[Task, None]]] = {} + + async def register_request(self, + request: GatewayAttributeRequest, + timeout: int = 30): + """ + Called when a request is sent to the platform and a response is awaited. + """ + request_id = request.request_id + device_name = request.device_session.device_info.device_name + key = (device_name, request_id) + if key in self._pending_attribute_requests: + raise RuntimeError(f"Request ID {request.request_id} is already registered.") + timeout_task = None + if timeout > 0: + timeout_task = asyncio.get_event_loop().call_later(timeout, self._on_timeout, device_name, request_id) + self._pending_attribute_requests[key] = (request, timeout_task) + logger.debug("Registered attribute request with ID %s for device %s", + request_id, device_name) + + def unregister_request(self, device_name: str, request_id: int): + """ + Unregisters a request for device attributes by device name and request ID. + This is useful if the request is no longer needed or has timed out. + """ + key = (device_name, request_id) + if key in self._pending_attribute_requests: + self._pending_attribute_requests.pop(key) + logger.debug("Unregistered attribute request with ID %s for device %s", + request_id, device_name) + else: + logger.debug("Attempted to unregister non-existent request ID %s for device %s", + request_id, device_name) + + async def handle(self, topic: str, payload: bytes): + """ + Handles the incoming attribute request response. + """ + try: + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle requested attribute response.") + return + deserialized_data = self._message_adapter.deserialize_to_dict(payload) + request_id = deserialized_data.get('id') + device_name = deserialized_data.get('device') + if request_id is None or device_name is None: + logger.error("Received requested attribute response without 'id' or 'device'. ") + return + attribute_request_with_callback = self._pending_attribute_requests.get((device_name, request_id)) + if not attribute_request_with_callback: + logger.warning("No pending request found for request ID %s. Ignoring response.", request_id) + return + attribute_request, timeout_task = attribute_request_with_callback + if timeout_task: + timeout_task.cancel() + requested_attribute_response = self._message_adapter.parse_gateway_requested_attribute_response( + attribute_request, deserialized_data) + device_session = self._device_manager.get_by_name(device_name) + if not device_session: + logger.warning("No device session found for device: %s", device_name) + return + dispatch_task = self._event_dispatcher.dispatch(requested_attribute_response, device_session=device_session) + + logger.trace("Dispatching callback for requested attribute response with ID %s", + requested_attribute_response.request_id) + task = asyncio.create_task(dispatch_task) + task.add_done_callback(self._handle_callback_exception) + + except Exception as e: + logger.exception("Failed to handle requested attribute response: %s", e) + + def _on_timeout(self, device_name: str, request_id: int): + """ + Called when a request times out. + Unregisters the request and logs a warning. + """ + key = (device_name, request_id) + if key in self._pending_attribute_requests: + self._pending_attribute_requests.pop(key) + logger.warning("Request ID %s for device %s has timed out and has been unregistered.", + request_id, device_name) + else: + logger.debug("Attempted to unregister non-existent request ID %s for device %s on timeout", + request_id, device_name) + + def _handle_callback_exception(self, task: asyncio.Task): + try: + task.result() + except Exception as e: + logger.exception("Exception in user-defined requested attribute callback: %s", e) + + def clear(self): + """ + Clears all pending requests (e.g., on disconnect). + """ + self._pending_attribute_requests.clear() diff --git a/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py new file mode 100644 index 0000000..a8dd5f5 --- /dev/null +++ b/tb_mqtt_client/service/gateway/handlers/gateway_rpc_handler.py @@ -0,0 +1,100 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Awaitable, Callable, Optional + +from tb_mqtt_client.common.async_utils import await_or_stop +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + +logger = get_logger(__name__) + + +class GatewayRPCHandler: + """ + Handles incoming RPC request messages for a device connected through the gateway. + """ + + def __init__(self, event_dispatcher: DirectEventDispatcher, + message_adapter: GatewayMessageAdapter, + device_manager: DeviceManager, + stop_event: asyncio.Event): + self._event_dispatcher = event_dispatcher + self._message_adapter = message_adapter + self._device_manager = device_manager + self._stop_event = stop_event + self._callback: Optional[Callable[[GatewayRPCRequest], Awaitable[GatewayRPCResponse]]] = None + self._event_dispatcher.register(GatewayEventType.DEVICE_RPC_REQUEST, self.handle) + + async def handle(self, topic: str, payload: bytes) -> None: + """ + Process the RPC request and return the response for it. + :returns: GatewayRPCResponse or None if failed + """ + + if not self._message_adapter: + logger.error("Message adapter is not initialized. Cannot handle RPC request.") + return None + + rpc_request: Optional[GatewayRPCRequest] = None + result = None + device_session: Optional[DeviceSession] = None + try: + data = self._message_adapter.deserialize_to_dict(payload) + rpc_request = self._message_adapter.parse_rpc_request(topic, data) + device_session = self._device_manager.get_by_name(rpc_request.device_name) + if not device_session: + logger.warning("No device session found for device: %s", rpc_request.device_name) + return None + logger.debug("Handling RPC method id: %i - %s with params: %s", + rpc_request.request_id, rpc_request.method, rpc_request.params) + result = await self._event_dispatcher.dispatch(rpc_request, device_session=device_session) # noqa + if not result: + return None + elif not isinstance(result, GatewayRPCResponse): + raise TypeError("RPC callback must return an instance of GatewayRPCResponse, got: %s", type(result)) + logger.debug("RPC response for device %r method id: %i - %s with result: %s", + rpc_request.device_name, result.request_id, rpc_request.method, result.result) + except Exception as e: + logger.exception("Failed to process RPC request: %s", e) + if rpc_request is None: + return None + result = GatewayRPCResponse.build(device_name=rpc_request.device_name, + request_id=rpc_request.request_id, + error=e) + + if not device_session: + logger.warning("No device session found for device: %s, cannot send RPC response", + rpc_request.device_name) + return None + + future = await self._event_dispatcher.dispatch(result) # noqa + + if not future: + logger.warning( + "No publish futures were returned from message queue for RPC response of device %s, request id %i", + rpc_request.device_name, rpc_request.request_id) + return None + try: + await await_or_stop(future, timeout=1, stop_event=self._stop_event) + except TimeoutError: + logger.warning("RPC response publish timed out for device %s, request id %i", + rpc_request.device_name, rpc_request.request_id) diff --git a/tb_mqtt_client/service/gateway/message_adapter.py b/tb_mqtt_client/service/gateway/message_adapter.py new file mode 100644 index 0000000..210b4ee --- /dev/null +++ b/tb_mqtt_client/service/gateway/message_adapter.py @@ -0,0 +1,417 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod, ABC +from collections import defaultdict +from datetime import datetime, UTC +from typing import List, Dict, Any, Union, Optional + +from orjson import loads, dumps + +from tb_mqtt_client.common.async_utils import future_map +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ + GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.service.gateway.message_splitter import GatewayMessageSplitter + +logger = get_logger(__name__) + + +class GatewayMessageAdapter(ABC): + """ + Adapter for converting events to uplink messages and received messages to events. + """ + + def __init__(self, max_payload_size: Optional[int] = None, max_datapoints: Optional[int] = None): + self._splitter = GatewayMessageSplitter(max_payload_size, max_datapoints) + logger.trace("GatewayMessageAdapter initialized with max_payload_size=%s, max_datapoints=%s", + max_payload_size, max_datapoints) + + @property + def splitter(self) -> GatewayMessageSplitter: + """ + Returns the message splitter instance used by this adapter. + This allows for splitting messages into smaller parts if needed. + """ + return self._splitter + + @abstractmethod + def build_uplink_messages( + self, + messages: List[MqttPublishMessage] + ) -> List[MqttPublishMessage]: + """ + Build a list of topic-payload pairs from the given messages. + Each pair consists of a topic string, payload bytes, the number of datapoints, + and a list of futures for delivery confirmation. + """ + pass + + @abstractmethod + def build_device_connect_message_payload(self, + device_connect_message: DeviceConnectMessage, + qos) -> MqttPublishMessage: + """ + Build the payload for a device connect message. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def build_device_disconnect_message_payload(self, + device_disconnect_message: DeviceDisconnectMessage, + qos) -> MqttPublishMessage: + """ + Build the payload for a device disconnect message. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def build_gateway_attribute_request_payload(self, + attribute_request: GatewayAttributeRequest, + qos) -> MqttPublishMessage: + """ + Build the payload for a gateway attribute request. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse, qos) -> MqttPublishMessage: + """ + Build the payload for a gateway RPC response. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def build_claim_request_payload(self, claim_request: GatewayClaimRequest, qos) -> MqttPublishMessage: + """ + Build the payload for a gateway claim request. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: + """ + Parse the attribute update payload into an GatewayAttributeUpdate. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_gateway_requested_attribute_response(self, + gateway_attribute_request: GatewayAttributeRequest, + data: Dict[str, Any]) -> Union[ + GatewayRequestedAttributeResponse, None]: + """ + Parse the gateway attribute response data into an GatewayAttributeResponse. + This method should be implemented to handle the specific format of the payload. + """ + pass + + @abstractmethod + def parse_rpc_request(self, topic: str, data: Dict[str, Any]) -> GatewayRPCRequest: + """ + Parse the RPC request from the given topic and payload. + This method should be implemented to handle the specific format of the RPC request. + """ + pass + + @abstractmethod + def deserialize_to_dict(self, payload: bytes) -> Dict[str, Any]: + """ + Deserialize incoming payload into a dictionary format, required for parsing responses. + This method should be implemented to handle the specific format of the response. + """ + pass + + +class JsonGatewayMessageAdapter(GatewayMessageAdapter): + """ + JSON implementation of GatewayMessageAdapter. + Builds uplink payloads from uplink message objects and parses JSON payloads into GatewayEvent objects. + """ + + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + if not messages: + logger.trace("No messages to process in build_uplink_messages.") + return [] + + result: List[MqttPublishMessage] = [] + device_groups: Dict[str, List[GatewayUplinkMessage]] = defaultdict(list) + qos = messages[0].qos + + # Group by device + for mqtt_msg in messages: + payload = mqtt_msg.payload + if isinstance(payload, GatewayUplinkMessage): + device_groups[payload.device_name].append(payload) + logger.trace("Queued GatewayUplinkMessage for device='%s'", payload.device_name) + else: + result.append(mqtt_msg) + logger.debug("Non-GatewayUplinkMessage found, sending as is: %s", type(payload).__name__) + + # Process each device group + for device_name, group_msgs in device_groups.items(): + telemetry_msgs = [m for m in group_msgs if m.has_timeseries()] + attr_msgs = [m for m in group_msgs if m.has_attributes()] + built_child_messages: List[MqttPublishMessage] = [] + + if telemetry_msgs: + for ts_batch in self._splitter.split_timeseries(telemetry_msgs): + payload_dict = {device_name: self.pack_timeseries(ts_batch)} + payload_bytes = dumps(payload_dict) + count = ts_batch.timeseries_datapoint_count() + futures = ts_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=GATEWAY_TELEMETRY_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=futures, + main_ts=ts_batch.main_ts, + original_payload=ts_batch + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + + if attr_msgs: + for attr_batch in self._splitter.split_attributes(attr_msgs): + payload_dict = {device_name: self.pack_attributes(attr_batch)} + payload_bytes = dumps(payload_dict) + count = attr_batch.attributes_datapoint_count() + futures = attr_batch.get_delivery_futures() or [] + + mqtt_msg = MqttPublishMessage( + topic=GATEWAY_ATTRIBUTES_TOPIC, + payload=payload_bytes, + qos=qos, + datapoints=count, + delivery_futures=futures, + main_ts=attr_batch.main_ts, + original_payload=attr_batch + ) + result.append(mqtt_msg) + built_child_messages.append(mqtt_msg) + + # Link parent futures to child delivery futures + parent_futures = [f for m in messages for f in (m.delivery_futures or []) + if isinstance(m.payload, GatewayUplinkMessage) and m.payload.device_name == device_name] + for parent in parent_futures: + for child_msg in built_child_messages: + for child in child_msg.delivery_futures or []: + future_map.register(parent, [child]) + + logger.trace("Generated %d MqttPublishMessage(s) for gateway uplink.", len(result)) + return result + + def build_device_connect_message_payload(self, + device_connect_message: DeviceConnectMessage, + qos) -> MqttPublishMessage: + """ + Build the payload for a device connect message. + This method serializes the DeviceConnectMessage to JSON format. + """ + try: + payload = dumps(device_connect_message.to_payload_format()) + logger.trace("Built device connect message payload for device='%s'", + device_connect_message.device_name) + return MqttPublishMessage(GATEWAY_CONNECT_TOPIC, payload, qos=1) + except Exception as e: + logger.error("Failed to build device connect message payload: %s", str(e)) + raise ValueError("Invalid device connect message format") from e + + def build_device_disconnect_message_payload(self, + device_disconnect_message: DeviceDisconnectMessage, + qos) -> MqttPublishMessage: + """ + Build the payload for a device disconnect message. + This method serializes the DeviceDisconnectMessage to JSON format. + """ + try: + payload = dumps(device_disconnect_message.to_payload_format()) + logger.trace("Built device disconnect message payload for device='%s'", + device_disconnect_message.device_name) + return MqttPublishMessage(GATEWAY_DISCONNECT_TOPIC, payload, qos) + except Exception as e: + logger.error("Failed to build device disconnect message payload: %s", str(e)) + raise ValueError("Invalid device disconnect message format") from e + + def build_gateway_attribute_request_payload(self, + attribute_request: GatewayAttributeRequest, + qos) -> MqttPublishMessage: + """ + Build the payload for a gateway attribute request. + This method serializes the GatewayAttributeRequest to JSON format. + """ + try: + payload = dumps(attribute_request.to_payload_format()) + logger.trace("Built gateway attribute request payload for device='%s'", + attribute_request.device_session.device_info.device_name) + return MqttPublishMessage(GATEWAY_ATTRIBUTES_REQUEST_TOPIC, payload, qos) + except Exception as e: + logger.error("Failed to build gateway attribute request payload: %s", str(e)) + raise ValueError("Invalid gateway attribute request format") from e + + def build_rpc_response_payload(self, rpc_response: GatewayRPCResponse, qos) -> MqttPublishMessage: + """ + Build the payload for a gateway RPC response. + This method serializes the GatewayRPCResponse to JSON format. + """ + try: + payload = dumps(rpc_response.to_payload_format()) + logger.trace("Built RPC response payload for device='%s', request_id=%i", + rpc_response.device_name, rpc_response.request_id) + return MqttPublishMessage(GATEWAY_RPC_TOPIC, payload, qos) + except Exception as e: + logger.error("Failed to build RPC response payload: %s", str(e)) + raise ValueError("Invalid RPC response format") from e + + def build_claim_request_payload(self, claim_request: GatewayClaimRequest, qos) -> MqttPublishMessage: + """ + Build the payload for a gateway claim request. + This method serializes the GatewayClaimRequest to JSON format. + """ + try: + payload = dumps(claim_request.to_payload_format()) + logger.trace("Built claim request payload for devices: %s", + list(claim_request.devices_requests.keys())) + return MqttPublishMessage(GATEWAY_CLAIM_TOPIC, payload, qos) + except Exception as e: + logger.error("Failed to build claim request payload: %s", str(e)) + raise ValueError("Invalid claim request format") from e + + def parse_attribute_update(self, data: Dict[str, Any]) -> GatewayAttributeUpdate: + try: + device_name = data['device'] + attribute_update = AttributeUpdate._deserialize_from_dict(data['data']) # noqa + return GatewayAttributeUpdate(device_name=device_name, attribute_update=attribute_update) + except Exception as e: + logger.error("Failed to parse attribute update: %s", str(e)) + raise ValueError("Invalid attribute update format") from e + + def parse_gateway_requested_attribute_response(self, + gateway_attribute_request: GatewayAttributeRequest, + data: Dict[str, Any]) -> Union[ + GatewayRequestedAttributeResponse, None]: + """ + Parse the gateway attribute response data into a GatewayRequestedAttributeResponse. + This method extracts the device name, shared and client attributes from the payload. + """ + try: + device_name = data['device'] + client = [] + shared = [] + client_keys_empty = gateway_attribute_request.client_keys is None + shared_keys_empty = gateway_attribute_request.shared_keys is None + if ('value' in data + and not (((not client_keys_empty and len(gateway_attribute_request.client_keys) == 1) + and not gateway_attribute_request.shared_keys) + or ((not shared_keys_empty and len(gateway_attribute_request.shared_keys) == 1) + and not gateway_attribute_request.client_keys))): + # TODO: Skipping case when requested several attributes, but only one is returned, issue on the platform + logger.warning("Received gateway attribute response with single key, but multiply keys expected. " + "Request keys: %s, Response value: %r", + gateway_attribute_request.client_keys + gateway_attribute_request.shared_keys, + data['value']) + client = [] + shared = [] + elif 'value' in data: + if not client_keys_empty and len(gateway_attribute_request.client_keys) == 1: + client = [AttributeEntry(gateway_attribute_request.client_keys[0], data['value'])] + elif not shared_keys_empty and len(gateway_attribute_request.shared_keys) == 1: + shared = [AttributeEntry(gateway_attribute_request.shared_keys[0], data['value'])] + elif 'values' in data: + if not client_keys_empty and len(gateway_attribute_request.client_keys) > 0: + client = [AttributeEntry(k, v) for k, v in data['values'].items() if + k in gateway_attribute_request.client_keys] + if not shared_keys_empty and len(gateway_attribute_request.shared_keys) > 0: + shared = [AttributeEntry(k, v) for k, v in data['values'].items() if + k in gateway_attribute_request.shared_keys] + return GatewayRequestedAttributeResponse(device_name=device_name, + request_id=gateway_attribute_request.request_id, + shared=shared, + client=client) + except Exception as e: + logger.error("Failed to parse gateway requested attribute response: %s", str(e)) + raise ValueError("Invalid gateway requested attribute response format") from e + + def parse_rpc_request(self, topic: str, data: Dict[str, Any]) -> GatewayRPCRequest: + """ + Parse the RPC request from the given topic and payload. + This method deserializes the payload into a GatewayRPCRequest object. + """ + try: + return GatewayRPCRequest._deserialize_from_dict(data) # noqa + except Exception as e: + logger.error("Failed to parse RPC request: %s", str(e)) + raise ValueError("Invalid RPC request format") from e + + def deserialize_to_dict(self, payload: bytes) -> Dict[str, Any]: + """ + Deserialize incoming payload into a dictionary format, required for parsing responses. + This method decodes the payload from bytes to a string and then loads it as a JSON object. + """ + try: + data = loads(payload.decode('utf-8')) + return data + except Exception as e: + logger.error("Failed to deserialize requested attribute response: %s", str(e)) + raise ValueError("Invalid requested attribute response format") from e + + @staticmethod + def pack_attributes(msg: GatewayUplinkMessage) -> Dict[str, Any]: + logger.trace("Packing %d attribute(s)", len(msg.attributes)) + return {attr.key: attr.value for attr in msg.attributes} + + @staticmethod + def pack_timeseries(msg: 'GatewayUplinkMessage') -> Union[Dict[str, Any], List[Dict[str, Any]]]: + entries = [e for entries in msg.timeseries.values() for e in entries] + if not entries: + return {} + + all_ts_none = True + for e in entries: + if e.ts is not None: + all_ts_none = False + break + + if all_ts_none: + result = {e.key: e.value for e in entries} + return [{"ts": msg.main_ts, "values": result}] if msg.main_ts is not None else [result] + + now_ts = msg.main_ts if msg.main_ts is not None else int(datetime.now(UTC).timestamp() * 1000) + grouped = defaultdict(dict) + for e in entries: + ts = e.ts if e.ts is not None else now_ts + grouped[ts][e.key] = e.value + + return [{"ts": ts, "values": values} for ts, values in grouped.items()] diff --git a/tb_mqtt_client/service/gateway/message_sender.py b/tb_mqtt_client/service/gateway/message_sender.py new file mode 100644 index 0000000..1c364a3 --- /dev/null +++ b/tb_mqtt_client/service/gateway/message_sender.py @@ -0,0 +1,177 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Future +from typing import List, Union, Optional + +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.message_service import MessageService + +logger = get_logger(__name__) + + +class GatewayMessageSender: + """ + Class responsible for sending uplink messages from devices connected to the gateway to the platform. + It handles the serialization of the message and sends to uplink message queue. + """ + + def __init__(self): + self._message_queue: Optional[MessageService] = None + self._message_adapter: Optional[GatewayMessageAdapter] = None + + async def send_uplink_message(self, message: GatewayUplinkMessage, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: + """ + Sends a list of uplink messages to the platform. + + :param message: List of GatewayUplinkMessage objects to be sent. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send uplink messages. Message queue is not set, do you connected to the platform?") + return None + if not message.has_timeseries() and not message.has_attributes(): + logger.warning("Uplink message does not contain timeseries or attributes, nothing to send.") + return None + futures = [] + if message.has_timeseries(): + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.GATEWAY_TELEMETRY_TOPIC, + payload=message, + qos=qos + ) + await self._message_queue.publish(mqtt_message) + futures.extend(mqtt_message.delivery_futures) + if message.has_attributes(): + mqtt_message = MqttPublishMessage( + topic=mqtt_topics.GATEWAY_ATTRIBUTES_TOPIC, + payload=message, + qos=qos + ) + await self._message_queue.publish(mqtt_message) + futures.extend(mqtt_message.delivery_futures) + return futures + + async def send_device_connect(self, device_connect_message: DeviceConnectMessage, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: + """ + Sends a device connect message to the platform. + + :param device_connect_message: DeviceConnectMessage object containing the device connection details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error( + "Cannot send device connect message. Message queue is not set, do you connected to the platform?") + return None + mqtt_message = self._message_adapter.build_device_connect_message_payload( + device_connect_message=device_connect_message, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures + + async def send_device_disconnect(self, device_disconnect_message: DeviceDisconnectMessage, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: + """ + Sends a device disconnect message to the platform. + + :param device_disconnect_message: DeviceDisconnectMessage object containing the device disconnection details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error( + "Cannot send device disconnect message. Message queue is not set, do you connected to the platform?") + return None + mqtt_message = self._message_adapter.build_device_disconnect_message_payload( + device_disconnect_message=device_disconnect_message, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures + + async def send_attributes_request(self, attribute_request: GatewayAttributeRequest, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: + """ + Sends an attribute request message to the platform. + + :param attribute_request: GatewayAttributeRequest object containing the attributes to be requested. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send attribute request. Message queue is not set, do you connected to the platform?") + return None + mqtt_message = self._message_adapter.build_gateway_attribute_request_payload( + attribute_request=attribute_request, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures + + async def send_rpc_response(self, rpc_response: GatewayRPCResponse, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: + """ + Sends an RPC response message to the platform. + + :param rpc_response: GatewayRPCResponse object containing the RPC response details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send RPC response. Message queue is not set, do you connected to the platform?") + return None + mqtt_message = self._message_adapter.build_rpc_response_payload(rpc_response=rpc_response, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures + + async def send_claim_request(self, claim_request: GatewayClaimRequest, qos=1) -> ( + Optional)[List[Union[PublishResult, Future[PublishResult]]]]: + """ + Sends a claim request message to the platform. + + :param claim_request: GatewayClaimRequest object containing the claim request details. + :param qos: Quality of Service level for the MQTT message. + :returns: List of PublishResult or Future[PublishResult] if successful, None if failed. + """ + if self._message_queue is None: + logger.error("Cannot send claim request. Message queue is not set, do you connected to the platform?") + return None + mqtt_message = self._message_adapter.build_claim_request_payload(claim_request=claim_request, qos=qos) + await self._message_queue.publish(mqtt_message) + return mqtt_message.delivery_futures + + def set_message_queue(self, message_queue: MessageService): + """ + Sets the message queue for sending uplink messages. + + :param message_queue: An instance of MessageQueue to be used for sending messages. + """ + self._message_queue = message_queue + + def set_message_adapter(self, message_adapter: GatewayMessageAdapter): + """ + Sets the message adapter for serializing uplink messages. + + :param message_adapter: An instance of GatewayMessageAdapter to be used for serializing messages. + """ + self._message_adapter = message_adapter diff --git a/tb_mqtt_client/service/gateway/message_splitter.py b/tb_mqtt_client/service/gateway/message_splitter.py new file mode 100644 index 0000000..79a3796 --- /dev/null +++ b/tb_mqtt_client/service/gateway/message_splitter.py @@ -0,0 +1,196 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import defaultdict +from typing import List, Dict, Tuple +from uuid import uuid4 + +from tb_mqtt_client.common.async_utils import future_map +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage, GatewayUplinkMessageBuilder, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.service.base_message_splitter import BaseMessageSplitter + +logger = get_logger(__name__) + + +class GatewayMessageSplitter(BaseMessageSplitter): + DEFAULT_MAX_PAYLOAD_SIZE = 55000 # Default max payload size in bytes + + def __init__(self, max_payload_size: int = 55000, max_datapoints: int = 0): + if max_payload_size is not None and max_payload_size > 0: + self._max_payload_size = max_payload_size + else: + self._max_payload_size = self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_payload_size = self._max_payload_size - DEFAULT_FIELDS_SIZE + self._max_datapoints = max_datapoints if max_datapoints is not None and max_datapoints > 0 else 0 + logger.trace("GatewayMessageSplitter initialized with max_payload_size=%d, max_datapoints=%d", + self._max_payload_size, self._max_datapoints) + + def split_timeseries(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: + logger.trace("Splitting gateway timeseries for %d messages", len(messages)) + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) + and messages[0].size <= self._max_payload_size): + return messages + + result: List[GatewayUplinkMessage] = [] + grouped: Dict[Tuple[str, str], List[GatewayUplinkMessage]] = defaultdict(list) + + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + all_ts_entries: List[TimeseriesEntry] = [] + parent_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_timeseries(): + for ts_group in msg.timeseries.values(): + all_ts_entries.extend(ts_group) + parent_futures.extend(msg.get_delivery_futures() or []) + + builder = None + size = 0 + point_count = 0 + names_len = len(device_name) + len(device_profile) + + for entry in all_ts_entries: + exceeds_size = builder and size + entry.size > self._max_payload_size - names_len + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder: + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed timeseries batch: size=%d, count=%d", size, point_count) + + builder = GatewayUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) + size = 0 + point_count = 0 + + builder.add_timeseries(entry) + size += entry.size + point_count += 1 + + if builder and builder._timeseries: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed final timeseries batch: size=%d, count=%d", size, point_count) + + return result + + def split_attributes(self, messages: List[GatewayUplinkMessage]) -> List[GatewayUplinkMessage]: + logger.trace("Splitting gateway attributes for %d messages", len(messages)) + if (len(messages) == 1 + and ((messages[0].attributes_datapoint_count() + messages[0].timeseries_datapoint_count() + <= self._max_datapoints) or self._max_datapoints == 0) + and messages[0].size <= self._max_payload_size): + return messages + + result: List[GatewayUplinkMessage] = [] + grouped: Dict[Tuple[str, str], List[GatewayUplinkMessage]] = defaultdict(list) + + for msg in messages: + grouped[(msg.device_name, msg.device_profile)].append(msg) + + for (device_name, device_profile), group_msgs in grouped.items(): + all_attrs: List[AttributeEntry] = [] + parent_futures: List[asyncio.Future] = [] + + for msg in group_msgs: + if msg.has_attributes(): + all_attrs.extend(msg.attributes) + parent_futures.extend(msg.get_delivery_futures() or []) + + builder = None + size = 0 + point_count = 0 + + for attr in all_attrs: + exceeds_size = builder and size + attr.size > self._max_payload_size + exceeds_points = 0 < self._max_datapoints <= point_count + + if not builder or exceeds_size or exceeds_points: + if builder and builder._attributes: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed attribute batch: size=%d, count=%d", size, point_count) + + builder = GatewayUplinkMessageBuilder() \ + .set_device_name(device_name) \ + .set_device_profile(device_profile) + size = 0 + point_count = 0 + + builder.add_attributes(attr) + size += attr.size + point_count += 1 + + if builder and builder._attributes: # noqa + shared_future = asyncio.get_running_loop().create_future() + shared_future.uuid = uuid4() + builder.add_delivery_futures(shared_future) + + built = builder.build() + result.append(built) + for parent in parent_futures: + future_map.register(parent, [shared_future]) + logger.trace("Flushed final attribute batch: size=%d, count=%d", size, point_count) + + return result + + @property + def max_payload_size(self) -> int: + return self._max_payload_size + + @max_payload_size.setter + def max_payload_size(self, value: int): + old = self._max_payload_size + self._max_payload_size = value if value > 0 else self.DEFAULT_MAX_PAYLOAD_SIZE + self._max_payload_size = self._max_payload_size - DEFAULT_FIELDS_SIZE + logger.debug("Updated max_payload_size: %d -> %d", old, self._max_payload_size) + + @property + def max_datapoints(self) -> int: + return self._max_datapoints + + @max_datapoints.setter + def max_datapoints(self, value: int): + old = self._max_datapoints + self._max_datapoints = value if value > 0 else 0 + logger.debug("Updated max_datapoints: %d -> %d", old, self._max_datapoints) diff --git a/tb_mqtt_client/service/message_service.py b/tb_mqtt_client/service/message_service.py new file mode 100644 index 0000000..20bb41b --- /dev/null +++ b/tb_mqtt_client/service/message_service.py @@ -0,0 +1,422 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from contextlib import suppress +from typing import List, Optional, Tuple + +from tb_mqtt_client.common.exceptions import BackpressureException +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.queue import AsyncDeque +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, EMPTY_RATE_LIMIT +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessage +from tb_mqtt_client.service.device.message_adapter import MessageAdapter +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.mqtt_manager import MQTTManager + +logger = get_logger(__name__) + + +class MessageService: + _QUEUE_COOLDOWN = 0.01 # seconds to sleep before the next iteration in case of empty queues + + def __init__(self, + mqtt_manager: MQTTManager, + main_stop_event: asyncio.Event, + device_rate_limiter: RateLimiter, + message_adapter: MessageAdapter, + max_queue_size: int = 1000000, + batch_collect_max_time_ms: int = 10, + batch_collect_max_count: int = 500, + gateway_message_adapter: Optional[GatewayMessageAdapter] = None, + gateway_rate_limiter: Optional[RateLimiter] = None): + self._main_stop_event = main_stop_event + self._batch_max_time = batch_collect_max_time_ms / 1000 + self._batch_max_count = batch_collect_max_count + self._mqtt_manager = mqtt_manager + # Patching MQTT client to add an ability to use retry logic + self._mqtt_manager.patch_client_for_retry_logic(self.put_retry_message) + self._device_rate_limiter = device_rate_limiter + self._gateway_rate_limiter = gateway_rate_limiter + self._rate_limit_ready = asyncio.Event() + self._backpressure = self._mqtt_manager.backpressure + self._retry_by_qos_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._initial_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._service_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._service_message_worker = MessageQueueWorker("ServiceMessageWorker", + self._service_queue, + self._main_stop_event, + self._mqtt_manager, + self._device_rate_limiter, + self._gateway_rate_limiter) + self._device_uplink_messages_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._device_uplink_message_worker = MessageQueueWorker("DeviceUplinkMessageWorker", + self._device_uplink_messages_queue, + self._main_stop_event, + self._mqtt_manager, + self._device_rate_limiter, + self._gateway_rate_limiter) + self._gateway_uplink_messages_queue: AsyncDeque = AsyncDeque(maxlen=max_queue_size) + self._gateway_uplink_message_worker = MessageQueueWorker("GatewayUplinkMessageWorker", + self._gateway_uplink_messages_queue, + self._main_stop_event, + self._mqtt_manager, + self._device_rate_limiter, + self._gateway_rate_limiter) + self._active = asyncio.Event() + self._wakeup_event = asyncio.Event() + self._can_process_device_uplink_event = asyncio.Event() + self._can_process_device_uplink_event.set() + self._can_process_gateway_uplink_event = asyncio.Event() + self._can_process_gateway_uplink_event.set() + self._active.set() + self._adapter = message_adapter + self._gateway_adapter = gateway_message_adapter + + self._retry_by_qos_task = asyncio.create_task(self._dispatch_retry_by_qos_queue_loop()) + self._initial_queue_task = asyncio.create_task(self._dispatch_initial_queue_loop()) + self._service_queue_task = asyncio.create_task( + self._dispatch_queue_loop(self._service_queue, self._service_message_worker)) + self._device_uplink_messages_queue_task = asyncio.create_task( + self._dispatch_queue_loop(self._device_uplink_messages_queue, self._device_uplink_message_worker)) + self._gateway_uplink_messages_queue_task = asyncio.create_task( + self._dispatch_queue_loop(self._gateway_uplink_messages_queue, self._gateway_uplink_message_worker)) + + self._rate_limit_refill_task = asyncio.create_task(self._rate_limit_refill_loop()) + self.__print_queue_statistics_task = asyncio.create_task(self.print_queues_statistics()) + + async def publish(self, message: MqttPublishMessage) -> Optional[List[asyncio.Future[PublishResult]]]: + """ + Publish a message to the message queue. + :param message: The MqttPublishMessage to publish. + :return: A list of futures for delivery results. + """ + try: + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace( + f"Pushing message to queue with delivery futures: {[f.uuid for f in message.delivery_futures]}") + await self._initial_queue.put(message) + except Exception as e: + logger.error("Failed to push message to queue: %s", e) + for future in message.delivery_futures: + if future and not future.done(): + future.set_result(PublishResult(message.topic, message.qos, -1, len(message.original_payload), -1)) + + async def _dispatch_initial_queue_loop(self): + """ + Loop to process messages from the initial queue and dispatch them to the appropriate queues. + """ + while not self._main_stop_event.is_set() and self._active.is_set(): + try: + if not self._mqtt_manager.is_connected(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + peeked_batch = await self._initial_queue.peek_batch(self._batch_max_count) + + gateway_messages = [] + device_messages = [] + for message in peeked_batch: + if isinstance(message.original_payload, bytes): + # If the message is a raw bytes payload, put it directly into the service queue + await self._service_queue.put(message) + elif isinstance(message.original_payload, GatewayUplinkMessage): + # If the message is a GatewayUplinkMessage, process it with the gateway adapter + gateway_messages.append(message) + elif isinstance(message.original_payload, DeviceUplinkMessage): + # If the message is a DeviceUplinkMessage, process it with the device adapter + device_messages.append(message) + else: + logger.warning("Unknown message type in initial queue: %s", + type(message.original_payload)) + + if gateway_messages: + # Process gateway messages in batches + messages = self._gateway_adapter.build_uplink_messages(gateway_messages) + await self._gateway_uplink_messages_queue.extend(messages) + + if device_messages: + # Process device messages in batches + messages = self._adapter.build_uplink_messages(device_messages) + await self._device_uplink_messages_queue.extend(messages) + + await self._initial_queue.pop_n(len(peeked_batch)) + + if self._initial_queue.is_empty(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + + except asyncio.CancelledError: + break + except Exception as e: + logger.exception("Dispatch loop error: %s", e) + await asyncio.sleep(0.5) + + async def _dispatch_queue_loop(self, queue: AsyncDeque, worker: 'MessageQueueWorker'): + """Loop to process messages from the service queue.""" + while not self._main_stop_event.is_set() and self._active.is_set(): + message: Optional[MqttPublishMessage] = None + try: + if not self._mqtt_manager.is_connected(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + message = await queue.get() + if not message: + await asyncio.sleep(0.01) + continue + logger.trace("Processing message from queue: %s, message_id: %s, message payload: %s", + message.topic, message.message_id, message.original_payload) + expected_duration, expected_tokens, triggered_rate_limit = await worker.process(message) + if triggered_rate_limit: + logger.trace("Reinserting message to the front of the queue: %s, message payload: %s", + message.uuid, message.original_payload) + await queue.reinsert_front(message) + triggered_rate_limit.set_required_tokens(expected_duration, expected_tokens) + await triggered_rate_limit.required_tokens_ready.wait() + + if queue.is_empty(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + + except asyncio.CancelledError: + break + except BackpressureException: + logger.warning("Backpressure exception occurred, re-inserting message to the front of the queue: %s", + message.uuid if message else "Unknown") + if message: + await queue.reinsert_front(message) + await asyncio.sleep(1) + except Exception as e: + logger.exception("Service queue loop error: %s", e) + if message: + await queue.reinsert_front(message) + await asyncio.sleep(1) + + async def _dispatch_retry_by_qos_queue_loop(self): + while not self._main_stop_event.is_set() and self._active.is_set(): + message = None + try: + if not self._mqtt_manager.is_connected(): + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + message = await self._retry_by_qos_queue.get() + if not message: + await asyncio.sleep(self._QUEUE_COOLDOWN) + continue + + logger.trace("Retrying QoS message: %s", message) + + if isinstance(message.original_payload, bytes): + await self._service_queue.reinsert_front(message) + elif isinstance(message.original_payload, GatewayUplinkMessage): + await self._gateway_uplink_messages_queue.reinsert_front(message) + elif isinstance(message.original_payload, DeviceUplinkMessage): + await self._device_uplink_messages_queue.reinsert_front(message) + else: + logger.warning("Unknown message type in retry queue: %s", + type(message.original_payload)) + + except asyncio.CancelledError: + break + except Exception as e: + logger.exception("Retry QoS dispatch error: %s", e) + if message: + await self._retry_by_qos_queue.reinsert_front(message) + await asyncio.sleep(0.5) + + async def shutdown(self): + logger.debug("Shutting down MessageQueue...") + self._active.clear() + self._wakeup_event.set() + + self._retry_by_qos_task.cancel() + self._initial_queue_task.cancel() + self._service_queue_task.cancel() + self._device_uplink_messages_queue_task.cancel() + self._gateway_uplink_messages_queue_task.cancel() + if self._rate_limit_refill_task: + self._rate_limit_refill_task.cancel() + self.__print_queue_statistics_task.cancel() + with suppress(asyncio.CancelledError): + await self._retry_by_qos_task + await self._initial_queue_task + await self._rate_limit_refill_task + await self.__print_queue_statistics_task + + await self.clear() + + logger.debug("MessageQueue shutdown complete, message queue size: %d", + self._initial_queue.size()) + + @staticmethod + async def _cancel_tasks(tasks: set[asyncio.Task]): + for task in list(tasks): + task.cancel() + with suppress(asyncio.CancelledError): + await asyncio.gather(*tasks, return_exceptions=True) + tasks.clear() + + def is_empty(self): + return self._initial_queue.is_empty() + + async def clear(self): + logger.debug("Clearing message queue...") + for queue in [self._initial_queue, self._service_queue, + self._device_uplink_messages_queue, self._gateway_uplink_messages_queue]: + while not queue.is_empty(): + message: MqttPublishMessage = await queue.get() + for future in message.delivery_futures: + if future and not future.done(): + future.set_result(PublishResult( + topic=message.topic, + qos=message.qos, + message_id=-1, + payload_size=message.payload_size if isinstance(message.payload, + bytes) else message.payload.size, + reason_code=-1 + )) + logger.debug("Message queue cleared.") + + async def _rate_limit_refill_loop(self): + try: + while self._active.is_set(): + await asyncio.sleep(1.0) + await self._refill_rate_limits() + logger.trace("Rate limits refilled, state: %s", self._device_rate_limiter) + except asyncio.CancelledError: + pass + + async def _refill_rate_limits(self): + for rl in self._device_rate_limiter.values(): + if rl: + await rl.refill() + if self._gateway_rate_limiter: + for rl in self._gateway_rate_limiter.values(): + if rl: + await rl.refill() + + def set_gateway_message_adapter(self, message_adapter: GatewayMessageAdapter): + self._gateway_adapter = message_adapter + + async def put_retry_message(self, message: MqttPublishMessage): + await self._retry_by_qos_queue.put(message) + + async def print_queues_statistics(self): + """ + Prints the current statistics of message queues. + """ + while self._active.is_set() and not self._main_stop_event.is_set(): + retry_queue_size = self._retry_by_qos_queue.size() + initial_queue_size = self._initial_queue.size() + service_queue_size = self._service_queue.size() + device_uplink_queue_size = self._device_uplink_messages_queue.size() + gateway_uplink_queue_size = self._gateway_uplink_messages_queue.size() + active = self._active.is_set() + logger.info("MessageQueue Statistics: " + "Initial Queue Size: %d, " + "Service Queue Size: %d, " + "Device Uplink Queue Size: %d, " + "Gateway Uplink Queue Size: %d, " + "Retry Queue Size: %d, " + "Active: %s", + initial_queue_size, + service_queue_size, + device_uplink_queue_size, + gateway_uplink_queue_size, + retry_queue_size, + active) + await asyncio.sleep(60) + + +class MessageQueueWorker: + def __init__(self, + name, + queue: AsyncDeque, + stop_event: asyncio.Event, + mqtt_manager: MQTTManager, + device_rate_limiter: RateLimiter, + gateway_rate_limiter: Optional[RateLimiter] = None): + self.name = name + self._queue = queue + self._stop_event = stop_event + self._mqtt_manager = mqtt_manager + self._device_rate_limiter = device_rate_limiter + self._gateway_rate_limiter = gateway_rate_limiter + + async def process(self, message: MqttPublishMessage) -> Tuple[Optional[int], Optional[int], Optional[RateLimit]]: + message_rate_limit, datapoints_rate_limit = self._get_rate_limits_for_message(message) + if message_rate_limit.has_limit() or datapoints_rate_limit.has_limit(): + triggered_rate_limit_entry, expected_tokens, rate_limit = await self.check_rate_limits_for_message( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit) + + if triggered_rate_limit_entry is not None: + triggered_duration = triggered_rate_limit_entry[1] + logger.debug("Rate limit %r per %r seconds hit by message with expected tokens: %r.", + triggered_rate_limit_entry[0], triggered_duration, expected_tokens) + return triggered_duration, expected_tokens, rate_limit + + await self._consume_rate_limits_for_message(datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit) + await self._mqtt_manager.publish(message) + return None, None, None + + def _get_rate_limits_for_message(self, message: MqttPublishMessage) -> Tuple[RateLimit, RateLimit]: + + message_rate_limit = EMPTY_RATE_LIMIT + datapoints_rate_limit = EMPTY_RATE_LIMIT + + if message.is_device_message: + if message.is_service_message: + message_rate_limit = self._device_rate_limiter.message_rate_limit + else: + message_rate_limit = self._device_rate_limiter.telemetry_message_rate_limit + datapoints_rate_limit = self._device_rate_limiter.telemetry_datapoints_rate_limit + + else: + if message.is_service_message: + message_rate_limit = self._gateway_rate_limiter.message_rate_limit + else: + message_rate_limit = self._gateway_rate_limiter.telemetry_message_rate_limit + datapoints_rate_limit = self._gateway_rate_limiter.telemetry_datapoints_rate_limit + + return message_rate_limit, datapoints_rate_limit + + @staticmethod + async def check_rate_limits_for_message(datapoints_count: int, + message_rate_limit: RateLimit, + datapoints_rate_limit: RateLimit) -> Tuple[Optional[Tuple[int, int]], + int, + Optional[RateLimit]]: + if message_rate_limit and message_rate_limit.has_limit(): + triggered_rate_limit_entry = await message_rate_limit.try_consume(1) + if triggered_rate_limit_entry: + return triggered_rate_limit_entry, 1, message_rate_limit + if datapoints_rate_limit and datapoints_rate_limit.has_limit(): + triggered_rate_limit_entry = await datapoints_rate_limit.try_consume(datapoints_count) + if triggered_rate_limit_entry: + return triggered_rate_limit_entry, datapoints_count, datapoints_rate_limit + return None, 0, None + + @staticmethod + async def _consume_rate_limits_for_message(datapoints_count: int, + message_rate_limit: RateLimit, + datapoints_rate_limit: RateLimit) -> None: + if message_rate_limit and message_rate_limit.has_limit(): + await message_rate_limit.consume(1) + if datapoints_rate_limit and datapoints_rate_limit.has_limit(): + await datapoints_rate_limit.consume(datapoints_count) diff --git a/tb_mqtt_client/service/mqtt_manager.py b/tb_mqtt_client/service/mqtt_manager.py new file mode 100644 index 0000000..5dfe4d2 --- /dev/null +++ b/tb_mqtt_client/service/mqtt_manager.py @@ -0,0 +1,531 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import ssl +from asyncio import sleep +from contextlib import suppress +from time import monotonic +from typing import Optional, Callable, Dict, Union, Tuple, Coroutine, Any +from uuid import uuid4 + +from gmqtt import Client as GMQTTClient, Subscription + +from tb_mqtt_client.common.async_utils import await_or_stop, future_map, run_coroutine_sync +from tb_mqtt_client.common.exceptions import BackpressureException +from tb_mqtt_client.common.gmqtt_patch import PatchUtils, PublishPacket +from tb_mqtt_client.common.logging_utils import get_logger, TRACE_LEVEL +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer, AttributeRequestIdProducer +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + +logger = get_logger(__name__) + +QUOTA_EXCEEDED = 0x97 # MQTT 5 reason code (151) +IMPLEMENTATION_SPECIFIC_ERROR = 0x83 # MQTT 5 reason code (131) + + +class MQTTManager: + _PUBLISH_TIMEOUT = 10.0 # Default timeout for publish operations + + def __init__( + self, + client_id: str, + main_stop_event: asyncio.Event, + message_adapter: MessageAdapter, + on_connect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, + on_disconnect: Optional[Callable[[], Coroutine[Any, Any, None]]] = None, + on_publish_result: Optional[Callable[[PublishResult], Coroutine[Any, Any, None]]] = None, + rate_limits_handler: Optional[Callable[[RPCResponse], Coroutine[Any, Any, None]]] = None, + rpc_response_handler: Optional[RPCResponseHandler] = None, + ): + self._main_stop_event = main_stop_event + self._message_adapter = message_adapter + self._patch_utils: PatchUtils = PatchUtils(None, self._main_stop_event, 1) + self._patch_utils.patch_gmqtt_protocol_connection_lost() + self._patch_utils.patch_mqtt_handler_disconnect() + + self._client: GMQTTClient = GMQTTClient(client_id) + self._patch_utils.client = self._client + self._patch_utils.patch_handle_connack() + self._patch_utils.apply(self._handle_puback_reason_code) + self._client.on_connect = self._on_connect_internal + self._client.on_disconnect = self._on_disconnect_internal + self._client.on_message = self._on_message_internal + self._client.on_publish = self._on_publish_internal + self._client.on_subscribe = self._on_subscribe_internal + self._client.on_unsubscribe = self._on_unsubscribe_internal + + self._on_connect_callback = on_connect + self._on_disconnect_callback = on_disconnect + self._on_publish_result_callback = on_publish_result + + self._connected_event = asyncio.Event() + self._handlers: Dict[str, Callable[[str, bytes], Coroutine[Any, Any, None]]] = {} + + self._pending_publishes: Dict[int, Tuple[asyncio.Future[PublishResult], MqttPublishMessage, float]] = {} + self._publish_monitor_task = asyncio.create_task(self._monitor_ack_timeouts()) + + self._pending_subscriptions: Dict[int, asyncio.Future] = {} + self._pending_unsubscriptions: Dict[int, asyncio.Future] = {} + self._rpc_response_handler = rpc_response_handler or RPCResponseHandler() + self.register_handler(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, self._rpc_response_handler.handle) + + self._backpressure = BackpressureController(self._main_stop_event) + self.__rate_limits_handler = rate_limits_handler + self.__rate_limits_retrieved = False + self.__gateway_rate_limits_retrieved = False + self.__rate_limiter: Optional[RateLimiter] = None + self.__gateway_rate_limiter: Optional[RateLimiter] = None + self.__is_gateway = False + self.__is_waiting_for_rate_limits_publish = True # True to prevent publishing before rate limits are retrieved + self._rate_limits_ready_event = asyncio.Event() + + async def connect(self, host: str, port: int = 1883, username: Optional[str] = None, + password: Optional[str] = None, tls: bool = False, + keepalive: int = 60, ssl_context: Optional[ssl.SSLContext] = None): + if username: + self._client.set_auth_credentials(username, password) + + if tls: + if ssl_context is None: + ssl_context = ssl.create_default_context() + await self._client.connect(host, port, ssl=ssl_context, keepalive=keepalive, raise_exc=False) + else: + await self._client.connect(host, port, keepalive=keepalive, raise_exc=False) + + while not self._client.is_connected and not self._main_stop_event.is_set(): + try: + await self._connected_event.wait() + break + except Exception as exc: + logger.warning("MQTT connection failed, waiting for connection: %s", str(exc)) + + def is_connected(self) -> bool: + return (self._client.is_connected + and self._connected_event.is_set() + and self.__rate_limits_retrieved + and (not self.__is_gateway or self.__gateway_rate_limits_retrieved)) + + async def disconnect(self): + try: + await self._client.disconnect() + except ConnectionResetError: + logger.debug("Connection reset error during disconnect, ignoring.") + except Exception as e: + logger.error("Error during MQTT disconnect: %s", str(e)) + await asyncio.sleep(0.2) + self._connected_event.clear() + self.__rate_limits_retrieved = False + self.__is_waiting_for_rate_limits_publish = True + self._rate_limits_ready_event.clear() + + async def publish(self, + message: MqttPublishMessage, + force=False): + + if not force: + if not self.__rate_limits_retrieved and not self.__is_waiting_for_rate_limits_publish: + raise RuntimeError("Cannot publish before rate limits are retrieved.") + try: + if not self._rate_limits_ready_event.is_set(): + await await_or_stop(self._rate_limits_ready_event.wait(), self._main_stop_event, timeout=10) + except asyncio.TimeoutError: + if not self.__is_waiting_for_rate_limits_publish: + logger.warning("Timeout waiting for rate limits, requesting them now.") + await self.__request_rate_limits() + raise RuntimeError("Timeout waiting for rate limits.") + + if not force and self._backpressure.should_pause(): + logger.trace("Backpressure active. Publishing suppressed.") + raise BackpressureException("Publishing temporarily paused due to backpressure.") + + if not message.dup: + return await self.process_regular_publish(message, message.qos) + else: + # If a message is a duplicate, it should be sent immediately + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace("Processing duplicate message with topic: %s, qos: %d, payload size: %d", + message.topic, message.qos, len(message.payload)) + + protocol = self._client._connection._protocol # noqa + + if protocol: + try: + mid, rebuilt = PublishPacket.build_package( + message=message, + protocol=protocol, + mid=message.message_id + ) + self._client._connection.send_package(rebuilt) + logger.trace("Retransmitted message mid=%r", mid) + except Exception as e: + logger.warning("Error during retransmission: %s", e) + logger.debug("Failed to retransmit message: %r", message, exc_info=e) + else: + logger.warning("Cannot retransmit, MQTT protocol unavailable.") + self._client._persistent_storage.push_message_nowait(message.message_id, message) # noqa + + async def process_regular_publish(self, message: MqttPublishMessage, qos: int = 1): + + mqtt_future = asyncio.get_event_loop().create_future() + mqtt_future.uuid = uuid4() + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace( + "Publishing message with topic: %s, qos: %d, " + "payload size: %d, mqtt_future id: %r, delivery futures: %r", + message.topic, qos, len(message.payload), + mqtt_future.uuid, [f.uuid for f in message.delivery_futures]) + if message.delivery_futures is not None: + await self._add_future_chain_processing(mqtt_future, message) # noqa + + mid, package = self._client._connection.publish(message) # noqa + + message.mark_as_sent(mid) + + if qos > 0: + logger.trace("Publishing mid=%s, storing publish main future with id: %r", + mid, mqtt_future.uuid) + self._pending_publishes[mid] = (mqtt_future, message, monotonic()) + self._client._persistent_storage.push_message_nowait(mid, message) # noqa + else: + mqtt_future.set_result(PublishResult(message.topic, qos, -1, message.payload_size, 0)) + + async def subscribe(self, topic: Union[str, Subscription], qos: int = 1) -> asyncio.Future: + sub_future = asyncio.get_event_loop().create_future() + sub_future.uuid = uuid4() + subscription = Subscription(topic, qos=qos) if isinstance(topic, str) else topic + + if self.__rate_limiter: + await self.__rate_limiter.message_rate_limit.consume() + mid = self._client._connection.subscribe([subscription]) # noqa + self._pending_subscriptions[mid] = sub_future + return sub_future + + async def unsubscribe(self, topic: str) -> asyncio.Future: + unsubscribe_future = asyncio.get_event_loop().create_future() + unsubscribe_future.uuid = uuid4() + if self.__rate_limiter: + await self.__rate_limiter.message_rate_limit.consume() + mid = self._client._connection.unsubscribe(topic) # noqa + self._pending_unsubscriptions[mid] = unsubscribe_future + return unsubscribe_future + + def register_handler(self, topic_filter: str, handler: Callable[[str, bytes], Coroutine[Any, Any, None]]): + self._handlers[topic_filter] = handler + + def unregister_handler(self, topic_filter: str): + self._handlers.pop(topic_filter, None) + + def _on_connect_internal(self, client, session_present, reason_code, properties): + if reason_code != 0: + logger.error("Failed to connect to platform with reason code: %s", reason_code) + if properties and 'reason_string' in properties: + logger.error("Connection reason: %s", properties['reason_string'][0]) + self._connected_event.clear() + return + logger.info("Connected to the platform.") + logger.debug("Connection session_present: %s, reason code: %s, properties: %s", session_present, reason_code, + properties) + if hasattr(client, '_connection'): + client._connection._on_disconnect_called = False # noqa + self._connected_event.set() + asyncio.create_task(self.__handle_connect_and_limits()) + + async def __handle_connect_and_limits(self): + logger.info("Subscribing to RPC response topics") + sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_REQUEST_TOPIC_FOR_SUBSCRIPTION, qos=1) + while not sub_future.done(): + await sleep(0.01) + if self._main_stop_event.is_set(): + return + sub_future = await self.subscribe(mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION, qos=1) + while not sub_future.done(): + await sleep(0.01) + if self._main_stop_event.is_set(): + return + logger.debug("Subscribing completed, sending rate limits request") + + await self.__request_rate_limits() + + if self._on_connect_callback: + asyncio.create_task(self._on_connect_callback()) + + def _on_disconnect_internal(self, client, reason_code=None, properties=None, exc=None): # noqa + if isinstance(reason_code, bytes): + # Skipping handling due to duplication, because gmqtt triggers this cb again. + logger.trace("Received bytes reason code: %r", reason_code) + return + self._connected_event.clear() + if reason_code is not None: + reason_desc = PatchUtils.DISCONNECT_REASON_CODES.get(reason_code, "Unknown reason") + logger.info("Disconnected from platform with reason code: %s (%s)", reason_code, reason_desc) + + if properties and 'reason_string' in properties: + logger.info("Disconnect reason: %s", properties['reason_string'][0]) + else: + logger.info("Disconnected from the platform.") + + if exc: + logger.warning("Disconnect exception: %s", exc) + + for mid, (future, mqtt_message, publishing_time) in list(self._pending_publishes.items()): + if not future.done(): + publish_result = PublishResult( + topic=mqtt_message.topic, + qos=mqtt_message.qos, + payload_size=mqtt_message.payload_size, + message_id=-1, + reason_code=reason_code or 0 + ) + future.set_result(publish_result) + logger.warning("Setting publish result for mid=%s: %r", mid, publish_result) + self._pending_publishes.clear() + + RPCRequestIdProducer.reset() + AttributeRequestIdProducer.reset() + self.__rate_limits_retrieved = False + self.__is_waiting_for_rate_limits_publish = True + self._rate_limits_ready_event.clear() + if reason_code == 142: + logger.error("Session was taken over, looks like another client connected with the same credentials.") + self._backpressure.notify_disconnect(delay_seconds=10) + if reason_code in (131, 142, 143, 151): # 131, 142, 151 may be caused by rate limits or issue with the data + reached_time = 1 + if self.__rate_limiter: + for rate_limit in self.__rate_limiter.values(): + if isinstance(rate_limit, RateLimit): + try: + reached_limit = run_coroutine_sync(rate_limit.reach_limit, raise_on_timeout=True) + except TimeoutError: + logger.warning("Timeout while checking rate limit reaching.") + reached_time = 10 # Default to 10 seconds if timeout occurs + break + reached_index, reached_time, reached_duration = reached_limit if reached_limit else (None, None, 1) + self._backpressure.notify_disconnect(delay_seconds=reached_time) + elif reason_code != 0: + # Default disconnect handling + self._backpressure.notify_disconnect(delay_seconds=15) + + self._rpc_response_handler.clear() + if self._on_disconnect_callback: + asyncio.create_task(self._on_disconnect_callback()) + + def _on_message_internal(self, client, topic: str, payload: bytes, qos, properties): + logger.trace("Received message by client %r on topic %s with payload %r, qos %r, properties %r", + client, topic, payload, qos, properties) + for topic_filter, handler in self._handlers.items(): + if self._match_topic(topic_filter, topic): + asyncio.create_task(handler(topic, payload)) + return + + def _on_publish_internal(self, client, mid): + logger.trace("Publish was sent by client %r with mid=%s", client, mid) + + def _handle_puback_reason_code(self, mid: int, reason_code: int, properties: dict): + logger.trace("Handling PUBACK mid=%s with rc %r and properties: %r", + mid, reason_code, properties) + pending_future_data = self._pending_publishes.pop(mid, None) + if pending_future_data is None: + logger.error("Missing future for mid=%s", mid) + return + future, mqtt_message, publishing_time = pending_future_data + publish_result = PublishResult( + topic=mqtt_message.topic, + qos=mqtt_message.qos, + payload_size=mqtt_message.payload_size, + message_id=mid, + reason_code=reason_code, + datapoints_count=mqtt_message.datapoints + ) + if logger.isEnabledFor(TRACE_LEVEL): + logger.trace("Received result for publish future (id: %r): %r", future.uuid, publish_result) + if not future.done(): + future.set_result(publish_result) + else: + logger.warning("Future (id: %r) for mid=%s was already done, skipping setting result", + future.uuid, mid) + + if reason_code == QUOTA_EXCEEDED: + logger.warning("PUBACK received with QUOTA_EXCEEDED for mid=%s", mid) + self._backpressure.notify_quota_exceeded(delay_seconds=10) + elif reason_code == IMPLEMENTATION_SPECIFIC_ERROR: + logger.warning( + "PUBACK received with IMPLEMENTATION_SPECIFIC_ERROR for mid=%s, treating as rate limit reached", mid) + self._backpressure.notify_quota_exceeded( + delay_seconds=15) # Treat implementation specific error as quota exceeded + elif reason_code != 0: + logger.warning("PUBACK received with error code %s for mid=%s", reason_code, mid) + + if self._on_publish_result_callback: + asyncio.create_task(self._on_publish_result_callback(publish_result)) + + def _on_subscribe_internal(self, client, mid, qos, properties): + logger.trace("Received SUBACK by client %r for mid=%s with qos %s, properties %s", + client, mid, qos, properties) + future = self._pending_subscriptions.pop(mid, None) + if future and not future.done(): + future.set_result(mid) + + def _on_unsubscribe_internal(self, client, mid, properties): + logger.trace("Received UNSUBACK by client %r for mid=%s with properties %s", client, mid, properties) + future = self._pending_unsubscriptions.pop(mid, None) + if future and not future.done(): + future.set_result(mid) + + async def await_ready(self, timeout: float = 5.0): + try: + await await_or_stop(self._rate_limits_ready_event.wait(), self._main_stop_event, timeout=timeout) + except asyncio.TimeoutError: + logger.debug("Waiting for rate limits timed out.") + + def set_rate_limits_received(self): + self.__rate_limits_retrieved = True + self.__is_waiting_for_rate_limits_publish = False + if not self.__is_gateway: + self._rate_limits_ready_event.set() + + def enable_gateway_mode(self): + self.__is_gateway = True + + def set_gateway_rate_limits_received(self): + self.__gateway_rate_limits_retrieved = True + self._rate_limits_ready_event.set() + + async def __request_rate_limits(self): + self.__is_waiting_for_rate_limits_publish = True + + logger.debug("Publishing rate limits request to server...") + + request = await RPCRequest.build("getSessionLimits") + mqtt_message: MqttPublishMessage = self._message_adapter.build_rpc_request(request) + response_future = self._rpc_response_handler.register_request(request.request_id, self.__rate_limits_handler) + + try: + await self.publish(mqtt_message, force=True) + await await_or_stop(response_future, self._main_stop_event, timeout=10) + logger.info("Successfully processed rate limits.") + # self.__rate_limits_retrieved = True + # self.__is_waiting_for_rate_limits_publish = False + # self._rate_limits_ready_event.set() + except asyncio.TimeoutError: + logger.warning("Timeout while waiting for rate limits.") + # Keep __is_waiting_for_rate_limits_publish as True to prevent publishing + # until rate limits are retrieved + + @property + def backpressure(self) -> BackpressureController: + return self._backpressure + + @staticmethod + def _match_topic(filter_expression: str, topic: str) -> bool: + filter_parts = filter_expression.split('/') + topic_parts = topic.split('/') + + for i, filter_part in enumerate(filter_parts): + if filter_part == '#': + return True + if i >= len(topic_parts): + return False + if filter_part != '+' and filter_part != topic_parts[i]: + return False + + return len(filter_parts) == len(topic_parts) + + async def _monitor_ack_timeouts(self): + while not self._main_stop_event.is_set(): + now = monotonic() + await self.check_pending_publishes(now) + # TODO: Add logic to handle expired futures, for subscriptions, rpc responses, etc. + await asyncio.sleep(0.1) + await self.check_pending_publishes(monotonic()) + + def patch_client_for_retry_logic(self, + put_retry_message_method: Callable[[MqttPublishMessage], + Coroutine[Any, Any, None]]): + self._client.put_retry_message = put_retry_message_method + + async def check_pending_publishes(self, time_to_check): + expired = [] + for mid, (future, message, timestamp) in list(self._pending_publishes.items()): + if self._main_stop_event.is_set(): + with suppress(asyncio.CancelledError): + future.cancel() + continue + if time_to_check - timestamp > self._PUBLISH_TIMEOUT: + if not future.done(): + logger.warning("Publish timeout: mid=%s, topic=%s", mid, message.topic) + result = PublishResult(message.topic, message.qos, message.payload_size, mid, reason_code=408) + future.set_result(result) + expired.append(mid) + for mid in expired: + self._pending_publishes.pop(mid, None) + + async def stop(self): + """ + Cleanly stop the MQTT manager and background tasks. + """ + if hasattr(self, '_patch_utils'): + await self._patch_utils.stop_retry_task() + + if hasattr(self, '_publish_monitor_task'): + self._publish_monitor_task.cancel() + with suppress(asyncio.CancelledError): + await self._publish_monitor_task + + if self._client.is_connected: + await self._client.disconnect() + + @staticmethod + async def _add_future_chain_processing(mqtt_future, message: MqttPublishMessage): + def resolve_attached(publish_future: asyncio.Future): + try: + try: + publish_result = publish_future.result() + except asyncio.CancelledError: + logger.info("Publish future was cancelled: %r, id: %r", publish_future, publish_future.uuid) + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + except Exception as exc: + logger.warning("Publish failed with exception: %s", exc) + logger.debug("Resolving delivery futures with failure:", exc_info=exc) + publish_result = PublishResult(message.topic, message.qos, -1, len(message.payload), -1) + + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_result(publish_result) + future_map.child_resolved(f) + logger.trace("Resolved delivery future #%d id=%r with %s, main publish future id: %r", + i, f.uuid, publish_result, publish_future.uuid) + except Exception as e: + logger.error("Error resolving delivery futures: %s", str(e)) + for i, f in enumerate(message.delivery_futures or []): + if f is not None and not f.done(): + f.set_exception(e) + logger.debug("Set exception for delivery future #%d id=%r", i, f.uuid) + + logger.trace("Adding done callback to main publish future: %r, main publish future done state: %r", + mqtt_future.uuid, mqtt_future.done()) + if mqtt_future.done(): + logger.debug("Main publish future is already done, resolving immediately.") + resolve_attached(mqtt_future) + else: + mqtt_future.add_done_callback(resolve_attached) diff --git a/tests/__init__.py b/tests/__init__.py index 8d89b47..cff354e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,13 +1,13 @@ -# Copyright 2025. ThingsBoard +# Copyright 2025 ThingsBoard # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/blackbox/__init__.py b/tests/blackbox/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/blackbox/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/blackbox/conftest.py b/tests/blackbox/conftest.py new file mode 100644 index 0000000..707f7a0 --- /dev/null +++ b/tests/blackbox/conftest.py @@ -0,0 +1,516 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import subprocess +import time +import uuid +from copy import deepcopy +from typing import Dict, Generator, Optional + +import pytest +import requests +from requests import HTTPError +from requests.adapters import HTTPAdapter +from urllib3.util import Retry + +from tb_mqtt_client.common.config_loader import DeviceConfig, GatewayConfig +from tests.blackbox.constants import RPC_RESPONSE_RULE_CHAIN +from tests.blackbox.rest_helpers import ( + find_related_entity_id, + get_device_info_by_id, + create_device_profile_and_firmware, + get_default_device_profile, +) + +# ---------------------------- +# Environment and configuration +# ---------------------------- + +ENV = os.environ.get +TB_HOST: str = ENV("SDK_BLACKBOX_TB_HOST", "localhost") +TB_HTTP_PORT: int = int(ENV("SDK_BLACKBOX_TB_PORT", "8080")) +TB_MQTT_PORT: int = int(ENV("SDK_BLACKBOX_TB_MQTT_PORT", "1883")) +TENANT_USER: str = ENV("SDK_BLACKBOX_TENANT_USER", "tenant@thingsboard.org") +TENANT_PASS: str = ENV("SDK_BLACKBOX_TENANT_PASS", "tenant") +TB_CONTAINER_NAME: str = ENV("SDK_BLACKBOX_TB_CONTAINER", "tb-ce-sdk-tests") +RUN_WITH_LOCAL_TB: bool = ENV("SDK_RUN_BLACKBOX_TESTS_LOCAL", "false").lower() == "true" +RUN_BLACKBOX: bool = ENV("SDK_RUN_BLACKBOX_TESTS", "false").lower() == "true" + +TB_HTTP_PROTOCOL: str = ENV("SDK_BLACKBOX_TB_HTTP_PROTOCOL", "http") +TB_URL: str = ENV("SDK_BLACKBOX_TB_URL", f"{TB_HTTP_PROTOCOL}://{TB_HOST}:{TB_HTTP_PORT}") + +REQUEST_TIMEOUT: float = float(ENV("SDK_BLACKBOX_HTTP_TIMEOUT", "30")) +TB_START_TIMEOUT: int = int(ENV("SDK_BLACKBOX_TB_START_TIMEOUT", "180")) + +logger = logging.getLogger("blackbox") +logger.setLevel(logging.INFO) + + +def _build_retrying_session() -> requests.Session: + """A single session with retry/backoff for better stability in tests.""" + session = requests.Session() + retries = Retry( + total=5, + connect=5, + read=5, + backoff_factor=0.5, # exponential: 0.5, 1, 2, ... + status_forcelist=(429, 500, 502, 503, 504), + allowed_methods=("GET", "POST", "PUT", "DELETE", "PATCH"), + raise_on_status=False, + ) + adapter = HTTPAdapter(max_retries=retries, pool_connections=20, pool_maxsize=50) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +@pytest.fixture(scope="session") +def http() -> Generator[requests.Session, None, None]: + """Shared HTTP session with retries and connection pooling.""" + session = _build_retrying_session() + try: + yield session + finally: + session.close() + + +def pytest_collection_modifyitems(config, items): + """Skip blackbox tests unless explicitly enabled.""" + if not RUN_BLACKBOX: + skip_marker = pytest.mark.skip( + reason=( + "Blackbox tests disabled. Set SDK_RUN_BLACKBOX_TESTS=true to run. " + "Optionally set SDK_RUN_BLACKBOX_TESTS_LOCAL=true to run against a local ThingsBoard instance." + ) + ) + for item in items: + if "blackbox" in str(item.fspath): + item.add_marker(skip_marker) + + +# ---------------------------- +# ThingsBoard lifecycle helpers +# ---------------------------- + +def _docker_is_running(container_name: str) -> bool: + try: + status = ( + subprocess.check_output( + ["docker", "inspect", "-f", "{{.State.Running}}", container_name], + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + return status == "true" + except subprocess.CalledProcessError: + return False + + +def _wait_for_tb_ready(http: requests.Session) -> None: + """Wait for TB to accept logins; uses exponential backoff.""" + logger.info("Waiting for ThingsBoard CE to become ready...") + start = time.time() + delay = 1.0 + # Try login to ensure full readiness (DB + REST). + while time.time() - start < TB_START_TIMEOUT: + try: + r = http.post( + f"{TB_URL}/api/auth/login", + json={"username": TENANT_USER, "password": TENANT_PASS}, + timeout=REQUEST_TIMEOUT, + ) + if r.ok: + logger.info("ThingsBoard CE is ready.") + return + except requests.RequestException: + pass + time.sleep(delay) + delay = min(delay * 1.5, 5.0) + pytest.fail("ThingsBoard CE did not become ready in time.") + + +@pytest.fixture(scope="session", autouse=True) +def start_thingsboard(http: requests.Session) -> Generator[None, None, None]: + """ + Start ThingsBoard CE in Docker if blackbox tests are enabled and not running locally. + Ensures instance is ready. If we start the container, we will stop it at session end. + """ + if not RUN_BLACKBOX: + _wait_for_tb_ready(http) + yield + return + + started_by_us = False + + if not RUN_WITH_LOCAL_TB: + if _docker_is_running(TB_CONTAINER_NAME): + logger.info("Using existing ThingsBoard CE container: %s", TB_CONTAINER_NAME) + else: + logger.info("Pulling ThingsBoard CE image (thingsboard/tb-postgres)...") + subprocess.run(["docker", "pull", "thingsboard/tb-postgres"], check=True) + + logger.info("Starting ThingsBoard CE container...") + subprocess.run( + [ + "docker", + "run", + "-d", + "--rm", + "--name", + TB_CONTAINER_NAME, + "-p", + f"{TB_HTTP_PORT}:8080", + "-p", + f"{TB_MQTT_PORT}:1883", + "thingsboard/tb-postgres", + ], + check=True, + ) + started_by_us = True + + _wait_for_tb_ready(http) + + try: + yield + finally: + if started_by_us: + logger.info("Stopping ThingsBoard CE container: %s", TB_CONTAINER_NAME) + subprocess.run(["docker", "stop", TB_CONTAINER_NAME], check=False) + + +# ---------------------------- +# Auth and headers +# ---------------------------- + +@pytest.fixture(scope="session") +def test_config() -> Dict[str, object]: + """Return configuration for blackbox tests.""" + return { + "tb_host": TB_HOST, + "tb_http_port": TB_HTTP_PORT, + "tb_mqtt_port": TB_MQTT_PORT, + "tenant_user": TENANT_USER, + "tenant_pass": TENANT_PASS, + "tb_url": TB_URL, + "timeout": REQUEST_TIMEOUT, + } + + +@pytest.fixture(scope="session") +def tb_admin_token(start_thingsboard, http: requests.Session) -> str: + """Login as tenant admin and return JWT.""" + r = http.post( + f"{TB_URL}/api/auth/login", + json={"username": TENANT_USER, "password": TENANT_PASS}, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + data = r.json() + # Some TB builds return {"token": "..."}; others include "refreshToken". We only need the token. + return data["token"] + + +@pytest.fixture(scope="session") +def tb_admin_headers(tb_admin_token: str) -> Dict[str, str]: + """Headers for ThingsBoard REST API requests.""" + return { + "X-Authorization": f"Bearer {tb_admin_token}", + "Content-Type": "application/json", + } + + +# ---------------------------- +# Device fixtures +# ---------------------------- + +def _create_device(http: requests.Session, headers: Dict[str, str], name: str, dev_type: str = "default") -> dict: + r = http.post( + f"{TB_URL}/api/device", + headers=headers, + json={"name": name, "type": dev_type}, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + return r.json() + + +def _get_device_credentials_token(http: requests.Session, headers: Dict[str, str], device_id: str) -> str: + r = http.get(f"{TB_URL}/api/device/{device_id}/credentials", headers=headers, timeout=REQUEST_TIMEOUT) + r.raise_for_status() + return r.json()["credentialsId"] + + +@pytest.fixture() +def device_info(tb_admin_headers: Dict[str, str], http: requests.Session) -> Generator[dict, None, None]: + """Create a unique device for a test; always cleaned up.""" + name = f"pytest-device-{uuid.uuid4().hex[:8]}" + try: + device = _create_device(http, tb_admin_headers, name) + except HTTPError as e: + logger.error("Failed to create device %s: %s", name, e) + pytest.fail("Failed to create test device.") + + try: + yield device + finally: + device_id = device["id"]["id"] + http.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) + + +@pytest.fixture() +def device_config( + device_info: dict, tb_admin_headers: Dict[str, str], http: requests.Session +) -> Generator[DeviceConfig, None, None]: + """Return a ready-to-use DeviceConfig; cleans up the device after use via device_info fixture.""" + device_id = device_info["id"]["id"] + token = _get_device_credentials_token(http, tb_admin_headers, device_id) + + cfg = DeviceConfig() + cfg.host = TB_HOST + cfg.port = TB_MQTT_PORT + cfg.access_token = token + yield cfg + + +# ---------------------------- +# Gateway fixtures +# ---------------------------- + +def _create_gateway_device(http: requests.Session, headers: Dict[str, str], name: str) -> dict: + r = http.post( + f"{TB_URL}/api/device", + headers=headers, + json={"name": name, "type": "default", "additionalInfo": {"gateway": True}}, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + return r.json() + + +@pytest.fixture() +def gateway_info(tb_admin_headers: Dict[str, str], http: requests.Session) -> Generator[dict, None, None]: + """Create a unique gateway device for a test; always cleaned up.""" + name = f"pytest-gw-device-{uuid.uuid4().hex[:8]}" + try: + gw = _create_gateway_device(http, tb_admin_headers, name) + except HTTPError as e: + logger.error("Failed to create gateway device %s: %s", name, e) + pytest.fail("Failed to create or find test gateway device.") + try: + yield gw + finally: + gw_id = gw["id"]["id"] + # Attempt to delete gateway at teardown (sub-devices handled by gateway_config fixture if created). + http.delete(f"{TB_URL}/api/device/{gw_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) + + +@pytest.fixture() +def gateway_config( + gateway_info: dict, tb_admin_headers: Dict[str, str], http: requests.Session +) -> Generator[GatewayConfig, None, None]: + """Return a GatewayConfig and clean up gateway + any auto-created sub-device/profile.""" + device_id = gateway_info["id"]["id"] + token = _get_device_credentials_token(http, tb_admin_headers, device_id) + + cfg = GatewayConfig() + cfg.host = TB_HOST + cfg.port = TB_MQTT_PORT + cfg.access_token = token + + # Provide config to the test + yield cfg + + # Cleanup: delete gateway and any auto-created sub-device/profile if present. + try: + subdevice_id: Optional[str] = None + try: + subdevice_id = find_related_entity_id(device_id, TB_URL, tb_admin_headers) + except Exception: + subdevice_id = None + + http.delete(f"{TB_URL}/api/device/{device_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) + + if subdevice_id: + sub_dev = get_device_info_by_id(subdevice_id, TB_URL, tb_admin_headers) + http.delete(f"{TB_URL}/api/device/{subdevice_id}", headers=tb_admin_headers, timeout=REQUEST_TIMEOUT) + if "deviceProfileId" in sub_dev and "id" in sub_dev["deviceProfileId"]: + http.delete( + f"{TB_URL}/api/deviceProfile/{sub_dev['deviceProfileId']['id']}", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + except requests.RequestException as e: + logger.warning("Gateway cleanup encountered an error: %s", e) + + +# ---------------------------- +# Firmware/profile fixtures +# ---------------------------- + +@pytest.fixture() +def firmware_profile_and_package( + tb_admin_headers: Dict[str, str], + test_config: Dict[str, object], + http: requests.Session, +) -> Generator[tuple[dict, dict, dict], None, None]: + """ + Creates a device profile and uploads a firmware package. + Returns (device_profile, firmware, firmware_info). + """ + firmware_name = "pytest-firmware" + firmware_version = "1.0.0" + firmware_bytes = b"Firmware binary data for pytest" + + firmware_info = { + "name": firmware_name, + "version": firmware_version, + "data": firmware_bytes, + } + + device_profile, firmware = create_device_profile_and_firmware( + firmware_name=firmware_name, + firmware_version=firmware_version, + firmware_bytes=firmware_bytes, + base_url=test_config["tb_url"], # type: ignore[arg-type] + headers=tb_admin_headers, + http=http, + timeout=REQUEST_TIMEOUT, + ) + + assert device_profile is not None, "Device profile should be created successfully" + assert firmware is not None, "Firmware should be created successfully" + + try: + yield device_profile, firmware, firmware_info + finally: + # Cleanup + http.delete( + f"{test_config['tb_url']}/api/deviceProfile/{device_profile['id']['id']}", # type: ignore[index] + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + http.delete( + f"{test_config['tb_url']}/api/firmware/{firmware['id']['id']}", # type: ignore[index] + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + + +# ---------------------------- +# Rule chain fixtures +# ---------------------------- + +@pytest.fixture() +def rpc_rule_chain( + tb_admin_headers: Dict[str, str], test_config: Dict[str, object], http: requests.Session +) -> Generator[dict, None, None]: + """ + Creates a rule chain based on RPC_RESPONSE_RULE_CHAIN and uploads its metadata. + Cleans up the rule chain afterwards. + """ + rule_chain_info = deepcopy(RPC_RESPONSE_RULE_CHAIN["ruleChain"]) + r = http.post( + f"{test_config['tb_url']}/api/ruleChain", + headers=tb_admin_headers, + json=rule_chain_info, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + rule_chain = r.json() + rule_chain_id = rule_chain["id"]["id"] + + # Get current metadata, then replace nodes/connections from our template + meta_resp = http.get( + f"{test_config['tb_url']}/api/ruleChain/{rule_chain_id}/metadata", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + meta_resp.raise_for_status() + + received_metadata = meta_resp.json() + tpl_meta = deepcopy(RPC_RESPONSE_RULE_CHAIN["metadata"]) + + received_metadata["nodes"] = tpl_meta["nodes"] + now_ms = int(time.time() * 1000) + for node in received_metadata["nodes"]: + node["createdTime"] = now_ms + + received_metadata["connections"] = tpl_meta["connections"] + received_metadata["firstNodeIndex"] = tpl_meta["firstNodeIndex"] + received_metadata["ruleChainId"] = {"id": rule_chain_id, "entityType": "RULE_CHAIN"} + + apply_resp = http.post( + f"{test_config['tb_url']}/api/ruleChain/metadata", + headers=tb_admin_headers, + json=received_metadata, + timeout=REQUEST_TIMEOUT, + ) + apply_resp.raise_for_status() + + try: + yield rule_chain + finally: + http.delete( + f"{test_config['tb_url']}/api/ruleChain/{rule_chain_id}", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) + + +@pytest.fixture() +def device_profile_with_rpc_rule_chain( + tb_admin_headers: Dict[str, str], + test_config: Dict[str, object], + rpc_rule_chain: dict, + http: requests.Session, +) -> Generator[dict, None, None]: + """ + Creates a device profile that uses the RPC rule chain as default; cleans up afterward. + """ + rule_chain = rpc_rule_chain + device_profile = get_default_device_profile(test_config["tb_url"], tb_admin_headers) + device_profile["id"] = None + device_profile["createdTime"] = None + device_profile["name"] = "pytest-client-side-rpc-profile" + device_profile["defaultRuleChainId"] = rule_chain["id"] + device_profile.setdefault( + "profileData", + { + "configuration": {"type": "DEFAULT"}, + "transportConfiguration": {"type": "DEFAULT"}, + }, + ) + + r = http.post( + f"{test_config['tb_url']}/api/deviceProfile", + headers=tb_admin_headers, + json=device_profile, + timeout=REQUEST_TIMEOUT, + ) + r.raise_for_status() + created = r.json() + try: + yield created + finally: + device_profile_id = created["id"]["id"] + http.delete( + f"{test_config['tb_url']}/api/deviceProfile/{device_profile_id}", + headers=tb_admin_headers, + timeout=REQUEST_TIMEOUT, + ) diff --git a/tests/blackbox/constants.py b/tests/blackbox/constants.py new file mode 100644 index 0000000..4acdbaa --- /dev/null +++ b/tests/blackbox/constants.py @@ -0,0 +1,81 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +RPC_RESPONSE_RULE_CHAIN = { + "ruleChain": { + "name": "test-client-side-rpc", + "type": "CORE", + "firstRuleNodeId": None, + "root": False, + "debugMode": False, + "configuration": None, + "additionalInfo": { + "description": "" + } + }, + "metadata": { + "ruleChainId": { + }, + "version": 3, + "firstNodeIndex": 0, + "nodes": [ + { + "createdTime": 1754561486165, + "type": "org.thingsboard.rule.engine.filter.TbMsgTypeSwitchNode", + "name": "msg_type_switch", + "debugSettings": None, + "singletonMode": False, + "queueName": None, + "configurationVersion": 0, + "configuration": { + "version": 0 + }, + "externalId": None, + "additionalInfo": { + "description": "", + "layoutX": 277, + "layoutY": 150 + } + }, + { + "createdTime": 1754561486166, + "type": "org.thingsboard.rule.engine.rpc.TbSendRPCReplyNode", + "name": "rpc_reply", + "debugSettings": None, + "singletonMode": False, + "queueName": None, + "configurationVersion": 0, + "configuration": { + "serviceIdMetaDataAttribute": "serviceId", + "sessionIdMetaDataAttribute": "sessionId", + "requestIdMetaDataAttribute": "requestId" + }, + "externalId": None, + "additionalInfo": { + "description": "", + "layoutX": 542, + "layoutY": 153 + } + } + ], + "connections": [ + { + "fromIndex": 0, + "toIndex": 1, + "type": "RPC Request from Device" + } + ], + "ruleChainConnections": None +} +} \ No newline at end of file diff --git a/tests/blackbox/rest_helpers.py b/tests/blackbox/rest_helpers.py new file mode 100644 index 0000000..f515c5f --- /dev/null +++ b/tests/blackbox/rest_helpers.py @@ -0,0 +1,272 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from time import time as now +from typing import Dict, List, Optional, Tuple + +import requests + + +def _check_ok(resp: requests.Response) -> None: + if not resp.ok: + resp.raise_for_status() + + +def get_device_attributes( + device_id: str, base_url: str, headers: Dict[str, str], scope: str, *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> List[dict]: + """Fetch device attributes for a scope.""" + sess = http or requests + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/attributes" + resp = sess.get(url, headers=headers, params={"scope": scope}, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def get_device_timeseries( + device_id: str, + base_url: str, + headers: Dict[str, str], + keys: Optional[List[str]] = None, + *, + http: Optional[requests.Session] = None, + timeout: float = 30.0, +) -> Dict[str, list]: + """Fetch device timeseries for the last 60 seconds, optionally filtered by keys.""" + sess = http or requests + current_ts = int(now() * 1000) + params = { + "startTs": current_ts - 60_000, + "endTs": current_ts, + "useStrictDataTypes": "true", + } + if keys: + params["keys"] = ",".join(keys) + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/values/timeseries" + resp = sess.get(url, headers=headers, params=params, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def update_shared_attributes( + device_id: str, base_url: str, headers: Dict[str, str], attributes: Dict[str, object], *, + http: Optional[requests.Session] = None, timeout: float = 30.0 +) -> None: + """Update shared scope attributes for a device.""" + sess = http or requests + url = f"{base_url}/api/plugins/telemetry/DEVICE/{device_id}/attributes/SHARED_SCOPE" + resp = sess.post(url, json=attributes, headers=headers, timeout=timeout) + _check_ok(resp) + + +def get_device_info_by_name( + device_name: str, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> dict: + """Look up a device by its exact name.""" + sess = http or requests + url = f"{base_url}/api/tenant/devices" + resp = sess.get(url, headers=headers, params={"deviceName": device_name}, timeout=timeout) + _check_ok(resp) + device = resp.json() + if not device: + raise ValueError(f"Device with name '{device_name}' not found.") + return device + + +def find_related_entity_id( + device_id: str, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> str: + """Return the first related entity ID for a device (if any).""" + sess = http or requests + url = f"{base_url}/api/relations" + resp = sess.get(url, headers=headers, params={"fromId": device_id, "fromType": "DEVICE"}, timeout=timeout) + _check_ok(resp) + relations = resp.json() + if relations: + return relations[0]["to"]["id"] + raise ValueError(f"No related entity found for device ID '{device_id}'.") + + +def get_device_info_by_id( + device_id: str, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> dict: + """Fetch device info by ThingsBoard device UUID.""" + sess = http or requests + url = f"{base_url}/api/device/{device_id}" + resp = sess.get(url, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def get_default_device_profile( + base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, timeout: float = 30.0 +) -> dict: + """Get the tenant default device profile (info).""" + sess = http or requests + url = f"{base_url}/api/deviceProfileInfo/default" + resp = sess.get(url, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def create_device_profile_with_provisioning( + device_profile_name: str, + base_url: str, + headers: Dict[str, str], + provisioning_device_key: str, + provisioning_device_secret: str, + *, + http: Optional[requests.Session] = None, + timeout: float = 30.0, +) -> dict: + """Create a device profile with ALLOW_CREATE_NEW_DEVICES provisioning.""" + sess = http or requests + device_profile = get_default_device_profile(base_url, headers, http=sess, timeout=timeout) + device_profile.update( + { + "id": None, + "name": device_profile_name, + "createdTime": None, + "provisionType": "ALLOW_CREATE_NEW_DEVICES", + } + ) + profile_data = device_profile.setdefault( + "profileData", + { + "configuration": {"type": "DEFAULT"}, + "transportConfiguration": {"type": "DEFAULT"}, + "alarms": None, + }, + ) + profile_data["provisionConfiguration"] = { + "type": "ALLOW_CREATE_NEW_DEVICES", + "provisionDeviceSecret": provisioning_device_secret, + } + device_profile["provisionDeviceKey"] = provisioning_device_key + + url = f"{base_url}/api/deviceProfile" + resp = sess.post(url, json=device_profile, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() + + +def create_device_profile_and_firmware( + firmware_name: str, + firmware_version: str, + firmware_bytes: bytes, + base_url: str, + headers: Dict[str, str], + *, + http: Optional[requests.Session] = None, + timeout: float = 30.0, +) -> Tuple[dict, dict]: + """ + Create a device profile and upload a firmware OTA package. + + Returns: + (device_profile, firmware_ota_package) + """ + sess = http or requests + + # Create device profile + device_profile = get_default_device_profile(base_url, headers, http=sess, timeout=timeout) + device_profile.update({"id": None, "createdTime": None, "name": "pytest-firmware-profile"}) + device_profile.setdefault("profileData", { + "configuration": {"type": "DEFAULT"}, + "transportConfiguration": {"type": "DEFAULT"}, + }) + + resp = sess.post(f"{base_url}/api/deviceProfile", json=device_profile, headers=headers, timeout=timeout) + _check_ok(resp) + device_profile = resp.json() + + # Create OTA package (metadata) + init_ota_package = { + "id": None, + "createdTime": None, + "deviceProfileId": {"entityType": "DEVICE_PROFILE", "id": device_profile["id"]["id"]}, + "type": "FIRMWARE", + "title": firmware_name, + "version": firmware_version, + "tag": f"{firmware_name} {firmware_version}", + "url": None, + "hasData": False, + "fileName": None, + "contentType": None, + "checksumAlgorithm": None, + "checksum": None, + "dataSize": None, + "externalId": None, + "name": firmware_name, + "additionalInfo": {"description": ""}, + } + created_ota_package = sess.post(f"{base_url}/api/otaPackage", json=init_ota_package, headers=headers, + timeout=timeout) + _check_ok(created_ota_package) + initial = created_ota_package.json() + + # Upload firmware data using proper multipart form-data + upload_url = f"{base_url}/api/otaPackage/{initial['id']['id']}?checksumAlgorithm=SHA256" + files = { + "file": ("firmware.bin", firmware_bytes, "application/octet-stream"), + } + # Only auth header required; requests builds Content-Type with boundary. + upload_headers = {k: v for k, v in headers.items() if k.lower() == "x-authorization"} + upload_resp = sess.post(upload_url, headers=upload_headers, files=files, timeout=timeout) + _check_ok(upload_resp) + + return device_profile, upload_resp.json() + + +def save_device( + device: dict, base_url: str, headers: Dict[str, str], *, http: Optional[requests.Session] = None, + timeout: float = 30.0 +) -> dict: + """Create or update a device via REST.""" + sess = http or requests + url = f"{base_url}/api/device" + resp = sess.post(url, json=device, headers=headers, timeout=timeout) + _check_ok(resp) + return resp.json() + + +async def send_rpc_request( + device_id: str, + base_url: str, + headers: Dict[str, str], + request: dict, + *, + http: Optional[requests.Session] = None, + timeout: float = 60.0, +) -> dict: + """ + Send a two-way RPC request to a device. Performs the blocking HTTP call in a thread pool. + """ + sess = http or requests + + def _send() -> dict: + url = f"{base_url}/api/rpc/twoway/{device_id}" + r = sess.post(url, json=request, headers=headers, timeout=timeout) + r.raise_for_status() + return r.json() + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, _send) diff --git a/tests/blackbox/test_basic_device_examples.py b/tests/blackbox/test_basic_device_examples.py new file mode 100644 index 0000000..4cc4b09 --- /dev/null +++ b/tests/blackbox/test_basic_device_examples.py @@ -0,0 +1,207 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest +from examples.device import send_attributes +from examples.device import send_timeseries +from examples.device import request_attributes +from examples.device import handle_attribute_updates +from examples.device import handle_rpc_requests +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tests.blackbox.rest_helpers import get_device_attributes, get_device_timeseries, update_shared_attributes, \ + send_rpc_request + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_send_attributes(device_config, device_info, tb_admin_headers, test_config): + expected_attributes = { + "firmwareVersion": "1.0.3", + "hardwareModel": "TB-SDK-Device", + "mode": "normal", + "location": "Building A", + "status": "active" + } + + send_attributes.config = device_config + + await send_attributes.main() + + attrs = get_device_attributes(device_info['id']['id'], test_config['tb_url'], tb_admin_headers, 'CLIENT_SCOPE') + received_keys = {a.get("key") for a in attrs} + for key, value in expected_attributes.items(): + assert any(a.get("key") == key and a.get("value") == value for a in attrs) + assert key in received_keys + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_send_timeseries(device_config, device_info, tb_admin_headers, test_config): + expected_timeseries = { + "temperature": "any", # Accept any value for temperature + "humidity": "any", # Accept any value for humidity + "pressure": "any", # Accept any value for pressure + "vibration": 0.05, + "speed": 123 + } + additional_timeseries = { + "vibration": 0.01, + "speed": 120 + } + + send_timeseries.config = device_config + await send_timeseries.main() + + timeseries = get_device_timeseries(device_info['id']['id'], test_config['tb_url'], tb_admin_headers) + for key, value in expected_timeseries.items(): + assert key in timeseries, f"Expected timeseries key '{key}' not found." + assert any((ts.get("value") == value or ts.get("value", True) == additional_timeseries.get(key, False) or value == "any") for ts in timeseries[key]), f"Expected value '{value}' for key '{key}' not found." + + all_timeseries_keys = set(timeseries.keys()) + all_timeseries = get_device_timeseries(device_info['id']['id'], test_config['tb_url'], + tb_admin_headers, keys=list(all_timeseries_keys)) + for key in all_timeseries_keys: + for ts_entry in all_timeseries[key]: + assert (expected_timeseries[key] == 'any' or ts_entry.get("value") == expected_timeseries[key] or ts_entry.get("value") == additional_timeseries.get(key)), f"Expected value '{expected_timeseries[key]}' for key '{key}' not found in all timeseries." + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_request_attributes(device_config, device_info, tb_admin_headers, test_config): + expected_attributes = { + "currentTemperature": 22.5 + } + + attributes_received = asyncio.Event() + + handler_future = asyncio.Future() + + async def attribute_request_callback(response: RequestedAttributeResponse): + attributes_received.set() + try: + assert response is not None, "Response should not be None" + assert response.request_id != -1, "Request ID should not be -1" + assert response.client, "Client attributes should not be empty" + for expected_attribute in expected_attributes: + assert any(attr.key == expected_attribute and attr.value == expected_attributes[expected_attribute] for attr in response.client), f"Expected attribute '{expected_attribute}' not found in response" + handler_future.set_result(True) + except AssertionError as e: + handler_future.set_exception(e) + + request_attributes.config = device_config + request_attributes.attribute_request_callback = attribute_request_callback + request_attributes.response_received = attributes_received + + await request_attributes.main() + + await asyncio.wait_for(handler_future, timeout=10) + + attrs = get_device_attributes(device_info['id']['id'], test_config['tb_url'], tb_admin_headers, 'CLIENT_SCOPE') + received_keys = {a.get("key") for a in attrs} + for key, value in expected_attributes.items(): + assert any(a.get("key") == key and a.get("value") == value for a in attrs) + assert key in received_keys + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_handle_attribute_updates(device_config, device_info, tb_admin_headers, test_config): + + test_shared_attribute = {"pytest_shared_attribute": "test_value"} + + handler_future = asyncio.Future() + + async def attribute_update_callback(update: AttributeUpdate): + try: + assert update is not None, "Update should not be None" + assert update.keys()[0] == list(test_shared_attribute.keys())[0], "Update key does not match expected key" + assert update.values()[0] == list(test_shared_attribute.values())[0], "Update value does not match expected value" + handler_future.set_result(True) + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_attribute_updates.attribute_update_callback = attribute_update_callback + handle_attribute_updates.config = device_config + + task = asyncio.create_task(handle_attribute_updates.main()) + + await asyncio.sleep(3) + try: + update_shared_attributes( + device_info['id']['id'], + test_config['tb_url'], + tb_admin_headers, + test_shared_attribute + ) + + await asyncio.wait_for(handler_future, timeout=10) + except asyncio.TimeoutError: + pytest.fail("Attribute update was not received within the timeout period") + finally: + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_handle_rpc_requests(device_config, device_info, tb_admin_headers, test_config): + + handler_future = asyncio.Future() + + rpc_request = { + "method": "getDeviceInfo", + "params": {"deviceId": device_info['id']['id']} + } + + async def rpc_request_callback(request: RPCRequest): + try: + assert request is not None, "RPC request should not be None" + assert request.method == "getDeviceInfo", "Unexpected RPC method" + response = RPCResponse.build(request_id=request.request_id, + result={"id": device_info['id'], "name": device_info['name']}) + handler_future.set_result(response) + return response + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_rpc_requests.rpc_request_callback = rpc_request_callback + handle_rpc_requests.config = device_config + + task = asyncio.create_task(handle_rpc_requests.main()) + await asyncio.sleep(1) + + try: + + response_task = asyncio.create_task(send_rpc_request(device_info['id']['id'], test_config['tb_url'], tb_admin_headers, rpc_request)) + + await asyncio.wait_for(handler_future, timeout=10) + + response = await asyncio.wait_for(response_task, timeout=10) + + assert response is not None, "RPC response should not be None" + + except asyncio.TimeoutError: + pytest.fail("RPC request was not received within the timeout period") + except Exception as e: + pytest.fail(f"An unexpected exception occurred: {e}") + finally: + task.cancel() \ No newline at end of file diff --git a/tests/blackbox/test_basic_gateway_examples.py b/tests/blackbox/test_basic_gateway_examples.py new file mode 100644 index 0000000..906dc65 --- /dev/null +++ b/tests/blackbox/test_basic_gateway_examples.py @@ -0,0 +1,359 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest + +from examples.gateway import connect_and_disconnect_device +from examples.gateway import send_timeseries +from examples.gateway import send_attributes +from examples.gateway import request_attributes +from examples.gateway import handle_attribute_updates +from examples.gateway import handle_rpc_requests +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tests.blackbox.rest_helpers import get_device_info_by_name, get_device_attributes, get_device_timeseries, \ + update_shared_attributes, send_rpc_request + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_connect_and_disconnect_device(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice" + device_profile = "pytest-gw-subdevice-profile" + connect_and_disconnect_device.config = gateway_config + connect_and_disconnect_device.device_name = device_name + connect_and_disconnect_device.device_profile = device_profile + + task = asyncio.create_task(connect_and_disconnect_device.main()) + + try: + await asyncio.sleep(3) + device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + device_service_attributes = get_device_attributes(device_info["id"]["id"], test_config["tb_url"], tb_admin_headers, "SERVER_SCOPE") + server_attributes = {attr["key"]: attr["value"] for attr in device_service_attributes} + assert device_info['type'] == device_profile + assert server_attributes['active'] + assert abs(server_attributes['lastActivityTime'] - server_attributes["lastConnectTime"]) < 5 + assert server_attributes['lastDisconnectTime'] >= server_attributes["lastConnectTime"] + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_send_timeseries(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-ts" + device_profile = "pytest-gw-subdevice-ts-profile" + + send_timeseries.config = gateway_config + send_timeseries.device_name = device_name + send_timeseries.device_profile = device_profile + + task = asyncio.create_task(send_timeseries.main()) + + try: + await asyncio.sleep(5) + + device_info = get_device_info_by_name( + device_name, test_config["tb_url"], tb_admin_headers + ) + assert device_info is not None, "Subdevice was not created/connected." + + timeseries = get_device_timeseries( + device_info["id"]["id"], test_config["tb_url"], tb_admin_headers + ) + + expected_keys = {"temperature", "humidity"} + for key in expected_keys: + assert key in timeseries, f"Timeseries key '{key}' not found in device data." + + for key in expected_keys: + assert any(isinstance(entry.get("value"), (int, float)) for entry in timeseries[key]), \ + f"No numeric values found for '{key}'." + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_send_attributes(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-attrs" + device_profile = "pytest-gw-subdevice-attrs-profile" + + send_attributes.config = gateway_config + send_attributes.device_name = device_name + send_attributes.device_profile = device_profile + + task = asyncio.create_task(send_attributes.main()) + + try: + await asyncio.sleep(5) + + device_info = get_device_info_by_name( + device_name, test_config["tb_url"], tb_admin_headers + ) + assert device_info is not None, "Subdevice was not created/connected." + + attrs = get_device_attributes( + device_info["id"]["id"], + test_config["tb_url"], + tb_admin_headers, + scope="CLIENT_SCOPE" + ) + attr_dict = {a["key"]: a["value"] for a in attrs} + + expected_attrs = { + "maintenance": "scheduled", + "id": 341, + "location": "office", + "status": "active", + "version": "1.0.0" + } + + for key, value in expected_attrs.items(): + assert key in attr_dict, f"Attribute '{key}' not found." + assert attr_dict[key] == value, ( + f"Attribute '{key}' expected '{value}', got '{attr_dict[key]}'" + ) + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_request_attributes(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-attrs-request" + device_profile = "pytest-gw-subdevice-attrs-request-profile" + + expected_attrs = { + "maintenance": "scheduled", + "id": 341, + "location": "office", + } + + handler_future = asyncio.Future() + + async def requested_attributes_handler(device_session, response: GatewayRequestedAttributeResponse): + try: + assert device_session is not None, "Device session should not be None." + assert device_session.device_info is not None, "Device info should not be None." + assert device_session.device_info.device_name == device_name, ( + f"Device name mismatch: expected '{device_name}', got '{device_session.device_info.device_name}'" + ) + assert response is not None, "Response should not be None." + assert isinstance(response.client, list), "Client attributes should be a list." + assert isinstance(response.shared, list), "Shared attributes should be a list." + if response.request_id == 1: # Assuming request_id 1 is for the client attributes + assert len(response.client) > 0, "Client attributes should not be empty." + assert len(response.shared) >= 0, "Shared attributes can be empty." + for attr in response.client: + for key, value in expected_attrs.items(): + if attr.key == key: + assert attr.value == value, ( + f"Attribute '{key}' expected '{value}', got '{attr.value}'" + ) + else: + assert response.client == [], "Client attributes should be empty for shared attribute requests." + assert response.shared == [], "Shared attributes should be empty, because they were not set." + handler_future.set_result(True) + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + request_attributes.config = gateway_config + request_attributes.device_name = device_name + request_attributes.device_profile = device_profile + request_attributes.requested_attributes_handler = requested_attributes_handler + + task = asyncio.create_task(request_attributes.main()) + + try: + await asyncio.wait_for(handler_future, timeout=10) + + device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + assert device_info is not None, "Subdevice was not created/connected." + + attrs = get_device_attributes( + device_info["id"]["id"], + test_config["tb_url"], + tb_admin_headers, + scope="CLIENT_SCOPE" + ) + attr_dict = {a["key"]: a["value"] for a in attrs} + + for key, value in expected_attrs.items(): + assert key in attr_dict, f"Attribute '{key}' not found." + assert attr_dict[key] == value, ( + f"Attribute '{key}' expected '{value}', got '{attr_dict[key]}'" + ) + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_handle_attribute_updates(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-attrs-update" + device_profile = "pytest-gw-subdevice-attrs-update-profile" + + test_shared_attribute = {"pytest_shared_attribute": "test_value"} + + handler_future = asyncio.Future() + + async def attribute_update_handler(device_session: DeviceSession, update: AttributeUpdate): + try: + assert device_session is not None, "Device session should not be None." + assert device_session.device_info is not None, "Device info should not be None." + assert device_session.device_info.device_name == device_name, ( + f"Device name mismatch: expected '{device_name}', got '{device_session.device_info.device_name}'" + ) + assert update is not None, "Update should not be None." + assert update.keys()[0] == list(test_shared_attribute.keys())[0], ( + "Update key does not match expected key." + ) + assert update.values()[0] == list(test_shared_attribute.values())[0], ( + "Update value does not match expected value." + ) + handler_future.set_result(True) + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_attribute_updates.config = gateway_config + handle_attribute_updates.device_name = device_name + handle_attribute_updates.device_profile = device_profile + handle_attribute_updates.attribute_update_handler = attribute_update_handler + + task = asyncio.create_task(handle_attribute_updates.main()) + + try: + await asyncio.sleep(3) + + sub_device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + + update_shared_attributes( + sub_device_info["id"]["id"], + test_config["tb_url"], + tb_admin_headers, + test_shared_attribute + ) + + await asyncio.wait_for(handler_future, timeout=10) + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_gateway_handle_rpc_requests(gateway_config, tb_admin_headers, test_config): + device_name = "pytest-gw-subdevice-rpc" + device_profile = "pytest-gw-subdevice-rpc-profile" + + handler_future = asyncio.Future() + + async def device_rpc_request_handler(device_session: DeviceSession, rpc_request: GatewayRPCRequest): + try: + assert device_session is not None, "Device session should not be None." + assert device_session.device_info is not None, "Device info should not be None." + assert device_session.device_info.device_name == device_name, ( + f"Device name mismatch: expected '{device_name}', got '{device_session.device_info.device_name}'" + ) + assert rpc_request is not None, "RPC request should not be None." + assert rpc_request.method == "testMethod", "RPC method does not match expected method." + response_data = { + "data": { + "device_name": device_session.device_info.device_name, + "request_id": rpc_request.request_id, + "method": rpc_request.method, + "params": rpc_request.params + } + } + rpc_response = GatewayRPCResponse.build( + device_session.device_info.device_name, + rpc_request.request_id, + response_data + ) + handler_future.set_result(True) + return rpc_response + except AssertionError as e: + if not handler_future.done(): + handler_future.set_exception(e) + + handle_rpc_requests.config = gateway_config + handle_rpc_requests.device_name = device_name + handle_rpc_requests.device_profile = device_profile + handle_rpc_requests.device_rpc_request_handler = device_rpc_request_handler + + task = asyncio.create_task(handle_rpc_requests.main()) + await asyncio.sleep(3) + + try: + + sub_device_info = get_device_info_by_name(device_name, test_config["tb_url"], tb_admin_headers) + rpc_request = { + "method": "testMethod", + "params": {"param1": "value1"} + } + + response_task = asyncio.create_task(send_rpc_request(sub_device_info['id']['id'], test_config['tb_url'], tb_admin_headers, rpc_request)) + + await asyncio.wait_for(handler_future, timeout=10) + + response_data = await asyncio.wait_for(response_task, timeout=10) + + assert 'result' in response_data, "RPC response should contain 'result'." + assert 'data' in response_data['result'], "RPC response should contain 'data'." + assert 'device_name' in response_data['result']['data'], "RPC response data should contain 'device_name'." + assert response_data['result']['data']['device_name'] == device_name, "RPC response device_name does not match expected device_name." + assert 'method' in response_data['result']['data'], "RPC response data should contain 'method'." + assert response_data['result']['data']['method'] == "testMethod", "RPC response method does not match expected method." + assert 'request_id' in response_data['result']['data'], "RPC response data should contain 'request_id'." + assert response_data['result']['data']['request_id'] == 0, "RPC response request_id does not match expected request_id." + assert 'params' in response_data['result']['data'], "RPC response data should contain 'params'." + assert response_data['result']['data']['params'] == {"param1": "value1"}, "RPC response params do not match expected params." + + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/tests/blackbox/test_device_client_provisioning.py b/tests/blackbox/test_device_client_provisioning.py new file mode 100644 index 0000000..5450007 --- /dev/null +++ b/tests/blackbox/test_device_client_provisioning.py @@ -0,0 +1,79 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio + +import pytest +import requests + +from examples.device import client_provisioning +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials +from tests.blackbox.rest_helpers import create_device_profile_with_provisioning, \ + get_device_timeseries, get_device_info_by_name + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_client_provisioning(tb_admin_headers, test_config): + device_name = 'pytest-provisioned-device' + + provisioning_device_key = 'pytest-provisioning-device-key' + provisioning_device_secret = 'pytest-provisioning-device-secret' + + provisioning_credentials = AccessTokenProvisioningCredentials( + provision_device_key=provisioning_device_key, + provision_device_secret=provisioning_device_secret, + ) + provisioning_request = ProvisioningRequest(test_config['tb_host'], + credentials=provisioning_credentials, + device_name=device_name) + + device_profile = create_device_profile_with_provisioning('pytest-provisioning-profile', + test_config['tb_url'], + tb_admin_headers, + provisioning_device_key, + provisioning_device_secret) + provisioned_device_info = None + try: + client_provisioning.provisioning_request = provisioning_request + + try: + await client_provisioning.main() + except Exception as e: + pytest.fail(f"Provisioning failed with exception: {e}") + + provisioned_device_info = get_device_info_by_name(device_name, + test_config['tb_url'], + tb_admin_headers) + + assert provisioned_device_info is not None, "Provisioned device info should not be None" + assert provisioned_device_info['name'] == device_name, "Provisioned device name should match" + assert provisioned_device_info['deviceProfileId']['id'] == device_profile['id']['id'], \ + "Provisioned device profile ID should match" + + timeseries = get_device_timeseries(provisioned_device_info['id']['id'], test_config['tb_url'], tb_admin_headers, + ['batteryLevel']) + assert len(timeseries) == 1, "There should be one timeseries entry" + assert 'batteryLevel' in timeseries, "Timeseries key should be 'batteryLevel'" + assert 'value' in timeseries['batteryLevel'][0], "Timeseries entry should have a value" + finally: + # Cleanup: delete the provisioned device and profile + if provisioned_device_info and 'id' in provisioned_device_info: + device_id = provisioned_device_info['id']['id'] + delete_url = f"{test_config['tb_url']}/api/device/{device_id}" + requests.delete(delete_url, headers=tb_admin_headers) + + if 'id' in device_profile: + profile_id = device_profile['id']['id'] + delete_profile_url = f"{test_config['tb_url']}/api/deviceProfile/{profile_id}" + requests.delete(delete_profile_url, headers=tb_admin_headers) diff --git a/tests/blackbox/test_device_client_side_rpc.py b/tests/blackbox/test_device_client_side_rpc.py new file mode 100644 index 0000000..c853939 --- /dev/null +++ b/tests/blackbox/test_device_client_side_rpc.py @@ -0,0 +1,62 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest + +from examples.device import send_client_side_rpc +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tests.blackbox.rest_helpers import save_device + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_client_side_rpc(device_profile_with_rpc_rule_chain, device_info, device_config, tb_admin_headers, test_config): + handler_future = asyncio.Future() + async def rpc_response_handler(rpc_request): + handler_future.set_result(rpc_request) + + send_client_side_rpc.config = device_config + send_client_side_rpc.rpc_response_callback = rpc_response_handler + + device_profile_id = device_profile_with_rpc_rule_chain['id'] + device_info['deviceProfileId'] = device_profile_id + device = save_device(device_info, test_config['tb_url'], tb_admin_headers) + assert device is not None, "Device should be saved successfully" + + await asyncio.sleep(1) # Ensure the device is saved before starting the RPC request + + task = asyncio.create_task(send_client_side_rpc.main()) + result = None + try: + result = await asyncio.wait_for(handler_future, timeout=30) + except Exception as e: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + first_rpc_response = send_client_side_rpc.rpc_response + second_rpc_response = result + + assert first_rpc_response is not None, "First RPC response should not be None" + assert isinstance(first_rpc_response, RPCResponse), "First RPC response should be an instance of RPCResponse" + assert first_rpc_response.request_id < second_rpc_response.request_id, "First RPC response ID should be less than second RPC response ID" + assert first_rpc_response.result['method'] == 'getTime', "First RPC response method should be 'getTime'" + + assert second_rpc_response is not None, "Second RPC response should not be None" + assert isinstance(second_rpc_response, RPCResponse), "Second RPC response should be an instance of RPCResponse" + assert second_rpc_response.result['method'] == 'getStatus', "Second RPC response method should be 'getStatus'" + assert second_rpc_response.result['params'] == {'param1': 'value1', 'param2': 'value2'}, "Second RPC response params should match expected values" diff --git a/tests/blackbox/test_device_firmware_update.py b/tests/blackbox/test_device_firmware_update.py new file mode 100644 index 0000000..f1c324b --- /dev/null +++ b/tests/blackbox/test_device_firmware_update.py @@ -0,0 +1,64 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from copy import deepcopy + +import pytest + +from examples.device import firmware_update +from tests.blackbox.rest_helpers import save_device + + +@pytest.mark.asyncio +@pytest.mark.blackbox +async def test_device_firmware_update(firmware_profile_and_package, device_info, device_config, tb_admin_headers, test_config): + + handler_future = asyncio.Future() + + async def firmware_update_callback(firmware_data, current_firmware_info): + handler_future.set_result((firmware_data, current_firmware_info)) + + firmware_update.config = device_config + firmware_update.firmware_update_callback = firmware_update_callback + + device_profile, firmware_package, firmware_info = firmware_profile_and_package + + device = deepcopy(device_info) + + device['deviceProfileId'] = device_profile['id'] + device['firmwareId'] = firmware_package['id'] + + save_device(device, test_config['tb_url'], tb_admin_headers) + + await asyncio.sleep(1) # Ensure the device is saved before starting the firmware update + + task = asyncio.create_task(firmware_update.main()) + result = None + try: + result = await asyncio.wait_for(handler_future, timeout=30) + except Exception as e: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert result is not None, "Firmware update callback should be called" + assert isinstance(result, tuple), "Result should be a tuple containing firmware data and info" + received_firmware_data, received_firmware_info = result + assert received_firmware_info['current_fw_title'] == firmware_info['name'] + assert received_firmware_info['current_fw_version'] == firmware_info['version'], \ + "Received firmware version should match the package version" + assert received_firmware_info['fw_state'] == 'UPDATED' diff --git a/tests/common/__init__.py b/tests/common/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/common/test_async_utils.py b/tests/common/test_async_utils.py new file mode 100644 index 0000000..ff88b6d --- /dev/null +++ b/tests/common/test_async_utils.py @@ -0,0 +1,208 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import threading +import time + +import pytest + +import tb_mqtt_client.common.async_utils as async_utils_mod +from tb_mqtt_client.common.async_utils import FutureMap, await_or_stop, run_coroutine_sync +from tb_mqtt_client.common.publish_result import PublishResult + + +@pytest.fixture +def fake_loop(monkeypatch): + class _FakeTask: + def __init__(self, thread: threading.Thread): + self._thread = thread + + def done(self): # optional helpers if ever needed + return not self._thread.is_alive() + + class _FakeLoop: + def create_task(self, coro): + t = threading.Thread(target=lambda: asyncio.run(coro), daemon=True) + t.start() + return _FakeTask(t) + + monkeypatch.setattr(async_utils_mod.asyncio, "get_running_loop", lambda: _FakeLoop()) + + return _FakeLoop() + + +@pytest.mark.asyncio +async def test_future_map_register_and_get_parents(): + fm = FutureMap() + parent = asyncio.Future() + child1, child2 = asyncio.Future(), asyncio.Future() + + fm.register(parent, [child1]) + assert fm.get_parents(child1) == [parent] + + # Add more children to same parent + fm.register(parent, [child2]) + assert set(fm.get_parents(child1)) == {parent} + assert set(fm.get_parents(child2)) == {parent} + + +@pytest.mark.asyncio +async def test_future_map_child_resolved_merges_results(): + fm = FutureMap() + parent = asyncio.Future() + child1, child2 = asyncio.Future(), asyncio.Future() + + fm.register(parent, [child1, child2]) + child1.set_result(PublishResult("topic", 1, 1, 1, 0)) + child2.set_result(PublishResult("topic", 1, 1, 1, 0)) + + # Resolving child1 won't complete parent + fm.child_resolved(child1) + assert not parent.done() + + # Resolving last child completes parent with merged PublishResult + fm.child_resolved(child2) + assert parent.done() + result = parent.result() + assert isinstance(result, PublishResult) + assert result.reason_code == 0 + + +@pytest.mark.asyncio +async def test_future_map_child_resolved_no_publish_result(): + fm = FutureMap() + parent = asyncio.Future() + child = asyncio.Future() + + fm.register(parent, [child]) + child.set_result("not publish result") + fm.child_resolved(child) + assert parent.done() + assert parent.result() is None + + +@pytest.mark.asyncio +async def test_future_map_child_resolved_with_cancelled_child(): + fm = FutureMap() + parent = asyncio.Future() + child = asyncio.Future() + fm.register(parent, [child]) + child.cancel() + fm.child_resolved(child) + assert parent.done() + assert parent.result() is None + + +@pytest.mark.asyncio +async def test_await_or_stop_coroutine_finishes_first(): + stop_event = asyncio.Event() + + async def coro(): return 123 + + result = await await_or_stop(coro(), stop_event, timeout=1) + assert result == 123 + + +@pytest.mark.asyncio +async def test_await_or_stop_stop_event_first(): + stop_event = asyncio.Event() + + async def coro(): await asyncio.sleep(0.5) + + asyncio.get_event_loop().call_soon(stop_event.set) + result = await await_or_stop(coro(), stop_event, timeout=1) + assert result is None + + +@pytest.mark.asyncio +async def test_await_or_stop_timeout(): + stop_event = asyncio.Event() + + async def coro(): await asyncio.sleep(1) + + with pytest.raises(asyncio.TimeoutError): + await await_or_stop(coro(), stop_event, timeout=0.01) + + +@pytest.mark.asyncio +async def test_await_or_stop_negative_timeout(): + stop_event = asyncio.Event() + + async def coro(): return "ok" + + result = await await_or_stop(coro(), stop_event, timeout=-1) + assert result == "ok" + + +@pytest.mark.asyncio +async def test_await_or_stop_future_done(): + stop_event = asyncio.Event() + fut = asyncio.Future() + fut.set_result("done") + result = await await_or_stop(fut, stop_event, timeout=1) + assert result == "done" + + +@pytest.mark.asyncio +async def test_await_or_stop_invalid_type(): + stop_event = asyncio.Event() + with pytest.raises(TypeError): + await await_or_stop("not a future", stop_event, timeout=1) + + +@pytest.mark.asyncio +async def test_await_or_stop_cancelled_error(): + stop_event = asyncio.Event() + + async def coro(): + raise asyncio.CancelledError() + + result = await await_or_stop(coro(), stop_event, timeout=1) + assert result is None + + +def test_returns_result_when_coroutine_completes(fake_loop): + async def ok_coro(): + await asyncio.sleep(0.01) + return "done" + + result = run_coroutine_sync(lambda: ok_coro(), timeout=0.5) + assert result == "done" + + +def test_raises_original_exception_from_coroutine(fake_loop): + class CustomError(RuntimeError): + pass + + async def bad_coro(): + await asyncio.sleep(0.01) + raise CustomError("boom") + + with pytest.raises(CustomError, match="boom"): + run_coroutine_sync(lambda: bad_coro(), timeout=0.5) + + +def test_timeout_raises_timeout_error(fake_loop): + async def slow_coro(): + await asyncio.sleep(0.2) + return "too late" + + with pytest.raises(TimeoutError, match=r"did not complete in 0\.05 seconds"): + run_coroutine_sync(lambda: slow_coro(), timeout=0.05, raise_on_timeout=True) + time.sleep(0.25) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/common/test_backpressure_controller.py b/tests/common/test_backpressure_controller.py new file mode 100644 index 0000000..d06d287 --- /dev/null +++ b/tests/common/test_backpressure_controller.py @@ -0,0 +1,108 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from datetime import datetime, timedelta, UTC + +import pytest + +from tb_mqtt_client.common.rate_limit.backpressure_controller import BackpressureController + + +@pytest.fixture +def stop_event(): + return asyncio.Event() + + +@pytest.fixture +def controller(stop_event): + return BackpressureController(stop_event) + + +def test_notify_quota_exceeded_initial(controller): + controller.notify_quota_exceeded() + assert controller._pause_until is not None + assert controller._consecutive_quota_exceeded == 1 + + +def test_notify_quota_exceeded_consecutive(controller): + controller._last_quota_exceeded = datetime.now(UTC) + controller._consecutive_quota_exceeded = 2 + controller.notify_quota_exceeded() + assert controller._consecutive_quota_exceeded == 3 + assert controller._pause_until is not None + + +def test_notify_quota_exceeded_reset_after_60s(controller): + controller._last_quota_exceeded = datetime.now(UTC) - timedelta(seconds=61) + controller._consecutive_quota_exceeded = 5 + controller.notify_quota_exceeded() + assert controller._consecutive_quota_exceeded == 1 + + +def test_notify_quota_exceeded_custom_delay(controller): + controller.notify_quota_exceeded(delay_seconds=30) + assert controller._pause_until is not None + + +def test_notify_quota_exceeded_stop_event_set(stop_event, controller): + stop_event.set() + controller.notify_quota_exceeded() + assert controller._pause_until is None + + +def test_notify_disconnect_default(controller): + controller.notify_disconnect() + assert controller._pause_until is not None + + +def test_notify_disconnect_custom_delay(controller): + controller.notify_disconnect(delay_seconds=25) + assert controller._pause_until is not None + + +def test_notify_disconnect_stop_event_set(stop_event, controller): + stop_event.set() + controller.notify_disconnect() + assert controller._pause_until is None + + +def test_should_pause_active(controller): + controller._pause_until = datetime.now(UTC) + timedelta(seconds=15) + assert controller.should_pause() + + +def test_should_pause_expired(controller): + controller._pause_until = datetime.now(UTC) - timedelta(seconds=1) + assert controller.should_pause() is False + assert controller._pause_until is None + + +def test_should_pause_not_set(controller): + controller._pause_until = None + assert controller.should_pause() is False + + +def test_should_pause_stop_event_set(stop_event, controller): + stop_event.set() + controller._pause_until = datetime.now(UTC) + timedelta(seconds=30) + assert controller.should_pause() is False + + +def test_clear(controller): + controller._pause_until = datetime.now(UTC) + timedelta(seconds=30) + controller._consecutive_quota_exceeded = 5 + controller.clear() + assert controller._pause_until is None + assert controller._consecutive_quota_exceeded == 0 diff --git a/tests/common/test_config_loader.py b/tests/common/test_config_loader.py new file mode 100644 index 0000000..9963e15 --- /dev/null +++ b/tests/common/test_config_loader.py @@ -0,0 +1,135 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from tb_mqtt_client.common.config_loader import DeviceConfig, GatewayConfig + + +class TestDeviceConfig(unittest.TestCase): + + def test_config_creation_from_dict(self): + config_dict = { + "host": "localhost", + "port": 1883, + "access_token": "test_token", + "username": "test_user", + "password": "test_pass", + "client_id": "test_client", + "ca_cert": "test_ca", + "client_cert": "test_cert", + "private_key": "test_key", + "qos": 1 + } + config = DeviceConfig(config_dict) + self.assertEqual(config.host, "localhost") + self.assertEqual(config.port, 1883) + self.assertEqual(config.access_token, "test_token") + self.assertEqual(config.username, "test_user") + self.assertEqual(config.password, "test_pass") + self.assertEqual(config.client_id, "test_client") + self.assertEqual(config.ca_cert, "test_ca") + self.assertEqual(config.client_cert, "test_cert") + self.assertEqual(config.private_key, "test_key") + self.assertEqual(config.qos, 1) + + def test_loads_default_values_when_env_vars_missing(self): + os.environ.clear() + config = DeviceConfig() + self.assertEqual(config.host, None) + self.assertEqual(config.port, 1883) + self.assertEqual(config.access_token, None) + self.assertEqual(config.username, None) + self.assertEqual(config.password, None) + self.assertEqual(config.client_id, None) + self.assertEqual(config.ca_cert, None) + self.assertEqual(config.client_cert, None) + self.assertEqual(config.private_key, None) + self.assertEqual(config.qos, 1) + + def test_loads_values_from_env_vars(self): + os.environ["TB_HOST"] = "test_host" + os.environ["TB_PORT"] = "8883" + os.environ["TB_ACCESS_TOKEN"] = "test_token" + os.environ["TB_USERNAME"] = "test_user" + os.environ["TB_PASSWORD"] = "test_pass" + os.environ["TB_CLIENT_ID"] = "test_client" + os.environ["TB_CA_CERT"] = "test_ca" + os.environ["TB_CLIENT_CERT"] = "test_cert" + os.environ["TB_PRIVATE_KEY"] = "test_key" + os.environ["TB_QOS"] = "2" + + config = DeviceConfig() + self.assertEqual(config.host, "test_host") + self.assertEqual(config.port, 8883) + self.assertEqual(config.access_token, "test_token") + self.assertEqual(config.username, "test_user") + self.assertEqual(config.password, "test_pass") + self.assertEqual(config.client_id, "test_client") + self.assertEqual(config.ca_cert, "test_ca") + self.assertEqual(config.client_cert, "test_cert") + self.assertEqual(config.private_key, "test_key") + self.assertEqual(config.qos, 2) + + def test_detects_tls_auth_correctly(self): + os.environ["TB_CA_CERT"] = "test_ca" + os.environ["TB_CLIENT_CERT"] = "test_cert" + os.environ["TB_PRIVATE_KEY"] = "test_key" + config = DeviceConfig() + self.assertTrue(config.use_tls_auth()) + + def test_detects_tls_correctly(self): + os.environ["TB_CA_CERT"] = "test_ca" + config = DeviceConfig() + self.assertTrue(config.use_tls()) + + +class TestGatewayConfig(unittest.TestCase): + + def test_loads_gateway_specific_env_vars(self): + os.environ["TB_GW_HOST"] = "gw_host" + os.environ["TB_GW_PORT"] = "8884" + os.environ["TB_GW_ACCESS_TOKEN"] = "gw_token" + os.environ["TB_GW_USERNAME"] = "gw_user" + os.environ["TB_GW_PASSWORD"] = "gw_pass" + os.environ["TB_GW_CLIENT_ID"] = "gw_client" + os.environ["TB_GW_CA_CERT"] = "gw_ca" + os.environ["TB_GW_CLIENT_CERT"] = "gw_cert" + os.environ["TB_GW_PRIVATE_KEY"] = "gw_key" + os.environ["TB_GW_QOS"] = "0" + + config = GatewayConfig() + self.assertEqual(config.host, "gw_host") + self.assertEqual(config.port, 8884) + self.assertEqual(config.access_token, "gw_token") + self.assertEqual(config.username, "gw_user") + self.assertEqual(config.password, "gw_pass") + self.assertEqual(config.client_id, "gw_client") + self.assertEqual(config.ca_cert, "gw_ca") + self.assertEqual(config.client_cert, "gw_cert") + self.assertEqual(config.private_key, "gw_key") + self.assertEqual(config.qos, 0) + + def test_falls_back_to_device_config_when_gateway_env_vars_missing(self): + os.environ.clear() + os.environ["TB_HOST"] = "device_host" + os.environ["TB_PORT"] = "1884" + config = GatewayConfig() + self.assertEqual(config.host, "device_host") + self.assertEqual(config.port, 1884) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/common/test_exceptions.py b/tests/common/test_exceptions.py new file mode 100644 index 0000000..03368a9 --- /dev/null +++ b/tests/common/test_exceptions.py @@ -0,0 +1,156 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import List, Tuple, Optional + +import pytest + +from tb_mqtt_client.common.exceptions import ExceptionHandler, BackpressureException + + +def make_recorder(storage: List[Tuple[str, Optional[dict]]]): + def _recorder(exc: BaseException, context: Optional[dict]): + storage.append((exc.__class__.__name__, context)) + + return _recorder + + +def test_register_and_handle_specific_exception_calls_callback(): + handler = ExceptionHandler() + events: List[Tuple[str, Optional[dict]]] = [] + cb = make_recorder(events) + + handler.register(ValueError, cb) + + err = ValueError("boom") + ctx = {"info": "ctx"} + handler.handle(err, ctx) + + assert events == [("ValueError", ctx)] + + +def test_default_callback_called_when_no_specific_registered(): + handler = ExceptionHandler() + events: List[Tuple[str, Optional[dict]]] = [] + default_cb = make_recorder(events) + + handler.register(None, default_cb) + + err = KeyError("missing") + ctx = {"k": "v"} + handler.handle(err, ctx) + + assert events == [("KeyError", ctx)] + + +def test_specific_takes_precedence_default_not_called(): + handler = ExceptionHandler() + events: List[Tuple[str, Optional[dict]]] = [] + specific_events: List[Tuple[str, Optional[dict]]] = [] + + default_cb = make_recorder(events) + specific_cb = make_recorder(specific_events) + + handler.register(None, default_cb) + handler.register(RuntimeError, specific_cb) + + err = RuntimeError("nope") + ctx = {"tag": 1} + handler.handle(err, ctx) + + assert specific_events == [("RuntimeError", ctx)] + assert events == [] + + +def test_multiple_callbacks_for_same_exception_type_all_called(): + handler = ExceptionHandler() + calls: List[Tuple[str, Optional[dict]]] = [] + + cb1 = make_recorder(calls) + cb2 = make_recorder(calls) + handler.register(IndexError, cb1) + handler.register(IndexError, cb2) + + err = IndexError("oops") + ctx = {"n": 42} + handler.handle(err, ctx) + + assert calls == [("IndexError", ctx), ("IndexError", ctx)] + + +def test_subclass_matching_triggers_all_matching_callbacks_and_skips_default(): + handler = ExceptionHandler() + + calls_exception: List[Tuple[str, Optional[dict]]] = [] + calls_valueerror: List[Tuple[str, Optional[dict]]] = [] + default_calls: List[Tuple[str, Optional[dict]]] = [] + + cb_exception = make_recorder(calls_exception) + cb_valueerror = make_recorder(calls_valueerror) + cb_default = make_recorder(default_calls) + + handler.register(Exception, cb_exception) + handler.register(ValueError, cb_valueerror) + handler.register(None, cb_default) + + err = ValueError("bad value") + ctx = {"a": 1} + handler.handle(err, ctx) + + assert len(calls_exception) == 1 + assert len(calls_valueerror) == 1 + assert calls_exception[0][0] == "ValueError" + assert calls_valueerror[0][0] == "ValueError" + + assert default_calls == [] + + +@pytest.mark.asyncio +async def test_install_asyncio_handler_dispatches_on_loop_exception(event_loop: asyncio.AbstractEventLoop): + """ + Instead of relying on an un-awaited task (flaky under uvloop), + directly call the loop's exception handler with a context containing an exception. + This validates that install_asyncio_handler wires _asyncio_handler correctly, + and that it dispatches to ExceptionHandler.handle(...). + """ + handler = ExceptionHandler() + + calls: List[Tuple[str, Optional[dict]]] = [] + cb = make_recorder(calls) + handler.register(ValueError, cb) + + handler.install_asyncio_handler(event_loop) + + err = ValueError("async oops") + context = {"exception": err, "message": "unit-test"} + event_loop.call_exception_handler(context) + + await asyncio.sleep(0) + + assert calls == [("ValueError", context)] + + +def test_backpressure_exception_default_message(): + exc = BackpressureException() + assert str(exc) == "Client is under backpressure. Please retry later." + + +def test_backpressure_exception_custom_message(): + exc = BackpressureException("Hold up!") + assert str(exc) == "Hold up!" + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_gmqtt_patch.py b/tests/common/test_gmqtt_patch.py new file mode 100644 index 0000000..172860b --- /dev/null +++ b/tests/common/test_gmqtt_patch.py @@ -0,0 +1,434 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import heapq +import struct +import types + +import pytest +from gmqtt.mqtt.constants import MQTTv50, MQTTv311, MQTTCommands +from gmqtt.mqtt.handler import MqttPackageHandler +from gmqtt.mqtt.protocol import MQTTProtocol + +from tb_mqtt_client.common.gmqtt_patch import PatchUtils, PublishPacket +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage + + +def test_parse_mqtt_properties_valid_and_invalid(): + pkt = bytes([1]) + bytes([255]) + assert PatchUtils.parse_mqtt_properties(pkt) == {} + + assert PatchUtils.parse_mqtt_properties(b"\xff") == {} + + +def test_extract_reason_code_all_paths(): + class Obj: + reason_code = 42 + + assert PatchUtils.extract_reason_code(Obj()) == 42 + assert PatchUtils.extract_reason_code(b"\x00*") == 42 + assert PatchUtils.extract_reason_code(b"") is None + assert PatchUtils.extract_reason_code(None) is None + + +def test_patch_puback_handling_and_storage(monkeypatch): + pu = PatchUtils(None, asyncio.Event()) + called = {} + + def on_puback(mid, reason, props): + called["hit"] = (mid, reason, props) + + monkeypatch.setattr("gmqtt.mqtt.handler.MqttPackageHandler._handle_puback_packet", lambda *a, **k: None) + pu.patch_puback_handling(on_puback) + + handler = types.SimpleNamespace( + _connection=types.SimpleNamespace( + persistent_storage=types.SimpleNamespace(remove=lambda m: None) + ) + ) + packet = struct.pack("!HB", 0xABCD, 0x10) + MqttPackageHandler._handle_puback_packet(handler, MQTTCommands.PUBACK, packet) + assert called["hit"][0] == 0xABCD + assert called["hit"][1] == 0x10 + assert isinstance(called["hit"][2], dict) + + client = types.SimpleNamespace( + _persistent_storage=types.SimpleNamespace( + _queue=[(0.0, 1, "raw")], + _check_empty=lambda: None + ) + ) + pu.client = client + pu.patch_storage() + + tm, mid, raw = asyncio.get_event_loop().run_until_complete(client._persistent_storage.pop_message()) + assert mid == 1 and raw == "raw" + + +@pytest.mark.asyncio +async def test_retry_loop_and_task_controls(monkeypatch): + storage_queue = [] + + class Storage: + def _check_empty(self): + pass + + async def pop_message(self): + if not storage_queue: + raise IndexError + return storage_queue.pop(0) + + msgs_sent = [] + + class FakeClient: + is_connected = True + + async def put_retry_message(self, msg): + msgs_sent.append(msg) + + _persistent_storage = Storage() + + pu = PatchUtils(FakeClient(), asyncio.Event(), retry_interval=0) + heapq.heappush(storage_queue, (0.0, 1, types.SimpleNamespace(topic="t", dup=False))) + pu._stop_event.set() + await pu._retry_loop() + assert not msgs_sent + pu._stop_event.clear() + pu.start_retry_task() + assert pu._retry_task is not None + await pu.stop_retry_task() + assert pu._retry_task is None + + pu._retry_task = asyncio.create_task(asyncio.sleep(1)) + await pu.stop_retry_task() + assert pu._retry_task is None + + +def test_apply_calls_patch_and_starts_task(monkeypatch): + pu = PatchUtils(None, asyncio.Event()) + monkeypatch.setattr(pu, "patch_puback_handling", lambda cb: setattr(pu, "_patched", True)) + monkeypatch.setattr(pu, "start_retry_task", lambda: setattr(pu, "_started", True)) + pu.apply(lambda a, b, c: None) + assert pu._patched + assert pu._started + + +def test_build_package_qos1_with_provided_mid(): + msg = MqttPublishMessage(topic="topic", payload=b"PAY", qos=1, retain=True) + msg.dup = True + protocol = types.SimpleNamespace(proto_ver=5) + mid, packet = PublishPacket.build_package(msg, protocol, mid=77) + assert mid == 77 + assert struct.pack("!H", 77) in packet + assert packet[0] & 0x08 + + +def test_build_package_qos1_with_generated_mid(monkeypatch): + msg = MqttPublishMessage(topic="gen", payload=b"PAY", qos=1) + monkeypatch.setattr(PublishPacket, "id_generator", types.SimpleNamespace(next_id=lambda: 1234)) + protocol = types.SimpleNamespace(proto_ver=5) + mid, packet = PublishPacket.build_package(msg, protocol) + assert mid == 1234 + assert struct.pack("!H", 1234) in packet + + +# ------------------------------ +# New tests (extended coverage) +# ------------------------------ + +def test_build_package_qos0_no_mid_empty_payload(): + msg = MqttPublishMessage(topic="t/0", payload=b"", qos=0, retain=False) + protocol = types.SimpleNamespace(proto_ver=5) + mid, packet = PublishPacket.build_package(msg, protocol) + assert mid is None + assert b"t/0" in packet + assert (packet[0] >> 1) & 0b11 == 0 + + +@pytest.mark.asyncio +async def test_patch_mqtt_handler_disconnect_invokes_on_disconnect_and_reconnect(monkeypatch): + assert PatchUtils.patch_mqtt_handler_disconnect(None) is True + + reconnect_called = asyncio.Event() + on_disconnect_args = {} + + async def fake_reconnect(delay=True): + reconnect_called.set() + + def fake_on_disconnect(self_like, reason_code, properties, exc): + on_disconnect_args["rc"] = reason_code + on_disconnect_args["props"] = properties + on_disconnect_args["exc"] = exc + + fake_connection = types.SimpleNamespace(_on_disconnect_called=False) + self_like = types.SimpleNamespace( + _clear_topics_aliases=lambda: None, + reconnect=fake_reconnect, + _handle_exception_in_future=lambda f: None, + on_disconnect=fake_on_disconnect, + _connection=fake_connection, + ) + + packet = bytes([151]) + bytes([0]) + + MqttPackageHandler._handle_disconnect_packet(self_like, MQTTCommands.DISCONNECT, packet) + + await asyncio.wait_for(reconnect_called.wait(), timeout=1.0) + assert on_disconnect_args["rc"] == 151 + assert isinstance(on_disconnect_args["props"], dict) + assert on_disconnect_args["exc"] is None + assert self_like._connection._on_disconnect_called is True + + +@pytest.mark.asyncio +async def test_patch_handle_connack_success_and_error_paths(monkeypatch): + assert PatchUtils.patch_handle_connack(None) is True + + connected_set = {"hit": False} + on_connect_called = {} + + def _connected_set(): + connected_set["hit"] = True + + success_self = types.SimpleNamespace( + _connected=types.SimpleNamespace(set=_connected_set), + _logger=types.SimpleNamespace(warning=lambda *a, **k: None, + info=lambda *a, **k: None, + debug=lambda *a, **k: None), + failed_connections=0, + protocol_version=MQTTv50, + _parse_properties=lambda payload: ({'x': 'y'}, b""), + _update_keepalive_if_needed=lambda: None, + properties={}, + on_connect=lambda *args: on_connect_called.setdefault("ok", args), + disconnect=lambda: asyncio.Future(), + reconnect=lambda delay=True: asyncio.Future(), + _handle_exception_in_future=lambda f: None, + ) + + packet_ok = struct.pack("!BB", 1, 0) + b"\x00" + MqttPackageHandler._handle_connack_packet(success_self, MQTTCommands.CONNACK, packet_ok) + assert connected_set["hit"] is True + assert "ok" in on_connect_called + + downgraded = {} + + fut = asyncio.get_event_loop().create_future() + fut.set_result(None) + + def fake_reconnect(delay=True): + downgraded["called"] = True + f = asyncio.get_event_loop().create_future() + f.set_result(None) + return f + + error_self = types.SimpleNamespace( + _connected=types.SimpleNamespace(set=lambda: None), + _logger=types.SimpleNamespace(warning=lambda *a, **k: None, + info=lambda *a, **k: None, + debug=lambda *a, **k: None), + failed_connections=0, + protocol_version=MQTTv50, + _parse_properties=lambda payload: ({}, b""), + _update_keepalive_if_needed=lambda: None, + properties={}, + on_connect=lambda *a, **k: None, + reconnect=fake_reconnect, + _handle_exception_in_future=lambda f: None, + ) + + packet_err = struct.pack("!BB", 0, 1) + MqttPackageHandler._handle_connack_packet(error_self, MQTTCommands.CONNACK, packet_err) + assert downgraded.get("called") is True + assert MQTTProtocol.proto_ver == MQTTv311 + + +@pytest.mark.asyncio +async def test_patch_gmqtt_protocol_connection_lost_and_handler_call(monkeypatch): + assert PatchUtils.patch_gmqtt_protocol_connection_lost(None) is True + + disconnect_pkgs = [] + + def put_package(pkg): + disconnect_pkgs.append(pkg) + + loop = asyncio.get_event_loop() + read_future = loop.create_future() + + class FakeMQTTProtocol(MQTTProtocol): + def __init__(self): + pass + + proto_self = FakeMQTTProtocol() + proto_self._connected = types.SimpleNamespace(clear=lambda: None) + proto_self._connection = types.SimpleNamespace( + _disconnect_exc=None, + _disconnect_properties=None, + put_package=put_package + ) + proto_self._read_loop_future = read_future + proto_self._queue = None + + proto_self._stream_reader_wr = lambda: types.SimpleNamespace( + feed_eof=lambda: None, + set_exception=lambda exc: None + ) + proto_self._closed = loop.create_future() + proto_self._paused = False + + proto_self.connection_lost(TimeoutError("simulated")) + + assert disconnect_pkgs and disconnect_pkgs[0][0] == MQTTCommands.DISCONNECT + payload = disconnect_pkgs[0][1] + assert isinstance(payload, (bytes, bytearray)) and len(payload) == 1 + + reconnect_flag = asyncio.Event() + on_disc_args = {} + + async def fake_reconnect(delay=True): + reconnect_flag.set() + + handler_self = types.SimpleNamespace( + _connection=proto_self._connection, + _clear_topics_aliases=lambda: None, + reconnect=fake_reconnect, + _handle_exception_in_future=lambda f: None, + on_disconnect=lambda *args: on_disc_args.setdefault("args", args), + ) + + MqttPackageHandler.__call__(handler_self, MQTTCommands.DISCONNECT, b"\x01") + await asyncio.wait_for(reconnect_flag.wait(), timeout=1.0) + assert "args" in on_disc_args + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "exc, expected_reason", + [ + (ConnectionRefusedError("boom"), 135), # Keep Alive timeout + (TimeoutError("boom"), 135), # Keep Alive timeout + (ConnectionResetError("boom"), 139), # Receive Maximum exceeded + (ConnectionAbortedError("boom"), 136), # Session taken over + (PermissionError("boom"), 132), # Not authorized + (OSError("boom"), 130), # Protocol Error + (ValueError("boom"), 131), # Implementation specific error (fallback) + ], +) +async def test_connection_lost_reason_code_mapping(monkeypatch, exc, expected_reason): + assert PatchUtils.patch_gmqtt_protocol_connection_lost(None) is True + + def build_proto(): + disconnect_pkgs = [] + + def put_package(pkg): + disconnect_pkgs.append(pkg) + + loop = asyncio.get_event_loop() + read_future = loop.create_future() + + class FakeMQTTProtocol(MQTTProtocol): + def __init__(self): + pass + + proto = FakeMQTTProtocol() + proto._connected = types.SimpleNamespace(clear=lambda: None) + proto._connection = types.SimpleNamespace( + _disconnect_exc=None, + _disconnect_properties=None, + put_package=put_package + ) + proto._read_loop_future = read_future + proto._queue = None + proto._stream_reader_wr = lambda: types.SimpleNamespace( + feed_eof=lambda: None, + set_exception=lambda e: None + ) + proto._closed = loop.create_future() + proto._paused = False + return proto, disconnect_pkgs + + proto_self, disconnect_pkgs = build_proto() + + proto_self.connection_lost(exc) + + assert disconnect_pkgs and disconnect_pkgs[0][0] == MQTTCommands.DISCONNECT + payload = disconnect_pkgs[0][1] + assert isinstance(payload, (bytes, bytearray)) and len(payload) == 1 + assert payload[0] == expected_reason + + props = getattr(proto_self._connection, "_disconnect_properties", {}) + assert isinstance(props, dict) + assert "reason_string" in props + assert isinstance(props["reason_string"], list) and props["reason_string"][0] == str(exc.args[0]) + + +def test_patch_puback_handling_with_properties(monkeypatch): + pu = PatchUtils(None, asyncio.Event()) + got = {} + + def on_puback(mid, reason, props): + got["mid"] = mid + got["reason"] = reason + got["props"] = props + + monkeypatch.setattr("gmqtt.mqtt.handler.MqttPackageHandler._handle_puback_packet", lambda *a, **k: None) + pu.patch_puback_handling(on_puback) + + from gmqtt.mqtt.property import Property + from gmqtt.mqtt.utils import pack_variable_byte_integer + + reason_prop = Property.factory(name="reason_string").dumps("OK") + props_len_prefix = pack_variable_byte_integer(len(reason_prop)) + props_payload = bytes(props_len_prefix) + bytes(reason_prop) + + packet = struct.pack("!H", 0x1234) + bytes([0x00]) + props_payload + + handler_self = types.SimpleNamespace( + _connection=types.SimpleNamespace( + persistent_storage=types.SimpleNamespace(remove=lambda m: None) + ) + ) + + # Act + MqttPackageHandler._handle_puback_packet(handler_self, MQTTCommands.PUBACK, packet) + + assert got["mid"] == 0x1234 + assert got["reason"] == 0x00 + assert isinstance(got["props"], dict) + assert got["props"].get("reason_string") == ["OK"] + + +def test_parse_properties_reason_string_and_user_property_real(): + from gmqtt.mqtt.property import Property + from gmqtt.mqtt.utils import pack_variable_byte_integer + + rs = Property.factory(name="reason_string").dumps("OK") + up = Property.factory(name="user_property").dumps(("k", "v")) + props = bytes(rs + up) + + packet = bytes(pack_variable_byte_integer(len(props))) + props + parsed = PatchUtils.parse_mqtt_properties(packet) + + assert parsed["reason_string"] == ["OK"] + assert parsed["user_property"] == [("k", "v")] + + +def test_parse_properties_unknown_property_returns_empty(): + packet = b"\x01" + b"\xff" + assert PatchUtils.parse_mqtt_properties(packet) == {} + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_install_package_utils.py b/tests/common/test_install_package_utils.py new file mode 100644 index 0000000..0e2a0e8 --- /dev/null +++ b/tests/common/test_install_package_utils.py @@ -0,0 +1,90 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from subprocess import CalledProcessError +from unittest import mock + +import pytest +from pkg_resources import DistributionNotFound + +from tb_mqtt_client.common.install_package_utils import install_package + + +@pytest.mark.parametrize("version", ["upgrade", "UPGRADE"]) +def test_install_package_upgrade_success(version): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", return_value=0) as mock_call: + result = install_package("somepkg", version) + assert result is True + mock_call.assert_called_once() + + +def test_install_package_upgrade_retry_success(): + # First fails with --user, then succeeds without + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=[CalledProcessError(1, ""), 0]) as mock_call: + result = install_package("somepkg", "upgrade") + assert result is True + assert mock_call.call_count == 2 + + +def test_install_package_upgrade_all_fail(): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=CalledProcessError(1, "")) as mock_call: + result = install_package("somepkg", "upgrade") + assert result is False + assert mock_call.call_count == 2 + + +def test_install_package_specific_version_already_installed(): + dist = mock.Mock() + dist.version = "1.2.3" + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", return_value=dist): + result = install_package("somepkg", "1.2.3") + assert result is True + + +@pytest.mark.parametrize("version_input,expected_arg", [ + ("1.2.4", "somepkg==1.2.4"), + (">=1.2.4", "somepkg>=1.2.4"), +]) +def test_install_package_specific_version_install_success(version_input, expected_arg): + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", side_effect=DistributionNotFound): + side_effects = [ + CalledProcessError(1, ["pip"]), + 0 + ] + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=side_effects) as mock_call: + result = install_package("somepkg", version_input) + assert result is True + calls = [str(call.args) for call in mock_call.call_args_list] + assert any(expected_arg in args for args in calls) + + +def test_install_package_specific_version_retry_success(): + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", side_effect=DistributionNotFound): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=[CalledProcessError(1, ""), 0]) as mock_call: + result = install_package("somepkg", "1.0.0") + assert result is True + assert mock_call.call_count == 2 + + +def test_install_package_specific_version_all_fail(): + with mock.patch("tb_mqtt_client.common.install_package_utils.get_distribution", side_effect=DistributionNotFound): + with mock.patch("tb_mqtt_client.common.install_package_utils.check_call", + side_effect=CalledProcessError(1, "")) as mock_call: + result = install_package("somepkg", "1.0.0") + assert result is False + assert mock_call.call_count == 2 diff --git a/tests/common/test_provisioning_client.py b/tests/common/test_provisioning_client.py new file mode 100644 index 0000000..51383ab --- /dev/null +++ b/tests/common/test_provisioning_client.py @@ -0,0 +1,116 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.provisioning_client import ProvisioningClient +from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse + + +@pytest.fixture +def real_request(): + return ProvisioningRequest( + host="test_host", + credentials=AccessTokenProvisioningCredentials("key", "secret"), + port=1883, + device_name="test_device" + ) + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageAdapter") +async def test_successful_provisioning_flow(mock_dispatcher_cls, mock_gmqtt_cls, real_request): + mock_client = AsyncMock() + mock_gmqtt_cls.return_value = mock_client + + mock_adapter = MagicMock() + topic = "provision/topic" + payload = b'{"provision": "data"}' + mock_message = MqttPublishMessage(topic=topic, payload=payload) + mock_adapter.build_provision_request.return_value = mock_message + + mock_provisioning_response = ProvisioningResponse.build(real_request, {"credentialsValue": "config_value", "status": "SUCCESS"}) + mock_adapter.parse_provisioning_response.return_value = mock_provisioning_response + mock_dispatcher_cls.return_value = mock_adapter + + client = ProvisioningClient("test_host", 1883, real_request) + + client._on_connect(mock_client, None, 0, None) + + mock_client.subscribe.assert_called_once_with(PROVISION_RESPONSE_TOPIC) + mock_client.publish.assert_called_once_with(mock_message.topic, mock_message.payload) + + await client._on_message(None, None, b"payload-data", None, None) + + assert client._provisioning_response == mock_provisioning_response + assert client._provisioned.is_set() + mock_client.disconnect.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageAdapter") +async def test_failed_connection(mock_dispatcher_cls, mock_gmqtt_cls, real_request, caplog): + mock_client = AsyncMock() + mock_gmqtt_cls.return_value = mock_client + + client = ProvisioningClient("localhost", 1883, real_request) + + with caplog.at_level("ERROR"): + client._on_connect(mock_client, None, 1, None) + + assert client._provisioning_response is not None + assert client._provisioning_response.status == ProvisioningResponseStatus.ERROR + assert client._provisioned.is_set() + assert "Cannot connect to ThingsBoard!" in caplog.text + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.common.provisioning_client.GMQTTClient") +@patch("tb_mqtt_client.common.provisioning_client.JsonMessageAdapter") +async def test_provision_method_awaits_provisioned(mock_dispatcher_cls, mock_gmqtt_cls, real_request): + mock_client = AsyncMock() + mock_gmqtt_cls.return_value = mock_client + + client = ProvisioningClient("localhost", 1883, real_request) + expected_config = MagicMock() + client._provisioning_response = expected_config + client._provisioned.set() + + result = await client.provision() + + mock_client.connect.assert_awaited_once_with("localhost", 1883) + assert result == expected_config + + +def test_initial_state(real_request): + client = ProvisioningClient("host", 1234, real_request) + + assert client._host == "host" + assert client._port == 1234 + assert client._provision_request == real_request + assert client._client_id == "provision" + assert not client._provisioned.is_set() + assert client._provisioning_response is None + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_publish_result.py b/tests/common/test_publish_result.py new file mode 100644 index 0000000..dbe57cb --- /dev/null +++ b/tests/common/test_publish_result.py @@ -0,0 +1,158 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.common.publish_result import PublishResult + + +@pytest.fixture +def default_publish_result(): + return PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + + +def test_publish_result_attributes(default_publish_result): + assert default_publish_result.topic == "v1/devices/me/telemetry" + assert default_publish_result.qos == 1 + assert default_publish_result.message_id == 123 + assert default_publish_result.payload_size == 256 + assert default_publish_result.reason_code == 0 + + +def test_publish_result_repr(default_publish_result): + result = repr(default_publish_result) + assert isinstance(result, str) + assert "PublishResult" in result + assert "v1/devices/me/telemetry" in result + assert "qos=1" in result + assert "message_id=123" in result + assert "payload_size=256" in result + assert "reason_code=0" in result + assert "datapoints_count=0" in result + + +def test_publish_result_as_dict(default_publish_result): + d = default_publish_result.as_dict() + assert isinstance(d, dict) + assert d == { + "topic": "v1/devices/me/telemetry", + "qos": 1, + "message_id": 123, + "payload_size": 256, + "reason_code": 0, + "datapoints_count": 0 + } + + +def test_publish_request_merge(): + result1 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + result2 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=124, + payload_size=512, + reason_code=0 + ) + merged_result = PublishResult.merge([result1, result2]) + + assert merged_result.topic == "v1/devices/me/telemetry" + assert merged_result.qos == 1 + assert merged_result.message_id == -1 # Merged results do not have a specific message_id + assert merged_result.payload_size == 768 # Combined payload size + assert merged_result.reason_code == 0 # All successful + + +def test_publish_result_merge_with_empty_list(): + with pytest.raises(ValueError, match="No publish results to merge."): + PublishResult.merge([]) + + +def test_publish_result_is_successful_true(default_publish_result): + assert default_publish_result.is_successful() is True + + +def test_publish_result_is_successful_false(): + result = PublishResult( + topic="v1/devices/me/attributes", + qos=0, + message_id=999, + payload_size=0, + reason_code=128 # Simulated failure + ) + assert result.is_successful() is False + + +def test_publish_result_equality(): + result1 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + result2 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + assert result1 == result2 + + +def test_publish_result_inequality(): + result1 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=123, + payload_size=256, + reason_code=0 + ) + result2 = PublishResult( + topic="v1/devices/me/telemetry", + qos=1, + message_id=124, # Different message_id + payload_size=256, + reason_code=0 + ) + assert result1 != result2 + assert result1 != "Not a PublishResult" # Different type comparison + + +@pytest.mark.parametrize("reason_code", [1, 2, 3, 16, 255]) +def test_publish_result_various_failure_codes(reason_code): + result = PublishResult( + topic="v1/devices/me/rpc", + qos=2, + message_id=42, + payload_size=100, + reason_code=reason_code + ) + assert result.is_successful() is False + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py new file mode 100644 index 0000000..75b02e5 --- /dev/null +++ b/tests/common/test_queue.py @@ -0,0 +1,158 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from types import SimpleNamespace + +import pytest + +from tb_mqtt_client.common.queue import AsyncDeque + + +@pytest.fixture +def make_item(): + """Factory to create unique or duplicate-like test items.""" + + def _make(uuid): + return SimpleNamespace(uuid=uuid) + + return _make + + +@pytest.mark.asyncio +async def test_put_and_get(make_item): + q = AsyncDeque(maxlen=5) + item = make_item("a") + await q.put(item) + assert q.size() == 1 + assert not q.is_empty() + + got = await q.get() + assert got.uuid == "a" + assert q.is_empty() + + +@pytest.mark.asyncio +async def test_put_duplicate_not_added(make_item): + q = AsyncDeque(maxlen=5) + item = make_item("dup") + await q.put(item) + await q.put(item) # duplicate + assert q.size() == 1 + + +@pytest.mark.asyncio +async def test_extend_and_duplicates(make_item): + q = AsyncDeque(maxlen=5) + items = [make_item(str(i)) for i in range(3)] + await q.extend(items) + assert q.size() == 3 + + # Duplicate extend + await q.extend(items) + assert q.size() == 3 + + +@pytest.mark.asyncio +async def test_put_left_and_duplicate(make_item): + q = AsyncDeque(maxlen=5) + item = make_item("x") + await q.put_left(item) + assert list(await q.peek_batch(1))[0].uuid == "x" + + await q.put_left(item) # duplicate, ignored + assert q.size() == 1 + + +@pytest.mark.asyncio +async def test_extend_left_and_duplicates(make_item): + q = AsyncDeque(maxlen=5) + items = [make_item(str(i)) for i in range(3)] + await q.extend_left(items) + assert [it.uuid for it in await q.peek_batch(3)] == ["0", "1", "2"] + + await q.extend_left(items) # duplicates ignored + assert q.size() == 3 + + +@pytest.mark.asyncio +async def test_peek_and_peek_batch_waits(make_item): + q = AsyncDeque(maxlen=5) + + async def delayed_put(): + await asyncio.sleep(0.01) + await q.put(make_item("peeked")) + + asyncio.create_task(delayed_put()) + first = await q.peek() + assert first.uuid == "peeked" + + batch = await q.peek_batch(1) + assert batch[0].uuid == "peeked" + + +@pytest.mark.asyncio +async def test_pop_n_partial_and_full(make_item): + q = AsyncDeque(maxlen=5) + items = [make_item(str(i)) for i in range(3)] + await q.extend(items) + + # pop more than available + popped = await q.pop_n(5) + assert [it.uuid for it in popped] == ["0", "1", "2"] + assert q.is_empty() + + # fill again + await q.extend(items) + popped2 = await q.pop_n(2) + assert [it.uuid for it in popped2] == ["0", "1"] + assert q.size() == 1 + + +@pytest.mark.asyncio +async def test_reinsert_front_duplicate_and_new(make_item): + q = AsyncDeque(maxlen=5) + item1 = make_item("1") + item2 = make_item("2") + + # Insert new + await q.reinsert_front(item1) + assert q.size() == 1 + assert (await q.peek()).uuid == "1" + + # Try duplicate + await q.reinsert_front(item1) + assert q.size() == 1 + + # Insert another new + await q.reinsert_front(item2) + assert q.size() == 2 + assert (await q.peek()).uuid == "2" + + +@pytest.mark.asyncio +async def test_get_waits_until_item_available(make_item): + q = AsyncDeque(maxlen=5) + + async def delayed_put(): + await asyncio.sleep(0.01) + await q.put(make_item("later")) + + asyncio.create_task(delayed_put()) + got = await q.get() + assert got.uuid == "later" + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/common/test_rate_limit.py b/tests/common/test_rate_limit.py new file mode 100644 index 0000000..8556f2b --- /dev/null +++ b/tests/common/test_rate_limit.py @@ -0,0 +1,230 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import math +from time import sleep + +import pytest + +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, GreedyTokenBucket + + +def test_greedy_token_bucket_can_consume_and_consume(): + bucket = GreedyTokenBucket(10, 1) + assert bucket.can_consume(5) is True + assert bucket.consume(5) is True + assert bucket.tokens <= 5 + + assert bucket.can_consume(6) is False + assert bucket.consume(6) is False + + sleep(1.1) + bucket.refill() + assert round(bucket.tokens) == 10 + + +def test_greedy_token_bucket_get_remaining_tokens_and_refill(): + bucket = GreedyTokenBucket(100, 2) + bucket.consume(50) + before = bucket.get_remaining_tokens() + sleep(0.5) + after = bucket.get_remaining_tokens() + assert after > before + + +@pytest.mark.asyncio +async def test_rate_limit_no_limit_behavior(): + rl = RateLimit("", "test") + assert await rl.check_limit_reached() is False + assert await rl.try_consume() is None + assert await rl.consume() is None + assert rl.minimal_limit == 0 + assert rl.minimal_timeout == 0 + assert rl.has_limit() is False + assert await rl.reach_limit() is None + assert rl.to_dict()["no_limit"] is True + + +@pytest.mark.asyncio +async def test_rate_limit_basic_limit_behavior(): + rl = RateLimit("10:1,20:2", "test") + assert rl.has_limit() is True + assert rl.minimal_limit == 10 * 0.8 + assert rl.minimal_timeout >= 2 + + for _ in range(8): + assert await rl.try_consume() is None + + warning = await rl.try_consume() + assert isinstance(warning, tuple) + assert warning[0] == 8.0 + assert warning[1] == 1.0 + + exceeded = await rl.try_consume() + assert isinstance(exceeded, tuple) + + await rl.consume() + + +@pytest.mark.asyncio +async def test_rate_limit_reach_limit_rotation(): + rl = RateLimit("1:1,1:2", "test") + await rl.reach_limit() + await rl.reach_limit() + result = await rl.reach_limit() + + assert isinstance(result, tuple) + index, timestamp, dur = result + assert index >= 0 + assert dur in (1, 2) + + +@pytest.mark.asyncio +async def test_rate_limit_set_limit_resets_state(): + rl = RateLimit("100:10", "test") + original = rl.to_dict() + await rl.set_limit("10:1", percentage=50) + new_state = rl.to_dict() + + assert new_state["rateLimits"] != original["rateLimits"] + assert rl.percentage == 50 + + +@pytest.mark.asyncio +async def test_rate_limit_invalid_format(caplog): + rl = RateLimit("invalid_format,10:xyz", "bad_config") + assert rl.has_limit() is False or isinstance(rl.to_dict(), dict) + assert any("Invalid rate limit format" in rec.message for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_rate_limit_refill_behavior(): + rl = RateLimit("3:1", "refill-test") + await rl.try_consume() + await rl.try_consume() + await rl.try_consume() + assert (await rl.try_consume()) is not None + + await asyncio.sleep(1.1) + await rl.refill() + assert (await rl.try_consume()) is None + + +@pytest.mark.asyncio +async def test_set_required_tokens_and_clear_event(): + rl = RateLimit("5:1", "token-test") + rl.set_required_tokens(1, 3) + # Initially event should not be set + assert not rl.required_tokens_ready.is_set() + # Force refill to satisfy requirement + for bucket in rl._rate_buckets.values(): + bucket.tokens = 3 + await rl.refill() + assert rl.required_tokens_ready.is_set() + rl.clear_required_tokens_event() + assert not rl.required_tokens_ready.is_set() + + +@pytest.mark.asyncio +async def test_refill_does_not_break_with_no_limit(): + rl = RateLimit("", "no-limit-refill") + # Should not raise and should not change anything + await rl.refill() + assert rl._no_limit + + +@pytest.mark.asyncio +async def test_consume_with_real_limit_changes_tokens(): + rl = RateLimit("5:1", "consume-test") + before = rl._rate_buckets[1].tokens + await rl.consume(1) + after = rl._rate_buckets[1].tokens + assert after < before + + +def test_minimal_limit_and_timeout_no_limit_case(): + rl = RateLimit("", "min-test") + assert rl.minimal_limit == 0 + assert rl.minimal_timeout == 0 + + +@pytest.mark.asyncio +async def test_to_dict_contains_expected_structure(): + rl = RateLimit("5:1", "dict-test") + d = rl.to_dict() + assert "name" in d and d["name"] == "dict-test" + assert "percentage" in d and isinstance(d["percentage"], int) + assert "rateLimits" in d and isinstance(d["rateLimits"], dict) + + +@pytest.mark.asyncio +async def test_set_limit_with_no_entries_sets_no_limit(): + rl = RateLimit("5:1", "reset-test") + await rl.set_limit("0:0,") + assert not rl.has_limit() + + +@pytest.mark.asyncio +async def test_try_consume_returns_tuple_when_not_enough_tokens(): + rl = RateLimit("1:1", "tuple-test") + # Exhaust tokens + await rl.try_consume() + res = await rl.try_consume() + assert isinstance(res, tuple) + assert len(res) == 2 + + +def test_parse_string_with_invalid_entries(): + # This will hit logger.warning branch for parse_string exception + rl = RateLimit("abc:def", "invalid-test") + assert not rl.has_limit() + + +@pytest.mark.asyncio +async def test_reach_limit_when_no_limit_returns_none(): + rl = RateLimit("", "reach-none") + assert await rl.reach_limit() is None + + +@pytest.mark.asyncio +async def test_reach_limit_resets_index_on_duration_window(): + rl = RateLimit("1:1,1:2", "reset-index") + rl._RateLimit__reached_index = 10 # force overflow + rl._RateLimit__reached_index_time = 0 + res = await rl.reach_limit() + assert isinstance(res, tuple) + assert rl._RateLimit__reached_index > 0 + + +def test_env_variable_invalid(monkeypatch): + # Simulate bad env variable + monkeypatch.setenv("TB_DEFAULT_RATE_LIMIT_PERCENTAGE", "bad") + # Reload module under test to re-run env reading + import importlib + import tb_mqtt_client.common.rate_limit.rate_limit as rl_mod + importlib.reload(rl_mod) + assert rl_mod.DEFAULT_RATE_LIMIT_PERCENTAGE == 80 + + +def test_greedy_token_bucket_edge_cases(): + b = GreedyTokenBucket(2, 1) + b.tokens = 0 + assert not b.can_consume(1) + assert not b.consume(5) + assert math.isclose(b.get_remaining_tokens(), b.tokens) + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/constants/__init__.py b/tests/constants/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/constants/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/constants/test_mqtt_topics.py b/tests/constants/test_mqtt_topics.py new file mode 100644 index 0000000..f8f19e3 --- /dev/null +++ b/tests/constants/test_mqtt_topics.py @@ -0,0 +1,21 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tb_mqtt_client.constants import mqtt_topics + + +def test_device_topic_builders(): + assert mqtt_topics.build_device_attributes_request_topic(42) == "v1/devices/me/attributes/request/42" + assert mqtt_topics.build_device_rpc_request_topic(99) == "v1/devices/me/rpc/request/99" + assert mqtt_topics.build_device_rpc_response_topic(99) == "v1/devices/me/rpc/response/99" diff --git a/tests/entities/__init__.py b/tests/entities/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/entities/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entities/data/__init__.py b/tests/entities/data/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/entities/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entities/data/test_attribute_entry.py b/tests/entities/data/test_attribute_entry.py new file mode 100644 index 0000000..a003fd6 --- /dev/null +++ b/tests/entities/data/test_attribute_entry.py @@ -0,0 +1,63 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry + + +@pytest.mark.parametrize("key, value", [ + ("temperature", 25), + ("status", True), + ("label", "sensor-A"), + ("list_val", [1, 2, 3]), + ("dict_val", {"k": "v"}) +]) +def test_attribute_entry_as_dict(key, value): + entry = AttributeEntry(key, value) + expected = {"key": key, "value": value} + assert entry.as_dict() == expected + + +@pytest.mark.parametrize("key, value", [ + ("a", 1), + ("b", "val"), + ("c", {"k": "v"}) +]) +def test_attribute_entry_repr(key, value): + entry = AttributeEntry(key, value) + assert repr(entry) == f"AttributeEntry(key={key}, value={value})" + + +def test_attribute_entry_eq_same(): + e1 = AttributeEntry("k1", 123) + e2 = AttributeEntry("k1", 123) + assert e1 == e2 + + +def test_attribute_entry_eq_diff_key(): + e1 = AttributeEntry("k1", 123) + e2 = AttributeEntry("k2", 123) + assert e1 != e2 + + +def test_attribute_entry_eq_diff_value(): + e1 = AttributeEntry("k1", 123) + e2 = AttributeEntry("k1", 456) + assert e1 != e2 + + +def test_attribute_entry_eq_wrong_type(): + e1 = AttributeEntry("k1", 123) + assert e1 != {"key": "k1", "value": 123} diff --git a/tests/entities/data/test_attribute_request.py b/tests/entities/data/test_attribute_request.py new file mode 100644 index 0000000..700ec15 --- /dev/null +++ b/tests/entities/data/test_attribute_request.py @@ -0,0 +1,85 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, patch + +import pytest + +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", new_callable=AsyncMock) +async def test_attribute_request_build_with_keys(mock_get_next): + mock_get_next.return_value = 101 + + request = await AttributeRequest.build( + shared_keys=["shared1", "shared2"], + client_keys=["client1"] + ) + + assert request.request_id == 101 + assert request.shared_keys == ["shared1", "shared2"] + assert request.client_keys == ["client1"] + assert request.to_payload_format() == { + "sharedKeys": "shared1,shared2", + "clientKeys": "client1" + } + assert repr(request) == ( + "AttributeRequest(id=101, shared_keys=['shared1', 'shared2'], client_keys=['client1'])" + ) + + +@pytest.mark.asyncio +@patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", new_callable=AsyncMock) +async def test_attribute_request_build_with_none_keys(mock_get_next): + mock_get_next.return_value = 202 + + request = await AttributeRequest.build() + assert request.request_id == 202 + assert request.shared_keys is None + assert request.client_keys is None + assert request.to_payload_format() == {} + assert repr(request) == "AttributeRequest(id=202, shared_keys=None, client_keys=None)" + + +@pytest.mark.asyncio +async def test_attribute_request_direct_instantiation_fails(): + with pytest.raises(TypeError, match="Direct instantiation of AttributeRequest is not allowed"): + AttributeRequest(1, ["s1"], ["c1"]) + + +@pytest.mark.asyncio +async def test_attribute_request_invalid_json_keys(): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", + new_callable=AsyncMock) as mock_get_next: + mock_get_next.return_value = 999 + + with pytest.raises(ValueError, match="unsupported - set"): + await AttributeRequest.build(shared_keys={1, 2, 3}) + + with pytest.raises(ValueError, match="expected str, got int"): + await AttributeRequest.build(client_keys=[{"bad_key": {1: "value"}}]) + + +@pytest.mark.asyncio +async def test_attribute_request_invalid_nested_value(): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequestIdProducer.get_next", + new_callable=AsyncMock) as mock_get_next: + mock_get_next.return_value = 1001 + + def dummy(): pass + + with pytest.raises(ValueError, match="unsupported - function"): + await AttributeRequest.build(client_keys=[{"nested": dummy}]) diff --git a/tests/entities/data/test_attribute_update.py b/tests/entities/data/test_attribute_update.py new file mode 100644 index 0000000..87bf108 --- /dev/null +++ b/tests/entities/data/test_attribute_update.py @@ -0,0 +1,84 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate + + +@pytest.fixture +def example_entries(): + return [ + AttributeEntry("temperature", 23), + AttributeEntry("humidity", 50), + AttributeEntry("status", "ok") + ] + + +def test_repr(example_entries): + update = AttributeUpdate(example_entries) + assert repr(update) == f"AttributeUpdate(entries={example_entries})" + + +def test_get_existing_key(example_entries): + update = AttributeUpdate(example_entries) + assert update.get("temperature") == 23 + assert update.get("status") == "ok" + + +def test_get_missing_key_with_default(example_entries): + update = AttributeUpdate(example_entries) + assert update.get("nonexistent") is None + assert update.get("nonexistent", "default") == "default" + + +def test_keys(example_entries): + update = AttributeUpdate(example_entries) + assert update.keys() == ["temperature", "humidity", "status"] + + +def test_values(example_entries): + update = AttributeUpdate(example_entries) + assert update.values() == [23, 50, "ok"] + + +def test_items(example_entries): + update = AttributeUpdate(example_entries) + assert update.items() == [ + ("temperature", 23), + ("humidity", 50), + ("status", "ok") + ] + + +def test_as_dict(example_entries): + update = AttributeUpdate(example_entries) + assert update.as_dict() == { + "temperature": 23, + "humidity": 50, + "status": "ok" + } + + +def test_deserialize_from_dict(): + raw = { + "speed": 100, + "enabled": True + } + update = AttributeUpdate._deserialize_from_dict(raw) + assert isinstance(update, AttributeUpdate) + assert update.as_dict() == raw + assert update.get("speed") == 100 + assert update.get("enabled") is True diff --git a/tests/entities/data/test_data_entry.py b/tests/entities/data/test_data_entry.py new file mode 100644 index 0000000..4a5ac7e --- /dev/null +++ b/tests/entities/data/test_data_entry.py @@ -0,0 +1,95 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from tb_mqtt_client.entities.data.data_entry import DataEntry + + +def test_data_entry_initialization_without_timestamp(): + entry = DataEntry("temperature", 22.5) + + assert entry.key == "temperature" + assert entry.value == 22.5 + assert entry.ts is None + assert isinstance(entry.size, int) + assert "DataEntry(key=temperature, value=22.5, ts=None)" == repr(entry) + + +def test_data_entry_initialization_with_timestamp(): + entry = DataEntry("humidity", 60, ts=1717171717) + + assert entry.key == "humidity" + assert entry.value == 60 + assert entry.ts == 1717171717 + assert isinstance(entry.size, int) + assert "DataEntry(key=humidity, value=60, ts=1717171717)" == repr(entry) + + +def test_data_entry_key_setter_recalculates_size(): + entry = DataEntry("initial", True) + original_size = entry.size + entry.key = "updated_key" + + assert entry.key == "updated_key" + assert isinstance(entry.size, int) + assert entry.size != original_size + + +def test_data_entry_value_setter_recalculates_size(): + entry = DataEntry("key", "initial") + original_size = entry.size + entry.value = "updated_value" + + assert entry.value == "updated_value" + assert isinstance(entry.size, int) + assert entry.size != original_size + + +def test_data_entry_ts_setter_recalculates_size(): + entry = DataEntry("voltage", 3.3) + original_size = entry.size + entry.ts = 1710000000 + + assert entry.ts == 1710000000 + assert isinstance(entry.size, int) + assert entry.size != original_size + + +def test_data_entry_raises_on_invalid_json(): + class NonSerializable: + pass + + with pytest.raises(ValueError, match="unsupported - NonSerializable"): + DataEntry("bad", NonSerializable()) + + +@patch("tb_mqtt_client.entities.data.data_entry.dumps") +def test_estimate_size_called_with_ts(mock_dumps): + mock_dumps.return_value = b'{"ts":123,"values":{"x":42}}' + entry = DataEntry("x", 42, ts=123) + + mock_dumps.assert_called_once_with({"ts": 123, "values": {"x": 42}}) + assert entry.size == len(mock_dumps.return_value) + + +@patch("tb_mqtt_client.entities.data.data_entry.dumps") +def test_estimate_size_called_without_ts(mock_dumps): + mock_dumps.return_value = b'{"x":42}' + entry = DataEntry("x", 42) + + mock_dumps.assert_called_once_with({"x": 42}) + assert entry.size == len(mock_dumps.return_value) diff --git a/tests/entities/data/test_device_uplink_message.py b/tests/entities/data/test_device_uplink_message.py new file mode 100644 index 0000000..bd90cc7 --- /dev/null +++ b/tests/entities/data/test_device_uplink_message.py @@ -0,0 +1,156 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import OrderedDict +from types import MappingProxyType + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder, DeviceUplinkMessage, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + + +@pytest.fixture +def attribute_entry(): + return AttributeEntry(key="temp", value=42) + + +@pytest.fixture +def timeseries_entry(): + return TimeseriesEntry(key="speed", value=88, ts=1234567890) + + +def test_direct_instantiation_forbidden(): + with pytest.raises(TypeError, match="Direct instantiation of DeviceUplinkMessage is not allowed"): + DeviceUplinkMessage(device_name="test", device_profile="default", attributes=(), timeseries={}, + delivery_futures=[], _size=0) + + +def test_build_empty_message(): + builder = DeviceUplinkMessageBuilder() + msg = builder.build() + + assert msg.device_name is None + assert msg.device_profile is None + assert msg.attributes == () + assert isinstance(msg.timeseries, MappingProxyType) + assert dict(msg.timeseries) == {} + assert len(msg.delivery_futures) == 1 + assert isinstance(msg.delivery_futures[0], asyncio.Future) + assert msg.size == DEFAULT_FIELDS_SIZE + assert not msg.has_attributes() + assert not msg.has_timeseries() + assert msg.attributes_datapoint_count() == 0 + assert msg.timeseries_datapoint_count() == 0 + + +def test_set_device_name_and_profile(): + builder = DeviceUplinkMessageBuilder() + msg = builder.set_device_name("device-1").set_device_profile("profile-x").build() + + assert msg.device_name == "device-1" + assert msg.device_profile == "profile-x" + assert msg.size == DEFAULT_FIELDS_SIZE + len("device-1") + len("profile-x") + + +def test_add_single_attribute(attribute_entry): + builder = DeviceUplinkMessageBuilder() + msg = builder.add_attributes(attribute_entry).build() + + assert len(msg.attributes) == 1 + assert msg.attributes[0] == attribute_entry + assert msg.attributes_datapoint_count() == 1 + assert msg.has_attributes() + assert msg.size == DEFAULT_FIELDS_SIZE + attribute_entry.size + + +def test_add_multiple_attributes(attribute_entry): + other = AttributeEntry(key="humidity", value=60) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_attributes([attribute_entry, other]).build() + + assert len(msg.attributes) == 2 + assert msg.attributes == (attribute_entry, other) + expected_size = DEFAULT_FIELDS_SIZE + attribute_entry.size + other.size + assert msg.size == expected_size + + +def test_add_single_timeseries_entry(timeseries_entry): + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries(timeseries_entry).build() + + assert len(msg.timeseries) == 1 + assert timeseries_entry.ts in msg.timeseries + assert msg.timeseries[timeseries_entry.ts] == (timeseries_entry,) + assert msg.has_timeseries() + assert msg.timeseries_datapoint_count() == 1 + assert msg.size == DEFAULT_FIELDS_SIZE + timeseries_entry.size + + +def test_add_multiple_timeseries_entries(timeseries_entry): + entry2 = TimeseriesEntry(key="vibration", value=1.5, ts=timeseries_entry.ts) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries([timeseries_entry, entry2]).build() + + assert timeseries_entry.ts in msg.timeseries + assert msg.timeseries[timeseries_entry.ts] == (timeseries_entry, entry2) + assert msg.timeseries_datapoint_count() == 2 + assert msg.size == DEFAULT_FIELDS_SIZE + timeseries_entry.size + entry2.size + + +def test_add_timeseries_with_none_ts(): + entry = TimeseriesEntry(key="x", value=1, ts=None) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries(entry).build() + + assert 0 in msg.timeseries + assert msg.timeseries[0] == (entry,) + assert msg.timeseries_datapoint_count() == 1 + assert msg.size == DEFAULT_FIELDS_SIZE + entry.size + + +def test_add_timeseries_from_ordered_dict(timeseries_entry): + ordered = OrderedDict({timeseries_entry.ts: [timeseries_entry]}) + builder = DeviceUplinkMessageBuilder() + msg = builder.add_timeseries(ordered).build() + + assert msg.timeseries[timeseries_entry.ts] == (timeseries_entry,) + assert msg.timeseries_datapoint_count() == 1 + + +def test_add_delivery_futures(): + future1 = asyncio.Future() + future2 = asyncio.Future() + builder = DeviceUplinkMessageBuilder() + msg = builder.add_delivery_futures([future1, future2]).build() + + assert msg.delivery_futures == (future1, future2) + + +def test_repr_contains_info(attribute_entry, timeseries_entry): + builder = DeviceUplinkMessageBuilder() + msg = builder.set_device_name("test-device") \ + .set_device_profile("dp") \ + .add_attributes(attribute_entry) \ + .add_timeseries(timeseries_entry) \ + .build() + + repr_str = repr(msg) + assert "test-device" in repr_str + assert "dp" in repr_str + assert "AttributeEntry" in repr_str + assert "TimeseriesEntry" in repr_str diff --git a/tests/entities/data/test_provisioning_data.py b/tests/entities/data/test_provisioning_data.py new file mode 100644 index 0000000..d1165c1 --- /dev/null +++ b/tests/entities/data/test_provisioning_data.py @@ -0,0 +1,170 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch, mock_open + +import pytest + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials, \ + BasicProvisioningCredentials, X509ProvisioningCredentials +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse + + +@pytest.fixture +def access_token_request(): + return ProvisioningRequest( + host="my-host", + port=1883, + credentials=AccessTokenProvisioningCredentials( + provision_device_key="provision_device_key", + provision_device_secret="provision_device_secret", + access_token="access-token123" + ) + ) + + +@pytest.fixture +def mqtt_basic_request(): + return ProvisioningRequest( + host="test-host", + port=8883, + credentials=BasicProvisioningCredentials( + provision_device_key="provision_device_key", + provision_device_secret="provision_device_secret", + client_id="my-client-id", + username="user1", + password="pass123" + ) + ) + + +@pytest.fixture +def x509_request(): + mocked_cert_content = "-----BEGIN CERTIFICATE-----\nABCDEF\n-----END CERTIFICATE-----" + + with patch("builtins.open", mock_open(read_data=mocked_cert_content)): + return ProvisioningRequest( + host="secure-host", + port=8883, + credentials=X509ProvisioningCredentials( + provision_device_key="provision_device_key", + provision_device_secret="provision_device_secret", + private_key_path="/path/to/private.key", + public_cert_path="/path/to/client.crt", + ca_cert_path="/path/to/ca.crt" + ) + ) + + +def test_direct_instantiation_fails(): + with pytest.raises(TypeError, match="Direct instantiation of ProvisioningResponse is not allowed"): + ProvisioningResponse(status=ProvisioningResponseStatus.SUCCESS) + + +def test_error_response(): + access_token_credentials = AccessTokenProvisioningCredentials("provision_device_key", "provision_device_secret") + request = ProvisioningRequest(host="any", port=1883, credentials=access_token_credentials) + payload = {"errorMsg": "Provisioning failed", "status": ProvisioningResponseStatus.ERROR.value} + + response = ProvisioningResponse.build(request, payload) + + assert response.status == ProvisioningResponseStatus.ERROR + assert response.result is None + assert response.error == "Provisioning failed" + + +def test_success_access_token(access_token_request): + payload = {"credentialsValue": "ACCESS-TOKEN-123", "status": "SUCCESS"} + + response = ProvisioningResponse.build(access_token_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + assert response.error is None + assert isinstance(response.result, DeviceConfig) + assert response.result.access_token == "ACCESS-TOKEN-123" + assert response.result.host == "my-host" + assert response.result.port == 1883 + + +def test_success_mqtt_basic(mqtt_basic_request): + payload = { + "credentialsValue": { + "clientId": "my-client-id", + "userName": "user1", + "password": "pass123" + }, + "status": "SUCCESS" + } + + response = ProvisioningResponse.build(mqtt_basic_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + assert isinstance(response.result, DeviceConfig) + config = response.result + assert config.client_id == "my-client-id" + assert config.username == "user1" + assert config.password == "pass123" + assert config.host == "test-host" + assert config.port == 8883 + assert response.error is None + + +def test_success_x509(x509_request): + payload = {"credentialsValue": None, "status": "SUCCESS"} # Should be ignored for X509 + + response = ProvisioningResponse.build(x509_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + config = response.result + assert config.ca_cert == "/path/to/ca.crt" + assert config.client_cert == "/path/to/client.crt" + assert config.private_key == "/path/to/private.key" + assert config.host == "secure-host" + assert config.port == 8883 + assert response.error is None + + +def test_repr_output(access_token_request): + payload = {"credentialsValue": "access-token", "status": "SUCCESS"} + response = ProvisioningResponse.build(access_token_request, payload) + + r = repr(response) + assert r.startswith("ProvisioningResponse(status=SUCCESS") + assert "DeviceConfig(" in r + assert "host=my-host" in r + assert "port=1883" in r + assert "error=None" in r + + +def test_missing_credentials_value_for_access_token(access_token_request): + payload = {"status": "SUCCESS"} # Missing 'credentialsValue' + + with pytest.raises(KeyError): + ProvisioningResponse.build(access_token_request, payload) + + +def test_mqtt_basic_missing_fields(mqtt_basic_request): + payload = {"credentialsValue": {}, "status": "SUCCESS"} # All fields missing + + with pytest.raises(KeyError): + ProvisioningResponse.build(mqtt_basic_request, payload) + + +def test_access_token_none_is_accepted(access_token_request): + payload = {"credentialsValue": None, "status": "SUCCESS"} + response = ProvisioningResponse.build(access_token_request, payload) + + assert response.status == ProvisioningResponseStatus.SUCCESS + assert response.result.access_token is None diff --git a/tests/entities/data/test_requested_attribute_response.py b/tests/entities/data/test_requested_attribute_response.py new file mode 100644 index 0000000..ff6e6bf --- /dev/null +++ b/tests/entities/data/test_requested_attribute_response.py @@ -0,0 +1,127 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse + + +@pytest.fixture +def shared_attrs(): + return [AttributeEntry("shared_key_1", "shared_value_1"), AttributeEntry("shared_key_2", 42)] + + +@pytest.fixture +def client_attrs(): + return [AttributeEntry("client_key_1", "client_value_1"), AttributeEntry("client_key_2", True)] + + +@pytest.fixture +def response(shared_attrs, client_attrs): + return RequestedAttributeResponse(request_id=1001, shared=shared_attrs, client=client_attrs) + + +def test_repr(response): + output = repr(response) + assert "RequestedAttributeResponse" in output + assert "request_id=1001" in output + assert "shared_key_1" in output + assert "client_key_1" in output + + +def test_getitem_shared_key(response): + assert response["shared_key_1"] == "shared_value_1" + assert response["shared_key_2"] == 42 + + +def test_getitem_client_key(response): + assert response["client_key_1"] == "client_value_1" + assert response["client_key_2"] is True + + +def test_getitem_key_error(response): + with pytest.raises(KeyError, match="Key 'nonexistent' not found"): + _ = response["nonexistent"] + + +def test_get_shared_success(shared_attrs, client_attrs): + response = RequestedAttributeResponse(5, shared_attrs, client_attrs) + assert response.get_shared("shared_key_1") == "shared_value_1" + assert response.get_shared("shared_key_2") == 42 + + +def test_get_shared_default(response): + assert response.get_shared("unknown", default="fallback") == "fallback" + assert response.get_shared("unknown") is None + + +def test_get_client_success(response): + assert response.get_client("client_key_1") == "client_value_1" + assert response.get_client("client_key_2") is True + + +def test_get_client_default(response): + assert response.get_client("unknown", default=0) == 0 + assert response.get_client("unknown") is None + + +def test_shared_keys(response): + keys = response.shared_keys() + assert keys == ["shared_key_1", "shared_key_2"] + + +def test_client_keys(response): + keys = response.client_keys() + assert keys == ["client_key_1", "client_key_2"] + + +def test_as_dict(response): + result = response.as_dict() + assert isinstance(result, dict) + assert "shared" in result + assert "client" in result + assert result["shared"][0]["key"] == "shared_key_1" + assert result["client"][1]["value"] is True + + +def test_from_dict_full(): + data = { + "request_id": 77, + "shared": {"temp": 25, "mode": "cool"}, + "client": {"state": "on", "power": 100} + } + resp = RequestedAttributeResponse.from_dict(data) + assert isinstance(resp, RequestedAttributeResponse) + assert resp.request_id == 77 + assert resp.get_shared("temp") == 25 + assert resp.get_client("power") == 100 + + +def test_from_dict_missing_request_id(): + data = { + "shared": {"x": 1}, + "client": {"y": 2} + } + resp = RequestedAttributeResponse.from_dict(data) + assert resp.request_id == -1 + assert resp.get_shared("x") == 1 + assert resp.get_client("y") == 2 + + +def test_from_dict_empty(): + resp = RequestedAttributeResponse.from_dict({}) + assert resp.request_id == -1 + assert resp.shared == [] + assert resp.client == [] diff --git a/tests/entities/data/test_rpc_request.py b/tests/entities/data/test_rpc_request.py new file mode 100644 index 0000000..b3ca496 --- /dev/null +++ b/tests/entities/data/test_rpc_request.py @@ -0,0 +1,91 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from tb_mqtt_client.common.request_id_generator import RPCRequestIdProducer +from tb_mqtt_client.entities.data.rpc_request import RPCRequest + + +@pytest.mark.asyncio +async def test_rpc_request_build_valid(): + method = "reboot" + params = {"delay": 5} + with patch.object(RPCRequestIdProducer, "get_next", return_value=42): + request = await RPCRequest.build(method, params) + + assert request.request_id == 42 + assert request.method == method + assert request.params == params + assert request.to_payload_format() == { + "id": 42, + "method": "reboot", + "params": {"delay": 5} + } + + +def test_rpc_request_direct_instantiation_raises(): + with pytest.raises(TypeError, match="Direct instantiation of RPCRequest is not allowed"): + RPCRequest(1, "reboot") + + +@pytest.mark.asyncio +async def test_rpc_request_build_with_none_params(): + with patch.object(RPCRequestIdProducer, "get_next", return_value="abc123"): + request = await RPCRequest.build("ping") + + assert request.params is None + assert request.request_id == "abc123" + assert request.to_payload_format() == { + "id": "abc123", + "method": "ping" + } + + +@pytest.mark.asyncio +async def test_rpc_request_build_invalid_method_type(): + with pytest.raises(ValueError, match="Method must be a string"): + await RPCRequest.build(123, {"key": "value"}) + + +def test_rpc_request_deserialize_valid(): + data = {"method": "getStatus", "params": {"verbose": True}} + req = RPCRequest._deserialize_from_dict(1001, data) + + assert req.request_id == 1001 + assert req.method == "getStatus" + assert req.params == {"verbose": True} + assert req.to_payload_format() == { + "id": 1001, + "method": "getStatus", + "params": {"verbose": True} + } + + +def test_rpc_request_deserialize_missing_id(): + with pytest.raises(ValueError, match="Missing request id"): + RPCRequest._deserialize_from_dict(None, {"method": "test"}) + + +def test_rpc_request_deserialize_missing_method(): + with pytest.raises(ValueError, match="Missing 'method' in RPC request"): + RPCRequest._deserialize_from_dict(1, {}) + + +def test_rpc_request_repr(): + data = {"method": "info", "params": {"a": 1}} + req = RPCRequest._deserialize_from_dict(42, data) + assert repr(req) == "RPCRequest(id=42, method=info, params={'a': 1})" diff --git a/tests/entities/data/test_rpc_response.py b/tests/entities/data/test_rpc_response.py new file mode 100644 index 0000000..b7fa703 --- /dev/null +++ b/tests/entities/data/test_rpc_response.py @@ -0,0 +1,87 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.rpc_response import RPCResponse, RPCStatus + + +class NonSerializable: + pass + + +def test_rpc_status_str(): + assert str(RPCStatus.SUCCESS) == "SUCCESS" + assert str(RPCStatus.ERROR) == "ERROR" + assert str(RPCStatus.TIMEOUT) == "TIMEOUT" + assert str(RPCStatus.NOT_FOUND) == "NOT_FOUND" + + +def test_rpc_response_direct_instantiation_raises(): + with pytest.raises(TypeError): + RPCResponse(1, result="test") + + +def test_rpc_response_success(): + response = RPCResponse.build(request_id=1, result={"key": "value"}) + assert response.request_id == 1 + assert response.status == RPCStatus.SUCCESS + assert response.result == {"key": "value"} + assert response.error is None + assert response.to_payload_format() == {"result": {"key": "value"}} + + +def test_rpc_response_error_with_string(): + response = RPCResponse.build(request_id="req-1", error="Something went wrong") + assert response.request_id == "req-1" + assert response.status == RPCStatus.ERROR + assert response.error == "Something went wrong" + assert response.result is None + assert response.to_payload_format() == {"error": "Something went wrong"} + + +def test_rpc_response_error_with_dict(): + error_dict = {"code": 500, "message": "Internal Error"} + response = RPCResponse.build(request_id="req-2", error=error_dict) + assert response.request_id == "req-2" + assert response.status == RPCStatus.ERROR + assert response.error == error_dict + assert response.result is None + assert response.to_payload_format() == {"error": error_dict} + + +def test_rpc_response_error_with_exception(): + try: + raise ValueError("Something failed") + except ValueError as ex: + response = RPCResponse.build(request_id=42, error=ex) + + assert response.status == RPCStatus.ERROR + assert isinstance(response.error, dict) + assert "message" in response.error + assert "type" in response.error + assert "details" in response.error + assert response.error["type"] == "ValueError" + assert response.error["message"] == "Something failed" + assert "raise ValueError(\"Something failed\")" in response.error["details"] + + +def test_rpc_response_invalid_error_type(): + with pytest.raises(ValueError): + RPCResponse.build(request_id=1, error=NonSerializable()) + + +def test_rpc_response_invalid_result_type(): + with pytest.raises(ValueError): + RPCResponse.build(request_id=1, result=NonSerializable()) diff --git a/tests/entities/data/test_timeseries_entry.py b/tests/entities/data/test_timeseries_entry.py new file mode 100644 index 0000000..ee12b7e --- /dev/null +++ b/tests/entities/data/test_timeseries_entry.py @@ -0,0 +1,75 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + + +@pytest.mark.parametrize( + "key,value,ts,expected_dict", + [ + ("temperature", 23.5, 1650000000000, {"key": "temperature", "value": 23.5, "ts": 1650000000000}), + ("status", "online", None, {"key": "status", "value": "online"}), + ("flag", True, 0, {"key": "flag", "value": True, "ts": 0}), + ] +) +def test_as_dict(key, value, ts, expected_dict): + entry = TimeseriesEntry(key, value, ts) + assert entry.as_dict() == expected_dict + + +def test_repr_output(): + entry = TimeseriesEntry("humidity", 60, 1650001234567) + assert repr(entry) == "TimeseriesEntry(key=humidity, value=60, ts=1650001234567)" + + +def test_repr_without_ts(): + entry = TimeseriesEntry("humidity", 60) + assert repr(entry) == "TimeseriesEntry(key=humidity, value=60, ts=None)" + + +def test_equality_same_values(): + e1 = TimeseriesEntry("pressure", 101.3, 1650000000000) + e2 = TimeseriesEntry("pressure", 101.3, 1650000000000) + assert e1 == e2 + + +def test_equality_different_key(): + e1 = TimeseriesEntry("a", 1, 123) + e2 = TimeseriesEntry("b", 1, 123) + assert e1 != e2 + + +def test_equality_different_value(): + e1 = TimeseriesEntry("key", 1, 123) + e2 = TimeseriesEntry("key", 2, 123) + assert e1 != e2 + + +def test_equality_different_ts(): + e1 = TimeseriesEntry("key", "value", 123) + e2 = TimeseriesEntry("key", "value", 456) + assert e1 != e2 + + +def test_equality_different_type(): + e1 = TimeseriesEntry("key", "value", 123) + assert e1 != {"key": "key", "value": "value", "ts": 123} + + +def test_default_ts_none(): + entry = TimeseriesEntry("key", "value") + assert entry.ts is None + assert entry.as_dict() == {"key": "key", "value": "value"} diff --git a/tests/entities/gateway/__init__.py b/tests/entities/gateway/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/entities/gateway/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/entities/gateway/test_base_gateway_event.py b/tests/entities/gateway/test_base_gateway_event.py new file mode 100644 index 0000000..3206a09 --- /dev/null +++ b/tests/entities/gateway/test_base_gateway_event.py @@ -0,0 +1,61 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from tb_mqtt_client.entities.gateway.base_gateway_event import BaseGatewayEvent +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +class DummyDeviceSession: + """A simple dummy object to mimic a device session.""" + pass + + +def test_initialization_and_event_type(): + # Create event with a specific GatewayEventType + event = BaseGatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Ensure the event_type property returns what we set + assert event.event_type == GatewayEventType.GATEWAY_CONNECT + + # device_session should be None initially + assert event.device_session is None + + +def test_set_and_get_device_session(): + event = BaseGatewayEvent(GatewayEventType.GATEWAY_DISCONNECT) + + # Initially None + assert event.device_session is None + + # Set a device session + dummy_session = DummyDeviceSession() + event.set_device_session(dummy_session) + + # Ensure the getter returns what we set + assert event.device_session is dummy_session + + +def test_str_calls_repr(monkeypatch): + event = BaseGatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Monkeypatch __repr__ to a known value + monkeypatch.setattr(event, "__repr__", lambda: "mocked_repr") + + # __str__ should return whatever __repr__ returns + assert str(event) == "mocked_repr" + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/entities/gateway/test_device_connect_message.py b/tests/entities/gateway/test_device_connect_message.py new file mode 100644 index 0000000..f2c9e53 --- /dev/null +++ b/tests/entities/gateway/test_device_connect_message.py @@ -0,0 +1,62 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +def test_direct_instantiation_not_allowed(): + # Direct instantiation should raise a TypeError + with pytest.raises(TypeError) as exc_info: + DeviceConnectMessage("device1") + assert "Direct instantiation" in str(exc_info.value) + + +def test_build_with_empty_device_name_raises(): + # Empty device name should raise ValueError + with pytest.raises(ValueError) as exc_info: + DeviceConnectMessage.build("") + assert "must not be empty" in str(exc_info.value) + + +def test_build_and_repr_and_payload(): + msg = DeviceConnectMessage.build("MyDevice", "MyProfile") + + # Check attributes + assert msg.device_name == "MyDevice" + assert msg.device_profile == "MyProfile" + assert msg.event_type == GatewayEventType.DEVICE_CONNECT + + # __repr__ format check + repr_str = repr(msg) + assert "DeviceConnectMessage" in repr_str + assert "MyDevice" in repr_str + assert "MyProfile" in repr_str + + # Payload format check + payload = msg.to_payload_format() + assert payload == {"device": "MyDevice", "type": "MyProfile"} + + +def test_build_with_default_profile(): + msg = DeviceConnectMessage.build("DefaultDevice") + assert msg.device_profile == "default" + assert msg.to_payload_format() == {"device": "DefaultDevice", "type": "default"} + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_device_disconnect_message.py b/tests/entities/gateway/test_device_disconnect_message.py new file mode 100644 index 0000000..7ee78ef --- /dev/null +++ b/tests/entities/gateway/test_device_disconnect_message.py @@ -0,0 +1,58 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType + + +def test_direct_instantiation_not_allowed(): + # Directly calling constructor should raise TypeError + with pytest.raises(TypeError) as exc_info: + DeviceDisconnectMessage("DeviceX") + assert "Direct instantiation of DeviceDisconnectMessage" in str(exc_info.value) + + +def test_build_and_attributes(): + # Build using correct way + msg = DeviceDisconnectMessage.build("TestDevice") + assert isinstance(msg, DeviceDisconnectMessage) + assert msg.device_name == "TestDevice" + assert msg.event_type == GatewayEventType.DEVICE_DISCONNECT + # Check __repr__ + assert repr(msg) == "DeviceDisconnectMessage(device_name=TestDevice)" + # Check to_payload_format + assert msg.to_payload_format() == {"device": "TestDevice"} + + +def test_build_with_empty_device_name(): + # Empty device name should raise ValueError + with pytest.raises(ValueError) as exc_info: + DeviceDisconnectMessage.build("") + assert "Device name must not be empty" in str(exc_info.value) + + +def test_frozen_slots_and_equality_behavior(): + msg1 = DeviceDisconnectMessage.build("A") + msg2 = DeviceDisconnectMessage.build("A") + msg3 = DeviceDisconnectMessage.build("B") + + # Equality check works for dataclass with frozen=True + assert msg1 == msg2 + assert msg1 != msg3 + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_device_info.py b/tests/entities/gateway/test_device_info.py new file mode 100644 index 0000000..b71053a --- /dev/null +++ b/tests/entities/gateway/test_device_info.py @@ -0,0 +1,112 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo + + +def test_post_init_and_basic_properties(): + """Test that __post_init__ sets original_name and _initializing is False.""" + d = DeviceInfo(device_name="dev1", device_profile="profile1") + assert d.device_name == "dev1" + assert d.original_name == "dev1" + assert isinstance(d.device_id, uuid.UUID) + assert not d._initializing + + +def test_setattr_after_init_raises(): + """Test that modifying attributes after init raises AttributeError.""" + d = DeviceInfo(device_name="dev1", device_profile="p1") + with pytest.raises(AttributeError) as exc: + d.device_name = "new_name" + assert "Cannot modify attribute" in str(exc.value) + + +def test_rename_changes_name(): + """Test rename method changes device_name.""" + d = DeviceInfo(device_name="dev1", device_profile="p1") + old_id = d.device_id + d.rename("new_dev") + assert d.device_name == "new_dev" + assert d.device_id == old_id # ID stays the same + + +def test_rename_same_name_no_change(): + """Renaming to the same name should not alter anything.""" + d = DeviceInfo(device_name="dev1", device_profile="p1") + d.rename("dev1") # No change expected + assert d.device_name == "dev1" + + +def test_from_dict_and_to_dict(): + """Test creating from dict and converting back to dict.""" + dev_id = str(uuid.uuid4()) + data = { + "device_name": "dname", + "device_profile": "prof", + "device_id": dev_id, + "original_name": "orig" + } + d = DeviceInfo.from_dict(data) + assert isinstance(d, DeviceInfo) + assert str(d.device_id) == dev_id + assert d.device_name == "dname" + assert d.original_name == "orig" + + # Check to_dict output + out = d.to_dict() + assert out["device_name"] == "dname" + assert out["device_profile"] == "prof" + assert out["device_id"] == dev_id + assert out["original_name"] == "orig" + + +def test_str_and_repr(): + """Test __str__ and __repr__ produce expected formats.""" + d = DeviceInfo(device_name="dname", device_profile="prof") + s = str(d) + r = repr(d) + assert "DeviceInfo" in s + assert "DeviceInfo" in r + assert str(d.device_id) in s + assert repr(d.device_id) in r + + +def test_eq_and_hash(): + """Test equality and hash implementation.""" + d1 = DeviceInfo(device_name="dname", device_profile="prof") + d2 = DeviceInfo(device_name="dname", device_profile="prof") + # Manually make IDs equal for equality check + object.__setattr__(d2, "device_id", d1.device_id) + object.__setattr__(d2, "original_name", d1.original_name) + + assert d1 == d2 + assert hash(d1) == hash(d2) + + # Different object type should return NotImplemented for eq + assert d1.__eq__("not a DeviceInfo") is NotImplemented + + +def test_hash_works_in_set(): + """Test that DeviceInfo is hashable and works in a set.""" + d = DeviceInfo(device_name="dname", device_profile="prof") + s = {d} + assert d in s + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_attribute_request.py b/tests/entities/gateway/test_gateway_attribute_request.py new file mode 100644 index 0000000..a8b498c --- /dev/null +++ b/tests/entities/gateway/test_gateway_attribute_request.py @@ -0,0 +1,131 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch, AsyncMock + +import pytest + +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest + + +@pytest.mark.asyncio +async def test_direct_instantiation_not_allowed(): + with pytest.raises(TypeError) as exc: + GatewayAttributeRequest() + assert "Direct instantiation" in str(exc.value) + + +@pytest.mark.asyncio +async def test_build_assigns_values_and_repr(): + mock_session = MagicMock() + mock_session.device_info.device_name = "TestDevice" + + with patch("tb_mqtt_client.common.request_id_generator.AttributeRequestIdProducer.get_next", + new=AsyncMock(return_value=123)): + req = await GatewayAttributeRequest.build( + device_session=mock_session, + shared_keys=["s1", "s2"], + client_keys=["c1"] + ) + + # Check attributes set correctly + assert req.device_session is mock_session + assert req.request_id == 123 + assert req.shared_keys == ["s1", "s2"] + assert req.client_keys == ["c1"] + assert req.event_type == GatewayEventType.DEVICE_ATTRIBUTE_REQUEST + + # __repr__ should include these values + r = repr(req) + assert "TestDevice" in r or "device_session" in r + assert "shared_keys" in r + assert "client_keys" in r + + +@pytest.mark.asyncio +async def test_from_attribute_request_success(): + mock_session = MagicMock() + mock_session.device_info.device_name = "MyDev" + + attr_req = await AttributeRequest.build(shared_keys=["sh1"], client_keys=["cl1"]) + + result = await GatewayAttributeRequest.from_attribute_request(mock_session, attr_req) + + assert result.device_session is mock_session + assert result.request_id == attr_req.request_id + assert result.shared_keys == ["sh1"] + assert result.client_keys == ["cl1"] + assert result.event_type == GatewayEventType.DEVICE_ATTRIBUTE_REQUEST + + +@pytest.mark.asyncio +async def test_from_attribute_request_invalid_type(): + mock_session = MagicMock() + with pytest.raises(TypeError) as exc: + await GatewayAttributeRequest.from_attribute_request(mock_session, object()) + assert "must be an instance of AttributeRequest" in str(exc.value) + + +def make_request_for_payload(shared=None, client=None): + mock_session = MagicMock() + mock_session.device_info.device_name = "DeviceX" + + req = object.__new__(GatewayAttributeRequest) + object.__setattr__(req, 'device_session', mock_session) + object.__setattr__(req, 'request_id', 999) + object.__setattr__(req, 'shared_keys', shared) + object.__setattr__(req, 'client_keys', client) + object.__setattr__(req, 'event_type', GatewayEventType.DEVICE_ATTRIBUTE_REQUEST) + return req + + +def test_to_payload_format_client_multiple_keys(): + req = make_request_for_payload(client=["a", "b"]) + payload = req.to_payload_format() + assert payload["device"] == "DeviceX" + assert payload["client"] is True + assert payload["keys"] == ["a", "b"] + + +def test_to_payload_format_client_single_key(): + req = make_request_for_payload(client=["only_one"]) + payload = req.to_payload_format() + assert payload["client"] is True + assert payload["key"] == "only_one" + + +def test_to_payload_format_shared_multiple_keys(): + req = make_request_for_payload(shared=["s1", "s2"]) + payload = req.to_payload_format() + assert payload["client"] is False + assert payload["keys"] == ["s1", "s2"] + + +def test_to_payload_format_shared_single_key(): + req = make_request_for_payload(shared=["s1"]) + payload = req.to_payload_format() + assert payload["client"] is False + assert payload["key"] == "s1" + + +def test_to_payload_format_no_keys(): + req = make_request_for_payload(shared=None, client=None) + payload = req.to_payload_format() + assert set(payload.keys()) == {"device", "id"} + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_attribute_update.py b/tests/entities/gateway/test_gateway_attribute_update.py new file mode 100644 index 0000000..9d7a04c --- /dev/null +++ b/tests/entities/gateway/test_gateway_attribute_update.py @@ -0,0 +1,58 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate + + +def test_init_with_list_of_entries(): + entries = [AttributeEntry("k1", "v1"), AttributeEntry("k2", "v2")] + obj = GatewayAttributeUpdate("deviceA", entries) + assert obj.device_name == "deviceA" + assert len(obj.entries) == 2 + assert all(isinstance(e, AttributeEntry) for e in obj.entries) + assert isinstance(obj.attribute_update, AttributeUpdate) + assert str(obj) == f"GatewayAttributeUpdate(device_name=deviceA, attribute_update={obj.attribute_update})" + + +def test_init_with_single_entry(): + entry = AttributeEntry("k", "v") + obj = GatewayAttributeUpdate("deviceB", entry) + assert obj.device_name == "deviceB" + assert len(obj.entries) == 1 + assert obj.entries[0].key == "k" + assert isinstance(obj.attribute_update, AttributeUpdate) + assert "deviceB" in str(obj) + + +def test_init_with_attribute_update(): + update = AttributeUpdate([AttributeEntry("ka", "va")]) + obj = GatewayAttributeUpdate("deviceC", update) + assert obj.device_name == "deviceC" + assert len(obj.entries) == 1 + assert isinstance(obj.attribute_update, AttributeUpdate) + assert "deviceC" in str(obj) + + +def test_init_with_invalid_type(): + with pytest.raises(TypeError) as exc: + GatewayAttributeUpdate("deviceD", {"invalid": "type"}) + assert "attribute_update must be an instance of AttributeUpdate" in str(exc.value) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_claim_request.py b/tests/entities/gateway/test_gateway_claim_request.py new file mode 100644 index 0000000..16a061f --- /dev/null +++ b/tests/entities/gateway/test_gateway_claim_request.py @@ -0,0 +1,99 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_claim_request import ( + GatewayClaimRequest, + GatewayClaimRequestBuilder +) +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_direct_instantiation_not_allowed(): + with pytest.raises(TypeError) as exc: + GatewayClaimRequest() + assert "Direct instantiation" in str(exc.value) + + +def test_repr_and_add_device_request_with_str(): + claim_req = ClaimRequest.build("secret", 1000) + + req = GatewayClaimRequest.build() + assert isinstance(req.devices_requests, dict) + assert req.event_type == GatewayEventType.GATEWAY_CLAIM_REQUEST + + req.add_device_request("device_1", claim_req) + assert "device_1" in req.devices_requests + assert repr(req) == f"GatewayClaimRequest(devices_requests={req.devices_requests})" + + payload = req.to_payload_format() + assert payload == {"device_1": claim_req.to_payload_format()} + + +def test_add_device_request_with_device_session(): + claim_req = ClaimRequest.build("key2", 2000) + dev_info = DeviceInfo(device_name="devA", device_profile="default") + dev_session = DeviceSession(dev_info) + + req = GatewayClaimRequest.build() + req.add_device_request(dev_session, claim_req) + + payload = req.to_payload_format() + assert payload == {"devA": claim_req.to_payload_format()} + + +def test_builder_add_and_build_with_str(): + claim_req = ClaimRequest.build("builder_secret", 3000) + builder = GatewayClaimRequestBuilder() + builder.add_device_request("devB", claim_req) + + built_req = builder.build() + assert isinstance(built_req, GatewayClaimRequest) + assert built_req.to_payload_format() == {"devB": claim_req.to_payload_format()} + + +def test_builder_add_and_build_with_device_session(): + claim_req = ClaimRequest.build("builder_secret_2", 4000) + dev_info = DeviceInfo(device_name="devC", device_profile="default") + dev_session = DeviceSession(dev_info) + + builder = GatewayClaimRequestBuilder() + builder.add_device_request(dev_session, claim_req) + + built_req = builder.build() + assert built_req.to_payload_format() == {"devC": claim_req.to_payload_format()} + + +@pytest.mark.parametrize("bad_device", [123, object()]) +def test_builder_add_device_request_invalid_device_type(bad_device): + builder = GatewayClaimRequestBuilder() + with pytest.raises(ValueError) as exc: + builder.add_device_request(bad_device, ClaimRequest.build("key", 100)) + assert "DeviceSession or a string" in str(exc.value) + + +@pytest.mark.parametrize("bad_claim", [123, object(), "string"]) +def test_builder_add_device_request_invalid_claim_request(bad_claim): + builder = GatewayClaimRequestBuilder() + with pytest.raises(ValueError) as exc: + builder.add_device_request("devName", bad_claim) + assert "must be an instance of ClaimRequest" in str(exc.value) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_event.py b/tests/entities/gateway/test_gateway_event.py new file mode 100644 index 0000000..73d05d2 --- /dev/null +++ b/tests/entities/gateway/test_gateway_event.py @@ -0,0 +1,76 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_initialization_and_event_type(): + # Create an event and check event type + event = GatewayEvent(GatewayEventType.GATEWAY_CONNECT) + assert event.event_type == GatewayEventType.GATEWAY_CONNECT + + # Initially __device_session should be None + assert event.get_device_session() is None + + +def test_set_and_get_device_session(): + event = GatewayEvent(GatewayEventType.GATEWAY_DISCONNECT) + + device_info = DeviceInfo("dummy_device", "default") + dummy_session = DeviceSession(device_info) + event.set_device_session(dummy_session) + + # Ensure get_device_session returns the same object + assert event.get_device_session() is dummy_session + + +def test_str_representation_with_and_without_device_session(): + event = GatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Without device session + no_session_str = str(event) + assert "GatewayEvent(type=" in no_session_str + assert "device_session=None" in no_session_str + + # With device session + device_info = DeviceInfo("dummy_device", "default") + dummy_session = DeviceSession(device_info) + event.set_device_session(dummy_session) + session_str = str(event) + assert "DeviceSession" in session_str + + +def test_to_dict_with_and_without_device_session(): + event = GatewayEvent(GatewayEventType.GATEWAY_CONNECT) + + # Without device session + d = event.to_dict() + assert d["event_type"] == GatewayEventType.GATEWAY_CONNECT + assert d["device_session"] is None + + # With device session + device_info = DeviceInfo("dummy_device", "default") + dummy_session = DeviceSession(device_info) + event.set_device_session(dummy_session) + d2 = event.to_dict() + assert d2["device_session"] is dummy_session + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/entities/gateway/test_gateway_requested_attribute_response.py b/tests/entities/gateway/test_gateway_requested_attribute_response.py new file mode 100644 index 0000000..ca26c15 --- /dev/null +++ b/tests/entities/gateway/test_gateway_requested_attribute_response.py @@ -0,0 +1,129 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse + + +def make_entry(key, value): + return AttributeEntry(key=key, value=value) + + +def test_post_init_and_event_type_enforcement(): + resp = GatewayRequestedAttributeResponse(device_name="dev1", request_id=1) + # Even if you pass no event_type, __post_init__ sets it + assert resp.event_type == GatewayEventType.DEVICE_REQUESTED_ATTRIBUTE_RESPONSE + + +def test_repr_contains_expected_fields(): + shared = [make_entry("s1", "v1")] + client = [make_entry("c1", "v2")] + resp = GatewayRequestedAttributeResponse( + device_name="devX", + request_id=42, + shared=shared, + client=client + ) + r = repr(resp) + assert "GatewayRequestedAttributeResponse" in r + assert "devX" in r + assert "42" in r + assert "s1" in r + assert "c1" in r + + +def test_getitem_finds_in_shared_first(): + shared = [make_entry("foo", "bar")] + client = [make_entry("foo", "baz")] # same key, should not be used if found in shared + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp["foo"] == "bar" # found in shared first + + +def test_getitem_finds_in_client_if_not_in_shared(): + shared = [make_entry("other", 123)] + client = [make_entry("foo", "baz")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp["foo"] == "baz" + + +def test_getitem_raises_keyerror_if_not_found(): + resp = GatewayRequestedAttributeResponse( + shared=[make_entry("one", 1)], + client=[make_entry("two", 2)] + ) + with pytest.raises(KeyError) as e: + _ = resp["missing"] + assert "Key 'missing'" in str(e.value) + + +def test_shared_keys_and_client_keys(): + shared = [make_entry("s1", "v1"), make_entry("s2", "v2")] + client = [make_entry("c1", "v3")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp.shared_keys() == ["s1", "s2"] + assert resp.client_keys() == ["c1"] + + +def test_get_shared_and_client_success(): + shared = [make_entry("sh", "sv")] + client = [make_entry("cl", "cv")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp.get_shared("sh") == "sv" + assert resp.get_client("cl") == "cv" + + +def test_get_shared_and_client_with_default_on_missing(): + shared = [make_entry("sh", "sv")] + client = [make_entry("cl", "cv")] + resp = GatewayRequestedAttributeResponse(shared=shared, client=client) + assert resp.get_shared("missing", default="def") == "def" + assert resp.get_client("missing", default="def") == "def" + + +def test_get_shared_and_client_when_none(): + resp = GatewayRequestedAttributeResponse(shared=None, client=None) + assert resp.get_shared("any", default="d") == "d" + assert resp.get_client("any", default="d") == "d" + + +def test_as_dict_returns_expected_dict(monkeypatch): + shared_entry = make_entry("s1", "v1") + client_entry = make_entry("c1", "v2") + called = {} + + def fake_as_dict_self(): + called.setdefault("calls", []).append(1) + return {"key": "dummy", "value": "dummy"} + + monkeypatch.setattr(shared_entry, "as_dict", fake_as_dict_self) + monkeypatch.setattr(client_entry, "as_dict", fake_as_dict_self) + + resp = GatewayRequestedAttributeResponse( + shared=[shared_entry], + client=[client_entry] + ) + d = resp.as_dict() + assert "shared" in d + assert "client" in d + assert isinstance(d["shared"], list) + assert isinstance(d["client"], list) + # Ensure our fake_as_dict was called twice + assert len(called["calls"]) == 2 + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/entities/gateway/test_gateway_rpc_request.py b/tests/entities/gateway/test_gateway_rpc_request.py new file mode 100644 index 0000000..5589b4d --- /dev/null +++ b/tests/entities/gateway/test_gateway_rpc_request.py @@ -0,0 +1,154 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest + + +def test_direct_instantiation_not_allowed(): + """ + Ensures dataclass can't be instantiated directly and guides to deserializer. + """ + with pytest.raises(TypeError) as exc: + GatewayRPCRequest() # type: ignore[call-arg] + assert "Direct instantiation" in str(exc.value) + + +def test_deserialize_with_int_id_and_params(): + """ + Valid deserialization: numeric request_id and explicit params. + """ + payload = { + "device": "pump-1", + "data": { + "id": 42, + "method": "reboot", + "params": {"delay": 5} + } + } + req = GatewayRPCRequest._deserialize_from_dict(payload) + + assert isinstance(req, GatewayRPCRequest) + assert req.request_id == 42 + assert req.device_name == "pump-1" + assert req.method == "reboot" + assert req.params == {"delay": 5} + assert req.event_type == GatewayEventType.DEVICE_RPC_REQUEST + + # __repr__ and __str__ should be identical and informative + expected_repr = "RPCRequest(id=42, device_name=pump-1, method=reboot, params={'delay': 5})" + assert repr(req) == expected_repr + assert str(req) == expected_repr + + +def test_deserialize_with_str_id_and_no_params(): + """ + Valid deserialization: string request_id and missing params (defaults to None). + """ + payload = { + "device": "sensor-A", + "data": { + "id": "rpc-001", + "method": "ping" + # no "params" + } + } + req = GatewayRPCRequest._deserialize_from_dict(payload) + + assert req.request_id == "rpc-001" + assert req.device_name == "sensor-A" + assert req.method == "ping" + assert req.params is None + assert req.event_type == GatewayEventType.DEVICE_RPC_REQUEST + + expected_repr = "RPCRequest(id=rpc-001, device_name=sensor-A, method=ping, params=None)" + assert repr(req) == expected_repr + assert str(req) == expected_repr + + +def test_missing_device_raises_value_error(): + """ + Missing top-level 'device' should raise a clear ValueError. + """ + payload = { + # "device": "x", + "data": {"id": 1, "method": "ping"} + } + with pytest.raises(ValueError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + assert "Missing device name" in str(exc.value) + + +def test_missing_data_key_raises_key_error(): + """ + The implementation accesses data['data'] directly; if absent, a KeyError is expected. + """ + payload = { + "device": "dev-x" + # no "data" + } + with pytest.raises(KeyError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + assert "'data'" in str(exc.value) # KeyError message includes missing key + + +@pytest.mark.parametrize("bad_id", [None, 3.14, {"n": 1}, ["1"]]) +def test_invalid_request_id_type_raises_value_error(bad_id): + """ + request_id must be int or str; anything else should raise ValueError. + """ + payload = { + "device": "dev-y", + "data": {"id": bad_id, "method": "ping"} + } + with pytest.raises(ValueError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + # The current implementation uses a generic message; assert on the key phrase. + assert "request id" in str(exc.value).lower() + + +def test_missing_method_raises_value_error(): + """ + Missing 'method' inside 'data' should raise ValueError. + """ + payload = { + "device": "dev-z", + "data": {"id": 7} # no "method" + } + with pytest.raises(ValueError) as exc: + GatewayRPCRequest._deserialize_from_dict(payload) + assert "Missing 'method'" in str(exc.value) + + +def test_event_type_is_device_rpc_request(): + """ + Sanity check to ensure event_type is set correctly by the deserializer. + """ + payload = { + "device": "gw-1", + "data": { + "id": 101, + "method": "set_threshold", + "params": {"value": 12.3} + } + } + req = GatewayRPCRequest._deserialize_from_dict(payload) + assert req.event_type == GatewayEventType.DEVICE_RPC_REQUEST + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_rpc_response.py b/tests/entities/gateway/test_gateway_rpc_response.py new file mode 100644 index 0000000..0ead467 --- /dev/null +++ b/tests/entities/gateway/test_gateway_rpc_response.py @@ -0,0 +1,177 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import pytest + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.data.rpc_response import RPCStatus +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse + + +def test_direct_instantiation_not_allowed(): + with pytest.raises(TypeError) as exc: + GatewayRPCResponse() # type: ignore[call-arg] + assert "Direct instantiation of GatewayRPCResponse is not allowed" in str(exc.value) + + +def test_build_success_with_result_only(): + req = GatewayRPCResponse.build( + device_name="pump-1", + request_id=101, + result={"ok": True} + ) + assert isinstance(req, GatewayRPCResponse) + assert req.device_name == "pump-1" + assert req.request_id == 101 + assert req.status == RPCStatus.SUCCESS + assert req.error is None + assert req.result == {"ok": True} + assert req.event_type == GatewayEventType.DEVICE_RPC_RESPONSE + + expected_repr = ( + "GatewayRPCResponse(device_name=pump-1, request_id=101, result={'ok': True}, error=None)" + ) + assert repr(req) == expected_repr + + payload = req.to_payload_format() + assert payload == {"device": "pump-1", "id": 101, "data": {"result": {"ok": True}}} + + +def test_build_with_error_string(): + req = GatewayRPCResponse.build( + device_name="sensor-A", + request_id=7, + error="boom" + ) + assert req.status == RPCStatus.ERROR + assert req.error == "boom" + assert req.result is None + + payload = req.to_payload_format() + assert payload == {"device": "sensor-A", "id": 7, "data": {"error": "boom"}} + + +def test_build_with_error_dict(): + err = {"message": "bad-things", "code": 500} + req = GatewayRPCResponse.build( + device_name="sensor-B", + request_id=8, + error=err + ) + assert req.status == RPCStatus.ERROR + assert req.error == err + assert req.result is None + + payload = req.to_payload_format() + assert payload == {"device": "sensor-B", "id": 8, "data": {"error": err}} + + +def test_build_with_exception_converted_to_struct(): + class CustomError(RuntimeError): + pass + + req = GatewayRPCResponse.build( + device_name="dev-ex", + request_id=999, + error=CustomError("kaput") + ) + + assert req.status == RPCStatus.ERROR + assert isinstance(req.error, dict) + assert req.error.get("message") == "kaput" + assert req.error.get("type") == "CustomError" + assert isinstance(req.error.get("details"), str) + assert "CustomError" in req.error["details"] + + payload = req.to_payload_format() + assert payload["device"] == "dev-ex" + assert payload["id"] == 999 + assert "error" in payload["data"] + assert payload["data"]["error"]["type"] == "CustomError" + + +def test_build_with_result_and_error_includes_both_in_payload(): + req = GatewayRPCResponse.build( + device_name="dev-mix", + request_id=321, + result={"partial": True}, + error="some error" + ) + assert req.status == RPCStatus.ERROR + assert req.result == {"partial": True} + assert req.error == "some error" + + payload = req.to_payload_format() + assert payload == { + "device": "dev-mix", + "id": 321, + "data": {"result": {"partial": True}, "error": "some error"}, + } + + +@pytest.mark.parametrize("bad_device", ["", 123, None, object()]) +def test_build_invalid_device_name_raises_value_error(bad_device): + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build(device_name=bad_device, request_id=1) # type: ignore[arg-type] + assert "Device name must be a non-empty string" in str(exc.value) + + +@pytest.mark.parametrize("bad_error", [123, 3.14, ["oops"], set(), object()]) +def test_build_invalid_error_type_raises_value_error(bad_error): + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build(device_name="dev-x", request_id=1, error=bad_error) # type: ignore[arg-type] + assert "Error must be a string, dictionary, or an exception instance" in str(exc.value) + + +def test_build_invalid_result_fails_json_validation(): + class NotJson: + pass + + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build(device_name="dev-j", request_id=2, result=NotJson()) + assert re.search(r"(json|compatible|type)", str(exc.value), re.IGNORECASE) + + +def test_build_invalid_error_dict_fails_json_validation(): + class NotJson: + pass + + with pytest.raises(ValueError) as exc: + GatewayRPCResponse.build( + device_name="dev-je", + request_id=3, + error={"bad": NotJson()} + ) + assert re.search(r"(json|compatible|type)", str(exc.value), re.IGNORECASE) + + +def test_event_type_is_device_rpc_response(): + req = GatewayRPCResponse.build(device_name="gw-1", request_id=55, result=True) + assert req.event_type == GatewayEventType.DEVICE_RPC_RESPONSE + + +def test_repr_includes_all_fields(): + req = GatewayRPCResponse.build(device_name="repr-dev", request_id=42, result={"x": 1}) + assert repr(req) == "GatewayRPCResponse(device_name=repr-dev, request_id=42, result={'x': 1}, error=None)" + + +def test_immutability_enforced_by_frozen_dataclass(): + req = GatewayRPCResponse.build(device_name="immu", request_id=1, result=None) + with pytest.raises((AttributeError, TypeError)): + setattr(req, "status", RPCStatus.ERROR) + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/entities/gateway/test_gateway_uplink_message.py b/tests/entities/gateway/test_gateway_uplink_message.py new file mode 100644 index 0000000..fd7cc40 --- /dev/null +++ b/tests/entities/gateway/test_gateway_uplink_message.py @@ -0,0 +1,152 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections import OrderedDict +from uuid import UUID + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_uplink_message import ( + GatewayUplinkMessage, + GatewayUplinkMessageBuilder +) + + +@pytest.mark.asyncio +async def test_direct_instantiation_forbidden(): + with pytest.raises(TypeError) as exc: + GatewayUplinkMessage(device_name="dev", device_profile="prof") + assert "Direct instantiation" in str(exc.value) + + +@pytest.mark.asyncio +async def test_repr_and_build_and_properties(): + # Prepare entities + attr1 = AttributeEntry("a1", "v1") + attr2 = AttributeEntry("a2", "v2") + ts1 = TimeseriesEntry("t1", 123) + ts2 = TimeseriesEntry("t2", 456) + + f1 = asyncio.get_event_loop().create_future() + f2 = asyncio.get_event_loop().create_future() + + msg = GatewayUplinkMessage.build( + device_name="dev1", + device_profile="prof1", + attributes=[attr1, attr2], + timeseries=OrderedDict({111: [ts1, ts2]}), + delivery_futures=[f1, f2], + size=123, + main_ts=999, + ) + + # __repr__ + rep = repr(msg) + assert "dev1" in rep and "prof1" in rep + + # Properties + assert msg.device_name == "dev1" + assert msg.device_profile == "prof1" + assert msg.size == 123 + assert msg.event_type == GatewayEventType.DEVICE_UPLINK + + # Datapoint counts + assert msg.timeseries_datapoint_count() == 2 + assert msg.attributes_datapoint_count() == 2 + assert msg.has_attributes() + assert msg.has_timeseries() + + # Futures + futs = msg.get_delivery_futures() + assert futs == (f1, f2) + + # set_main_ts + assert msg.set_main_ts(555).main_ts == 555 + + +@pytest.mark.asyncio +async def test_builder_minimal_and_defaults(): + b = GatewayUplinkMessageBuilder() + b.set_device_name("devname") + b.set_device_profile("profile") + attr = AttributeEntry("akey", "aval") + ts = TimeseriesEntry("tkey", 1) + b.add_attributes(attr) + b.add_timeseries(ts) + + # Delivery future gets added + fut = asyncio.get_event_loop().create_future() + b.add_delivery_futures(fut) + + b.set_main_ts(1000) + msg = b.build() + + assert isinstance(msg, GatewayUplinkMessage) + assert msg.device_name == "devname" + assert msg.device_profile == "profile" + assert msg.main_ts == 1000 + assert msg.has_attributes() + assert msg.has_timeseries() + + +@pytest.mark.asyncio +async def test_builder_defaults_when_not_set(): + # No device_profile and no futures -> builder should create defaults + b = GatewayUplinkMessageBuilder() + b.set_device_name("d1") + + msg = b.build() + assert msg.device_profile == "default" + assert isinstance(msg.get_delivery_futures()[0], asyncio.Future) + # The auto-created future has uuid + assert isinstance(msg.get_delivery_futures()[0].uuid, UUID) + + +@pytest.mark.asyncio +async def test_add_attributes_and_timeseries_variants(): + b = GatewayUplinkMessageBuilder() + b.set_device_name("dev") + attr_list = [AttributeEntry("a1", "v1"), AttributeEntry("a2", "v2")] + ts_list = [TimeseriesEntry("t1", 1), TimeseriesEntry("t2", 2)] + + # Add as list + b.add_attributes(attr_list) + b.add_timeseries(ts_list) + + # Add OrderedDict directly + od = OrderedDict({123: [TimeseriesEntry("k", 42)]}) + b.add_timeseries(od) + + msg = b.build() + assert msg.has_attributes() + assert msg.has_timeseries() + + +@pytest.mark.asyncio +async def test_len_tracks_size(): + b = GatewayUplinkMessageBuilder() + base_len = len(b) + b.set_device_name("abc") + b.set_device_profile("xyz") + b.add_attributes(AttributeEntry("k", "v")) + b.add_timeseries(TimeseriesEntry("ts", 1)) + assert len(b) > base_len + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/__init__.py b/tests/service/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/service/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/service/device/__init__.py b/tests/service/device/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/service/device/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/service/device/handlers/__init__.py b/tests/service/device/handlers/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/service/device/handlers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/service/device/handlers/test_attribute_updates_handler.py b/tests/service/device/handlers/test_attribute_updates_handler.py new file mode 100644 index 0000000..8cb294c --- /dev/null +++ b/tests/service/device/handlers/test_attribute_updates_handler.py @@ -0,0 +1,125 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + + +class DummyAdapter(MessageAdapter): + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + pass + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + pass + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def __init__(self, result=None, raise_exc=False): + super().__init__() + self.result = result + self.raise_exc = raise_exc + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + if self.raise_exc: + raise ValueError("Simulated parse error") + return self.result + + +@pytest.mark.asyncio +async def test_set_message_adapter_and_callback_called(): + handler = AttributeUpdatesHandler() + adapter_result = AttributeUpdate([AttributeEntry("key", "value")]) + + handler.set_message_adapter(DummyAdapter(result=adapter_result)) + + called = {} + + async def cb(update: AttributeUpdate): + called["data"] = update + + handler.set_callback(cb) + + await handler.handle("topic", b"payload") + assert called["data"] == adapter_result + + +def test_set_message_adapter_invalid_type(): + handler = AttributeUpdatesHandler() + with pytest.raises(ValueError) as exc: + handler.set_message_adapter("not-a-message-adapter") + assert "message_adapter must be an instance of MessageAdapter" in str(exc.value) + + +@pytest.mark.asyncio +async def test_handle_no_callback_set(): + handler = AttributeUpdatesHandler() + # even if message_adapter is set, no callback means skip + handler.set_message_adapter(DummyAdapter(result=AttributeUpdate({}))) + # should simply return without error + await handler.handle("topic", b"payload") + + +@pytest.mark.asyncio +async def test_handle_parse_error(): + handler = AttributeUpdatesHandler() + handler.set_message_adapter(DummyAdapter(raise_exc=True)) + + called = {"called": False} + + async def cb(update: AttributeUpdate): + called["called"] = True + + handler.set_callback(cb) + # parse will raise, callback should not be called + await handler.handle("topic", b"payload") + assert not called["called"] + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/test_requested_attributes_response_handler.py b/tests/service/device/handlers/test_requested_attributes_response_handler.py new file mode 100644 index 0000000..f3a0986 --- /dev/null +++ b/tests/service/device/handlers/test_requested_attributes_response_handler.py @@ -0,0 +1,194 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Union, List + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ + RequestedAttributeResponseHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + + +class DummyMessageAdapter(MessageAdapter): + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + pass + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + pass + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def __init__(self, response=None, exc=None): + super().__init__() + self._response = response + self._exc = exc + + def parse_requested_attribute_response(self, topic, payload): + if self._exc: + raise self._exc + return self._response + + +@pytest.mark.asyncio +async def test_set_message_adapter_and_full_handle_flow(): + handler = RequestedAttributeResponseHandler() + # Build real AttributeRequest + request = await AttributeRequest.build(shared_keys=["temp"], client_keys=["c1"]) + + called = {} + + async def cb(resp): + called["val"] = resp["temp"] + + # Register request + await handler.register_request(request, cb) + + # Create real RequestedAttributeResponse with matching ID + resp = RequestedAttributeResponse( + request_id=request.request_id, + shared=[AttributeEntry("temp", 42)], + client=[AttributeEntry("c1", "v1")] + ) + + # Set adapter returning our response + handler.set_message_adapter(DummyMessageAdapter(response=resp)) + + # Handle message + await handler.handle(f"some/topic/{request.request_id}", b"payload") + + assert called["val"] == 42 + assert request.request_id not in handler._pending_attribute_requests + + +def test_set_message_adapter_invalid_type(): + handler = RequestedAttributeResponseHandler() + with pytest.raises(ValueError): + handler.set_message_adapter(object()) + + +@pytest.mark.asyncio +async def test_register_request_duplicate(): + handler = RequestedAttributeResponseHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + with pytest.raises(RuntimeError): + await handler.register_request(req, lambda r: None) + + +@pytest.mark.asyncio +async def test_unregister_existing_and_non_existing(): + handler = RequestedAttributeResponseHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + + assert req.request_id in handler._pending_attribute_requests + handler.unregister_request(req.request_id) + assert req.request_id not in handler._pending_attribute_requests + + # No error for missing ID + handler.unregister_request(9999) + + +@pytest.mark.asyncio +async def test_handle_without_message_adapter_removes_request(): + handler = RequestedAttributeResponseHandler() + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + await handler.handle(f"topic/{req.request_id}", b"payload") + assert req.request_id not in handler._pending_attribute_requests + + +@pytest.mark.asyncio +async def test_handle_with_no_pending_request(): + handler = RequestedAttributeResponseHandler() + resp = RequestedAttributeResponse( + request_id=999, + shared=[], + client=[] + ) + handler.set_message_adapter(DummyMessageAdapter(response=resp)) + # No request registered → should just return + await handler.handle("topic/999", b"payload") + + +@pytest.mark.asyncio +async def test_handle_with_no_callback(): + handler = RequestedAttributeResponseHandler() + req = await AttributeRequest.build() + await handler.register_request(req, None) + resp = RequestedAttributeResponse( + request_id=req.request_id, + shared=[], + client=[] + ) + handler.set_message_adapter(DummyMessageAdapter(response=resp)) + await handler.handle(f"topic/{req.request_id}", b"payload") + + +@pytest.mark.asyncio +async def test_handle_with_parse_exception(): + handler = RequestedAttributeResponseHandler() + req = await AttributeRequest.build() + await handler.register_request(req, lambda r: None) + handler.set_message_adapter(DummyMessageAdapter(exc=ValueError("bad parse"))) + await handler.handle(f"topic/{req.request_id}", b"payload") + + +def test_clear(): + handler = RequestedAttributeResponseHandler() + handler._pending_attribute_requests[1] = ("req", lambda r: None) + handler.clear() + assert not handler._pending_attribute_requests + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/test_rpc_requests_handler.py b/tests/service/device/handlers/test_rpc_requests_handler.py new file mode 100644 index 0000000..197bc42 --- /dev/null +++ b/tests/service/device/handlers/test_rpc_requests_handler.py @@ -0,0 +1,226 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, List + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse, RPCStatus +from tb_mqtt_client.service.device.handlers.rpc_requests_handler import RPCRequestsHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter + + +class DummyMessageAdapter(MessageAdapter): + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + pass + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + pass + + def parse_rpc_response(self, topic: str, payload: Union[bytes, Exception]) -> RPCResponse: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def parse_rpc_request(self, topic, payload): + # Simulate payload as dict with required keys + data = {"method": "testMethod", "params": {"foo": "bar"}} + return RPCRequest._deserialize_from_dict(42, data) + + +@pytest.mark.asyncio +async def test_handle_no_callback_returns_none(): + handler = RPCRequestsHandler() + handler.set_message_adapter(DummyMessageAdapter()) + # No callback set + result = await handler.handle("topic", b"payload") + assert result is None + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_returns_none(): + handler = RPCRequestsHandler() + + async def cb(_): + return RPCResponse.build(1, result={"ok": True}) + + handler.set_callback(cb) + # No message adapter set + result = await handler.handle("topic", b"payload") + assert result is None + + +@pytest.mark.asyncio +async def test_handle_callback_returns_valid_response(): + handler = RPCRequestsHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + async def cb(rpc_request: RPCRequest): + assert isinstance(rpc_request, RPCRequest) + return RPCResponse.build(rpc_request.request_id, result={"success": True}) + + handler.set_callback(cb) + result = await handler.handle("topic", b"payload") + assert isinstance(result, RPCResponse) + assert result.status == RPCStatus.SUCCESS + assert result.result == {"success": True} + assert result.error is None + + +@pytest.mark.asyncio +async def test_handle_callback_returns_invalid_type(): + handler = RPCRequestsHandler() + handler.set_message_adapter(DummyMessageAdapter()) + + async def cb(_): + return "not-an-rpc-response" + + handler.set_callback(cb) + result = await handler.handle("topic", b"payload") + assert result is None + + +@pytest.mark.asyncio +async def test_handle_parse_raises_exception(): + class BadAdapter(DummyMessageAdapter): + def parse_rpc_request(self, topic, payload): + raise RuntimeError("bad parse") + + handler = RPCRequestsHandler() + handler.set_message_adapter(BadAdapter()) + + async def cb(_): + return RPCResponse.build(1, result={}) + + handler.set_callback(cb) + result = await handler.handle("topic", b"payload") + assert result is None + + +def test_set_message_adapter_invalid_type(): + handler = RPCRequestsHandler() + with pytest.raises(ValueError): + handler.set_message_adapter("not-a-message-adapter") + + +def test_rpc_response_build_success_and_repr_and_payload(): + resp = RPCResponse.build(1, result={"x": 1}) + assert resp.status == RPCStatus.SUCCESS + assert resp.result == {"x": 1} + assert resp.error is None + assert "RPCResponse" in repr(resp) + assert resp.to_payload_format() == {"result": {"x": 1}} + + +def test_rpc_response_build_with_str_error(): + resp = RPCResponse.build(1, error="something went wrong") + assert resp.status == RPCStatus.ERROR + assert resp.error == "something went wrong" + assert "error" in resp.to_payload_format() + + +def test_rpc_response_build_with_dict_error(): + err_dict = {"code": 123} + resp = RPCResponse.build(1, error=err_dict) + assert resp.status == RPCStatus.ERROR + assert resp.error == err_dict + + +def test_rpc_response_build_with_exception_error(): + exc = ValueError("bad value") + resp = RPCResponse.build(1, error=exc) + assert resp.status == RPCStatus.ERROR + assert isinstance(resp.error, dict) + assert "message" in resp.error + assert "type" in resp.error + assert "details" in resp.error + + +def test_rpc_response_invalid_error_type(): + with pytest.raises(ValueError): + RPCResponse.build(1, error=object()) + + +def test_rpc_response_direct_init_disallowed(): + with pytest.raises(TypeError): + RPCResponse(1, result={}) + + +@pytest.mark.asyncio +async def test_rpc_request_build_and_str_and_payload_format(): + req = await RPCRequest.build("myMethod", params={"y": 2}) + assert req.method == "myMethod" + assert req.params == {"y": 2} + assert "myMethod" in str(req) + payload = req.to_payload_format() + assert payload["method"] == "myMethod" + assert payload["params"] == {"y": 2} + + +@pytest.mark.asyncio +async def test_rpc_request_build_invalid_method_type(): + with pytest.raises(ValueError): + await RPCRequest.build(123) + + +def test_rpc_request_deserialize_missing_method(): + with pytest.raises(ValueError): + RPCRequest._deserialize_from_dict(1, {}) + + +def test_rpc_request_deserialize_invalid_id_type(): + with pytest.raises(ValueError): + RPCRequest._deserialize_from_dict(object(), {"method": "x"}) + + +def test_rpc_request_direct_init_disallowed(): + with pytest.raises(TypeError): + RPCRequest(1, "method") + + +def test_rpc_request_to_payload_format_without_params(): + req = RPCRequest._deserialize_from_dict(5, {"method": "noParams"}) + payload = req.to_payload_format() + assert "params" not in payload + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/handlers/test_rpc_response_handler.py b/tests/service/device/handlers/test_rpc_response_handler.py new file mode 100644 index 0000000..324b865 --- /dev/null +++ b/tests/service/device/handlers/test_rpc_response_handler.py @@ -0,0 +1,194 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import List + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter, JsonMessageAdapter + + +class DummyMessageAdapter(MessageAdapter): + def build_uplink_messages(self, messages: List[MqttPublishMessage]) -> List[MqttPublishMessage]: + pass + + def build_attribute_request(self, request: AttributeRequest) -> MqttPublishMessage: + pass + + def build_claim_request(self, claim_request) -> MqttPublishMessage: + pass + + def build_rpc_request(self, rpc_request: RPCRequest) -> MqttPublishMessage: + pass + + def build_rpc_response(self, rpc_response: RPCResponse) -> MqttPublishMessage: + pass + + def build_provision_request(self, provision_request) -> MqttPublishMessage: + pass + + def parse_requested_attribute_response(self, topic: str, payload: bytes) -> RequestedAttributeResponse: + pass + + def parse_attribute_update(self, payload: bytes) -> AttributeUpdate: + pass + + def parse_rpc_request(self, topic: str, payload: bytes) -> RPCRequest: + pass + + def parse_provisioning_response(self, provisioning_request: ProvisioningRequest, + payload: bytes) -> ProvisioningResponse: + pass + + def __init__(self, rpc_response): + super().__init__() + self._resp = rpc_response + + def parse_rpc_response(self, topic, payload): + return self._resp + + +@pytest.mark.asyncio +async def test_set_message_adapter_invalid_type(): + handler = RPCResponseHandler() + with pytest.raises(ValueError): + handler.set_message_adapter("not-an-adapter") + + +@pytest.mark.asyncio +async def test_register_and_handle_success_with_callback(): + handler = RPCResponseHandler() + resp = RPCResponse.build(1, result={"ok": True}) + handler.set_message_adapter(DummyMessageAdapter(resp)) + + called = {} + + async def cb(r: RPCResponse): + called["cb"] = r + + fut = handler.register_request(1, callback=cb) + await handler.handle("topic", b"payload") + assert fut.done() + result = fut.result() + assert isinstance(result, RPCResponse) + assert result.result == {"ok": True} + assert called["cb"] == resp + + +@pytest.mark.asyncio +async def test_register_request_duplicate_id_raises(): + handler = RPCResponseHandler() + handler._pending_rpc_requests[5] = (asyncio.get_event_loop().create_future(), None) + with pytest.raises(RuntimeError): + await handler.register_request(5) + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_uses_json_adapter(): + handler = RPCResponseHandler() + # Build a real RPCResponse and serialize to payload + resp = RPCResponse.build(99, result={"foo": "bar"}) + message = JsonMessageAdapter().build_rpc_response(resp) + # Register request so we can match + fut = handler.register_request(99) + # Should succeed even without explicitly setting message adapter + await handler.handle("v1/devices/me/rpc/response/99", message.payload) + assert fut.done() + assert fut.result().result == {'result': {"foo": "bar"}} + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_uses_json_adapter_with_error_rpc(): + handler = RPCResponseHandler() + # Build a real RPCResponse and serialize to payload + resp = RPCResponse.build(99, error="Some error occurred") + message = JsonMessageAdapter().build_rpc_response(resp) + # Register request so we can match + fut = handler.register_request(99) + # Should succeed even without explicitly setting message adapter + await handler.handle("v1/devices/me/rpc/response/99", message.payload) + assert fut.done() + assert fut.result().result == {'error': "Some error occurred"} + + +@pytest.mark.asyncio +async def test_handle_response_for_unknown_request_id(): + handler = RPCResponseHandler() + resp = RPCResponse.build("abc", result={"zzz": 1}) + # No registered request with "abc" + handler.set_message_adapter(DummyMessageAdapter(resp)) + await handler.handle("topic", b"payload") # Should log warning and return + # Nothing to assert except no crash + + +@pytest.mark.asyncio +async def test_handle_with_no_future(): + handler = RPCResponseHandler() + resp = RPCResponse.build(777, result={"x": 1}) + handler.set_message_adapter(DummyMessageAdapter(resp)) + handler._pending_rpc_requests[777] = (None, None) + await handler.handle("topic", b"payload") # Should log warning and return + + +@pytest.mark.asyncio +async def test_handle_callback_exception_sets_future_exception(): + handler = RPCResponseHandler() + resp = RPCResponse.build(42, result={"ok": True}) + handler.set_message_adapter(DummyMessageAdapter(resp)) + + async def bad_cb(_): + raise RuntimeError("bad") + + fut = handler.register_request(42, callback=bad_cb) + await handler.handle("topic", b"payload") + assert fut.done() + with pytest.raises(RuntimeError): + fut.result() + + +@pytest.mark.asyncio +async def test_handle_with_error_field_sets_future_exception(): + handler = RPCResponseHandler() + resp = RPCResponse.build(321, error="fail") + handler.set_message_adapter(DummyMessageAdapter(resp)) + fut = handler.register_request(321) + await handler.handle("topic", b"payload") + assert fut.done() + with pytest.raises(Exception) as e: + fut.result() + assert "fail" in str(e.value) + + +@pytest.mark.asyncio +async def test_clear_cancels_pending_futures(): + handler = RPCResponseHandler() + f1 = asyncio.get_event_loop().create_future() + handler._pending_rpc_requests[1] = (f1, None) + handler.clear() + assert f1.cancelled() + assert handler._pending_rpc_requests == {} + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/service/device/test_device_client.py b/tests/service/device/test_device_client.py new file mode 100644 index 0000000..73a2ebf --- /dev/null +++ b/tests/service/device/test_device_client.py @@ -0,0 +1,642 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants.mqtt_topics import DEVICE_CLAIM_TOPIC +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, AccessTokenProvisioningCredentials +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.service.device.client import DeviceClient + + +@pytest.mark.asyncio +async def test_send_timeseries_with_dict(): + # Setup + client = DeviceClient() + client._mqtt_manager._handle_puback_reason_code = ( + client._mqtt_manager._handle_puback_reason_code.__get__(client._mqtt_manager) + ) + + class DummyConnection: + def __init__(self): + self.published_messages = [] + + def publish(self, *args, **kwargs): + self.published_messages.append(args[0]) + return 1, 1 + + async def fake_subscribe(*args, **kwargs): + fut = asyncio.get_event_loop().create_future() + fut.set_result(1) + return fut + + client._mqtt_manager.subscribe = fake_subscribe + + # Patch MQTTManager.connect to simulate connection without real network + async def fake_connect(**kwargs): + client._mqtt_manager._is_connected = True + client._mqtt_manager._connected_event.set() + + client._mqtt_manager.connect = fake_connect + client._mqtt_manager._rate_limits_ready_event.set() + client._mqtt_manager.is_connected = lambda: True + + client._mqtt_manager._client._connection = DummyConnection() + client._mqtt_manager._client._persistent_storage.push_message_nowait = lambda *args, **kwargs: None + + async def fake_await_ready(): + return + + client._mqtt_manager.await_ready = fake_await_ready + + # Call connect (real queue and adapter will be initialized) + await client.connect() + + # Act: send timeseries + delivery_futures = await client.send_timeseries({"temp": 22}, wait_for_publish=False) + await asyncio.sleep(.1) + + # Trigger PUBACK manually + client._mqtt_manager._handle_puback_reason_code( + 1, + reason_code=0, + properties={} + ) + await asyncio.sleep(1) + + # Await result + result = await asyncio.wait_for(delivery_futures[0], timeout=1) + + mqtt_msg = client._mqtt_manager._client._connection.published_messages[0] + + # Assert + assert isinstance(result, PublishResult) + # The initial mqtt message doesn't contain message_id, + # expected behavior because an initial message can be split to separated messages or grouped with other messages + assert result.message_id == -1 + assert result.topic == mqtt_msg.topic + assert result.payload_size == mqtt_msg.payload_size + assert result.datapoints_count == mqtt_msg.datapoints + assert result.reason_code == 0 + + +@pytest.mark.asyncio +async def test_send_timeseries_timeout(): + client = DeviceClient() + client._message_queue = AsyncMock() + future = asyncio.Future() + client._message_queue.publish.return_value = [future] + result = await client.send_timeseries({"temp": 22}, timeout=0.01) + assert isinstance(result, PublishResult) + assert result.message_id == -1 + + +@pytest.mark.asyncio +async def test_send_attributes_with_dict(): + # Setup + client = DeviceClient() + client._mqtt_manager._handle_puback_reason_code = ( + client._mqtt_manager._handle_puback_reason_code.__get__(client._mqtt_manager) + ) + + class DummyConnection: + def __init__(self): + self.published_messages = [] + + def publish(self, *args, **kwargs): + self.published_messages.append(args[0]) + return 2, 1 # Simulate mid=2, qos=1 + + async def fake_subscribe(*args, **kwargs): + fut = asyncio.get_event_loop().create_future() + fut.set_result(1) + return fut + + client._mqtt_manager.subscribe = fake_subscribe + + async def fake_connect(**kwargs): + client._mqtt_manager._is_connected = True + client._mqtt_manager._connected_event.set() + + client._mqtt_manager.connect = fake_connect + client._mqtt_manager._rate_limits_ready_event.set() + client._mqtt_manager.is_connected = lambda: True + + client._mqtt_manager._client._connection = DummyConnection() + client._mqtt_manager._client._persistent_storage.push_message_nowait = lambda *args, **kwargs: None + + async def fake_await_ready(): + return + + client._mqtt_manager.await_ready = fake_await_ready + + await client.connect() + + # Act: send attributes + delivery_futures = await client.send_attributes({"key": "value"}, wait_for_publish=False) + await asyncio.sleep(.1) + + # Trigger PUBACK manually + client._mqtt_manager._handle_puback_reason_code( + 2, reason_code=0, properties={} + ) + await asyncio.sleep(1) + + result = await asyncio.wait_for(delivery_futures[0], timeout=1) + + mqtt_msg = client._mqtt_manager._client._connection.published_messages[0] + + # Assert + assert isinstance(result, PublishResult) + assert result.message_id == -1 # Split/group logic applies + assert result.topic == mqtt_msg.topic + assert result.payload_size == mqtt_msg.payload_size + assert result.datapoints_count == mqtt_msg.datapoints + assert result.reason_code == 0 + + +@pytest.mark.asyncio +async def test_send_rpc_request_timeout(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_request import RPCRequest + request = await RPCRequest.build(method="get", params={}) + client._rpc_response_handler.register_request = MagicMock(return_value=asyncio.Future()) + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [] + with pytest.raises(TimeoutError): + await client.send_rpc_request(request, wait_for_publish=True, timeout=0.01) + + +@pytest.mark.asyncio +async def test_send_rpc_response(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_response import RPCResponse + response = RPCResponse.build(123, result={"ok": True}) + client._message_queue = AsyncMock() + await client.send_rpc_response(response) + client._message_queue.publish.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_claim_device_success(): + client = DeviceClient() + client._message_queue = AsyncMock() + + claim = ClaimRequest.build(secret_key="abc") + + fut = asyncio.Future() + fut.set_result(PublishResult(topic=DEVICE_CLAIM_TOPIC, qos=1, message_id=1, payload_size=12, reason_code=0)) + client._message_queue.publish.return_value = [fut] + + result = await client.claim_device(claim) + assert isinstance(result, PublishResult) + assert result.topic == DEVICE_CLAIM_TOPIC + + +@pytest.mark.asyncio +async def test_claim_device_timeout(): + client = DeviceClient() + client._message_queue = AsyncMock() + + claim = ClaimRequest.build(secret_key="abc") + fut = asyncio.Future() + client._message_queue.publish.return_value = [fut] + + result = await client.claim_device(claim, timeout=0.01) + assert isinstance(result, PublishResult) + assert result.message_id == -1 + + +@pytest.mark.asyncio +async def test_claim_device_payload_contains_secret_key(): + client = DeviceClient() + client._message_queue = AsyncMock() + + claim = ClaimRequest.build(secret_key="my-secret") + fut = asyncio.Future() + fut.set_result(PublishResult(topic=DEVICE_CLAIM_TOPIC, qos=1, message_id=3, payload_size=15, reason_code=0)) + client._message_queue.publish.return_value = [fut] + + await client.claim_device(claim) + + client._message_queue.publish.assert_awaited_once() + args, kwargs = client._message_queue.publish.call_args + mqtt_msg = args[0] + assert isinstance(mqtt_msg, MqttPublishMessage) + assert mqtt_msg.topic == DEVICE_CLAIM_TOPIC + assert b"my-secret" in mqtt_msg.payload + + +@pytest.mark.asyncio +async def test_send_attribute_request(): + client = DeviceClient() + from tb_mqtt_client.entities.data.attribute_request import AttributeRequest + request = await AttributeRequest.build(client_keys=["key1"]) + client._requested_attribute_response_handler.register_request = AsyncMock() + client._message_queue = AsyncMock() + await client.send_attribute_request(request, AsyncMock()) + client._message_queue.publish.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect(): + client = DeviceClient() + client._mqtt_manager.disconnect = AsyncMock() + await client.disconnect() + client._mqtt_manager.disconnect.assert_awaited() + + +@pytest.mark.asyncio +async def test_stop_disconnects_and_shuts_down_queue(): + client = DeviceClient() + client._mqtt_manager.is_connected = lambda: True + client._mqtt_manager.stop = AsyncMock() + client._message_queue = MagicMock() + client._message_queue.shutdown = AsyncMock() + await client.stop() + client._message_queue.shutdown.assert_awaited() + client._mqtt_manager.stop.assert_awaited() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_invalid_response(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_response import RPCResponse + bad_resp = RPCResponse.build(1, result="invalid") + result = await client._handle_rate_limit_response(bad_resp) + assert result is None + + +@pytest.mark.asyncio +async def test_handle_rate_limit_valid_response(): + client = DeviceClient() + from tb_mqtt_client.entities.data.rpc_response import RPCResponse + client._mqtt_manager.set_rate_limits = MagicMock() + payload = { + "rateLimits": { + "messages": "10:1,", + "telemetryMessages": "100:60,", + "telemetryDataPoints": "500:60," + }, + "maxInflightMessages": 200, + "maxPayloadSize": 512 + } + resp = RPCResponse.build(1, result=payload) + result = await client._handle_rate_limit_response(resp) + assert result is True + assert client.max_payload_size <= 512 + + +@pytest.mark.asyncio +async def test_provision_success(monkeypatch): + from tb_mqtt_client.service.device import client as client_module + mock_prov = AsyncMock() + mock_prov.provision.return_value = "creds" + monkeypatch.setattr(client_module, "ProvisioningClient", lambda **kwargs: mock_prov) + credentials = AccessTokenProvisioningCredentials("provision_device_key", "provision_device_secret") + req = ProvisioningRequest(host="localhost", credentials=credentials, port=1883, device_name="dev") + res = await DeviceClient.provision(req) + assert res == "creds" + + +@pytest.mark.asyncio +async def test_connects_to_platform_with_tls_using_device_config(): + config = DeviceConfig() + config.use_tls = MagicMock(return_value=True) + config.ca_cert = "path/to/ca_cert" + config.client_cert = "path/to/client_cert" + config.private_key = "path/to/private_key" + + mqtt_manager = AsyncMock() + client = DeviceClient(config=config) + client._mqtt_manager = mqtt_manager + client._mqtt_manager.connect = AsyncMock() + client._mqtt_manager.is_connected = MagicMock(return_value=True) + client._mqtt_manager.await_ready = AsyncMock() + + with patch("tb_mqtt_client.service.device.client.ssl.create_default_context") as mock_ssl_context: + ssl_context = mock_ssl_context.return_value + ssl_context.load_verify_locations = MagicMock() + ssl_context.load_cert_chain = MagicMock() + await client.connect() + + mock_ssl_context.assert_called_once() + ssl_context.load_verify_locations.assert_called_once_with(config.ca_cert) + ssl_context.load_cert_chain.assert_called_once_with(certfile=config.client_cert, keyfile=config.private_key) + client._mqtt_manager.connect.assert_awaited_once_with( + host=config.host, + port=config.port, + username=config.access_token or config.username, + password=None if config.access_token else config.password, + tls=True, + ssl_context=ssl_context) + + +@pytest.mark.asyncio +async def test_connects_to_platform_with_tls(): + config = DeviceConfig() + config.use_tls = MagicMock(return_value=True) + config.ca_cert = "path/to/ca_cert" + config.client_cert = "path/to/client_cert" + config.private_key = "path/to/private_key" + config.host = "localhost" + config.port = 8883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + + with patch("tb_mqtt_client.service.device.client.ssl.create_default_context") as mock_ssl_context: + ssl_context = mock_ssl_context.return_value + await client.connect() + + mock_ssl_context.assert_called_once() + ssl_context.load_verify_locations.assert_called_once_with(config.ca_cert) + ssl_context.load_cert_chain.assert_called_once_with( + certfile=config.client_cert, + keyfile=config.private_key + ) + mqtt_manager.connect.assert_awaited_once_with( + host=config.host, + port=config.port, + username=config.access_token, + password=None, + tls=True, + ssl_context=ssl_context + ) + + +@pytest.mark.asyncio +async def test_connects_to_platform_without_tls(): + config = DeviceConfig() + config.use_tls = MagicMock(return_value=False) + config.host = "localhost" + config.port = 1883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + + await client.connect() + + mqtt_manager.connect.assert_awaited_once_with( + host=config.host, + port=config.port, + username=config.access_token, + password=None, + tls=False, + ssl_context=None + ) + + +@pytest.mark.asyncio +async def test_stops_if_event_is_set_during_connection(): + config = DeviceConfig() + config.host = "localhost" + config.port = 1883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + mqtt_manager.is_connected.return_value = False + mqtt_manager.await_ready = AsyncMock() + + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + client._stop_event.set() + + await client.connect() + + mqtt_manager.await_ready.assert_not_awaited() + mqtt_manager.is_connected.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_initializes_adapter_and_queue_after_connection(): + config = DeviceConfig() + config.host = "localhost" + config.port = 1883 + config.access_token = "token" + config.username = None + config.password = None + + mqtt_manager = AsyncMock() + mqtt_manager.is_connected.return_value = True + + with patch("tb_mqtt_client.service.device.client.MessageService") as mock_message_service: + client = DeviceClient(config) + client._mqtt_manager = mqtt_manager + + await client.connect() + + assert client.max_payload_size == 65535 + assert client._message_adapter is not None + assert client._message_queue is not None + + mock_message_service.assert_called_once() + kwargs = mock_message_service.call_args.kwargs + assert kwargs["max_queue_size"] == client._max_uplink_message_queue_size + + +class FakeSplitter: + def __init__(self): + self.max_payload_size = None + + +@pytest.mark.asyncio +async def test_uses_default_max_payload_size_when_not_provided(): + client = DeviceClient() + client.max_payload_size = None + + splitter = FakeSplitter() + client._message_adapter = MagicMock() + client._message_adapter.splitter = splitter + + resp = RPCResponse.build(1, result={"rateLimits": {}}) + await client._handle_rate_limit_response(resp) + + assert client.max_payload_size == 65535 + assert splitter.max_payload_size == 65535 + + +@pytest.mark.asyncio +async def test_does_not_update_adapter_when_not_initialized(): + client = DeviceClient() + client.max_payload_size = None + client._message_adapter = None + + resp = RPCResponse.build(1, result={"rateLimits": {}}) + await client._handle_rate_limit_response(resp) + + assert client.max_payload_size == 65535 + + +@pytest.mark.asyncio +async def test_send_timeseries_without_connect_raises_error(): + client = DeviceClient() + with pytest.raises(AttributeError): + await client.send_timeseries({"temp": 22}) + + +@pytest.mark.asyncio +async def test_handle_attribute_update_calls_handler(): + client = DeviceClient() + mock_handler = AsyncMock() + client._attribute_updates_handler.handle = mock_handler + await client._handle_attribute_update("topic", b'{"key":"value"}') + mock_handler.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_rpc_request_triggers_rpc_response(): + client = DeviceClient() + mock_handler = AsyncMock(return_value=RPCResponse.build(1, result={"res": "ok"})) + client._rpc_requests_handler.handle = mock_handler + client.send_rpc_response = AsyncMock() + await client._handle_rpc_request("topic", b'{"method": "test"}') + client.send_rpc_response.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response_with_partial_data(): + client = DeviceClient() + response = RPCResponse.build(1, result={"rateLimits": {"messages": "5:1"}}) + result = await client._handle_rate_limit_response(response) + assert result is True + assert client._rate_limiter.message_rate_limit.has_limit() + assert client.max_payload_size == 65535 + + +@pytest.mark.asyncio +async def test_update_firmware_triggers_firmware_updater(): + client = DeviceClient() + updater = AsyncMock() + client._firmware_updater = updater + await client.update_firmware() + updater.update.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_rpc_request_with_callback_executes(): + client = DeviceClient() + rpc = await RPCRequest.build(method="get", params={}) + client._rpc_response_handler.register_request = MagicMock(return_value=AsyncMock()) + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [] + + callback = AsyncMock() + await client.send_rpc_request(rpc, callback=callback, wait_for_publish=False) + client._rpc_response_handler.register_request.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_attribute_update_callback_sets_handler(): + client = DeviceClient() + cb = AsyncMock() + client.set_attribute_update_callback(cb) + assert client._attribute_updates_handler._callback == cb + + +@pytest.mark.asyncio +async def test_send_timeseries_with_invalid_type_raises(): + client = DeviceClient() + with pytest.raises(ValueError): + await client.send_timeseries("invalid") + + +@pytest.mark.asyncio +async def test_send_attributes_no_wait(): + client = DeviceClient() + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [asyncio.Future()] + result = await client.send_attributes({"attr": "val"}, wait_for_publish=False) + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + + +@pytest.mark.asyncio +async def test_claim_device_returns_future_when_not_waiting(): + client = DeviceClient() + claim = ClaimRequest.build("secret_key") + client._message_queue = AsyncMock() + client._message_queue.publish.return_value = [asyncio.Future()] + result = await client.claim_device(claim, wait_for_publish=False) + assert isinstance(result, asyncio.Future) + + +@pytest.mark.asyncio +async def test_handle_rpc_response_calls_handler(): + client = DeviceClient() + mock = AsyncMock() + client._rpc_response_handler.handle = mock + await client._handle_rpc_response("topic", b"payload") + mock.assert_awaited_once_with("topic", b"payload") + + +@pytest.mark.asyncio +async def test_handle_requested_attribute_response_calls_handler(): + client = DeviceClient() + mock = AsyncMock() + client._requested_attribute_response_handler.handle = mock + await client._handle_requested_attribute_response("topic", b"payload") + mock.assert_awaited_once_with("topic", b"payload") + + +@pytest.mark.asyncio +async def test_on_disconnect_clears_handlers(): + client = DeviceClient() + client._requested_attribute_response_handler.clear = MagicMock() + await client._on_disconnect() + client._requested_attribute_response_handler.clear.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_rpc_request_callback_sets_handler(): + client = DeviceClient() + cb = AsyncMock() + client.set_rpc_request_callback(cb) + assert client._rpc_requests_handler._callback == cb + + +@pytest.mark.asyncio +async def test_provision_timeout(monkeypatch): + from tb_mqtt_client.service.device import client as client_module + mock_prov = AsyncMock() + mock_prov.provision.side_effect = asyncio.TimeoutError() + monkeypatch.setattr(client_module, "ProvisioningClient", lambda **kwargs: mock_prov) + + credentials = AccessTokenProvisioningCredentials("key", "secret") + req = ProvisioningRequest(host="localhost", credentials=credentials, port=1883, device_name="dev") + result = await DeviceClient.provision(req, timeout=0.01) + assert result is None + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/device/test_firmware_updater.py b/tests/service/device/test_firmware_updater.py new file mode 100644 index 0000000..c0eb751 --- /dev/null +++ b/tests/service/device/test_firmware_updater.py @@ -0,0 +1,261 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from hashlib import sha256, sha384, sha512, md5 +from unittest.mock import AsyncMock, MagicMock, patch, ANY +from zlib import crc32 + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.firmware import (FirmwareStates, FW_STATE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR, + FW_SIZE_ATTR, FW_CHECKSUM_ALG_ATTR, REQUIRED_SHARED_KEYS, + FW_CHECKSUM_ATTR) +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater + + +@pytest.fixture +def mock_client(): + client = MagicMock() + client._mqtt_manager.register_handler = MagicMock() + client._mqtt_manager.subscribe = AsyncMock(return_value=asyncio.Future()) + client._mqtt_manager.subscribe.return_value.set_result(True) + client._mqtt_manager.unsubscribe = AsyncMock() + client._mqtt_manager.is_connected.return_value = True + client._message_queue.publish = AsyncMock() + client.send_timeseries = AsyncMock() + client.send_attribute_request = AsyncMock() + return client + + +@pytest.fixture +def updater(mock_client): + return FirmwareUpdater(mock_client) + + +@pytest.mark.asyncio +async def test_update_success(updater, mock_client): + with patch("tb_mqtt_client.entities.data.attribute_request.AttributeRequest.build", new_callable=AsyncMock), \ + patch.object(updater, "_firmware_info_callback", new=AsyncMock()): + await updater.update() + mock_client._mqtt_manager.subscribe.assert_called_once() + mock_client.send_timeseries.assert_called() + mock_client.send_attribute_request.assert_called_once() + + +@pytest.mark.asyncio +async def test_update_not_connected(updater, mock_client): + mock_client._mqtt_manager.is_connected.return_value = False + await updater.update() + + +@pytest.mark.asyncio +async def test_handle_firmware_update_full(updater): + updater._target_firmware_length = 4 + updater._chunk_size = 4 + updater._firmware_data = b'ab' + payload = b'cd' + with patch.object(updater, '_verify_downloaded_firmware', new=AsyncMock()) as verify: + await updater._handle_firmware_update(None, payload) + verify.assert_awaited_once() + assert updater._firmware_data == b'abcd' + + +@pytest.mark.asyncio +async def test_handle_firmware_update_partial(updater): + updater._target_firmware_length = 10 + updater._chunk_size = 5 + updater._firmware_data = b'123' + payload = b'456' + with patch.object(updater, '_get_next_chunk', new=AsyncMock()) as next_chunk: + await updater._handle_firmware_update(None, payload) + next_chunk.assert_awaited_once() + assert updater._firmware_data.endswith(payload) + + +@pytest.mark.asyncio +async def test_get_next_chunk_valid(updater, mock_client): + updater._chunk_size = 5 + updater._target_firmware_length = 10 + updater._firmware_request_id = 1 + updater._current_chunk = 2 + await updater._get_next_chunk() + mock_client._message_queue.publish.assert_awaited() + + +@pytest.mark.asyncio +async def test_get_next_chunk_empty_payload(updater, mock_client): + updater._chunk_size = 15 + updater._target_firmware_length = 10 + await updater._get_next_chunk() + mock_client._message_queue.publish.assert_awaited() + + +@pytest.mark.asyncio +async def test_verify_downloaded_firmware_success(updater): + updater._firmware_data = b'data' + updater._target_checksum = "8d777f385d3dfec8815d20f7496026dc" + updater._target_checksum_alg = "md5" + updater.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value + with patch.object(updater, 'verify_checksum', return_value=True), \ + patch.object(updater, '_apply_downloaded_firmware', new=AsyncMock()), \ + patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): + await updater._verify_downloaded_firmware() + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.VERIFIED.value + + +@pytest.mark.asyncio +async def test_verify_downloaded_firmware_fail(updater): + updater._firmware_data = b'data' + updater._target_checksum = "wrong" + updater._target_checksum_alg = "md5" + with patch.object(updater, 'verify_checksum', return_value=False), \ + patch.object(updater, '_send_current_firmware_info', new=AsyncMock()): + await updater._verify_downloaded_firmware() + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + + +@pytest.mark.asyncio +async def test_apply_downloaded_firmware_saves_file(tmp_path, updater): + updater._firmware_data = b'binary-firmware' + updater._target_title = 'fw.bin' + updater._target_version = 'v3' + updater._save_path = str(tmp_path) + updater._save_firmware = True + updater._on_received_callback = AsyncMock() + with patch.object(updater, '_send_current_firmware_info', new=AsyncMock()), \ + patch.object(updater._client._mqtt_manager, 'unsubscribe', new=AsyncMock()): + await updater._apply_downloaded_firmware() + assert (tmp_path / 'fw.bin').exists() + + +def test_verify_checksum_md5_valid(updater): + result = updater.verify_checksum(b'data', 'md5', "8d777f385d3dfec8815d20f7496026dc") + assert isinstance(result, bool) + + +def test_verify_checksum_invalid_algorithm(updater): + result = updater.verify_checksum(b'data', 'invalid_alg', "deadbeef") + assert result is False + + +def test_is_different_versions_true(updater): + new_info = {FW_TITLE_ATTR: 'fw', FW_VERSION_ATTR: 'v2'} + assert updater._is_different_firmware_versions(new_info) is True + + +def test_is_different_versions_false(updater): + updater.current_firmware_info['current_' + FW_TITLE_ATTR] = 'fw' + updater.current_firmware_info['current_' + FW_VERSION_ATTR] = 'v2' + new_info = {FW_TITLE_ATTR: 'fw', FW_VERSION_ATTR: 'v2'} + assert updater._is_different_firmware_versions(new_info) is False + + +@pytest.mark.asyncio +async def test_save_firmware_failure_logs_error(updater, caplog): + updater._firmware_data = b'data' + updater._target_title = "fw.bin" + updater._target_version = "v1" + updater._save_firmware = True + with patch.object(updater, '_send_current_firmware_info', new=AsyncMock()), \ + patch.object(updater._client._mqtt_manager, 'unsubscribe', new=AsyncMock()), \ + patch.object(updater, '_save', side_effect=IOError("disk error")): + await updater._apply_downloaded_firmware() + assert "Failed to save firmware" in caplog.text + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + + +@pytest.mark.asyncio +async def test_send_current_firmware_info_calls_send_timeseries(updater, mock_client): + updater.current_firmware_info = { + f"current_{FW_TITLE_ATTR}": "test", + f"current_{FW_VERSION_ATTR}": "1.0", + FW_STATE_ATTR: FirmwareStates.DOWNLOADING.value + } + await updater._send_current_firmware_info() + mock_client.send_timeseries.assert_awaited_once() + args = mock_client.send_timeseries.call_args[0][0] + assert all(isinstance(entry, TimeseriesEntry) for entry in args) + + +@pytest.mark.asyncio +async def test_firmware_info_callback_keys_mismatch(updater, caplog): + response = MagicMock() + response.shared_keys.return_value = ["unexpected_key"] + await updater._firmware_info_callback(response) + assert "does not match required keys" in caplog.text + assert updater.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.FAILED.value + + +@pytest.mark.asyncio +async def test_firmware_info_callback_same_version(updater): + updater.current_firmware_info[f"current_{FW_TITLE_ATTR}"] = "fw" + updater.current_firmware_info[f"current_{FW_VERSION_ATTR}"] = "v1" + + response = MagicMock() + response.shared_keys.return_value = REQUIRED_SHARED_KEYS + response.as_dict.return_value = { + "shared": [{"key": FW_TITLE_ATTR, "value": "fw"}, + {"key": FW_VERSION_ATTR, "value": "v1"}, + {"key": FW_SIZE_ATTR, "value": 123}, + {"key": FW_CHECKSUM_ALG_ATTR, "value": "dummy"}, + {"key": FW_CHECKSUM_ATTR, "value": "dummy"}] + } + + await updater._firmware_info_callback(response) + + +@pytest.mark.asyncio +async def test_firmware_info_callback_triggers_download(updater): + response = MagicMock() + response.shared_keys.return_value = REQUIRED_SHARED_KEYS + response.as_dict.return_value = { + "shared": [{"key": FW_TITLE_ATTR, "value": "new_fw"}, + {"key": FW_VERSION_ATTR, "value": "v2"}, + {"key": FW_SIZE_ATTR, "value": 123}, + {"key": FW_CHECKSUM_ALG_ATTR, "value": "alg"}, + {"key": FW_CHECKSUM_ATTR, "value": "chk"}] + } + + with patch.object(updater, '_get_next_chunk', new=AsyncMock()) as mocked: + await updater._firmware_info_callback(response) + mocked.assert_awaited_once() + assert updater._target_title == "new_fw" + assert updater._target_version == "v2" + + +def test_verify_checksum_null_data(updater): + result = updater.verify_checksum(None, "md5", "abc") + assert not result + + +def test_verify_checksum_null_checksum(updater): + result = updater.verify_checksum(b"data", "md5", None) + assert not result + + +@pytest.mark.parametrize("alg", [ + ("sha256", sha256(b"data").digest().hex()), + ("sha384", sha384(b"data").digest().hex()), + ("sha512", sha512(b"data").digest().hex()), + ("md5", md5(b"data").digest().hex()), + ("crc32", "".join(reversed([f'{crc32(b"data") & 0xffffffff:0>2X}'[i:i + 2] + for i in range(0, len(f'{crc32(b"data") & 0xffffffff:0>2X}'), 2)])).lower()) +]) +def test_verify_checksum_known_algorithms(updater, alg): + name, checksum = alg + with patch("tb_mqtt_client.service.device.firmware_updater.randint", return_value=0): + assert updater.verify_checksum(b"data", name, checksum) is True diff --git a/tests/service/gateway/__init__.py b/tests/service/gateway/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/service/gateway/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/service/gateway/handlers/__init__.py b/tests/service/gateway/handlers/__init__.py new file mode 100644 index 0000000..cff354e --- /dev/null +++ b/tests/service/gateway/handlers/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py b/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py new file mode 100644 index 0000000..f6fd6c1 --- /dev/null +++ b/tests/service/gateway/handlers/test_gateway_attribute_updates_handler.py @@ -0,0 +1,113 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.handlers.gateway_attribute_updates_handler import GatewayAttributeUpdatesHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +@pytest.mark.asyncio +async def test_handle_with_existing_device(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayAttributeUpdatesHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the device session + device_session = MagicMock(spec=DeviceSession) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + payload = b'{"device": "test_device", "data": {"key": "value"}}' + deserialized_data = {"device": "test_device", "data": {"key": "value"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock attribute update + attribute_update = AttributeUpdate([AttributeEntry(key="key", value="value")]) + gateway_attribute_update = GatewayAttributeUpdate(device_name="test_device", attribute_update=attribute_update) + message_adapter.parse_attribute_update.return_value = gateway_attribute_update + gateway_attribute_update.set_device_session = AsyncMock() + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_attribute_update.assert_called_once_with(deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + gateway_attribute_update.set_device_session.assert_called_once_with(device_session) + event_dispatcher.dispatch.assert_awaited_once_with(attribute_update, device_session=device_session) + + +@pytest.mark.asyncio +async def test_handle_with_nonexistent_device(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayAttributeUpdatesHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the device session (nonexistent) + device_manager.get_by_name.return_value = None + + # Mock the message adapter + payload = b'{"device": "nonexistent_device", "data": {"key": "value"}}' + deserialized_data = {"device": "nonexistent_device", "data": {"key": "value"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock attribute update + attribute_update = AttributeUpdate([AttributeEntry("key", "value")]) + gateway_attribute_update = GatewayAttributeUpdate(device_name="nonexistent_device", + attribute_update=attribute_update) + message_adapter.parse_attribute_update.return_value = gateway_attribute_update + gateway_attribute_update.set_device_session = AsyncMock() + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_attribute_update.assert_called_once_with(deserialized_data) + device_manager.get_by_name.assert_called_once_with("nonexistent_device") + # set_device_session should not be called since a device is None + assert not hasattr(gateway_attribute_update, + 'set_device_session') or not gateway_attribute_update.set_device_session.called + event_dispatcher.dispatch.assert_awaited_once_with(attribute_update, device_session=None) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py b/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py new file mode 100644 index 0000000..0badff7 --- /dev/null +++ b/tests/service/gateway/handlers/test_gateway_requested_attributes_response_handler.py @@ -0,0 +1,419 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.handlers.gateway_requested_attributes_response_handler import \ + GatewayRequestedAttributeResponseHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +@pytest.mark.asyncio +async def test_register_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Act + await handler.register_request(request) + + # Assert + assert (request.device_session.device_info.device_name, request.request_id) in handler._pending_attribute_requests + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][ + 0] == request + + +@pytest.mark.asyncio +async def test_register_request_with_timeout(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Mock asyncio.get_event_loop().call_later + with patch('asyncio.get_event_loop') as mock_get_loop: + mock_loop = MagicMock() + mock_get_loop.return_value = mock_loop + mock_timeout_task = MagicMock() + mock_loop.call_later.return_value = mock_timeout_task + + # Act + await handler.register_request(request, timeout=10) + + # Assert + assert (request.device_session.device_info.device_name, + request.request_id) in handler._pending_attribute_requests + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][0] == request # noqa + assert handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)][1] == mock_timeout_task # noqa + mock_loop.call_later.assert_called_once_with(10, handler._on_timeout, + request.device_session.device_info.device_name, request.request_id) + + +@pytest.mark.asyncio +async def test_register_request_duplicate(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request once + await handler.register_request(request) + + # Act & Assert - should raise RuntimeError when registering the same request ID again + with pytest.raises(RuntimeError): + await handler.register_request(request) + + +@pytest.mark.asyncio +async def test_unregister_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, None) + + # Act + handler.unregister_request(request.device_session.device_info.device_name, request.request_id) + + # Assert + assert (request.device_session.device_info.device_name, + request.request_id) not in handler._pending_attribute_requests + + +def test_unregister_nonexistent_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Act - should not raise an exception + handler.unregister_request("nonexistent_device", 999) + + # Assert + assert ("nonexistent_device", 999) not in handler._pending_attribute_requests + + +@pytest.mark.asyncio +async def test_handle_valid_response(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + timeout_task = MagicMock() + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, timeout_task) + + # Mock the message adapter + payload = b'{"device": "test_device", "id": ' + str( + request.request_id).encode() + b', "values": {"key1": "value1"}}' + deserialized_data = {"device": "test_device", "id": request.request_id, "values": {"key1": "value1"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock response + response = GatewayRequestedAttributeResponse(device_name="test_device", request_id=request.request_id, + shared=[AttributeEntry("key1", "value1")]) + message_adapter.parse_gateway_requested_attribute_response.return_value = response + + # Mock asyncio.create_task + with patch('asyncio.create_task') as mock_create_task: + mock_task = MagicMock() + mock_create_task.return_value = mock_task + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_gateway_requested_attribute_response.assert_called_once_with(request, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + timeout_task.cancel.assert_called_once() + mock_create_task.assert_called_once() + mock_task.add_done_callback.assert_called_once_with(handler._handle_callback_exception) + + +@pytest.mark.asyncio +async def test_handle_missing_fields(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the message adapter + payload = b'{"values": {"key1": "value1"}}' # Missing 'device' and 'id' + deserialized_data = {"values": {"key1": "value1"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Act + await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + + +@pytest.mark.asyncio +async def test_handle_no_pending_request(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Mock the message adapter + payload = b'{"device": "test_device", "id": 999, "values": {"key1": "value1"}}' + deserialized_data = {"device": "test_device", "id": 999, "values": {"key1": "value1"}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Act + result = await handler.handle("topic", payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + assert result is None + + +@pytest.mark.asyncio +async def test_on_timeout(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, None) + + handler._on_timeout(request.device_session.device_info.device_name, request.request_id) + + # Assert + assert (request.device_session.device_info.device_name, + request.request_id) not in handler._pending_attribute_requests + + +def test_handle_callback_exception(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a mock task that raises an exception + task = MagicMock(spec=asyncio.Task) + task.result.side_effect = Exception("Test exception") + + # Act + handler._handle_callback_exception(task) + + # Assert + task.result.assert_called_once() + + +@pytest.mark.asyncio +async def test_clear(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + + # Create a handler + handler = GatewayRequestedAttributeResponseHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + + # Create a request + request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Register the request + handler._pending_attribute_requests[(request.device_session.device_info.device_name, request.request_id)] = ( + request, None) + + # Act + handler.clear() + + # Assert + assert len(handler._pending_attribute_requests) == 0 + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter_removes_request(): + adapter = MagicMock(spec=GatewayMessageAdapter) + handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) + handler._pending_attribute_requests["test_device", 5] = (MagicMock(spec=AttributeRequest), AsyncMock()) + topic = "attr/request/5" + await handler.handle(topic, b"{}") + assert 5 not in handler._pending_attribute_requests + + +@pytest.mark.asyncio +async def test_handle_with_no_callback_registered(): + adapter = MagicMock(spec=GatewayMessageAdapter) + handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) + resp = MagicMock(spec=RequestedAttributeResponse) + resp.request_id = 42 + adapter.parse_gateway_requested_attribute_response.return_value = resp + handler._pending_attribute_requests['test_device', 42] = (MagicMock(spec=AttributeRequest), None) + await handler.handle("topic", b"payload") + + +@pytest.mark.asyncio +async def test_handle_with_parsing_exception(): + adapter = MagicMock(spec=GatewayMessageAdapter) + handler = GatewayRequestedAttributeResponseHandler(event_dispatcher=MagicMock(spec=DirectEventDispatcher), + message_adapter=adapter, + device_manager=MagicMock(spec=DeviceManager)) + adapter.parse_gateway_requested_attribute_response.side_effect = RuntimeError("bad parse") + await handler.handle("topic", b"payload") + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/handlers/test_gateway_rpc_handler.py b/tests/service/gateway/handlers/test_gateway_rpc_handler.py new file mode 100644 index 0000000..fa85e90 --- /dev/null +++ b/tests/service/gateway/handlers/test_gateway_rpc_handler.py @@ -0,0 +1,404 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher +from tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler import GatewayRPCHandler +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter + + +def test_init(): + # Setup + event_dispatcher = MagicMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Act + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Assert + assert handler._event_dispatcher == event_dispatcher + assert handler._message_adapter == message_adapter + assert handler._device_manager == device_manager + assert handler._stop_event == stop_event + assert handler._callback is None + event_dispatcher.register.assert_called_once_with(GatewayEventType.DEVICE_RPC_REQUEST, handler.handle) + + +@pytest.mark.asyncio +async def test_handle_successful_rpc(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Create a mock RPC response + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Mock the event dispatcher to return the RPC response + event_dispatcher.dispatch.side_effect = [rpc_response, asyncio.Future()] + + # Mock await_or_stop + with patch('tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler.await_or_stop') as mock_await_or_stop: + # Act + await handler.handle(topic, payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + + # Check that dispatch was called twice - once for the request and once for the response + assert event_dispatcher.dispatch.call_count == 2 + event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) + event_dispatcher.dispatch.assert_any_call(rpc_response) + + mock_await_or_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_no_message_adapter(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler with no message adapter + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=None, + device_manager=device_manager, + stop_event=stop_event + ) + + # Act + result = await handler.handle("topic", b'{}') + + # Assert + assert result is None + + +@pytest.mark.asyncio +async def test_handle_no_device_session(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Mock the device manager to return None (no device session) + device_manager.get_by_name.return_value = None + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "nonexistent_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "nonexistent_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("nonexistent_device") + + +@pytest.mark.asyncio +async def test_handle_no_response_from_callback(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Mock the event dispatcher to return None (no response from callback) + event_dispatcher.dispatch.return_value = None + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_called_once_with(rpc_request, device_session=device_session) + + +@pytest.mark.asyncio +async def test_handle_invalid_response_type(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Mock the event dispatcher to return an invalid response type + event_dispatcher.dispatch.side_effect = ["invalid_response", asyncio.Future()] + + # Act + await handler.handle(topic, payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_called() + + +@pytest.mark.asyncio +async def test_handle_exception_in_processing(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter to raise an exception + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + message_adapter.deserialize_to_dict.side_effect = Exception("Test exception") + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + + +@pytest.mark.asyncio +async def test_handle_timeout_in_publish(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Create a mock RPC response + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Mock the event dispatcher + event_dispatcher.dispatch.side_effect = [rpc_response, asyncio.Future()] + + # Mock await_or_stop to raise TimeoutError + with patch('tb_mqtt_client.service.gateway.handlers.gateway_rpc_handler.await_or_stop') as mock_await_or_stop: + mock_await_or_stop.side_effect = TimeoutError("Timeout") + + # Act + await handler.handle(topic, payload) + + # Assert + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) + event_dispatcher.dispatch.assert_any_call(rpc_response) + mock_await_or_stop.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_no_publish_futures(): + # Setup + event_dispatcher = AsyncMock(spec=DirectEventDispatcher) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + device_manager = MagicMock(spec=DeviceManager) + stop_event = asyncio.Event() + + # Create a handler + handler = GatewayRPCHandler( + event_dispatcher=event_dispatcher, + message_adapter=message_adapter, + device_manager=device_manager, + stop_event=stop_event + ) + + # Create a device session + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + device_manager.get_by_name.return_value = device_session + + # Mock the message adapter + topic = "v1/gateway/rpc" + payload = b'{"device": "test_device", "id": 1, "method": "test_method", "params": {"param1": "value1"}}' + deserialized_data = {"device": "test_device", + "data": {"id": 1, "method": "test_method", "params": {"param1": "value1"}}} + message_adapter.deserialize_to_dict.return_value = deserialized_data + + # Create a mock RPC request + rpc_request = GatewayRPCRequest._deserialize_from_dict(deserialized_data) + message_adapter.parse_rpc_request.return_value = rpc_request + + # Create a mock RPC response + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Mock the event dispatcher to return the RPC response but no publish futures + event_dispatcher.dispatch.side_effect = [rpc_response, None] + + # Act + result = await handler.handle(topic, payload) + + # Assert + assert result is None + message_adapter.deserialize_to_dict.assert_called_once_with(payload) + message_adapter.parse_rpc_request.assert_called_once_with(topic, deserialized_data) + device_manager.get_by_name.assert_called_once_with("test_device") + event_dispatcher.dispatch.assert_any_call(rpc_request, device_session=device_session) + event_dispatcher.dispatch.assert_any_call(rpc_response) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_device_manager.py b/tests/service/gateway/test_device_manager.py new file mode 100644 index 0000000..10650e6 --- /dev/null +++ b/tests/service/gateway/test_device_manager.py @@ -0,0 +1,391 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock + +import pytest + +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.service.gateway.device_manager import DeviceManager +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_register_new_device(): + # Setup + manager = DeviceManager() + + # Act + session = manager.register("test_device", "default") + + # Assert + assert session is not None + assert session.device_info.device_name == "test_device" + assert session.device_info.device_profile == "default" + assert session.device_info.device_id in manager._sessions_by_id + assert "test_device" in manager._ids_by_device_name + assert session.device_info.original_name in manager._ids_by_original_name + + +def test_register_existing_device(): + # Setup + manager = DeviceManager() + first_session = manager.register("test_device", "default") + + # Act + second_session = manager.register("test_device", "custom") + + # Assert + assert first_session is second_session + assert second_session.device_info.device_profile == "default" # Profile should not change + + +def test_unregister_device(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + + # Act + manager.unregister(device_id) + + # Assert + assert device_id not in manager._sessions_by_id + assert "test_device" not in manager._ids_by_device_name + assert session.device_info.original_name not in manager._ids_by_original_name + + +def test_unregister_nonexistent_device(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + + # Act & Assert - should not raise an exception + manager.unregister(nonexistent_id) + + +def test_get_by_id(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + + # Act + retrieved_session = manager.get_by_id(device_id) + + # Assert + assert retrieved_session is session + + +def test_get_by_id_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + + # Act + retrieved_session = manager.get_by_id(nonexistent_id) + + # Assert + assert retrieved_session is None + + +def test_get_by_name(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + + # Act + retrieved_session = manager.get_by_name("test_device") + + # Assert + assert retrieved_session is session + + +def test_get_by_original_name(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + original_name = session.device_info.original_name + + # Rename the device + manager.rename_device("test_device", "new_name") + + # Act + retrieved_session = manager.get_by_name(original_name) + + # Assert + assert retrieved_session is session + + +def test_get_by_name_nonexistent(): + # Setup + manager = DeviceManager() + + # Act + retrieved_session = manager.get_by_name("nonexistent_device") + + # Assert + assert retrieved_session is None + + +def test_is_connected(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + + # Act + is_connected = manager.is_connected(device_id) + + # Assert + assert is_connected is True + + +def test_is_connected_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + + # Act + is_connected = manager.is_connected(nonexistent_id) + + # Assert + assert is_connected is False + + +def test_all(): + # Setup + manager = DeviceManager() + session1 = manager.register("device1", "default") + session2 = manager.register("device2", "default") + + # Act + all_sessions = list(manager.all()) + + # Assert + assert len(all_sessions) == 2 + assert session1 in all_sessions + assert session2 in all_sessions + + +def test_rename_device(): + # Setup + manager = DeviceManager() + session = manager.register("old_name", "default") + device_id = session.device_info.device_id + + # Act + manager.rename_device("old_name", "new_name") + + # Assert + assert "old_name" not in manager._ids_by_device_name + assert "new_name" in manager._ids_by_device_name + assert manager._ids_by_device_name["new_name"] == device_id + assert session.device_info.device_name == "new_name" + assert session.device_info.original_name in manager._ids_by_original_name + + +def test_rename_nonexistent_device(): + # Setup + manager = DeviceManager() + + # Act - should not raise an exception + manager.rename_device("nonexistent_device", "new_name") + + # Assert + assert "nonexistent_device" not in manager._ids_by_device_name + assert "new_name" not in manager._ids_by_device_name + + +def test_set_attribute_update_callback(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + callback = MagicMock() + + # Mock the session's set_attribute_update_callback method + session.set_attribute_update_callback = MagicMock() + + # Act + manager.set_attribute_update_callback(device_id, callback) + + # Assert + session.set_attribute_update_callback.assert_called_once_with(callback) + + +def test_set_attribute_update_callback_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + callback = MagicMock() + + # Act - should not raise an exception + manager.set_attribute_update_callback(nonexistent_id, callback) + + +def test_set_attribute_response_callback(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + callback = MagicMock() + + # Mock the session's set_attribute_response_callback method + session.set_attribute_response_callback = MagicMock() + + # Act + manager.set_attribute_response_callback(device_id, callback) + + # Assert + session.set_attribute_response_callback.assert_called_once_with(callback) + + +def test_set_attribute_response_callback_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + callback = MagicMock() + + # Act - should not raise an exception + manager.set_attribute_response_callback(nonexistent_id, callback) + + +def test_set_rpc_request_callback(): + # Setup + manager = DeviceManager() + session = manager.register("test_device", "default") + device_id = session.device_info.device_id + callback = MagicMock() + + # Mock the session's set_rpc_request_callback method + session.set_rpc_request_callback = MagicMock() + + # Act + manager.set_rpc_request_callback(device_id, callback) + + # Assert + session.set_rpc_request_callback.assert_called_once_with(callback) + + +def test_set_rpc_request_callback_nonexistent(): + # Setup + manager = DeviceManager() + from uuid import uuid4 + nonexistent_id = uuid4() + callback = MagicMock() + + # Act - should not raise an exception + manager.set_rpc_request_callback(nonexistent_id, callback) + + +def test_state_change_callback(): + # Setup + manager = DeviceManager() + + # Create a device session with a mocked state + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info, manager._DeviceManager__state_change_callback) + + # Mock the session's state + session.state = MagicMock() + session.state.is_connected = MagicMock(return_value=True) + + # Act + manager._DeviceManager__state_change_callback(session) + + # Assert + assert session in manager.connected_devices + + +def test_state_change_callback_disconnect(): + # Setup + manager = DeviceManager() + + # Create a device session with a mocked state + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info, manager._DeviceManager__state_change_callback) + + # Add the session to connected devices + manager._DeviceManager__connected_devices.add(session) + + # Mock the session's state + session.state = MagicMock() + session.state.is_connected = MagicMock(return_value=False) + + # Act + manager._DeviceManager__state_change_callback(session) + + # Assert + assert session not in manager.connected_devices + + +def test_connected_devices_property(): + # Setup + manager = DeviceManager() + session1 = manager.register("device1", "default") + session2 = manager.register("device2", "default") + + # Add sessions to connected devices + manager._DeviceManager__connected_devices.add(session1) + manager._DeviceManager__connected_devices.add(session2) + + # Act + connected = manager.connected_devices + + # Assert + assert len(connected) == 2 + assert session1 in connected + assert session2 in connected + + +def test_all_devices_property(): + # Setup + manager = DeviceManager() + session1 = manager.register("device1", "default") + session2 = manager.register("device2", "default") + + # Act + all_devices = manager.all_devices + + # Assert + assert len(all_devices) == 2 + assert session1.device_info.device_id in all_devices + assert session2.device_info.device_id in all_devices + assert all_devices[session1.device_info.device_id] is session1 + assert all_devices[session2.device_info.device_id] is session2 + + +def test_repr(): + # Setup + manager = DeviceManager() + manager.register("device1", "default") + manager.register("device2", "default") + + # Act + repr_string = repr(manager) + + # Assert + assert "DeviceManager" in repr_string + assert "device1" in repr_string + assert "device2" in repr_string + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_device_session.py b/tests/service/gateway/test_device_session.py new file mode 100644 index 0000000..88d4cd9 --- /dev/null +++ b/tests/service/gateway/test_device_session.py @@ -0,0 +1,313 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock + +import pytest + +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.device_session_state import DeviceSessionState +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +def test_init(): + # Setup + device_info = DeviceInfo("test_device", "default") + state_change_callback = MagicMock() + + # Act + session = DeviceSession(device_info, state_change_callback) + + # Assert + assert session.device_info == device_info + assert session._state_change_callback == state_change_callback + assert session.state.is_connected() + assert session.connected_at > 0 + assert session.last_seen_at > 0 + assert session.claimed is False + assert session.provisioned is False + assert session.attribute_update_callback is None + assert session.attribute_response_callback is None + assert session.rpc_request_callback is None + + +def test_update_state(): + # Setup + device_info = DeviceInfo("test_device", "default") + state_change_callback = MagicMock() + session = DeviceSession(device_info, state_change_callback) + + # Act + session.update_state(DeviceSessionState.DISCONNECTED) + + # Assert + assert session.state == DeviceSessionState.DISCONNECTED + state_change_callback.assert_called_once_with(session) + + +def test_update_state_no_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) # No callback provided + + # Act - should not raise an exception + session.update_state(DeviceSessionState.DISCONNECTED) + + # Assert + assert session.state == DeviceSessionState.DISCONNECTED + + +def test_update_last_seen(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + old_last_seen = session.last_seen_at + + # Wait a bit to ensure the timestamp changes + import time + time.sleep(0.001) + + # Act + session.update_last_seen() + + # Assert + assert session.last_seen_at > old_last_seen + + +def test_set_attribute_update_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + + # Act + session.set_attribute_update_callback(callback) + + # Assert + assert session.attribute_update_callback == callback + + +def test_set_attribute_response_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + + # Act + session.set_attribute_response_callback(callback) + + # Assert + assert session.attribute_response_callback == callback + + +def test_set_rpc_request_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + + # Act + session.set_rpc_request_callback(callback) + + # Assert + assert session.rpc_request_callback == callback + + +@pytest.mark.asyncio +async def test_handle_event_to_device_attribute_update(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + session.set_attribute_update_callback(callback) + + # Create an attribute update event + attribute_entry = AttributeEntry(key="key", value="value") + attribute_update = AttributeUpdate([attribute_entry]) + + gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, + attribute_update=attribute_update) + + # Act + result = await session.handle_event_to_device(gateway_attribute_update.attribute_update) + + # Assert + callback.assert_called_once_with(session, attribute_update) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_handle_event_to_device_attribute_response(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + session.set_attribute_response_callback(callback) + + # Create an attribute response event + gateway_requested_attribute_response = GatewayRequestedAttributeResponse(request_id=1, device_name="test_device", + shared=[ + AttributeEntry(key="shared_key", + value="shared_value")], + client=[]) + + # Act + result = await session.handle_event_to_device(gateway_requested_attribute_response) + + # Assert + callback.assert_called_once_with(session, gateway_requested_attribute_response) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_handle_event_to_device_rpc_request(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + callback = MagicMock() + session.set_rpc_request_callback(callback) + + # Create an RPC request event + request_dict = { + "device": "test_device", + "data": { + "id": 1, + "method": "test", + "params": {"param": "value"} + } + } + rpc_request = GatewayRPCRequest._deserialize_from_dict(request_dict) + + # Act + result = await session.handle_event_to_device(rpc_request) + + # Assert + callback.assert_called_once_with(session, rpc_request) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_handle_event_to_device_no_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + # Create an attribute update event + attribute_update = AttributeUpdate([AttributeEntry(key="shared_key", value="shared_value")]) + gateway_attribute_update = GatewayAttributeUpdate(session.device_info.device_name, + attribute_update=attribute_update) + + # Act + result = await session.handle_event_to_device(gateway_attribute_update.attribute_update) + + # Assert + assert result is None + + +@pytest.mark.asyncio +async def test_handle_event_to_device_async_callback(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + rpc_request = { + "device": "test_device", + "data": { + "id": 1, + "method": "test", + "params": {"param": "value"} + } + } + + # Create an async callback + async def async_callback(session, event: GatewayRPCRequest): + return GatewayRPCResponse.build(device_name=event.device_name, request_id=event.request_id, result="success") + + session.set_rpc_request_callback(async_callback) + + # Create an RPC request event + rpc_request = GatewayRPCRequest._deserialize_from_dict(rpc_request) + + # Act + result = await session.handle_event_to_device(rpc_request) + + # Assert + assert isinstance(result, GatewayRPCResponse) + assert result.device_name == "test_device" + assert result.request_id == 1 + assert result.result == "success" + + +@pytest.mark.asyncio +async def test_handle_event_to_device_unsupported_event(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + # Create a mock event with an unsupported event type + mock_event = MagicMock() + mock_event.event_type = "UNSUPPORTED_EVENT_TYPE" + + # Act + result = await session.handle_event_to_device(mock_event) + + # Assert + assert result is None + + +def test_equality(): + # Setup + device_info1 = DeviceInfo("test_device", "default") + device_info2 = DeviceInfo("test_device", "default") # Same device name, but different UUID + device_info3 = DeviceInfo("other_device", "default") + + session1 = DeviceSession(device_info1) + session2 = DeviceSession(device_info2) + session3 = DeviceSession(device_info3) + + # Act & Assert + assert session1 != session2 # Different UUIDs + assert session1 != session3 + assert session2 != session3 + + # Test equality with the same device_info + session4 = DeviceSession(device_info1) + assert session1 == session4 # Same device_info + + +def test_hash(): + # Setup + device_info = DeviceInfo("test_device", "default") + session = DeviceSession(device_info) + + # Act + hash_value = hash(session) + + # Assert + assert hash_value == hash(device_info) + + # Test that sessions with the same device_info have the same hash + session2 = DeviceSession(device_info) + assert hash(session) == hash(session2) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_direct_event_dispatcher.py b/tests/service/gateway/test_direct_event_dispatcher.py new file mode 100644 index 0000000..5d1d2ea --- /dev/null +++ b/tests/service/gateway/test_direct_event_dispatcher.py @@ -0,0 +1,276 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from tb_mqtt_client.entities.gateway.event_type import GatewayEventType +from tb_mqtt_client.entities.gateway.gateway_event import GatewayEvent +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.direct_event_dispatcher import DirectEventDispatcher + + +def test_init(): + # Setup & Act + dispatcher = DirectEventDispatcher() + + # Assert + assert isinstance(dispatcher._handlers, dict) + assert len(dispatcher._handlers) == 0 + assert isinstance(dispatcher._lock, asyncio.Lock) + + +def test_register(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Act + dispatcher.register(event_type, callback) + + # Assert + assert event_type in dispatcher._handlers + assert callback in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 1 + + +def test_register_duplicate(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback twice + dispatcher.register(event_type, callback) + dispatcher.register(event_type, callback) + + # Assert + assert event_type in dispatcher._handlers + assert callback in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 1 # Should only be added once + + +def test_register_multiple(): + # Setup + dispatcher = DirectEventDispatcher() + callback1 = MagicMock() + callback2 = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Act + dispatcher.register(event_type, callback1) + dispatcher.register(event_type, callback2) + + # Assert + assert event_type in dispatcher._handlers + assert callback1 in dispatcher._handlers[event_type] + assert callback2 in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 2 + + +def test_unregister(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register and then unregister + dispatcher.register(event_type, callback) + dispatcher.unregister(event_type, callback) + + # Assert + assert event_type not in dispatcher._handlers # Event type should be removed when last callback is unregistered + + +def test_unregister_nonexistent(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Act - should not raise an exception + dispatcher.unregister(event_type, callback) + + # Assert + assert event_type not in dispatcher._handlers + + +def test_unregister_one_of_many(): + # Setup + dispatcher = DirectEventDispatcher() + callback1 = MagicMock() + callback2 = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register both callbacks and unregister one + dispatcher.register(event_type, callback1) + dispatcher.register(event_type, callback2) + dispatcher.unregister(event_type, callback1) + + # Assert + assert event_type in dispatcher._handlers + assert callback1 not in dispatcher._handlers[event_type] + assert callback2 in dispatcher._handlers[event_type] + assert len(dispatcher._handlers[event_type]) == 1 + + +@pytest.mark.asyncio +async def test_dispatch_to_device_session(): + # Setup + dispatcher = DirectEventDispatcher() + device_session = MagicMock(spec=DeviceSession) + device_session.handle_event_to_device = AsyncMock() + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = GatewayEventType.DEVICE_ATTRIBUTE_UPDATE + + # Act + await dispatcher.dispatch(event, device_session=device_session) + + # Assert + device_session.handle_event_to_device.assert_awaited_once_with(event) + + +@pytest.mark.asyncio +async def test_dispatch_to_sync_callback(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_called_once_with(event) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_dispatch_to_async_callback(): + # Setup + dispatcher = DirectEventDispatcher() + callback = AsyncMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_awaited_once_with(event) + assert result == callback.return_value + + +@pytest.mark.asyncio +async def test_dispatch_with_args(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock() + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + await dispatcher.dispatch(event, "arg1", "arg2", kwarg1="value1", kwarg2="value2") + + # Assert + callback.assert_called_once_with(event, "arg1", "arg2", kwarg1="value1", kwarg2="value2") + + +@pytest.mark.asyncio +async def test_dispatch_no_handlers(): + # Setup + dispatcher = DirectEventDispatcher() + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = GatewayEventType.DEVICE_CONNECT + + # Act + result = await dispatcher.dispatch(event) + + # Assert + assert result is None + + +@pytest.mark.asyncio +async def test_dispatch_callback_exception(): + # Setup + dispatcher = DirectEventDispatcher() + callback = MagicMock(side_effect=Exception("Test exception")) + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_called_once_with(event) + assert result is None + + +@pytest.mark.asyncio +async def test_dispatch_async_callback_exception(): + # Setup + dispatcher = DirectEventDispatcher() + callback = AsyncMock(side_effect=Exception("Test exception")) + event_type = GatewayEventType.DEVICE_CONNECT + + # Register the callback + dispatcher.register(event_type, callback) + + # Create a mock event + event = MagicMock(spec=GatewayEvent) + event.event_type = event_type + + # Act + result = await dispatcher.dispatch(event) + + # Assert + callback.assert_awaited_once_with(event) + assert result is None + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_gateway_client.py b/tests/service/gateway/test_gateway_client.py new file mode 100644 index 0000000..61248b3 --- /dev/null +++ b/tests/service/gateway/test_gateway_client.py @@ -0,0 +1,477 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, \ + GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequestBuilder +from tb_mqtt_client.service.gateway.client import GatewayClient +from tb_mqtt_client.service.gateway.device_session import DeviceSession + + +@pytest.mark.asyncio +async def test_connect_device(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._event_dispatcher.dispatch = AsyncMock() + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_CONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + device_session, result = await client.connect_device("test_device", wait_for_publish=True) + + # Assert + assert device_session is not None + assert device_session.device_info.device_name == "test_device" + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_CONNECT_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_device_with_connect_message(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._event_dispatcher.dispatch = AsyncMock() + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_CONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a device connect message + connect_message = DeviceConnectMessage.build("test_device", "custom_profile") + + # Act + device_session, result = await client.connect_device(connect_message, wait_for_publish=True) + + # Assert + assert device_session is not None + assert device_session.device_info.device_name == "test_device" + assert device_session.device_info.device_profile == "custom_profile" + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_CONNECT_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_device_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + device_session, futures = await client.connect_device("test_device", wait_for_publish=False) + + # Assert + assert device_session is not None + assert device_session.device_info.device_name == "test_device" + assert isinstance(futures[0], asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_device(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result( + PublishResult(topic=GATEWAY_DISCONNECT_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + session, result = await client.disconnect_device(device_session, wait_for_publish=True) + + # Assert + assert session is not None + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_DISCONNECT_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_device_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + session, futures = await client.disconnect_device(device_session, wait_for_publish=False) + + # Assert + assert session is not None + assert isinstance(futures[0], asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_timeseries(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result( + PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + result = await client.send_device_timeseries(device_session, {"temperature": 25.5}, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_TELEMETRY_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_timeseries_with_timeseries_entry(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result( + PublishResult(topic=GATEWAY_TELEMETRY_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a timeseries entry + timeseries = TimeseriesEntry("temperature", 25.5) + + # Act + result = await client.send_device_timeseries(device_session, timeseries, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_TELEMETRY_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_timeseries_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + future = await client.send_device_timeseries(device_session, {"temperature": 25.5}, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result( + PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Act + result = await client.send_device_attributes(device_session, {"firmware_version": "1.0.0"}, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_ATTRIBUTES_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_with_attribute_entry(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result( + PublishResult(topic=GATEWAY_ATTRIBUTES_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create an attribute entry + attributes = AttributeEntry("firmware_version", "1.0.0") + + # Act + result = await client.send_device_attributes(device_session, attributes, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_ATTRIBUTES_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Act + future = await client.send_device_attributes(device_session, {"firmware_version": "1.0.0"}, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_request(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._gateway_requested_attribute_response_handler = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + future.set_result( + PublishResult(topic=GATEWAY_ATTRIBUTES_REQUEST_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway attribute request + request = await GatewayAttributeRequest.build(device_session, shared_keys=["firmware_version"], client_keys=None) + + # Act + result = await client.send_device_attributes_request(device_session, request, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_ATTRIBUTES_REQUEST_TOPIC + client._gateway_requested_attribute_response_handler.register_request.assert_awaited_once() + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_attributes_request_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + client._gateway_requested_attribute_response_handler = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that will be returned by dispatch + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway attribute request + request = await GatewayAttributeRequest.build(device_session, shared_keys=["firmware_version"], client_keys=None) + + # Act + future = await client.send_device_attributes_request(device_session, request, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._gateway_requested_attribute_response_handler.register_request.assert_awaited_once() + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_claim_request(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that dispatch will return + future = asyncio.Future() + future.set_result(PublishResult(topic=GATEWAY_CLAIM_TOPIC, qos=1, message_id=1, payload_size=100, reason_code=0)) + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway claim request + device_claim_request = ClaimRequest.build(secret_key="secret", duration=1000) + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, + device_claim_request).build() + + # Act + result = await client.send_device_claim_request(device_session, gateway_claim_request, wait_for_publish=True) + + # Assert + assert isinstance(result, PublishResult) + assert result.topic == GATEWAY_CLAIM_TOPIC + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_device_claim_request_no_wait(): + # Setup + client = GatewayClient() + client._event_dispatcher = AsyncMock() + + # Create a device session + + info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info=info) + + # Create a future that dispatch will return + future = asyncio.Future() + client._event_dispatcher.dispatch.return_value = [future] + + # Create a gateway claim request + device_claim_request = ClaimRequest.build(secret_key="secret", duration=1000) + gateway_claim_request = GatewayClaimRequestBuilder().add_device_request(device_session, + device_claim_request).build() + + # Act + future = await client.send_device_claim_request(device_session, gateway_claim_request, wait_for_publish=False) + + # Assert + assert isinstance(future, asyncio.Future) + client._event_dispatcher.dispatch.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect(): + # Setup + client = GatewayClient() + client._mqtt_manager = AsyncMock() + client._mqtt_manager.unsubscribe = AsyncMock(return_value=asyncio.Future()) + + # Act + await client.disconnect() + + # Assert + client._mqtt_manager.unsubscribe.assert_awaited() + client._mqtt_manager.disconnect.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_handle_rate_limit_response(): + # Setup + client = GatewayClient() + client._mqtt_manager = MagicMock() + client._gateway_message_adapter = MagicMock() + client._gateway_message_adapter.splitter = MagicMock() + client._gateway_rate_limiter = MagicMock() + client._gateway_rate_limiter.message_rate_limit = AsyncMock() + client._gateway_rate_limiter.telemetry_message_rate_limit = AsyncMock() + client._gateway_rate_limiter.telemetry_datapoints_rate_limit = AsyncMock() + + # Create a response with gateway rate limits + response = RPCResponse.build(1, result={ + 'gatewayRateLimits': { + 'messages': '10:1,', + 'telemetryMessages': '100:60,', + 'telemetryDataPoints': '500:60,' + }, + 'maxPayloadSize': 512 + }) + + # Mock the parent class method + with patch('tb_mqtt_client.service.device.client.DeviceClient._handle_rate_limit_response', return_value=True): + # Act + result = await client._handle_rate_limit_response(response) + + # Assert + assert result is True + client._gateway_rate_limiter.message_rate_limit.set_limit.assert_awaited_once() + client._gateway_rate_limiter.telemetry_message_rate_limit.set_limit.assert_awaited_once() + client._gateway_rate_limiter.telemetry_datapoints_rate_limit.set_limit.assert_awaited_once() + client._mqtt_manager.set_gateway_rate_limits_received.assert_called_once() + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_message_adapter.py b/tests/service/gateway/test_message_adapter.py new file mode 100644 index 0000000..b5f2372 --- /dev/null +++ b/tests/service/gateway/test_message_adapter.py @@ -0,0 +1,538 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime, UTC + +import orjson +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ + GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.device_info import DeviceInfo +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_attribute_update import GatewayAttributeUpdate +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequestBuilder +from tb_mqtt_client.entities.gateway.gateway_requested_attribute_response import GatewayRequestedAttributeResponse +from tb_mqtt_client.entities.gateway.gateway_rpc_request import GatewayRPCRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.service.gateway.device_session import DeviceSession +from tb_mqtt_client.service.gateway.message_adapter import JsonGatewayMessageAdapter + + +def test_init(): + # Setup & Act + adapter = JsonGatewayMessageAdapter(max_payload_size=1000, max_datapoints=100) + + # Assert + assert adapter.splitter.max_payload_size == 1000 - DEFAULT_FIELDS_SIZE + assert adapter.splitter.max_datapoints == 100 + + +def test_build_device_connect_message_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_connect_message = DeviceConnectMessage.build("test_device", "default") + + # Act + result = adapter.build_device_connect_message_payload(device_connect_message, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_CONNECT_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + assert "type" in payload_dict + assert payload_dict["type"] == "default" + + +def test_build_device_disconnect_message_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_disconnect_message = DeviceDisconnectMessage.build("test_device") + + # Act + result = adapter.build_device_disconnect_message_payload(device_disconnect_message, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_DISCONNECT_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + + +@pytest.mark.asyncio +async def test_build_gateway_attribute_request_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + # Act + result = adapter.build_gateway_attribute_request_payload(attribute_request, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_ATTRIBUTES_REQUEST_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + assert "keys" in payload_dict + assert payload_dict["keys"] == ["key1", "key2"] + assert "id" in payload_dict + assert payload_dict["id"] == attribute_request.request_id + + +def test_build_rpc_response_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + rpc_response = GatewayRPCResponse.build(device_name="test_device", request_id=1, result="success") + + # Act + result = adapter.build_rpc_response_payload(rpc_response, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_RPC_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "device" in payload_dict + assert payload_dict["device"] == "test_device" + assert "id" in payload_dict + assert payload_dict["id"] == 1 + assert "data" in payload_dict + assert payload_dict["data"] == {"result": "success"} + + +def test_build_claim_request_payload(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_claim_request = ClaimRequest.build(secret_key="secret", duration=1) + gateway_claim_request = (GatewayClaimRequestBuilder() + .add_device_request(device_name_or_session="test_device", + device_claim_request=device_claim_request) + .build()) + + # Act + result = adapter.build_claim_request_payload(gateway_claim_request, qos=1) + + # Assert + assert isinstance(result, MqttPublishMessage) + assert result.topic == GATEWAY_CLAIM_TOPIC + assert result.qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result.payload) + assert "test_device" in payload_dict + assert payload_dict["test_device"]["secretKey"] == "secret" + assert payload_dict["test_device"]["durationMs"] == 1000 + + +def test_parse_attribute_update(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "device": "test_device", + "data": { + "key1": "value1", + "key2": 42 + } + } + + # Act + result = adapter.parse_attribute_update(data) + + # Assert + assert isinstance(result, GatewayAttributeUpdate) + assert result.device_name == "test_device" + assert isinstance(result.attribute_update, AttributeUpdate) + assert result.attribute_update.as_dict() == {"key1": "value1", "key2": 42} + + +def test_parse_attribute_update_invalid_format(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "invalid_format": True + } + + # Act & Assert + with pytest.raises(ValueError): + adapter.parse_attribute_update(data) + + +@pytest.mark.asyncio +async def test_parse_gateway_requested_attribute_response(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=["key1", "key2"]) + + data = { + "device": "test_device", + "id": attribute_request.request_id, + "values": { + "key1": "value1", + "key2": 42 + } + } + + # Act + result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) + + # Assert + assert isinstance(result, GatewayRequestedAttributeResponse) + assert result.device_name == "test_device" + assert result.request_id == attribute_request.request_id + assert len(result.client) == 2 + assert any(attr.key == "key1" and attr.value == "value1" for attr in result.client) + assert any(attr.key == "key2" and attr.value == 42 for attr in result.client) + assert len(result.shared) == 0 + + +@pytest.mark.asyncio +async def test_parse_gateway_requested_attribute_response_single_value(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, client_keys=["key1"]) + + data = { + "device": "test_device", + "id": attribute_request.request_id, + "value": "value1" + } + + # Act + result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) + + # Assert + assert isinstance(result, GatewayRequestedAttributeResponse) + assert result.device_name == "test_device" + assert result.request_id == attribute_request.request_id + assert len(result.client) == 1 + assert result.client[0].key == "key1" + assert result.client[0].value == "value1" + assert len(result.shared) == 0 + + +@pytest.mark.asyncio +async def test_parse_gateway_requested_attribute_response_shared_keys(): + # Setup + adapter = JsonGatewayMessageAdapter() + device_info = DeviceInfo("test_device", "default") + device_session = DeviceSession(device_info) + attribute_request = await GatewayAttributeRequest.build(device_session=device_session, shared_keys=["key1", "key2"]) + + data = { + "device": "test_device", + "id": attribute_request.request_id, + "values": { + "key1": "value1", + "key2": 42 + } + } + + # Act + result = adapter.parse_gateway_requested_attribute_response(attribute_request, data) + + # Assert + assert isinstance(result, GatewayRequestedAttributeResponse) + assert result.device_name == "test_device" + assert result.request_id == attribute_request.request_id + assert len(result.shared) == 2 + assert any(attr.key == "key1" and attr.value == "value1" for attr in result.shared) + assert any(attr.key == "key2" and attr.value == 42 for attr in result.shared) + assert len(result.client) == 0 + + +def test_parse_rpc_request(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "device": "test_device", + "data": { + "id": 1, + "method": "test_method", + "params": { + "param1": "value1", + "param2": 42 + } + } + } + + # Act + result = adapter.parse_rpc_request("v1/gateway/rpc", data) + + # Assert + assert isinstance(result, GatewayRPCRequest) + assert result.device_name == "test_device" + assert result.request_id == 1 + assert result.method == "test_method" + assert result.params == {"param1": "value1", "param2": 42} + + +def test_parse_rpc_request_invalid_format(): + # Setup + adapter = JsonGatewayMessageAdapter() + data = { + "invalid_format": True + } + + # Act & Assert + with pytest.raises(ValueError): + adapter.parse_rpc_request("v1/gateway/rpc", data) + + +def test_deserialize_to_dict(): + # Setup + adapter = JsonGatewayMessageAdapter() + payload = b'{"key": "value", "number": 42}' + + # Act + result = adapter.deserialize_to_dict(payload) + + # Assert + assert isinstance(result, dict) + assert result == {"key": "value", "number": 42} + + +def test_deserialize_to_dict_invalid_format(): + # Setup + adapter = JsonGatewayMessageAdapter() + payload = b'invalid json' + + # Act & Assert + with pytest.raises(ValueError): + adapter.deserialize_to_dict(payload) + + +def test_pack_attributes(): + # Setup + attributes = [ + AttributeEntry("key1", "value1"), + AttributeEntry("key2", 42) + ] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes(attributes).build() + + # Act + result = JsonGatewayMessageAdapter.pack_attributes(uplink_message) + + # Assert + assert isinstance(result, dict) + assert result == {"key1": "value1", "key2": 42} + + +def test_pack_timeseries_no_timestamp(): + # Setup + timeseries = [TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) + + # Assert + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == {"temp": 22.5, "humidity": 45} + + +def test_pack_timeseries_with_timestamp(): + # Setup + ts = int(datetime.now(UTC).timestamp() * 1000) + timeseries = [TimeseriesEntry("temp", 22.5, ts), TimeseriesEntry("humidity", 45, ts)] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) + + # Assert + assert isinstance(result, list) + assert len(result) == 1 + assert "ts" in result[0] + assert result[0]["ts"] == ts + assert "values" in result[0] + assert result[0]["values"] == {"temp": 22.5, "humidity": 45} + + +def test_pack_timeseries_with_different_timestamps(): + # Setup + ts1 = int(datetime.now(UTC).timestamp() * 1000) + ts2 = ts1 + 1000 # 1 second later + timeseries = [TimeseriesEntry("temp", 22.5, ts1), TimeseriesEntry("humidity", 45, ts2)] + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = JsonGatewayMessageAdapter.pack_timeseries(uplink_message) + + # Assert + assert isinstance(result, list) + assert len(result) == 2 + + # Find entries by timestamp + ts1_entry = next((entry for entry in result if entry["ts"] == ts1), None) + ts2_entry = next((entry for entry in result if entry["ts"] == ts2), None) + + assert ts1_entry is not None + assert ts2_entry is not None + assert ts1_entry["values"] == {"temp": 22.5} + assert ts2_entry["values"] == {"humidity": 45} + + +def test_build_uplink_messages_empty(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Act + result = adapter.build_uplink_messages([]) + + # Assert + assert result == [] + + +def test_build_uplink_messages_non_gateway_message(): + # Setup + adapter = JsonGatewayMessageAdapter() + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=b"test_payload", qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 1 + assert result[0] == mqtt_msg + + +def test_build_uplink_messages_with_telemetry(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Create a gateway uplink message with telemetry + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries([TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)]) + .build()) + + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 1 + assert result[0].topic == GATEWAY_TELEMETRY_TOPIC + assert result[0].qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result[0].payload) + assert "test_device" in payload_dict + assert isinstance(payload_dict["test_device"], list) + assert len(payload_dict["test_device"]) == 1 + assert "values" in payload_dict["test_device"][0] + assert payload_dict["test_device"][0]["values"] == {"temp": 22.5, "humidity": 45} + + +def test_build_uplink_messages_with_attributes(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Create a gateway uplink message with attributes + uplink_message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes([ + AttributeEntry("key1", "value1"), + AttributeEntry("key2", 42) + ]).build() + + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 1 + assert result[0].topic == GATEWAY_ATTRIBUTES_TOPIC + assert result[0].qos == 1 + + # Verify payload content + payload_dict = orjson.loads(result[0].payload) + assert "test_device" in payload_dict + assert payload_dict["test_device"] == {"key1": "value1", "key2": 42} + + +def test_build_uplink_messages_with_both_telemetry_and_attributes(): + # Setup + adapter = JsonGatewayMessageAdapter() + + # Create a gateway uplink message with both telemetry and attributes + uplink_message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + uplink_message_builder.add_timeseries([TimeseriesEntry("temp", 22.5), TimeseriesEntry("humidity", 45)]) + uplink_message_builder.add_attributes([AttributeEntry("key1", "value1"), AttributeEntry("key2", 42)]) + uplink_message = uplink_message_builder.build() + + mqtt_msg = MqttPublishMessage(topic="test/topic", payload=uplink_message, qos=1) + + # Act + result = adapter.build_uplink_messages([mqtt_msg]) + + # Assert + assert len(result) == 2 + + # Find telemetry and attribute messages + telemetry_msg = next((msg for msg in result if msg.topic == GATEWAY_TELEMETRY_TOPIC), None) + attribute_msg = next((msg for msg in result if msg.topic == GATEWAY_ATTRIBUTES_TOPIC), None) + + assert telemetry_msg is not None + assert attribute_msg is not None + + # Verify telemetry payload + telemetry_dict = orjson.loads(telemetry_msg.payload) + assert "test_device" in telemetry_dict + assert isinstance(telemetry_dict["test_device"], list) + assert len(telemetry_dict["test_device"]) == 1 + assert "values" in telemetry_dict["test_device"][0] + assert telemetry_dict["test_device"][0]["values"] == {"temp": 22.5, "humidity": 45} + + # Verify attribute payload + attribute_dict = orjson.loads(attribute_msg.payload) + assert "test_device" in attribute_dict + assert attribute_dict["test_device"] == {"key1": "value1", "key2": 42} + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_message_sender.py b/tests/service/gateway/test_message_sender.py new file mode 100644 index 0000000..3c83c94 --- /dev/null +++ b/tests/service/gateway/test_message_sender.py @@ -0,0 +1,328 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants.mqtt_topics import GATEWAY_TELEMETRY_TOPIC, GATEWAY_ATTRIBUTES_TOPIC, \ + GATEWAY_CONNECT_TOPIC, GATEWAY_DISCONNECT_TOPIC, GATEWAY_ATTRIBUTES_REQUEST_TOPIC, GATEWAY_RPC_TOPIC, \ + GATEWAY_CLAIM_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.device_connect_message import DeviceConnectMessage +from tb_mqtt_client.entities.gateway.device_disconnect_message import DeviceDisconnectMessage +from tb_mqtt_client.entities.gateway.gateway_attribute_request import GatewayAttributeRequest +from tb_mqtt_client.entities.gateway.gateway_claim_request import GatewayClaimRequest +from tb_mqtt_client.entities.gateway.gateway_rpc_response import GatewayRPCResponse +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.gateway.message_sender import GatewayMessageSender +from tb_mqtt_client.service.message_service import MessageService + + +def test_init(): + # Setup & Act + sender = GatewayMessageSender() + + # Assert + assert sender._message_queue is None + assert sender._message_adapter is None + + +def test_set_message_queue(): + # Setup + sender = GatewayMessageSender() + message_queue = MagicMock(spec=MessageService) + + # Act + sender.set_message_queue(message_queue) + + # Assert + assert sender._message_queue == message_queue + + +def test_set_message_adapter(): + # Setup + sender = GatewayMessageSender() + message_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Act + sender.set_message_adapter(message_adapter) + + # Assert + assert sender._message_adapter == message_adapter + + +@pytest.mark.asyncio +async def test_send_uplink_message_with_timeseries(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + sender.set_message_queue(message_queue) + + # Create an uplink message with timeseries + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries([TimeseriesEntry("temp", 22.5)]) + .build()) + + # Mock the publish method to set delivery_futures + async def mock_publish(mqtt_message): + mqtt_message.delivery_futures = [asyncio.Future()] + return [mqtt_message] + + message_queue.publish.side_effect = mock_publish + + # Act + result = await sender.send_uplink_message(uplink_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_queue.publish.assert_called_once() + + # Verify the MQTT message + mqtt_message = message_queue.publish.call_args[0][0] + assert mqtt_message.topic == GATEWAY_TELEMETRY_TOPIC + assert mqtt_message.payload == uplink_message + assert mqtt_message.qos == 1 + + +@pytest.mark.asyncio +async def test_send_uplink_message_with_attributes(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + sender.set_message_queue(message_queue) + + # Create an uplink message with attributes + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes([AttributeEntry("temperature", 22.5)]) + .build()) + + # Mock the publish method to set delivery_futures + async def mock_publish(mqtt_message): + mqtt_message.delivery_futures = [asyncio.Future()] + return [mqtt_message] + + message_queue.publish.side_effect = mock_publish + + # Act + result = await sender.send_uplink_message(uplink_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_queue.publish.assert_called_once() + + # Verify the MQTT message + mqtt_message = message_queue.publish.call_args[0][0] + assert mqtt_message.topic == GATEWAY_ATTRIBUTES_TOPIC + assert mqtt_message.payload == uplink_message + assert mqtt_message.qos == 1 + + +@pytest.mark.asyncio +async def test_send_uplink_message_with_both(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + sender.set_message_queue(message_queue) + + # Create an uplink message with both timeseries and attributes + uplink_message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_timeseries([TimeseriesEntry("temp", 22.5)]) + .add_attributes(AttributeEntry("key1", "value1")) + .build()) + + # Mock the publish method to set delivery_futures + async def mock_publish(mqtt_message): + mqtt_message.delivery_futures = [asyncio.Future()] + return [mqtt_message] + + message_queue.publish.side_effect = mock_publish + + # Act + result = await sender.send_uplink_message(uplink_message) + + # Assert + assert result is not None + assert len(result) == 2 # One for telemetry, one for attributes + assert isinstance(result[0], asyncio.Future) + assert isinstance(result[1], asyncio.Future) + assert message_queue.publish.call_count == 2 + + # Verify the MQTT messages + call_args_list = message_queue.publish.call_args_list + assert call_args_list[0][0][0].topic == GATEWAY_TELEMETRY_TOPIC + assert call_args_list[0][0][0].payload == uplink_message + assert call_args_list[1][0][0].topic == GATEWAY_ATTRIBUTES_TOPIC + assert call_args_list[1][0][0].payload == uplink_message + + +@pytest.mark.asyncio +async def test_send_device_connect(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create a device connect message + device_connect_message = DeviceConnectMessage.build("test_device", "default") + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_CONNECT_TOPIC, b'{"device":"test_device","type":"default"}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_device_connect_message_payload.return_value = mqtt_message + + # Act + result = await sender.send_device_connect(device_connect_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_device_connect_message_payload.assert_called_once_with( + device_connect_message=device_connect_message, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_device_disconnect(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create a device disconnect message + device_disconnect_message = DeviceDisconnectMessage.build("test_device") + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_DISCONNECT_TOPIC, b'{"device":"test_device"}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_device_disconnect_message_payload.return_value = mqtt_message + + # Act + result = await sender.send_device_disconnect(device_disconnect_message) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_device_disconnect_message_payload.assert_called_once_with( + device_disconnect_message=device_disconnect_message, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_attributes_request(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create an attribute request + attribute_request = MagicMock(spec=GatewayAttributeRequest) + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_ATTRIBUTES_REQUEST_TOPIC, b'{"device":"test_device","keys":["key1"]}', + qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_gateway_attribute_request_payload.return_value = mqtt_message + + # Act + result = await sender.send_attributes_request(attribute_request) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_gateway_attribute_request_payload.assert_called_once_with(attribute_request=attribute_request, + qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_rpc_response(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create an RPC response + rpc_response = MagicMock(spec=GatewayRPCResponse) + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_RPC_TOPIC, b'{"device":"test_device","id":1,"data":{"result":"success"}}', + qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_rpc_response_payload.return_value = mqtt_message + + # Act + result = await sender.send_rpc_response(rpc_response) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_rpc_response_payload.assert_called_once_with(rpc_response=rpc_response, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +@pytest.mark.asyncio +async def test_send_claim_request(): + # Setup + sender = GatewayMessageSender() + message_queue = AsyncMock(spec=MessageService) + message_adapter = MagicMock(spec=GatewayMessageAdapter) + sender.set_message_queue(message_queue) + sender.set_message_adapter(message_adapter) + + # Create a claim request + claim_request = MagicMock(spec=GatewayClaimRequest) + + # Mock the message adapter + mqtt_message = MqttPublishMessage(GATEWAY_CLAIM_TOPIC, b'{"device":"test_device","secretKey":"secret"}', qos=1) + mqtt_message.delivery_futures = [asyncio.Future()] + message_adapter.build_claim_request_payload.return_value = mqtt_message + + # Act + result = await sender.send_claim_request(claim_request) + + # Assert + assert result is not None + assert len(result) == 1 + assert isinstance(result[0], asyncio.Future) + message_adapter.build_claim_request_payload.assert_called_once_with(claim_request=claim_request, qos=1) + message_queue.publish.assert_called_once_with(mqtt_message) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/gateway/test_message_splitter.py b/tests/service/gateway/test_message_splitter.py new file mode 100644 index 0000000..8200f64 --- /dev/null +++ b/tests/service/gateway/test_message_splitter.py @@ -0,0 +1,411 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from tb_mqtt_client.common.logging_utils import configure_logging +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder, \ + DEFAULT_FIELDS_SIZE +from tb_mqtt_client.service.gateway.message_splitter import GatewayMessageSplitter + +configure_logging() + + +def test_init_default(): + # Setup & Act + splitter = GatewayMessageSplitter() + + # Assert + assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE + assert splitter.max_datapoints == 0 + + +def test_init_custom(): + # Setup & Act + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Assert + assert splitter.max_payload_size == 10000 - DEFAULT_FIELDS_SIZE + assert splitter.max_datapoints == 100 + + +def test_init_invalid_values(): + # Setup & Act + splitter = GatewayMessageSplitter(max_payload_size=-1, max_datapoints=-1) + + # Assert + assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE # Default value + assert splitter.max_datapoints == 0 # Default value + + +def test_max_payload_size_property(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_payload_size = 20000 + + # Assert + assert splitter.max_payload_size == 20000 - DEFAULT_FIELDS_SIZE + + +def test_max_payload_size_property_invalid(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_payload_size = -1 + + # Assert + assert splitter.max_payload_size == 55000 - DEFAULT_FIELDS_SIZE # Default value + + +def test_max_datapoints_property(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_datapoints = 200 + + # Assert + assert splitter.max_datapoints == 200 + + +def test_max_datapoints_property_invalid(): + # Setup + splitter = GatewayMessageSplitter() + + # Act + splitter.max_datapoints = -1 + + # Assert + assert splitter.max_datapoints == 0 # Default value + + +@pytest.mark.asyncio +async def test_split_timeseries_no_split_needed(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create a message with timeseries that doesn't need splitting + timeseries = TimeseriesEntry("temp", 22.5) + message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_timeseries(timeseries).build() + + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 1 + assert result[0] == message + + +@pytest.mark.asyncio +async def test_split_timeseries_by_size(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Create a message with timeseries that needs splitting by size + message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + + # Mock the size property to force splitting + timeseries_entry = TimeseriesEntry("temp", 22.5) + entries = [TimeseriesEntry("temp", 22.5) for _ in range(int(60 / timeseries_entry.size + 1))] + entries.extend([TimeseriesEntry("humidity", 45) for _ in range(int(60 / timeseries_entry.size + 1))]) + message_builder.add_timeseries(entries) + message = message_builder.build() + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 2 + assert result[0].device_name == "test_device" + assert result[1].device_name == "test_device" + + # Check that each result has only one of the timeseries entries + assert result[0].size - DEFAULT_FIELDS_SIZE < splitter.max_payload_size + assert result[1].size - DEFAULT_FIELDS_SIZE < splitter.max_payload_size + + +@pytest.mark.asyncio +async def test_split_timeseries_by_datapoints(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=1) + + # Create a message with timeseries that needs splitting by datapoints + message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + entries = [ + TimeseriesEntry("temp", 22.5), + TimeseriesEntry("humidity", 45) + ] + message_builder.add_timeseries(entries) + message = message_builder.build() + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 2 + assert result[0].device_name == "test_device" + assert result[1].device_name == "test_device" + + # Check that each result has only one of the timeseries entries + assert len(result[0].timeseries[0]) == 1 + assert len(result[1].timeseries[0]) == 1 + + # The entries should be in separate messages + key0 = result[0].timeseries[0][0].key + key1 = result[1].timeseries[0][0].key + assert key0 != key1 + assert key0 == 'temp' + assert key1 == 'humidity' + + +@pytest.mark.asyncio +async def test_split_timeseries_multiple_messages(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create multiple messages with timeseries + message1 = GatewayUplinkMessageBuilder().set_device_name("device1").add_timeseries( + TimeseriesEntry("temp", 22.5) + ).build() + + message2 = GatewayUplinkMessageBuilder().set_device_name("device2").add_timeseries( + TimeseriesEntry("humidity", 45) + ).build() + + # Act + result = splitter.split_timeseries([message1, message2]) + + # Assert + assert len(result) == 2 + + # Check that the messages are for different devices + devices = {msg.device_name for msg in result} + assert devices == {"device1", "device2"} + + # Check that each result has the correct timeseries + for msg in result: + if msg.device_name == "device1": + assert "temp" == msg.timeseries[0][0].key + elif msg.device_name == "device2": + assert "humidity" == msg.timeseries[0][0].key + + +@pytest.mark.asyncio +async def test_split_timeseries_with_delivery_futures(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Create a message with enough timeseries entries to exceed the max size + message_builder = GatewayUplinkMessageBuilder().set_device_name("test_device") + sample_entry = TimeseriesEntry("temp", 22.5) + # Create enough "temp" and "humidity" entries to force splitting + entries = [TimeseriesEntry("temp", 22.5) for _ in range(int(60 / sample_entry.size + 1))] + entries.extend([TimeseriesEntry("humidity", 45) for _ in range(int(60 / sample_entry.size + 1))]) + message_builder.add_timeseries(entries) + + # Add a delivery future to the original message + parent_future = asyncio.Future() + message_builder.add_delivery_futures([parent_future]) + + message = message_builder.build() + + # Mock the event loop and future_map + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock), \ + patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: + # Act + result = splitter.split_timeseries([message]) + + # Assert + assert len(result) == 2 + # Ensure futures are linked correctly + assert mock_register.call_count == 2 + mock_register.assert_any_call(parent_future, [future]) + + +@pytest.mark.asyncio +async def test_split_attributes_no_split_needed(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create a message with attributes that doesn't need splitting + message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes(AttributeEntry("key1", "value1")) + .build()) + + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) == 1 + assert result[0] == message + + +@pytest.mark.asyncio +async def test_split_attributes_by_size(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Enough repeated attributes to exceed 100 bytes (simulate large payload) + attrs = [AttributeEntry(f"key{i}", f"value{i}") for i in range(10)] + + message = GatewayUplinkMessageBuilder().set_device_name("test_device").add_attributes(attrs).build() + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) > 1 + # Ensure each split contains fewer attributes than the original + for part in result: + assert len(part.attributes) < len(attrs) + + +@pytest.mark.asyncio +async def test_split_attributes_by_datapoints(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=1) + + # Create a message with attributes that needs splitting by datapoints + message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes([AttributeEntry("key1", "value1"), + AttributeEntry("key2", "value2")]) + .build()) + + # Mock the event loop + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock): + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) == 2 + assert result[0].device_name == "test_device" + assert result[1].device_name == "test_device" + + # Check that each result has only one of the attribute entries + assert len(result[0].attributes) == 1 + assert len(result[1].attributes) == 1 + + # The entries should be in separate messages + keys0 = {attr.key for attr in result[0].attributes} + keys1 = {attr.key for attr in result[1].attributes} + assert len(keys0.intersection(keys1)) == 0 + assert keys0.union(keys1) == {"key1", "key2"} + + +@pytest.mark.asyncio +async def test_split_attributes_multiple_messages(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=10000, max_datapoints=100) + + # Create multiple messages with attributes + message1 = (GatewayUplinkMessageBuilder() + .set_device_name("device1") + .add_attributes(AttributeEntry("key1", "value1")) + .build()) + + message2 = (GatewayUplinkMessageBuilder() + .set_device_name("device2") + .add_attributes(AttributeEntry("key2", "value2")) + .build()) + + # Act + result = splitter.split_attributes([message1, message2]) + + # Assert + assert len(result) == 2 + + # Check that the messages are for different devices + devices = {msg.device_name for msg in result} + assert devices == {"device1", "device2"} + + # Check that each result has the correct attributes + for msg in result: + if msg.device_name == "device1": + assert msg.attributes[0].key == "key1" + elif msg.device_name == "device2": + assert msg.attributes[0].key == "key2" + + +@pytest.mark.asyncio +async def test_split_attributes_with_delivery_futures(): + # Setup + splitter = GatewayMessageSplitter(max_payload_size=100 + DEFAULT_FIELDS_SIZE, max_datapoints=100) + + # Create enough attributes to exceed the max size + attrs = [AttributeEntry(f"key{i}", f"value{i}") for i in range(10)] + + # Delivery future for the original message + parent_future = asyncio.Future() + + message = (GatewayUplinkMessageBuilder() + .set_device_name("test_device") + .add_attributes(attrs) + .add_delivery_futures([parent_future]) + .build()) + + # Mock the event loop and future_map + loop_mock = MagicMock() + future = asyncio.Future() + loop_mock.create_future.return_value = future + + with patch('asyncio.get_running_loop', return_value=loop_mock), \ + patch('tb_mqtt_client.common.async_utils.future_map.register') as mock_register: + # Act + result = splitter.split_attributes([message]) + + # Assert + assert len(result) > 1 + assert mock_register.call_count == len(result) + mock_register.assert_any_call(parent_future, [future]) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/service/test_json_message_adapter.py b/tests/service/test_json_message_adapter.py new file mode 100644 index 0000000..a0a0a5c --- /dev/null +++ b/tests/service/test_json_message_adapter.py @@ -0,0 +1,327 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch, mock_open + +import pytest +from orjson import dumps + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.mqtt_topics import DEVICE_TELEMETRY_TOPIC, DEVICE_ATTRIBUTES_TOPIC +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate +from tb_mqtt_client.entities.data.claim_request import ClaimRequest +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, BasicProvisioningCredentials, \ + X509ProvisioningCredentials, AccessTokenProvisioningCredentials +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse +from tb_mqtt_client.entities.data.rpc_request import RPCRequest +from tb_mqtt_client.entities.data.rpc_response import RPCResponse +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter + + +@pytest.fixture +def dummy_provisioning_request(): + credentials = AccessTokenProvisioningCredentials( + provision_device_key="key", + provision_device_secret="secret", + access_token="token" + ) + return ProvisioningRequest("some-device", credentials, device_name="dev", gateway=False) + + +def build_msg(device="devX", with_attr=False, with_ts=False): + builder = DeviceUplinkMessageBuilder().set_device_name(device) + if with_attr: + builder.add_attributes(AttributeEntry("a", 1)) + if with_ts: + builder.add_timeseries(TimeseriesEntry("t", 2, ts=1234567890)) + return builder.build() + + +@pytest.fixture +def adapter(): + return JsonMessageAdapter() + + +def test_build_attribute_request(adapter): + request = MagicMock(spec=AttributeRequest) + request.request_id = 1 + request.to_payload_format.return_value = {"clientKeys": "temp", "sharedKeys": "shared"} + mqtt_message = adapter.build_attribute_request(request) + assert mqtt_message.topic.endswith("/1") + assert b"clientKeys" in mqtt_message.payload + + +def test_build_attribute_request_invalid(adapter): + request = MagicMock(spec=AttributeRequest) + request.request_id = None + with pytest.raises(ValueError): + adapter.build_attribute_request(request) + + +def test_build_claim_request(adapter): + req = ClaimRequest.build("secretKey") + mqtt_message = adapter.build_claim_request(req) + assert mqtt_message.topic == mqtt_topics.DEVICE_CLAIM_TOPIC + assert b"secretKey" in mqtt_message.payload + + +def test_build_claim_request_invalid(adapter): + with pytest.raises(ValueError): + req = ClaimRequest.build(secret_key=None) # Simulating an invalid request # noqa + + +def test_build_rpc_request(adapter): + request = MagicMock(spec=RPCRequest) + request.request_id = 42 + request.to_payload_format.return_value = {"method": "reboot"} + mqtt_message = adapter.build_rpc_request(request) + assert mqtt_message.topic.endswith("42") + assert b"reboot" in mqtt_message.payload + + +def test_build_rpc_request_invalid(adapter): + request = MagicMock(spec=RPCRequest) + request.request_id = None + with pytest.raises(ValueError): + adapter.build_rpc_request(request) + + +def test_build_rpc_response(adapter): + response = MagicMock(spec=RPCResponse) + response.request_id = 123 + response.to_payload_format.return_value = {"result": "ok"} + mqtt_message = adapter.build_rpc_response(response) + assert mqtt_message.topic.endswith("123") + assert b"ok" in mqtt_message.payload + + +def test_build_rpc_response_invalid(adapter): + response = MagicMock(spec=RPCResponse) + response.request_id = None + with pytest.raises(ValueError): + adapter.build_rpc_response(response) + + +def test_build_provision_request_access_token(adapter): + credentials = AccessTokenProvisioningCredentials("key1", "secret1", access_token="tokenABC") + req = ProvisioningRequest("localhost", credentials, device_name="dev1", gateway=True) + mqtt_message = adapter.build_provision_request(req) + assert mqtt_message.topic == mqtt_topics.PROVISION_REQUEST_TOPIC + assert b"provisionDeviceKey" in mqtt_message.payload + assert b"tokenABC" in mqtt_message.payload + assert b"credentialsType" in mqtt_message.payload + assert b"deviceName" in mqtt_message.payload + assert b"gateway" in mqtt_message.payload + + +def test_build_provision_request_mqtt_basic(adapter): + credentials = BasicProvisioningCredentials("key2", "secret2", client_id="cid", username="user", password="pass") + req = ProvisioningRequest("127.0.0.1", credentials, device_name="dev2", gateway=False) + mqtt_message = adapter.build_provision_request(req) + assert b"clientId" in mqtt_message.payload + assert b"username" in mqtt_message.payload + assert b"password" in mqtt_message.payload + assert b"credentialsType" in mqtt_message.payload + + +def test_build_provision_request_x509(adapter): + cert_path = "/fake/path/cert.pem" + cert_content = "-----BEGIN CERTIFICATE-----\nFAKECERT\n-----END CERTIFICATE-----" + with patch("builtins.open", mock_open(read_data=cert_content)): + credentials = X509ProvisioningCredentials("key3", "secret3", "key.pem", cert_path, "ca.pem") + req = ProvisioningRequest("iot.server", credentials, device_name="dev3") + mqtt_message = adapter.build_provision_request(req) + assert b"hash" in mqtt_message.payload + assert b"credentialsType" in mqtt_message.payload + assert b"FAKECERT" in mqtt_message.payload + + +def test_build_provision_request_x509_file_not_found(adapter): + with patch("builtins.open", side_effect=FileNotFoundError): + with pytest.raises(FileNotFoundError): + X509ProvisioningCredentials("key", "secret", "k.pem", "nonexistent.pem", "ca.pem") + + +def test_parse_attribute_request_response(adapter): + topic = "v1/devices/me/attributes/response/42" + payload = dumps({"shared": {"temp": 22}}) + with patch.object(RequestedAttributeResponse, "from_dict", return_value="ok") as mock: + result = adapter.parse_requested_attribute_response(topic, payload) + assert result == "ok" + mock.assert_called_once() + + +def test_parse_attribute_request_response_invalid(adapter): + topic = "v1/devices/me/attributes/response/bad" + with pytest.raises(ValueError): + adapter.parse_requested_attribute_response(topic, b"invalid") + + +def test_parse_attribute_update(adapter): + payload = dumps({"shared": {"humidity": 60}}) + with patch.object(AttributeUpdate, "_deserialize_from_dict", return_value="AU"): + result = adapter.parse_attribute_update(payload) + assert result == "AU" + + +def test_parse_attribute_update_invalid(adapter): + with pytest.raises(ValueError): + adapter.parse_attribute_update(b"{bad}") + + +def test_parse_rpc_request(adapter): + topic = "v1/devices/me/rpc/request/123" + payload = dumps({"params": {"a": 1}}) + with patch.object(RPCRequest, "_deserialize_from_dict", return_value="REQ"): + assert adapter.parse_rpc_request(topic, payload) == "REQ" + + +def test_parse_rpc_request_invalid(adapter): + topic = "v1/devices/me/rpc/request/NaN" + with pytest.raises(ValueError): + adapter.parse_rpc_request(topic, b"{}") + + +def test_parse_rpc_response(adapter): + topic = "v1/devices/me/rpc/response/999" + payload = dumps({"value": "done"}) + with patch.object(RPCResponse, "build", return_value="RSP"): + assert adapter.parse_rpc_response(topic, payload) == "RSP" + + +def test_parse_rpc_response_with_error(adapter): + topic = "v1/devices/me/rpc/response/888" + error = ValueError("fail") + with patch.object(RPCResponse, "build", return_value="ERR"): + assert adapter.parse_rpc_response(topic, error) == "ERR" + + +def test_parse_rpc_response_invalid(adapter): + topic = "v1/devices/me/rpc/response/NaN" + with pytest.raises(ValueError): + adapter.parse_rpc_response(topic, b"bad") + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_empty(adapter: JsonMessageAdapter): + assert adapter.build_uplink_messages([]) == [] + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_only_attributes(adapter: JsonMessageAdapter): + msg = build_msg(with_attr=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) + with patch.object(adapter._splitter, "split_attributes", return_value=[msg]): + result = adapter.build_uplink_messages([initial_mqtt_message]) + assert len(result) == 1 + mqtt_message = result[0] + assert mqtt_message.topic == DEVICE_ATTRIBUTES_TOPIC + assert mqtt_message.datapoints == 1 + assert b"a" in mqtt_message.payload + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_only_timeseries(adapter: JsonMessageAdapter): + msg = build_msg(with_ts=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) + with patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): + result = adapter.build_uplink_messages([initial_mqtt_message]) + assert len(result) == 1 + mqtt_message = result[0] + assert mqtt_message.topic == DEVICE_TELEMETRY_TOPIC + assert mqtt_message.datapoints == 1 + assert b"ts" in mqtt_message.payload + + +@pytest.mark.asyncio +async def test_build_uplink_payloads_both(adapter: JsonMessageAdapter): + msg = build_msg(with_attr=True, with_ts=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) + + with patch.object(adapter._splitter, "split_attributes", return_value=[msg]), \ + patch.object(adapter._splitter, "split_timeseries", return_value=[msg]): + result = adapter.build_uplink_messages([initial_mqtt_message]) + assert len(result) == 2 + topics = {r.topic for r in result} + assert DEVICE_ATTRIBUTES_TOPIC in topics + assert DEVICE_TELEMETRY_TOPIC in topics + + +def test_build_payload_without_device_name(adapter: JsonMessageAdapter): + builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 9)) + msg = builder.build() + payload = adapter.build_payload(msg, False) + assert isinstance(payload, bytes) + assert b"x" in payload + + +def test_pack_attributes(): + builder = DeviceUplinkMessageBuilder().add_attributes(AttributeEntry("x", 10)) + msg = builder.build() + result = JsonMessageAdapter.pack_attributes(msg) + assert isinstance(result, dict) + assert "x" in result + + +def test_pack_timeseries_no_ts(monkeypatch): + monkeypatch.setattr("tb_mqtt_client.service.device.message_adapter.datetime", MagicMock()) + ts_entry = TimeseriesEntry("temp", 23, ts=None) + builder = DeviceUplinkMessageBuilder().add_timeseries(ts_entry) + msg = builder.build() + packed = JsonMessageAdapter.pack_timeseries(msg) + assert isinstance(packed, dict) + assert ts_entry.key in packed + assert packed[ts_entry.key] == ts_entry.value + + +def test_build_uplink_payloads_error_handling(adapter: JsonMessageAdapter): + with patch("tb_mqtt_client.service.device.message_adapter.DeviceUplinkMessage.has_attributes", + side_effect=Exception("boom")): + msg = build_msg(with_attr=True) + initial_mqtt_message = MqttPublishMessage(topic="", payload=msg) + with pytest.raises(Exception, match="boom"): + adapter.build_uplink_messages([initial_mqtt_message]) + + +def test_parse_provisioning_response_success(adapter, dummy_provisioning_request): + payload_dict = {"status": "SUCCESS", "credentialsType": "ACCESS_TOKEN"} + payload_bytes = dumps(payload_dict) + + with patch.object(ProvisioningResponse, "build", return_value="SUCCESS_RESPONSE") as mock_build: + result = adapter.parse_provisioning_response(dummy_provisioning_request, payload_bytes) + assert result == "SUCCESS_RESPONSE" + mock_build.assert_called_once_with(dummy_provisioning_request, payload_dict) + + +def test_parse_provisioning_response_failure(adapter, dummy_provisioning_request): + broken_bytes = b"{not_json" + + with patch.object(ProvisioningResponse, "build", return_value="FAILURE_RESPONSE") as mock_build: + result = adapter.parse_provisioning_response(dummy_provisioning_request, broken_bytes) + assert result == "FAILURE_RESPONSE" + mock_build.assert_called_once() + args = mock_build.call_args[0] + assert args[0] == dummy_provisioning_request + assert args[1]["status"] == "FAILURE" + assert "errorMsg" in args[1] + + +if __name__ == '__main__': + pytest.main([__file__, "--tb=short", "-v"]) diff --git a/tests/service/test_message_service.py b/tests/service/test_message_service.py new file mode 100644 index 0000000..956cd0f --- /dev/null +++ b/tests/service/test_message_service.py @@ -0,0 +1,1063 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit, EMPTY_RATE_LIMIT +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.gateway.gateway_uplink_message import GatewayUplinkMessageBuilder +from tb_mqtt_client.service.device.message_adapter import MessageAdapter +from tb_mqtt_client.service.gateway.message_adapter import GatewayMessageAdapter +from tb_mqtt_client.service.message_service import MessageService, MessageQueueWorker +from tb_mqtt_client.service.mqtt_manager import MQTTManager + + +@pytest_asyncio.fixture +async def setup_message_service(): + # Create mocks for all dependencies + mqtt_manager = MagicMock(spec=MQTTManager) + main_stop_event = asyncio.Event() + + # Mock rate limiters + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.values.return_value = [ + device_rate_limiter.message_rate_limit, + device_rate_limiter.telemetry_message_rate_limit, + device_rate_limiter.telemetry_datapoints_rate_limit + ] + + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.values.return_value = [ + gateway_rate_limiter.message_rate_limit, + gateway_rate_limiter.telemetry_message_rate_limit, + gateway_rate_limiter.telemetry_datapoints_rate_limit + ] + + # Mock message adapters + message_adapter = MagicMock(spec=MessageAdapter) + gateway_message_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Create the service with patched asyncio.create_task to avoid actual task creation + with patch('asyncio.create_task', new=MagicMock()): + service = MessageService( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + device_rate_limiter=device_rate_limiter, + message_adapter=message_adapter, + gateway_message_adapter=gateway_message_adapter, + gateway_rate_limiter=gateway_rate_limiter + ) + + # Replace the actual tasks with mocks + service._initial_queue_task = MagicMock() + service._service_queue_task = MagicMock() + service._device_uplink_messages_queue_task = MagicMock() + service._gateway_uplink_messages_queue_task = MagicMock() + service._rate_limit_refill_task = MagicMock() + service.__print_queue_statistics_task = MagicMock() + + yield (service, mqtt_manager, main_stop_event, device_rate_limiter, + message_adapter, gateway_message_adapter, gateway_rate_limiter) + + +@pytest_asyncio.fixture +async def setup_retry_loop_service(setup_message_service): + service, mqtt_manager, main_stop_event, *_ = setup_message_service + service._main_stop_event = main_stop_event + service._active.set() + + # Mock dependencies + service._retry_by_qos_queue = AsyncMock() + service._service_queue = AsyncMock() + service._gateway_uplink_messages_queue = AsyncMock() + service._device_uplink_messages_queue = AsyncMock() + + # Cooldown patch to avoid slow tests + service._QUEUE_COOLDOWN = 0 + + return service, mqtt_manager + + +@pytest.mark.asyncio +async def test_publish_success(setup_message_service): + service, mqtt_manager, _, _, _, _, _ = setup_message_service + + # Mock the queue put method + service._initial_queue.put = AsyncMock() + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the publish method + await service.publish(message) + + # Verify the message was added to the queue + service._initial_queue.put.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_publish_exception(setup_message_service): + service, mqtt_manager, _, _, _, _, _ = setup_message_service + + # Mock the queue put method to raise an exception + service._initial_queue.put = AsyncMock(side_effect=Exception("Queue error")) + + # Create a test message with a delivery future + future = asyncio.Future() + message = MqttPublishMessage("test/topic", b"test_payload", delivery_futures=[future]) + + # Call the publish method + await service.publish(message) + + # Verify the future was completed with an error result + assert future.done() + result = future.result() + assert isinstance(result, PublishResult) + assert result.reason_code == -1 + + +@pytest.mark.asyncio +async def test_shutdown(): + """Test the shutdown method with real tasks to ensure full coverage.""" + # Create mocks for all dependencies + mqtt_manager = MagicMock(spec=MQTTManager) + main_stop_event = asyncio.Event() + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.values.return_value = [] + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.values.return_value = [] + message_adapter = MagicMock(spec=MessageAdapter) + gateway_message_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Create a real service with real tasks + with patch('asyncio.create_task', side_effect=asyncio.create_task) as mock_create_task: + # Create simple async functions for the tasks + async def mock_task(): + try: + while True: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + # This is expected during shutdown + pass + + # Patch the methods that would create tasks to return our mock task + with patch.object(MessageService, '_dispatch_initial_queue_loop', return_value=mock_task()), \ + patch.object(MessageService, '_dispatch_queue_loop', return_value=mock_task()), \ + patch.object(MessageService, '_rate_limit_refill_loop', return_value=mock_task()), \ + patch.object(MessageService, 'print_queues_statistics', return_value=mock_task()), \ + patch.object(MessageService, 'clear', new_callable=AsyncMock) as mock_clear: + + # Create the service + service = MessageService( + mqtt_manager=mqtt_manager, + main_stop_event=main_stop_event, + device_rate_limiter=device_rate_limiter, + message_adapter=message_adapter, + gateway_message_adapter=gateway_message_adapter, + gateway_rate_limiter=gateway_rate_limiter + ) + + # Verify that tasks were created + assert mock_create_task.call_count >= 5 + + # Call the shutdown method + await service.shutdown() + + # Verify the active flag was cleared + assert not service._active.is_set() + + # Verify the clear method was called + mock_clear.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_is_empty(setup_message_service): + service, _, _, _, _, _, _ = setup_message_service + + # Mock the queue is_empty method + service._initial_queue.is_empty = MagicMock(return_value=True) + + # Check if the queue is empty + assert service.is_empty() + + # Verify the is_empty method was called + service._initial_queue.is_empty.assert_called_once() + + +@pytest.mark.asyncio +async def test_clear(setup_message_service): + service, _, _, _, _, _, _ = setup_message_service + + # Mock the queue methods + service._initial_queue.is_empty = MagicMock(side_effect=[False, True]) + service._initial_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + service._service_queue.is_empty = MagicMock(side_effect=[False, True]) + service._service_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + service._device_uplink_messages_queue.is_empty = MagicMock(side_effect=[False, True]) + service._device_uplink_messages_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + service._gateway_uplink_messages_queue.is_empty = MagicMock(side_effect=[False, True]) + service._gateway_uplink_messages_queue.get = AsyncMock(return_value=MqttPublishMessage("topic", b"payload")) + + # Call the clear method + await service.clear() + + # Verify the get method was called for each queue + service._initial_queue.get.assert_awaited_once() + service._service_queue.get.assert_awaited_once() + service._device_uplink_messages_queue.get.assert_awaited_once() + service._gateway_uplink_messages_queue.get.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_set_gateway_message_adapter(setup_message_service): + service, _, _, _, _, _, gateway_message_adapter = setup_message_service + + # Create a new mock adapter + new_adapter = MagicMock(spec=GatewayMessageAdapter) + + # Set the new adapter + service.set_gateway_message_adapter(new_adapter) + + # Verify the adapter was set + assert service._gateway_adapter == new_adapter + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_service_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Mock the queue methods + service._initial_queue.peek_batch = AsyncMock(return_value=[ + MqttPublishMessage("topic", b"raw_payload") + ]) + service._initial_queue.pop_n = AsyncMock() + service._initial_queue.is_empty = MagicMock(return_value=True) + + service._service_queue.put = AsyncMock() + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # Verify the message was put in the service queue + service._service_queue.put.assert_awaited_once() + service._initial_queue.pop_n.assert_awaited_once_with(1) + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_device_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, message_adapter, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a device uplink message using the builder + device_uplink = DeviceUplinkMessageBuilder().set_device_name("test-device").build() + device_message = MqttPublishMessage("topic", device_uplink) + + # Mock the queue methods + service._initial_queue.peek_batch = AsyncMock(return_value=[device_message]) + service._initial_queue.pop_n = AsyncMock() + service._initial_queue.is_empty = MagicMock(return_value=True) + + service._device_uplink_messages_queue.extend = AsyncMock() + + # Mock the adapter + processed_messages = [MqttPublishMessage("processed_topic", b"processed_payload")] + message_adapter.build_uplink_messages.return_value = processed_messages + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # Verify the adapter was called and messages were added to the device queue + message_adapter.build_uplink_messages.assert_called_once_with([device_message]) + service._device_uplink_messages_queue.extend.assert_awaited_once_with(processed_messages) + service._initial_queue.pop_n.assert_awaited_once_with(1) + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_gateway_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, message_adapter, gateway_message_adapter, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a gateway uplink message using the builder + gateway_uplink = GatewayUplinkMessageBuilder().set_device_name("test-device").build() + gateway_message = MqttPublishMessage("topic", gateway_uplink) + + # Mock the queue methods + service._initial_queue.peek_batch = AsyncMock(return_value=[gateway_message]) + service._initial_queue.pop_n = AsyncMock() + service._initial_queue.is_empty = MagicMock(return_value=True) + + service._gateway_uplink_messages_queue.extend = AsyncMock() + + # Mock the adapter + processed_messages = [MqttPublishMessage("processed_topic", b"processed_payload")] + gateway_message_adapter.build_uplink_messages.return_value = processed_messages + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # Verify the adapter was called and messages were added to the gateway queue + gateway_message_adapter.build_uplink_messages.assert_called_once_with([gateway_message]) + service._gateway_uplink_messages_queue.extend.assert_awaited_once_with(processed_messages) + service._initial_queue.pop_n.assert_awaited_once_with(1) + + +@pytest.mark.asyncio +async def test_dispatch_initial_queue_loop_exception(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Mock the queue methods to raise an exception + service._initial_queue.peek_batch = AsyncMock(side_effect=Exception("Test exception")) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_initial_queue_loop() + + # The test passes if no exception is raised + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods + message = MqttPublishMessage("topic", b"payload") + queue.get = AsyncMock(return_value=message) + queue.is_empty = MagicMock(return_value=True) + + # Mock the worker process method + worker.process = AsyncMock(return_value=(None, None, None)) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # Verify the worker process method was called + worker.process.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop_rate_limit_triggered(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods + message = MqttPublishMessage("topic", b"payload") + queue.get = AsyncMock(return_value=message) + queue.is_empty = MagicMock(return_value=True) + queue.reinsert_front = AsyncMock() + + # Mock the worker process method to trigger rate limit + rate_limit = MagicMock(spec=RateLimit) + rate_limit.required_tokens_ready = asyncio.Event() + rate_limit.required_tokens_ready.set() # Set it to avoid waiting in the test + worker.process = AsyncMock(return_value=(10, 5, rate_limit)) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # Verify the worker process method was called + worker.process.assert_awaited_once_with(message) + + # Verify the message was reinserted to the front of the queue + queue.reinsert_front.assert_awaited_once_with(message) + + # Verify the rate limit was set + rate_limit.set_required_tokens.assert_called_once_with(10, 5) + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop_empty_message(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods to return None + queue.get = AsyncMock(return_value=None) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # Verify the worker process method was not called + worker.process.assert_not_called() + + +@pytest.mark.asyncio +async def test_dispatch_queue_loop_exception(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Set up the test to run once and then exit + main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + # Create a mock queue and worker + queue = AsyncMock() + worker = MagicMock(spec=MessageQueueWorker) + + # Mock the queue methods to raise an exception + queue.get = AsyncMock(side_effect=Exception("Test exception")) + + # Run the dispatch loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._dispatch_queue_loop(queue, worker) + + # The test passes if no exception is raised + + +@pytest.mark.asyncio +async def test_rate_limit_refill_loop(setup_message_service): + service, mqtt_manager, _, device_rate_limiter, _, _, gateway_rate_limiter = setup_message_service + + # Set up the test to run once and then exit + service._active.is_set = MagicMock(side_effect=[True, False]) + + # Mock the refill methods + service._refill_rate_limits = AsyncMock() + + # Run the refill loop + with patch('asyncio.sleep', new=AsyncMock()): + await service._rate_limit_refill_loop() + + # Verify the refill method was called + service._refill_rate_limits.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_refill_rate_limits(setup_message_service): + service, mqtt_manager, _, device_rate_limiter, _, _, gateway_rate_limiter = setup_message_service + + # Mock the rate limit refill methods + device_rate_limit = MagicMock(spec=RateLimit) + device_rate_limit.refill = AsyncMock() + device_rate_limiter.values.return_value = [device_rate_limit] + + gateway_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limit.refill = AsyncMock() + gateway_rate_limiter.values.return_value = [gateway_rate_limit] + + # Call the refill method + await service._refill_rate_limits() + + # Verify the refill methods were called + device_rate_limit.refill.assert_awaited_once() + gateway_rate_limit.refill.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_print_queue_statistics(setup_message_service): + service, mqtt_manager, main_stop_event, _, _, _, _ = setup_message_service + + # Create a patched version of the print_queue_statistics method to avoid infinite loop + original_print_queue_statistics = service.print_queues_statistics # noqa + + async def patched_print_queue_statistics(): + # Just run the body of the loop once + initial_queue_size = service._initial_queue.size() # noqa + service_queue_size = service._service_queue.size() # noqa + device_uplink_queue_size = service._device_uplink_messages_queue.size() # noqa + gateway_uplink_queue_size = service._gateway_uplink_messages_queue.size() # noqa + # We don't need to log anything in the test + + # Replace the method with our patched version + service.print_queues_statistics = patched_print_queue_statistics + + # Mock the queue size methods + service._initial_queue.size = MagicMock(return_value=1) + service._service_queue.size = MagicMock(return_value=2) + service._device_uplink_messages_queue.size = MagicMock(return_value=3) + service._gateway_uplink_messages_queue.size = MagicMock(return_value=4) + + # Run the statistics method + await service.print_queues_statistics() + + # Verify the size methods were called + service._initial_queue.size.assert_called_once() + service._service_queue.size.assert_called_once() + service._device_uplink_messages_queue.size.assert_called_once() + service._gateway_uplink_messages_queue.size.assert_called_once() + + +@pytest.mark.asyncio +async def test_cancel_tasks(): + # Create mock tasks + task1 = MagicMock(spec=asyncio.Task) + task2 = MagicMock(spec=asyncio.Task) + tasks = {task1, task2} + + # Call the cancel_tasks method + with patch('asyncio.gather', new=AsyncMock()): + await MessageService._cancel_tasks(tasks) + + # Verify the tasks were canceled + task1.cancel.assert_called_once() + task2.cancel.assert_called_once() + assert len(tasks) == 0 + + +# Tests for MessageQueueWorker class +@pytest.mark.asyncio +async def test_message_queue_worker_process_no_rate_limits(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + device_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter = MagicMock(spec=RateLimiter) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Mock the rate limit methods + worker._get_rate_limits_for_message = MagicMock(return_value=(EMPTY_RATE_LIMIT, EMPTY_RATE_LIMIT)) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the process method + result = await worker.process(message) + + # Verify the mqtt_manager.publish method was called + mqtt_manager.publish.assert_awaited_once_with(message) + + # Verify the result + assert result == (None, None, None) + + +@pytest.mark.asyncio +async def test_message_queue_worker_process_with_rate_limits_not_triggered(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + device_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter = MagicMock(spec=RateLimiter) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create rate limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + + # Mock the rate limit methods + worker._get_rate_limits_for_message = MagicMock(return_value=(message_rate_limit, datapoints_rate_limit)) + worker.check_rate_limits_for_message = AsyncMock(return_value=(None, 0, None)) + worker._consume_rate_limits_for_message = AsyncMock() + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the process method + result = await worker.process(message) + + # Verify the rate limit methods were called + worker.check_rate_limits_for_message.assert_awaited_once_with( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + worker._consume_rate_limits_for_message.assert_awaited_once_with( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the mqtt_manager.publish method was called + mqtt_manager.publish.assert_awaited_once_with(message) + + # Verify the result + assert result == (None, None, None) + + +@pytest.mark.asyncio +async def test_message_queue_worker_process_with_rate_limits_triggered(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + device_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter = MagicMock(spec=RateLimiter) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create rate limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + + # Mock the rate limit methods + worker._get_rate_limits_for_message = MagicMock(return_value=(message_rate_limit, datapoints_rate_limit)) + + # Set up the rate limit to be triggered + triggered_rate_limit_entry = (10, 5) # (tokens, duration) + expected_tokens = 5 + worker.check_rate_limits_for_message = AsyncMock( + return_value=(triggered_rate_limit_entry, expected_tokens, message_rate_limit)) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + + # Call the process method + result = await worker.process(message) + + # Verify the rate limit methods were called + worker.check_rate_limits_for_message.assert_awaited_once_with( + datapoints_count=message.datapoints, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the mqtt_manager.publish method was not called + mqtt_manager.publish.assert_not_awaited() + + # Verify the result + assert result == (5, 5, message_rate_limit) + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_device_service(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = True + message.is_service_message = True + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == device_rate_limiter.message_rate_limit + assert datapoints_rate_limit == EMPTY_RATE_LIMIT + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_device_telemetry(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = True + message.is_service_message = False + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == device_rate_limiter.telemetry_message_rate_limit + assert datapoints_rate_limit == device_rate_limiter.telemetry_datapoints_rate_limit + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_gateway_service(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = False + message.is_service_message = True + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == gateway_rate_limiter.message_rate_limit + assert datapoints_rate_limit == EMPTY_RATE_LIMIT + + +@pytest.mark.asyncio +async def test_message_queue_worker_get_rate_limits_for_message_gateway_telemetry(): + # Create mocks + queue = AsyncMock() + stop_event = asyncio.Event() + mqtt_manager = MagicMock(spec=MQTTManager) + + # Create device rate limiter with proper attributes + device_rate_limiter = MagicMock(spec=RateLimiter) + device_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + device_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create gateway rate limiter with proper attributes + gateway_rate_limiter = MagicMock(spec=RateLimiter) + gateway_rate_limiter.message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_message_rate_limit = MagicMock(spec=RateLimit) + gateway_rate_limiter.telemetry_datapoints_rate_limit = MagicMock(spec=RateLimit) + + # Create the worker + worker = MessageQueueWorker( + "TestWorker", + queue, + stop_event, + mqtt_manager, + device_rate_limiter, + gateway_rate_limiter + ) + + # Create a test message + message = MqttPublishMessage("test/topic", b"test_payload") + message.is_device_message = False + message.is_service_message = False + + # Call the method + message_rate_limit, datapoints_rate_limit = worker._get_rate_limits_for_message(message) + + # Verify the correct rate limits were returned + assert message_rate_limit == gateway_rate_limiter.telemetry_message_rate_limit + assert datapoints_rate_limit == gateway_rate_limiter.telemetry_datapoints_rate_limit + + +@pytest.mark.asyncio +async def test_message_queue_worker_check_rate_limits_for_message_message_limit_triggered(): + # Create a mock message rate limit that will trigger + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.try_consume = AsyncMock(return_value=(10, 5)) # (tokens, duration) + + # Create a mock datapoints rate limit that won't be checked + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + + # Call the method + result = await MessageQueueWorker.check_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the result + assert result == ((10, 5), 1, message_rate_limit) + + # Verify the message rate limit was checked + message_rate_limit.try_consume.assert_awaited_once_with(1) + + # Verify the datapoints rate limit was not checked + datapoints_rate_limit.try_consume.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_message_queue_worker_check_rate_limits_for_message_datapoints_limit_triggered(): + # Create a mock message rate limit that won't trigger + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.try_consume = AsyncMock(return_value=None) + + # Create a mock datapoints rate limit that will trigger + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + datapoints_rate_limit.try_consume = AsyncMock(return_value=(20, 10)) # (tokens, duration) + + # Call the method + result = await MessageQueueWorker.check_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the result + assert result == ((20, 10), 5, datapoints_rate_limit) + + # Verify both rate limits were checked + message_rate_limit.try_consume.assert_awaited_once_with(1) + datapoints_rate_limit.try_consume.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_message_queue_worker_check_rate_limits_for_message_no_limits_triggered(): + # Create mock rate limits that won't trigger + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.try_consume = AsyncMock(return_value=None) + + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + datapoints_rate_limit.try_consume = AsyncMock(return_value=None) + + # Call the method + result = await MessageQueueWorker.check_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify the result + assert result == (None, 0, None) + + # Verify both rate limits were checked + message_rate_limit.try_consume.assert_awaited_once_with(1) + datapoints_rate_limit.try_consume.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_message_queue_worker_consume_rate_limits_for_message(): + # Create mock rate limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = True + message_rate_limit.consume = AsyncMock() + + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = True + datapoints_rate_limit.consume = AsyncMock() + + # Call the method + await MessageQueueWorker._consume_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify both rate limits were consumed + message_rate_limit.consume.assert_awaited_once_with(1) + datapoints_rate_limit.consume.assert_awaited_once_with(5) + + +@pytest.mark.asyncio +async def test_message_queue_worker_consume_rate_limits_for_message_no_limits(): + # Create mock rate limits with no limits + message_rate_limit = MagicMock(spec=RateLimit) + message_rate_limit.has_limit.return_value = False + + datapoints_rate_limit = MagicMock(spec=RateLimit) + datapoints_rate_limit.has_limit.return_value = False + + # Call the method + await MessageQueueWorker._consume_rate_limits_for_message( + datapoints_count=5, + message_rate_limit=message_rate_limit, + datapoints_rate_limit=datapoints_rate_limit + ) + + # Verify no rate limits were consumed + message_rate_limit.consume.assert_not_awaited() + datapoints_rate_limit.consume.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_retry_loop_with_bytes_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + message = MagicMock() + message.original_payload = b"binary" + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[message, asyncio.CancelledError()]) + + await service._dispatch_retry_by_qos_queue_loop() + + service._service_queue.reinsert_front.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_retry_loop_with_gateway_uplink_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + message = MagicMock() + message.original_payload = GatewayUplinkMessageBuilder().set_device_name("test").build() + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[message, asyncio.CancelledError()]) + + await service._dispatch_retry_by_qos_queue_loop() + + service._gateway_uplink_messages_queue.reinsert_front.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_retry_loop_with_device_uplink_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + message = MagicMock() + message.original_payload = DeviceUplinkMessageBuilder().set_device_name("test-device").build() + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[message, asyncio.CancelledError()]) + + await service._dispatch_retry_by_qos_queue_loop() + + service._device_uplink_messages_queue.reinsert_front.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_retry_loop_with_cancelled_error(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + mqtt_manager.is_connected.return_value = True + + service._retry_by_qos_queue.get = AsyncMock(side_effect=asyncio.CancelledError()) + + await service._dispatch_retry_by_qos_queue_loop() + + +@pytest.mark.asyncio +async def test_retry_loop_with_not_connected_and_empty_message(setup_retry_loop_service): + service, mqtt_manager = setup_retry_loop_service + + mqtt_manager.is_connected.side_effect = [False, True] + + service._retry_by_qos_queue.get = AsyncMock(side_effect=[None, asyncio.CancelledError()]) + + service._main_stop_event.is_set = MagicMock(side_effect=[False, True]) + + with patch("asyncio.sleep", new=AsyncMock()): + await service._dispatch_retry_by_qos_queue_loop() + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_message_splitter.py b/tests/service/test_message_splitter.py new file mode 100644 index 0000000..ae94627 --- /dev/null +++ b/tests/service/test_message_splitter.py @@ -0,0 +1,231 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.entities.data.attribute_entry import AttributeEntry +from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessageBuilder +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.service.device.message_adapter import JsonMessageAdapter +from tb_mqtt_client.service.device.message_splitter import MessageSplitter + + +@pytest.fixture +def splitter(): + return MessageSplitter(max_payload_size=100, max_datapoints=3) + + +# Positive cases +def test_single_small_timeseries_pass_through(splitter): + msg = MagicMock() + msg.has_timeseries.return_value = True + msg.size = 50 + msg.attributes_datapoint_count.return_value = 0 + msg.timeseries_datapoint_count.return_value = 2 + result = splitter.split_timeseries([msg]) + assert result == [msg] + + +def test_single_small_attributes_pass_through(splitter): + msg = MagicMock() + msg.has_attributes.return_value = True + msg.size = 50 + msg.attributes_datapoint_count.return_value = 2 + msg.timeseries_datapoint_count.return_value = 0 + result = splitter.split_attributes([msg]) + assert result == [msg] + + +# Negative test: invalid payload size and datapoints +def test_invalid_config_defaults(): + splitter = MessageSplitter(max_payload_size=-10, max_datapoints=-100) + assert splitter.max_payload_size == 55000 + assert splitter.max_datapoints == 0 + + +# Negative test: empty message list +def test_empty_list_returns_empty(splitter): + assert splitter.split_timeseries([]) == [] + assert splitter.split_attributes([]) == [] + + +# Negative test: message without required fields +def test_malformed_message_handling(splitter): + msg_ts = MagicMock() + msg_ts.has_timeseries.side_effect = Exception("Malformed TS field") + msg_ts.attributes_datapoint_count.return_value = 0 + msg_ts.timeseries_datapoint_count.return_value = 0 + msg_ts.size = 200 + + with pytest.raises(Exception, match="Malformed TS field"): + splitter.split_timeseries([msg_ts]) + + msg_attr = MagicMock() + msg_attr.has_attributes.side_effect = Exception("Malformed Attr field") + msg_attr.attributes_datapoint_count.return_value = 0 + msg_attr.timeseries_datapoint_count.return_value = 0 + msg_attr.size = 200 + + with pytest.raises(Exception, match="Malformed Attr field"): + splitter.split_attributes([msg_attr]) + + +# Negative test: builder fails on build() +@patch("tb_mqtt_client.service.device.message_splitter.DeviceUplinkMessageBuilder") +def test_builder_failure_during_split_raises(mock_builder_class): + async def run_test(): + entry = MagicMock() + entry.size = 10 + + message = MagicMock() + message.device_name = "dev" + message.device_profile = "prof" + message.has_timeseries.return_value = True + message.timeseries = {"temp": [entry] * 5} + message.get_delivery_futures.return_value = [] + message.attributes_datapoint_count.return_value = 0 + message.timeseries_datapoint_count.return_value = 5 + message.size = 100 + + builder_instance = MagicMock() + builder_instance.set_device_name.return_value = builder_instance + builder_instance.set_device_profile.return_value = builder_instance + builder_instance.set_main_ts.return_value = builder_instance + + builder_instance._timeseries = [] + + def add_ts_side_effect(ts): + builder_instance._timeseries.append(ts) + + builder_instance.add_timeseries.side_effect = add_ts_side_effect + builder_instance.add_delivery_futures.return_value = None + builder_instance.build.side_effect = RuntimeError("build failed") + + mock_builder_class.return_value = builder_instance + + splitter = MessageSplitter(max_payload_size=20, max_datapoints=2) + + with pytest.raises(RuntimeError, match="build failed"): + splitter.split_timeseries([message]) + + asyncio.run(run_test()) + + +# Property validation +def test_payload_setter_validation(): + s = MessageSplitter() + s.max_payload_size = 12345 + assert s.max_payload_size == 12345 + s.max_payload_size = 0 + assert s.max_payload_size == 55000 + + +def test_datapoint_setter_validation(): + s = MessageSplitter() + s.max_datapoints = 99 + assert s.max_datapoints == 99 + s.max_datapoints = 0 + assert s.max_datapoints == 0 + s.max_datapoints = -5 + assert s.max_datapoints == 0 + + +@pytest.mark.asyncio +async def test_split_attributes_grouping(): + dispatcher = JsonMessageAdapter(max_payload_size=200, max_datapoints=5) + + builder1 = DeviceUplinkMessageBuilder().set_device_name("deviceA").set_device_profile("default") + builder2 = DeviceUplinkMessageBuilder().set_device_name("deviceA").set_device_profile("default") + + for i in range(3): + builder1.add_attributes(AttributeEntry(f"key_{i}", i)) + + for i in range(3, 6): + builder2.add_attributes(AttributeEntry(f"key_{i}", i)) + + messages = [builder1.build(), builder2.build()] + result = dispatcher.splitter.split_attributes(messages) + + assert len(result) == 2 + total_attrs = sum(len(msg.attributes) for msg in result) + assert total_attrs == 6 + assert all(msg.device_name == "deviceA" for msg in result) + for msg in result: + for fut in msg.get_delivery_futures(): + fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) + + +@pytest.mark.asyncio +async def test_split_attributes_different_devices_not_grouped(): + dispatcher = JsonMessageAdapter(max_payload_size=200, max_datapoints=100) + + builder1 = DeviceUplinkMessageBuilder().set_device_name("deviceA") + builder2 = DeviceUplinkMessageBuilder().set_device_name("deviceB") + + for i in range(3): + builder1.add_attributes(AttributeEntry(f"key_{i}", i)) + builder2.add_attributes(AttributeEntry(f"key_{i + 3}", i + 3)) + + result = dispatcher.splitter.split_attributes([builder1.build(), builder2.build()]) + + assert len(result) == 2 + assert result[0].device_name != result[1].device_name + for msg in result: + for fut in msg.get_delivery_futures(): + fut.set_result(PublishResult("test/topic", 1, 1, 100, 0)) + + +@patch("tb_mqtt_client.service.device.message_splitter.future_map.register") +@pytest.mark.asyncio +async def test_split_timeseries_registers_futures_and_batches_correctly(mock_register): + splitter = MessageSplitter(max_payload_size=100, max_datapoints=2) + + builder = DeviceUplinkMessageBuilder().set_device_name("deviceX").set_device_profile("profileY") + entry1 = TimeseriesEntry("temp", 1, 100) + entry2 = TimeseriesEntry("humidity", 2, 100) + entry3 = TimeseriesEntry("pressure", 3, 100) + + builder.add_timeseries(entry1) + builder.add_timeseries(entry2) + builder.add_timeseries(entry3) + + parent_future = asyncio.Future() + builder.add_delivery_futures(parent_future) + original_msg = builder.build() + + result = splitter.split_timeseries([original_msg]) + + # Should be split due to datapoints=3 > max_datapoints=2 → 2 batches + assert len(result) == 2 + total_points = sum(m.timeseries_datapoint_count() for m in result) + assert total_points == 3 + + # Validate that register was called for the shared future for each batch + assert mock_register.call_count == 2 + for call in mock_register.call_args_list: + args, _ = call + assert args[0] is parent_future + assert isinstance(args[1], list) + assert len(args[1]) == 1 + shared_future = args[1][0] + assert isinstance(shared_future, asyncio.Future) + assert hasattr(shared_future, "uuid") + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short']) diff --git a/tests/service/test_mqtt_manager.py b/tests/service/test_mqtt_manager.py new file mode 100644 index 0000000..4ee9cf0 --- /dev/null +++ b/tests/service/test_mqtt_manager.py @@ -0,0 +1,665 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import ssl +from time import monotonic +from unittest.mock import AsyncMock, MagicMock, patch, call, PropertyMock + +import pytest +import pytest_asyncio + +from tb_mqtt_client.common.exceptions import BackpressureException +from tb_mqtt_client.common.mqtt_message import MqttPublishMessage +from tb_mqtt_client.common.publish_result import PublishResult +from tb_mqtt_client.common.rate_limit.rate_limit import RateLimit +from tb_mqtt_client.common.rate_limit.rate_limiter import RateLimiter +from tb_mqtt_client.service.device.handlers.rpc_response_handler import RPCResponseHandler +from tb_mqtt_client.service.device.message_adapter import MessageAdapter +from tb_mqtt_client.service.mqtt_manager import MQTTManager, IMPLEMENTATION_SPECIFIC_ERROR, QUOTA_EXCEEDED + + +@pytest_asyncio.fixture +async def setup_manager(): + stop_event = asyncio.Event() + message_adapter = MagicMock(spec=MessageAdapter) + on_connect = AsyncMock() + on_disconnect = AsyncMock() + on_publish_result = AsyncMock() + rate_limits_handler = AsyncMock() + rpc_response_handler = MagicMock(spec=RPCResponseHandler) + + manager = MQTTManager( + client_id="test-client", + main_stop_event=stop_event, + message_adapter=message_adapter, + on_connect=on_connect, + on_disconnect=on_disconnect, + on_publish_result=on_publish_result, + rate_limits_handler=rate_limits_handler, + rpc_response_handler=rpc_response_handler + ) + return (manager, stop_event, message_adapter, on_connect, on_disconnect, + on_publish_result, rate_limits_handler, rpc_response_handler) + + +@pytest.mark.asyncio +async def test_is_connected_returns_false_if_not_ready(setup_manager): + manager, *_ = setup_manager + assert not manager.is_connected() + + +@pytest.mark.asyncio +async def test_register_and_unregister_handler(setup_manager): + manager, *_ = setup_manager + + async def dummy(topic, payload): pass + + manager.register_handler("topic/+", dummy) + assert "topic/+" in manager._handlers + manager.unregister_handler("topic/+") + assert "topic/+" not in manager._handlers + + +@pytest.mark.asyncio +async def test_on_disconnect_internal_abnormal_disconnect(setup_manager): + manager, *_ = setup_manager + + fut1 = asyncio.Future() + fut2 = asyncio.Future() + manager._pending_publishes[101] = (fut1, MqttPublishMessage("topic1", b"test"), monotonic()) + manager._pending_publishes[102] = (fut2, MqttPublishMessage("topic2", b"test"), monotonic()) + + manager._backpressure = MagicMock() + manager._backpressure.notify_disconnect = MagicMock() + + manager._on_disconnect_internal(manager._client, reason_code=1) + + assert fut1.done() and fut2.done() + + manager._backpressure.notify_disconnect.assert_called() + + +@pytest.mark.asyncio +async def test_handle_puback_reason_code_unknown_id(setup_manager): + manager, *_ = setup_manager + manager._handle_puback_reason_code(999, 0, {}) + + +@pytest.mark.asyncio +async def test_on_message_internal_handler_exception(setup_manager): + manager, *_ = setup_manager + + async def bad_handler(topic, payload): + raise ValueError("oops") + + manager.register_handler("test/topic", bad_handler) + manager._on_message_internal(manager._client, "test/topic", b"{}", 0, {}) + await asyncio.sleep(0.05) # Let async task run + + +def test_match_topic_full_wildcard(): + assert MQTTManager._match_topic("#", "any/depth/of/topic") + + +@pytest.mark.asyncio +async def test_publish_fails_without_rate_limits(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = False + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + with pytest.raises(RuntimeError, match="Cannot publish before rate limits are retrieved."): + mqtt_message = MqttPublishMessage("topic", b"payload") + await manager.publish(mqtt_message, force=False) + + +@pytest.mark.asyncio +async def test_publish_force_bypasses_limits(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._rate_limits_ready_event.set() + + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (10, b"packet") + manager._client._persistent_storage = MagicMock() + + mqtt_publish_message = MqttPublishMessage("topic", b"payload", qos=1) + await manager.publish(mqtt_publish_message, force=True) + assert manager._client._connection.publish.call_count == 1 + + +@pytest.mark.asyncio +async def test_publish_backpressure_blocks(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._rate_limits_ready_event.set() + manager._backpressure = MagicMock() + manager._backpressure.should_pause.return_value = True + + with pytest.raises(BackpressureException): + await manager.publish(MqttPublishMessage("t", b"x"), force=False) + + +@pytest.mark.asyncio +async def test_on_disconnect_internal_clears_futures(setup_manager): + manager, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[42] = (fut, MqttPublishMessage("topic", b"payload"), monotonic()) + manager._on_disconnect_internal(manager._client, reason_code=0) + assert not manager._pending_publishes + assert fut.done() + assert isinstance(fut.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_on_message_internal_triggers_handler(setup_manager): + manager, *_ = setup_manager + called = asyncio.Event() + + async def dummy_handler(topic, payload): + called.set() + + manager.register_handler("foo/bar", dummy_handler) + manager._on_message_internal(manager._client, "foo/bar", b"123", 1, {}) + await asyncio.wait_for(called.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_handle_puback_reason_code(setup_manager): + manager, *_ = setup_manager + fut = asyncio.Future() + fut.uuid = "test-future" + manager._pending_publishes[123] = (fut, MqttPublishMessage("topic", b"payload", delivery_futures=[fut]), + monotonic()) + manager._handle_puback_reason_code(123, 0, {}) + assert fut.done() + assert fut.result().message_id == 123 + + +@pytest.mark.asyncio +async def test_await_ready_timeout(setup_manager): + manager, stop_event, *_ = setup_manager + with patch("tb_mqtt_client.service.mqtt_manager.await_or_stop", side_effect=asyncio.TimeoutError): + await manager.await_ready(timeout=0.01) + + +@pytest.mark.asyncio +async def test_set_rate_limits_allows_ready(setup_manager): + manager, *_ = setup_manager + manager.set_rate_limits_received() + assert manager._rate_limits_ready_event.is_set() + + +@pytest.mark.asyncio +async def test_set_rate_limits_gateway_requires_gateway_limits(setup_manager): + manager, *_ = setup_manager + manager.enable_gateway_mode() + manager.set_rate_limits_received() + assert not manager._rate_limits_ready_event.is_set() + manager.set_gateway_rate_limits_received() + assert manager._rate_limits_ready_event.is_set() + + +@pytest.mark.asyncio +async def test_match_topic_logic(): + assert MQTTManager._match_topic("foo/+", "foo/bar") + assert not MQTTManager._match_topic("foo/bar", "foo/bar/baz") + assert MQTTManager._match_topic("foo/#", "foo/bar/baz") + + +@pytest.mark.asyncio +async def test_check_pending_publishes_timeout(setup_manager): + manager, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[1] = (fut, MqttPublishMessage("topic", b"payload"), monotonic() - 20) + await manager.check_pending_publishes(monotonic()) + assert fut.done() + assert fut.result().reason_code == 408 + + +@pytest.mark.asyncio +async def test_check_pending_publishes_cancelled_when_stopping(setup_manager): + manager, stop_event, *_ = setup_manager + fut = asyncio.Future() + manager._pending_publishes[5] = (fut, MqttPublishMessage("t", b"p"), monotonic()) + stop_event.set() + await manager.check_pending_publishes(monotonic()) + assert fut.cancelled() + stop_event.clear() + + +@pytest.mark.asyncio +async def test_disconnect_swallows_reset_error(setup_manager): + manager, *_ = setup_manager + with patch.object(manager._client, "disconnect", side_effect=ConnectionResetError): + await manager.disconnect() + + +@pytest.mark.asyncio +async def test_subscribe_adds_future(setup_manager): + manager, *_ = setup_manager + manager._client._connection = MagicMock() + manager._client._connection.subscribe.return_value = 42 + + mock_rate_limit = AsyncMock() + rate_limiter = RateLimiter(mock_rate_limit, MagicMock(), MagicMock()) + setattr(manager, "_MQTTManager__rate_limiter", rate_limiter) + + fut = await manager.subscribe("topic", qos=1) + await asyncio.sleep(0.1) + + assert 42 in manager._pending_subscriptions + assert isinstance(fut, asyncio.Future) + mock_rate_limit.consume.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_unsubscribe_adds_future(setup_manager): + manager, *_ = setup_manager + manager._client._connection = MagicMock() + manager._client._connection.unsubscribe.return_value = 77 + + mock_rate_limit = AsyncMock() + rate_limiter = RateLimiter(mock_rate_limit, MagicMock(), MagicMock()) + setattr(manager, "_MQTTManager__rate_limiter", rate_limiter) + + fut = await manager.unsubscribe("topic") + await asyncio.sleep(0.1) + + assert 77 in manager._pending_unsubscriptions + assert isinstance(fut, asyncio.Future) + mock_rate_limit.consume.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_publish_qos_zero_sets_result_immediately(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._rate_limits_ready_event.set() + + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (99, b"packet") + manager._client._persistent_storage = MagicMock() + future = asyncio.Future() + + await manager.publish(MqttPublishMessage("topic", b"payload", qos=0, delivery_futures=future), force=True) + await asyncio.sleep(0.05) + assert future.done() + assert future.result() == PublishResult("topic", 0, -1, 7, 0) + + +@pytest.mark.asyncio +async def test_on_subscribe_internal_sets_future(setup_manager): + manager, *_ = setup_manager + future = asyncio.Future() + manager._pending_subscriptions[5] = future + manager._on_subscribe_internal(manager._client, 5, 1, {}) + assert future.done() + assert future.result() == 5 + + +@pytest.mark.asyncio +async def test_on_unsubscribe_internal_sets_future(setup_manager): + manager, *_ = setup_manager + future = asyncio.Future() + manager._pending_unsubscriptions[11] = future + manager._on_unsubscribe_internal(manager._client, 11, {}) + assert future.done() + assert future.result() == 11 + + +@pytest.mark.asyncio +async def test_handle_puback_reason_code_errors(setup_manager): + manager, *_ = setup_manager + + f1 = asyncio.Future() + manager._pending_publishes[1] = (f1, MqttPublishMessage("topic", b"payload"), 0) + manager._handle_puback_reason_code(1, IMPLEMENTATION_SPECIFIC_ERROR, {}) + assert f1.result().reason_code == IMPLEMENTATION_SPECIFIC_ERROR + + f2 = asyncio.Future() + manager._pending_publishes[2] = (f2, MqttPublishMessage("topic", b"payload"), 0) + manager._handle_puback_reason_code(2, QUOTA_EXCEEDED, {}) + assert f2.result().reason_code == QUOTA_EXCEEDED + + manager._handle_puback_reason_code(9999, 1, {}) + + +@pytest.mark.asyncio +async def test_request_rate_limits_timeout(setup_manager): + manager, stop_event, _, _, _, _, rate_handler, _ = setup_manager + adapter = manager._message_adapter + + req_mock = MagicMock() + req_mock.request_id = "req-id" + + adapter.build_rpc_request.return_value = MqttPublishMessage("topic", b"payload") + + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (999, b"fake_packet") + manager._client._persistent_storage = MagicMock() + manager._client._persistent_storage.push_message_nowait = MagicMock() + + future = asyncio.Future() + future.set_result(None) + manager._rpc_response_handler.register_request.return_value = future + + with patch("tb_mqtt_client.entities.data.rpc_request.RPCRequest.build", return_value=req_mock): + await manager._MQTTManager__request_rate_limits() + manager.set_rate_limits_received() + assert manager._rate_limits_ready_event.is_set() + + +@pytest.mark.asyncio +async def test_request_rate_limits_real_timeout_branch(setup_manager): + manager, *_ = setup_manager + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + + manager._message_adapter.build_rpc_request.return_value = MqttPublishMessage("t", b"p") + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (1, b"p") + manager._client._persistent_storage = MagicMock() + + never = asyncio.Future() + manager._rpc_response_handler.register_request.return_value = never + + with patch("tb_mqtt_client.service.mqtt_manager.await_or_stop", side_effect=asyncio.TimeoutError): + await manager._MQTTManager__request_rate_limits() + + assert manager._MQTTManager__is_waiting_for_rate_limits_publish is True + assert not manager._rate_limits_ready_event.is_set() + + +@pytest.mark.asyncio +async def test_monitor_ack_timeouts_stops_gracefully(setup_manager): + manager, stop_event, *_ = setup_manager + stop_event.set() + await manager._monitor_ack_timeouts() + stop_event.clear() + + +@pytest.mark.asyncio +async def test_match_topic_exact_match_and_failures(): + assert MQTTManager._match_topic("a/b/c", "a/b/c") + assert not MQTTManager._match_topic("a/b/c", "a/b") + assert not MQTTManager._match_topic("a/+/c", "a/x") + + +@patch("tb_mqtt_client.service.mqtt_manager.run_coroutine_sync") +@pytest.mark.asyncio +async def test_disconnect_reason_code_142_triggers_special_flow(mock_run_sync, setup_manager): + mock_run_sync.return_value = (None, 1, 1) + manager, *_ = setup_manager + manager._client = MagicMock() + manager._backpressure = MagicMock() + manager._on_disconnect_callback = AsyncMock() + manager._run_coroutine_sync = MagicMock(return_value=(None, 1, 1)) + + rate_limit = MagicMock(spec=RateLimit) + manager._MQTTManager__rate_limiter = {"messages": rate_limit} + + fut = asyncio.Future() + manager._pending_publishes[42] = (fut, MqttPublishMessage("topic", b"payload"), 0) + + manager._on_disconnect_internal(manager._client, reason_code=142) + await asyncio.sleep(0.05) + + assert fut.done() + manager._backpressure.notify_disconnect.assert_has_calls([ + call(delay_seconds=10), + call(delay_seconds=1), + ]) + manager._on_disconnect_callback.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_duplicate_publish_path_retransmits_and_persists(setup_manager): + manager, *_ = setup_manager + conn = MagicMock() + protocol = object() + conn._protocol = protocol + manager._client._connection = conn + manager._client._persistent_storage = MagicMock() + + msg = MqttPublishMessage("topic/dup", b"payload", qos=1) + msg.dup = True + msg.message_id = 321 + + with patch("tb_mqtt_client.service.mqtt_manager.PublishPacket.build_package", + return_value=(321, b"rebuilt")) as mock_build: + await manager.publish(msg, force=True) + + mock_build.assert_called_once() + conn.send_package.assert_called_once_with(b"rebuilt") + manager._client._persistent_storage.push_message_nowait.assert_called_once_with(321, msg) + + +@pytest.mark.asyncio +async def test_add_future_chain_processing_success_and_immediate(setup_manager): + manager, *_ = setup_manager + main = asyncio.get_event_loop().create_future() + main.uuid = "main" + + f1 = asyncio.get_event_loop().create_future() + f1.uuid = "f1" + f2 = asyncio.get_event_loop().create_future() + f2.uuid = "f2" + + msg = MqttPublishMessage("t", b"p", qos=1, delivery_futures=[f1, f2]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved") as child_resolved: + await MQTTManager._add_future_chain_processing(main, msg) + res = PublishResult("t", 1, -1, 1, 0) + main.set_result(res) + await asyncio.sleep(0.05) + assert f1.done() and f2.done() + assert f1.result() == res and f2.result() == res + assert child_resolved.call_count == 2 + + main2 = asyncio.get_event_loop().create_future() + main2.uuid = "main2" + main2.set_result(res) + g1 = asyncio.get_event_loop().create_future(); g1.uuid = "g1" + g2 = asyncio.get_event_loop().create_future(); g2.uuid = "g2" + msg2 = MqttPublishMessage("t2", b"q", qos=1, delivery_futures=[g1, g2]) + await MQTTManager._add_future_chain_processing(main2, msg2) + await asyncio.sleep(0.05) + assert g1.done() and g2.done() + + + +@pytest.mark.asyncio +async def test_add_future_chain_processing_cancel_and_exception(setup_manager): + manager, *_ = setup_manager + + cancelled_main = asyncio.get_event_loop().create_future() + cancelled_main.uuid = "cmain" + cancelled_main.cancel() + c1 = asyncio.get_event_loop().create_future(); c1.uuid = "c1" + c2 = asyncio.get_event_loop().create_future(); c2.uuid = "c2" + msg = MqttPublishMessage("t", b"p", qos=1, delivery_futures=[c1, c2]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved"): + await MQTTManager._add_future_chain_processing(cancelled_main, msg) + await asyncio.sleep(0.05) + assert c1.done() and c2.done() + assert isinstance(c1.result(), PublishResult) and isinstance(c2.result(), PublishResult) + + exc_main = asyncio.get_event_loop().create_future() + exc_main.uuid = "emain" + exc_main.set_exception(RuntimeError("boom")) + e1 = asyncio.get_event_loop().create_future(); e1.uuid = "e1" + e2 = asyncio.get_event_loop().create_future(); e2.uuid = "e2" + msg2 = MqttPublishMessage("t2", b"p2", qos=1, delivery_futures=[e1, e2]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved"): + await MQTTManager._add_future_chain_processing(exc_main, msg2) + await asyncio.sleep(0.05) + assert e1.done() and e2.done() + assert isinstance(e1.result(), PublishResult) and isinstance(e2.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_process_regular_publish_qos1_resolves_delivery_on_puback(setup_manager): + manager, *_ = setup_manager + manager._client._connection = MagicMock() + manager._client._connection.publish.return_value = (555, b"p") + manager._client._persistent_storage = MagicMock() + + d1 = asyncio.get_event_loop().create_future(); + d1.uuid = "d1" + d2 = asyncio.get_event_loop().create_future(); + d2.uuid = "d2" + msg = MqttPublishMessage("t", b"payload", qos=1, delivery_futures=[d1, d2]) + await manager.publish(msg, force=True) + + manager._handle_puback_reason_code(555, 0, {}) + await asyncio.sleep(0.05) + assert d1.done() and d2.done() + assert isinstance(d1.result(), PublishResult) and isinstance(d2.result(), PublishResult) + + +@pytest.mark.asyncio +async def test_on_connect_internal_failure_does_not_set_ready(setup_manager): + manager, *_ = setup_manager + manager._connected_event.set() + manager._on_connect_internal(manager._client, session_present=False, reason_code=128, properties=None) + assert not manager._connected_event.is_set() + + + +@pytest.mark.asyncio +async def test_on_connect_internal_success_triggers_limits_flow(setup_manager): + manager, *_ = setup_manager + called = asyncio.Event() + + async def fake_limits_flow(): + called.set() + + manager._client._connection = type("Conn", (), {})() + + setattr(manager, "_MQTTManager__handle_connect_and_limits", fake_limits_flow) + + manager._on_connect_internal(manager._client, session_present=True, reason_code=0, properties={}) + await asyncio.wait_for(called.wait(), timeout=1.0) + assert manager._connected_event.is_set() + + +@pytest.mark.asyncio +async def test_patch_client_for_retry_logic_assigns_method(setup_manager): + manager, *_ = setup_manager + + async def put_retry(msg: MqttPublishMessage): + return None + + manager.patch_client_for_retry_logic(put_retry) + assert manager._client.put_retry_message is put_retry + + +@pytest.mark.asyncio +async def test_stop_cancels_tasks_and_calls_patch_utils(setup_manager): + manager, *_ = setup_manager + manager._patch_utils.stop_retry_task = AsyncMock() + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock, return_value=False): + await manager.stop() + + manager._patch_utils.stop_retry_task.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_tls_creates_default_ssl_context(setup_manager): + manager, *_ = setup_manager + + manager._client.connect = AsyncMock() + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock, return_value=True), \ + patch.object(ssl, "create_default_context", autospec=True) as create_ctx: + await manager.connect("host", 8883, tls=True, ssl_context=None) + + create_ctx.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_loop_handles_wait_exception_and_exits_when_connected(setup_manager): + manager, stop_event, *_ = setup_manager + + manager._client.connect = AsyncMock() + + with patch.object(type(manager._client), "is_connected", new_callable=PropertyMock, + side_effect=[False, True]), \ + patch.object(manager._connected_event, "wait", + AsyncMock(side_effect=[Exception("boom"), None])): + await manager.connect("host", 1883, tls=False) + + +@pytest.mark.asyncio +async def test_publish_wait_rate_limits_timeout_requests_then_raises(setup_manager): + manager, *_ = setup_manager + + manager._MQTTManager__rate_limits_retrieved = True + manager._MQTTManager__is_waiting_for_rate_limits_publish = False + manager._backpressure.should_pause = MagicMock(return_value=False) + + with patch("tb_mqtt_client.service.mqtt_manager.await_or_stop", side_effect=asyncio.TimeoutError), \ + patch.object(manager, "_MQTTManager__request_rate_limits", new=AsyncMock()) as req_limits: + with pytest.raises(RuntimeError, match="Timeout waiting for rate limits."): + await manager.publish(MqttPublishMessage("t", b"p"), force=False) + + req_limits.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_disconnect_rate_limit_timeout_sets_default_backpressure_delay(setup_manager): + manager, *_ = setup_manager + + rate_limit = MagicMock(spec=RateLimit) + setattr(manager, "_MQTTManager__rate_limiter", {"messages": rate_limit}) + manager._backpressure = MagicMock() + + with patch("tb_mqtt_client.service.mqtt_manager.run_coroutine_sync", side_effect=TimeoutError): + manager._on_disconnect_internal(manager._client, reason_code=131) + + manager._backpressure.notify_disconnect.assert_called_with(delay_seconds=10) + + +@pytest.mark.asyncio +async def test_add_future_chain_processing_sets_exception_when_child_resolved_raises(setup_manager): + manager, *_ = setup_manager + + main = asyncio.get_event_loop().create_future() + main.uuid = "main" + + f_ok = asyncio.get_event_loop().create_future(); f_ok.uuid = "ok" + f_err = asyncio.get_event_loop().create_future(); f_err.uuid = "err" + msg = MqttPublishMessage("topic", b"payload", qos=1, delivery_futures=[f_ok, f_err]) + + with patch("tb_mqtt_client.service.mqtt_manager.future_map.child_resolved", + side_effect=RuntimeError("boom")): + await MQTTManager._add_future_chain_processing(main, msg) + + main.set_result(PublishResult("topic", 1, -1, len(b"payload"), 0)) + await asyncio.sleep(0.05) + + assert f_ok.done() + assert isinstance(f_ok.result(), PublishResult) + + assert f_err.done() + assert isinstance(f_err.exception(), Exception) + + +if __name__ == '__main__': + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/tb_device_mqtt_client_tests.py b/tests/tb_device_mqtt_client_tests.py deleted file mode 100644 index b8bef22..0000000 --- a/tests/tb_device_mqtt_client_tests.py +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from time import sleep - -from tb_device_mqtt import TBDeviceMqttClient - - -class TBDeviceMqttClientTests(unittest.TestCase): - """ - Before running tests, do the next steps: - 1. Create device "Example Name" in ThingsBoard - 2. Add shared attribute "attr" with value "hello" to created device - 3. Add client attribute "atr3" with value "value3" to created device - """ - - client = None - - shared_attribute_name = 'attr' - shared_attribute_value = 'hello' - - client_attribute_name = 'atr3' - client_attribute_value = 'value3' - - request_attributes_result = None - subscribe_to_attribute = None - subscribe_to_attribute_all = None - - @classmethod - def setUpClass(cls) -> None: - cls.client = TBDeviceMqttClient('127.0.0.1', 1883, 'TEST_DEVICE_TOKEN') - cls.client.connect(timeout=1) - - @classmethod - def tearDownClass(cls) -> None: - cls.client.disconnect() - - @staticmethod - def on_attributes_change_callback(result, exception=None): - if exception is not None: - TBDeviceMqttClientTests.request_attributes_result = exception - else: - TBDeviceMqttClientTests.request_attributes_result = result - - @staticmethod - def callback_for_specific_attr(result, *args): - TBDeviceMqttClientTests.subscribe_to_attribute = result - - @staticmethod - def callback_for_everything(result, *args): - TBDeviceMqttClientTests.subscribe_to_attribute_all = result - - def test_request_attributes(self): - self.client.request_attributes(shared_keys=[self.shared_attribute_name], - callback=self.on_attributes_change_callback) - sleep(3) - self.assertEqual(self.request_attributes_result, - {'shared': {self.shared_attribute_name: self.shared_attribute_value}}) - - self.client.request_attributes(client_keys=[self.client_attribute_name], - callback=self.on_attributes_change_callback) - sleep(3) - self.assertEqual(self.request_attributes_result, - {'client': {self.client_attribute_name: self.client_attribute_value}}) - - def test_send_telemetry_and_attr(self): - telemetry = {"temperature": 41.9, "humidity": 69, "enabled": False, "currentFirmwareVersion": "v1.2.2"} - self.assertEqual(self.client.send_telemetry(telemetry, 0).get(), 0) - - attributes = {"sensorModel": "DHT-22", self.client_attribute_name: self.client_attribute_value} - self.assertEqual(self.client.send_attributes(attributes, 0).get(), 0) - - def test_subscribe_to_attrs(self): - sub_id_1 = self.client.subscribe_to_attribute(self.shared_attribute_name, self.callback_for_specific_attr) - sub_id_2 = self.client.subscribe_to_all_attributes(self.callback_for_everything) - - sleep(1) - value = input("Updated attribute value: ") - - self.assertEqual(self.subscribe_to_attribute_all, {self.shared_attribute_name: value}) - self.assertEqual(self.subscribe_to_attribute, {self.shared_attribute_name: value}) - - self.client.unsubscribe_from_attribute(sub_id_1) - self.client.unsubscribe_from_attribute(sub_id_2) - - -class TestSplitMessageVariants(unittest.TestCase): - - def test_empty_input(self): - self.assertEqual(TBDeviceMqttClient._split_message([], 10, 100), []) - - def test_rpc_payload(self): - rpc_payload = {'device': 'dev1'} - result = TBDeviceMqttClient._split_message(rpc_payload, 10, 100) - self.assertEqual(len(result), 1) - self.assertEqual(result[0]['data'], rpc_payload) - - def test_single_value_message(self): - msg = [{'ts': 1, 'values': {'a': 1}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result[0]['data'][0]['values'], {'a': 1}) - - def test_timestamp_change_split(self): - msg = [{'ts': 1, 'values': {'a': 1}}, {'ts': 2, 'values': {'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(len(result), 2) - - def test_exceeding_datapoint_limit(self): - msg = [{'ts': 1, 'values': {f'k{i}': i for i in range(10)}}] - result = TBDeviceMqttClient._split_message(msg, 5, 1000) - self.assertGreaterEqual(len(result), 2) - - def test_message_with_metadata(self): - msg = [{'ts': 1, 'values': {'a': 1}, 'metadata': {'unit': 'C'}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertIn('metadata', result[0]['data'][0]) - - def test_large_payload_split(self): - msg = [{'ts': 1, 'values': {f'key{i}': 'v'*50 for i in range(5)}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertGreater(len(result), 1) - - def test_metadata_present_with_ts(self): - msg = [{'ts': 123456789, 'values': {'temperature': 25}, 'metadata': {'unit': 'C'}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertIn('metadata', result[0]['data'][0]) - self.assertEqual(result[0]['data'][0]['metadata'], {'unit': 'C'}) - - def test_metadata_ignored_without_ts(self): - msg = [{'values': {'temperature': 25}, 'metadata': {'unit': 'C'}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertTrue(all('metadata' not in entry for r in result for entry in r['data'])) - - def test_grouping_same_ts_exceeds_datapoint_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 1, 'b': 2, 'c': 3}}, - {'ts': 1, 'values': {'d': 4, 'e': 5}} - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=3, max_payload_size=1000) - self.assertGreater(len(result), 1) - total_keys = set() - for part in result: - for d in part['data']: - total_keys.update(d['values'].keys()) - self.assertEqual(total_keys, {'a', 'b', 'c', 'd', 'e'}) - - def test_grouping_same_ts_exceeds_payload_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 'x'*30}}, - {'ts': 1, 'values': {'b': 'y'*30}}, - {'ts': 1, 'values': {'c': 'z'*30}} - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=64) - self.assertGreater(len(result), 1) - all_keys = set() - for r in result: - for d in r['data']: - all_keys.update(d['values'].keys()) - self.assertEqual(all_keys, {'a', 'b', 'c'}) - - def test_individual_messages_not_grouped_if_payload_limit_exceeded(self): - msg = [ - {'ts': 1, 'values': {'a': 'x' * 100}}, - {'ts': 1, 'values': {'b': 'y' * 100}}, - {'ts': 1, 'values': {'c': 'z' * 100}} - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=110) - self.assertEqual(len(result), 3) - - grouped_keys = [] - for r in result: - for d in r['data']: - grouped_keys.extend(d['values'].keys()) - - self.assertEqual(set(grouped_keys), {'a', 'b', 'c'}) - for r in result: - for d in r['data']: - self.assertLessEqual(sum(len(k) + len(str(v)) for k, v in d['values'].items()), 110) - - def test_partial_grouping_due_to_payload_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 'x' * 10, 'b': 'y' * 10}}, # should be grouped together - {'ts': 1, 'values': {'c': 'z' * 100}}, # should be on its own due to size - {'ts': 1, 'values': {'d': 'w' * 10, 'e': 'q' * 10}} # should be grouped again - ] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=64) - self.assertEqual(len(result), 3) - - result_keys = [] - for r in result: - keys = [] - for d in r['data']: - keys.extend(d['values'].keys()) - result_keys.append(set(keys)) - - self.assertIn({'a', 'b'}, result_keys) - self.assertIn({'c'}, result_keys) - self.assertIn({'d', 'e'}, result_keys) - - def test_partial_grouping_due_to_datapoint_limit(self): - msg = [ - {'ts': 1, 'values': {'a': 1, 'b': 2}}, # grouped - {'ts': 1, 'values': {'c': 3, 'd': 4}}, # grouped - {'ts': 1, 'values': {'e': 5}}, # forced into next group due to datapoint limit - {'ts': 1, 'values': {'f': 6}}, # grouped with above - ] - # Max datapoints per message is 4 (after subtracting 1 in implementation) - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=5, max_payload_size=1000) - self.assertEqual(len(result), 2) - - all_keys = [] - for r in result: - keys = set() - for d in r['data']: - keys.update(d['values'].keys()) - all_keys.append(keys) - - # First group should contain a, b, c, d (4 datapoints) - self.assertIn({'a', 'b', 'c', 'd'}, all_keys) - # Second group should contain e, f (2 datapoints) - self.assertIn({'e', 'f'}, all_keys) - - def test_values_included_only_when_ts_present(self): - msg = [{'values': {'a': 1, 'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result[0]['data'][0], {'a': 1, 'b': 2}) - - def test_missing_values_field_uses_whole_message(self): - msg = [{'ts': 123, 'a': 1, 'b': 2}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertIn('values', result[0]['data'][0]) - self.assertEqual(result[0]['data'][0]['values'], {'a': 1, 'b': 2, 'ts': 123}) - - def test_metadata_conflict_same_ts_no_grouping(self): - msg = [ - {'ts': 1, 'values': {'a': 1}, 'metadata': {'unit': 'C'}}, - {'ts': 1, 'values': {'b': 2}, 'metadata': {'unit': 'F'}} - ] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - - self.assertEqual(len(result), 2) - - metadata_sets = [d['data'][0].get('metadata') for d in result] - self.assertIn({'unit': 'C'}, metadata_sets) - self.assertIn({'unit': 'F'}, metadata_sets) - - value_keys_sets = [set(d['data'][0]['values'].keys()) for d in result] - self.assertIn({'a'}, value_keys_sets) - self.assertIn({'b'}, value_keys_sets) - - def test_non_dict_message_is_skipped(self): - msg = [{'ts': 1, 'values': {'a': 1}}, 'this_is_not_a_dict', {'ts': 1, 'values': {'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(len(result), 1) - values = result[0]['data'][0]['values'] - self.assertEqual(set(values.keys()), {'a', 'b'}) - - def test_multiple_dicts_without_ts_values_metadata(self): - msg = [{'a': 1}, {'b': 2}, {'c': 3}] - result = TBDeviceMqttClient._split_message(msg, 10, 1000) - self.assertEqual(len(result), 1) # grouped - combined_keys = {} - for d in result[0]['data']: - combined_keys.update(d) - self.assertEqual(set(combined_keys.keys()), {'a', 'b', 'c'}) - - def test_multiple_dicts_without_ts_split_by_payload(self): - msg = [{'a': 'x' * 60}, {'b': 'y' * 60}, {'c': 'z' * 60}] - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=10, max_payload_size=64) - self.assertEqual(len(result), 3) # each too large to group - keys = [list(r['data'][0].keys())[0] for r in result] - self.assertEqual(set(keys), {'a', 'b', 'c'}) - - def test_mixed_dicts_with_and_without_ts(self): - msg = [ - {'ts': 1, 'values': {'a': 1}}, - {'b': 2}, - {'ts': 1, 'values': {'c': 3}}, - {'d': 4} - ] - result = TBDeviceMqttClient._split_message(msg, 10, 1000) - # Should split into at least 2 chunks: one for ts=1 and one for ts=None - self.assertGreaterEqual(len(result), 2) - - ts_chunks = [d for r in result for d in r['data'] if 'ts' in d] - raw_chunks = [d for r in result for d in r['data'] if 'ts' not in d] - - ts_keys = set() - for d in ts_chunks: - ts_keys.update(d['values'].keys()) - - raw_keys = set() - for d in raw_chunks: - raw_keys.update(d.keys()) - - self.assertEqual(ts_keys, {'a', 'c'}) - self.assertEqual(raw_keys, {'b', 'd'}) - - def test_complex_mixed_messages(self): - msg = [ - {'ts': 1, 'values': {'a': 1}}, - {'ts': 1, 'values': {'b': 2}, 'metadata': {'unit': 'C'}}, - - {'ts': 2, 'values': {'c1': 1, 'c2': 2, 'c3': 3}}, - {'ts': 2, 'values': {'c4': 4, 'c5': 5}}, - - {'ts': 3, 'values': {'x': 'x' * 60}}, - {'ts': 3, 'values': {'y': 'y' * 60}}, - - {'m1': 1, 'm2': 2}, - - 123, - - {'m3': 'a' * 100}, - {'m4': 'b' * 100}, - - {'k1': 1, 'k2': 2, 'k3': 3}, - {'k4': 4, 'k5': 5}, - - {'ts': 1, 'values': {'z': 99}}, - ] - - result = TBDeviceMqttClient._split_message(msg, datapoints_max_count=4, max_payload_size=64) - - all_ts_groups = {} - raw_chunks = [] - for r in result: - for entry in r['data']: - if isinstance(entry, dict) and 'ts' in entry: - ts = entry['ts'] - if ts not in all_ts_groups: - all_ts_groups[ts] = set() - all_ts_groups[ts].update(entry.get('values', {}).keys()) - if 'metadata' in entry: - self.assertIsInstance(entry['metadata'], dict) - else: - raw_chunks.append(entry) - - self.assertIn(1, all_ts_groups) - self.assertEqual(all_ts_groups[1], {'a', 'b', 'z'}) - - self.assertIn(2, all_ts_groups) - self.assertEqual(all_ts_groups[2], {'c1', 'c2', 'c3', 'c4', 'c5'}) - - self.assertIn(3, all_ts_groups) - self.assertEqual(all_ts_groups[3], {'x', 'y'}) - - all_raw_keys = [set(entry.keys()) for entry in raw_chunks] - - expected_raw_key_sets = [ - {'m1', 'm2'}, - {'m3'}, - {'m4'}, - {'k1', 'k2', 'k3'}, - {'k4', 'k5'} - ] - - for expected_keys in expected_raw_key_sets: - self.assertIn(expected_keys, all_raw_keys) - - for r in raw_chunks: - self.assertLessEqual(len(r), 4) # Max datapoints = 4 - - total_size = sum(len(k) + len(str(v)) for k, v in r.items()) - - if len(r) > 1: - self.assertLessEqual(total_size, 64) - if total_size > 64: - self.assertEqual(len(r), 1) - - self.assertGreaterEqual(len(result), 8) - - def test_empty_values_should_skip_or_include_empty(self): - msg = [{'ts': 1, 'values': {}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result, []) - - def test_duplicate_keys_within_same_ts(self): - msg = [{'ts': 1, 'values': {'a': 1}}, {'ts': 1, 'values': {'a': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - values = {} - for d in result[0]['data']: - values.update(d['values']) - self.assertEqual(values['a'], 2) # Last value wins - - def test_partial_metadata_presence(self): - msg = [{'ts': 1, 'values': {'a': 1}, 'metadata': {'unit': 'C'}}, - {'ts': 1, 'values': {'b': 2}}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - for d in result[0]['data']: - if d['values'].keys() == {'a', 'b'} or d['values'].keys() == {'b', 'a'}: - self.assertIn('metadata', d) - - def test_non_dict_metadata_should_be_ignored(self): - msg = [{'ts': 1, 'values': {'a': 1}, 'metadata': ['invalid']}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertTrue(all( - isinstance(d.get('metadata', {}), dict) or 'metadata' not in d - for r in result for d in r['data'] - )) - - def test_non_list_message_pack_single_dict_raw(self): - msg = {'a': 1, 'b': 2} - result = TBDeviceMqttClient._split_message(msg, 10, 100) - self.assertEqual(result[0]['data'][0], msg) - - def test_nested_value_object_should_count_size_correctly(self): - msg = [{'ts': 1, 'values': {'a': {'nested': 'structure'}}}] - result = TBDeviceMqttClient._split_message(msg, 10, 1000) - total_size = sum(len(k) + len(str(v)) for r in result for d in r['data'] for k, v in d['values'].items()) - self.assertGreater(total_size, 0) - - def test_raw_duplicate_keys_overwrite_behavior(self): - msg = [{'a': 1}, {'a': 2}] - result = TBDeviceMqttClient._split_message(msg, 10, 100) - all_data = {} - for r in result: - for d in r['data']: - all_data.update(d) - self.assertEqual(all_data['a'], 2) # Last value wins - - -if __name__ == '__main__': - unittest.main('tb_device_mqtt_client_tests') diff --git a/tests/tb_gateway_mqtt_client_tests.py b/tests/tb_gateway_mqtt_client_tests.py deleted file mode 100644 index 6893183..0000000 --- a/tests/tb_gateway_mqtt_client_tests.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2025. ThingsBoard -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from time import sleep, time - -from tb_gateway_mqtt import TBGatewayMqttClient - - -class TBGatewayMqttClientTests(unittest.TestCase): - """ - Before running tests, do the next steps: - 1. Create device "Example Name" in ThingsBoard - 2. Add shared attribute "attr" with value "hello" to created device - """ - - client = None - - device_name = 'Example Name' - shared_attr_name = 'attr' - shared_attr_value = 'hello' - - request_attributes_result = None - subscribe_to_attribute = None - subscribe_to_attribute_all = None - subscribe_to_device_attribute_all = None - - @classmethod - def setUpClass(cls) -> None: - cls.client = TBGatewayMqttClient('127.0.0.1', 1883, 'TEST_GATEWAY_TOKEN') - cls.client.connect(timeout=1) - - @classmethod - def tearDownClass(cls) -> None: - cls.client.disconnect() - - @staticmethod - def request_attributes_callback(result, exception=None): - if exception is not None: - TBGatewayMqttClientTests.request_attributes_result = exception - else: - TBGatewayMqttClientTests.request_attributes_result = result - - @staticmethod - def callback(result): - TBGatewayMqttClientTests.subscribe_to_device_attribute_all = result - - @staticmethod - def callback_for_everything(result): - TBGatewayMqttClientTests.subscribe_to_attribute_all = result - - @staticmethod - def callback_for_specific_attr(result): - TBGatewayMqttClientTests.subscribe_to_attribute = result - - def test_connect_disconnect_device(self): - self.assertEqual(self.client.gw_connect_device(self.device_name).rc, 0) - self.assertEqual(self.client.gw_disconnect_device(self.device_name).rc, 0) - - def test_request_attributes(self): - self.client.gw_request_shared_attributes(self.device_name, [self.shared_attr_name], - self.request_attributes_callback) - sleep(3) - self.assertEqual(self.request_attributes_result, - {'id': 1, 'device': self.device_name, 'value': self.shared_attr_value}) - - def test_send_telemetry_and_attributes(self): - attributes = {"atr1": 1, "atr2": True, "atr3": "value3"} - telemetry = {"ts": int(round(time() * 1000)), "values": {"key1": "11"}} - self.assertEqual(self.client.gw_send_attributes(self.device_name, attributes).get(), 0) - self.assertEqual(self.client.gw_send_telemetry(self.device_name, telemetry).get(), 0) - - def test_subscribe_to_attributes(self): - self.client.gw_connect_device(self.device_name) - - self.client.gw_subscribe_to_all_attributes(self.callback_for_everything) - self.client.gw_subscribe_to_attribute(self.device_name, self.shared_attr_name, self.callback_for_specific_attr) - sub_id = self.client.gw_subscribe_to_all_device_attributes(self.device_name, self.callback) - - sleep(1) - value = input("Updated attribute value: ") - - self.assertEqual(self.subscribe_to_attribute, - {'device': self.device_name, 'data': {self.shared_attr_name: value}}) - self.assertEqual(self.subscribe_to_attribute_all, - {'device': self.device_name, 'data': {self.shared_attr_name: value}}) - self.assertEqual(self.subscribe_to_device_attribute_all, - {'device': self.device_name, 'data': {self.shared_attr_name: value}}) - - self.client.gw_unsubscribe(sub_id) - - -if __name__ == '__main__': - unittest.main('tb_gateway_mqtt_client_tests')