From 72eb7a976dcc47207ee968555518d3c253ce9ff6 Mon Sep 17 00:00:00 2001 From: Septatrix <24257556+Septatrix@users.noreply.github.com> Date: Wed, 20 Nov 2024 21:48:16 +0100 Subject: [PATCH] Make ConfigSetting generic --- mkosi/config.py | 113 +++++++++++++++++++++++++++--------------------- 1 file changed, 63 insertions(+), 50 deletions(-) diff --git a/mkosi/config.py b/mkosi/config.py index ca88a1e3ec..6e49bdbb1e 100644 --- a/mkosi/config.py +++ b/mkosi/config.py @@ -28,7 +28,7 @@ from collections.abc import Collection, Iterable, Iterator, Sequence from contextlib import AbstractContextManager from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union, cast +from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast from mkosi.distributions import Distribution, detect_distribution from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, Style, die @@ -48,11 +48,11 @@ from mkosi.versioncomp import GenericVersion T = TypeVar("T") +SE = TypeVar("SE", bound=StrEnum) -ConfigParseCallback = Callable[[Optional[str], Optional[Any]], Any] -ConfigMatchCallback = Callable[[str, Any], bool] -ConfigDefaultCallback = Callable[[argparse.Namespace], Any] - +ConfigParseCallback = Callable[[Optional[str], Optional[T]], Optional[T]] +ConfigMatchCallback = Callable[[str, T], bool] +ConfigDefaultCallback = Callable[[argparse.Namespace], T] BUILTIN_CONFIGS = ("mkosi-tools", "mkosi-initrd", "mkosi-vm") @@ -676,7 +676,7 @@ def config_match_build_sources(match: str, value: list[ConfigTree]) -> bool: return Path(match.lstrip("/")) in [tree.target for tree in value if tree.target] -def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback: +def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback[list[T]]: def config_match_list(match: str, value: list[T]) -> bool: return parse(match) in value @@ -687,7 +687,7 @@ def config_parse_string(value: Optional[str], old: Optional[str]) -> Optional[st return value or None -def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback: +def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback[str]: def config_match_string(match: str, value: str) -> bool: if allow_globs: return fnmatch.fnmatchcase(value, match) @@ -906,8 +906,8 @@ def config_default_proxy_url(namespace: argparse.Namespace) -> Optional[str]: return None -def make_enum_parser(type: type[StrEnum]) -> Callable[[str], StrEnum]: - def parse_enum(value: str) -> StrEnum: +def make_enum_parser(type: type[SE]) -> Callable[[str], SE]: + def parse_enum(value: str) -> SE: try: return type(value) except ValueError: @@ -916,17 +916,15 @@ def parse_enum(value: str) -> StrEnum: return parse_enum -def config_make_enum_parser(type: type[StrEnum]) -> ConfigParseCallback: - def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]: +def config_make_enum_parser(type: type[SE]) -> ConfigParseCallback[SE]: + def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]: return make_enum_parser(type)(value) if value else None return config_parse_enum -def config_make_enum_parser_with_boolean( - type: type[StrEnum], *, yes: StrEnum, no: StrEnum -) -> ConfigParseCallback: - def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]: +def config_make_enum_parser_with_boolean(type: type[SE], *, yes: SE, no: SE) -> ConfigParseCallback[SE]: + def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]: if not value: return None @@ -938,8 +936,8 @@ def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[ return config_parse_enum -def config_make_enum_matcher(type: type[StrEnum]) -> ConfigMatchCallback: - def config_match_enum(match: str, value: StrEnum) -> bool: +def config_make_enum_matcher(type: type[SE]) -> ConfigMatchCallback[SE]: + def config_match_enum(match: str, value: SE) -> bool: return make_enum_parser(type)(match) == value return config_match_enum @@ -948,11 +946,11 @@ def config_match_enum(match: str, value: StrEnum) -> bool: def config_make_list_parser( *, delimiter: Optional[str] = None, - parse: Callable[[str], Any] = str, + parse: Callable[[str], T] = str, # type: ignore # see mypy#3737 unescape: bool = False, reset: bool = True, -) -> ConfigParseCallback: - def config_parse_list(value: Optional[str], old: Optional[list[Any]]) -> Optional[list[Any]]: +) -> ConfigParseCallback[list[T]]: + def config_parse_list(value: Optional[str], old: Optional[list[T]]) -> Optional[list[T]]: new = old.copy() if old else [] if value is None: @@ -1010,12 +1008,12 @@ def config_match_version(match: str, value: str) -> bool: def config_make_dict_parser( *, delimiter: Optional[str] = None, - parse: Callable[[str], tuple[str, Any]], + parse: Callable[[str], tuple[str, str]], unescape: bool = False, allow_paths: bool = False, reset: bool = True, -) -> ConfigParseCallback: - def config_parse_dict(value: Optional[str], old: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: +) -> ConfigParseCallback[dict[str, str]]: + def config_parse_dict(value: Optional[str], old: Optional[dict[str, str]]) -> Optional[dict[str, str]]: new = old.copy() if old else {} if value is None: @@ -1104,7 +1102,7 @@ def config_make_path_parser( expandvars: bool = True, secret: bool = False, constants: Sequence[str] = (), -) -> ConfigParseCallback: +) -> ConfigParseCallback[Path]: def config_parse_path(value: Optional[str], old: Optional[Path]) -> Optional[Path]: if not value: return None @@ -1127,7 +1125,7 @@ def is_valid_filename(s: str) -> bool: return not (s == "." or s == ".." or "/" in s) -def config_make_filename_parser(hint: str) -> ConfigParseCallback: +def config_make_filename_parser(hint: str) -> ConfigParseCallback[str]: def config_parse_filename(value: Optional[str], old: Optional[str]) -> Optional[str]: if not value: return None @@ -1389,7 +1387,8 @@ def config_parse_artifact_output_list( if boolean_value is not None: return ArtifactOutput.compat_yes() if boolean_value else ArtifactOutput.compat_no() - list_value = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))(value, old) + list_parser = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput)) + list_value = list_parser(value, old) return cast(list[ArtifactOutput], list_value) @@ -1403,14 +1402,14 @@ class SettingScope(StrEnum): @dataclasses.dataclass(frozen=True) -class ConfigSetting: +class ConfigSetting(Generic[T]): dest: str section: str - parse: ConfigParseCallback = config_parse_string - match: Optional[ConfigMatchCallback] = None + parse: ConfigParseCallback[T] = config_parse_string # type: ignore # see mypy#3737 + match: Optional[ConfigMatchCallback[T]] = None name: str = "" - default: Any = None - default_factory: Optional[ConfigDefaultCallback] = None + default: Optional[T] = None + default_factory: Optional[ConfigDefaultCallback[T]] = None default_factory_depends: tuple[str, ...] = tuple() paths: tuple[str, ...] = () recursive_paths: tuple[str, ...] = () @@ -1422,7 +1421,7 @@ class ConfigSetting: # settings for argparse short: Optional[str] = None long: str = "" - choices: Optional[Any] = None + choices: Optional[list[str]] = None metavar: Optional[str] = None nargs: Optional[str] = None const: Optional[Any] = None @@ -1529,7 +1528,7 @@ def __call__( parser.exit() -def dict_with_capitalised_keys_factory(pairs: Any) -> dict[str, Any]: +def dict_with_capitalised_keys_factory(pairs: list[tuple[str, T]]) -> dict[str, T]: def key_transformer(k: str) -> str: if (s := SETTINGS_LOOKUP_BY_DEST.get(k)) is not None: return s.name @@ -1634,11 +1633,13 @@ class UKIProfile: cmdline: list[str] -def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]) -> Callable[[str], Any]: +def make_simple_config_parser( + settings: Sequence[ConfigSetting[object]], valtype: type[T] +) -> Callable[[str], T]: lookup_by_name = {s.name: s for s in settings} lookup_by_dest = {s.dest: s for s in settings} - def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None: + def finalize_value(config: argparse.Namespace, setting: ConfigSetting[object]) -> None: if hasattr(config, setting.dest): return @@ -1654,7 +1655,7 @@ def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None: setattr(config, setting.dest, default) - def parse_simple_config(value: str) -> Any: + def parse_simple_config(value: str) -> T: path = parse_path(value) config = argparse.Namespace() @@ -1681,7 +1682,9 @@ def parse_simple_config(value: str) -> Any: for setting in settings: finalize_value(config, setting) - return type(**{k: v for k, v in vars(config).items() if k in inspect.signature(type).parameters}) + return valtype( + **{k: v for k, v in vars(config).items() if k in inspect.signature(valtype).parameters} + ) return parse_simple_config @@ -2160,7 +2163,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple yield section, "", "" -PE_ADDON_SETTINGS = ( +PE_ADDON_SETTINGS: list[ConfigSetting[Any]] = [ ConfigSetting( dest="output", section="PEAddon", @@ -2172,10 +2175,10 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple section="PEAddon", parse=config_make_list_parser(delimiter=" "), ), -) +] -UKI_PROFILE_SETTINGS = ( +UKI_PROFILE_SETTINGS: list[ConfigSetting[Any]] = [ ConfigSetting( dest="profile", section="UKIProfile", @@ -2186,10 +2189,10 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple section="UKIProfile", parse=config_make_list_parser(delimiter=" "), ), -) +] -SETTINGS = ( +SETTINGS: list[ConfigSetting[Any]] = [ # Include section ConfigSetting( dest="include", @@ -3665,7 +3668,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple # arguments. help=argparse.SUPPRESS, ), -) +] SETTINGS_LOOKUP_BY_NAME = {name: s for s in SETTINGS for name in [s.name, *s.compat_names]} SETTINGS_LOOKUP_BY_DEST = {s.dest: s for s in SETTINGS} SETTINGS_LOOKUP_BY_SPECIFIER = {s.specifier: s for s in SETTINGS if s.specifier} @@ -4062,17 +4065,27 @@ def parse_new_includes(self) -> None: with chdir(path if path.is_dir() else Path.cwd()): self.parse_config_one(path if path.is_file() else Path(".")) - def finalize_value(self, setting: ConfigSetting) -> Optional[Any]: + def finalize_value(self, setting: ConfigSetting[T]) -> Optional[T]: # If a value was specified on the CLI, it always takes priority. If the setting is a collection of # values, we merge the value from the CLI with the value from the configuration, making sure that the # value from the CLI always takes priority. - if hasattr(self.cli, setting.dest) and (v := getattr(self.cli, setting.dest)) is not None: + if (v := cast(Optional[T], getattr(self.cli, setting.dest, None))) is not None: + cfg_value = getattr(self.config, setting.dest, None) + if cfg_value is None: + return v + + # The instance asserts are pushed down to help mypy/pylance narrow the types. + # Mypy still cannot properly infer that the merged collections conform to T + # so we ignore the return-value error for it. if isinstance(v, list): - return (getattr(self.config, setting.dest, None) or []) + v + assert isinstance(cfg_value, type(v)) + return cfg_value + v # type: ignore[return-value] elif isinstance(v, dict): - return (getattr(self.config, setting.dest, None) or {}) | v + assert isinstance(cfg_value, type(v)) + return cfg_value | v # type: ignore[return-value] elif isinstance(v, set): - return (getattr(self.config, setting.dest, None) or set()) | v + assert isinstance(cfg_value, type(v)) + return cfg_value | v # type: ignore[return-value] else: return v @@ -4083,7 +4096,7 @@ def finalize_value(self, setting: ConfigSetting) -> Optional[Any]: if ( not hasattr(self.cli, setting.dest) and hasattr(self.config, setting.dest) - and (v := getattr(self.config, setting.dest)) is not None + and (v := cast(Optional[T], getattr(self.config, setting.dest))) is not None ): return v @@ -4190,7 +4203,7 @@ def match_config(self, path: Path) -> bool: return match_triggered is not False def parse_config_one(self, path: Path, parse_profiles: bool = False, parse_local: bool = False) -> bool: - s: Optional[ConfigSetting] # Make mypy happy + s: Optional[ConfigSetting[object]] # Hint to mypy that we might assign None extras = path.is_dir() if path.is_dir():