Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_subcommand function. #341

Merged
merged 16 commits into from
Aug 27, 2024
122 changes: 35 additions & 87 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,9 @@ Subcommands and positional arguments are expressed using the `CliSubCommand` and
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.

Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.

!!! note
CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model
are grouped under a single subparser; it does not allow for multiple subparsers with each subparser having its own
Expand All @@ -759,114 +762,59 @@ subcommands must be a valid type derived from either a pydantic `BaseModel` or p
```py
import sys

from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass
from pydantic import BaseModel

from pydantic_settings import (
BaseSettings,
CliPositionalArg,
CliSubCommand,
SettingsError,
get_subcommand,
)


@dataclass
class FooPlugin:
"""git-plugins-foo - Extra deep foo plugin command"""

x_feature: bool = Field(default=False, description='Enable "X" feature')


@dataclass
class BarPlugin:
"""git-plugins-bar - Extra deep bar plugin command"""

y_feature: bool = Field(default=False, description='Enable "Y" feature')


@dataclass
class Plugins:
"""git-plugins - Fake plugins for GIT"""

foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin')

bar: CliSubCommand[BarPlugin] = Field(description='Bar is fake plugin')
class Init(BaseModel):
directory: CliPositionalArg[str]


class Clone(BaseModel):
"""git-clone - Clone a repository into a new directory"""

repository: CliPositionalArg[str] = Field(description='The repo ...')

directory: CliPositionalArg[str] = Field(description='The dir ...')

local: bool = Field(default=False, description='When the repo ...')


class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'):
"""git - The stupid content tracker"""

clone: CliSubCommand[Clone] = Field(description='Clone a repo ...')

plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugins')


try:
sys.argv = ['example.py', '--help']
Git()
except SystemExit as e:
print(e)
#> 0
"""
usage: git [-h] {clone,plugins} ...
repository: CliPositionalArg[str]
directory: CliPositionalArg[str]

git - The stupid content tracker

options:
-h, --help show this help message and exit
class Git(BaseSettings, cli_parse_args=True, cli_exit_on_error=False):
clone: CliSubCommand[Clone]
init: CliSubCommand[Init]

subcommands:
{clone,plugins}
clone Clone a repo ...
plugins Fake GIT plugins
"""

# Run without subcommands
sys.argv = ['example.py']
cmd = Git()
assert cmd.model_dump() == {'clone': None, 'init': None}

try:
sys.argv = ['example.py', 'clone', '--help']
Git()
except SystemExit as e:
print(e)
#> 0
"""
usage: git clone [-h] [--local bool] [--shared bool] REPOSITORY DIRECTORY
# Will raise an error since no subcommand was provided
get_subcommand(cmd).model_dump()
except SettingsError as err:
assert str(err) == 'Error: CLI subcommand is required {clone, init}'

git-clone - Clone a repository into a new directory
# Will not raise an error since subcommand is not required
assert get_subcommand(cmd, is_required=False) is None

positional arguments:
REPOSITORY The repo ...
DIRECTORY The dir ...
options:
-h, --help show this help message and exit
--local bool When the repo ... (default: False)
"""

# Run the clone subcommand
sys.argv = ['example.py', 'clone', 'repo', 'dest']
cmd = Git()
assert cmd.model_dump() == {
'clone': {'repository': 'repo', 'directory': 'dest'},
'init': None,
}

try:
sys.argv = ['example.py', 'plugins', 'bar', '--help']
Git()
except SystemExit as e:
print(e)
#> 0
"""
usage: git plugins bar [-h] [--my_feature bool]
git-plugins-bar - Extra deep bar plugin command
options:
-h, --help show this help message and exit
--y_feature bool Enable "Y" feature (default: False)
"""
# Returns the subcommand model instance (in this case, 'clone')
assert get_subcommand(cmd).model_dump() == {
'directory': 'dest',
'repository': 'repo',
}
```

### Customizing the CLI Experience
Expand Down
2 changes: 2 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SettingsError,
TomlConfigSettingsSource,
YamlConfigSettingsSource,
get_subcommand,
)
from .version import VERSION

Expand All @@ -38,6 +39,7 @@
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
'AzureKeyVaultSettingsSource',
'get_subcommand',
'__version__',
)

Expand Down
72 changes: 58 additions & 14 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,53 @@ def error(self, message: str) -> NoReturn:
CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag]


def get_subcommand(model: BaseModel, is_required: bool = True, cli_exit_on_error: bool | None = None) -> Any:
"""
Get the subcommand from a model.
Args:
model: The model to get the subcommand from.
is_required: Determines whether a model must have subcommand set and raises error if not
found. Defaults to `True`.
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
Returns:
The subcommand model if found, otherwise `None`.
Raises:
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
(the default).
SettingsError: When no subcommand is found and is_required=`True` and
cli_exit_on_error=`False`.
"""

model_cls = type(model)
if cli_exit_on_error is None and is_model_class(model_cls):
model_default = model.model_config.get('cli_exit_on_error')
if isinstance(model_default, bool):
cli_exit_on_error = model_default
if cli_exit_on_error is None:
cli_exit_on_error = True

subcommands: list[str] = []
for field_name, field_info in _get_model_fields(model_cls).items():
if _CliSubCommand in field_info.metadata:
if getattr(model, field_name) is not None:
return getattr(model, field_name)
subcommands.append(field_name)

if is_required:
error_message = (
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
if subcommands
else 'Error: CLI subcommand is required but no subcommands were found.'
)
raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)

return None


class EnvNoneType(str):
pass

Expand Down Expand Up @@ -763,11 +810,7 @@ class Cfg(BaseSettings):
if type_has_key:
return type_has_key
elif is_model_class(annotation) or is_pydantic_dataclass(annotation):
fields = (
annotation.__pydantic_fields__
if is_pydantic_dataclass(annotation) and hasattr(annotation, '__pydantic_fields__')
else cast(BaseModel, annotation).model_fields
)
fields = _get_model_fields(annotation)
# `case_sensitive is None` is here to be compatible with the old behavior.
# Has to be removed in V3.
if (case_sensitive is None or case_sensitive) and fields.get(key):
Expand Down Expand Up @@ -1376,12 +1419,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,

def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
positional_args, subcommand_args, optional_args = [], [], []
fields = (
model.__pydantic_fields__
if hasattr(model, '__pydantic_fields__') and is_pydantic_dataclass(model)
else model.model_fields
)
for field_name, field_info in fields.items():
for field_name, field_info in _get_model_fields(model).items():
if _CliSubCommand in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
Expand Down Expand Up @@ -1496,9 +1534,7 @@ def _add_parser_args(
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
if _CliSubCommand in field_info.metadata:
if subparsers is None:
subparsers = self._add_subparsers(
parser, title='subcommands', dest=f'{arg_prefix}:subcommand', required=self.cli_enforce_required
)
subparsers = self._add_subparsers(parser, title='subcommands', dest=f'{arg_prefix}:subcommand')
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}']
else:
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}')
Expand Down Expand Up @@ -2095,5 +2131,13 @@ def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any
return None


def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]:
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
return model_cls.__pydantic_fields__
if is_model_class(model_cls):
return model_cls.model_fields
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')


def _is_function(obj: Any) -> bool:
return inspect.isfunction(obj) or inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.ismethod(obj)
44 changes: 44 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AliasChoices,
AliasPath,
BaseModel,
ConfigDict,
DirectoryPath,
Discriminator,
Field,
Expand Down Expand Up @@ -59,6 +60,7 @@
CliSubCommand,
DefaultSettingsSource,
SettingsError,
get_subcommand,
)

try:
Expand Down Expand Up @@ -3095,6 +3097,12 @@ class FooPlugin:
class BarPlugin:
my_feature: bool = False

bar = BarPlugin()
with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(bar)
with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(bar, cli_exit_on_error=False)

@pydantic_dataclasses.dataclass
class Plugins:
foo: CliSubCommand[FooPlugin]
Expand All @@ -3116,26 +3124,62 @@ class Git(BaseSettings):
init: CliSubCommand[Init]
plugins: CliSubCommand[Plugins]

git = Git(_cli_parse_args=[])
assert git.model_dump() == {
'clone': None,
'init': None,
'plugins': None,
}
assert get_subcommand(git, is_required=False) is None
with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'):
get_subcommand(git)
with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'):
get_subcommand(git, cli_exit_on_error=False)

git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path'])
assert git.model_dump() == {
'clone': None,
'init': {'directory': 'dir/path', 'quiet': True, 'bare': False},
'plugins': None,
}
assert get_subcommand(git) == git.init
assert get_subcommand(git, is_required=False) == git.init

git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true'])
assert git.model_dump() == {
'clone': {'repository': 'repo', 'directory': '.', 'local': False, 'shared': True},
'init': None,
'plugins': None,
}
assert get_subcommand(git) == git.clone
assert get_subcommand(git, is_required=False) == git.clone

git = Git(_cli_parse_args=['plugins', 'bar'])
assert git.model_dump() == {
'clone': None,
'init': None,
'plugins': {'foo': None, 'bar': {'my_feature': False}},
}
assert get_subcommand(git) == git.plugins
assert get_subcommand(git, is_required=False) == git.plugins
assert get_subcommand(get_subcommand(git)) == git.plugins.bar
assert get_subcommand(get_subcommand(git), is_required=False) == git.plugins.bar

class NotModel: ...

with pytest.raises(
SettingsError, match='Error: NotModel is not subclass of BaseModel or pydantic.dataclasses.dataclass'
):
get_subcommand(NotModel())

class NotSettingsConfigDict(BaseModel):
model_config = ConfigDict(cli_exit_on_error='not a bool')

with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(NotSettingsConfigDict())

with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'):
get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False)


def test_cli_union_similar_sub_models():
Expand Down
Loading