Skip to content

Commit

Permalink
Refactor ansible config get
Browse files Browse the repository at this point in the history
- change implementation to raise KeyError by default for missing keys
- allow setting a default value, including None

Follow-Up: #5 (comment)
  • Loading branch information
ssbarnea committed Jun 28, 2021
1 parent a2d6afd commit 2dc452e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/ansible_compat/prerun.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ansible_compat.loaders import yaml_from_file

_logger = logging.getLogger(__name__)
SENTINEL = object()


def check_ansible_presence(exit_on_error: bool = False) -> Tuple[str, str]:
Expand Down Expand Up @@ -335,7 +336,9 @@ def _update_env(varname: str, value: List[str], default: str = "") -> None:
_logger.info("Added %s=%s", varname, value_str)


def ansible_config_get(key: str, kind: Type[Any] = str) -> Union[str, List[str], None]:
def ansible_config_get(
key: str, kind: Type[Any] = str, default: object = SENTINEL
) -> Union[str, List[str], None, object]:
"""Return configuration item from ansible config."""
env = os.environ.copy()
# Avoid possible ANSI garbage
Expand All @@ -361,8 +364,11 @@ def ansible_config_get(key: str, kind: Type[Any] = str) -> Union[str, List[str],
raise RuntimeError(f"Unexpected data read for {key}: {val}")
return val
else:
raise RuntimeError("Unknown data type.")
return None
raise NotImplementedError("Unknown data type %s." % kind)

if default == SENTINEL:
raise KeyError(key)
return default


def require_collection( # noqa: C901
Expand Down
21 changes: 21 additions & 0 deletions test/test_prerun.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,27 @@ def test_ansible_config_get() -> None:
assert len(paths) > 0


@pytest.mark.parametrize(
"default",
(
(None,),
(123,),
),
)
def test_ansible_config_get_default(default: object) -> None:
"""Check that config get returns default when appropriate."""
result = prerun.ansible_config_get("NON_EXISTING_OPTION", default=default)
assert result is default


def test_ansible_config_get_raise() -> None:
"""Check that config get raise if key is not found."""
key = "NON_EXISTING_OPTION"
with pytest.raises(KeyError) as exc:
prerun.ansible_config_get(key)
assert exc.value.args[0] == key


def test_install_collection() -> None:
"""Check that valid collection installs do not fail."""
prerun.install_collection("containers.podman:>=1.0")
Expand Down

0 comments on commit 2dc452e

Please sign in to comment.