Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store Address on V2 Target and pass it to Fields during validation #9300

Merged
merged 1 commit into from
Mar 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,29 @@

from typing_extensions import final

from pants.build_graph.address import Address
from pants.engine.rules import UnionMembership
from pants.util.meta import frozen_after_init


@frozen_after_init
@dataclass(unsafe_hash=True)
@dataclass(unsafe_hash=True) # type: ignore[misc] # MyPy doesn't like the abstract __init__()
class Field(ABC):
alias: ClassVar[str]
raw_value: Optional[Any] # None indicates that the field was not explicitly defined

# This is a little weird to have an abstract __init__(). We do this to ensure that all
# subclasses have this exact type signature for their constructor.
#
# Normally, with dataclasses, each constructor parameter would instead be specified via a
# dataclass field declaration. But, we don't want to declare `address` as an actual attribute
# because not all subclasses will need to store the value. Instead, we only care that the
# constructor accepts `address` so that the `Field` can use it in validation, and can
# optionally store the value if it wants to.
@abstractmethod
def __init__(self, raw_value: Optional[Any], *, address: Address) -> None:
pass

def __repr__(self) -> str:
return f"{self.__class__}(alias={repr(self.alias)}, raw_value={self.raw_value})"

Expand All @@ -40,7 +53,7 @@ class ZipSafe(PrimitiveField):
raw_value: Optional[bool]
value: bool

def hydrate(self) -> bool:
def hydrate(self, *, address: Address) -> bool:
if self.raw_value is None:
return True
return self.raw_value
Expand All @@ -49,12 +62,15 @@ def hydrate(self) -> bool:
value: Any

@final
def __init__(self, raw_value: Optional[Any]) -> None:
def __init__(self, raw_value: Optional[Any], *, address: Address) -> None:
self.raw_value = raw_value
self.value = self.hydrate()
# NB: we do not store the `address` as an attribute of the class. We only use the
# `address` parameter for eager validation of the field to be able to generate more
# helpful error messages.
self.value = self.hydrate(address=address)

@abstractmethod
def hydrate(self) -> Any:
def hydrate(self, *, address: Address) -> Any:
"""Convert `self.raw_value` into `self.value`.

You should perform any validation and/or hydration here. For example, you may want to check
Expand Down Expand Up @@ -108,6 +124,13 @@ def rules():
sources = await Get[SourcesResult](Sources, my_tgt.get(Sources))
"""

address: Address

@final
def __init__(self, raw_value: Optional[Any], *, address: Address) -> None:
self.raw_value = raw_value
self.address = address

def __str__(self) -> str:
return f"{self.alias}={repr(self.raw_value)}"

Expand All @@ -125,6 +148,7 @@ class Target(ABC):
core_fields: ClassVar[Tuple[Type[Field], ...]]

# These get calculated in the constructor
address: Address
plugin_fields: Tuple[Type[Field], ...]
field_values: Dict[Type[Field], Field]

Expand All @@ -133,8 +157,10 @@ def __init__(
self,
unhydrated_values: Dict[str, Any],
*,
address: Address,
union_membership: Optional[UnionMembership] = None,
) -> None:
self.address = address
self.plugin_fields = cast(
Tuple[Type[Field], ...],
(
Expand All @@ -149,13 +175,14 @@ def __init__(
for alias, value in unhydrated_values.items():
if alias not in aliases_to_field_types:
raise ValueError(
f"Unrecognized field `{alias}={value}` for target type `{self.alias}`."
f"Unrecognized field `{alias}={value}` for target {address} with target "
f"type `{self.alias}`."
)
field_type = aliases_to_field_types[alias]
self.field_values[field_type] = field_type(value)
self.field_values[field_type] = field_type(value, address=address)
# For undefined fields, mark the raw value as None.
for field_type in set(self.field_types) - set(self.field_values.keys()):
self.field_values[field_type] = field_type(raw_value=None)
self.field_values[field_type] = field_type(raw_value=None, address=address)

@final
@property
Expand All @@ -174,6 +201,7 @@ class PluginField:
def __repr__(self) -> str:
return (
f"{self.__class__}("
f"address={self.address},"
f"alias={repr(self.alias)}, "
f"core_fields={list(self.core_fields)}, "
f"plugin_fields={list(self.plugin_fields)}, "
Expand All @@ -183,7 +211,8 @@ def __repr__(self) -> str:

def __str__(self) -> str:
fields = ", ".join(str(field) for field in self.field_values.values())
return f"{self.alias}({fields})"
address = f"address=\"{self.address}\"{', ' if fields else ''}"
return f"{self.alias}({address}{fields})"

@final
def _find_registered_field_subclass(self, requested_field: Type[_F]) -> Optional[Type[_F]]:
Expand Down Expand Up @@ -234,7 +263,7 @@ class BoolField(PrimitiveField):
value: bool
default: ClassVar[bool]

def hydrate(self) -> bool:
def hydrate(self, *, address: Address) -> bool:
if self.raw_value is None:
return self.default
# TODO: consider type checking `raw_value` via `isinstance`. Here, we assume that it's
Expand Down
66 changes: 45 additions & 21 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

from pants.build_graph.address import Address
from pants.engine.fs import EMPTY_DIRECTORY_DIGEST, PathGlobs, Snapshot
from pants.engine.rules import UnionMembership, rule
from pants.engine.selectors import Get
Expand All @@ -21,7 +22,7 @@ class HaskellGhcExtensions(PrimitiveField):
raw_value: Optional[List[str]]
value: List[str]

def hydrate(self) -> List[str]:
def hydrate(self, *, address: Address) -> List[str]:
if self.raw_value is None:
return []
# Add some arbitrary validation to test that hydration/validation works properly.
Expand All @@ -31,7 +32,7 @@ def hydrate(self) -> List[str]:
if bad_extensions:
raise ValueError(
f"All elements of `{self.alias}` must be prefixed by `Ghc`. Received "
f"{bad_extensions}."
f"{bad_extensions} for target {address}."
)
return self.raw_value

Expand All @@ -54,7 +55,10 @@ async def hydrate_haskell_sources(sources: HaskellSources) -> HaskellSourcesResu
# Validate after hydration
non_haskell_sources = [fp for fp in result.files if PurePath(fp).suffix != ".hs"]
if non_haskell_sources:
raise ValueError(f"Received non-Haskell sources in {sources.alias}: {non_haskell_sources}.")
raise ValueError(
f"Received non-Haskell sources in {sources.alias} for target {sources.address}: "
f"{non_haskell_sources}."
)
return HaskellSourcesResult(result)


Expand All @@ -65,19 +69,24 @@ class HaskellTarget(Target):

def test_invalid_fields_rejected() -> None:
with pytest.raises(ValueError) as exc:
HaskellTarget({"invalid_field": True})
assert "Unrecognized field `invalid_field=True` for target type `haskell`." in str(exc)
HaskellTarget({"invalid_field": True}, address=Address.parse(":lib"))
assert (
"Unrecognized field `invalid_field=True` for target //:lib with target type `haskell`."
in str(exc)
)


def test_get_primitive_field() -> None:
extensions = ["GhcExistentialQuantification"]
extensions_field = HaskellTarget({HaskellGhcExtensions.alias: extensions}).get(
HaskellGhcExtensions
)
extensions_field = HaskellTarget(
{HaskellGhcExtensions.alias: extensions}, address=Address.parse(":lib")
).get(HaskellGhcExtensions)
assert extensions_field.raw_value == extensions
assert extensions_field.value == extensions

default_extensions_field = HaskellTarget({}).get(HaskellGhcExtensions)
default_extensions_field = HaskellTarget({}, address=Address.parse(":default")).get(
HaskellGhcExtensions
)
assert default_extensions_field.raw_value is None
assert default_extensions_field.value == []

Expand All @@ -86,7 +95,9 @@ def test_get_async_field() -> None:
def hydrate_field(
*, raw_source_files: List[str], hydrated_source_files: Tuple[str, ...]
) -> HaskellSourcesResult:
sources_field = HaskellTarget({HaskellSources.alias: raw_source_files}).get(HaskellSources)
sources_field = HaskellTarget(
{HaskellSources.alias: raw_source_files}, address=Address.parse(":lib")
).get(HaskellSources)
assert sources_field.raw_value == raw_source_files
result: HaskellSourcesResult = run_rule(
hydrate_haskell_sources,
Expand Down Expand Up @@ -123,14 +134,15 @@ def hydrate_field(
with pytest.raises(ValueError) as exc:
hydrate_field(raw_source_files=["*.js"], hydrated_source_files=("not_haskell.js",))
assert "Received non-Haskell sources" in str(exc)
assert "//:lib" in str(exc)


def test_has_fields() -> None:
class UnrelatedField(BoolField):
alias: ClassVar = "unrelated"
default: ClassVar = False

tgt = HaskellTarget({})
tgt = HaskellTarget({}, address=Address.parse(":lib"))
assert tgt.has_fields([]) is True
assert tgt.has_fields([HaskellGhcExtensions]) is True
assert tgt.has_fields([UnrelatedField]) is False
Expand All @@ -140,9 +152,11 @@ class UnrelatedField(BoolField):
def test_primitive_field_hydration_is_eager() -> None:
with pytest.raises(ValueError) as exc:
HaskellTarget(
{HaskellGhcExtensions.alias: ["GhcExistentialQuantification", "DoesNotStartWithGhc"]}
{HaskellGhcExtensions.alias: ["GhcExistentialQuantification", "DoesNotStartWithGhc"]},
address=Address.parse(":bad_extension"),
)
assert "must be prefixed by `Ghc`" in str(exc)
assert "//:bad_extension" in str(exc)


def test_add_custom_fields() -> None:
Expand All @@ -152,13 +166,17 @@ class CustomField(BoolField):

union_membership = UnionMembership({HaskellTarget.PluginField: OrderedSet([CustomField])})
tgt_values = {CustomField.alias: True}
tgt = HaskellTarget(tgt_values, union_membership=union_membership)
tgt = HaskellTarget(
tgt_values, address=Address.parse(":lib"), union_membership=union_membership
)
assert tgt.field_types == (HaskellGhcExtensions, HaskellSources, CustomField)
assert tgt.core_fields == (HaskellGhcExtensions, HaskellSources)
assert tgt.plugin_fields == (CustomField,)
assert tgt.get(CustomField).value is True

default_tgt = HaskellTarget({}, union_membership=union_membership)
default_tgt = HaskellTarget(
{}, address=Address.parse(":default"), union_membership=union_membership
)
assert default_tgt.get(CustomField).value is False


Expand All @@ -177,16 +195,18 @@ class CustomHaskellGhcExtensions(HaskellGhcExtensions):
banned_extensions: ClassVar = ["GhcBanned"]
default_extensions: ClassVar = ["GhcCustomExtension"]

def hydrate(self) -> List[str]:
def hydrate(self, *, address: Address) -> List[str]:
# Ensure that we avoid certain problematic extensions and always use some defaults.
specified_extensions = super().hydrate()
specified_extensions = super().hydrate(address=address)
banned = [
extension
for extension in specified_extensions
if extension in self.banned_extensions
]
if banned:
raise ValueError(f"Banned extensions used for {self.alias}: {banned}.")
raise ValueError(
f"Banned extensions used for {self.alias} on target {address}: {banned}."
)
return [*specified_extensions, *self.default_extensions]

class CustomHaskellTarget(Target):
Expand All @@ -195,7 +215,9 @@ class CustomHaskellTarget(Target):
{*HaskellTarget.core_fields, CustomHaskellGhcExtensions} - {HaskellGhcExtensions}
)

custom_tgt = CustomHaskellTarget({HaskellGhcExtensions.alias: ["GhcNormalExtension"]})
custom_tgt = CustomHaskellTarget(
{HaskellGhcExtensions.alias: ["GhcNormalExtension"]}, address=Address.parse(":custom")
)

assert custom_tgt.has_fields([HaskellGhcExtensions]) is True
assert custom_tgt.has_fields([CustomHaskellGhcExtensions]) is True
Expand All @@ -204,7 +226,7 @@ class CustomHaskellTarget(Target):
# Ensure that subclasses not defined on a target are not accepted. This allows us to, for
# example, filter every target with `PythonSources` (or a subclass) and to ignore targets with
# only `Sources`.
normal_tgt = HaskellTarget({})
normal_tgt = HaskellTarget({}, address=Address.parse(":normal"))
assert normal_tgt.has_fields([HaskellGhcExtensions]) is True
assert normal_tgt.has_fields([CustomHaskellGhcExtensions]) is False

Expand All @@ -216,13 +238,15 @@ class CustomHaskellTarget(Target):

# Check custom default value
assert (
CustomHaskellTarget({}).get(HaskellGhcExtensions).value
CustomHaskellTarget({}, address=Address.parse(":default")).get(HaskellGhcExtensions).value
== CustomHaskellGhcExtensions.default_extensions
)

# Custom validation
with pytest.raises(ValueError) as exc:
CustomHaskellTarget(
{HaskellGhcExtensions.alias: CustomHaskellGhcExtensions.banned_extensions}
{HaskellGhcExtensions.alias: CustomHaskellGhcExtensions.banned_extensions},
address=Address.parse(":invalid"),
)
assert str(CustomHaskellGhcExtensions.banned_extensions) in str(exc)
assert "//:invalid" in str(exc)