diff --git a/src/python/pants/base/specs.py b/src/python/pants/base/specs.py index 88f97128039..070279692ca 100644 --- a/src/python/pants/base/specs.py +++ b/src/python/pants/base/specs.py @@ -5,7 +5,7 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any, Optional, Tuple, cast from pants.util.collections import assert_single_element from pants.util.dirutil import fast_relpath_optional, recursive_dirname @@ -201,14 +201,14 @@ def make_glob_patterns(self, address_mapper): } -def more_specific(spec1: Spec, spec2: Spec) -> Spec: +def more_specific(spec1: Optional[Spec], spec2: Spec) -> Spec: """Returns which of the two specs is more specific. This is useful when a target matches multiple specs, and we want to associate it with the "most specific" one, which will make the most intuitive sense to the user. """ # Note that if either of spec1 or spec2 is None, the other will be returned. - return spec1 if _specificity[type(spec1)] < _specificity[type(spec2)] else spec2 + return cast(Spec, spec1) if _specificity[type(spec1)] < _specificity[type(spec2)] else spec2 @frozen_after_init @@ -255,11 +255,11 @@ def matches_target_address_pair(self, address, target): @dataclass(unsafe_hash=True) class Specs: """A collection of Specs representing Spec subclasses, and a SpecsMatcher to filter results.""" - dependencies: Tuple + dependencies: Tuple[Spec, ...] matcher: SpecsMatcher def __init__( - self, dependencies: Tuple, tags: Optional[Tuple] = None, exclude_patterns: Tuple = () + self, dependencies: Tuple[Spec, ...], tags: Optional[Tuple] = None, exclude_patterns: Tuple = () ) -> None: self.dependencies = tuple(dependencies) self.matcher = SpecsMatcher(tags=tags, exclude_patterns=exclude_patterns) diff --git a/src/python/pants/engine/BUILD b/src/python/pants/engine/BUILD index 941c2de6f94..8595f878315 100644 --- a/src/python/pants/engine/BUILD +++ b/src/python/pants/engine/BUILD @@ -89,6 +89,7 @@ python_library( 'src/python/pants/util:filtering', 'src/python/pants/util:objects', ], + tags = {'partially_type_checked'}, ) python_library( diff --git a/src/python/pants/engine/build_files.py b/src/python/pants/engine/build_files.py index f6d2e38423c..95ffab6f23d 100644 --- a/src/python/pants/engine/build_files.py +++ b/src/python/pants/engine/build_files.py @@ -5,7 +5,7 @@ from collections.abc import MutableMapping, MutableSequence from dataclasses import dataclass from os.path import dirname, join -from typing import Dict +from typing import Dict, Tuple from twitter.common.collections import OrderedSet @@ -24,7 +24,7 @@ from pants.engine.objects import Locatable, SerializableFactory, Validatable from pants.engine.parser import HydratedStruct from pants.engine.rules import RootRule, rule -from pants.engine.selectors import Get +from pants.engine.selectors import Get, MultiGet from pants.engine.struct import Struct from pants.util.objects import TypeConstraintError @@ -42,15 +42,15 @@ def _key_func(entry): @rule -def parse_address_family(address_mapper: AddressMapper, directory: Dir) -> AddressFamily: +async def parse_address_family(address_mapper: AddressMapper, directory: Dir) -> AddressFamily: """Given an AddressMapper and a directory, return an AddressFamily. The AddressFamily may be empty, but it will not be None. """ patterns = tuple(join(directory.path, p) for p in address_mapper.build_patterns) path_globs = PathGlobs(include=patterns, exclude=address_mapper.build_ignore_patterns) - snapshot = yield Get(Snapshot, PathGlobs, path_globs) - files_content = yield Get(FilesContent, Digest, snapshot.directory_digest) + snapshot: Snapshot = await Get(Snapshot, PathGlobs, path_globs) + files_content: FilesContent = await Get(FilesContent, Digest, snapshot.directory_digest) if not files_content: raise ResolveError( @@ -63,7 +63,7 @@ def parse_address_family(address_mapper: AddressMapper, directory: Dir) -> Addre filecontent_product.path, filecontent_product.content, address_mapper.parser ) ) - yield AddressFamily.create(directory.path, address_maps) + return AddressFamily.create(directory.path, address_maps) def _raise_did_you_mean(address_family, name, source=None): @@ -82,14 +82,14 @@ def _raise_did_you_mean(address_family, name, source=None): @rule -def hydrate_struct(address_mapper: AddressMapper, address: Address) -> HydratedStruct: +async def hydrate_struct(address_mapper: AddressMapper, address: Address) -> HydratedStruct: """Given an AddressMapper and an Address, resolve a Struct from a BUILD file. Recursively collects any embedded addressables within the Struct, but will not walk into a dependencies field, since those should be requested explicitly by rules. """ - address_family = yield Get(AddressFamily, Dir(address.spec_path)) + address_family: AddressFamily = await Get(AddressFamily, Dir(address.spec_path)) struct = address_family.addressables.get(address) addresses = address_family.addressables @@ -131,9 +131,9 @@ def collect_inline_dependencies(item): collect_inline_dependencies(struct) # And then hydrate the inline dependencies. - hydrated_inline_dependencies = yield [ + hydrated_inline_dependencies: Tuple[HydratedStruct, ...] = await MultiGet( Get(HydratedStruct, Address, a) for a in inline_dependencies - ] + ) dependencies = [d.value for d in hydrated_inline_dependencies] def maybe_consume(outer_key, value): @@ -155,7 +155,8 @@ def maybe_consume(outer_key, value): return value # NB: Some pythons throw an UnboundLocalError for `idx` if it is a simple local variable. - maybe_consume.idx = 0 + # TODO(#8496): create a decorator for functions which declare a sentinel variable like this! + maybe_consume.idx = 0 # type: ignore # 'zip' the previously-requested dependencies back together as struct fields. def consume_dependencies(item, args=None): @@ -177,7 +178,7 @@ def consume_dependencies(item, args=None): hydrated_args[key] = maybe_consume(key, value) return _hydrate(type(item), address.spec_path, **hydrated_args) - yield HydratedStruct(consume_dependencies(struct, args={"address": address})) + return HydratedStruct(consume_dependencies(struct, args={"address": address})) def _hydrate(item_type, spec_path, **kwargs): @@ -202,7 +203,7 @@ def _hydrate(item_type, spec_path, **kwargs): @rule -def provenanced_addresses_from_address_families( +async def provenanced_addresses_from_address_families( address_mapper: AddressMapper, specs: Specs ) -> ProvenancedBuildFileAddresses: """Given an AddressMapper and list of Specs, return matching ProvenancedBuildFileAddresses. @@ -213,9 +214,9 @@ def provenanced_addresses_from_address_families( :raises: :class:`AddressLookupError` if no targets are matched for non-SingleAddress specs. """ # Capture a Snapshot covering all paths for these Specs, then group by directory. - snapshot = yield Get(Snapshot, PathGlobs, _spec_to_globs(address_mapper, specs)) + snapshot = await Get(Snapshot, PathGlobs, _spec_to_globs(address_mapper, specs)) dirnames = {dirname(f) for f in snapshot.files} - address_families = yield [Get(AddressFamily, Dir(d)) for d in dirnames] + address_families = await MultiGet(Get(AddressFamily, Dir(d)) for d in dirnames) address_family_by_directory = {af.namespace: af for af in address_families} matched_addresses = OrderedSet() @@ -249,10 +250,10 @@ def provenanced_addresses_from_address_families( ) # NB: This may be empty, as the result of filtering by tag and exclude patterns! - yield ProvenancedBuildFileAddresses( + return ProvenancedBuildFileAddresses( tuple( ProvenancedBuildFileAddress( - build_file_address=addr, provenance=addr_to_provenance.get(addr) + build_file_address=addr, provenance=addr_to_provenance[addr] ) for addr in matched_addresses ) @@ -261,7 +262,7 @@ def provenanced_addresses_from_address_families( @rule def remove_provenance(pbfas: ProvenancedBuildFileAddresses) -> BuildFileAddresses: - yield BuildFileAddresses(tuple(pbfa.build_file_address for pbfa in pbfas)) + return BuildFileAddresses(tuple(pbfa.build_file_address for pbfa in pbfas)) @dataclass(frozen=True) @@ -287,7 +288,7 @@ def _spec_to_globs(address_mapper, specs): return PathGlobs(include=patterns, exclude=address_mapper.build_ignore_patterns) -def create_graph_rules(address_mapper): +def create_graph_rules(address_mapper: AddressMapper): """Creates tasks used to parse Structs from BUILD files. :param address_mapper_key: The subject key for an AddressMapper instance. diff --git a/src/python/pants/engine/mapper.py b/src/python/pants/engine/mapper.py index 7390e2afd20..79e18c649ec 100644 --- a/src/python/pants/engine/mapper.py +++ b/src/python/pants/engine/mapper.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple from pants.build_graph.address import BuildFileAddress from pants.engine.objects import Serializable @@ -94,7 +94,7 @@ class AddressFamily: objects_by_name: Any @classmethod - def create(cls, spec_path, address_maps): + def create(cls, spec_path, address_maps) -> 'AddressFamily': """Creates an address family from the given set of address maps. :param spec_path: The directory prefix shared by all address_maps. @@ -113,7 +113,7 @@ def create(cls, spec_path, address_maps): .format(spec_path, address_map.path)) - objects_by_name = {} + objects_by_name: Dict[str, Tuple[str, Any]] = {} for address_map in address_maps: current_path = address_map.path for name, obj in address_map.objects_by_name.items(): diff --git a/src/python/pants/engine/native.py b/src/python/pants/engine/native.py index 692d2ed4efa..3d0633c2eba 100644 --- a/src/python/pants/engine/native.py +++ b/src/python/pants/engine/native.py @@ -9,7 +9,7 @@ import sysconfig import traceback from contextlib import closing -from types import GeneratorType +from types import CoroutineType, GeneratorType from typing import Any, NamedTuple, Tuple, Type import cffi @@ -502,9 +502,18 @@ def extern_generator_send(self, context_handle, func, arg): c.identities_buf([c.identify(g.subject) for g in res]), ) else: + # TODO: this will soon become obsolete when all @rules are fully mypy-annotated and must + # `return` instead of `yield` at the end! # Break. response.tag = self._lib.Broke response.broke = (c.to_value(res),) + except StopIteration as e: + if not e.args: + raise + # This was a `return` from a generator or coroutine, as opposed to a `StopIteration` raised + # by calling `next()` on an empty iterator. + response.tag = self._lib.Broke + response.broke = (c.to_value(e.value),) except Exception as e: # Throw. response.tag = self._lib.Throw @@ -568,6 +577,7 @@ class EngineTypes(NamedTuple): multi_platform_process_request: TypeId process_result: TypeId generator: TypeId + coroutine: TypeId url_to_fetch: TypeId string: TypeId bytes: TypeId @@ -925,6 +935,7 @@ def ti(type_obj): multi_platform_process_request=ti(MultiPlatformExecuteProcessRequest), process_result=ti(FallibleExecuteProcessResult), generator=ti(GeneratorType), + coroutine=ti(CoroutineType), url_to_fetch=ti(UrlToFetch), string=ti(str), bytes=ti(bytes), diff --git a/src/python/pants/engine/rules.py b/src/python/pants/engine/rules.py index 9846aee0f65..888f94382f4 100644 --- a/src/python/pants/engine/rules.py +++ b/src/python/pants/engine/rules.py @@ -12,7 +12,7 @@ from collections.abc import Iterable from dataclasses import dataclass from textwrap import dedent -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple, Type import asttokens from twitter.common.collections import OrderedSet @@ -28,11 +28,11 @@ class _RuleVisitor(ast.NodeVisitor): - """Pull `Get` calls out of an @rule body and validate `yield` statements.""" + """Pull `Get` calls out of an @rule body and validate `yield` or `await` statements.""" def __init__(self, func, func_node, func_source, orig_indent, parents_table): super().__init__() - self._gets = [] + self._gets: List[Get] = [] self._func = func self._func_node = func_node self._func_source = func_source @@ -41,10 +41,10 @@ def __init__(self, func, func_node, func_source, orig_indent, parents_table): self._yields_in_assignments = set() @property - def gets(self): + def gets(self) -> List[Get]: return self._gets - def _generate_ast_error_message(self, node, msg): + def _generate_ast_error_message(self, node, msg) -> str: # This is the location info of the start of the decorated @rule. filename = inspect.getsourcefile(self._func) source_lines, line_number = inspect.getsourcelines(self._func) @@ -90,7 +90,7 @@ def _generate_ast_error_message(self, node, msg): class YieldVisitError(Exception): pass @staticmethod - def _maybe_end_of_stmt_list(attr_value): + def _maybe_end_of_stmt_list(attr_value: Optional[Any]) -> Optional[Any]: """If `attr_value` is a non-empty iterable, return its final element.""" if (attr_value is not None) and isinstance(attr_value, Iterable): result = list(attr_value) @@ -98,7 +98,7 @@ def _maybe_end_of_stmt_list(attr_value): return result[-1] return None - def _stmt_is_at_end_of_parent_list(self, stmt): + def _stmt_is_at_end_of_parent_list(self, stmt) -> bool: """Determine if `stmt` is at the end of a list of statements (i.e. can be an implicit `return`). If there are any statements following `stmt` at the same level of nesting, this method returns @@ -145,18 +145,19 @@ def _stmt_is_at_end_of_parent_list(self, stmt): def _is_get(self, node): return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == Get.__name__ - def visit_Call(self, node): + def visit_Call(self, node) -> None: + self.generic_visit(node) if self._is_get(node): self._gets.append(Get.extract_constraints(node)) - def visit_Assign(self, node): - if isinstance(node.value, ast.Yield): + def visit_Assign(self, node) -> None: + if isinstance(node.value, (ast.Yield, ast.Await)): self._yields_in_assignments.add(node.value) self.generic_visit(node) - def visit_Yield(self, node): + def _visit_await_or_yield_compat(self, node, *, is_yield: bool) -> None: self.generic_visit(node) - if node not in self._yields_in_assignments: + if is_yield and (node not in self._yields_in_assignments): # The current yield "expr" is the child of an "Expr" "stmt". expr_for_yield = self._parents_table[node] @@ -186,6 +187,12 @@ def visit_Yield(self, node): supported. See https://github.com/pantsbuild/pants/pull/8227 for progress. """))) + def visit_Await(self, node) -> None: + self._visit_await_or_yield_compat(node, is_yield=False) + + def visit_Yield(self, node) -> None: + self._visit_await_or_yield_compat(node, is_yield=True) + @memoized def optionable_rule(optionable_factory): @@ -256,7 +263,7 @@ def resolve_type(name): gets = OrderedSet() rule_func_node = assert_single_element( node for node in ast.iter_child_nodes(module_ast) - if isinstance(node, ast.FunctionDef) and node.name == func.__name__) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func.__name__) parents_table = {} for parent in ast.walk(rule_func_node): @@ -291,20 +298,20 @@ def resolve_type(name): return wrapper -class MissingTypeAnnotation(TypeError): - """Indicates a missing type annotation for an `@rule`.""" +class InvalidTypeAnnotation(TypeError): + """Indicates an incorrect type annotation for an `@rule`.""" -class MissingReturnTypeAnnotation(MissingTypeAnnotation): +class MissingReturnTypeAnnotation(InvalidTypeAnnotation): """Indicates a missing return type annotation for an `@rule`.""" -class MissingParameterTypeAnnotation(MissingTypeAnnotation): +class MissingParameterTypeAnnotation(InvalidTypeAnnotation): """Indicates a missing parameter type annotation for an `@rule`.""" def _ensure_type_annotation( - annotation: Any, name: str, empty_value: Any, raise_type: Type[MissingTypeAnnotation] + annotation: Any, name: str, empty_value: Any, raise_type: Type[InvalidTypeAnnotation], ) -> type: if annotation == empty_value: raise raise_type(f'{name} is missing a type annotation.') diff --git a/src/python/pants/engine/selectors.py b/src/python/pants/engine/selectors.py index b27b72abcb7..31643c1611d 100644 --- a/src/python/pants/engine/selectors.py +++ b/src/python/pants/engine/selectors.py @@ -4,7 +4,7 @@ import ast from dataclasses import dataclass from textwrap import dedent -from typing import Any, Tuple, Type +from typing import Any, Generator, Iterable, Tuple, Type, cast from pants.util.meta import frozen_after_init from pants.util.objects import TypeConstraint @@ -16,13 +16,36 @@ class Get: """Experimental synchronous generator API. May be called equivalently as either: - # verbose form: Get(product_type, subject_declared_type, subject) - # shorthand form: Get(product_type, subject_type(subject)) + # verbose form: Get(product, subject_declared_type, subject) + # shorthand form: Get(product, subject_declared_type()) """ product: Type subject_declared_type: Type subject: Any + def __await__(self) -> Generator[Any, Any, Any]: + """Allow a Get to be `await`ed within an `async` method, returning a strongly-typed result. + + The `yield`ed value `self` is interpreted by the engine within `extern_generator_send()` in + `native.py`. This class will yield a single Get instance, which is converted into + `PyGeneratorResponse::Get` from `externs.rs` via the python `cffi` library and the rust + `cbindgen` crate. + + This is how this method is eventually called: + - When the engine calls an `async def` method decorated with `@rule`, an instance of + `types.CoroutineType` is created. + - The engine will call `.send(None)` on the coroutine, which will either: + - raise StopIteration with a value (if the coroutine `return`s), or + - return a `Get` instance to the engine (if the rule instead called `await Get(...)`). + - The engine will fulfill the `Get` request to produce `x`, then call `.send(x)` and repeat the + above until StopIteration. + + See more information about implementing this method at + https://www.python.org/dev/peps/pep-0492/#await-expression. + """ + result = yield self + return result + def __init__(self, *args: Any) -> None: if len(args) not in (2, 3): raise ValueError( @@ -34,9 +57,9 @@ def __init__(self, *args: Any) -> None: if isinstance(subject, (type, TypeConstraint)): raise TypeError(dedent("""\ The two-argument form of Get does not accept a type as its second argument. - + args were: Get({args!r}) - + Get.create_statically_for_rule_graph() should be used to generate a Get() for the `input_gets` field of a rule. If you are using a `yield Get(...)` in a rule and a type was intended, use the 3-argument version: @@ -69,8 +92,6 @@ def render_args(): if len(call_node.args) == 2: product_type, subject_constructor = call_node.args if not isinstance(product_type, ast.Name) or not isinstance(subject_constructor, ast.Call): - # TODO(#7114): describe what types of objects are expected in the get call, not just the - # argument names. After #7114 this will be easier because they will just be types! raise ValueError( 'Two arg form of {} expected (product_type, subject_type(subject)), but ' 'got: ({})'.format(Get.__name__, render_args())) @@ -87,7 +108,7 @@ def render_args(): 'got: ({})'.format(Get.__name__, render_args())) @classmethod - def create_statically_for_rule_graph(cls, product_type, subject_type): + def create_statically_for_rule_graph(cls, product_type, subject_type) -> 'Get': """Construct a `Get` with a None value. This method is used to help make it explicit which `Get` instances are parsed from @rule bodies @@ -96,6 +117,34 @@ def create_statically_for_rule_graph(cls, product_type, subject_type): return cls(product_type, subject_type, None) +@frozen_after_init +@dataclass(unsafe_hash=True) +class MultiGet: + """Can be constructed with an iterable of `Get()`s and `await`ed to evaluate them in parallel.""" + gets: Tuple[Get, ...] + + def __await__(self) -> Generator[Any, Any, Tuple[Any, ...]]: + """Yield a tuple of Get instances with the same subject/product type pairs all at once. + + The `yield`ed value `self.gets` is interpreted by the engine within `extern_generator_send()` in + `native.py`. This class will yield a tuple of Get instances, which is converted into + `PyGeneratorResponse::GetMulti` from `externs.rs`. + + The engine will fulfill these Get instances in parallel, and return a tuple of T + instances to this method, which then returns this tuple to the `@rule` which called + `await MultiGet(Get(T, ...) for ... in ...)`. + """ + result = yield self.gets + return cast(Tuple[Any, ...], result) + + def __init__(self, gets: Iterable[Get]) -> None: + """Create a MultiGet from a generator expression. + + This constructor will infer this class's _Product parameter from the input `gets`. + """ + self.gets = tuple(gets) + + @frozen_after_init @dataclass(unsafe_hash=True) class Params: diff --git a/src/python/pants/testutil/engine/util.py b/src/python/pants/testutil/engine/util.py index 5f30cfcf7a2..6428c103ee6 100644 --- a/src/python/pants/testutil/engine/util.py +++ b/src/python/pants/testutil/engine/util.py @@ -5,7 +5,7 @@ import re from dataclasses import dataclass from io import StringIO -from types import GeneratorType +from types import CoroutineType, GeneratorType from typing import Any, Callable, Optional, Sequence, Type from colors import blue, green, red @@ -81,7 +81,7 @@ def run_rule( task_rule.input_gets, mock_gets)) res = rule(*(rule_args or ())) - if not isinstance(res, GeneratorType): + if not isinstance(res, (CoroutineType, GeneratorType)): return res def get(product, subject): @@ -98,13 +98,17 @@ def get(product, subject): rule_coroutine = res rule_input = None while True: - res = rule_coroutine.send(rule_input) - if isinstance(res, Get): - rule_input = get(res.product, res.subject) - elif type(res) in (tuple, list): - rule_input = [get(g.product, g.subject) for g in res] - else: - return res + try: + res = rule_coroutine.send(rule_input) + if isinstance(res, Get): + rule_input = get(res.product, res.subject) + elif type(res) in (tuple, list): + rule_input = [get(g.product, g.subject) for g in res] + else: + return res + except StopIteration as e: + if e.args: + return e.value def init_native(): diff --git a/src/rust/engine/src/nodes.rs b/src/rust/engine/src/nodes.rs index ab93a896155..86f9da9dbc7 100644 --- a/src/rust/engine/src/nodes.rs +++ b/src/rust/engine/src/nodes.rs @@ -1067,7 +1067,9 @@ impl WrappedNode for Task { }) .then(move |task_result| match task_result { Ok(val) => match externs::get_type_for(&val) { - t if t == context.core.types.generator => Self::generate(context, params, entry, val), + t if t == context.core.types.generator || t == context.core.types.coroutine => { + Self::generate(context, params, entry, val) + } t if t == product => ok(val), _ => err(throw(&format!( "{:?} returned a result value that did not satisfy its constraints: {:?}", diff --git a/src/rust/engine/src/types.rs b/src/rust/engine/src/types.rs index 8c8e202a6e9..229a4381a04 100644 --- a/src/rust/engine/src/types.rs +++ b/src/rust/engine/src/types.rs @@ -26,6 +26,7 @@ pub struct Types { pub multi_platform_process_request: TypeId, pub process_result: TypeId, pub generator: TypeId, + pub coroutine: TypeId, pub url_to_fetch: TypeId, pub string: TypeId, pub bytes: TypeId,