From 397c93f6d59c0f874c0076c78898a15f6c3b76b7 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Mon, 21 Oct 2024 16:35:20 +0100 Subject: [PATCH] feat: add items iterator (#38) We extend the public API with an iterator that allows retrieving all the configuration items. This returns a mapping between attribute paths and environment variable instances. We also attach the full variable name to each EnvVariable instance, which is accessible via the ``full_name`` property. --- envier/env.py | 120 ++++++++++++++++++++++++++++++++-------------- tests/test_env.py | 50 ++++++++++++++++++- 2 files changed, 132 insertions(+), 38 deletions(-) diff --git a/envier/env.py b/envier/env.py index 3543691..eccab8d 100644 --- a/envier/env.py +++ b/envier/env.py @@ -1,3 +1,4 @@ +from collections import deque from collections import namedtuple import os import typing as t @@ -68,6 +69,12 @@ def __init__( self.help_type = help_type self.help_default = help_default + self._full_name = _normalized(name) # Will be set by the EnvMeta metaclass + + @property + def full_name(self) -> str: + return f"_{self._full_name}" if self.private else self._full_name + def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any: if _type is bool: return t.cast(T, raw.lower() in env.__truthy__) @@ -100,9 +107,7 @@ def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any: def _retrieve(self, env: "Env", prefix: str) -> T: source = env.source - full_name = prefix + _normalized(self.name) - if self.private: - full_name = f"_{full_name}" + full_name = self.full_name raw = source.get(full_name.format(**env.dynamic)) if raw is None and self.deprecations: for name, deprecated_when, removed_when in self.deprecations: @@ -167,10 +172,8 @@ def __call__(self, env: "Env", prefix: str) -> T: try: self.validator(value) except ValueError as e: - full_name = prefix + _normalized(self.name) - raise ValueError( - "Invalid value for environment variable %s: %s" % (full_name, e) - ) + msg = f"Invalid value for environment variable {self.full_name}: {e}" + raise ValueError(msg) return value @@ -191,7 +194,22 @@ def __call__(self, env: "Env") -> T: return value -class Env(object): +class EnvMeta(type): + def __new__( + cls, name: str, bases: t.Tuple[t.Type], ns: t.Dict[str, t.Any] + ) -> t.Any: + env = t.cast("Env", super().__new__(cls, name, bases, ns)) + + prefix = ns.get("__prefix__") + if prefix: + for v in env.values(recursive=True): + if isinstance(v, EnvVariable): + v._full_name = f"{_normalized(prefix)}_{v._full_name}".upper() + + return env + + +class Env(metaclass=EnvMeta): """Env base class. This class is meant to be subclassed. The configuration is declared by using @@ -336,26 +354,42 @@ def d( return DerivedVariable(type, derivation) @classmethod - def keys(cls) -> t.Iterator[str]: - """Return the names of all the items.""" - return ( - k - for k, v in cls.__dict__.items() - if isinstance(v, (EnvVariable, DerivedVariable)) - or isinstance(v, type) - and issubclass(v, Env) - ) + def items( + cls, recursive: bool = False, include_derived: bool = False + ) -> t.Iterator[t.Tuple[str, t.Union[EnvVariable, DerivedVariable]]]: + classes = (EnvVariable, DerivedVariable) if include_derived else (EnvVariable,) + q: t.Deque[t.Tuple[t.Tuple[str], t.Type["Env"]]] = deque() + path: t.Tuple[str] = tuple() # type: ignore[assignment] + q.append((path, cls)) + while q: + path, env = q.popleft() + for k, v in env.__dict__.items(): + if isinstance(v, classes): + yield ( + ".".join((*path, k)), + t.cast(t.Union[EnvVariable, DerivedVariable], v), + ) + elif isinstance(v, type) and issubclass(v, Env) and recursive: + item_name = getattr(v, "__item__", k) + if item_name is None: + item_name = k + q.append(((*path, item_name), v)) # type: ignore[arg-type] @classmethod - def values(cls) -> t.Iterator[t.Union[EnvVariable, DerivedVariable, t.Type["Env"]]]: - """Return the names of all the items.""" - return ( - v - for v in cls.__dict__.values() - if isinstance(v, (EnvVariable, DerivedVariable)) - or isinstance(v, type) - and issubclass(v, Env) - ) + def keys( + cls, recursive: bool = False, include_derived: bool = False + ) -> t.Iterator[str]: + """Return the name of all the configuration items.""" + for k, _ in cls.items(recursive, include_derived): + yield k + + @classmethod + def values( + cls, recursive: bool = False, include_derived: bool = False + ) -> t.Iterator[t.Union[EnvVariable, DerivedVariable, t.Type["Env"]]]: + """Return the value of all the configuration items.""" + for _, v in cls.items(recursive, include_derived): + yield v @classmethod def include( @@ -371,14 +405,6 @@ def include( operation would result in some variables being overwritten. This can be disabled by setting the ``overwrite`` argument to ``True``. """ - if namespace is not None: - if not overwrite and hasattr(cls, namespace): - raise ValueError("Namespace already in use: {}".format(namespace)) - - setattr(cls, namespace, env_spec) - - return None - # Pick only the attributes that define variables. to_include = { k: v @@ -387,14 +413,36 @@ def include( or isinstance(v, type) and issubclass(v, Env) } - if not overwrite: overlap = set(cls.__dict__.keys()) & set(to_include.keys()) if overlap: raise ValueError("Configuration clashes detected: {}".format(overlap)) + own_prefix = _normalized(getattr(cls, "__prefix__", "")) + + if namespace is not None: + if not overwrite and hasattr(cls, namespace): + raise ValueError("Namespace already in use: {}".format(namespace)) + + if getattr(cls, namespace, None) is not env_spec: + setattr(cls, namespace, env_spec) + + if own_prefix: + for _, v in to_include.items(): + if isinstance(v, EnvVariable): + v._full_name = f"{own_prefix}_{v._full_name}" + + return None + + other_prefix = getattr(env_spec, "__prefix__", "") for k, v in to_include.items(): - setattr(cls, k, v) + if getattr(cls, k, None) is not v: + setattr(cls, k, v) + if isinstance(v, EnvVariable): + if other_prefix: + v._full_name = v._full_name[len(other_prefix) + 1 :] # noqa + if own_prefix: + v._full_name = f"{own_prefix}_{v._full_name}" @classmethod def help_info( diff --git a/tests/test_env.py b/tests/test_env.py index 72ccf7e..d16e78a 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -161,7 +161,12 @@ class GlobalConfig(Env): service = ServiceConfig config = GlobalConfig() - assert set(config.keys()) == {"debug_mode", "service"} + assert set(config.keys()) == {"debug_mode"} + assert set(config.keys(recursive=True)) == { + "debug_mode", + "service.host", + "service.port", + } assert config.service.port == 8080 @@ -178,11 +183,23 @@ class ServiceConfig(Env): host = Env.var(str, "host", default="localhost") port = Env.var(int, "port", default=3000) + _private = Env.var(int, "private", default=42, private=True) config = GlobalConfig() - assert set(config.keys()) == {"debug_mode", "service"} + assert set(config.keys()) == {"debug_mode"} + assert set(config.keys(recursive=True)) == { + "debug_mode", + "service.host", + "service.port", + "service._private", + } assert config.service.port == 8080 + assert GlobalConfig.debug_mode.full_name == "MYAPP_DEBUG" + assert GlobalConfig.service.host.full_name == "MYAPP_SERVICE_HOST" + assert GlobalConfig.service.port.full_name == "MYAPP_SERVICE_PORT" + assert GlobalConfig.service._private.full_name == "_MYAPP_SERVICE_PRIVATE" + def test_env_include(): class GlobalConfig(Env): @@ -383,3 +400,32 @@ class Config(Env): ("_PRIVATE_FOO", "int", "42", ""), ("PUBLIC_FOO", "int", "42", ""), } + + assert Config.private.full_name == "_PRIVATE_FOO" + + +def test_env_items(monkeypatch): + monkeypatch.setenv("MYAPP_SERVICE_PORT", "8080") + + class GlobalConfig(Env): + __prefix__ = "myapp" + + debug_mode = Env.var(bool, "debug", default=False) + + class ServiceConfig(Env): + __item__ = __prefix__ = "service" + + host = Env.var(str, "host", default="localhost") + port = Env.var(int, "port", default=3000) + _private = Env.var(int, "private", default=42, private=True) + + items = list(GlobalConfig.items()) + assert items == [("debug_mode", GlobalConfig.debug_mode)] + + items = list(GlobalConfig.items(recursive=True)) + assert items == [ + ("debug_mode", GlobalConfig.debug_mode), + ("service.host", GlobalConfig.ServiceConfig.host), + ("service.port", GlobalConfig.ServiceConfig.port), + ("service._private", GlobalConfig.ServiceConfig._private), + ]