Skip to content

Commit

Permalink
Merge pull request #187 from DanCardin/dc/one-or-more
Browse files Browse the repository at this point in the history
fix: Required positional arguments in native parser.
  • Loading branch information
DanCardin authored Nov 26, 2024
2 parents 98f5bfa + 98c80d5 commit 5b5c94b
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 36 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

### 0.26.0

- Add `Default` object with associated fallback semantics for sequences of default handlers.
- Add `ValueFrom` for handling default_factory lazily, as well as arbitrary function dispatch.
- Add `State` as object accessible to invoke, Arg.parse, and ValueFrom.callable for sharing
- feat: Add `Default` object with associated fallback semantics for sequences of default handlers.
- feat: Add `ValueFrom` for handling default_factory lazily, as well as arbitrary function dispatch.
- feat: Add `State` as object accessible to invoke, Arg.parse, and ValueFrom.callable for sharing
state amongst different stages of argument parsing.
- fix: Skip non-init fields in dataclasses.
- fix: Required positional arguments in the native parser.
- fix: Infer num_args on unbounded sequence options (e.g. `list[list[str]]` annotation on an option).

## 0.25

Expand Down
44 changes: 21 additions & 23 deletions src/cappa/arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ def normalize(
long = infer_long(self, type_view, field_name, default_long)
choices = infer_choices(self, type_view)
action = action or infer_action(self, type_view, long, default)
num_args = infer_num_args(self, type_view, field_name, action, long)
num_args = infer_num_args(
type_view, field_name, arg=self, action=action, long=long
)
required = infer_required(self, default)

parse = infer_parse(self, type_view, state=state)
Expand Down Expand Up @@ -525,20 +527,21 @@ def infer_action(


def infer_num_args(
arg: Arg,
type_view: TypeView,
field_name: str,
action: ArgActionType,
long,
arg: Arg | None = None,
action: ArgActionType | None = None,
long=None,
) -> int:
if arg.num_args is not None:
return arg.num_args
if arg:
if arg.num_args is not None:
return arg.num_args

if arg.parse:
return 1
if arg.parse:
return 1

if ArgAction.is_non_value_consuming(action):
return 0
if ArgAction.is_non_value_consuming(action):
return 0

if type_view.is_union:
# Recursively determine the `num_args` value of each variant. Use the value
Expand All @@ -550,11 +553,7 @@ def infer_num_args(
continue

num_args = infer_num_args(
arg,
type_arg,
field_name,
action,
long,
type_arg, field_name, arg=arg, action=action, long=long
)

distinct_num_args.add(num_args)
Expand All @@ -574,15 +573,17 @@ def infer_num_args(
f"On field '{field_name}', mismatch of arity between union variants. {invalid_kinds}."
)

is_positional = not arg.short and not long
if type_view.is_subclass_of((list, set)) and is_positional:
return -1

if type_view.is_tuple and not type_view.is_variadic_tuple:
return len(type_view.args)

if type_view.is_variadic_tuple and is_positional:
is_positional = arg is None or (not arg.short and not long)
is_sequence = type_view.is_variadic_tuple or type_view.is_subclass_of((list, set))
if is_positional and is_sequence:
return -1

# Options with outer-types as sequences should determine the num_args from the inner type.
if type_view.is_subclass_of((list, set, tuple)):
return infer_num_args(type_view.inner_types[0], field_name)
return 1


Expand Down Expand Up @@ -659,9 +660,6 @@ def infer_value_name(arg: Arg, field_name: str, num_args: int | None) -> str:
if arg.value_name is not Empty:
return arg.value_name

if num_args == -1:
return f"{field_name} ..."

if num_args and num_args > 1:
return " ".join([field_name] * num_args)

Expand Down
4 changes: 2 additions & 2 deletions src/cappa/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ def add_argument(
num_args = backend_num_args(arg.num_args)

kwargs: dict[str, typing.Any] = {
"dest": dest_prefix + arg.field_name,
"dest": dest_prefix + assert_type(arg.field_name, str),
"help": arg.help,
"metavar": arg.value_name,
"action": get_action(arg),
"default": argparse.SUPPRESS,
}

if not is_positional and arg.required:
if not is_positional and arg.required and assert_type(arg.num_args, int) >= 0:
kwargs["required"] = arg.required

if num_args is not None and not ArgAction.is_non_value_consuming(arg.action):
Expand Down
6 changes: 6 additions & 0 deletions src/cappa/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,16 @@ def format_arg_name(arg: Arg | Subcommand, delimiter, *, n=0) -> str:
if not arg.is_option:
arg_names = arg_names.upper()

if arg.num_args == -1:
arg_names = f"{arg_names} ..."

text = f"[cappa.arg]{arg_names}[/cappa.arg]"

if arg.is_option and has_value:
name = typing.cast(str, arg.value_name).upper()
if arg.num_args == -1:
name = f"{name} ..."

text = f"{text} [cappa.arg.name]{name}[/cappa.arg.name]"

if not arg.required:
Expand Down
11 changes: 9 additions & 2 deletions src/cappa/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,18 @@ def consume_arg(
arg=arg,
)
else:
if orig_num_args > 0 and len(result) != orig_num_args:
missing_arg_requirement = (
# Positive fixed-arg amount
(orig_num_args > 0 and len(result) != orig_num_args)
# Unbounded but required arg
or (orig_num_args < 0 and arg.required and not result)
)
if missing_arg_requirement:
quoted_result = [f"'{r}'" for r in result]
names_str = arg.names_str("/")

message = f"Argument '{names_str}' requires {orig_num_args} values, found {len(result)}"
num_args_value = "at least one" if orig_num_args < 0 else orig_num_args
message = f"Argument '{names_str}' requires {num_args_value} values, found {len(result)}"
if quoted_result:
message += f" ({', '.join(quoted_result)} so far)"
raise BadArgumentError(
Expand Down
41 changes: 41 additions & 0 deletions tests/arg/test_num_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from dataclasses import dataclass

from typing_extensions import Annotated

import cappa
from tests.utils import parse


def test_unbounded_list_option():
@dataclass
class Args:
unbounded: Annotated[list[list[int]], cappa.Arg(short=True)]

result = parse(Args, "-u", "0")
assert result == Args([[0]])

result = parse(Args, "-u", "0", "1", "2", "3", "4", "5")
assert result == Args([[0, 1, 2, 3, 4, 5]])


def test_unbounded_set_option():
@dataclass
class Args:
unbounded: Annotated[set[tuple[int, int]], cappa.Arg(short=True)]

result = parse(Args, "-u", "0", "1")
assert result == Args({(0, 1)})


def test_unbounded_tuple_option():
@dataclass
class Args:
unbounded: Annotated[tuple[list[int], ...], cappa.Arg(short=True)]

result = parse(Args, "-u", "0")
assert result == Args(([0],))

result = parse(Args, "-u", "0", "1", "2", "3", "4", "5")
assert result == Args(([0, 1, 2, 3, 4, 5],))
30 changes: 30 additions & 0 deletions tests/arg/test_required.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Union

Expand Down Expand Up @@ -74,3 +76,31 @@ class Example:

result = parse(Example, "-c", "c", backend=backend)
assert result == Example(c="c")


@backends
def test_required_unbounded_list(backend):
@dataclass
class Example:
c: Annotated[list[str], cappa.Arg(required=True)]

with pytest.raises(cappa.Exit) as e:
parse(Example, backend=backend)
assert "require" in str(e.value.message)

result = parse(Example, "c", backend=backend)
assert result == Example(["c"])


@backends
def test_required_option(backend):
@dataclass
class Example:
c: Annotated[list[str], cappa.Arg(short=True, required=True)]

with pytest.raises(cappa.Exit) as e:
parse(Example, backend=backend)
assert "the following arguments are required: -c" == str(e.value.message).lower()

result = parse(Example, "-c", "c", backend=backend)
assert result == Example(["c"])
14 changes: 9 additions & 5 deletions tests/help/test_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
def test_argument_name(capsys):
@dataclass
class Args:
name: Annotated[str, cappa.Arg(value_name="string-name", help="more")]
short: Annotated[str, cappa.Arg(short=True, value_name="optional-string")]
name: Annotated[str, cappa.Arg(value_name="sname", help="more")]
short: Annotated[str, cappa.Arg(short=True, value_name="ostr")]
unbounded: Annotated[list[list[str]], cappa.Arg(short=True, value_name="UNB")]
unbounded_pos: Annotated[list[str], cappa.Arg(value_name="upos", help="lots")]

with pytest.raises(cappa.HelpExit) as e:
parse(Args, "--help")
Expand All @@ -25,13 +27,15 @@ class Args:

assert out == textwrap.dedent(
"""\
Usage: args -s OPTIONAL-STRING STRING-NAME [-h] [--completion COMPLETION]
Usage: args -s OSTR -u UNB ... SNAME UPOS ... [-h] [--completion COMPLETION]
Options
-s OPTIONAL-STRING
-s OSTR
-u UNB ...
Arguments
STRING-NAME more
SNAME more
UPOS ... lots
Help
[-h, --help] Show this message and exit.
Expand Down
20 changes: 20 additions & 0 deletions tests/parser/test_unbounded_num_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, field
from typing import Optional

import pytest
from typing_extensions import Annotated

import cappa
Expand Down Expand Up @@ -34,3 +35,22 @@ class Args:
# Or the -- separator can be used to terminate it
t1 = parse(Args, "-a", "1", "2", "3", "--", "foo", backend=backend)
assert t1 == Args(a=["1", "2", "3"], foo="foo")


@backends
def test_unbounded_positional_args(backend):
@dataclass
class Args:
a: list[str]

with pytest.raises(cappa.Exit) as e:
parse(Args, backend=backend)
error = str(e.value.message).lower()

if backend:
assert error == "the following arguments are required: a"
else:
assert error == "argument 'a' requires at least one values, found 0"

t1 = parse(Args, "a", backend=backend)
assert t1 == Args(["a"])
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5b5c94b

Please sign in to comment.