Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CliMutuallyExclusiveGroup. #473

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,44 @@ For `BaseModel` and `pydantic.dataclasses.dataclass` types, `CliApp.run` will in
The alias generator for kebab case does not propagate to subcommands or submodels and will have to be manually set
in these cases.

### Mutually Exclusive Groups

CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.

!!! note
A `CliMutuallyExclusiveGroup` cannot be used in a union or contain nested models.

```py
from typing import Optional

from pydantic import BaseModel

from pydantic_settings import CliApp, CliMutuallyExclusiveGroup, SettingsError


class Circle(CliMutuallyExclusiveGroup):
radius: Optional[float] = None
diameter: Optional[float] = None
perimeter: Optional[float] = None


class Settings(BaseModel):
circle: Circle


try:
CliApp.run(
Settings,
cli_args=['--circle.radius=1', '--circle.diameter=2'],
cli_exit_on_error=False,
)
except SettingsError as e:
print(e)
"""
error parsing CLI: argument --circle.diameter: not allowed with argument --circle.radius
"""
```

### Customizing the CLI Experience

The below flags can be used to customise the CLI experience to your needs.
Expand Down
2 changes: 2 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AzureKeyVaultSettingsSource,
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
Expand Down Expand Up @@ -34,6 +35,7 @@
'CliPositionalArg',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'InitSettingsSource',
'JsonConfigSettingsSource',
'PyprojectTomlConfigSettingsSource',
Expand Down
51 changes: 45 additions & 6 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def error(self, message: str) -> NoReturn:
super().error(message)


class CliMutuallyExclusiveGroup(BaseModel):
pass


T = TypeVar('T')
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
CliPositionalArg = Annotated[T, _CliPositionalArg]
Expand Down Expand Up @@ -1483,7 +1487,7 @@ def _connect_parser_method(
if (
parser_method is not None
and self.case_sensitive is False
and method_name == 'parsed_args_method'
and method_name == 'parse_args_method'
and isinstance(self._root_parser, _CliInternalArgParser)
):

Expand Down Expand Up @@ -1515,6 +1519,26 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any:
else:
return parser_method

def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]:
add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')

def add_group_method(parser: Any, **kwargs: Any) -> Any:
if not kwargs.pop('_is_cli_mutually_exclusive_group'):
kwargs.pop('required')
return add_argument_group(parser, **kwargs)
else:
main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
main_group_kwargs['title'] += ' (mutually exclusive)'
group = add_argument_group(parser, **main_group_kwargs)
if not hasattr(group, 'add_mutually_exclusive_group'):
raise SettingsError(
'cannot connect CLI settings source root parser: '
'group object is missing add_mutually_exclusive_group but is needed for connecting'
)
return group.add_mutually_exclusive_group(**kwargs)

return add_group_method

def _connect_root_parser(
self,
root_parser: T,
Expand All @@ -1531,9 +1555,9 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
self._root_parser = root_parser
if parse_args_method is None:
parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args
self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method')
self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method')
self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method')
self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
self._add_group = self._connect_group_method(add_argument_group_method)
self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method')
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
self._formatter_class = formatter_class
Expand Down Expand Up @@ -1665,6 +1689,7 @@ def _add_parser_args(
if is_parser_submodel:
self._add_parser_submodels(
parser,
model,
sub_models,
added_args,
arg_prefix,
Expand All @@ -1680,7 +1705,7 @@ def _add_parser_args(
elif not is_alias_path_only:
if group is not None:
if isinstance(group, dict):
group = self._add_argument_group(parser, **group)
group = self._add_group(parser, **group)
added_args += list(arg_names)
self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs)
else:
Expand Down Expand Up @@ -1724,6 +1749,7 @@ def _get_arg_names(
def _add_parser_submodels(
self,
parser: Any,
model: type[BaseModel],
sub_models: list[type[BaseModel]],
added_args: list[str],
arg_prefix: str,
Expand All @@ -1736,10 +1762,23 @@ def _add_parser_submodels(
alias_names: tuple[str, ...],
model_default: Any,
) -> None:
if issubclass(model, CliMutuallyExclusiveGroup):
# Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a
# mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion).
# Since nested models result in a group add, raise an exception for nested models in a mutually
# exclusive group.
raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup')

model_group: Any = None
model_group_kwargs: dict[str, Any] = {}
model_group_kwargs['title'] = f'{arg_names[0]} options'
model_group_kwargs['description'] = field_info.description
model_group_kwargs['required'] = kwargs['required']
model_group_kwargs['_is_cli_mutually_exclusive_group'] = any(
issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models
)
if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1:
raise SettingsError('cannot use union with CliMutuallyExclusiveGroup')
if self.cli_use_class_docs_for_groups and len(sub_models) == 1:
model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__)

Expand All @@ -1762,7 +1801,7 @@ def _add_parser_submodels(
if not self.cli_avoid_json:
added_args.append(arg_names[0])
kwargs['help'] = f'set {arg_names[0]} from JSON string'
model_group = self._add_argument_group(parser, **model_group_kwargs)
model_group = self._add_group(parser, **model_group_kwargs)
self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs)
for model in sub_models:
self._add_parser_args(
Expand All @@ -1788,7 +1827,7 @@ def _add_parser_alias_paths(
if alias_path_args:
context = parser
if group is not None:
context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group
context = self._add_group(parser, **group) if isinstance(group, dict) else group
is_nested_alias_path = arg_prefix.endswith('.')
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
for name, metavar in alias_path_args.items():
Expand Down
Loading
Loading