Skip to content

Commit

Permalink
Make ConfigSetting generic
Browse files Browse the repository at this point in the history
  • Loading branch information
septatrix committed Nov 21, 2024
1 parent 4447b04 commit 72eb7a9
Showing 1 changed file with 63 additions and 50 deletions.
113 changes: 63 additions & 50 deletions mkosi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections.abc import Collection, Iterable, Iterator, Sequence
from contextlib import AbstractContextManager
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union, cast
from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast

from mkosi.distributions import Distribution, detect_distribution
from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, Style, die
Expand All @@ -48,11 +48,11 @@
from mkosi.versioncomp import GenericVersion

T = TypeVar("T")
SE = TypeVar("SE", bound=StrEnum)

ConfigParseCallback = Callable[[Optional[str], Optional[Any]], Any]
ConfigMatchCallback = Callable[[str, Any], bool]
ConfigDefaultCallback = Callable[[argparse.Namespace], Any]

ConfigParseCallback = Callable[[Optional[str], Optional[T]], Optional[T]]
ConfigMatchCallback = Callable[[str, T], bool]
ConfigDefaultCallback = Callable[[argparse.Namespace], T]

BUILTIN_CONFIGS = ("mkosi-tools", "mkosi-initrd", "mkosi-vm")

Expand Down Expand Up @@ -676,7 +676,7 @@ def config_match_build_sources(match: str, value: list[ConfigTree]) -> bool:
return Path(match.lstrip("/")) in [tree.target for tree in value if tree.target]


def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback:
def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback[list[T]]:
def config_match_list(match: str, value: list[T]) -> bool:
return parse(match) in value

Expand All @@ -687,7 +687,7 @@ def config_parse_string(value: Optional[str], old: Optional[str]) -> Optional[st
return value or None


def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback:
def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback[str]:
def config_match_string(match: str, value: str) -> bool:
if allow_globs:
return fnmatch.fnmatchcase(value, match)
Expand Down Expand Up @@ -906,8 +906,8 @@ def config_default_proxy_url(namespace: argparse.Namespace) -> Optional[str]:
return None


def make_enum_parser(type: type[StrEnum]) -> Callable[[str], StrEnum]:
def parse_enum(value: str) -> StrEnum:
def make_enum_parser(type: type[SE]) -> Callable[[str], SE]:
def parse_enum(value: str) -> SE:
try:
return type(value)
except ValueError:
Expand All @@ -916,17 +916,15 @@ def parse_enum(value: str) -> StrEnum:
return parse_enum


def config_make_enum_parser(type: type[StrEnum]) -> ConfigParseCallback:
def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
def config_make_enum_parser(type: type[SE]) -> ConfigParseCallback[SE]:
def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
return make_enum_parser(type)(value) if value else None

return config_parse_enum


def config_make_enum_parser_with_boolean(
type: type[StrEnum], *, yes: StrEnum, no: StrEnum
) -> ConfigParseCallback:
def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
def config_make_enum_parser_with_boolean(type: type[SE], *, yes: SE, no: SE) -> ConfigParseCallback[SE]:
def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
if not value:
return None

Expand All @@ -938,8 +936,8 @@ def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[
return config_parse_enum


def config_make_enum_matcher(type: type[StrEnum]) -> ConfigMatchCallback:
def config_match_enum(match: str, value: StrEnum) -> bool:
def config_make_enum_matcher(type: type[SE]) -> ConfigMatchCallback[SE]:
def config_match_enum(match: str, value: SE) -> bool:
return make_enum_parser(type)(match) == value

return config_match_enum
Expand All @@ -948,11 +946,11 @@ def config_match_enum(match: str, value: StrEnum) -> bool:
def config_make_list_parser(
*,
delimiter: Optional[str] = None,
parse: Callable[[str], Any] = str,
parse: Callable[[str], T] = str, # type: ignore # see mypy#3737
unescape: bool = False,
reset: bool = True,
) -> ConfigParseCallback:
def config_parse_list(value: Optional[str], old: Optional[list[Any]]) -> Optional[list[Any]]:
) -> ConfigParseCallback[list[T]]:
def config_parse_list(value: Optional[str], old: Optional[list[T]]) -> Optional[list[T]]:
new = old.copy() if old else []

if value is None:
Expand Down Expand Up @@ -1010,12 +1008,12 @@ def config_match_version(match: str, value: str) -> bool:
def config_make_dict_parser(
*,
delimiter: Optional[str] = None,
parse: Callable[[str], tuple[str, Any]],
parse: Callable[[str], tuple[str, str]],
unescape: bool = False,
allow_paths: bool = False,
reset: bool = True,
) -> ConfigParseCallback:
def config_parse_dict(value: Optional[str], old: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
) -> ConfigParseCallback[dict[str, str]]:
def config_parse_dict(value: Optional[str], old: Optional[dict[str, str]]) -> Optional[dict[str, str]]:
new = old.copy() if old else {}

if value is None:
Expand Down Expand Up @@ -1104,7 +1102,7 @@ def config_make_path_parser(
expandvars: bool = True,
secret: bool = False,
constants: Sequence[str] = (),
) -> ConfigParseCallback:
) -> ConfigParseCallback[Path]:
def config_parse_path(value: Optional[str], old: Optional[Path]) -> Optional[Path]:
if not value:
return None
Expand All @@ -1127,7 +1125,7 @@ def is_valid_filename(s: str) -> bool:
return not (s == "." or s == ".." or "/" in s)


def config_make_filename_parser(hint: str) -> ConfigParseCallback:
def config_make_filename_parser(hint: str) -> ConfigParseCallback[str]:
def config_parse_filename(value: Optional[str], old: Optional[str]) -> Optional[str]:
if not value:
return None
Expand Down Expand Up @@ -1389,7 +1387,8 @@ def config_parse_artifact_output_list(
if boolean_value is not None:
return ArtifactOutput.compat_yes() if boolean_value else ArtifactOutput.compat_no()

list_value = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))(value, old)
list_parser = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))
list_value = list_parser(value, old)
return cast(list[ArtifactOutput], list_value)


Expand All @@ -1403,14 +1402,14 @@ class SettingScope(StrEnum):


@dataclasses.dataclass(frozen=True)
class ConfigSetting:
class ConfigSetting(Generic[T]):
dest: str
section: str
parse: ConfigParseCallback = config_parse_string
match: Optional[ConfigMatchCallback] = None
parse: ConfigParseCallback[T] = config_parse_string # type: ignore # see mypy#3737
match: Optional[ConfigMatchCallback[T]] = None
name: str = ""
default: Any = None
default_factory: Optional[ConfigDefaultCallback] = None
default: Optional[T] = None
default_factory: Optional[ConfigDefaultCallback[T]] = None
default_factory_depends: tuple[str, ...] = tuple()
paths: tuple[str, ...] = ()
recursive_paths: tuple[str, ...] = ()
Expand All @@ -1422,7 +1421,7 @@ class ConfigSetting:
# settings for argparse
short: Optional[str] = None
long: str = ""
choices: Optional[Any] = None
choices: Optional[list[str]] = None
metavar: Optional[str] = None
nargs: Optional[str] = None
const: Optional[Any] = None
Expand Down Expand Up @@ -1529,7 +1528,7 @@ def __call__(
parser.exit()


def dict_with_capitalised_keys_factory(pairs: Any) -> dict[str, Any]:
def dict_with_capitalised_keys_factory(pairs: list[tuple[str, T]]) -> dict[str, T]:
def key_transformer(k: str) -> str:
if (s := SETTINGS_LOOKUP_BY_DEST.get(k)) is not None:
return s.name
Expand Down Expand Up @@ -1634,11 +1633,13 @@ class UKIProfile:
cmdline: list[str]


def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]) -> Callable[[str], Any]:
def make_simple_config_parser(
settings: Sequence[ConfigSetting[object]], valtype: type[T]
) -> Callable[[str], T]:
lookup_by_name = {s.name: s for s in settings}
lookup_by_dest = {s.dest: s for s in settings}

def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:
def finalize_value(config: argparse.Namespace, setting: ConfigSetting[object]) -> None:
if hasattr(config, setting.dest):
return

Expand All @@ -1654,7 +1655,7 @@ def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:

setattr(config, setting.dest, default)

def parse_simple_config(value: str) -> Any:
def parse_simple_config(value: str) -> T:
path = parse_path(value)
config = argparse.Namespace()

Expand All @@ -1681,7 +1682,9 @@ def parse_simple_config(value: str) -> Any:
for setting in settings:
finalize_value(config, setting)

return type(**{k: v for k, v in vars(config).items() if k in inspect.signature(type).parameters})
return valtype(
**{k: v for k, v in vars(config).items() if k in inspect.signature(valtype).parameters}
)

return parse_simple_config

Expand Down Expand Up @@ -2160,7 +2163,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
yield section, "", ""


PE_ADDON_SETTINGS = (
PE_ADDON_SETTINGS: list[ConfigSetting[Any]] = [
ConfigSetting(
dest="output",
section="PEAddon",
Expand All @@ -2172,10 +2175,10 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
section="PEAddon",
parse=config_make_list_parser(delimiter=" "),
),
)
]


UKI_PROFILE_SETTINGS = (
UKI_PROFILE_SETTINGS: list[ConfigSetting[Any]] = [
ConfigSetting(
dest="profile",
section="UKIProfile",
Expand All @@ -2186,10 +2189,10 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
section="UKIProfile",
parse=config_make_list_parser(delimiter=" "),
),
)
]


SETTINGS = (
SETTINGS: list[ConfigSetting[Any]] = [
# Include section
ConfigSetting(
dest="include",
Expand Down Expand Up @@ -3665,7 +3668,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
# arguments.
help=argparse.SUPPRESS,
),
)
]
SETTINGS_LOOKUP_BY_NAME = {name: s for s in SETTINGS for name in [s.name, *s.compat_names]}
SETTINGS_LOOKUP_BY_DEST = {s.dest: s for s in SETTINGS}
SETTINGS_LOOKUP_BY_SPECIFIER = {s.specifier: s for s in SETTINGS if s.specifier}
Expand Down Expand Up @@ -4062,17 +4065,27 @@ def parse_new_includes(self) -> None:
with chdir(path if path.is_dir() else Path.cwd()):
self.parse_config_one(path if path.is_file() else Path("."))

def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
def finalize_value(self, setting: ConfigSetting[T]) -> Optional[T]:
# If a value was specified on the CLI, it always takes priority. If the setting is a collection of
# values, we merge the value from the CLI with the value from the configuration, making sure that the
# value from the CLI always takes priority.
if hasattr(self.cli, setting.dest) and (v := getattr(self.cli, setting.dest)) is not None:
if (v := cast(Optional[T], getattr(self.cli, setting.dest, None))) is not None:
cfg_value = getattr(self.config, setting.dest, None)
if cfg_value is None:
return v

# The instance asserts are pushed down to help mypy/pylance narrow the types.
# Mypy still cannot properly infer that the merged collections conform to T
# so we ignore the return-value error for it.
if isinstance(v, list):
return (getattr(self.config, setting.dest, None) or []) + v
assert isinstance(cfg_value, type(v))
return cfg_value + v # type: ignore[return-value]
elif isinstance(v, dict):
return (getattr(self.config, setting.dest, None) or {}) | v
assert isinstance(cfg_value, type(v))
return cfg_value | v # type: ignore[return-value]
elif isinstance(v, set):
return (getattr(self.config, setting.dest, None) or set()) | v
assert isinstance(cfg_value, type(v))
return cfg_value | v # type: ignore[return-value]
else:
return v

Expand All @@ -4083,7 +4096,7 @@ def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
if (
not hasattr(self.cli, setting.dest)
and hasattr(self.config, setting.dest)
and (v := getattr(self.config, setting.dest)) is not None
and (v := cast(Optional[T], getattr(self.config, setting.dest))) is not None
):
return v

Expand Down Expand Up @@ -4190,7 +4203,7 @@ def match_config(self, path: Path) -> bool:
return match_triggered is not False

def parse_config_one(self, path: Path, parse_profiles: bool = False, parse_local: bool = False) -> bool:
s: Optional[ConfigSetting] # Make mypy happy
s: Optional[ConfigSetting[object]] # Hint to mypy that we might assign None
extras = path.is_dir()

if path.is_dir():
Expand Down

0 comments on commit 72eb7a9

Please sign in to comment.