diff --git a/juju/client/connection.py b/juju/client/connection.py index c842311c8..e79f2ea7e 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -1,5 +1,6 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import base64 import json @@ -9,37 +10,29 @@ import warnings import weakref from http.client import HTTPSConnection -from typing import Dict, Literal, Optional, Sequence +from typing import Any, Literal, Sequence import macaroonbakery.bakery as bakery import macaroonbakery.httpbakery as httpbakery import websockets from dateutil.parser import parse +from typing_extensions import Self, TypeAlias, overload from juju import errors, jasyncio, tag, utils from juju.client import client from juju.utils import IdQueue from juju.version import CLIENT_VERSION +from .facade import TypeEncoder, _Json, _RichJson from .facade_versions import client_facade_versions, known_unsupported_facades +SpecifiedFacades: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]" +_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol" + LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] log = logging.getLogger("juju.client.connection") -def facade_versions(name, versions): - """facade_versions returns a new object that correctly returns a object in - format expected by the connection facades inspection. - :param name: name of the facade - :param versions: versions to support by the facade - """ - if name.endswith("Facade"): - name = name[: -len("Facade")] - return { - name: {"versions": versions}, - } - - class Monitor: """Monitor helper class for our Connection class. @@ -59,7 +52,7 @@ class Monitor: DISCONNECTING = "disconnecting" DISCONNECTED = "disconnected" - def __init__(self, connection): + def __init__(self, connection: Connection): self.connection = weakref.ref(connection) self.reconnecting = jasyncio.Lock() self.close_called = jasyncio.Event() @@ -117,28 +110,41 @@ class Connection: MAX_FRAME_SIZE = 2**22 "Maximum size for a single frame. Defaults to 4MB." - facades: Dict[str, int] - _specified_facades: Dict[str, Sequence[int]] + facades: dict[str, int] + _specified_facades: dict[str, Sequence[int]] + bakery_client: Any + usertag: str | None + password: str | None + name: str + __request_id__: int + endpoints: list[tuple[str, str]] | None # Set by juju/controller.py + is_debug_log_connection: bool + monitor: Monitor + proxy: Any # Need to find types for this library + max_frame_size: int + _retries: int + _retry_backoff: float + uuid: str | None + messages: IdQueue + _ws: _WebSocket | None @classmethod async def connect( cls, endpoint=None, - uuid=None, - username=None, - password=None, + uuid: str | None = None, + username: str | None = None, + password: str | None = None, cacert=None, bakery_client=None, - max_frame_size=None, + max_frame_size: int | None = None, retries=3, retry_backoff=10, - specified_facades: Optional[ - Dict[str, Dict[Literal["versions"], Sequence[int]]] - ] = None, + specified_facades: SpecifiedFacades | None = None, proxy=None, debug_log_conn=None, debug_log_params={}, - ): + ) -> Self: """Connect to the websocket. If uuid is None, the connection will be to the controller. Otherwise it @@ -270,7 +276,7 @@ def ws(self): return self._ws @property - def username(self): + def username(self) -> str | None: if not self.usertag: return None return self.usertag[len("user-") :] @@ -299,7 +305,7 @@ def _get_ssl(self, cert=None): context.check_hostname = False return context - async def _open(self, endpoint, cacert): + async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]: if self.is_debug_log_connection: assert self.uuid url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log" @@ -372,7 +378,7 @@ async def close(self, to_reconnect=False): if self.proxy is not None: self.proxy.close() - async def _recv(self, request_id): + async def _recv(self, request_id: int) -> dict[str, Any]: if not self.is_open: raise websockets.exceptions.ConnectionClosed( websockets.frames.Close( @@ -534,7 +540,19 @@ async def _do_ping(): log.debug("ping failed because of closed connection") pass - async def rpc(self, msg, encoder=None): + @overload + async def rpc( + self, msg: dict[str, _Json], encoder: None = None + ) -> dict[str, _Json]: ... + + @overload + async def rpc( + self, msg: dict[str, _RichJson], encoder: TypeEncoder + ) -> dict[str, _Json]: ... + + async def rpc( + self, msg: dict[str, Any], encoder: json.JSONEncoder | None = None + ) -> dict[str, _Json]: """Make an RPC to the API. The message is encoded as JSON using the given encoder if any. :param msg: Parameters for the call (will be encoded as JSON). @@ -710,7 +728,9 @@ async def _connect(self, endpoints): if len(endpoints) == 0: raise errors.JujuConnectionError("no endpoints to connect to") - async def _try_endpoint(endpoint, cacert, delay): + async def _try_endpoint( + endpoint, cacert, delay + ) -> tuple[_WebSocket, str, str, str]: if delay: await jasyncio.sleep(delay) return await self._open(endpoint, cacert) @@ -722,6 +742,8 @@ async def _try_endpoint(endpoint, cacert, delay): jasyncio.ensure_future(_try_endpoint(endpoint, cacert, 0.1 * i)) for i, (endpoint, cacert) in enumerate(endpoints) ] + result: tuple[_WebSocket, str, str, str] | None = None + for attempt in range(self._retries + 1): for task in jasyncio.as_completed(tasks): try: @@ -744,8 +766,12 @@ async def _try_endpoint(endpoint, cacert, delay): # only executed if inner loop's else did not continue # (i.e., inner loop did break due to successful connection) break + for task in tasks: task.cancel() + + assert result # loop raises or sets the result + self._ws = result[0] self.addr = result[1] self.endpoint = result[2] diff --git a/juju/client/connector.py b/juju/client/connector.py index feb650d8c..c9be0cda4 100644 --- a/juju/client/connector.py +++ b/juju/client/connector.py @@ -1,8 +1,10 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import copy import logging +from typing import Any import macaroonbakery.httpbakery as httpbakery from packaging import version @@ -33,9 +35,9 @@ class Connector: def __init__( self, - max_frame_size=None, - bakery_client=None, - jujudata=None, + max_frame_size: int | None = None, + bakery_client: Any | None = None, + jujudata: Any | None = None, ): """Initialize a connector that will use the given parameters by default when making a new connection @@ -52,7 +54,7 @@ def is_connected(self): """Report whether there is a currently connected controller or not""" return self._connection is not None - def connection(self): + def connection(self) -> Connection: """Return the current connection; raises an exception if there is no current connection. """ diff --git a/juju/client/facade.py b/juju/client/facade.py index 08c7a5242..596f97aba 100644 --- a/juju/client/facade.py +++ b/juju/client/facade.py @@ -1,5 +1,6 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import argparse import builtins @@ -13,13 +14,22 @@ from collections import defaultdict from glob import glob from pathlib import Path -from typing import Any, Dict, List, Mapping, Sequence +from typing import Any, Mapping, Sequence import packaging.version import typing_inspect +from typing_extensions import TypeAlias from . import codegen +# Plain JSON, what is received from Juju +_JsonLeaf: TypeAlias = "None | bool | int | float | str" +_Json: TypeAlias = "_JsonLeaf|list[_Json]|dict[str, _Json]" + +# Type-enriched JSON, what can be sent to Juju +_RichLeaf: TypeAlias = "_JsonLeaf|Type" +_RichJson: TypeAlias = "_RichLeaf|list[_RichJson]|dict[str, _RichJson]" + _marker = object() JUJU_VERSION = re.compile(r"[0-9]+\.[0-9-]+[\.\-][0-9a-z]+(\.[0-9]+)?") @@ -634,7 +644,7 @@ class {name}Facade(Type): class TypeEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: _RichJson) -> _Json: if isinstance(obj, Type): return obj.serialize() return json.JSONEncoder.default(self, obj) @@ -653,7 +663,7 @@ def __eq__(self, other): return self.__dict__ == other.__dict__ - async def rpc(self, msg): + async def rpc(self, msg: dict[str, _RichJson]) -> _Json: result = await self.connection.rpc(msg, encoder=TypeEncoder) return result @@ -704,13 +714,13 @@ def _parse_nested_list_entry(expr, result_dict): return cls(**d) return None - def serialize(self): + def serialize(self) -> dict[str, _Json]: d = {} for attr, tgt in self._toSchema.items(): d[tgt] = getattr(self, attr) return d - def to_json(self): + def to_json(self) -> str: return json.dumps(self.serialize(), cls=TypeEncoder, sort_keys=True) def __contains__(self, key): @@ -917,8 +927,8 @@ def generate_definitions(schemas): def generate_facades( - schemas: Dict[str, List[Schema]], -) -> Dict[str, Dict[int, codegen.Capture]]: + schemas: dict[str, list[Schema]], +) -> dict[str, dict[int, codegen.Capture]]: captures = defaultdict(codegen.Capture) # Build the Facade classes diff --git a/juju/client/overrides.py b/juju/client/overrides.py index 4b9446a13..1b5692555 100644 --- a/juju/client/overrides.py +++ b/juju/client/overrides.py @@ -1,8 +1,9 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations import re -from collections import namedtuple +from typing import Any, NamedTuple from . import _client, _definitions from .facade import ReturnMapping, Type, TypeEncoder @@ -22,6 +23,12 @@ ] +class _Change(NamedTuple): + entity: str + type: str + data: dict[str, Any] + + class Delta(Type): """A single websocket delta. @@ -42,12 +49,11 @@ class Delta(Type): _toSchema = {"deltas": "deltas"} _toPy = {"deltas": "deltas"} - def __init__(self, deltas=None): + def __init__(self, deltas: tuple[str, str, dict[str, Any]]): """:param deltas: [str, str, object]""" self.deltas = deltas - Change = namedtuple("Change", "entity type data") - change = Change(*self.deltas) + change = _Change(*self.deltas) self.entity = change.entity self.type = change.type diff --git a/juju/delta.py b/juju/delta.py index b32464e65..c82a20bd5 100644 --- a/juju/delta.py +++ b/juju/delta.py @@ -1,10 +1,12 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations -from .client import client +from . import model +from .client import client, overrides -def get_entity_delta(d): +def get_entity_delta(d: overrides.Delta): return _delta_types[d.entity](d.deltas) @@ -13,12 +15,14 @@ def get_entity_class(entity_type): class EntityDelta(client.Delta): - def get_id(self): + data: dict[str, str] + + def get_id(self) -> str: return self.data["id"] @classmethod - def get_entity_class(cls): - return None + def get_entity_class(cls) -> type[model.ModelEntity]: + raise NotImplementedError() class ActionDelta(EntityDelta): diff --git a/juju/jasyncio.py b/juju/jasyncio.py index 3590e9390..49d499142 100644 --- a/juju/jasyncio.py +++ b/juju/jasyncio.py @@ -21,6 +21,7 @@ ) from asyncio import ( CancelledError, + Task, create_task, wait, ) @@ -84,7 +85,7 @@ ROOT_LOGGER = logging.getLogger() -def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER): +def create_task_with_handler(coro, task_name, logger=ROOT_LOGGER) -> Task: """Wrapper around "asyncio.create_task" to make sure the task exceptions are handled properly. diff --git a/juju/model.py b/juju/model.py index 12f9edced..ada4ac29b 100644 --- a/juju/model.py +++ b/juju/model.py @@ -19,6 +19,7 @@ from datetime import datetime, timedelta from functools import partial from pathlib import Path +from typing import Any import websockets import yaml @@ -28,6 +29,7 @@ from .bundle import BundleHandler, get_charm_series, is_local_charm from .charmhub import CharmHub from .client import client, connector +from .client.connection import Connection from .client.overrides import Caveat, Macaroon from .constraints import parse as parse_constraints from .controller import ConnectedController, Controller @@ -257,7 +259,15 @@ def get_entity(self, entity_type, entity_id, history_index=-1, connected=True): class ModelEntity: """An object in the Model tree""" - def __init__(self, entity_id, model, history_index=-1, connected=True): + entity_id: str + + def __init__( + self, + entity_id: str, + model: Model, + history_index: int = -1, + connected: bool = True, + ): """Initialize a new entity :param entity_id str: The unique id of the object in the model @@ -279,7 +289,7 @@ def __init__(self, entity_id, model, history_index=-1, connected=True): def __repr__(self): return f'<{type(self).__name__} entity_id="{self.entity_id}">' - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Fetch object attributes from the underlying data dict held in the model. @@ -615,7 +625,7 @@ def is_connected(self): """Reports whether the Model is currently connected.""" return self._connector.is_connected() - def connection(self): + def connection(self) -> Connection: """Return the current Connection object. It raises an exception if the Model is disconnected """ @@ -2913,7 +2923,7 @@ async def _get_source_api(self, url): async def wait_for_idle( self, - apps=None, + apps: list[str] | None = None, raise_on_error=True, raise_on_blocked=False, wait_for_active=False, @@ -2923,7 +2933,7 @@ async def wait_for_idle( status=None, wait_for_at_least_units=None, wait_for_exact_units=None, - ): + ) -> None: """Wait for applications in the model to settle into an idle state. :param List[str] apps: Optional list of specific app names to wait on. @@ -3227,16 +3237,16 @@ def make_archive(self, path): zf.close() return path - def _check_type(self, path): + def _check_type(self, path: str) -> str: """Check the path""" - s = os.stat(str(path)) + s = os.stat(path) if stat.S_ISDIR(s.st_mode) or stat.S_ISREG(s.st_mode): return path raise ValueError( "Invalid Charm at %s %s" % (path, "Invalid file type for a charm") ) - def _check_link(self, path): + def _check_link(self, path: str) -> None: link_path = os.readlink(path) if link_path[0] == "/": raise ValueError( @@ -3249,7 +3259,9 @@ def _check_link(self, path): "Invalid charm at %s %s" % (path, "Only internal symlinks are allowed") ) - def _write_symlink(self, zf, link_target, link_path): + def _write_symlink( + self, zf: zipfile.ZipFile, link_target: str, link_path: str + ) -> None: """Package symlinks with appropriate zipfile metadata.""" info = zipfile.ZipInfo() info.filename = link_path @@ -3259,11 +3271,8 @@ def _write_symlink(self, zf, link_target, link_path): info.external_attr = 2716663808 zf.writestr(info, link_target) - def _ignore(self, path): - if path == "build" or path.startswith("build/"): - return True - if path.startswith("."): - return True + def _ignore(self, path: str) -> bool: + return path == "build" or path.startswith("build/") or path.startswith(".") class ModelInfo(ModelEntity): diff --git a/juju/tag.py b/juju/tag.py index 957710288..5057d6398 100644 --- a/juju/tag.py +++ b/juju/tag.py @@ -1,63 +1,64 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations # TODO: Tags should be a proper class, so that we can distinguish whether # something is already a tag or not. For example, 'user-foo' is a valid # username, but is ambiguous with the already-tagged username 'foo'. -def _prefix(prefix, s): +def _prefix(prefix: str, s: str) -> str: if s and not s.startswith(prefix): return f"{prefix}{s}" return s -def untag(prefix, s): +def untag(prefix: str, s: str) -> str: if s and s.startswith(prefix): return s[len(prefix) :] return s -def cloud(cloud_name): +def cloud(cloud_name: str) -> str: return _prefix("cloud-", cloud_name) -def controller(controller_uuid): +def controller(controller_uuid: str) -> str: return _prefix("controller-", controller_uuid) -def credential(cloud, user, credential_name): +def credential(cloud: str, user: str, credential_name: str) -> str: credential_string = f"{cloud}_{user}_{credential_name}" return _prefix("cloudcred-", credential_string) -def model(model_uuid): +def model(model_uuid: str) -> str: return _prefix("model-", model_uuid) -def machine(machine_id): +def machine(machine_id: str) -> str: return _prefix("machine-", machine_id) -def user(username): +def user(username: str) -> str: return _prefix("user-", username) -def application(app_name): +def application(app_name: str) -> str: return _prefix("application-", app_name) -def storage(app_name): +def storage(app_name: str) -> str: return _prefix("storage-", app_name) -def unit(unit_name): +def unit(unit_name: str) -> str: return _prefix("unit-", unit_name.replace("/", "-")) -def action(action_uuid): +def action(action_uuid: str) -> str: return _prefix("action-", action_uuid) -def space(space_name): +def space(space_name: str) -> str: return _prefix("space-", space_name) diff --git a/juju/utils.py b/juju/utils.py index 1692816f1..710fcc56e 100644 --- a/juju/utils.py +++ b/juju/utils.py @@ -1,13 +1,15 @@ # Copyright 2023 Canonical Ltd. # Licensed under the Apache V2, see LICENCE file for details. +from __future__ import annotations +import asyncio import base64 import os import textwrap import zipfile from collections import defaultdict -from functools import partial from pathlib import Path +from typing import Any import yaml from pyasn1.codec.der.encoder import encode @@ -20,11 +22,11 @@ async def execute_process(*cmd, log=None): """Wrapper around asyncio.create_subprocess_exec.""" - p = await jasyncio.create_subprocess_exec( + p = await asyncio.create_subprocess_exec( *cmd, - stdin=jasyncio.subprocess.PIPE, - stdout=jasyncio.subprocess.PIPE, - stderr=jasyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await p.communicate() if log: @@ -84,7 +86,7 @@ async def read_ssh_key(): """Attempt to read the local juju admin's public ssh key, so that it can be passed on to a model. """ - loop = jasyncio.get_running_loop() + loop = asyncio.get_running_loop() return await loop.run_in_executor(None, _read_ssh_key) @@ -93,20 +95,31 @@ class IdQueue: ID. """ - def __init__(self, maxsize=0): - self._queues = defaultdict(partial(jasyncio.Queue, maxsize)) - - async def get(self, id_): + _queues: dict[int, asyncio.Queue[dict[str, Any] | Exception]] + + def __init__(self): + self._queues = defaultdict(asyncio.Queue) + # FIXME cleanup needed. + # in some cases an Exception is put into the queue. + # if the main coro exits, this exception will be logged as "never awaited" + # we gotta do something about that to keep the output clean. + # + # Additionally, it's conceivable that a response is put in the queue + # and then an exception is put via put_all() + # the reader only ever fetches one item, and exception is "never awaited" + # rewrite put_all to replace the pending response instead. + + async def get(self, id_: int) -> dict[str, Any]: value = await self._queues[id_].get() del self._queues[id_] if isinstance(value, Exception): raise value return value - async def put(self, id_, value): + async def put(self, id_: int, value: dict[str, Any]): await self._queues[id_].put(value) - async def put_all(self, value): + async def put_all(self, value: Exception): for queue in self._queues.values(): await queue.put(value) @@ -120,9 +133,9 @@ async def block_until(*conditions, timeout=None, wait_period=0.5): async def _block(): while not all(c() for c in conditions): - await jasyncio.sleep(wait_period) + await asyncio.sleep(wait_period) - await jasyncio.shield(jasyncio.wait_for(_block(), timeout)) + await asyncio.shield(asyncio.wait_for(_block(), timeout)) async def block_until_with_coroutine( @@ -136,12 +149,12 @@ async def block_until_with_coroutine( async def _block(): while not await condition_coroutine(): - await jasyncio.sleep(wait_period) + await asyncio.sleep(wait_period) - await jasyncio.shield(jasyncio.wait_for(_block(), timeout=timeout)) + await asyncio.shield(asyncio.wait_for(_block(), timeout=timeout)) -async def wait_for_bundle(model, bundle, **kwargs): +async def wait_for_bundle(model, bundle: str | Path, **kwargs) -> None: """Helper to wait for just the apps in a specific bundle. Equivalent to loading the bundle, pulling out the app names, and calling:: @@ -156,8 +169,8 @@ async def wait_for_bundle(model, bundle, **kwargs): bundle = bundle_path / "bundle.yaml" except OSError: pass - bundle = yaml.safe_load(textwrap.dedent(bundle).strip()) - apps = list(bundle.get("applications", bundle.get("services")).keys()) + content: dict[str, Any] = yaml.safe_load(textwrap.dedent(bundle).strip()) + apps = list(content.get("applications", content.get("services")).keys()) await model.wait_for_idle(apps, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 7b9e36e91..e55113be8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -211,7 +211,8 @@ ignore = [ [tool.pyright] # These are tentative -# include = ["**/*.py"] -pythonVersion = "3.8" # check no python > 3.8 features are used -pythonPlatform = "All" +include = ["**/*.py"] +pythonVersion = "3.8" typeCheckingMode = "strict" +useLibraryCodeForTypes = true +reportGeneralTypeIssues = true