Skip to content

Commit

Permalink
feat: add validator argument (#12)
Browse files Browse the repository at this point in the history
This change adds the `validator` argument to the `Env.var` method to
allow passing in a validator function for the retrieved value.
  • Loading branch information
P403n1x87 authored Jul 5, 2022
1 parent 3dd7a4c commit 2a39ed2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
19 changes: 16 additions & 3 deletions envier/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
type, # type: Union[object, Type[T]]
name, # type: str
parser=None, # type: Optional[Callable[[str], T]]
validator=None, # type: Optional[Callable[[T], None]]
map=None, # type: Optional[MapType]
default=NoDefault, # type: Union[T, NoDefaultType]
deprecations=None, # type: Optional[List[DeprecationInfo]]
Expand All @@ -63,11 +64,12 @@ def __init__(
self.type = type
self.name = name
self.parser = parser
self.validator = validator
self.map = map
self.default = default
self.deprecations = deprecations

def __call__(self, env, prefix):
def _retrieve(self, env, prefix):
# type: (Env, str) -> T
source = env.source

Expand Down Expand Up @@ -142,6 +144,15 @@ def __call__(self, env, prefix):

return self.type(raw) # type: ignore[call-arg,operator]

def __call__(self, env, prefix):
# type: (Env, str) -> T
value = self._retrieve(env, prefix)

if self.validator is not None:
self.validator(value)

return value


class DerivedVariable(Generic[T]):
def __init__(self, type, derivation):
Expand Down Expand Up @@ -234,25 +245,27 @@ def var(
type, # type: Type[T]
name, # type: str
parser=None, # type: Optional[Callable[[str], T]]
validator=None, # type: Optional[Callable[[T], None]]
map=None, # type: Optional[MapType]
default=NoDefault, # type: Union[T, NoDefaultType]
deprecations=None, # type: Optional[List[DeprecationInfo]]
):
# type: (...) -> EnvVariable[T]
return EnvVariable(type, name, parser, map, default, deprecations)
return EnvVariable(type, name, parser, validator, map, default, deprecations)

@classmethod
def v(
cls,
type, # type: Union[object, Type[T]]
name, # type: str
parser=None, # type: Optional[Callable[[str], T]]
validator=None, # type: Optional[Callable[[T], None]]
map=None, # type: Optional[MapType]
default=NoDefault, # type: Union[T, NoDefaultType]
deprecations=None, # type: Optional[List[DeprecationInfo]]
):
# type: (...) -> EnvVariable[T]
return EnvVariable(type, name, parser, map, default, deprecations)
return EnvVariable(type, name, parser, validator, map, default, deprecations)

@classmethod
def der(cls, type, derivation):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,29 @@ class DictConfig(Env):
foo = Env.der(Optional[int], lambda _: value)

assert DictConfig().foo is value


@pytest.mark.parametrize(
"value,exc",
[
(0, None),
(512, None),
(-1, ValueError),
(513, ValueError),
],
)
def test_env_validator(monkeypatch, value, exc):
monkeypatch.setenv("FOO", str(value))

class Config(Env):
def validate(value):
if not (0 <= value <= 512):
raise ValueError("Value must be between 0 and 512")

foo = Env.var(int, "FOO", validator=validate)

if exc is not None:
with pytest.raises(exc):
Config()
else:
assert Config().foo == value

0 comments on commit 2a39ed2

Please sign in to comment.