Skip to content

Commit

Permalink
Merge pull request #43 from alan-turing-institute/linting
Browse files Browse the repository at this point in the history
Add additional linting checks
  • Loading branch information
jemrobinson authored May 31, 2024
2 parents d66a77a + 2d34589 commit fe41f39
Show file tree
Hide file tree
Showing 28 changed files with 600 additions and 427 deletions.
2 changes: 1 addition & 1 deletion apricot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .patches import LDAPString # noqa: F401

__all__ = [
"ApricotServer",
"__version__",
"__version_info__",
"ApricotServer",
]
35 changes: 28 additions & 7 deletions apricot/apricot_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import inspect
import sys
from typing import Any, cast
from typing import Any, Self, cast

from twisted.internet import reactor, task
from twisted.internet.endpoints import quoteStringArgument, serverFromString
Expand All @@ -13,8 +15,10 @@


class ApricotServer:
"""The Apricot server running via Twisted."""

def __init__(
self,
self: Self,
backend: OAuthBackend,
client_id: str,
client_secret: str,
Expand All @@ -32,6 +36,23 @@ def __init__(
tls_private_key: str | None = None,
**kwargs: Any,
) -> None:
"""Initialise an ApricotServer.
@param backend: An OAuth backend,
@param client_id: An OAuth client ID
@param client_secret: An OAuth client secret
@param domain: The OAuth domain
@param port: Port to expose LDAP on
@param background_refresh: Whether to refresh the LDAP tree in the background
@param debug: Enable debug output
@param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users
@param redis_host: Host for a Redis cache (if used)
@param redis_port: Port for a Redis cache (if used)
@param refresh_interval: Interval after which the LDAP information is stale
@param tls_port: Port to expose LDAPS on
@param tls_certificate: TLS certificate for LDAPS
@param tls_private_key: TLS private key for LDAPS
"""
self.debug = debug

# Log to stdout
Expand All @@ -41,7 +62,7 @@ def __init__(
uid_cache: UidCache
if redis_host and redis_port:
log.msg(
f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'."
f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'.",
)
uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port)
else:
Expand All @@ -54,7 +75,7 @@ def __init__(
log.msg(f"Creating an OAuthClient for {backend}.")
oauth_backend = OAuthClientMap[backend]
oauth_backend_args = inspect.getfullargspec(
oauth_backend.__init__ # type: ignore
oauth_backend.__init__, # type: ignore[misc]
).args
oauth_client = oauth_backend(
client_id=client_id,
Expand All @@ -81,7 +102,7 @@ def __init__(
if background_refresh:
if self.debug:
log.msg(
f"Starting background refresh (interval={factory.adaptor.refresh_interval})"
f"Starting background refresh (interval={factory.adaptor.refresh_interval})",
)
loop = task.LoopingCall(factory.adaptor.refresh)
loop.start(factory.adaptor.refresh_interval)
Expand Down Expand Up @@ -111,8 +132,8 @@ def __init__(
# Load the Twisted reactor
self.reactor = cast(IReactorCore, reactor)

def run(self) -> None:
"""Start the Twisted reactor"""
def run(self: Self) -> None:
"""Start the Twisted reactor."""
if self.debug:
log.msg("Starting the Twisted reactor.")
self.reactor.run()
19 changes: 13 additions & 6 deletions apricot/cache/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import annotations

from typing import Self

from .uid_cache import UidCache


class LocalCache(UidCache):
def __init__(self) -> None:
"""Implementation of UidCache using an in-memory dictionary."""

def __init__(self: Self) -> None:
"""Initialise a RedisCache."""
self.cache: dict[str, int] = {}

def get(self, identifier: str) -> int | None:
def get(self: Self, identifier: str) -> int | None:
return self.cache.get(identifier, None)

def keys(self) -> list[str]:
return [str(k) for k in self.cache.keys()]
def keys(self: Self) -> list[str]:
return [str(k) for k in self.cache]

def set(self, identifier: str, uid_value: int) -> None:
def set(self: Self, identifier: str, uid_value: int) -> None:
self.cache[identifier] = uid_value

def values(self, keys: list[str]) -> list[int]:
def values(self: Self, keys: list[str]) -> list[int]:
return [v for k, v in self.cache.items() if k in keys]
35 changes: 22 additions & 13 deletions apricot/cache/redis_cache.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,45 @@
from typing import cast
from __future__ import annotations

from typing import Self, cast

import redis

from .uid_cache import UidCache


class RedisCache(UidCache):
def __init__(self, redis_host: str, redis_port: int) -> None:
"""Implementation of UidCache using a Redis backend."""

def __init__(self: Self, redis_host: str, redis_port: int) -> None:
"""Initialise a RedisCache.
@param redis_host: Host for the Redis cache
@param redis_port: Port for the Redis cache
"""
self.redis_host = redis_host
self.redis_port = redis_port
self.cache_: "redis.Redis[str]" | None = None # noqa: UP037
self.cache_: redis.Redis[str] | None = None

@property
def cache(self) -> "redis.Redis[str]":
"""
Lazy-load the cache on request
"""
def cache(self: Self) -> redis.Redis[str]:
"""Lazy-load the cache on request."""
if not self.cache_:
self.cache_ = redis.Redis(
host=self.redis_host, port=self.redis_port, decode_responses=True
host=self.redis_host,
port=self.redis_port,
decode_responses=True,
)
return self.cache_

def get(self, identifier: str) -> int | None:
def get(self: Self, identifier: str) -> int | None:
value = self.cache.get(identifier)
return None if value is None else int(value)

def keys(self) -> list[str]:
return [str(k) for k in self.cache.keys()]
def keys(self: Self) -> list[str]:
return [str(k) for k in self.cache.keys()] # noqa: SIM118

def set(self, identifier: str, uid_value: int) -> None:
def set(self: Self, identifier: str, uid_value: int) -> None:
self.cache.set(identifier, uid_value)

def values(self, keys: list[str]) -> list[int]:
def values(self: Self, keys: list[str]) -> list[int]:
return [int(cast(str, v)) for v in self.cache.mget(keys)]
74 changes: 31 additions & 43 deletions apricot/cache/uid_cache.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,49 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import cast
from typing import Self, cast


class UidCache(ABC):
"""Abstract cache for storing UIDs."""

@abstractmethod
def get(self, identifier: str) -> int | None:
"""
Get the UID for a given identifier, returning None if it does not exist
"""
pass
def get(self: Self, identifier: str) -> int | None:
"""Get the UID for a given identifier, returning None if it does not exist."""

@abstractmethod
def keys(self) -> list[str]:
"""
Get list of cached keys
"""
pass
def keys(self: Self) -> list[str]:
"""Get list of cached keys."""

@abstractmethod
def set(self, identifier: str, uid_value: int) -> None:
"""
Set the UID for a given identifier
"""
pass
def set(self: Self, identifier: str, uid_value: int) -> None:
"""Set the UID for a given identifier."""

@abstractmethod
def values(self, keys: list[str]) -> list[int]:
"""
Get list of cached values corresponding to requested keys
"""
pass
def values(self: Self, keys: list[str]) -> list[int]:
"""Get list of cached values corresponding to requested keys."""

def get_group_uid(self, identifier: str) -> int:
"""
Get UID for a group, constructing one if necessary
def get_group_uid(self: Self, identifier: str) -> int:
"""Get UID for a group, constructing one if necessary.
@param identifier: Identifier for group needing a UID
"""
return self.get_uid(identifier, category="group", min_value=3000)

def get_user_uid(self, identifier: str) -> int:
"""
Get UID for a user, constructing one if necessary
def get_user_uid(self: Self, identifier: str) -> int:
"""Get UID for a user, constructing one if necessary.
@param identifier: Identifier for user needing a UID
"""
return self.get_uid(identifier, category="user", min_value=2000)

def get_uid(
self, identifier: str, category: str, min_value: int | None = None
self: Self,
identifier: str,
category: str,
min_value: int | None = None,
) -> int:
"""
Get UID, constructing one if necessary.
"""Get UID, constructing one if necessary.
@param identifier: Identifier for object needing a UID
@param category: Category the object belongs to
Expand All @@ -60,14 +52,13 @@ def get_uid(
identifier_ = f"{category}-{identifier}"
uid = self.get(identifier_)
if not uid:
min_value = min_value if min_value else 0
min_value = min_value or 0
next_uid = max(self._get_max_uid(category) + 1, min_value)
self.set(identifier_, next_uid)
return cast(int, self.get(identifier_))

def _get_max_uid(self, category: str | None) -> int:
"""
Get maximum UID for a given category
def _get_max_uid(self: Self, category: str | None) -> int:
"""Get maximum UID for a given category.
@param category: Category to check UIDs for
"""
Expand All @@ -78,27 +69,24 @@ def _get_max_uid(self, category: str | None) -> int:
values = [*self.values(keys), -999]
return max(values)

def overwrite_group_uid(self, identifier: str, uid: int) -> None:
"""
Set UID for a group, overwriting the existing value if there is one
def overwrite_group_uid(self: Self, identifier: str, uid: int) -> None:
"""Set UID for a group, overwriting the existing value if there is one.
@param identifier: Identifier for group
@param uid: Desired UID
"""
return self.overwrite_uid(identifier, category="group", uid=uid)

def overwrite_user_uid(self, identifier: str, uid: int) -> None:
"""
Get UID for a user, constructing one if necessary
def overwrite_user_uid(self: Self, identifier: str, uid: int) -> None:
"""Get UID for a user, constructing one if necessary.
@param identifier: Identifier for user
@param uid: Desired UID
"""
return self.overwrite_uid(identifier, category="user", uid=uid)

def overwrite_uid(self, identifier: str, category: str, uid: int) -> None:
"""
Set UID, overwriting the existing one if necessary.
def overwrite_uid(self: Self, identifier: str, category: str, uid: int) -> None:
"""Set UID, overwriting the existing one if necessary.
@param identifier: Identifier for object
@param category: Category the object belongs to
Expand Down
32 changes: 18 additions & 14 deletions apricot/ldap/oauth_ldap_entry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import cast
from __future__ import annotations

from typing import Self, cast

from ldaptor.inmemory import ReadOnlyInMemoryLDAPEntry
from ldaptor.protocols.ldap.distinguishedname import (
Expand All @@ -16,17 +18,18 @@


class OAuthLDAPEntry(ReadOnlyInMemoryLDAPEntry):
"""An LDAP entry that represents a view of an OAuth object."""

dn: DistinguishedName
attributes: LDAPAttributeDict

def __init__(
self,
self: Self,
dn: DistinguishedName | str,
attributes: LDAPAttributeDict,
oauth_client: OAuthClient | None = None,
) -> None:
"""
Initialize the object.
"""Initialize the object.
@param dn: Distinguished Name of the object
@param attributes: Attributes of the object.
Expand All @@ -37,7 +40,7 @@ def __init__(
dn = DistinguishedName(stringValue=dn)
super().__init__(dn, attributes)

def __str__(self) -> str:
def __str__(self: Self) -> str:
output = bytes(self.toWire()).decode("utf-8")
for child in self._children.values():
try:
Expand All @@ -52,18 +55,19 @@ def __str__(self) -> str:
return output

@property
def oauth_client(self) -> OAuthClient:
if not self.oauth_client_:
if hasattr(self._parent, "oauth_client"):
self.oauth_client_ = self._parent.oauth_client
def oauth_client(self: Self) -> OAuthClient:
if not self.oauth_client_ and hasattr(self._parent, "oauth_client"):
self.oauth_client_ = self._parent.oauth_client
if not isinstance(self.oauth_client_, OAuthClient):
msg = f"OAuthClient is of incorrect type {type(self.oauth_client_)}"
raise TypeError(msg)
return self.oauth_client_

def add_child(
self, rdn: RelativeDistinguishedName | str, attributes: LDAPAttributeDict
) -> "OAuthLDAPEntry":
self: Self,
rdn: RelativeDistinguishedName | str,
attributes: LDAPAttributeDict,
) -> OAuthLDAPEntry:
if isinstance(rdn, str):
rdn = RelativeDistinguishedName(stringValue=rdn)
try:
Expand All @@ -73,8 +77,8 @@ def add_child(
output = self._children[rdn.getText()]
return cast(OAuthLDAPEntry, output)

def bind(self, password: bytes) -> defer.Deferred["OAuthLDAPEntry"]:
def _bind(password: bytes) -> "OAuthLDAPEntry":
def bind(self: Self, password: bytes) -> defer.Deferred[OAuthLDAPEntry]:
def _bind(password: bytes) -> OAuthLDAPEntry:
oauth_username = next(iter(self.get("oauth_username", "unknown")))
s_password = password.decode("utf-8")
if self.oauth_client.verify(username=oauth_username, password=s_password):
Expand All @@ -84,5 +88,5 @@ def _bind(password: bytes) -> "OAuthLDAPEntry":

return defer.maybeDeferred(_bind, password)

def list_children(self) -> "list[OAuthLDAPEntry]":
def list_children(self: Self) -> list[OAuthLDAPEntry]:
return [cast(OAuthLDAPEntry, entry) for entry in self._children.values()]
Loading

0 comments on commit fe41f39

Please sign in to comment.