Skip to content

Commit

Permalink
Refactor options parsing to validate complex data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
nat-n committed Nov 9, 2024
1 parent fc5b0e3 commit 4b2e9c5
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 155 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
run-tests:
name: Run tests
strategy:
fail-fast: false
matrix:
os: [Ubuntu, MacOS, Windows]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
Expand Down
17 changes: 12 additions & 5 deletions poethepoet/config/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
Optional,
Sequence,
Type,
TypedDict,
Union,
)

from ..exceptions import ConfigValidationError
from ..options import NoValue, PoeOptions
from .primitives import EmptyDict, EnvDefault

KNOWN_SHELL_INTERPRETERS = (
"posix",
Expand All @@ -24,6 +26,14 @@
)


class IncludeItem(TypedDict):
path: str
cwd: str


IncludeItem.__optional_keys__ = frozenset({"cwd"})


class ConfigPartition:
options: PoeOptions
full_config: Mapping[str, Any]
Expand Down Expand Up @@ -74,9 +84,6 @@ def get(self, key: str, default: Any = NoValue):
return self.options.get(key, default)


EmptyDict: Mapping = MappingProxyType({})


class ProjectConfig(ConfigPartition):
is_primary = True

Expand All @@ -88,10 +95,10 @@ class ConfigOptions(PoeOptions):
default_task_type: str = "cmd"
default_array_task_type: str = "sequence"
default_array_item_task_type: str = "ref"
env: Mapping[str, str] = EmptyDict
env: Mapping[str, Union[str, EnvDefault]] = EmptyDict
envfile: Union[str, Sequence[str]] = tuple()
executor: Mapping[str, str] = MappingProxyType({"type": "auto"})
include: Sequence[str] = tuple()
include: Union[str, Sequence[str], Sequence[IncludeItem]] = tuple()
poetry_command: str = "poe"
poetry_hooks: Mapping[str, str] = EmptyDict
shell_interpreter: Union[str, Sequence[str]] = "posix"
Expand Down
8 changes: 8 additions & 0 deletions poethepoet/config/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from types import MappingProxyType
from typing import Mapping, TypedDict

EmptyDict: Mapping = MappingProxyType({})


class EnvDefault(TypedDict):
default: str
131 changes: 24 additions & 107 deletions poethepoet/options.py → poethepoet/options/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
import collections
from __future__ import annotations

from keyword import iskeyword
from typing import (
Any,
Dict,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
get_args,
get_origin,
)

from .exceptions import ConfigValidationError
from typing import Any, Mapping, Sequence, get_type_hints

from ..exceptions import ConfigValidationError
from .annotations import TypeAnnotation

NoValue = object()

Expand All @@ -26,7 +14,7 @@ class PoeOptions:
A special kind of config object that parses options ...
"""

__annotations: Dict[str, Type]
__annotations: dict[str, TypeAnnotation]

def __init__(self, **options: Any):
for key in self.get_fields():
Expand Down Expand Up @@ -61,13 +49,13 @@ def __getattr__(self, name: str):
@classmethod
def parse(
cls,
source: Union[Mapping[str, Any], list],
source: Mapping[str, Any] | list,
strict: bool = True,
extra_keys: Sequence[str] = tuple(),
):
config_keys = {
key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: vtype
for key, vtype in cls.get_fields().items()
key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: type_
for key, type_ in cls.get_fields().items()
}
if strict:
for index, item in enumerate(cls.normalize(source, strict)):
Expand Down Expand Up @@ -110,29 +98,8 @@ def _parse_value(
return value_type.parse(value, strict=strict)

if strict:
expected_type: Union[Type, Tuple[Type, ...]] = cls._type_of(value_type)
if not isinstance(value, expected_type):
# Try format expected_type nicely in the error message
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)
formatted_type = " | ".join(
type_.__name__ for type_ in expected_type if type_ is not type(None)
)
raise ConfigValidationError(
f"Option {key!r} should have a value of type: {formatted_type}",
index=index,
)

annotation = cls.get_annotation(key)
if get_origin(annotation) is Literal:
allowed_values = get_args(annotation)
if value not in allowed_values:
raise ConfigValidationError(
f"Option {key!r} must be one of {allowed_values!r}",
index=index,
)

# TODO: validate list/dict contents
for error_msg in value_type.validate((key,), value):
raise ConfigValidationError(error_msg, index=index)

return value

Expand Down Expand Up @@ -171,43 +138,25 @@ def get(self, key: str, default: Any = NoValue) -> Any:
if default is NoValue:
# Fallback to getting getting the zero value for the type of this attribute
# e.g. 0, False, empty list, empty dict, etc
return self.__get_zero_value(key)
annotation = self.get_fields().get(self._resolve_key(key))
assert annotation
return annotation.zero_value()

return default

def __get_zero_value(self, key: str):
type_of_attr = self.type_of(key)
if isinstance(type_of_attr, tuple):
if type(None) in type_of_attr:
# Optional types default to None
return None
type_of_attr = type_of_attr[0]
assert type_of_attr
return type_of_attr()

def __is_optional(self, key: str):
# TODO: precache optional options keys?
type_of_attr = self.type_of(key)
if isinstance(type_of_attr, tuple):
return type(None) in type_of_attr
return False
annotation = self.get_fields().get(self._resolve_key(key))
assert annotation
return annotation.is_optional

def update(self, options_dict: Dict[str, Any]):
def update(self, options_dict: dict[str, Any]):
new_options_dict = {}
for key in self.get_fields().keys():
if key in options_dict:
new_options_dict[key] = options_dict[key]
elif hasattr(self, key):
new_options_dict[key] = getattr(self, key)

@classmethod
def type_of(cls, key: str) -> Optional[Union[Type, Tuple[Type, ...]]]:
return cls._type_of(cls.get_annotation(key))

@classmethod
def get_annotation(cls, key: str) -> Optional[Type]:
return cls.get_fields().get(cls._resolve_key(key))

@classmethod
def _resolve_key(cls, key: str) -> str:
"""
Expand All @@ -219,51 +168,19 @@ def _resolve_key(cls, key: str) -> str:
return key

@classmethod
def _type_of(cls, annotation: Any) -> Union[Type, Tuple[Type, ...]]:
if get_origin(annotation) is Union:
result: List[Type] = []
for component in get_args(annotation):
component_type = cls._type_of(component)
if isinstance(component_type, tuple):
result.extend(component_type)
else:
result.append(component_type)
return tuple(result)

if get_origin(annotation) in (
dict,
Mapping,
MutableMapping,
collections.abc.Mapping,
collections.abc.MutableMapping,
):
return dict

if get_origin(annotation) in (
list,
Sequence,
collections.abc.Sequence,
):
return list

if get_origin(annotation) is Literal:
return tuple({type(arg) for arg in get_args(annotation)})

return annotation

@classmethod
def get_fields(cls) -> Dict[str, Any]:
def get_fields(cls) -> dict[str, TypeAnnotation]:
"""
Recent python versions removed inheritance for __annotations__
so we have to implement it explicitly
"""
if not hasattr(cls, "__annotations"):
annotations = {}
for base_cls in cls.__bases__:
annotations.update(base_cls.__annotations__)
annotations.update(cls.__annotations__)
annotations.update(get_type_hints(base_cls))
annotations.update(get_type_hints(cls))

cls.__annotations = {
key: type_
key: TypeAnnotation.parse(type_)
for key, type_ in annotations.items()
if not key.startswith("_")
}
Expand Down
Loading

0 comments on commit 4b2e9c5

Please sign in to comment.