Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: type hints for the connnection class #1189

Merged
merged 10 commits into from
Nov 19, 2024
77 changes: 49 additions & 28 deletions juju/client/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2023 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

import base64
import json
Expand All @@ -9,37 +10,28 @@
import warnings
import weakref
from http.client import HTTPSConnection
from typing import Dict, Literal, Optional, Sequence
from typing import Any, Literal, Sequence

import macaroonbakery.bakery as bakery
import macaroonbakery.httpbakery as httpbakery
import websockets
from dateutil.parser import parse
from typing_extensions import Self, TypeAlias, overload

from juju import errors, jasyncio, tag, utils
from juju.client import client
from juju.utils import IdQueue
from juju.version import CLIENT_VERSION

from .facade import _JSON, _RICH_JSON, TypeEncoder
from .facade_versions import client_facade_versions, known_unsupported_facades

SPECIFIED_FACADES: TypeAlias = dict[str, dict[Literal["versions"], Sequence[int]]]

dimaqq marked this conversation as resolved.
Show resolved Hide resolved
LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"]
log = logging.getLogger("juju.client.connection")


def facade_versions(name, versions):
"""facade_versions returns a new object that correctly returns a object in
format expected by the connection facades inspection.
:param name: name of the facade
:param versions: versions to support by the facade
"""
if name.endswith("Facade"):
name = name[: -len("Facade")]
return {
name: {"versions": versions},
}


class Monitor:
"""Monitor helper class for our Connection class.

Expand All @@ -59,7 +51,7 @@ class Monitor:
DISCONNECTING = "disconnecting"
DISCONNECTED = "disconnected"

def __init__(self, connection):
def __init__(self, connection: Connection):
self.connection = weakref.ref(connection)
self.reconnecting = jasyncio.Lock()
self.close_called = jasyncio.Event()
Expand Down Expand Up @@ -117,28 +109,40 @@ class Connection:

MAX_FRAME_SIZE = 2**22
"Maximum size for a single frame. Defaults to 4MB."
facades: Dict[str, int]
_specified_facades: Dict[str, Sequence[int]]
facades: dict[str, int]
_specified_facades: dict[str, Sequence[int]]
bakery_client: Any
usertag: str | None
password: str | None
name: str
__request_id__: int
endpoints: list[tuple[str, str]] | None # Set by juju/controller.py
is_debug_log_connection: bool
monitor: Monitor
proxy: Any # Need to find types for this library
max_frame_size: int
_retries: int
_retry_backoff: float
uuid: str | None
messages: IdQueue

@classmethod
async def connect(
cls,
endpoint=None,
uuid=None,
username=None,
password=None,
uuid: str | None = None,
username: str | None = None,
password: str | None = None,
cacert=None,
bakery_client=None,
max_frame_size=None,
max_frame_size: int | None = None,
retries=3,
retry_backoff=10,
specified_facades: Optional[
Dict[str, Dict[Literal["versions"], Sequence[int]]]
] = None,
specified_facades: SPECIFIED_FACADES | None = None,
proxy=None,
debug_log_conn=None,
debug_log_params={},
):
) -> Self:
"""Connect to the websocket.

If uuid is None, the connection will be to the controller. Otherwise it
Expand Down Expand Up @@ -270,7 +274,7 @@ def ws(self):
return self._ws

@property
def username(self):
def username(self) -> str | None:
if not self.usertag:
return None
return self.usertag[len("user-") :]
Expand Down Expand Up @@ -372,7 +376,7 @@ async def close(self, to_reconnect=False):
if self.proxy is not None:
self.proxy.close()

async def _recv(self, request_id):
async def _recv(self, request_id: int) -> dict[str, Any]:
if not self.is_open:
raise websockets.exceptions.ConnectionClosed(
websockets.frames.Close(
Expand Down Expand Up @@ -534,7 +538,19 @@ async def _do_ping():
log.debug("ping failed because of closed connection")
pass

async def rpc(self, msg, encoder=None):
@overload
async def rpc(
self, msg: dict[str, _JSON], encoder: None = None
) -> dict[str, _JSON]: ...

@overload
async def rpc(
self, msg: dict[str, _RICH_JSON], encoder: TypeEncoder
) -> dict[str, _JSON]: ...

async def rpc(
self, msg: dict[str, Any], encoder: json.JSONEncoder | None = None
) -> dict[str, Any]:
dimaqq marked this conversation as resolved.
Show resolved Hide resolved
"""Make an RPC to the API. The message is encoded as JSON
using the given encoder if any.
:param msg: Parameters for the call (will be encoded as JSON).
Expand Down Expand Up @@ -744,8 +760,13 @@ async def _try_endpoint(endpoint, cacert, delay):
# only executed if inner loop's else did not continue
# (i.e., inner loop did break due to successful connection)
break
else:
# impossible, work around https://github.com/microsoft/pyright/issues/8791
assert False # noqa: B011

dimaqq marked this conversation as resolved.
Show resolved Hide resolved
for task in tasks:
task.cancel()

self._ws = result[0]
self.addr = result[1]
self.endpoint = result[2]
Expand Down
10 changes: 6 additions & 4 deletions juju/client/connector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright 2023 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

import copy
import logging
from typing import Any

import macaroonbakery.httpbakery as httpbakery
from packaging import version
Expand Down Expand Up @@ -33,9 +35,9 @@ class Connector:

def __init__(
self,
max_frame_size=None,
bakery_client=None,
jujudata=None,
max_frame_size: int | None = None,
bakery_client: Any | None = None,
jujudata: Any | None = None,
):
"""Initialize a connector that will use the given parameters
by default when making a new connection
Expand All @@ -52,7 +54,7 @@ def is_connected(self):
"""Report whether there is a currently connected controller or not"""
return self._connection is not None

def connection(self):
def connection(self) -> Connection:
"""Return the current connection; raises an exception if there
is no current connection.
"""
Expand Down
24 changes: 17 additions & 7 deletions juju/client/facade.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2023 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

import argparse
import builtins
Expand All @@ -13,13 +14,22 @@
from collections import defaultdict
from glob import glob
from pathlib import Path
from typing import Any, Dict, List, Mapping, Sequence
from typing import Any, Mapping, Sequence

import packaging.version
import typing_inspect
from typing_extensions import TypeAlias

from . import codegen

# Plain JSON, what is received from Juju
_JSON_LEAF: TypeAlias = None | bool | int | float | str
dimaqq marked this conversation as resolved.
Show resolved Hide resolved
_JSON: TypeAlias = "_JSON_LEAF|list[_JSON]|dict[str, _JSON]"

# Type-enriched JSON, what can be sent to Juju
_RICH_LEAF: TypeAlias = "_JSON_LEAF|Type"
_RICH_JSON: TypeAlias = "_RICH_LEAF|list[_RICH_JSON]|dict[str, _RICH_JSON]"

_marker = object()

JUJU_VERSION = re.compile(r"[0-9]+\.[0-9-]+[\.\-][0-9a-z]+(\.[0-9]+)?")
Expand Down Expand Up @@ -634,7 +644,7 @@ class {name}Facade(Type):


class TypeEncoder(json.JSONEncoder):
def default(self, obj):
def default(self, obj: _RICH_JSON) -> _JSON:
if isinstance(obj, Type):
return obj.serialize()
return json.JSONEncoder.default(self, obj)
Expand All @@ -653,7 +663,7 @@ def __eq__(self, other):

return self.__dict__ == other.__dict__

async def rpc(self, msg):
async def rpc(self, msg: dict[str, _RICH_JSON]) -> _JSON:
result = await self.connection.rpc(msg, encoder=TypeEncoder)
return result

Expand Down Expand Up @@ -704,13 +714,13 @@ def _parse_nested_list_entry(expr, result_dict):
return cls(**d)
return None

def serialize(self):
def serialize(self) -> dict[str, _JSON]:
d = {}
for attr, tgt in self._toSchema.items():
d[tgt] = getattr(self, attr)
return d

def to_json(self):
def to_json(self) -> str:
return json.dumps(self.serialize(), cls=TypeEncoder, sort_keys=True)

def __contains__(self, key):
Expand Down Expand Up @@ -917,8 +927,8 @@ def generate_definitions(schemas):


def generate_facades(
schemas: Dict[str, List[Schema]],
) -> Dict[str, Dict[int, codegen.Capture]]:
schemas: dict[str, list[Schema]],
) -> dict[str, dict[int, codegen.Capture]]:
captures = defaultdict(codegen.Capture)

# Build the Facade classes
Expand Down
14 changes: 10 additions & 4 deletions juju/client/overrides.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright 2023 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

import re
from collections import namedtuple
from typing import Any, NamedTuple

from . import _client, _definitions
from .facade import ReturnMapping, Type, TypeEncoder
Expand All @@ -22,6 +23,12 @@
]


class _Change(NamedTuple):
entity: str
type: str
data: dict[str, Any]


class Delta(Type):
"""A single websocket delta.

Expand All @@ -42,12 +49,11 @@ class Delta(Type):
_toSchema = {"deltas": "deltas"}
_toPy = {"deltas": "deltas"}

def __init__(self, deltas=None):
def __init__(self, deltas: tuple[str, str, dict[str, Any]]):
""":param deltas: [str, str, object]"""
self.deltas = deltas

Change = namedtuple("Change", "entity type data")
change = Change(*self.deltas)
change = _Change(*self.deltas)

self.entity = change.entity
self.type = change.type
Expand Down
14 changes: 9 additions & 5 deletions juju/delta.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright 2023 Canonical Ltd.
# Licensed under the Apache V2, see LICENCE file for details.
from __future__ import annotations

from .client import client
from . import model
from .client import client, overrides


def get_entity_delta(d):
def get_entity_delta(d: overrides.Delta):
return _delta_types[d.entity](d.deltas)


Expand All @@ -13,12 +15,14 @@ def get_entity_class(entity_type):


class EntityDelta(client.Delta):
def get_id(self):
data: dict[str, str]

def get_id(self) -> str:
return self.data["id"]

@classmethod
def get_entity_class(cls):
return None
def get_entity_class(cls) -> type[model.ModelEntity]:
raise NotImplementedError()


class ActionDelta(EntityDelta):
Expand Down
3 changes: 2 additions & 1 deletion juju/jasyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from asyncio import (
CancelledError,
Task,
create_task,
wait,
)
Expand Down Expand Up @@ -84,7 +85,7 @@
ROOT_LOGGER = logging.getLogger()


def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER):
def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER) -> Task:
"""Wrapper around "asyncio.create_task" to make sure the task
exceptions are handled properly.

Expand Down
Loading