Skip to content

Commit

Permalink
Fix client ID and subscribed topics (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw authored Dec 29, 2023
1 parent ad6083b commit 503872e
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 116 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
python-version: ['3.8', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v2
Expand Down
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ repos:
args: [--markdown-linebreak-ext=md]
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.7.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.8
hooks:
- id: black
# Run the linter.
#- id: ruff
# args: [ --fix ]
# Run the formatter.
- id: ruff-format
53 changes: 15 additions & 38 deletions JciHitachi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,13 @@ def task_id(self) -> int:
Returns
-------
int
Serial number counted from 0.
Serial number counted from 0, with maximum 999.
"""

self._task_id += 1
if self._task_id >= 1000:
self._task_id = 1

return self._task_id

def _sync_peripherals_availablity(self) -> None:
Expand Down Expand Up @@ -914,7 +917,6 @@ def __init__(
self._things: dict[str, AWSThing] = {}
self._aws_tokens: Optional[aws_connection.AWSTokens] = None
self._aws_identity: Optional[aws_connection.AWSIdentity] = None
self._host_identity_id: Optional[str] = None
self._task_id: int = 0

@property
Expand Down Expand Up @@ -980,24 +982,6 @@ def login(self) -> None:
self._aws_tokens = conn.aws_tokens
conn_status, self._aws_identity = conn.get_data()

conn = aws_connection.ListSubUser(
self._aws_tokens, print_response=self.print_response
)
conn_status, conn_json = conn.get_data()

if conn_status == "OK":
for user in conn_json["results"]["FamilyMemberList"]:
if user["isHost"]:
self._host_identity_id = user["userId"]
break
assert (
self._host_identity_id is not None
), "Host is not found in the user list"
else:
raise RuntimeError(
f"An error occurred when listing account users: {conn_status}"
)

conn = aws_connection.GetAllDevice(
self._aws_tokens, print_response=self.print_response
)
Expand Down Expand Up @@ -1028,13 +1012,13 @@ def get_credential_callable():
self._mqtt = aws_connection.JciHitachiAWSMqttConnection(
get_credential_callable, print_response=self.print_response
)
self._mqtt.configure()
self._mqtt.configure(self._aws_identity.identity_id)

if not self._mqtt.connect(
self._host_identity_id, self._shadow_names, thing_names
self._aws_identity.host_identity_id, self._shadow_names, thing_names
):
raise RuntimeError(
f"An error occurred when connecting to MQTT endpoint."
"An error occurred when connecting to MQTT endpoint."
)

# status
Expand Down Expand Up @@ -1175,7 +1159,7 @@ def refresh_status(

if refresh_support_code:
self._mqtt.publish(
self._host_identity_id,
self._aws_identity.host_identity_id,
thing.thing_name,
"support",
self._mqtt_timeout,
Expand All @@ -1184,7 +1168,10 @@ def refresh_status(
self._mqtt.publish_shadow(thing.thing_name, "get", shadow_name="info")

self._mqtt.publish(
self._host_identity_id, thing.thing_name, "status", self._mqtt_timeout
self._aws_identity.host_identity_id,
thing.thing_name,
"status",
self._mqtt_timeout,
)

# execute
Expand Down Expand Up @@ -1324,9 +1311,7 @@ def set_status(
"enableQAMode": "qa",
}

if (
False
): # status_name in shadow_publish_mapping: # TODO: replace False cond after shadow function is completed.
if False: # status_name in shadow_publish_mapping: # TODO: replace False cond after shadow function is completed.
shadow_publish_schema = {}
if (
shadow_publish_mapping[status_name] == "filter"
Expand Down Expand Up @@ -1356,22 +1341,14 @@ def set_status(
return False

self._mqtt.publish(
self._host_identity_id,
self._aws_identity.host_identity_id,
thing.thing_name,
"control",
self._mqtt_timeout,
{
"Condition": {
"ThingName": thing.thing_name,
"Index": 0,
"Geofencing": {
"Arrive": None,
"Leave": None,
},
},
status_name: status_value,
"TaskID": self.task_id,
"Timestamp": time.time(),
"Timestamp": int(time.time()),
},
)

Expand Down
44 changes: 22 additions & 22 deletions JciHitachi/aws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
import logging
import threading
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from random import random
from random import random, choices
from typing import Callable, Optional, Union

import awscrt
import httpx
from awsiot import iotshadow, mqtt_connection_builder

from .model import JciHitachiAWSStatus, JciHitachiAWSStatusSupport
from .utility import to_thread

AWS_REGION = "ap-northeast-1"
AWS_COGNITO_IDP_ENDPOINT = f"cognito-idp.{AWS_REGION}.amazonaws.com"
Expand All @@ -43,6 +41,7 @@ class AWSTokens:
@dataclass
class AWSIdentity:
identity_id: str
host_identity_id: str
user_name: str
user_attributes: dict

Expand Down Expand Up @@ -204,7 +203,7 @@ def login(self, use_refresh_token: bool = False) -> tuple(str, AWSTokens):
"""

# https://docs.aws.amazon.com/cognito-user-identity-pools/latest/APIReference/API_InitiateAuth.html
if use_refresh_token and self._aws_tokens != None:
if use_refresh_token and self._aws_tokens is not None:
login_json_data = {
"AuthFlow": "REFRESH_TOKEN_AUTH",
"AuthParameters": {
Expand Down Expand Up @@ -312,6 +311,7 @@ def get_data(self):
}
aws_identity = AWSIdentity(
identity_id=user_attributes["custom:cognito_identity_id"],
host_identity_id=user_attributes["custom:host_identity_id"],
user_name=response["Username"],
user_attributes=user_attributes,
)
Expand Down Expand Up @@ -653,7 +653,7 @@ def _on_message(self, topic, payload, dup, qos, retain, **kwargs):
return

def _on_connection_interrupted(self, connection, error, **kwargs):
_LOGGER.error("MQTT connection was interrupted with exception {error}.")
_LOGGER.error(f"MQTT connection was interrupted with exception {error}")
self._mqtt_events.mqtt_error = error.__class__.__name__
self._mqtt_events.mqtt_error_event.set()

Expand Down Expand Up @@ -681,11 +681,11 @@ def on_resubscribe_complete(resubscribe_future):
_LOGGER.info("Resubscribed successfully.")
return

async def _wrap_async(self, identifier: str, fn: Callable, timeout: float) -> str:
async def _wrap_async(self, identifier: str, fn: Callable) -> str:
await asyncio.sleep(
random() / 2
) # randomly wait 0~0.5 seconds to prevent messages flooding to the broker.
await asyncio.wait_for(to_thread(fn), timeout)
await asyncio.to_thread(fn)
return identifier

def disconnect(self) -> None:
Expand All @@ -694,7 +694,7 @@ def disconnect(self) -> None:
if self._mqttc is not None:
self._mqttc.disconnect()

def configure(self) -> None:
def configure(self, identity_id) -> None:
"""Configure MQTT."""

cred_provider = awscrt.auth.AwsCredentialsProvider.new_delegate(
Expand All @@ -708,7 +708,7 @@ def configure(self) -> None:
cred_provider,
client_bootstrap=client_bootstrap,
endpoint=AWS_MQTT_ENDPOINT,
client_id=str(uuid.uuid4()),
client_id=f"{identity_id}_{''.join(choices('abcdef0123456789', k=16))}", # {identityid}_{64bit_hex}
on_connection_interrupted=self._on_connection_interrupted,
on_connection_resumed=self._on_connection_resumed,
)
Expand Down Expand Up @@ -750,7 +750,7 @@ def connect(

try:
subscribe_future, _ = self._mqttc.subscribe(
f"{host_identity_id}/#", QOS, callback=self._on_publish
f"{host_identity_id}/+/+/response", QOS, callback=self._on_publish
)
subscribe_future.result()

Expand Down Expand Up @@ -861,11 +861,11 @@ def fn():
publish_future, _ = self._mqttc.publish(
support_topic, json.dumps(default_payload), QOS
)
publish_future.result()
self._mqtt_events.device_support_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_support_event[thing_name].wait(timeout)

self._execution_pools.support_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)
elif publish_type == "status":
status_topic = f"{host_identity_id}/{thing_name}/status/request"
Expand All @@ -878,11 +878,11 @@ def fn():
publish_future, _ = self._mqttc.publish(
status_topic, json.dumps(default_payload), QOS
)
publish_future.result()
self._mqtt_events.device_status_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_status_event[thing_name].wait(timeout)

self._execution_pools.status_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)
elif publish_type == "control":
control_topic = f"{host_identity_id}/{thing_name}/control/request"
Expand All @@ -895,11 +895,11 @@ def fn():
publish_future, _ = self._mqttc.publish(
control_topic, json.dumps(payload), QOS
)
publish_future.result()
self._mqtt_events.device_control_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_control_event[thing_name].wait(timeout)

self._execution_pools.control_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)

else:
Expand Down Expand Up @@ -995,11 +995,11 @@ def fn():
),
qos=QOS,
)
publish_future.result()
self._mqtt_events.device_shadow_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_shadow_event[thing_name].wait(timeout)

self._execution_pools.shadow_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)

def execute(
Expand Down
1 change: 0 additions & 1 deletion JciHitachi/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import ssl

import httpx

Expand Down
11 changes: 4 additions & 7 deletions JciHitachi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,13 +2196,6 @@ class JciHitachiAWSStatusSupport:
Status retrieved from `JciHitachiAWSMqttConnection` _on_publish() callback.
"""

extended_mapping = {
"FirmwareId": None,
"Model": "model",
"Brand": "brand",
"FindMe": None,
}

device_type_mapping = JciHitachiAWSStatus.device_type_mapping

def __init__(self, raw_status: dict) -> None:
Expand All @@ -2217,6 +2210,10 @@ def __repr__(self) -> str:

def _preprocess(self, status):
status = status.copy()

if status.get("Error", 0) != 0:
return status

# device type
status["DeviceType"] = self.device_type_mapping[status["DeviceType"]]

Expand Down
26 changes: 2 additions & 24 deletions JciHitachi/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,29 +111,7 @@ def extract_bytes(v, start, end): # pragma: no cover
Extracted value.
"""

assert (
start > end and end >= 0
), "Starting byte must be greater than ending byte, \
assert start > end and end >= 0, "Starting byte must be greater than ending byte, \
and ending byte must be greater than zero : \
{}, {}".format(
start, end
)
{}, {}".format(start, end)
return cast_bytes(v >> end * 8, start - end)


# Copied from https://github.com/python/cpython/blob/main/Lib/asyncio/threads.py
# TODO: Remove this once we upgrade the minimally supported Python version to 3.9.
async def to_thread(func, /, *args, **kwargs): # pragma: no cover
"""Asynchronously run function *func* in a separate thread.
Any *args and **kwargs supplied for this function are directly passed
to *func*. Also, the current :class:`contextvars.Context` is propagated,
allowing context variables from the main thread to be accessed in the
separate thread.
Return a coroutine that can be awaited to get the eventual result of *func*.
"""
loop = events.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
awsiotsdk==1.15.4
awsiotsdk==1.20.0
httpx
paho-mqtt
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
long_description = f.read()

install_requires = [
"awsiotsdk==1.15.4",
"awsiotsdk==1.20.0",
"httpx",
"paho-mqtt",
]
Expand All @@ -32,7 +32,6 @@
"Documentation": "https://libjcihitachi.readthedocs.io/en/latest/",
},
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down
Loading

0 comments on commit 503872e

Please sign in to comment.