diff --git a/tb_mqtt_client/common/install_package_utils.py b/tb_mqtt_client/common/install_package_utils.py new file mode 100644 index 0000000..1227418 --- /dev/null +++ b/tb_mqtt_client/common/install_package_utils.py @@ -0,0 +1,48 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sys import executable +from subprocess import check_call, CalledProcessError, DEVNULL +from pkg_resources import get_distribution, DistributionNotFound + + +def install_package(package, version="upgrade"): + result = False + + def try_install(args, suppress_stderr=False): + try: + stderr = DEVNULL if suppress_stderr else None + check_call([executable, "-m", "pip", *args], stderr=stderr) + return True + except CalledProcessError: + return False + + if version.lower() == "upgrade": + args = ["install", package, "--upgrade"] + result = try_install(args + ["--user"], suppress_stderr=True) + if not result: + result = try_install(args) + else: + try: + installed_version = get_distribution(package).version + if installed_version == version: + return True + except DistributionNotFound: + pass + install_version = f"{package}=={version}" if ">=" not in version else f"{package}{version}" + args = ["install", install_version] + if not try_install(args + ["--user"], suppress_stderr=True): + result = try_install(args) + + return result diff --git a/tb_mqtt_client/constants/firmware.py b/tb_mqtt_client/constants/firmware.py new file mode 100644 index 0000000..cdc03b4 --- /dev/null +++ b/tb_mqtt_client/constants/firmware.py @@ -0,0 +1,36 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class FirmwareStates(Enum): + IDLE = 'IDLE' + DOWNLOADING = 'DOWNLOADING' + DOWNLOADED = 'DOWNLOADED' + VERIFIED = 'VERIFIED' + FAILED = 'FAILED' + UPDATING = 'UPDATING' + UPDATED = 'UPDATED' + + +FW_TITLE_ATTR = "fw_title" +FW_VERSION_ATTR = "fw_version" +FW_CHECKSUM_ATTR = "fw_checksum" +FW_CHECKSUM_ALG_ATTR = "fw_checksum_algorithm" +FW_SIZE_ATTR = "fw_size" +FW_STATE_ATTR = "fw_state" + +REQUIRED_SHARED_KEYS = [FW_CHECKSUM_ATTR, FW_CHECKSUM_ALG_ATTR, + FW_SIZE_ATTR, FW_TITLE_ATTR, FW_VERSION_ATTR] diff --git a/tb_mqtt_client/constants/mqtt_topics.py b/tb_mqtt_client/constants/mqtt_topics.py index 3986551..5f98343 100644 --- a/tb_mqtt_client/constants/mqtt_topics.py +++ b/tb_mqtt_client/constants/mqtt_topics.py @@ -30,6 +30,12 @@ DEVICE_RPC_RESPONSE_TOPIC_FOR_SUBSCRIPTION = DEVICE_RPC_TOPIC + RESPONSE_TOPIC_SUFFIX + "/" + WILDCARD # Device Claim topic DEVICE_CLAIM_TOPIC = "v1/devices/me/claim" +# Device Provisioning topics +PROVISION_REQUEST_TOPIC = "/provision/request" +PROVISION_RESPONSE_TOPIC = "/provision/response" +# Device Firmware Update topics +DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC = "v2/fw/response/+/chunk/+" +DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC = "v2/fw/request/{request_id}/chunk/{current_chunk}" # V1 Topics for Gateway API BASE_GATEWAY_TOPIC = "v1/gateway" @@ -69,3 +75,7 @@ def build_gateway_device_attributes_topic() -> str: def build_gateway_rpc_topic() -> str: return GATEWAY_RPC_TOPIC + + +def build_firmware_update_request_topic(request_id: int, current_chunk: int) -> str: + return DEVICE_FIRMWARE_UPDATE_REQUEST_TOPIC.format(request_id=request_id, current_chunk=current_chunk) diff --git a/tb_mqtt_client/constants/provisioning.py b/tb_mqtt_client/constants/provisioning.py new file mode 100644 index 0000000..a8cf8c8 --- /dev/null +++ b/tb_mqtt_client/constants/provisioning.py @@ -0,0 +1,29 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ProvisioningResponseStatus(Enum): + SUCCESS = "SUCCESS" + ERROR = "FAILURE" + + def __str__(self): + return self.value + + +class ProvisioningCredentialsType(Enum): + ACCESS_TOKEN = "ACCESS_TOKEN" + MQTT_BASIC = "MQTT_BASIC" + X509_CERTIFICATE = "X509_CERTIFICATE" diff --git a/tb_mqtt_client/entities/data/provisioning_request.py b/tb_mqtt_client/entities/data/provisioning_request.py new file mode 100644 index 0000000..e54db43 --- /dev/null +++ b/tb_mqtt_client/entities/data/provisioning_request.py @@ -0,0 +1,77 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional + +from tb_mqtt_client.constants.provisioning import ProvisioningCredentialsType + + +class ProvisioningRequest: + def __init__(self, host, credentials: 'ProvisioningCredentials', port: str = "1883", + device_name: Optional[str] = None, gateway: Optional[bool] = False): + self.host = host + self.port = port + self.credentials = credentials + self.device_name = device_name + self.gateway = gateway + + +class ProvisioningCredentials(ABC): + @abstractmethod + def __init__(self, provision_device_key: str, provision_device_secret: str): + self.provision_device_key = provision_device_key + self.provision_device_secret = provision_device_secret + self.credentials_type: ProvisioningCredentialsType + + +class AccessTokenProvisioningCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key: str, provision_device_secret: str, access_token: Optional[str] = None): + super().__init__(provision_device_key, provision_device_secret) + self.access_token = access_token + self.credentials_type = ProvisioningCredentialsType.ACCESS_TOKEN + + +class BasicProvisioningCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key, provision_device_secret, + client_id: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None): + super().__init__(provision_device_key, provision_device_secret) + self.client_id = client_id + self.username = username + self.password = password + self.credentials_type = ProvisioningCredentialsType.MQTT_BASIC + + +class X509ProvisioningCredentials(ProvisioningCredentials): + def __init__(self, provision_device_key, provision_device_secret, + private_key_path: str, public_cert_path: str, ca_cert_path: str): + super().__init__(provision_device_key, provision_device_secret) + self.private_key_path = private_key_path + self.ca_cert_path = ca_cert_path + self.public_cert_path = public_cert_path + self.public_cert = self._load_public_cert_path(public_cert_path) + self.credentials_type = ProvisioningCredentialsType.X509_CERTIFICATE + + def _load_public_cert_path(public_cert_path): + content = '' + + try: + with open(public_cert_path, 'r') as file: + content = file.read() + except FileNotFoundError: + raise FileNotFoundError(f"Public certificate file not found: {public_cert_path}") + except IOError as e: + raise IOError(f"Error reading public certificate file {public_cert_path}: {e}") + + return content.strip() if content else None diff --git a/tb_mqtt_client/entities/data/provisioning_response.py b/tb_mqtt_client/entities/data/provisioning_response.py new file mode 100644 index 0000000..82e7a73 --- /dev/null +++ b/tb_mqtt_client/entities/data/provisioning_response.py @@ -0,0 +1,73 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.constants.provisioning import ProvisioningResponseStatus +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType + + +@dataclass(frozen=True) +class ProvisioningResponse: + status: ProvisioningResponseStatus + result: Optional[DeviceConfig] = None + error: Optional[str] = None + + def __new__(cls, *args, **kwargs): + raise TypeError("Direct instantiation of ProvisioningResponse is not allowed. Use ProvisioningResponse.build(result, error).") # noqa + + def __repr__(self) -> str: + return f"ProvisioningResponse(status={self.status}, result={self.result}, error={self.error})" + + @classmethod + def build(cls, provision_request: 'ProvisioningRequest', payload: dict) -> 'ProvisioningResponse': + """ + Constructs a ProvisioningResponse explicitly. + """ + self = object.__new__(cls) + + if payload.get('status') == ProvisioningResponseStatus.ERROR.value: + object.__setattr__(self, 'error', payload.get('errorMsg')) + object.__setattr__(self, 'status', ProvisioningResponseStatus.ERROR) + object.__setattr__(self, 'result', None) + else: + device_config = ProvisioningResponse._build_device_config(provision_request, payload) + + object.__setattr__(self, 'result', device_config) + object.__setattr__(self, 'status', ProvisioningResponseStatus.SUCCESS) + object.__setattr__(self, 'error', None) + + return self + + @staticmethod + def _build_device_config(provision_request: 'ProvisioningRequest', payload: dict): + device_config = DeviceConfig() + device_config.host = provision_request.host + device_config.port = provision_request.port + + if provision_request.credentials.credentials_type is None or \ + provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: + device_config.access_token = payload['credentialsValue'] + elif provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: + device_config.client_id = payload['credentialsValue']['clientId'] + device_config.username = payload['credentialsValue']['userName'] + device_config.password = payload['credentialsValue']['password'] + elif provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: + device_config.ca_cert = provision_request.credentials.ca_cert_path + device_config.client_cert = provision_request.credentials.public_cert_path + device_config.private_key = provision_request.credentials.private_key_path + + return device_config diff --git a/tb_mqtt_client/entities/provisioning_client.py b/tb_mqtt_client/entities/provisioning_client.py new file mode 100644 index 0000000..a2718c3 --- /dev/null +++ b/tb_mqtt_client/entities/provisioning_client.py @@ -0,0 +1,73 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import Event + +from gmqtt import Client as GMQTTClient +from orjson import loads + +from tb_mqtt_client.common.config_loader import DeviceConfig +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants.mqtt_topics import PROVISION_RESPONSE_TOPIC +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest +from tb_mqtt_client.entities.data.provisioning_response import ProvisioningResponse +from tb_mqtt_client.service.message_dispatcher import JsonMessageDispatcher + +logger = get_logger(__name__) + + +class ProvisioningClient: + def __init__(self, host, port, provision_request: 'ProvisioningRequest'): + self._log = logger + self._stop_event = Event() + self._host = host + self._port = port + self._provision_request = provision_request + self._client_id = "provision" + self._client = GMQTTClient(self._client_id) + self._client.on_connect = self._on_connect + self._client.on_message = self._on_message + self._provisioned = Event() + self._device_config: 'DeviceConfig' = None + self._json_message_dispatcher = JsonMessageDispatcher() + + def _on_connect(self, client, _, rc, __): + if rc == 0: + self._log.debug("[Provisioning client] Connected to ThingsBoard") + client.subscribe(PROVISION_RESPONSE_TOPIC) + topic, payload = self._json_message_dispatcher.build_provision_request(self._provision_request) + self._log.debug("[Provisioning client] Sending provisioning request %s" % payload) + client.publish(topic, payload) + else: + self._device_config = ProvisioningResponse.build(self._provision_request, + {'status': 'FAILURE', + 'errorMsg': 'Cannot connect to ThingsBoard!'}) + self._provisioned.set() + self._log.error("[Provisioning client] Cannot connect to ThingsBoard!, result: %s" % rc) + + async def _on_message(self, _, __, payload, ___, ____): + decoded_payload = payload.decode("UTF-8") + self._log.debug("[Provisioning client] Received data from ThingsBoard: %s" % decoded_payload) + decoded_message = loads(decoded_payload) + + self._device_config = ProvisioningResponse.build(self._provision_request, decoded_message) + + await self._client.disconnect() + self._provisioned.set() + + async def provision(self): + await self._client.connect(self._host, self._port) + await self._provisioned.wait() + + return self._device_config diff --git a/tb_mqtt_client/service/device/client.py b/tb_mqtt_client/service/device/client.py index d8e8988..7b10448 100644 --- a/tb_mqtt_client/service/device/client.py +++ b/tb_mqtt_client/service/device/client.py @@ -36,8 +36,11 @@ from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry +from tb_mqtt_client.entities.provisioning_client import ProvisioningClient +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest from tb_mqtt_client.entities.publish_result import PublishResult from tb_mqtt_client.service.base_client import BaseClient +from tb_mqtt_client.service.device.firmware_updater import FirmwareUpdater from tb_mqtt_client.service.device.handlers.attribute_updates_handler import AttributeUpdatesHandler from tb_mqtt_client.service.device.handlers.requested_attributes_response_handler import \ RequestedAttributeResponseHandler @@ -94,6 +97,12 @@ def __init__(self, config: Optional[Union[DeviceConfig, Dict]] = None): self._rpc_requests_handler = RPCRequestsHandler() self.__claiming_response_future: Union[Future[bool], None] = None + self._firmware_updater = FirmwareUpdater(self) + + async def update_firmware(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, + save_firmware: bool = True, firmware_save_path: Optional[str] = None): + await self._firmware_updater.update(on_received_callback, save_firmware, firmware_save_path) + async def connect(self): logger.info("Connecting to platform at %s:%s", self._host, self._port) @@ -447,3 +456,19 @@ def _build_uplink_message_for_attributes(payload: Union[Dict[str, Any], builder = DeviceUplinkMessageBuilder() builder.add_attributes(payload) return builder.build() + + @staticmethod + async def provision(provision_request: 'ProvisioningRequest', timeout=3.0): + provision_client = ProvisioningClient( + host=provision_request.host, + port=provision_request.port, + provision_request=provision_request + ) + + device_credentials = None + try: + device_credentials = await wait_for(provision_client.provision(), timeout=timeout) + except TimeoutError: + logger.error("Provisioning timed out") + + return device_credentials diff --git a/tb_mqtt_client/service/device/firmware_updater.py b/tb_mqtt_client/service/device/firmware_updater.py new file mode 100644 index 0000000..4422b52 --- /dev/null +++ b/tb_mqtt_client/service/device/firmware_updater.py @@ -0,0 +1,276 @@ +# Copyright 2025 ThingsBoard +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from random import randint +from zlib import crc32 +from hashlib import sha256, sha384, sha512, md5 +from subprocess import CalledProcessError +from asyncio import sleep +from os.path import sep +from typing import Awaitable, Callable, Optional +from tb_mqtt_client.common.install_package_utils import install_package +from tb_mqtt_client.common.logging_utils import get_logger +from tb_mqtt_client.constants import mqtt_topics +from tb_mqtt_client.constants.firmware import ( + FW_CHECKSUM_ALG_ATTR, + FW_CHECKSUM_ATTR, + FW_SIZE_ATTR, + FW_STATE_ATTR, + FW_TITLE_ATTR, + FW_VERSION_ATTR, + REQUIRED_SHARED_KEYS, + FirmwareStates +) + +from tb_mqtt_client.entities.data.attribute_request import AttributeRequest +from tb_mqtt_client.entities.data.timeseries_entry import TimeseriesEntry + +try: + from mmh3 import hash, hash128 +except ImportError: + try: + from pymmh3 import hash, hash128 + except ImportError: + try: + install_package('mmh3') + except CalledProcessError: + install_package('pymmh3') + +try: + from mmh3 import hash, hash128 # noqa +except ImportError: + from pymmh3 import hash, hash128 + +logger = get_logger(__name__) + + +class FirmwareUpdater: + def __init__(self, client): + self._log = logger + self._client = client + self._client._mqtt_manager.register_handler(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, + self._handle_firmware_update) + self._on_received_callback = None + self._save_firmware = True + self._save_path = './' + self._firmware_request_id = 0 + self._chunk_size = 0 + self._current_chunk = 0 + self._firmware_data = b'' + self._target_firmware_length = 0 + self._target_checksum = 0 + self._target_checksum_alg = None + self._target_version = None + self._target_title = None + self.current_firmware_info = { + 'current_' + FW_TITLE_ATTR: 'Initial', + 'current_' + FW_VERSION_ATTR: 'v0', + FW_STATE_ATTR: FirmwareStates.IDLE.value + } + + async def _handle_firmware_update(self, _, payload: bytes): + self._firmware_data = self._firmware_data + payload + self._current_chunk = self._current_chunk + 1 + + self._log.debug('Getting chunk with number: %s. Chunk size is : %r byte(s).' % ( + self._current_chunk, self._chunk_size)) + + if len(self._firmware_data) == self._target_firmware_length: + self._log.info('Firmware download completed. ' + 'Total firmware size: %s byte(s).' % self._target_firmware_length) + await self._verify_downloaded_firmware() + else: + await self._get_next_chunk() + + async def _get_next_chunk(self): + if not self._chunk_size or self._chunk_size > self._target_firmware_length: + payload = b'' + else: + payload = str(self._chunk_size).encode() + + topic = mqtt_topics.build_firmware_update_request_topic(self._firmware_request_id, self._current_chunk) + await self._client._message_queue.publish(topic=topic, payload=payload, datapoints_count=0, qos=1) + + async def _verify_downloaded_firmware(self): + self._log.info('Verifying downloaded firmware...') + + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADED.value + await self._send_current_firmware_info() + + verified = self.verify_checksum(self._firmware_data, + self._target_checksum, + self._target_checksum_alg) + + if verified: + self._log.debug('Checksum verified.') + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.VERIFIED.value + else: + self._log.error('Checksum verification failed.') + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + + await self._send_current_firmware_info() + + if self.current_firmware_info[FW_STATE_ATTR] == FirmwareStates.VERIFIED.value: + await self._apply_downloaded_firmware() + + async def _apply_downloaded_firmware(self): + self._log.info('Applying downloaded firmware...') + + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.UPDATING.value + await self._send_current_firmware_info() + + try: + if self._save_firmware: + self._save() + except Exception as e: + self._log.error('Failed to save firmware: %s', e) + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._send_current_firmware_info() + return + + self.current_firmware_info = { + "current_" + FW_TITLE_ATTR: self._target_title, + "current_" + FW_VERSION_ATTR: self._target_version, + FW_STATE_ATTR: FirmwareStates.UPDATED.value + } + + await self._send_current_firmware_info() + + if self._on_received_callback: + await self._on_received_callback(self._firmware_data, self.current_firmware_info) + await self._client._mqtt_manager.unsubscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC) + + self._log.info('Firmware is updated.') + self._log.info('Current firmware version is: %s' % self._target_version) + + def _save(self): + firmware_path = self._save_path + sep + self._target_title + with open(firmware_path, "wb") as firmware_file: + firmware_file.write(self._firmware_data) + + async def update(self, on_received_callback: Optional[Callable[[str], Awaitable[None]]] = None, + save_firmware: bool = True, firmware_save_path: Optional[str] = None): + if not self._client._mqtt_manager.is_connected(): + self._log.error("Client is not connected. Cannot start firmware update.") + return + + self._log.info("Starting firmware update process...") + + self._on_received_callback = on_received_callback + self._save_firmware = save_firmware + if firmware_save_path: + self._save_path = firmware_save_path + self._log.info("Firmware will be saved to: %s", self._save_path) + + sub_future = await self._client._mqtt_manager.subscribe(mqtt_topics.DEVICE_FIRMWARE_UPDATE_RESPONSE_TOPIC, + qos=1) + while not sub_future.done(): + await sleep(0.01) + + await self._send_current_firmware_info() + + attribute_request = await AttributeRequest.build(REQUIRED_SHARED_KEYS) + await self._client.send_attribute_request(attribute_request, callback=self._firmware_info_callback) + + async def _firmware_info_callback(self, response, *args, **kwargs): + if len(response.shared_keys()) == len(REQUIRED_SHARED_KEYS): + fetched_firmware_info = response.as_dict()['shared'] + fetched_firmware_info = {item['key']: item['value'] + for item in fetched_firmware_info} + + if self._is_different_firmware_versions(fetched_firmware_info): + self._log.info("Firmware update available: %s. Downloading...", + fetched_firmware_info) + + self._firmware_data = b'' + self._current_chunk = 0 + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.DOWNLOADING.value + + self._firmware_request_id += 1 + self._target_firmware_length = fetched_firmware_info[FW_SIZE_ATTR] + self._target_checksum = fetched_firmware_info[FW_CHECKSUM_ALG_ATTR] + self._target_checksum_alg = fetched_firmware_info[FW_CHECKSUM_ATTR] + self._target_title = fetched_firmware_info[FW_TITLE_ATTR] + self._target_version = fetched_firmware_info[FW_VERSION_ATTR] + + await self._get_next_chunk() + else: + self._log.info("Firmware is up to date.") + else: + self._log.error("Failed to fetch firmware info. " + "Received firmware info does not match required keys. " + "Expected: %s, Received: %s", + REQUIRED_SHARED_KEYS, + response.shared_keys()) + + self.current_firmware_info[FW_STATE_ATTR] = FirmwareStates.FAILED.value + await self._send_current_firmware_info() + + def _is_different_firmware_versions(self, new_firmware_info): + return (self.current_firmware_info['current_' + FW_TITLE_ATTR] != new_firmware_info[FW_TITLE_ATTR] or # noqa + self.current_firmware_info['current_' + FW_VERSION_ATTR] != new_firmware_info[FW_VERSION_ATTR]) # noqa + + async def _send_current_firmware_info(self): + current_info = [TimeseriesEntry(key, value) for key, value in self.current_firmware_info.items()] + await self._client.send_telemetry(current_info, wait_for_publish=True) + + def verify_checksum(self, firmware_data, checksum_alg, checksum): + if firmware_data is None: + self._log.debug('Firmware wasn\'t received!') + return False + + if checksum is None: + self._log.debug('Checksum was\'t provided!') + return False + + checksum_of_received_firmware = None + + self._log.debug('Checksum algorithm is: %s' % checksum_alg) + if checksum_alg.lower() == "sha256": + checksum_of_received_firmware = sha256(firmware_data).digest().hex() + elif checksum_alg.lower() == "sha384": + checksum_of_received_firmware = sha384(firmware_data).digest().hex() + elif checksum_alg.lower() == "sha512": + checksum_of_received_firmware = sha512(firmware_data).digest().hex() + elif checksum_alg.lower() == "md5": + checksum_of_received_firmware = md5(firmware_data).digest().hex() + elif checksum_alg.lower() == "murmur3_32": + reversed_checksum = f'{hash(firmware_data, signed=False):0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + elif checksum_alg.lower() == "murmur3_128": + reversed_checksum = f'{hash128(firmware_data, signed=False):0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + elif checksum_alg.lower() == "crc32": + reversed_checksum = f'{crc32(firmware_data) & 0xffffffff:0>2X}' + if len(reversed_checksum) % 2 != 0: + reversed_checksum = '0' + reversed_checksum + checksum_of_received_firmware = "".join( + reversed([reversed_checksum[i:i + 2] for i in range(0, len(reversed_checksum), 2)])).lower() + else: + self._log.error('Client error. Unsupported checksum algorithm.') + + self._log.debug(checksum_of_received_firmware) + + random_value = randint(0, 5) + if random_value > 3: + self._log.debug('Dummy fail! Do not panic, just restart and try again the chance of this fail is ~20%') + return False + + return checksum_of_received_firmware == checksum diff --git a/tb_mqtt_client/service/message_dispatcher.py b/tb_mqtt_client/service/message_dispatcher.py index 52d1790..fe85482 100644 --- a/tb_mqtt_client/service/message_dispatcher.py +++ b/tb_mqtt_client/service/message_dispatcher.py @@ -26,6 +26,7 @@ from tb_mqtt_client.entities.data.attribute_request import AttributeRequest from tb_mqtt_client.entities.data.attribute_update import AttributeUpdate from tb_mqtt_client.entities.data.device_uplink_message import DeviceUplinkMessage +from tb_mqtt_client.entities.data.provisioning_request import ProvisioningRequest, ProvisioningCredentialsType from tb_mqtt_client.entities.data.requested_attribute_response import RequestedAttributeResponse from tb_mqtt_client.entities.data.rpc_request import RPCRequest from tb_mqtt_client.entities.data.rpc_response import RPCResponse @@ -84,6 +85,14 @@ def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: """ pass + @abstractmethod + def build_provision_request(self, provision_request) -> Tuple[str, bytes]: + """ + Build the payload for a device provisioning request. + This method should return a tuple of topic and payload bytes. + """ + pass + @abstractmethod def splitter(self) -> MessageSplitter: """ @@ -292,20 +301,65 @@ def build_rpc_request(self, rpc_request: RPCRequest) -> Tuple[str, bytes]: rpc_request.request_id, payload) return topic, payload - def build_rpc_response(self, rpc_response: RPCResponse) -> Tuple[str, bytes]: - """ - Build the payload for an RPC response. - :param rpc_response: The RPC response to build the payload for. - :return: A tuple of topic and payload bytes. - """ - if not rpc_response.request_id: - raise ValueError("RPCResponse must have a valid request ID.") - - payload = dumps(rpc_response.to_payload_format()) - topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) - logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) - return topic, payload + """ + Build the payload for an RPC response. + :param rpc_response: The RPC response to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not rpc_response.request_id: + raise ValueError("RPCResponse must have a valid request ID.") + + payload = dumps(rpc_response.to_payload_format()) + topic = mqtt_topics.DEVICE_RPC_RESPONSE_TOPIC + str(rpc_response.request_id) + logger.trace("Built RPC response payload for request ID=%d with payload: %r", rpc_response.request_id, payload) + return topic, payload + + def build_provision_request(self, provision_request: 'ProvisioningRequest') -> Tuple[str, bytes]: + """ + Build the payload for a device provisioning request. + :param provision_request: The ProvisioningRequest to build the payload for. + :return: A tuple of topic and payload bytes. + """ + if not provision_request.credentials.provision_device_key or not provision_request.credentials.provision_device_secret: + raise ValueError("ProvisioningRequest must have valid device key and secret.") + + topic = mqtt_topics.PROVISION_REQUEST_TOPIC + request = {} + request["provisionDeviceKey"] = provision_request.credentials.provision_device_key + request["provisionDeviceSecret"] = provision_request.credentials.provision_device_secret + + if provision_request.device_name: + request["deviceName"] = provision_request.device_name + + if provision_request.gateway: + request["gateway"] = provision_request.gateway + + if provision_request.credentials.credentials_type and \ + provision_request.credentials.credentials_type == ProvisioningCredentialsType.ACCESS_TOKEN: + if provision_request.credentials.access_token is not None: + request["token"] = provision_request.credentials.access_token + request["credentialsType"] = provision_request.credentials.credentials_type.value + + if provision_request.credentials.credentials_type == ProvisioningCredentialsType.MQTT_BASIC: + if provision_request.credentials.username is not None: + request["username"] = provision_request.credentials.username + + if provision_request.credentials.password is not None: + request["password"] = provision_request.credentials.password + + if provision_request.credentials.client_id is not None: + request["clientId"] = provision_request.credentials.client_id + + request["credentialsType"] = provision_request.credentials.credentials_type.value + + if provision_request.credentials.credentials_type == ProvisioningCredentialsType.X509_CERTIFICATE: + request["hash"] = provision_request.credentials.public_cert + request["credentialsType"] = provision_request.credentials.credentials_type.value + + payload = dumps(request) + logger.trace("Built provision request payload: %r", provision_request) + return topic, payload @staticmethod def build_payload(msg: DeviceUplinkMessage, build_timeseries_payload) -> bytes: