Skip to content

Commit

Permalink
Add support for hydrating sources with the Target API (#9306)
Browse files Browse the repository at this point in the history
This fleshes out how we use `AsyncField`s, which are much more complex than `PrimitiveField`s.

## Result

```python
sources = await Get[SourcesResult](SourcesRequest, tgt.get(Sources).request)
print(sources.snapshot.files)
```

This also works:

```python
if tgt.has_fields([PythonLibrarySources]):
  sources1 = await Get[Sources](SourcesRequest, tgt.get(PythonLibrarySources).request)
  sources2 = await Get[Sources](SourcesRequest, tgt.get(Sources).request)
  assert sources1 == sources2
```

`PythonSources` and its subclasses will validate that all resulting files end in `*.py` (new behavior). `PythonLibrarySources` and `PythonTestsSources` will use the previous default globs. `PythonBinarySources` will enforce that `sources` is 0 or 1 files (previous behavior).

## Solution

### Ensuring support for subclassed `AsyncField`s

With the Target API, we allow new targets to subclass `Field`s for custom behavior. For example, `PythonLibrarySources` might use the default globs of `*.py` whereas `PythonTestSources` might use the default globs of `test_*.py`.

To allow these custom subclasses of `Field`s, we added support in #9286 for substituting in the subclass with the original parent class. For example, `my_python_library.get(Sources) == my_python_library.get(PythonSources) == my_python_library.get(PythonLibrarySources)`.

This works great with `PrimitiveField` but is tricky to implement with `AsyncField` due to the engine not supporting subclasses.

Originally, I tried achieving this extensibility through a union, which would allow the engine to have multiple ways to get a common result type like `SourcesResult`. But, this created a problem that there became multiple paths in the rule graph to compute the same product, e.g. `Sources->SourcesResult`, `PythonSources->SourcesResult`, etc.

Instead, each `AsyncField` should define a simple `Request` dataclass that simply wraps the underlying `AsyncField`. This allows us to have only one path from `SourcesRequest -> SourcesResult`, but still give custom behavior in the underlying `SourcesRequest`. Within the hydration rule, the rule will call standardized extension points provided by the underlying field.

**This means that the onus is on the `AsyncField` author to expose certain entry points for customizing the field's behavior.** For example, `Sources` defines the entry points of `default_globs` and `validate_snapshot()`. `Dependencies` might provide entry points like `inject_dependencies()` and `validate_dependencies()` (not necessarily, only possibilities).

While this approach has lots of boilerplate and less extensibility than `PrimitiveField`s, it solves the graph ambiguity and still allows for subclassing an `AsyncField`.

### Fixing `__eq__` for `Field`s

The previous naive implementation resulted in `Field`s only comparing their classvar `alias`, rather than their actual underlying values. This meant that the engine would cache values when it should not have.

This tweaks how we use dataclasses to ensure that the engine works correctly with `AsyncField`s.
  • Loading branch information
Eric-Arellano authored Mar 16, 2020
1 parent 747d637 commit 50eb697
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 44 deletions.
22 changes: 11 additions & 11 deletions src/python/pants/backend/python/rules/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from typing import Any, ClassVar, Optional

from pants.build_graph.address import Address
from pants.engine.fs import Snapshot
from pants.engine.objects import union
from pants.engine.target import (
COMMON_TARGET_FIELDS,
BoolField,
ImmutableValue,
PrimitiveField,
Sources,
SourcesResult,
StringField,
StringOrStringSequenceField,
Target,
Expand All @@ -20,12 +21,11 @@

@union
class PythonSources(Sources):
@classmethod
def validate_result(cls, result: SourcesResult) -> None:
non_python_files = [fp for fp in result.snapshot.files if not PurePath(fp).suffix == ".py"]
def validate_snapshot(self, snapshot: Snapshot) -> None:
non_python_files = [fp for fp in snapshot.files if not PurePath(fp).suffix == ".py"]
if non_python_files:
raise ValueError(
f"Target {result.address} has non-Python sources in its `sources` field: "
f"Target {self.address} has non-Python sources in its `sources` field: "
f"{non_python_files}"
)

Expand All @@ -39,13 +39,13 @@ class PythonTestsSources(PythonSources):


class PythonBinarySources(PythonSources):
@classmethod
def validate_result(cls, result: SourcesResult) -> None:
super().validate_result(result)
if len(result.snapshot.files) not in [0, 1]:
def validate_snapshot(self, snapshot: Snapshot) -> None:
super().validate_snapshot(snapshot)
if len(snapshot.files) not in [0, 1]:
raise ValueError(
"Binary targets must have only 0 or 1 source files. Any additional files should "
"be put in a `python_library` which is added to `dependencies`"
f"be put in a `python_library` which is added to `dependencies`. The target "
f"{self.address} had {len(snapshot.files)} sources: {snapshot.files}."
)


Expand All @@ -65,7 +65,7 @@ class Compatibility(StringOrStringSequenceField):
class Provides(PrimitiveField):
alias: ClassVar = "provides"

def hydrate(self, raw_value: Optional[Any], *, address: Address) -> Any:
def hydrate(self, raw_value: Optional[Any], *, address: Address) -> ImmutableValue:
return raw_value


Expand Down
68 changes: 67 additions & 1 deletion src/python/pants/backend/python/rules/targets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,19 @@

import pytest

from pants.backend.python.rules.targets import Timeout
from pants.backend.python.rules.targets import (
PythonBinarySources,
PythonLibrarySources,
PythonSources,
PythonTestsSources,
Timeout,
)
from pants.build_graph.address import Address
from pants.engine.rules import RootRule
from pants.engine.scheduler import ExecutionError
from pants.engine.target import SourcesRequest, SourcesResult
from pants.engine.target import rules as target_rules
from pants.testutil.test_base import TestBase


def test_timeout_validation() -> None:
Expand All @@ -13,3 +24,58 @@ def test_timeout_validation() -> None:
with pytest.raises(ValueError):
Timeout(0, address=Address.parse(":tests"))
assert Timeout(5, address=Address.parse(":tests")).value == 5


class TestPythonSources(TestBase):
PYTHON_SRC_FILES = ("f1.py", "f2.py")
PYTHON_TEST_FILES = ("conftest.py", "test_f1.py", "f1_test.py")

@classmethod
def rules(cls):
return [*target_rules(), RootRule(SourcesRequest)]

def test_python_sources_validation(self) -> None:
files = ("f.js", "f.hs", "f.txt", "f.py")
self.create_files(path="", files=files)
sources = PythonSources(files, address=Address.parse(":lib"))
assert sources.sanitized_raw_value == files
with pytest.raises(ExecutionError) as exc:
self.request_single_product(SourcesResult, sources.request)
assert "non-Python sources" in str(exc)
assert "f.hs" in str(exc)

# Also check that we support valid sources
valid_sources = PythonSources(["f.py"], address=Address.parse(":lib"))
assert valid_sources.sanitized_raw_value == ("f.py",)
assert self.request_single_product(SourcesResult, valid_sources.request).snapshot.files == (
"f.py",
)

def test_python_binary_sources_validation(self) -> None:
self.create_files(path="", files=["f1.py", "f2.py"])
address = Address.parse(":binary")

zero_sources = PythonBinarySources(None, address=address)
assert self.request_single_product(SourcesResult, zero_sources.request).snapshot.files == ()

one_source = PythonBinarySources(["f1.py"], address=address)
assert self.request_single_product(SourcesResult, one_source.request).snapshot.files == (
"f1.py",
)

multiple_sources = PythonBinarySources(["f1.py", "f2.py"], address=address)
with pytest.raises(ExecutionError) as exc:
self.request_single_product(SourcesResult, multiple_sources.request)
assert "//:binary had 2 sources" in str(exc)

def test_python_library_sources_default_globs(self) -> None:
self.create_files(path="", files=[*self.PYTHON_SRC_FILES, *self.PYTHON_TEST_FILES])
sources = PythonLibrarySources(None, address=Address.parse(":lib"))
result = self.request_single_product(SourcesResult, sources.request)
assert result.snapshot.files == self.PYTHON_SRC_FILES

def test_python_tests_sources_default_globs(self) -> None:
self.create_files(path="", files=[*self.PYTHON_SRC_FILES, *self.PYTHON_TEST_FILES])
sources = PythonTestsSources(None, address=Address.parse(":tests"))
result = self.request_single_product(SourcesResult, sources.request)
assert set(result.snapshot.files) == set(self.PYTHON_TEST_FILES)
Loading

0 comments on commit 50eb697

Please sign in to comment.