From 47924f54a9254501956334a1a43ef9295baa68ef Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Tue, 27 Aug 2024 01:37:05 -0600 Subject: [PATCH] Add get_subcommand function. (#341) --- docs/index.md | 122 ++++++++++------------------------ pydantic_settings/__init__.py | 2 + pydantic_settings/sources.py | 72 ++++++++++++++++---- tests/test_settings.py | 44 ++++++++++++ 4 files changed, 139 insertions(+), 101 deletions(-) diff --git a/docs/index.md b/docs/index.md index 63459ea..5e3cf28 100644 --- a/docs/index.md +++ b/docs/index.md @@ -747,6 +747,9 @@ Subcommands and positional arguments are expressed using the `CliSubCommand` and annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. +Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is +not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found. + !!! note CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model are grouped under a single subparser; it does not allow for multiple subparsers with each subparser having its own @@ -759,114 +762,59 @@ subcommands must be a valid type derived from either a pydantic `BaseModel` or p ```py import sys -from pydantic import BaseModel, Field -from pydantic.dataclasses import dataclass +from pydantic import BaseModel from pydantic_settings import ( BaseSettings, CliPositionalArg, CliSubCommand, + SettingsError, + get_subcommand, ) -@dataclass -class FooPlugin: - """git-plugins-foo - Extra deep foo plugin command""" - - x_feature: bool = Field(default=False, description='Enable "X" feature') - - -@dataclass -class BarPlugin: - """git-plugins-bar - Extra deep bar plugin command""" - - y_feature: bool = Field(default=False, description='Enable "Y" feature') - - -@dataclass -class Plugins: - """git-plugins - Fake plugins for GIT""" - - foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin') - - bar: CliSubCommand[BarPlugin] = Field(description='Bar is fake plugin') +class Init(BaseModel): + directory: CliPositionalArg[str] class Clone(BaseModel): - """git-clone - Clone a repository into a new directory""" - - repository: CliPositionalArg[str] = Field(description='The repo ...') - - directory: CliPositionalArg[str] = Field(description='The dir ...') - - local: bool = Field(default=False, description='When the repo ...') - - -class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'): - """git - The stupid content tracker""" - - clone: CliSubCommand[Clone] = Field(description='Clone a repo ...') - - plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugins') - - -try: - sys.argv = ['example.py', '--help'] - Git() -except SystemExit as e: - print(e) - #> 0 -""" -usage: git [-h] {clone,plugins} ... + repository: CliPositionalArg[str] + directory: CliPositionalArg[str] -git - The stupid content tracker -options: - -h, --help show this help message and exit +class Git(BaseSettings, cli_parse_args=True, cli_exit_on_error=False): + clone: CliSubCommand[Clone] + init: CliSubCommand[Init] -subcommands: - {clone,plugins} - clone Clone a repo ... - plugins Fake GIT plugins -""" +# Run without subcommands +sys.argv = ['example.py'] +cmd = Git() +assert cmd.model_dump() == {'clone': None, 'init': None} try: - sys.argv = ['example.py', 'clone', '--help'] - Git() -except SystemExit as e: - print(e) - #> 0 -""" -usage: git clone [-h] [--local bool] [--shared bool] REPOSITORY DIRECTORY + # Will raise an error since no subcommand was provided + get_subcommand(cmd).model_dump() +except SettingsError as err: + assert str(err) == 'Error: CLI subcommand is required {clone, init}' -git-clone - Clone a repository into a new directory +# Will not raise an error since subcommand is not required +assert get_subcommand(cmd, is_required=False) is None -positional arguments: - REPOSITORY The repo ... - DIRECTORY The dir ... - -options: - -h, --help show this help message and exit - --local bool When the repo ... (default: False) -""" +# Run the clone subcommand +sys.argv = ['example.py', 'clone', 'repo', 'dest'] +cmd = Git() +assert cmd.model_dump() == { + 'clone': {'repository': 'repo', 'directory': 'dest'}, + 'init': None, +} -try: - sys.argv = ['example.py', 'plugins', 'bar', '--help'] - Git() -except SystemExit as e: - print(e) - #> 0 -""" -usage: git plugins bar [-h] [--my_feature bool] - -git-plugins-bar - Extra deep bar plugin command - -options: - -h, --help show this help message and exit - --y_feature bool Enable "Y" feature (default: False) -""" +# Returns the subcommand model instance (in this case, 'clone') +assert get_subcommand(cmd).model_dump() == { + 'directory': 'dest', + 'repository': 'repo', +} ``` ### Customizing the CLI Experience diff --git a/pydantic_settings/__init__.py b/pydantic_settings/__init__.py index 5f979ea..696276d 100644 --- a/pydantic_settings/__init__.py +++ b/pydantic_settings/__init__.py @@ -16,6 +16,7 @@ SettingsError, TomlConfigSettingsSource, YamlConfigSettingsSource, + get_subcommand, ) from .version import VERSION @@ -38,6 +39,7 @@ 'TomlConfigSettingsSource', 'YamlConfigSettingsSource', 'AzureKeyVaultSettingsSource', + 'get_subcommand', '__version__', ) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 90e5b37..f3791f0 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -156,6 +156,53 @@ def error(self, message: str) -> NoReturn: CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag] +def get_subcommand(model: BaseModel, is_required: bool = True, cli_exit_on_error: bool | None = None) -> Any: + """ + Get the subcommand from a model. + + Args: + model: The model to get the subcommand from. + is_required: Determines whether a model must have subcommand set and raises error if not + found. Defaults to `True`. + cli_exit_on_error: Determines whether this function exits with error if no subcommand is found. + Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`. + + Returns: + The subcommand model if found, otherwise `None`. + + Raises: + SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True` + (the default). + SettingsError: When no subcommand is found and is_required=`True` and + cli_exit_on_error=`False`. + """ + + model_cls = type(model) + if cli_exit_on_error is None and is_model_class(model_cls): + model_default = model.model_config.get('cli_exit_on_error') + if isinstance(model_default, bool): + cli_exit_on_error = model_default + if cli_exit_on_error is None: + cli_exit_on_error = True + + subcommands: list[str] = [] + for field_name, field_info in _get_model_fields(model_cls).items(): + if _CliSubCommand in field_info.metadata: + if getattr(model, field_name) is not None: + return getattr(model, field_name) + subcommands.append(field_name) + + if is_required: + error_message = ( + f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}' + if subcommands + else 'Error: CLI subcommand is required but no subcommands were found.' + ) + raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message) + + return None + + class EnvNoneType(str): pass @@ -763,11 +810,7 @@ class Cfg(BaseSettings): if type_has_key: return type_has_key elif is_model_class(annotation) or is_pydantic_dataclass(annotation): - fields = ( - annotation.__pydantic_fields__ - if is_pydantic_dataclass(annotation) and hasattr(annotation, '__pydantic_fields__') - else cast(BaseModel, annotation).model_fields - ) + fields = _get_model_fields(annotation) # `case_sensitive is None` is here to be compatible with the old behavior. # Has to be removed in V3. if (case_sensitive is None or case_sensitive) and fields.get(key): @@ -1376,12 +1419,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]: positional_args, subcommand_args, optional_args = [], [], [] - fields = ( - model.__pydantic_fields__ - if hasattr(model, '__pydantic_fields__') and is_pydantic_dataclass(model) - else model.model_fields - ) - for field_name, field_info in fields.items(): + for field_name, field_info in _get_model_fields(model).items(): if _CliSubCommand in field_info.metadata: if not field_info.is_required(): raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value') @@ -1496,9 +1534,7 @@ def _add_parser_args( sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info) if _CliSubCommand in field_info.metadata: if subparsers is None: - subparsers = self._add_subparsers( - parser, title='subcommands', dest=f'{arg_prefix}:subcommand', required=self.cli_enforce_required - ) + subparsers = self._add_subparsers(parser, title='subcommands', dest=f'{arg_prefix}:subcommand') self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}'] else: self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}') @@ -2095,5 +2131,13 @@ def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any return None +def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]: + if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'): + return model_cls.__pydantic_fields__ + if is_model_class(model_cls): + return model_cls.model_fields + raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass') + + def _is_function(obj: Any) -> bool: return inspect.isfunction(obj) or inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.ismethod(obj) diff --git a/tests/test_settings.py b/tests/test_settings.py index 0ba96a0..185e3ae 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -20,6 +20,7 @@ AliasChoices, AliasPath, BaseModel, + ConfigDict, DirectoryPath, Discriminator, Field, @@ -59,6 +60,7 @@ CliSubCommand, DefaultSettingsSource, SettingsError, + get_subcommand, ) try: @@ -3095,6 +3097,12 @@ class FooPlugin: class BarPlugin: my_feature: bool = False + bar = BarPlugin() + with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(bar) + with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(bar, cli_exit_on_error=False) + @pydantic_dataclasses.dataclass class Plugins: foo: CliSubCommand[FooPlugin] @@ -3116,12 +3124,26 @@ class Git(BaseSettings): init: CliSubCommand[Init] plugins: CliSubCommand[Plugins] + git = Git(_cli_parse_args=[]) + assert git.model_dump() == { + 'clone': None, + 'init': None, + 'plugins': None, + } + assert get_subcommand(git, is_required=False) is None + with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'): + get_subcommand(git) + with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'): + get_subcommand(git, cli_exit_on_error=False) + git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) assert git.model_dump() == { 'clone': None, 'init': {'directory': 'dir/path', 'quiet': True, 'bare': False}, 'plugins': None, } + assert get_subcommand(git) == git.init + assert get_subcommand(git, is_required=False) == git.init git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true']) assert git.model_dump() == { @@ -3129,6 +3151,8 @@ class Git(BaseSettings): 'init': None, 'plugins': None, } + assert get_subcommand(git) == git.clone + assert get_subcommand(git, is_required=False) == git.clone git = Git(_cli_parse_args=['plugins', 'bar']) assert git.model_dump() == { @@ -3136,6 +3160,26 @@ class Git(BaseSettings): 'init': None, 'plugins': {'foo': None, 'bar': {'my_feature': False}}, } + assert get_subcommand(git) == git.plugins + assert get_subcommand(git, is_required=False) == git.plugins + assert get_subcommand(get_subcommand(git)) == git.plugins.bar + assert get_subcommand(get_subcommand(git), is_required=False) == git.plugins.bar + + class NotModel: ... + + with pytest.raises( + SettingsError, match='Error: NotModel is not subclass of BaseModel or pydantic.dataclasses.dataclass' + ): + get_subcommand(NotModel()) + + class NotSettingsConfigDict(BaseModel): + model_config = ConfigDict(cli_exit_on_error='not a bool') + + with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(NotSettingsConfigDict()) + + with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False) def test_cli_union_similar_sub_models():