From 1cf9fff45f661336160fe4980de99a8388778c41 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Wed, 10 Jul 2024 18:10:13 -0600 Subject: [PATCH 01/12] Add get_subcommand function. --- pydantic_settings/__init__.py | 2 ++ pydantic_settings/sources.py | 13 +++++++++++++ tests/test_settings.py | 20 +++++++++++++++++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pydantic_settings/__init__.py b/pydantic_settings/__init__.py index d70ccc8a..494ba011 100644 --- a/pydantic_settings/__init__.py +++ b/pydantic_settings/__init__.py @@ -12,6 +12,7 @@ SecretsSettingsSource, TomlConfigSettingsSource, YamlConfigSettingsSource, + get_subcommand, ) from .version import VERSION @@ -30,6 +31,7 @@ 'SettingsConfigDict', 'TomlConfigSettingsSource', 'YamlConfigSettingsSource', + 'get_subcommand', '__version__', ) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index ec9f604c..ea2b371b 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -110,6 +110,19 @@ class _CliInternalArgParser(ArgumentParser): CliPositionalArg = Annotated[T, _CliPositionalArg] +def get_subcommand(model: BaseModel, is_required: bool=False) -> Any: + subcommands: list[str] = [] + fields = model.__pydantic_fields__ if is_pydantic_dataclass(type(model)) else model.model_fields + for field_name, field_info in fields.items(): + if _CliSubCommand in field_info.metadata: + if getattr(model, field_name) is not None: + return getattr(model, field_name) + subcommands.append(field_name) + if is_required: + raise SettingsError(f'CLI subcommand is required {{{", ".join(subcommands)}}}') + return None + + class EnvNoneType(str): pass diff --git a/tests/test_settings.py b/tests/test_settings.py index f0888e3b..13464c81 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -49,7 +49,7 @@ TomlConfigSettingsSource, YamlConfigSettingsSource, ) -from pydantic_settings.sources import CliPositionalArg, CliSettingsSource, CliSubCommand, SettingsError +from pydantic_settings.sources import CliPositionalArg, CliSettingsSource, CliSubCommand, SettingsError, get_subcommand try: import dotenv @@ -2784,12 +2784,24 @@ class Git(BaseSettings): init: CliSubCommand[Init] plugins: CliSubCommand[Plugins] + git = Git(_cli_parse_args=[]) + assert git.model_dump() == { + 'clone': None, + 'init': None, + 'plugins': None, + } + assert get_subcommand(git) == None + with pytest.raises(SettingsError): + get_subcommand(git, is_required=True) + git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) assert git.model_dump() == { 'clone': None, 'init': {'directory': 'dir/path', 'quiet': True, 'bare': False}, 'plugins': None, } + assert get_subcommand(git) == git.init + assert get_subcommand(git, is_required=True) == git.init git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true']) assert git.model_dump() == { @@ -2797,6 +2809,8 @@ class Git(BaseSettings): 'init': None, 'plugins': None, } + assert get_subcommand(git) == git.clone + assert get_subcommand(git, is_required=True) == git.clone git = Git(_cli_parse_args=['plugins', 'bar']) assert git.model_dump() == { @@ -2804,6 +2818,10 @@ class Git(BaseSettings): 'init': None, 'plugins': {'foo': None, 'bar': {'my_feature': False}}, } + assert get_subcommand(git) == git.plugins + assert get_subcommand(git, is_required=True) == git.plugins + assert get_subcommand(get_subcommand(git)) == git.plugins.bar + assert get_subcommand(get_subcommand(git), is_required=True) == git.plugins.bar def test_cli_union_similar_sub_models(): From 3de927b4eeba2eb30023784a269bbb82efe69d99 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Wed, 10 Jul 2024 18:12:06 -0600 Subject: [PATCH 02/12] Lint. --- tests/test_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_settings.py b/tests/test_settings.py index 13464c81..a6fc4820 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2790,7 +2790,7 @@ class Git(BaseSettings): 'init': None, 'plugins': None, } - assert get_subcommand(git) == None + assert get_subcommand(git) is None with pytest.raises(SettingsError): get_subcommand(git, is_required=True) From d23aaa79e89612b39ad6763667466bd1014fda8c Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 12 Jul 2024 14:35:07 -0600 Subject: [PATCH 03/12] Doc updates. --- docs/index.md | 97 ++++++++++-------------------------- pydantic_settings/sources.py | 5 +- 2 files changed, 29 insertions(+), 73 deletions(-) diff --git a/docs/index.md b/docs/index.md index a89a0abc..a32b23a2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -713,6 +713,9 @@ Subcommands and positional arguments are expressed using the `CliSubCommand` and annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. +Parsed subcommands can be retrieved from model instances using the `get_subcommand` function. If a subcommand is +required, use the `is_required` flag to raise a `SettingsError` if no subcommand is found. + !!! note CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model are grouped under a single subparser; it does not allow for multiple subparsers with each subparser having its own @@ -726,113 +729,65 @@ subcommands must be a valid type derived from either a pydantic `BaseModel` or p import sys from pydantic import BaseModel, Field -from pydantic.dataclasses import dataclass from pydantic_settings import ( BaseSettings, CliPositionalArg, CliSubCommand, + get_subcommand, ) -@dataclass -class FooPlugin: +class FooPlugin(BaseModel): """git-plugins-foo - Extra deep foo plugin command""" - x_feature: bool = Field(default=False, description='Enable "X" feature') - - -@dataclass -class BarPlugin: - """git-plugins-bar - Extra deep bar plugin command""" + feature: bool = Field(default=False, description='Enable feature') - y_feature: bool = Field(default=False, description='Enable "Y" feature') + def plugin_main(self) -> BaseModel: + """Run this method if selected and return self""" + return self -@dataclass -class Plugins: +class Plugins(BaseModel): """git-plugins - Fake plugins for GIT""" foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin') - bar: CliSubCommand[BarPlugin] = Field(description='Bar is fake plugin') + def sub_main(self) -> BaseModel: + """Run this method if selected and return fake plugin command""" + return get_subcommand(self, is_required=True).plugin_main() class Clone(BaseModel): """git-clone - Clone a repository into a new directory""" repository: CliPositionalArg[str] = Field(description='The repo ...') - directory: CliPositionalArg[str] = Field(description='The dir ...') - local: bool = Field(default=False, description='When the repo ...') + def sub_main(self) -> BaseModel: + """Run this method if selected and return self""" + return self + class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'): """git - The stupid content tracker""" clone: CliSubCommand[Clone] = Field(description='Clone a repo ...') - plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugins') + def main(self) -> BaseModel: + """Run CLI main method and return sub cmd""" + return get_subcommand(self, is_required=True).sub_main() -try: - sys.argv = ['example.py', '--help'] - Git() -except SystemExit as e: - print(e) - #> 0 -""" -usage: git [-h] {clone,plugins} ... - -git - The stupid content tracker - -options: - -h, --help show this help message and exit - -subcommands: - {clone,plugins} - clone Clone a repo ... - plugins Fake GIT plugins -""" +sys.argv = ['example.py', 'clone', 'repo', 'dest'] +print(Git().main().model_dump()) +#> {'repository': 'repo', 'directory': 'dest', 'local': False} -try: - sys.argv = ['example.py', 'clone', '--help'] - Git() -except SystemExit as e: - print(e) - #> 0 -""" -usage: git clone [-h] [--local bool] [--shared bool] REPOSITORY DIRECTORY - -git-clone - Clone a repository into a new directory - -positional arguments: - REPOSITORY The repo ... - DIRECTORY The dir ... - -options: - -h, --help show this help message and exit - --local bool When the repo ... (default: False) -""" - - -try: - sys.argv = ['example.py', 'plugins', 'bar', '--help'] - Git() -except SystemExit as e: - print(e) - #> 0 -""" -usage: git plugins bar [-h] [--my_feature bool] - -git-plugins-bar - Extra deep bar plugin command - -options: - -h, --help show this help message and exit - --y_feature bool Enable "Y" feature (default: False) -""" +sys.argv = ['example.py', 'plugins', 'foo', '--feature', 'true'] +print(Git().main().model_dump()) +#> {'feature': True} ``` ### Customizing the CLI Experience diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index ea2b371b..eb9a81d6 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -110,9 +110,10 @@ class _CliInternalArgParser(ArgumentParser): CliPositionalArg = Annotated[T, _CliPositionalArg] -def get_subcommand(model: BaseModel, is_required: bool=False) -> Any: +def get_subcommand(model: BaseModel, is_required: bool = False) -> Any: + model_cls = type(model) subcommands: list[str] = [] - fields = model.__pydantic_fields__ if is_pydantic_dataclass(type(model)) else model.model_fields + fields = model_cls.__pydantic_fields__ if is_pydantic_dataclass(model_cls) else model.model_fields for field_name, field_info in fields.items(): if _CliSubCommand in field_info.metadata: if getattr(model, field_name) is not None: From e0937cc191cf3e056decaa28af61496666dcbc0b Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 12 Jul 2024 15:02:13 -0600 Subject: [PATCH 04/12] More docs and is_exit_on_error option. --- pydantic_settings/sources.py | 25 +++++++++++++++++++++++-- tests/test_settings.py | 4 +++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index eb9a81d6..53647c7c 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -110,7 +110,27 @@ class _CliInternalArgParser(ArgumentParser): CliPositionalArg = Annotated[T, _CliPositionalArg] -def get_subcommand(model: BaseModel, is_required: bool = False) -> Any: +def get_subcommand(model: BaseModel, is_required: bool = False, is_exit_on_error: bool = True) -> Any: + """ + Get the subcommand from a model. + + Args: + model: The model to get the subcommand from. + is_required: Determines whether a model must have subcommand set and raises error if not + found. Defaults to `False`. + is_exit_on_error: Determines whether this function exits with error if no subcommand is found. + Defaults to `True`. + + Returns: + The subcommand model if found, otherwise `None`. + + Raises: + SystemExit: When no subcommand is found and is_required=`True` and is_exit_on_error=`True` + (the default). + SettingsError: When no subcommand is found and is_required=`True` and + is_exit_on_error=`False`. + """ + model_cls = type(model) subcommands: list[str] = [] fields = model_cls.__pydantic_fields__ if is_pydantic_dataclass(model_cls) else model.model_fields @@ -120,7 +140,8 @@ def get_subcommand(model: BaseModel, is_required: bool = False) -> Any: return getattr(model, field_name) subcommands.append(field_name) if is_required: - raise SettingsError(f'CLI subcommand is required {{{", ".join(subcommands)}}}') + error_message = f'CLI subcommand is required {{{", ".join(subcommands)}}}' + raise SystemExit(error_message) if is_exit_on_error else SettingsError(error_message) return None diff --git a/tests/test_settings.py b/tests/test_settings.py index a6fc4820..52c56532 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2791,8 +2791,10 @@ class Git(BaseSettings): 'plugins': None, } assert get_subcommand(git) is None - with pytest.raises(SettingsError): + with pytest.raises(SystemExit, match='CLI subcommand is required {clone, init, plugins}'): get_subcommand(git, is_required=True) + with pytest.raises(SettingsError, match='CLI subcommand is required {clone, init, plugins}'): + get_subcommand(git, is_required=True, is_exit_on_error=False) git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) assert git.model_dump() == { From 2f91feb11e830fe6e173bdd163812ac67223c4c8 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 12 Jul 2024 15:05:00 -0600 Subject: [PATCH 05/12] Cleanup. --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index a32b23a2..3a17660e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -714,7 +714,7 @@ annotations can only be applied to required fields (i.e. fields that do not have subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. Parsed subcommands can be retrieved from model instances using the `get_subcommand` function. If a subcommand is -required, use the `is_required` flag to raise a `SettingsError` if no subcommand is found. +required, use the `is_required` flag to raise an error if no subcommand is found. !!! note CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model From 2ca135206e4285d54c283ca4cc66a5c5dbb98cd3 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 12 Jul 2024 15:15:27 -0600 Subject: [PATCH 06/12] Update error message string. --- pydantic_settings/sources.py | 2 +- tests/test_settings.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 53647c7c..1292d319 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -140,7 +140,7 @@ def get_subcommand(model: BaseModel, is_required: bool = False, is_exit_on_error return getattr(model, field_name) subcommands.append(field_name) if is_required: - error_message = f'CLI subcommand is required {{{", ".join(subcommands)}}}' + error_message = f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}' raise SystemExit(error_message) if is_exit_on_error else SettingsError(error_message) return None diff --git a/tests/test_settings.py b/tests/test_settings.py index 52c56532..10c9f1e5 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2791,9 +2791,9 @@ class Git(BaseSettings): 'plugins': None, } assert get_subcommand(git) is None - with pytest.raises(SystemExit, match='CLI subcommand is required {clone, init, plugins}'): + with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'): get_subcommand(git, is_required=True) - with pytest.raises(SettingsError, match='CLI subcommand is required {clone, init, plugins}'): + with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'): get_subcommand(git, is_required=True, is_exit_on_error=False) git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) From 6db12342d6435fa16a0f5830fa9b176d48ea6991 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Thu, 1 Aug 2024 21:00:31 -0600 Subject: [PATCH 07/12] Change is_required default to True. --- pydantic_settings/sources.py | 4 ++-- tests/test_settings.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 4e58c639..c5a82458 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -139,14 +139,14 @@ def error(self, message: str) -> NoReturn: CliPositionalArg = Annotated[T, _CliPositionalArg] -def get_subcommand(model: BaseModel, is_required: bool = False, is_exit_on_error: bool = True) -> Any: +def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: bool = True) -> Any: """ Get the subcommand from a model. Args: model: The model to get the subcommand from. is_required: Determines whether a model must have subcommand set and raises error if not - found. Defaults to `False`. + found. Defaults to `True`. is_exit_on_error: Determines whether this function exits with error if no subcommand is found. Defaults to `True`. diff --git a/tests/test_settings.py b/tests/test_settings.py index ad2b3769..9ca1eae7 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2783,11 +2783,11 @@ class Git(BaseSettings): 'init': None, 'plugins': None, } - assert get_subcommand(git) is None + assert get_subcommand(git, is_required=False) is None with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'): - get_subcommand(git, is_required=True) + get_subcommand(git) with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'): - get_subcommand(git, is_required=True, is_exit_on_error=False) + get_subcommand(git, is_exit_on_error=False) git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) assert git.model_dump() == { @@ -2796,7 +2796,7 @@ class Git(BaseSettings): 'plugins': None, } assert get_subcommand(git) == git.init - assert get_subcommand(git, is_required=True) == git.init + assert get_subcommand(git, is_required=False) == git.init git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true']) assert git.model_dump() == { @@ -2805,7 +2805,7 @@ class Git(BaseSettings): 'plugins': None, } assert get_subcommand(git) == git.clone - assert get_subcommand(git, is_required=True) == git.clone + assert get_subcommand(git, is_required=False) == git.clone git = Git(_cli_parse_args=['plugins', 'bar']) assert git.model_dump() == { @@ -2814,9 +2814,9 @@ class Git(BaseSettings): 'plugins': {'foo': None, 'bar': {'my_feature': False}}, } assert get_subcommand(git) == git.plugins - assert get_subcommand(git, is_required=True) == git.plugins + assert get_subcommand(git, is_required=False) == git.plugins assert get_subcommand(get_subcommand(git)) == git.plugins.bar - assert get_subcommand(get_subcommand(git), is_required=True) == git.plugins.bar + assert get_subcommand(get_subcommand(git), is_required=False) == git.plugins.bar def test_cli_union_similar_sub_models(): From b37974f27d84ec3b21dfc2b9d853a9dc2e28d572 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Thu, 1 Aug 2024 21:04:40 -0600 Subject: [PATCH 08/12] Docs update. --- docs/index.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/index.md b/docs/index.md index b40496bd..8d36b76f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -714,7 +714,7 @@ annotations can only be applied to required fields (i.e. fields that do not have subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. Parsed subcommands can be retrieved from model instances using the `get_subcommand` function. If a subcommand is -required, use the `is_required` flag to raise an error if no subcommand is found. +not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found. !!! note CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model @@ -755,7 +755,7 @@ class Plugins(BaseModel): def sub_main(self) -> BaseModel: """Run this method if selected and return fake plugin command""" - return get_subcommand(self, is_required=True).plugin_main() + return get_subcommand(self).plugin_main() class Clone(BaseModel): @@ -778,7 +778,7 @@ class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'): def main(self) -> BaseModel: """Run CLI main method and return sub cmd""" - return get_subcommand(self, is_required=True).sub_main() + return get_subcommand(self).sub_main() sys.argv = ['example.py', 'clone', 'repo', 'dest'] From 0c1bdf93247ea8afb9811f74b9c7742b867de571 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Thu, 1 Aug 2024 21:17:36 -0600 Subject: [PATCH 09/12] Avoid #356 issue. --- pydantic_settings/sources.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index c5a82458..1c15e85b 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -162,7 +162,11 @@ def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: model_cls = type(model) subcommands: list[str] = [] - fields = model_cls.__pydantic_fields__ if is_pydantic_dataclass(model_cls) else model.model_fields + fields = ( + model_cls.__pydantic_fields__ + if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__') + else model.model_fields + ) for field_name, field_info in fields.items(): if _CliSubCommand in field_info.metadata: if getattr(model, field_name) is not None: From 8bc5194b803af77d744024bc6f3cc8a76950e7e8 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Thu, 1 Aug 2024 21:23:15 -0600 Subject: [PATCH 10/12] Lint. --- pydantic_settings/sources.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 1c15e85b..bacad934 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -164,7 +164,7 @@ def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: subcommands: list[str] = [] fields = ( model_cls.__pydantic_fields__ - if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__') + if hasattr(model_cls, '__pydantic_fields__') and is_pydantic_dataclass(model_cls) else model.model_fields ) for field_name, field_info in fields.items(): From 1c3d6ce227dc086c6cad733b57be30f6c8b84b56 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Thu, 1 Aug 2024 21:38:41 -0600 Subject: [PATCH 11/12] Handle no subcommands case. --- pydantic_settings/sources.py | 8 ++++++-- tests/test_settings.py | 6 ++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index bacad934..db8a7e49 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -165,7 +165,7 @@ def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: fields = ( model_cls.__pydantic_fields__ if hasattr(model_cls, '__pydantic_fields__') and is_pydantic_dataclass(model_cls) - else model.model_fields + else model_cls.model_fields ) for field_name, field_info in fields.items(): if _CliSubCommand in field_info.metadata: @@ -173,7 +173,11 @@ def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: return getattr(model, field_name) subcommands.append(field_name) if is_required: - error_message = f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}' + error_message = ( + f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}' + if subcommands + else 'Error: CLI subcommand is required but no subcommands were found.' + ) raise SystemExit(error_message) if is_exit_on_error else SettingsError(error_message) return None diff --git a/tests/test_settings.py b/tests/test_settings.py index 9ca1eae7..ee4f61f6 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -2756,6 +2756,12 @@ class FooPlugin: class BarPlugin: my_feature: bool = False + bar = BarPlugin() + with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(bar) + with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(bar, is_exit_on_error=False) + @pydantic_dataclasses.dataclass class Plugins: foo: CliSubCommand[FooPlugin] From 28bed19035e8d458ff006e8051dbb08adbea068c Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Fri, 23 Aug 2024 09:38:09 -0600 Subject: [PATCH 12/12] Review updates. --- docs/index.md | 71 ++++++++++++++++-------------------- pydantic_settings/sources.py | 49 +++++++++++++------------ tests/test_settings.py | 21 ++++++++++- 3 files changed, 77 insertions(+), 64 deletions(-) diff --git a/docs/index.md b/docs/index.md index 5da9ce04..5e3cf281 100644 --- a/docs/index.md +++ b/docs/index.md @@ -747,7 +747,7 @@ Subcommands and positional arguments are expressed using the `CliSubCommand` and annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`. -Parsed subcommands can be retrieved from model instances using the `get_subcommand` function. If a subcommand is +Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found. !!! note @@ -762,66 +762,59 @@ not required, set the `is_required` flag to `False` to disable raising an error ```py import sys -from pydantic import BaseModel, Field +from pydantic import BaseModel from pydantic_settings import ( BaseSettings, CliPositionalArg, CliSubCommand, + SettingsError, get_subcommand, ) -class FooPlugin(BaseModel): - """git-plugins-foo - Extra deep foo plugin command""" - - feature: bool = Field(default=False, description='Enable feature') - - def plugin_main(self) -> BaseModel: - """Run this method if selected and return self""" - return self - - -class Plugins(BaseModel): - """git-plugins - Fake plugins for GIT""" - - foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin') - - def sub_main(self) -> BaseModel: - """Run this method if selected and return fake plugin command""" - return get_subcommand(self).plugin_main() +class Init(BaseModel): + directory: CliPositionalArg[str] class Clone(BaseModel): - """git-clone - Clone a repository into a new directory""" + repository: CliPositionalArg[str] + directory: CliPositionalArg[str] - repository: CliPositionalArg[str] = Field(description='The repo ...') - directory: CliPositionalArg[str] = Field(description='The dir ...') - local: bool = Field(default=False, description='When the repo ...') - def sub_main(self) -> BaseModel: - """Run this method if selected and return self""" - return self +class Git(BaseSettings, cli_parse_args=True, cli_exit_on_error=False): + clone: CliSubCommand[Clone] + init: CliSubCommand[Init] -class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'): - """git - The stupid content tracker""" +# Run without subcommands +sys.argv = ['example.py'] +cmd = Git() +assert cmd.model_dump() == {'clone': None, 'init': None} - clone: CliSubCommand[Clone] = Field(description='Clone a repo ...') - plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugins') +try: + # Will raise an error since no subcommand was provided + get_subcommand(cmd).model_dump() +except SettingsError as err: + assert str(err) == 'Error: CLI subcommand is required {clone, init}' - def main(self) -> BaseModel: - """Run CLI main method and return sub cmd""" - return get_subcommand(self).sub_main() +# Will not raise an error since subcommand is not required +assert get_subcommand(cmd, is_required=False) is None +# Run the clone subcommand sys.argv = ['example.py', 'clone', 'repo', 'dest'] -print(Git().main().model_dump()) -#> {'repository': 'repo', 'directory': 'dest', 'local': False} +cmd = Git() +assert cmd.model_dump() == { + 'clone': {'repository': 'repo', 'directory': 'dest'}, + 'init': None, +} -sys.argv = ['example.py', 'plugins', 'foo', '--feature', 'true'] -print(Git().main().model_dump()) -#> {'feature': True} +# Returns the subcommand model instance (in this case, 'clone') +assert get_subcommand(cmd).model_dump() == { + 'directory': 'dest', + 'repository': 'repo', +} ``` ### Customizing the CLI Experience diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 8270b5de..cb96b4db 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -156,7 +156,7 @@ def error(self, message: str) -> NoReturn: CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag] -def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: bool = True) -> Any: +def get_subcommand(model: BaseModel, is_required: bool = True, cli_exit_on_error: bool | None = None) -> Any: """ Get the subcommand from a model. @@ -164,38 +164,42 @@ def get_subcommand(model: BaseModel, is_required: bool = True, is_exit_on_error: model: The model to get the subcommand from. is_required: Determines whether a model must have subcommand set and raises error if not found. Defaults to `True`. - is_exit_on_error: Determines whether this function exits with error if no subcommand is found. - Defaults to `True`. + cli_exit_on_error: Determines whether this function exits with error if no subcommand is found. + Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`. Returns: The subcommand model if found, otherwise `None`. Raises: - SystemExit: When no subcommand is found and is_required=`True` and is_exit_on_error=`True` + SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True` (the default). SettingsError: When no subcommand is found and is_required=`True` and - is_exit_on_error=`False`. + cli_exit_on_error=`False`. """ model_cls = type(model) + if cli_exit_on_error is None and is_model_class(model_cls): + model_default = model.model_config.get('cli_exit_on_error') + if isinstance(model_default, bool): + cli_exit_on_error = model_default + if cli_exit_on_error is None: + cli_exit_on_error = True + subcommands: list[str] = [] - fields = ( - model_cls.__pydantic_fields__ - if hasattr(model_cls, '__pydantic_fields__') and is_pydantic_dataclass(model_cls) - else model_cls.model_fields - ) - for field_name, field_info in fields.items(): + for field_name, field_info in _get_model_fields(model_cls).items(): if _CliSubCommand in field_info.metadata: if getattr(model, field_name) is not None: return getattr(model, field_name) subcommands.append(field_name) + if is_required: error_message = ( f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}' if subcommands else 'Error: CLI subcommand is required but no subcommands were found.' ) - raise SystemExit(error_message) if is_exit_on_error else SettingsError(error_message) + raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message) + return None @@ -806,11 +810,7 @@ class Cfg(BaseSettings): if type_has_key: return type_has_key elif is_model_class(annotation) or is_pydantic_dataclass(annotation): - fields = ( - annotation.__pydantic_fields__ - if is_pydantic_dataclass(annotation) and hasattr(annotation, '__pydantic_fields__') - else cast(BaseModel, annotation).model_fields - ) + fields = _get_model_fields(annotation) # `case_sensitive is None` is here to be compatible with the old behavior. # Has to be removed in V3. if (case_sensitive is None or case_sensitive) and fields.get(key): @@ -1419,12 +1419,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]: positional_args, subcommand_args, optional_args = [], [], [] - fields = ( - model.__pydantic_fields__ - if hasattr(model, '__pydantic_fields__') and is_pydantic_dataclass(model) - else model.model_fields - ) - for field_name, field_info in fields.items(): + for field_name, field_info in _get_model_fields(model).items(): if _CliSubCommand in field_info.metadata: if not field_info.is_required(): raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value') @@ -2133,3 +2128,11 @@ def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any if name in tuple(val.name for val in type_): return type_[name] return None + + +def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]: + if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'): + return model_cls.__pydantic_fields__ + if is_model_class(model_cls): + return model_cls.model_fields + raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass') diff --git a/tests/test_settings.py b/tests/test_settings.py index bef08f70..ed3dfd37 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -19,6 +19,7 @@ AliasChoices, AliasPath, BaseModel, + ConfigDict, DirectoryPath, Discriminator, Field, @@ -3101,7 +3102,7 @@ class BarPlugin: with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'): get_subcommand(bar) with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'): - get_subcommand(bar, is_exit_on_error=False) + get_subcommand(bar, cli_exit_on_error=False) @pydantic_dataclasses.dataclass class Plugins: @@ -3134,7 +3135,7 @@ class Git(BaseSettings): with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'): get_subcommand(git) with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'): - get_subcommand(git, is_exit_on_error=False) + get_subcommand(git, cli_exit_on_error=False) git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path']) assert git.model_dump() == { @@ -3165,6 +3166,22 @@ class Git(BaseSettings): assert get_subcommand(get_subcommand(git)) == git.plugins.bar assert get_subcommand(get_subcommand(git), is_required=False) == git.plugins.bar + class NotModel: ... + + with pytest.raises( + SettingsError, match='Error: NotModel is not subclass of BaseModel or pydantic.dataclasses.dataclass' + ): + get_subcommand(NotModel()) + + class NotSettingsConfigDict(BaseModel): + model_config = ConfigDict(cli_exit_on_error='not a bool') + + with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(NotSettingsConfigDict()) + + with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'): + get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False) + def test_cli_union_similar_sub_models(): class ChildA(BaseModel):