Skip to content

Commit

Permalink
Make ConfigSetting generic
Browse files Browse the repository at this point in the history
  • Loading branch information
septatrix committed Nov 20, 2024
1 parent 4447b04 commit f1882f2
Showing 1 changed file with 48 additions and 51 deletions.
99 changes: 48 additions & 51 deletions mkosi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections.abc import Collection, Iterable, Iterator, Sequence
from contextlib import AbstractContextManager
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union, cast
from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast

from mkosi.distributions import Distribution, detect_distribution
from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, Style, die
Expand All @@ -48,11 +48,11 @@
from mkosi.versioncomp import GenericVersion

T = TypeVar("T")
SE = TypeVar("SE", bound=StrEnum)

ConfigParseCallback = Callable[[Optional[str], Optional[Any]], Any]
ConfigMatchCallback = Callable[[str, Any], bool]
ConfigDefaultCallback = Callable[[argparse.Namespace], Any]

ConfigParseCallback = Callable[[Optional[str], Optional[T]], Optional[T]]
ConfigMatchCallback = Callable[[str, T], bool]
ConfigDefaultCallback = Callable[[argparse.Namespace], T]

BUILTIN_CONFIGS = ("mkosi-tools", "mkosi-initrd", "mkosi-vm")

Expand Down Expand Up @@ -676,7 +676,7 @@ def config_match_build_sources(match: str, value: list[ConfigTree]) -> bool:
return Path(match.lstrip("/")) in [tree.target for tree in value if tree.target]


def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback:
def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback[list[T]]:
def config_match_list(match: str, value: list[T]) -> bool:
return parse(match) in value

Expand All @@ -687,7 +687,7 @@ def config_parse_string(value: Optional[str], old: Optional[str]) -> Optional[st
return value or None


def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback:
def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback[str]:
def config_match_string(match: str, value: str) -> bool:
if allow_globs:
return fnmatch.fnmatchcase(value, match)
Expand Down Expand Up @@ -906,8 +906,8 @@ def config_default_proxy_url(namespace: argparse.Namespace) -> Optional[str]:
return None


def make_enum_parser(type: type[StrEnum]) -> Callable[[str], StrEnum]:
def parse_enum(value: str) -> StrEnum:
def make_enum_parser(type: type[SE]) -> Callable[[str], SE]:
def parse_enum(value: str) -> SE:
try:
return type(value)
except ValueError:
Expand All @@ -916,17 +916,17 @@ def parse_enum(value: str) -> StrEnum:
return parse_enum


def config_make_enum_parser(type: type[StrEnum]) -> ConfigParseCallback:
def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
def config_make_enum_parser(type: type[SE]) -> ConfigParseCallback[SE]:
def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
return make_enum_parser(type)(value) if value else None

return config_parse_enum


def config_make_enum_parser_with_boolean(
type: type[StrEnum], *, yes: StrEnum, no: StrEnum
) -> ConfigParseCallback:
def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
type: type[SE], *, yes: SE, no: SE
) -> ConfigParseCallback[SE]:
def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
if not value:
return None

Expand All @@ -938,8 +938,8 @@ def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[
return config_parse_enum


def config_make_enum_matcher(type: type[StrEnum]) -> ConfigMatchCallback:
def config_match_enum(match: str, value: StrEnum) -> bool:
def config_make_enum_matcher(type: type[SE]) -> ConfigMatchCallback[SE]:
def config_match_enum(match: str, value: SE) -> bool:
return make_enum_parser(type)(match) == value

return config_match_enum
Expand All @@ -948,13 +948,11 @@ def config_match_enum(match: str, value: StrEnum) -> bool:
def config_make_list_parser(
*,
delimiter: Optional[str] = None,
parse: Callable[[str], Any] = str,
parse: Callable[[str], T] = str, # type: ignore # see mypy#3737
unescape: bool = False,
reset: bool = True,
) -> ConfigParseCallback:
def config_parse_list(value: Optional[str], old: Optional[list[Any]]) -> Optional[list[Any]]:
new = old.copy() if old else []

) -> ConfigParseCallback[list[T]]:
def config_parse_list(value: Optional[str], old: Optional[list[T]]) -> Optional[list[T]]:
if value is None:
return []

Expand All @@ -975,11 +973,10 @@ def config_parse_list(value: Optional[str], old: Optional[list[Any]]) -> Optiona
if reset and len(values) == 1 and values[0] == "":
return None

return new + [parse(v) for v in values if v]
return (old.copy() if old else []) + [parse(v) for v in values if v]

return config_parse_list


def config_match_version(match: str, value: str) -> bool:
version = GenericVersion(value)

Expand All @@ -1006,16 +1003,15 @@ def config_match_version(match: str, value: str) -> bool:

return True


def config_make_dict_parser(
*,
delimiter: Optional[str] = None,
parse: Callable[[str], tuple[str, Any]],
parse: Callable[[str], tuple[str, str]],
unescape: bool = False,
allow_paths: bool = False,
reset: bool = True,
) -> ConfigParseCallback:
def config_parse_dict(value: Optional[str], old: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
) -> ConfigParseCallback[dict[str, str]]:
def config_parse_dict(value: Optional[str], old: Optional[dict[str, str]]) -> Optional[dict[str, str]]:
new = old.copy() if old else {}

if value is None:
Expand Down Expand Up @@ -1104,7 +1100,7 @@ def config_make_path_parser(
expandvars: bool = True,
secret: bool = False,
constants: Sequence[str] = (),
) -> ConfigParseCallback:
) -> ConfigParseCallback[Path]:
def config_parse_path(value: Optional[str], old: Optional[Path]) -> Optional[Path]:
if not value:
return None
Expand All @@ -1127,7 +1123,7 @@ def is_valid_filename(s: str) -> bool:
return not (s == "." or s == ".." or "/" in s)


def config_make_filename_parser(hint: str) -> ConfigParseCallback:
def config_make_filename_parser(hint: str) -> ConfigParseCallback[str]:
def config_parse_filename(value: Optional[str], old: Optional[str]) -> Optional[str]:
if not value:
return None
Expand Down Expand Up @@ -1389,7 +1385,8 @@ def config_parse_artifact_output_list(
if boolean_value is not None:
return ArtifactOutput.compat_yes() if boolean_value else ArtifactOutput.compat_no()

list_value = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))(value, old)
list_parser = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))
list_value = list_parser(value, old)
return cast(list[ArtifactOutput], list_value)


Expand All @@ -1403,14 +1400,14 @@ class SettingScope(StrEnum):


@dataclasses.dataclass(frozen=True)
class ConfigSetting:
class ConfigSetting(Generic[T]):
dest: str
section: str
parse: ConfigParseCallback = config_parse_string
match: Optional[ConfigMatchCallback] = None
parse: ConfigParseCallback[T] = config_parse_string # type: ignore # see mypy#3737
match: Optional[ConfigMatchCallback[T]] = None
name: str = ""
default: Any = None
default_factory: Optional[ConfigDefaultCallback] = None
default: Optional[T] = None
default_factory: Optional[ConfigDefaultCallback[T]] = None
default_factory_depends: tuple[str, ...] = tuple()
paths: tuple[str, ...] = ()
recursive_paths: tuple[str, ...] = ()
Expand All @@ -1422,7 +1419,7 @@ class ConfigSetting:
# settings for argparse
short: Optional[str] = None
long: str = ""
choices: Optional[Any] = None
choices: Optional[list[Any]] = None
metavar: Optional[str] = None
nargs: Optional[str] = None
const: Optional[Any] = None
Expand Down Expand Up @@ -1529,7 +1526,7 @@ def __call__(
parser.exit()


def dict_with_capitalised_keys_factory(pairs: Any) -> dict[str, Any]:
def dict_with_capitalised_keys_factory(pairs: list[tuple[str, T]]) -> dict[str, T]:
def key_transformer(k: str) -> str:
if (s := SETTINGS_LOOKUP_BY_DEST.get(k)) is not None:
return s.name
Expand Down Expand Up @@ -1634,11 +1631,11 @@ class UKIProfile:
cmdline: list[str]


def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]) -> Callable[[str], Any]:
def make_simple_config_parser(settings: Sequence[ConfigSetting[Any]], type_: type[T]) -> Callable[[str], T]:
lookup_by_name = {s.name: s for s in settings}
lookup_by_dest = {s.dest: s for s in settings}

def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:
def finalize_value(config: argparse.Namespace, setting: ConfigSetting[Any]) -> None:
if hasattr(config, setting.dest):
return

Expand All @@ -1654,7 +1651,7 @@ def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:

setattr(config, setting.dest, default)

def parse_simple_config(value: str) -> Any:
def parse_simple_config(value: str) -> T:
path = parse_path(value)
config = argparse.Namespace()

Expand All @@ -1681,7 +1678,7 @@ def parse_simple_config(value: str) -> Any:
for setting in settings:
finalize_value(config, setting)

return type(**{k: v for k, v in vars(config).items() if k in inspect.signature(type).parameters})
return type_(**{k: v for k, v in vars(config).items() if k in inspect.signature(type_).parameters})

return parse_simple_config

Expand Down Expand Up @@ -2160,7 +2157,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
yield section, "", ""


PE_ADDON_SETTINGS = (
PE_ADDON_SETTINGS: tuple[ConfigSetting[Any], ...] = (
ConfigSetting(
dest="output",
section="PEAddon",
Expand All @@ -2175,7 +2172,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
)


UKI_PROFILE_SETTINGS = (
UKI_PROFILE_SETTINGS: tuple[ConfigSetting[Any], ...] = (
ConfigSetting(
dest="profile",
section="UKIProfile",
Expand All @@ -2189,7 +2186,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
)


SETTINGS = (
SETTINGS: tuple[ConfigSetting[Any], ...] = (
# Include section
ConfigSetting(
dest="include",
Expand Down Expand Up @@ -4062,19 +4059,19 @@ def parse_new_includes(self) -> None:
with chdir(path if path.is_dir() else Path.cwd()):
self.parse_config_one(path if path.is_file() else Path("."))

def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
def finalize_value(self, setting: ConfigSetting[T]) -> Optional[T]:
# If a value was specified on the CLI, it always takes priority. If the setting is a collection of
# values, we merge the value from the CLI with the value from the configuration, making sure that the
# value from the CLI always takes priority.
if hasattr(self.cli, setting.dest) and (v := getattr(self.cli, setting.dest)) is not None:
if isinstance(v, list):
return (getattr(self.config, setting.dest, None) or []) + v
return (getattr(self.config, setting.dest, None) or []) + v # type: ignore
elif isinstance(v, dict):
return (getattr(self.config, setting.dest, None) or {}) | v
return (getattr(self.config, setting.dest, None) or {}) | v # type: ignore
elif isinstance(v, set):
return (getattr(self.config, setting.dest, None) or set()) | v
return (getattr(self.config, setting.dest, None) or set()) | v # type: ignore
else:
return v
return v # type: ignore

# If the setting was assigned the empty string on the CLI, we don't use any value configured in the
# configuration file. Additionally, if the setting is a collection of values, we won't use any
Expand All @@ -4085,7 +4082,7 @@ def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
and hasattr(self.config, setting.dest)
and (v := getattr(self.config, setting.dest)) is not None
):
return v
return v # type: ignore

if (hasattr(self.cli, setting.dest) or hasattr(self.config, setting.dest)) and isinstance(
setting.parse(None, None), (dict, list, set)
Expand Down Expand Up @@ -4190,7 +4187,7 @@ def match_config(self, path: Path) -> bool:
return match_triggered is not False

def parse_config_one(self, path: Path, parse_profiles: bool = False, parse_local: bool = False) -> bool:
s: Optional[ConfigSetting] # Make mypy happy
s: Optional[ConfigSetting[Any]] # Hint to mypy that this may be None
extras = path.is_dir()

if path.is_dir():
Expand Down

0 comments on commit f1882f2

Please sign in to comment.