Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for libcoap.py #462

Merged
merged 6 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true

[mypy-pytradfri.api.libcoap_api]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true

[mypy-pytradfri.device.base_controller]
check_untyped_defs = true
disallow_incomplete_defs = true
Expand Down
72 changes: 42 additions & 30 deletions pytradfri/api/libcoap_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""COAP implementation."""
from __future__ import annotations

from functools import wraps
import json
import logging
import subprocess
from time import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast, overload

from aiocoap import Message

from ..command import Command, T
from ..error import ClientError, RequestError, RequestTimeout, ServerError
Expand All @@ -21,25 +22,34 @@
class APIFactory:
"""APIFactory."""

def __init__(self, host, psk_id="pytradfri", psk=None, timeout=10):
def __init__(
self,
host: str,
psk_id: str = "pytradfri",
psk: str | None = None,
timeout: int = 10,
) -> None:
"""Create object of class."""
self._host = host
self._psk_id = psk_id
self._psk = psk
self._timeout = timeout # seconds

@property
def psk(self):
def psk(self) -> str | None:
"""Return psk."""
return self._psk

@psk.setter
def psk(self, value):
def psk(self, value: str) -> None:
"""Set psk."""
self._psk = value

def _base_command(self, method: str) -> list[str]:
"""Return base command."""
if self._psk is None:
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("You must enter a PSK.")

return [
"coap-client",
"-u",
Expand All @@ -52,7 +62,9 @@ def _base_command(self, method: str) -> list[str]:
method,
]

def _execute(self, api_command, *, timeout=None):
def _execute(
self, api_command: Command[T], *, timeout: int | None = None
) -> T | None:
"""Execute the command."""

if api_command.observe:
Expand All @@ -71,7 +83,7 @@ def _execute(self, api_command, *, timeout=None):

command = self._base_command(method)

kwargs = {
kwargs: dict[str, Any] = {
"stderr": subprocess.DEVNULL,
"timeout": proc_timeout,
"universal_newlines": True,
Expand All @@ -98,7 +110,21 @@ def _execute(self, api_command, *, timeout=None):
api_command.process_result(_process_output(return_value, parse_json))
return api_command.result

def request(self, api_commands, *, timeout=None):
@overload
def request(self, api_commands: Command[T], timeout: int | None = None) -> T | None:
"""Make a request. Timeout is in seconds."""
...

@overload
def request(
self, api_commands: list[Command[T]], timeout: int | None = None
) -> list[Optional[T]] | None:
"""Make a request. Timeout is in seconds."""
...

def request(
self, api_commands: Command[T] | list[Command[T]], timeout: int | None = None
) -> T | list[Optional[T]] | None:
"""Make a request. Timeout is in seconds."""
if not isinstance(api_commands, list):
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
return self._execute(api_commands, timeout=timeout)
Expand Down Expand Up @@ -169,7 +195,7 @@ def read_stdout() -> str:
api_command.process_result(_process_output(output))
output = ""

def generate_psk(self, security_key):
def generate_psk(self, security_key: str) -> str:
"""Generate and set a psk from the security key."""
if not self._psk:
# Backup the real identity.
Expand All @@ -180,15 +206,18 @@ def generate_psk(self, security_key):
self._psk = security_key

# Ask the Gateway to generate the psk for the identity.
self._psk = self.request(Gateway().generate_psk(existing_psk_id))
command: list[Command[str]] = [Gateway().generate_psk(existing_psk_id)]
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
self._psk = cast(str, self.request(command))

# Restore the real identity.
self._psk_id = existing_psk_id

return self._psk


def _process_output(output, parse_json=True):
def _process_output(
output: Message, parse_json: bool = True
) -> list[Any] | dict[Any, Any] | str | None:
"""Process output."""
output = output.strip()
_LOGGER.debug("Received: %s", output)
Expand All @@ -206,22 +235,5 @@ def _process_output(output, parse_json=True):
if output.startswith(SERVER_ERROR_PREFIX):
raise ServerError(output)
if not parse_json:
return output
return json.loads(output)


def retry_timeout(api, retries=3):
MartinHjelmare marked this conversation as resolved.
Show resolved Hide resolved
"""Retry API call when a timeout occurs."""

@wraps(api)
def retry_api(*args, **kwargs):
"""Retrying API."""
for i in range(1, retries + 1):
try:
return api(*args, **kwargs)
except RequestTimeout:
if i == retries:
raise
return None

return retry_api
return cast(str, output)
return cast(Union[Dict[Any, Any], List[Any]], json.loads(output))
61 changes: 3 additions & 58 deletions tests/api/test_libcoap_api.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,10 @@
"""Test API utilities."""
import json

import pytest

from pytradfri import RequestTimeout
from pytradfri.api.libcoap_api import APIFactory, retry_timeout
from pytradfri.api.libcoap_api import APIFactory
from pytradfri.gateway import Gateway


def test_retry_timeout_passes_args():
"""Test passing args."""
calls = []

def api(*args, **kwargs):
"""Mock api."""
calls.append((args, kwargs))

retry_api = retry_timeout(api)

retry_api(1, 2, hello="world")
assert len(calls) == 1
args, kwargs = calls[0]
assert args == (1, 2)
assert kwargs == {"hello": "world"}


def test_retry_timeout_retries_timeouts():
"""Test retrying timeout."""
calls = []

def api(*args, **kwargs):
"""Mock api."""
calls.append((args, kwargs))

if len(calls) == 1:
raise RequestTimeout()

retry_api = retry_timeout(api, retries=2)

retry_api(1, 2, hello="world")
assert len(calls) == 2


def test_retry_timeout_raises_after_max_retries():
"""Test retrying timeout."""
calls = []

def api(*args, **kwargs):
"""Mock api."""
calls.append((args, kwargs))

raise RequestTimeout()

retry_api = retry_timeout(api, retries=5)

with pytest.raises(RequestTimeout):
retry_api(1, 2, hello="world")

assert len(calls) == 5


def test_constructor_timeout_passed_to_subprocess(monkeypatch):
"""Test that original timeout is passed to subprocess."""
capture = {}
Expand All @@ -70,7 +15,7 @@ def capture_args(*args, **kwargs):

monkeypatch.setattr("subprocess.check_output", capture_args)

api = APIFactory("anything", timeout=20)
api = APIFactory("anything", timeout=20, psk="abc")
api.request(Gateway().get_devices())
assert capture["timeout"] == 20

Expand All @@ -85,6 +30,6 @@ def capture_args(*args, **kwargs):

monkeypatch.setattr("subprocess.check_output", capture_args)

api = APIFactory("anything")
api = APIFactory("anything", psk="abc")
api.request(Gateway().get_devices(), timeout=1)
assert capture["timeout"] == 1