Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ConfigSetting generic #3215

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 67 additions & 52 deletions mkosi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections.abc import Collection, Iterable, Iterator, Sequence
from contextlib import AbstractContextManager
from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union, cast
from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast

from mkosi.distributions import Distribution, detect_distribution
from mkosi.log import ARG_DEBUG, ARG_DEBUG_SANDBOX, ARG_DEBUG_SHELL, Style, die
Expand All @@ -48,11 +48,11 @@
from mkosi.versioncomp import GenericVersion

T = TypeVar("T")
SE = TypeVar("SE", bound=StrEnum)

ConfigParseCallback = Callable[[Optional[str], Optional[Any]], Any]
ConfigMatchCallback = Callable[[str, Any], bool]
ConfigDefaultCallback = Callable[[argparse.Namespace], Any]

ConfigParseCallback = Callable[[Optional[str], Optional[T]], Optional[T]]
ConfigMatchCallback = Callable[[str, T], bool]
ConfigDefaultCallback = Callable[[argparse.Namespace], T]

BUILTIN_CONFIGS = ("mkosi-tools", "mkosi-initrd", "mkosi-vm")

Expand Down Expand Up @@ -676,7 +676,7 @@ def config_match_build_sources(match: str, value: list[ConfigTree]) -> bool:
return Path(match.lstrip("/")) in [tree.target for tree in value if tree.target]


def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback:
def config_make_list_matcher(parse: Callable[[str], T]) -> ConfigMatchCallback[list[T]]:
def config_match_list(match: str, value: list[T]) -> bool:
return parse(match) in value

Expand All @@ -687,7 +687,7 @@ def config_parse_string(value: Optional[str], old: Optional[str]) -> Optional[st
return value or None


def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback:
def config_make_string_matcher(allow_globs: bool = False) -> ConfigMatchCallback[str]:
def config_match_string(match: str, value: str) -> bool:
if allow_globs:
return fnmatch.fnmatchcase(value, match)
Expand Down Expand Up @@ -906,8 +906,8 @@ def config_default_proxy_url(namespace: argparse.Namespace) -> Optional[str]:
return None


def make_enum_parser(type: type[StrEnum]) -> Callable[[str], StrEnum]:
def parse_enum(value: str) -> StrEnum:
def make_enum_parser(type: type[SE]) -> Callable[[str], SE]:
def parse_enum(value: str) -> SE:
try:
return type(value)
except ValueError:
Expand All @@ -916,17 +916,15 @@ def parse_enum(value: str) -> StrEnum:
return parse_enum


def config_make_enum_parser(type: type[StrEnum]) -> ConfigParseCallback:
def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
def config_make_enum_parser(type: type[SE]) -> ConfigParseCallback[SE]:
def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
return make_enum_parser(type)(value) if value else None

return config_parse_enum


def config_make_enum_parser_with_boolean(
type: type[StrEnum], *, yes: StrEnum, no: StrEnum
) -> ConfigParseCallback:
def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[StrEnum]:
def config_make_enum_parser_with_boolean(type: type[SE], *, yes: SE, no: SE) -> ConfigParseCallback[SE]:
def config_parse_enum(value: Optional[str], old: Optional[SE]) -> Optional[SE]:
if not value:
return None

Expand All @@ -938,8 +936,8 @@ def config_parse_enum(value: Optional[str], old: Optional[StrEnum]) -> Optional[
return config_parse_enum


def config_make_enum_matcher(type: type[StrEnum]) -> ConfigMatchCallback:
def config_match_enum(match: str, value: StrEnum) -> bool:
def config_make_enum_matcher(type: type[SE]) -> ConfigMatchCallback[SE]:
def config_match_enum(match: str, value: SE) -> bool:
return make_enum_parser(type)(match) == value

return config_match_enum
Expand All @@ -948,11 +946,11 @@ def config_match_enum(match: str, value: StrEnum) -> bool:
def config_make_list_parser(
*,
delimiter: Optional[str] = None,
parse: Callable[[str], Any] = str,
parse: Callable[[str], T] = str, # type: ignore # see mypy#3737
behrmann marked this conversation as resolved.
Show resolved Hide resolved
unescape: bool = False,
reset: bool = True,
) -> ConfigParseCallback:
def config_parse_list(value: Optional[str], old: Optional[list[Any]]) -> Optional[list[Any]]:
) -> ConfigParseCallback[list[T]]:
def config_parse_list(value: Optional[str], old: Optional[list[T]]) -> Optional[list[T]]:
new = old.copy() if old else []

if value is None:
Expand Down Expand Up @@ -1010,12 +1008,12 @@ def config_match_version(match: str, value: str) -> bool:
def config_make_dict_parser(
*,
delimiter: Optional[str] = None,
parse: Callable[[str], tuple[str, Any]],
parse: Callable[[str], tuple[str, str]],
unescape: bool = False,
allow_paths: bool = False,
reset: bool = True,
) -> ConfigParseCallback:
def config_parse_dict(value: Optional[str], old: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
) -> ConfigParseCallback[dict[str, str]]:
def config_parse_dict(value: Optional[str], old: Optional[dict[str, str]]) -> Optional[dict[str, str]]:
septatrix marked this conversation as resolved.
Show resolved Hide resolved
new = old.copy() if old else {}

if value is None:
Expand Down Expand Up @@ -1104,7 +1102,7 @@ def config_make_path_parser(
expandvars: bool = True,
secret: bool = False,
constants: Sequence[str] = (),
) -> ConfigParseCallback:
) -> ConfigParseCallback[Path]:
def config_parse_path(value: Optional[str], old: Optional[Path]) -> Optional[Path]:
if not value:
return None
Expand All @@ -1127,7 +1125,7 @@ def is_valid_filename(s: str) -> bool:
return not (s == "." or s == ".." or "/" in s)


def config_make_filename_parser(hint: str) -> ConfigParseCallback:
def config_make_filename_parser(hint: str) -> ConfigParseCallback[str]:
def config_parse_filename(value: Optional[str], old: Optional[str]) -> Optional[str]:
if not value:
return None
Expand Down Expand Up @@ -1389,7 +1387,8 @@ def config_parse_artifact_output_list(
if boolean_value is not None:
return ArtifactOutput.compat_yes() if boolean_value else ArtifactOutput.compat_no()

list_value = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))(value, old)
list_parser = config_make_list_parser(delimiter=",", parse=make_enum_parser(ArtifactOutput))
list_value = list_parser(value, old)
return cast(list[ArtifactOutput], list_value)


Expand All @@ -1403,14 +1402,14 @@ class SettingScope(StrEnum):


@dataclasses.dataclass(frozen=True)
class ConfigSetting:
class ConfigSetting(Generic[T]):
dest: str
section: str
parse: ConfigParseCallback = config_parse_string
match: Optional[ConfigMatchCallback] = None
parse: ConfigParseCallback[T] = config_parse_string # type: ignore # see mypy#3737
match: Optional[ConfigMatchCallback[T]] = None
name: str = ""
default: Any = None
default_factory: Optional[ConfigDefaultCallback] = None
default: Optional[T] = None
default_factory: Optional[ConfigDefaultCallback[T]] = None
default_factory_depends: tuple[str, ...] = tuple()
paths: tuple[str, ...] = ()
recursive_paths: tuple[str, ...] = ()
Expand All @@ -1422,7 +1421,7 @@ class ConfigSetting:
# settings for argparse
short: Optional[str] = None
long: str = ""
choices: Optional[Any] = None
choices: Optional[list[str]] = None
metavar: Optional[str] = None
nargs: Optional[str] = None
const: Optional[Any] = None
Expand Down Expand Up @@ -1529,7 +1528,7 @@ def __call__(
parser.exit()


def dict_with_capitalised_keys_factory(pairs: Any) -> dict[str, Any]:
def dict_with_capitalised_keys_factory(pairs: list[tuple[str, T]]) -> dict[str, T]:
def key_transformer(k: str) -> str:
if (s := SETTINGS_LOOKUP_BY_DEST.get(k)) is not None:
return s.name
Expand Down Expand Up @@ -1634,11 +1633,14 @@ class UKIProfile:
cmdline: list[str]


def make_simple_config_parser(settings: Sequence[ConfigSetting], type: type[Any]) -> Callable[[str], Any]:
def make_simple_config_parser(
settings: Sequence[ConfigSetting[object]],
valtype: type[T],
) -> Callable[[str], T]:
lookup_by_name = {s.name: s for s in settings}
lookup_by_dest = {s.dest: s for s in settings}

def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:
def finalize_value(config: argparse.Namespace, setting: ConfigSetting[object]) -> None:
if hasattr(config, setting.dest):
return

Expand All @@ -1654,7 +1656,7 @@ def finalize_value(config: argparse.Namespace, setting: ConfigSetting) -> None:

setattr(config, setting.dest, default)

def parse_simple_config(value: str) -> Any:
def parse_simple_config(value: str) -> T:
path = parse_path(value)
config = argparse.Namespace()

Expand All @@ -1681,7 +1683,9 @@ def parse_simple_config(value: str) -> Any:
for setting in settings:
finalize_value(config, setting)

return type(**{k: v for k, v in vars(config).items() if k in inspect.signature(type).parameters})
return valtype(
**{k: v for k, v in vars(config).items() if k in inspect.signature(valtype).parameters}
)

return parse_simple_config

Expand Down Expand Up @@ -2160,7 +2164,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
yield section, "", ""


PE_ADDON_SETTINGS = (
PE_ADDON_SETTINGS: list[ConfigSetting[Any]] = [
DaanDeMeyer marked this conversation as resolved.
Show resolved Hide resolved
ConfigSetting(
dest="output",
section="PEAddon",
Expand All @@ -2172,10 +2176,10 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
section="PEAddon",
parse=config_make_list_parser(delimiter=" "),
),
)
]


UKI_PROFILE_SETTINGS = (
UKI_PROFILE_SETTINGS: list[ConfigSetting[Any]] = [
ConfigSetting(
dest="profile",
section="UKIProfile",
Expand All @@ -2186,10 +2190,10 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
section="UKIProfile",
parse=config_make_list_parser(delimiter=" "),
),
)
]


SETTINGS = (
SETTINGS: list[ConfigSetting[Any]] = [
# Include section
ConfigSetting(
dest="include",
Expand Down Expand Up @@ -3665,7 +3669,7 @@ def parse_ini(path: Path, only_sections: Collection[str] = ()) -> Iterator[tuple
# arguments.
help=argparse.SUPPRESS,
),
)
]
SETTINGS_LOOKUP_BY_NAME = {name: s for s in SETTINGS for name in [s.name, *s.compat_names]}
SETTINGS_LOOKUP_BY_DEST = {s.dest: s for s in SETTINGS}
SETTINGS_LOOKUP_BY_SPECIFIER = {s.specifier: s for s in SETTINGS if s.specifier}
Expand Down Expand Up @@ -4065,19 +4069,30 @@ def parse_new_includes(self) -> None:
with chdir(path if path.is_dir() else Path.cwd()):
self.parse_config_one(path if path.is_file() else Path("."))

def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
def finalize_value(self, setting: ConfigSetting[T]) -> Optional[T]:
# If a value was specified on the CLI, it always takes priority. If the setting is a collection of
# values, we merge the value from the CLI with the value from the configuration, making sure that the
# value from the CLI always takes priority.
if (v := getattr(self.cli, setting.dest, None)) is not None:
if getattr(self.cli, f"{setting.dest}_was_none", False):
if (v := cast(Optional[T], getattr(self.cli, setting.dest, None))) is not None:
cfg_value = getattr(self.config, setting.dest, None)
# We either have no corresponding value in the config files
# or the values was assigned the empty string on the CLI
# and should thus be treated as a reset and override of the value from the config file.
if cfg_value is None or getattr(self.cli, f"{setting.dest}_was_none", False):
return v
elif isinstance(v, list):
return (getattr(self.config, setting.dest, None) or []) + v

# The instance asserts are pushed down to help mypy/pylance narrow the types.
# Mypy still cannot properly infer that the merged collections conform to T
# so we ignore the return-value error for it.
if isinstance(v, list):
assert isinstance(cfg_value, type(v))
return cfg_value + v # type: ignore[return-value]
elif isinstance(v, dict):
return (getattr(self.config, setting.dest, None) or {}) | v
assert isinstance(cfg_value, type(v))
return cfg_value | v # type: ignore[return-value]
elif isinstance(v, set):
return (getattr(self.config, setting.dest, None) or set()) | v
assert isinstance(cfg_value, type(v))
return cfg_value | v # type: ignore[return-value]
else:
return v

Expand All @@ -4088,7 +4103,7 @@ def finalize_value(self, setting: ConfigSetting) -> Optional[Any]:
if (
not hasattr(self.cli, setting.dest)
and hasattr(self.config, setting.dest)
and (v := getattr(self.config, setting.dest)) is not None
and (v := cast(Optional[T], getattr(self.config, setting.dest))) is not None
):
return v

Expand Down Expand Up @@ -4195,7 +4210,7 @@ def match_config(self, path: Path) -> bool:
return match_triggered is not False

def parse_config_one(self, path: Path, parse_profiles: bool = False, parse_local: bool = False) -> bool:
s: Optional[ConfigSetting] # Make mypy happy
s: Optional[ConfigSetting[object]] # Hint to mypy that we might assign None
extras = path.is_dir()

if path.is_dir():
Expand Down
Loading