From 23b7caccebecbfeeefe4f7a0b4612c9caa4433cd Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 10:59:11 +0100 Subject: [PATCH 1/7] =?UTF-8?q?=E2=9C=A8=20implement=20Option=20with=20opt?= =?UTF-8?q?ional=20user=20value?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_annotated.py | 39 +++++++++++++ typer/core.py | 2 + typer/main.py | 120 +++++++++++++++++++++++++++++++++++----- 3 files changed, 146 insertions(+), 15 deletions(-) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 09072b3ae1..0b5371e348 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -1,3 +1,5 @@ +from enum import StrEnum, auto + import typer from typer.testing import CliRunner from typing_extensions import Annotated @@ -76,3 +78,40 @@ def cmd(force: Annotated[bool, typer.Option("--force")] = False): result = runner.invoke(app, ["--force"]) assert result.exit_code == 0, result.output assert "Forcing operation" in result.output + + +def test_annotated_option_accepts_optional_value(): + class OptEnum(StrEnum): + val1 = auto() + val2 = auto() + + app = typer.Typer() + + @app.command() + def cmd(opt: Annotated[bool | OptEnum, typer.Option()] = OptEnum.val1): + if opt is False: + print("False") + elif opt is True: + print("True") + else: + print(opt.value) + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "False" in result.output + + result = runner.invoke(app, ["--opt"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + result = runner.invoke(app, ["--opt", "val1"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + result = runner.invoke(app, ["--opt", "val2"]) + assert result.exit_code == 0, result.output + assert "val2" in result.output + + result = runner.invoke(app, ["--opt", "val3"]) + assert result.exit_code != 0 + assert "Invalid value for '--opt': 'val3' is not one of" in result.output diff --git a/typer/core.py b/typer/core.py index 8ec8b4b95d..f95f095779 100644 --- a/typer/core.py +++ b/typer/core.py @@ -420,6 +420,7 @@ def __init__( prompt_required: bool = True, hide_input: bool = False, is_flag: Optional[bool] = None, + flag_value: Optional[Any] = None, multiple: bool = False, count: bool = False, allow_from_autoenv: bool = True, @@ -446,6 +447,7 @@ def __init__( confirmation_prompt=confirmation_prompt, hide_input=hide_input, is_flag=is_flag, + flag_value=flag_value, multiple=multiple, count=count, allow_from_autoenv=allow_from_autoenv, diff --git a/typer/main.py b/typer/main.py index 55d865c780..b95a893874 100644 --- a/typer/main.py +++ b/typer/main.py @@ -619,30 +619,48 @@ def get_command_from_info( return command -def determine_type_convertor(type_: Any) -> Optional[Callable[[Any], Any]]: +def determine_type_convertor( + type_: Any, skip_bool: bool = False +) -> Optional[Callable[[Any], Any]]: convertor: Optional[Callable[[Any], Any]] = None if lenient_issubclass(type_, Path): - convertor = param_path_convertor + convertor = generate_path_convertor(skip_bool) if lenient_issubclass(type_, Enum): - convertor = generate_enum_convertor(type_) + convertor = generate_enum_convertor(type_, skip_bool) return convertor -def param_path_convertor(value: Optional[str] = None) -> Optional[Path]: - if value is not None: +def generate_path_convertor( + skip_bool: bool = False, +) -> Callable[[Any], Union[None, bool, Path]]: + def convertor(value: Optional[str] = None) -> Union[None, bool, Path]: + if value is None: + return None + + if isinstance(value, bool) and skip_bool: + return value + return Path(value) - return None + + return convertor -def generate_enum_convertor(enum: Type[Enum]) -> Callable[[Any], Any]: +def generate_enum_convertor( + enum: Type[Enum], skip_bool: bool = False +) -> Callable[[Any], Union[None, bool, Enum]]: val_map = {str(val.value): val for val in enum} - def convertor(value: Any) -> Any: - if value is not None: - val = str(value) - if val in val_map: - key = val_map[val] - return enum(key) + def convertor(value: Any) -> Union[None, bool, Enum]: + if value is None: + return None + + if isinstance(value, bool) and skip_bool: + return value + + val = str(value) + if val in val_map: + key = val_map[val] + return enum(key) return convertor @@ -809,6 +827,57 @@ def lenient_issubclass( return isinstance(cls, type) and issubclass(cls, class_or_tuple) +class ClickTypeUnion(click.ParamType): + def __init__(self, *types: click.ParamType) -> None: + self._types: tuple[click.ParamType, ...] = types + self.name: str = "|".join(t.name for t in types) + + def to_info_dict(self) -> Dict[str, Any]: + info_dict: Dict[str, Any] = {} + for t in self._types: + info_dict |= t.to_info_dict() + + return info_dict + + def get_metavar(self, param: click.Parameter) -> Optional[str]: + metavar_union: list[str] = [] + for t in self._types: + metavar = t.get_metavar(param) + if metavar is not None: + metavar_union.append(metavar) + + if not len(metavar_union): + return None + + return "|".join(metavar_union) + + def get_missing_message(self, param: click.Parameter) -> Optional[str]: + message_union: list[str] = [] + for t in self._types: + message = t.get_missing_message(param) + if message is not None: + message_union.append(message) + + if not len(message_union): + return None + + return "\n".join(message_union) + + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> Any: + fail_messages: list[str] = [] + + for t in self._types: + try: + return t.convert(value, param, ctx) + + except click.BadParameter as e: + fail_messages.append(e.message) + + self.fail(" or ".join(fail_messages), param, ctx) + + def get_click_param( param: ParamMeta, ) -> Tuple[Union[click.Argument, click.Option], Any]: @@ -836,10 +905,12 @@ def get_click_param( else: annotation = str main_type = annotation + secondary_type: bool | None = None is_list = False is_tuple = False parameter_type: Any = None is_flag = None + flag_value: Any = None origin = get_origin(main_type) if origin is not None: @@ -850,7 +921,17 @@ def get_click_param( if type_ is NoneType: continue types.append(type_) - assert len(types) == 1, "Typer Currently doesn't support Union types" + + if len(types) == 1: + main_type = types[0] + + else: + types = sorted(types, key=lambda t: t is bool) + main_type, secondary_type, *union_types = types + assert ( + not len(union_types) and secondary_type is bool + ), "Typer Currently doesn't support Union types" + main_type = types[0] origin = get_origin(main_type) # Handle Tuples and Lists @@ -875,7 +956,7 @@ def get_click_param( parameter_type = get_click_type( annotation=main_type, parameter_info=parameter_info ) - convertor = determine_type_convertor(main_type) + convertor = determine_type_convertor(main_type, skip_bool=secondary_type is bool) if is_list: convertor = generate_list_convertor( convertor=convertor, default_value=default_value @@ -888,6 +969,14 @@ def get_click_param( # Click doesn't accept a flag of type bool, only None, and then it sets it # to bool internally parameter_type = None + + elif secondary_type is bool: + is_flag = False + flag_value = default_value + default_value = False + assert parameter_type is not None + parameter_type = ClickTypeUnion(parameter_type, click.BOOL) + default_option_name = get_command_name(param.name) if is_flag: default_option_declaration = ( @@ -910,6 +999,7 @@ def get_click_param( prompt_required=parameter_info.prompt_required, hide_input=parameter_info.hide_input, is_flag=is_flag, + flag_value=flag_value, multiple=is_list, count=parameter_info.count, allow_from_autoenv=parameter_info.allow_from_autoenv, From 1262385220d3baad0a5f18e2af593a7b34a2f0c6 Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 12:10:02 +0100 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=9A=A8=20Add=20tests=20for=20option?= =?UTF-8?q?=20with=20optional=20value?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_annotated.py | 96 ++++++++++++++++++++++++++--------------- typer/main.py | 90 +++++++++++++++++++++++++++++++++++--- 2 files changed, 146 insertions(+), 40 deletions(-) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 0b5371e348..b94292c588 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -80,38 +80,64 @@ def cmd(force: Annotated[bool, typer.Option("--force")] = False): assert "Forcing operation" in result.output -def test_annotated_option_accepts_optional_value(): - class OptEnum(StrEnum): - val1 = auto() - val2 = auto() - - app = typer.Typer() - - @app.command() - def cmd(opt: Annotated[bool | OptEnum, typer.Option()] = OptEnum.val1): - if opt is False: - print("False") - elif opt is True: - print("True") - else: - print(opt.value) - - result = runner.invoke(app) - assert result.exit_code == 0, result.output - assert "False" in result.output - - result = runner.invoke(app, ["--opt"]) - assert result.exit_code == 0, result.output - assert "val1" in result.output - - result = runner.invoke(app, ["--opt", "val1"]) - assert result.exit_code == 0, result.output - assert "val1" in result.output - - result = runner.invoke(app, ["--opt", "val2"]) - assert result.exit_code == 0, result.output - assert "val2" in result.output - - result = runner.invoke(app, ["--opt", "val3"]) - assert result.exit_code != 0 - assert "Invalid value for '--opt': 'val3' is not one of" in result.output +class TestAnnotatedOptionAcceptsOptionalValue: + def test_enum(self): + app = typer.Typer() + + class OptEnum(StrEnum): + val1 = auto() + val2 = auto() + + @app.command() + def cmd(opt: Annotated[bool | OptEnum, typer.Option()] = OptEnum.val1): + if opt is False: + print("False") + elif opt is True: + print("True") + else: + print(opt.value) + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "False" in result.output + + result = runner.invoke(app, ["--opt"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + result = runner.invoke(app, ["--opt", "val1"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + result = runner.invoke(app, ["--opt", "val2"]) + assert result.exit_code == 0, result.output + assert "val2" in result.output + + result = runner.invoke(app, ["--opt", "val3"]) + assert result.exit_code != 0 + assert "Invalid value for '--opt': 'val3' is not one of" in result.output + + def test_int(self): + app = typer.Typer() + + @app.command() + def cmd(opt: Annotated[bool | int, typer.Option()] = 1): + print(opt) + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "False" in result.output + + result = runner.invoke(app, ["--opt"]) + assert result.exit_code == 0, result.output + assert "1" in result.output + + result = runner.invoke(app, ["--opt", "2"]) + assert result.exit_code == 0, result.output + assert "2" in result.output + + result = runner.invoke(app, ["--opt", "test"]) + assert result.exit_code != 0 + assert ( + "Invalid value for '--opt': 'test' is not a valid integer" in result.output + ) diff --git a/typer/main.py b/typer/main.py index b95a893874..7d3d44221b 100644 --- a/typer/main.py +++ b/typer/main.py @@ -8,14 +8,25 @@ from datetime import datetime from enum import Enum from functools import update_wrapper +from gettext import gettext from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from uuid import UUID import click -from typing_extensions import get_args, get_origin +from typing_extensions import get_args, get_origin, override from ._typing import is_union from .completion import get_completion_inspect_parameters @@ -832,6 +843,11 @@ def __init__(self, *types: click.ParamType) -> None: self._types: tuple[click.ParamType, ...] = types self.name: str = "|".join(t.name for t in types) + @override + def __repr__(self) -> str: + return "|".join(repr(t) for t in self._types) + + @override def to_info_dict(self) -> Dict[str, Any]: info_dict: Dict[str, Any] = {} for t in self._types: @@ -839,6 +855,7 @@ def to_info_dict(self) -> Dict[str, Any]: return info_dict + @override def get_metavar(self, param: click.Parameter) -> Optional[str]: metavar_union: list[str] = [] for t in self._types: @@ -851,6 +868,7 @@ def get_metavar(self, param: click.Parameter) -> Optional[str]: return "|".join(metavar_union) + @override def get_missing_message(self, param: click.Parameter) -> Optional[str]: message_union: list[str] = [] for t in self._types: @@ -863,6 +881,7 @@ def get_missing_message(self, param: click.Parameter) -> Optional[str]: return "\n".join(message_union) + @override def convert( self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] ) -> Any: @@ -873,9 +892,68 @@ def convert( return t.convert(value, param, ctx) except click.BadParameter as e: - fail_messages.append(e.message) + if not getattr(t, "union_ignore_fail_message", False): + fail_messages.append(e.message) + + self.fail(" and ".join(fail_messages), param, ctx) + + +class BoolLiteral(click.types.BoolParamType): + union_ignore_fail_message: bool = True + name: str = "boolean literal" + + @override + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> Any: + value_ = str(value) + norm = value_.strip().lower() + + # do not cast "1" + if norm in {"True", "true", "t", "yes", "y", "on"}: + return True + + # do not cast "0" + if norm in {"False", "false", "f", "no", "n", "off"}: + return False + + self.fail( + gettext("{value!r} is not a valid boolean literal.").format(value=value_), + param, + ctx, + ) - self.fail(" or ".join(fail_messages), param, ctx) + @override + def __repr__(self) -> str: + return "BOOL(Literal)" + + +class BoolInteger(click.ParamType): + union_ignore_fail_message: bool = True + name: str = "boolean integer" + + @override + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> Any: + value_ = str(value) + norm = value_.strip() + + if norm == "1": + return True + + if norm == "0": + return False + + self.fail( + gettext("{value!r} is not a valid boolean integer.").format(value=value_), + param, + ctx, + ) + + @override + def __repr__(self) -> str: + return "BOOL(int)" def get_click_param( @@ -975,7 +1053,9 @@ def get_click_param( flag_value = default_value default_value = False assert parameter_type is not None - parameter_type = ClickTypeUnion(parameter_type, click.BOOL) + parameter_type = ClickTypeUnion( + BoolLiteral(), parameter_type, BoolInteger() + ) default_option_name = get_command_name(param.name) if is_flag: From 0846450313a26cfcca4115aef9c06afe0889c4f1 Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 13:39:56 +0100 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=93=9A=20Update=20docs=20for=20Option?= =?UTF-8?q?'s=20optional=20value?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/options/optional_value.md | 65 +++++++++++++++++++ .../options/optional_value/tutorial001.py | 19 ++++++ .../options/optional_value/tutorial001_an.py | 22 +++++++ mkdocs.yml | 1 + 4 files changed, 107 insertions(+) create mode 100644 docs/tutorial/options/optional_value.md create mode 100644 docs_src/options/optional_value/tutorial001.py create mode 100644 docs_src/options/optional_value/tutorial001_an.py diff --git a/docs/tutorial/options/optional_value.md b/docs/tutorial/options/optional_value.md new file mode 100644 index 0000000000..db5340f150 --- /dev/null +++ b/docs/tutorial/options/optional_value.md @@ -0,0 +1,65 @@ +# Optional value for CLI Options + +As in Click, providing a value to a *CLI option* can be made optional, in which case a default value will be used instead. + +To make a *CLI option*'s value optional, you can annotate it as a *Union* of types *bool* and the parameter type. + +/// info + +You can create a type Union by importing *Union* from the typing module. + +For example `Union[bool, str]` represents a type that is either a boolean or a string. + +You can also use the equivalent notation `bool | str` + +/// + +Let's add a *CLI option* `--tone` with optional value: + +{* docs_src/options/optional_value/tutorial001_an.py hl[5] *} + +Now, there are three possible configurations: + +* `--greeting` is not used, the parameter will receive a value of `False`. +``` +python main.py +``` + +* `--greeting` is supplied with a value, the parameter will receive the string representation of that value. +``` +python main.py --greeting +``` + +* `--greeting` is used with no value, the parameter will receive the default `formal` value. +``` +python main.py --greeting +``` + + +And test it: + +
+ +```console +$ python main.py Camila Gutiérrez + +// We didn't pass the greeting CLI option, we get no greeting + + +// Now update it to pass the --greeting CLI option with default value +$ python main.py Camila Gutiérrez --greeting + +Hello Camila Gutiérrez + +// The above is equivalent to passing the --greeting CLI option with value `formal` +$ python main.py Camila Gutiérrez --greeting formal + +Hi Camila ! + +// But you can select another value +$ python main.py Camila Gutiérrez --greeting casual + +Hi Camila ! +``` + +
diff --git a/docs_src/options/optional_value/tutorial001.py b/docs_src/options/optional_value/tutorial001.py new file mode 100644 index 0000000000..d731afaece --- /dev/null +++ b/docs_src/options/optional_value/tutorial001.py @@ -0,0 +1,19 @@ +import typer + + +def main(name: str, lastname: str, greeting: bool | str = "formal"): + if not greeting: + return + + if greeting == "formal": + print(f"Hello {name} {lastname}") + + elif greeting == "casual": + print(f"Hi {name} !") + + else: + raise ValueError(f"Invalid greeting '{greeting}'") + + +if __name__ == "__main__": + typer.run(main) diff --git a/docs_src/options/optional_value/tutorial001_an.py b/docs_src/options/optional_value/tutorial001_an.py new file mode 100644 index 0000000000..ecd097d066 --- /dev/null +++ b/docs_src/options/optional_value/tutorial001_an.py @@ -0,0 +1,22 @@ +import typer +from typing_extensions import Annotated + + +def main( + name: str, lastname: str, greeting: Annotated[bool | str, typer.Option()] = "formal" +): + if not greeting: + return + + if greeting == "formal": + print(f"Hello {name} {lastname}") + + elif greeting == "casual": + print(f"Hi {name} !") + + else: + raise ValueError(f"Invalid greeting '{greeting}'") + + +if __name__ == "__main__": + typer.run(main) diff --git a/mkdocs.yml b/mkdocs.yml index 042d7ad116..e1bbfc6d59 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -89,6 +89,7 @@ nav: - tutorial/options/index.md - tutorial/options/help.md - tutorial/options/required.md + - tutorial/options/optional_value.md - tutorial/options/prompt.md - tutorial/options/password.md - tutorial/options/name.md From e8b83d68dfbc87b33d0baa7cd3f0ae26e392e425 Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 15:06:15 +0100 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=90=9B=20Fix=20help=20message=20for?= =?UTF-8?q?=20Option=20with=20optional=20value?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/tutorial/options/optional_value.md | 2 +- typer/main.py | 124 +++++++----------------- 2 files changed, 34 insertions(+), 92 deletions(-) diff --git a/docs/tutorial/options/optional_value.md b/docs/tutorial/options/optional_value.md index db5340f150..d200819e5c 100644 --- a/docs/tutorial/options/optional_value.md +++ b/docs/tutorial/options/optional_value.md @@ -54,7 +54,7 @@ Hello Camila Gutiérrez // The above is equivalent to passing the --greeting CLI option with value `formal` $ python main.py Camila Gutiérrez --greeting formal -Hi Camila ! +Hello Camila Gutiérrez // But you can select another value $ python main.py Camila Gutiérrez --greeting casual diff --git a/typer/main.py b/typer/main.py index 7d3d44221b..814b2034b2 100644 --- a/typer/main.py +++ b/typer/main.py @@ -8,7 +8,6 @@ from datetime import datetime from enum import Enum from functools import update_wrapper -from gettext import gettext from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType @@ -838,122 +837,67 @@ def lenient_issubclass( return isinstance(cls, type) and issubclass(cls, class_or_tuple) -class ClickTypeUnion(click.ParamType): - def __init__(self, *types: click.ParamType) -> None: - self._types: tuple[click.ParamType, ...] = types - self.name: str = "|".join(t.name for t in types) +class DefaultOption(click.ParamType): + def __init__(self, type_: click.ParamType, default: Any) -> None: + self._type: click.ParamType = type_ + self._default: Any = default + self.name: str = f"BOOLEAN|{type_.name}" @override def __repr__(self) -> str: - return "|".join(repr(t) for t in self._types) + return f"DefaultOption({self._type})" @override def to_info_dict(self) -> Dict[str, Any]: - info_dict: Dict[str, Any] = {} - for t in self._types: - info_dict |= t.to_info_dict() - - return info_dict + return self._type.to_info_dict() @override def get_metavar(self, param: click.Parameter) -> Optional[str]: - metavar_union: list[str] = [] - for t in self._types: - metavar = t.get_metavar(param) - if metavar is not None: - metavar_union.append(metavar) - - if not len(metavar_union): - return None - - return "|".join(metavar_union) + return self._type.get_metavar(param) @override def get_missing_message(self, param: click.Parameter) -> Optional[str]: - message_union: list[str] = [] - for t in self._types: - message = t.get_missing_message(param) - if message is not None: - message_union.append(message) - - if not len(message_union): - return None - - return "\n".join(message_union) - - @override - def convert( - self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] - ) -> Any: - fail_messages: list[str] = [] - - for t in self._types: - try: - return t.convert(value, param, ctx) - - except click.BadParameter as e: - if not getattr(t, "union_ignore_fail_message", False): - fail_messages.append(e.message) - - self.fail(" and ".join(fail_messages), param, ctx) - - -class BoolLiteral(click.types.BoolParamType): - union_ignore_fail_message: bool = True - name: str = "boolean literal" + return self._type.get_missing_message(param) @override def convert( self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] ) -> Any: - value_ = str(value) - norm = value_.strip().lower() + str_value = str(value).strip().lower() - # do not cast "1" - if norm in {"True", "true", "t", "yes", "y", "on"}: - return True + if str_value in {"True", "true", "t", "yes", "y", "on"}: + return self._default - # do not cast "0" - if norm in {"False", "false", "f", "no", "n", "off"}: + if str_value in {"False", "false", "f", "no", "n", "off"}: return False - self.fail( - gettext("{value!r} is not a valid boolean literal.").format(value=value_), - param, - ctx, - ) + if isinstance(value, DefaultFalse): + return False - @override - def __repr__(self) -> str: - return "BOOL(Literal)" + try: + return self._type.convert(value, param, ctx) + except click.BadParameter as e: + fail = e -class BoolInteger(click.ParamType): - union_ignore_fail_message: bool = True - name: str = "boolean integer" + if str_value == "1": + return self._default - @override - def convert( - self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] - ) -> Any: - value_ = str(value) - norm = value_.strip() + if str_value == "0": + return False - if norm == "1": - return True + raise fail - if norm == "0": - return False - self.fail( - gettext("{value!r} is not a valid boolean integer.").format(value=value_), - param, - ctx, - ) +class DefaultFalse: + def __init__(self, value: Any) -> None: + self._value = value - @override def __repr__(self) -> str: - return "BOOL(int)" + return f"False ({repr(self._value)})" + + def __str__(self) -> str: + return f"False ({str(self._value)})" def get_click_param( @@ -1051,11 +995,9 @@ def get_click_param( elif secondary_type is bool: is_flag = False flag_value = default_value - default_value = False assert parameter_type is not None - parameter_type = ClickTypeUnion( - BoolLiteral(), parameter_type, BoolInteger() - ) + parameter_type = DefaultOption(parameter_type, default=default_value) + default_value = DefaultFalse(default_value) default_option_name = get_command_name(param.name) if is_flag: From 8bb363d1662e5091c3bdb33e2b06e8f14cbd0419 Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 15:28:03 +0100 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=90=9B=20Fix=20mypy=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- typer/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/typer/main.py b/typer/main.py index 814b2034b2..9ece09e037 100644 --- a/typer/main.py +++ b/typer/main.py @@ -672,6 +672,8 @@ def convertor(value: Any) -> Union[None, bool, Enum]: key = val_map[val] return enum(key) + return None + return convertor @@ -927,7 +929,7 @@ def get_click_param( else: annotation = str main_type = annotation - secondary_type: bool | None = None + secondary_type: type[bool] | None = None is_list = False is_tuple = False parameter_type: Any = None From 9e7afc39c18ed1c46ca21e5de89fd93a0bf0127c Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 15:37:46 +0100 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=90=9B=20Fix=20typing=20&=20imports?= =?UTF-8?q?=20for=20python=20<=203.11?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_annotated.py | 13 +++++++------ typer/main.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index b94292c588..25622d0ebf 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -1,4 +1,5 @@ -from enum import StrEnum, auto +from enum import Enum +from typing import Union import typer from typer.testing import CliRunner @@ -84,12 +85,12 @@ class TestAnnotatedOptionAcceptsOptionalValue: def test_enum(self): app = typer.Typer() - class OptEnum(StrEnum): - val1 = auto() - val2 = auto() + class OptEnum(str, Enum): + val1 = "val1" + val2 = "val2" @app.command() - def cmd(opt: Annotated[bool | OptEnum, typer.Option()] = OptEnum.val1): + def cmd(opt: Annotated[Union[bool, OptEnum], typer.Option()] = OptEnum.val1): if opt is False: print("False") elif opt is True: @@ -121,7 +122,7 @@ def test_int(self): app = typer.Typer() @app.command() - def cmd(opt: Annotated[bool | int, typer.Option()] = 1): + def cmd(opt: Annotated[Union[bool, int], typer.Option()] = 1): print(opt) result = runner.invoke(app) diff --git a/typer/main.py b/typer/main.py index 9ece09e037..53270b119e 100644 --- a/typer/main.py +++ b/typer/main.py @@ -929,7 +929,7 @@ def get_click_param( else: annotation = str main_type = annotation - secondary_type: type[bool] | None = None + secondary_type: Union[Type[bool], None] = None is_list = False is_tuple = False parameter_type: Any = None From 37b74b44bbbb6a79ec2104885fd2ad9d58c109f9 Mon Sep 17 00:00:00 2001 From: MatteoBouvierVidium Date: Tue, 26 Nov 2024 23:23:38 +0100 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=9A=A8=20Get=20full=20test=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_annotated.py | 66 ---------------------- tests/test_type_conversion.py | 101 +++++++++++++++++++++++++++++++++- typer/main.py | 31 +++-------- 3 files changed, 107 insertions(+), 91 deletions(-) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 25622d0ebf..09072b3ae1 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -1,6 +1,3 @@ -from enum import Enum -from typing import Union - import typer from typer.testing import CliRunner from typing_extensions import Annotated @@ -79,66 +76,3 @@ def cmd(force: Annotated[bool, typer.Option("--force")] = False): result = runner.invoke(app, ["--force"]) assert result.exit_code == 0, result.output assert "Forcing operation" in result.output - - -class TestAnnotatedOptionAcceptsOptionalValue: - def test_enum(self): - app = typer.Typer() - - class OptEnum(str, Enum): - val1 = "val1" - val2 = "val2" - - @app.command() - def cmd(opt: Annotated[Union[bool, OptEnum], typer.Option()] = OptEnum.val1): - if opt is False: - print("False") - elif opt is True: - print("True") - else: - print(opt.value) - - result = runner.invoke(app) - assert result.exit_code == 0, result.output - assert "False" in result.output - - result = runner.invoke(app, ["--opt"]) - assert result.exit_code == 0, result.output - assert "val1" in result.output - - result = runner.invoke(app, ["--opt", "val1"]) - assert result.exit_code == 0, result.output - assert "val1" in result.output - - result = runner.invoke(app, ["--opt", "val2"]) - assert result.exit_code == 0, result.output - assert "val2" in result.output - - result = runner.invoke(app, ["--opt", "val3"]) - assert result.exit_code != 0 - assert "Invalid value for '--opt': 'val3' is not one of" in result.output - - def test_int(self): - app = typer.Typer() - - @app.command() - def cmd(opt: Annotated[Union[bool, int], typer.Option()] = 1): - print(opt) - - result = runner.invoke(app) - assert result.exit_code == 0, result.output - assert "False" in result.output - - result = runner.invoke(app, ["--opt"]) - assert result.exit_code == 0, result.output - assert "1" in result.output - - result = runner.invoke(app, ["--opt", "2"]) - assert result.exit_code == 0, result.output - assert "2" in result.output - - result = runner.invoke(app, ["--opt", "test"]) - assert result.exit_code != 0 - assert ( - "Invalid value for '--opt': 'test' is not a valid integer" in result.output - ) diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 904a686d2e..0102ac1d2d 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -1,11 +1,12 @@ from enum import Enum from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import click import pytest import typer from typer.testing import CliRunner +from typing_extensions import Annotated from .utils import needs_py310 @@ -169,3 +170,101 @@ def custom_click_type( result = runner.invoke(app, ["0x56"]) assert result.exit_code == 0 + + +class TestOptionAcceptsOptionalValue: + def test_enum(self): + app = typer.Typer() + + class OptEnum(str, Enum): + val1 = "val1" + val2 = "val2" + + @app.command() + def cmd(opt: Annotated[Union[bool, OptEnum], typer.Option()] = OptEnum.val1): + if opt is False: + print("False") + + else: + print(opt.value) + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "False" in result.output + + result = runner.invoke(app, ["--opt"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + result = runner.invoke(app, ["--opt", "val1"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + result = runner.invoke(app, ["--opt", "val2"]) + assert result.exit_code == 0, result.output + assert "val2" in result.output + + result = runner.invoke(app, ["--opt", "val3"]) + assert result.exit_code != 0 + assert "Invalid value for '--opt': 'val3' is not one of" in result.output + + result = runner.invoke(app, ["--opt", "0"]) + assert result.exit_code == 0, result.output + assert "False" in result.output + + result = runner.invoke(app, ["--opt", "1"]) + assert result.exit_code == 0, result.output + assert "val1" in result.output + + def test_int(self): + app = typer.Typer() + + @app.command() + def cmd(opt: Annotated[Union[bool, int], typer.Option()] = 1): + print(opt) + + result = runner.invoke(app) + assert result.exit_code == 0, result.output + assert "False" in result.output + + result = runner.invoke(app, ["--opt"]) + assert result.exit_code == 0, result.output + assert "1" in result.output + + result = runner.invoke(app, ["--opt", "2"]) + assert result.exit_code == 0, result.output + assert "2" in result.output + + result = runner.invoke(app, ["--opt", "test"]) + assert result.exit_code != 0 + assert ( + "Invalid value for '--opt': 'test' is not a valid integer" in result.output + ) + + result = runner.invoke(app, ["--opt", "true"]) + assert result.exit_code == 0, result.output + assert "1" in result.output + + result = runner.invoke(app, ["--opt", "off"]) + assert result.exit_code == 0, result.output + assert "False" in result.output + + def test_path(self): + app = typer.Typer() + + @app.command() + def cmd(opt: Annotated[Union[bool, Path], typer.Option()] = Path(".")): + if isinstance(opt, Path): + print((opt / "file.py").as_posix()) + + result = runner.invoke(app, ["--opt"]) + assert result.exit_code == 0, result.output + assert "file.py" in result.output + + result = runner.invoke(app, ["--opt", "/test/path/file.py"]) + assert result.exit_code == 0, result.output + assert "/test/path/file.py" in result.output + + result = runner.invoke(app, ["--opt", "False"]) + assert result.exit_code == 0, result.output + assert "file.py" not in result.output diff --git a/typer/main.py b/typer/main.py index 53270b119e..93bb0b378f 100644 --- a/typer/main.py +++ b/typer/main.py @@ -660,19 +660,21 @@ def generate_enum_convertor( ) -> Callable[[Any], Union[None, bool, Enum]]: val_map = {str(val.value): val for val in enum} - def convertor(value: Any) -> Union[None, bool, Enum]: - if value is None: - return None - + def convertor(value: Any) -> Union[bool, Enum]: if isinstance(value, bool) and skip_bool: return value + if isinstance(value, enum): + return value + val = str(value) if val in val_map: key = val_map[val] return enum(key) - return None + raise click.BadParameter( + f"Invalid value '{value}' for enum '{enum.__name__}'" + ) # pragma: no cover return convertor @@ -845,22 +847,6 @@ def __init__(self, type_: click.ParamType, default: Any) -> None: self._default: Any = default self.name: str = f"BOOLEAN|{type_.name}" - @override - def __repr__(self) -> str: - return f"DefaultOption({self._type})" - - @override - def to_info_dict(self) -> Dict[str, Any]: - return self._type.to_info_dict() - - @override - def get_metavar(self, param: click.Parameter) -> Optional[str]: - return self._type.get_metavar(param) - - @override - def get_missing_message(self, param: click.Parameter) -> Optional[str]: - return self._type.get_missing_message(param) - @override def convert( self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] @@ -895,9 +881,6 @@ class DefaultFalse: def __init__(self, value: Any) -> None: self._value = value - def __repr__(self) -> str: - return f"False ({repr(self._value)})" - def __str__(self) -> str: return f"False ({str(self._value)})"