Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various improvements to config validation logic and error handling in general #251

Merged
merged 2 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions poethepoet/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def resolve_task(self, allow_hidden: bool = False) -> Optional["PoeTask"]:

task = tuple(self.ui["task"])
if not task:
self.print_help(info="No task specified.")
try:
self.print_help(info="No task specified.")
except PoeException as error:
self.print_help(error=error)
return None

task_name = task[0]
Expand Down Expand Up @@ -184,7 +187,12 @@ def run_task_graph(self, task: "PoeTask") -> Optional[int]:
from .task.graph import TaskExecutionGraph

context = self.get_run_context(multistage=True)
graph = TaskExecutionGraph(task, context)
try:
graph = TaskExecutionGraph(task, context)
except PoeException as error:
self.print_help(error=error)
return 1

plan = graph.get_execution_plan()

for stage in plan:
Expand Down Expand Up @@ -240,7 +248,9 @@ def print_help(
task_name: (
(
content.get("help", ""),
PoeTaskArgs.get_help_content(content.get("args")),
PoeTaskArgs.get_help_content(
content.get("args"), task_name, suppress_errors=bool(error)
),
)
if isinstance(content, dict)
else ("", tuple())
Expand Down
17 changes: 12 additions & 5 deletions poethepoet/config/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
Optional,
Sequence,
Type,
TypedDict,
Union,
)

from ..exceptions import ConfigValidationError
from ..options import NoValue, PoeOptions
from .primitives import EmptyDict, EnvDefault

KNOWN_SHELL_INTERPRETERS = (
"posix",
Expand All @@ -24,6 +26,14 @@
)


class IncludeItem(TypedDict):
path: str
cwd: str


IncludeItem.__optional_keys__ = frozenset({"cwd"})


class ConfigPartition:
options: PoeOptions
full_config: Mapping[str, Any]
Expand Down Expand Up @@ -74,9 +84,6 @@ def get(self, key: str, default: Any = NoValue):
return self.options.get(key, default)


EmptyDict: Mapping = MappingProxyType({})


class ProjectConfig(ConfigPartition):
is_primary = True

Expand All @@ -88,10 +95,10 @@ class ConfigOptions(PoeOptions):
default_task_type: str = "cmd"
default_array_task_type: str = "sequence"
default_array_item_task_type: str = "ref"
env: Mapping[str, str] = EmptyDict
env: Mapping[str, Union[str, EnvDefault]] = EmptyDict
envfile: Union[str, Sequence[str]] = tuple()
executor: Mapping[str, str] = MappingProxyType({"type": "auto"})
include: Sequence[str] = tuple()
include: Union[str, Sequence[str], Sequence[IncludeItem]] = tuple()
poetry_command: str = "poe"
poetry_hooks: Mapping[str, str] = EmptyDict
shell_interpreter: Union[str, Sequence[str]] = "posix"
Expand Down
8 changes: 8 additions & 0 deletions poethepoet/config/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from types import MappingProxyType
from typing import Mapping, TypedDict

EmptyDict: Mapping = MappingProxyType({})


class EnvDefault(TypedDict):
default: str
131 changes: 24 additions & 107 deletions poethepoet/options.py → poethepoet/options/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
import collections
from __future__ import annotations

from keyword import iskeyword
from typing import (
Any,
Dict,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
Union,
get_args,
get_origin,
)

from .exceptions import ConfigValidationError
from typing import Any, Mapping, Sequence, get_type_hints

from ..exceptions import ConfigValidationError
from .annotations import TypeAnnotation

NoValue = object()

Expand All @@ -26,7 +14,7 @@ class PoeOptions:
A special kind of config object that parses options ...
"""

__annotations: Dict[str, Type]
__annotations: dict[str, TypeAnnotation]

def __init__(self, **options: Any):
for key in self.get_fields():
Expand Down Expand Up @@ -61,13 +49,13 @@ def __getattr__(self, name: str):
@classmethod
def parse(
cls,
source: Union[Mapping[str, Any], list],
source: Mapping[str, Any] | list,
strict: bool = True,
extra_keys: Sequence[str] = tuple(),
):
config_keys = {
key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: vtype
for key, vtype in cls.get_fields().items()
key[:-1] if key.endswith("_") and iskeyword(key[:-1]) else key: type_
for key, type_ in cls.get_fields().items()
}
if strict:
for index, item in enumerate(cls.normalize(source, strict)):
Expand Down Expand Up @@ -110,29 +98,8 @@ def _parse_value(
return value_type.parse(value, strict=strict)

if strict:
expected_type: Union[Type, Tuple[Type, ...]] = cls._type_of(value_type)
if not isinstance(value, expected_type):
# Try format expected_type nicely in the error message
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)
formatted_type = " | ".join(
type_.__name__ for type_ in expected_type if type_ is not type(None)
)
raise ConfigValidationError(
f"Option {key!r} should have a value of type: {formatted_type}",
index=index,
)

annotation = cls.get_annotation(key)
if get_origin(annotation) is Literal:
allowed_values = get_args(annotation)
if value not in allowed_values:
raise ConfigValidationError(
f"Option {key!r} must be one of {allowed_values!r}",
index=index,
)

# TODO: validate list/dict contents
for error_msg in value_type.validate((key,), value):
raise ConfigValidationError(error_msg, index=index)

return value

Expand Down Expand Up @@ -171,43 +138,25 @@ def get(self, key: str, default: Any = NoValue) -> Any:
if default is NoValue:
# Fallback to getting getting the zero value for the type of this attribute
# e.g. 0, False, empty list, empty dict, etc
return self.__get_zero_value(key)
annotation = self.get_fields().get(self._resolve_key(key))
assert annotation
return annotation.zero_value()

return default

def __get_zero_value(self, key: str):
type_of_attr = self.type_of(key)
if isinstance(type_of_attr, tuple):
if type(None) in type_of_attr:
# Optional types default to None
return None
type_of_attr = type_of_attr[0]
assert type_of_attr
return type_of_attr()

def __is_optional(self, key: str):
# TODO: precache optional options keys?
type_of_attr = self.type_of(key)
if isinstance(type_of_attr, tuple):
return type(None) in type_of_attr
return False
annotation = self.get_fields().get(self._resolve_key(key))
assert annotation
return annotation.is_optional

def update(self, options_dict: Dict[str, Any]):
def update(self, options_dict: dict[str, Any]):
new_options_dict = {}
for key in self.get_fields().keys():
if key in options_dict:
new_options_dict[key] = options_dict[key]
elif hasattr(self, key):
new_options_dict[key] = getattr(self, key)

@classmethod
def type_of(cls, key: str) -> Optional[Union[Type, Tuple[Type, ...]]]:
return cls._type_of(cls.get_annotation(key))

@classmethod
def get_annotation(cls, key: str) -> Optional[Type]:
return cls.get_fields().get(cls._resolve_key(key))

@classmethod
def _resolve_key(cls, key: str) -> str:
"""
Expand All @@ -219,51 +168,19 @@ def _resolve_key(cls, key: str) -> str:
return key

@classmethod
def _type_of(cls, annotation: Any) -> Union[Type, Tuple[Type, ...]]:
if get_origin(annotation) is Union:
result: List[Type] = []
for component in get_args(annotation):
component_type = cls._type_of(component)
if isinstance(component_type, tuple):
result.extend(component_type)
else:
result.append(component_type)
return tuple(result)

if get_origin(annotation) in (
dict,
Mapping,
MutableMapping,
collections.abc.Mapping,
collections.abc.MutableMapping,
):
return dict

if get_origin(annotation) in (
list,
Sequence,
collections.abc.Sequence,
):
return list

if get_origin(annotation) is Literal:
return tuple({type(arg) for arg in get_args(annotation)})

return annotation

@classmethod
def get_fields(cls) -> Dict[str, Any]:
def get_fields(cls) -> dict[str, TypeAnnotation]:
"""
Recent python versions removed inheritance for __annotations__
so we have to implement it explicitly
"""
if not hasattr(cls, "__annotations"):
annotations = {}
for base_cls in cls.__bases__:
annotations.update(base_cls.__annotations__)
annotations.update(cls.__annotations__)
annotations.update(get_type_hints(base_cls))
annotations.update(get_type_hints(cls))

cls.__annotations = {
key: type_
key: TypeAnnotation.parse(type_)
for key, type_ in annotations.items()
if not key.startswith("_")
}
Expand Down
Loading