Skip to content

Commit

Permalink
parametrize: expand grouped parametrizations. (#21027)
Browse files Browse the repository at this point in the history
This adds support for having parametrized fields within a parametrize
group.

Example:
```python
 __defaults__({
    (python_sources, python_tests): dict(
        **parametrize(
            "py310-compat",
            resolve="service-a",
            interpreter_constraints=[
                "CPython == 3.9.*",
                "CPython == 3.10.*",
            ]
        ),
        **parametrize(
            "py39-compat",
            resolve=parametrize(
                "service-b",
                "service-c",
                "service-d",
            ),
            interpreter_constraints=[
                "CPython == 3.9.*",
            ]
        )
    )
})
```

## Notice

Parametrize groups does not work well with defaults for `all`, as the
parametrized fields will be added to all targets and will result in
errors unless it's a field that all targets support.
  • Loading branch information
kaos authored Jun 19, 2024
1 parent d3dcd89 commit 3f17936
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ python_test(
)
```
(Using `parametrize` on grouped fields is also supported. For instance, if there is two resolves to use for Python 3.10, these can be provided within the py310 group: `**parametrize("py310", interpreter_constraints=[">=3.10,<3.11"], resolve=parametrize("lock-b", "lock-c"))`.)
The targets' addresses will have `@key=value` at the end, as shown above. Run [`pants list dir:`](../project-introspection.mdx) in the directory of the parametrized target to see all parametrized target addresses, and [`pants peek dir:`](../project-introspection.mdx) to see all their metadata.
Generally, you can use the address without the `@` suffix as an alias to all the parametrized targets. For example, `pants test example:tests` will run all the targets in parallel. Use the more precise address if you only want to use one parameter value, e.g. `pants test example:tests@shell=bash`.
Expand Down
4 changes: 4 additions & 0 deletions docs/notes/2.23.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ for inspiration. To opt into the feature set the flag

### Backends

#### BUILD

Support for parametrizing grouped parametrizations. (e.g. `**parametrize(resolve=parametrize("a", "b"), ...)`). This works for `___defaults__` as well, as long as it is specified per target type rather than using `all`.

#### Docker

Docker inference is improved. Pants can now make inferences by target address for targets supporting `pants package`, and `file` targets can be included by filename. See the [documentation on Docker dependency inference](https://www.pantsbuild.org/2.23/docs/docker#dependency-inference-support) for details
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def generate_tgt(
normalized_proj_name = canonicalize_project_name(project_name)
tgt_overrides = overrides.pop(normalized_proj_name, {})
if Dependencies.alias in tgt_overrides:
tgt_overrides[Dependencies.alias] = list(tgt_overrides[Dependencies.alias]) + req_deps
tgt_overrides = tgt_overrides | {
Dependencies.alias: list(tgt_overrides[Dependencies.alias]) + req_deps
}

return PythonRequirementTarget(
{
Expand Down
64 changes: 64 additions & 0 deletions src/python/pants/engine/internals/build_files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,70 @@ def test_default_parametrized_groups(target_adaptor_rule_runner: RuleRunner) ->
)


def test_default_parametrized_groups_with_parametrizations(
target_adaptor_rule_runner: RuleRunner,
) -> None:
target_adaptor_rule_runner.write_files(
{
"src/BUILD": dedent(
"""
__defaults__({
mock_tgt: dict(
**parametrize(
"py310-compat",
resolve="service-a",
tags=[
"CPython == 3.9.*",
"CPython == 3.10.*",
]
),
**parametrize(
"py39-compat",
resolve=parametrize(
"service-b",
"service-c",
"service-d",
),
tags=[
"CPython == 3.9.*",
]
)
)
})
mock_tgt()
"""
),
}
)
address = Address("src")
target_adaptor = target_adaptor_rule_runner.request(
TargetAdaptor,
[TargetAdaptorRequest(address, description_of_origin="tests")],
)
targets = tuple(Parametrize.expand(address, target_adaptor.kwargs))
assert targets == (
(
address.parametrize(dict(parametrize="py310-compat")),
dict(
tags=("CPython == 3.9.*", "CPython == 3.10.*"),
resolve="service-a",
),
),
(
address.parametrize(dict(parametrize="py39-compat", resolve="service-b")),
dict(tags=("CPython == 3.9.*",), resolve="service-b"),
),
(
address.parametrize(dict(parametrize="py39-compat", resolve="service-c")),
dict(tags=("CPython == 3.9.*",), resolve="service-c"),
),
(
address.parametrize(dict(parametrize="py39-compat", resolve="service-d")),
dict(tags=("CPython == 3.9.*",), resolve="service-d"),
),
)


def test_augment_target_field_defaults(target_adaptor_rule_runner: RuleRunner) -> None:
target_adaptor_rule_runner.write_files(
{
Expand Down
5 changes: 3 additions & 2 deletions src/python/pants/engine/internals/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
FrozenSet,
Iterable,
Iterator,
Mapping,
NamedTuple,
Sequence,
Type,
Expand Down Expand Up @@ -275,7 +276,7 @@ async def _parametrized_target_generators_with_templates(
target_type: type[TargetGenerator],
generator_fields: dict[str, Any],
union_membership: UnionMembership,
) -> list[tuple[TargetGenerator, dict[str, Any]]]:
) -> list[tuple[TargetGenerator, Mapping[str, Any]]]:
# Pre-load field values from defaults for the target type being generated.
if hasattr(target_type, "generated_target_cls"):
family = await Get(AddressFamily, AddressFamilyDir(address.spec_path))
Expand Down Expand Up @@ -468,7 +469,7 @@ def _create_target(
address: Address,
target_type: type[_TargetType],
target_adaptor: TargetAdaptor,
field_values: dict[str, Any],
field_values: Mapping[str, Any],
union_membership: UnionMembership,
name_explicitly_set: bool | None = None,
) -> _TargetType:
Expand Down
102 changes: 73 additions & 29 deletions src/python/pants/engine/internals/parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dataclasses
import itertools
import operator
from collections import defaultdict
from collections import abc, defaultdict
from dataclasses import dataclass
from enum import Enum
from functools import reduce
Expand Down Expand Up @@ -140,32 +140,42 @@ def group_name(self) -> str:

@classmethod
def expand(
cls, address: Address, fields: dict[str, Any | Parametrize]
) -> Iterator[tuple[Address, dict[str, Any]]]:
cls,
address: Address,
fields: Mapping[str, Any | Parametrize],
) -> Iterator[tuple[Address, Mapping[str, Any]]]:
"""Produces the cartesian product of fields for the given possibly-Parametrized fields.
Only one level of expansion is performed: if individual field values might also contain
Parametrize instances (in particular: an `overrides` field), expanding those will require
separate calls.
Parametrized groups are expanded however (that is: any `parametrize` field values in a
`**parametrize()` group are also expanded).
"""
try:
parametrizations = cls._collect_parametrizations(fields)
cls._check_parametrizations(parametrizations)
parametrized: list[list[tuple[str, str, Any]]] = [
[
(field_name, alias, field_value)
for alias, field_value in v.to_parameters().items()
]
for field_name, v in parametrizations.get(None, ())
]
parametrized_groups: list[tuple[str, str, Parametrize]] = [
("parametrize", group_name, vs[0][1])
for group_name, vs in parametrizations.items()
if group_name is not None
]
yield from cls._expand(address, fields)
except Exception as e:
raise Exception(f"Failed to parametrize `{address}`:\n {e}") from e

@classmethod
def _expand(
cls,
address: Address,
fields: Mapping[str, Any | Parametrize],
_parametrization_group_prefix: str = "",
) -> Iterator[tuple[Address, Mapping[str, Any]]]:
parametrizations = cls._collect_parametrizations(fields)
cls._check_parametrizations(parametrizations)
parametrized: list[list[tuple[str, str, Any]]] = [
[(field_name, alias, field_value) for alias, field_value in v.to_parameters().items()]
for field_name, v in parametrizations.get(None, ())
]
parametrized_groups: list[tuple[str, str, Parametrize]] = [
("parametrize", (_parametrization_group_prefix + group_name), vs[0][1])
for group_name, vs in parametrizations.items()
if group_name is not None
]
parameters = address.parameters
non_parametrized = tuple(
(field_name, field_value)
Expand Down Expand Up @@ -193,17 +203,24 @@ def expand(
field_name: alias for field_name, alias, _ in parametrized_args
}
# There will be at most one group per cross product.
group_kwargs: Mapping[str, Any] = next(
parametrize_group: Parametrize | None = next(
(
field_value.kwargs
field_value
for _, _, field_value in parametrized_args
if isinstance(field_value, Parametrize) and field_value.is_group
),
{},
None,
)
# Exclude fields from parametrize group from address parameters.
for k in group_kwargs.keys() & parameters.keys():
expanded_parameters.pop(k, None)
if parametrize_group is not None:
# Exclude fields from parametrize group from address parameters.
for k in parametrize_group.kwargs.keys() & parameters.keys():
expanded_parameters.pop(k, None)
expand_recursively = any(
isinstance(group_value, Parametrize)
for group_value in parametrize_group.kwargs.values()
)
else:
expand_recursively = False

parametrized_args_fields = tuple(
(field_name, field_value)
Expand All @@ -212,14 +229,36 @@ def expand(
if not (isinstance(field_value, Parametrize) and field_value.is_group)
)
expanded_fields: dict[str, Any] = dict(non_parametrized + parametrized_args_fields)
expanded_fields.update(group_kwargs)

expanded_address = address.parametrize(expanded_parameters, replace=True)
yield expanded_address, expanded_fields

if expand_recursively:
assert parametrize_group is not None # Type narrowing to satisfy mypy.
# Expand nested parametrize within a parametrized group.
for grouped_address, grouped_fields in cls._expand(
expanded_address,
parametrize_group.kwargs,
_parametrization_group_prefix=_parametrization_group_prefix
+ parametrize_group.group_name
+ "-",
):
cls._check_conflicting(
{
name
for name in grouped_fields.keys()
if isinstance(fields.get(name), Parametrize)
}
)
yield expanded_address.parametrize(
grouped_address.parameters
), expanded_fields | dict(grouped_fields)
else:
if parametrize_group is not None:
expanded_fields |= dict(parametrize_group.kwargs)
yield expanded_address, expanded_fields

@staticmethod
def _collect_parametrizations(
fields: dict[str, Any | Parametrize]
fields: Mapping[str, Any | Parametrize]
) -> Mapping[str | None, list[tuple[str, Parametrize]]]:
parametrizations = defaultdict(list)
for field_name, v in fields.items():
Expand All @@ -246,13 +285,18 @@ def _check_parametrizations(
if group_name is not None
for field_name in groups[0][1].kwargs.keys()
}
conflicting = parametrize_field_names.intersection(parametrize_field_names_from_groups)
Parametrize._check_conflicting(
parametrize_field_names.intersection(parametrize_field_names_from_groups)
)

@staticmethod
def _check_conflicting(conflicting: abc.Collection[str]) -> None:
if conflicting:
raise ValueError(
softwrap(
f"""
Conflicting parametrizations for {pluralize(len(conflicting), "field", include_count=False)}:
{', '.join(sorted(conflicting))}
{', '.join(map(repr, sorted(conflicting)))}
"""
)
)
Expand Down
51 changes: 48 additions & 3 deletions src/python/pants/engine/internals/parametrize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def test_bad_group_name(exception_str: str, args: list[Any], kwargs: dict[str, A
**Parametrize("B", f1=3, f2=4),
},
),
(
[
("a@parametrize=A,x=x1", {"f0": "c", "f1": 1, "f2": 2, "x": 1}),
("a@parametrize=A,x=x2", {"f0": "c", "f1": 1, "f2": 2, "x": 2}),
("a@parametrize=B,x=x1", {"f0": "c", "f1": 3, "f2": 4, "x": 1}),
("a@parametrize=B,x=x2", {"f0": "c", "f1": 3, "f2": 4, "x": 2}),
],
{
"f0": "c",
"x": Parametrize(x1=1, x2=2),
**Parametrize("A", f1=1, f2=2),
**Parametrize("B", f1=3, f2=4),
},
),
(
[
("a@parametrize=A", {"f": 1}),
Expand All @@ -113,6 +127,23 @@ def test_bad_group_name(exception_str: str, args: list[Any], kwargs: dict[str, A
**Parametrize("C", g=[]),
),
),
(
# Nested parametrization groups with parametrize!
[
("a@c=sub_c3,parametrize=root-sub2", {"a": 2, "b": 0, "c": "val3"}),
("a@c=val2,parametrize=root-sub2", {"a": 2, "b": 0, "c": "val2"}),
("a@parametrize=root-sub1", {"a": 1, "b": 1}),
],
dict( # type: ignore[arg-type]
b=0,
**Parametrize( # type: ignore[arg-type]
"root",
a=1,
**Parametrize("sub1", b=1),
**Parametrize("sub2", a=2, c=Parametrize("val2", sub_c3="val3")),
),
),
),
],
)
def test_expand(
Expand All @@ -123,7 +154,6 @@ def test_expand(
(address.spec, result_fields)
for address, result_fields in Parametrize.expand(Address("a"), fields)
),
key=lambda value: value[0],
)


Expand Down Expand Up @@ -176,7 +206,6 @@ def test_expand_existing_parameters(
Address("a", parameters=parameters), fields
)
),
key=lambda value: value[0],
)


Expand All @@ -193,7 +222,23 @@ def test_expand_existing_parameters(
**Parametrize("A", f=1), # type: ignore[arg-type]
**Parametrize("B", g=2, x=3),
),
"Failed to parametrize `a:a`:\n Conflicting parametrizations for fields: f, g",
"Failed to parametrize `a:a`:\n Conflicting parametrizations for fields: 'f', 'g'",
),
(
dict(
f=Parametrize("x", "y"),
g=Parametrize("x", "y"),
h=Parametrize("x", "y"),
x=5,
z=6,
**Parametrize(
"root",
**Parametrize("A", f=1, h=4), # type: ignore[arg-type]
**Parametrize("B", g=2, x=3),
),
),
# We only catch fields from one nested group at a time.
"Failed to parametrize `a:a`:\n Conflicting parametrizations for fields: 'f', 'h'",
),
(
dict(
Expand Down
Loading

0 comments on commit 3f17936

Please sign in to comment.