Skip to content

Commit

Permalink
allow async/await syntax in @rules to enable mypy type checking (#8639)
Browse files Browse the repository at this point in the history
### Problem

Fixes #7077.

See #8635 (comment). As of #8330, we annotate all `@rule`s with a single return type, even for rules which call `yield Get(...)` within the method body. MyPy doesn't like this, and requests that rules be given a `Generator` or `Iterable` return type. This blocks type-checking (even partially) for many targets which define rules.

### Solution

- Expand `_RuleVisitor` in `rules.py` to extract `Get` calls from `async def` rules.
- Add an `__await__()` method to `Get` and create `MultiGet` to allow awaiting on multiple `Get`s in parallel (named to match the corresponding rust type).
  - Lists cannot be `await`ed in the way that it is possible to `yield [Get(...) ...]` -- the expression must be *awaitable* -- see https://www.python.org/dev/peps/pep-0492/#await-expression. In this case, `MultiGet` wrapping an iterable of `Get`s and exposing an `__await__()` method is a complete replacement for the previous `yield [Get(...) ...]` syntax.
- Edit `native.py` to allow for `@rule` coroutine methods to `return` at the end instead of `yield`.
- Convert the `@rule`s in `build_files.py` into `async` methods, and type-check the `:build_files` target!

### Result

`@rule`s can now be defined with `async`/`await` syntax:
```python
@rule
async def hydrate_struct(address_mapper: AddressMapper, address: Address) -> HydratedStruct:
  # ...
  address_family: AddressFamily = await Get(AddressFamily, Dir(address.spec_path))
  # ...
  hydrated_inline_dependencies: Tuple[HydratedStruct, ...] = await MultiGet(Get(HydratedStruct, Address, a)
                                                for a in inline_dependencies)
  # ...
  return HydratedStruct(consume_dependencies(struct, args={"address": address}))
```

As a result, plugins and backends that define `@rule`s can now be checked with MyPy!

#### Alternative Solution: Returning `Generator[Any, Any, T]`

In [the first few commits of this PR](https://github.com/pantsbuild/pants/pull/8639/files/43f73f1ac2d1e86301936a895ef08ffe8787d0f7..f7e8534c72965c5e6daa143ccefcdc7428192291), MyPy type annotations were added to `build_files.py` without adding any support for `async` methods. This worked by first adding support for `return` at the end of an `@rule` body in `native.py`, as in this PR, and annotating `@rule`s with the return type `-> Generator[Any, Any, T]`, `T` being the rule's actual return type.

This worked, but required a lot of extra effort to extract the return type `T` from the `Generator[Any, Any, T]` annotation, so this was discarded because `async def` rules require no additional work to extract the output type. For the record, this would have looked like:
```python
@rule
def hydrate_struct(address_mapper: AddressMapper, address: Address) -> Generator[Any, Any, HydratedStruct]:
  # ...
  address_family: AddressFamily
  address_family = yield Get(AddressFamily, Dir(address.spec_path))
  # ...
  hydrated_inline_dependencies: Tuple[HydratedStruct, ...]
  hydrated_inline_dependencies = yield [Get(HydratedStruct, Address, a)
                                                for a in inline_dependencies]
  # ...
  return HydratedStruct(consume_dependencies(struct, args={"address": address}))
```

Note that `x: X = yield Get(X, Y(...))` is not a valid python expression -- this is another benefit of the `async`/`await` approach.

#### Follow-Up Extension: Type-Checked `await Get[X](...)`

Another alternative extending the `await Get()` syntax was proposed in order to automatically type-check the result of the `await` call. This would have looked like:

```python
@rule
async def hydrate_struct(address_mapper: AddressMapper, address: Address) -> HydratedStruct:
  # ...
  address_family = await Get[AddressFamily](Dir(address.spec_path))
  # ...
  hydrated_inline_dependencies = await MultiGet(Get[HydratedStruct](Address, a)
                                                for a in inline_dependencies)
  # ...
  return HydratedStruct(consume_dependencies(struct, args={"address": address}))
```

**Note that the `await Get[X](...)` calls are type-checked to return the type `X`!** This means that mypy can check that later uses of the `address_family` and `hydrated_inline_dependencies` objects above are correct, which it couldn't do without a separate redundant `address_family: AddressFamily` annotation before.

The syntax extension for `Get[X](...)` was reverted in e92fecb in order to reduce complexity of the initial implementation.
  • Loading branch information
cosmicexplorer authored Nov 20, 2019
1 parent aa70d5a commit 9459eca
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 64 deletions.
10 changes: 5 additions & 5 deletions src/python/pants/base/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, cast

from pants.util.collections import assert_single_element
from pants.util.dirutil import fast_relpath_optional, recursive_dirname
Expand Down Expand Up @@ -201,14 +201,14 @@ def make_glob_patterns(self, address_mapper):
}


def more_specific(spec1: Spec, spec2: Spec) -> Spec:
def more_specific(spec1: Optional[Spec], spec2: Spec) -> Spec:
"""Returns which of the two specs is more specific.
This is useful when a target matches multiple specs, and we want to associate it with
the "most specific" one, which will make the most intuitive sense to the user.
"""
# Note that if either of spec1 or spec2 is None, the other will be returned.
return spec1 if _specificity[type(spec1)] < _specificity[type(spec2)] else spec2
return cast(Spec, spec1) if _specificity[type(spec1)] < _specificity[type(spec2)] else spec2


@frozen_after_init
Expand Down Expand Up @@ -255,11 +255,11 @@ def matches_target_address_pair(self, address, target):
@dataclass(unsafe_hash=True)
class Specs:
"""A collection of Specs representing Spec subclasses, and a SpecsMatcher to filter results."""
dependencies: Tuple
dependencies: Tuple[Spec, ...]
matcher: SpecsMatcher

def __init__(
self, dependencies: Tuple, tags: Optional[Tuple] = None, exclude_patterns: Tuple = ()
self, dependencies: Tuple[Spec, ...], tags: Optional[Tuple] = None, exclude_patterns: Tuple = ()
) -> None:
self.dependencies = tuple(dependencies)
self.matcher = SpecsMatcher(tags=tags, exclude_patterns=exclude_patterns)
Expand Down
1 change: 1 addition & 0 deletions src/python/pants/engine/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ python_library(
'src/python/pants/util:filtering',
'src/python/pants/util:objects',
],
tags = {'partially_type_checked'},
)

python_library(
Expand Down
39 changes: 20 additions & 19 deletions src/python/pants/engine/build_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import MutableMapping, MutableSequence
from dataclasses import dataclass
from os.path import dirname, join
from typing import Dict
from typing import Dict, Tuple

from twitter.common.collections import OrderedSet

Expand All @@ -24,7 +24,7 @@
from pants.engine.objects import Locatable, SerializableFactory, Validatable
from pants.engine.parser import HydratedStruct
from pants.engine.rules import RootRule, rule
from pants.engine.selectors import Get
from pants.engine.selectors import Get, MultiGet
from pants.engine.struct import Struct
from pants.util.objects import TypeConstraintError

Expand All @@ -42,15 +42,15 @@ def _key_func(entry):


@rule
def parse_address_family(address_mapper: AddressMapper, directory: Dir) -> AddressFamily:
async def parse_address_family(address_mapper: AddressMapper, directory: Dir) -> AddressFamily:
"""Given an AddressMapper and a directory, return an AddressFamily.
The AddressFamily may be empty, but it will not be None.
"""
patterns = tuple(join(directory.path, p) for p in address_mapper.build_patterns)
path_globs = PathGlobs(include=patterns, exclude=address_mapper.build_ignore_patterns)
snapshot = yield Get(Snapshot, PathGlobs, path_globs)
files_content = yield Get(FilesContent, Digest, snapshot.directory_digest)
snapshot: Snapshot = await Get(Snapshot, PathGlobs, path_globs)
files_content: FilesContent = await Get(FilesContent, Digest, snapshot.directory_digest)

if not files_content:
raise ResolveError(
Expand All @@ -63,7 +63,7 @@ def parse_address_family(address_mapper: AddressMapper, directory: Dir) -> Addre
filecontent_product.path, filecontent_product.content, address_mapper.parser
)
)
yield AddressFamily.create(directory.path, address_maps)
return AddressFamily.create(directory.path, address_maps)


def _raise_did_you_mean(address_family, name, source=None):
Expand All @@ -82,14 +82,14 @@ def _raise_did_you_mean(address_family, name, source=None):


@rule
def hydrate_struct(address_mapper: AddressMapper, address: Address) -> HydratedStruct:
async def hydrate_struct(address_mapper: AddressMapper, address: Address) -> HydratedStruct:
"""Given an AddressMapper and an Address, resolve a Struct from a BUILD file.
Recursively collects any embedded addressables within the Struct, but will not walk into a
dependencies field, since those should be requested explicitly by rules.
"""

address_family = yield Get(AddressFamily, Dir(address.spec_path))
address_family: AddressFamily = await Get(AddressFamily, Dir(address.spec_path))

struct = address_family.addressables.get(address)
addresses = address_family.addressables
Expand Down Expand Up @@ -131,9 +131,9 @@ def collect_inline_dependencies(item):
collect_inline_dependencies(struct)

# And then hydrate the inline dependencies.
hydrated_inline_dependencies = yield [
hydrated_inline_dependencies: Tuple[HydratedStruct, ...] = await MultiGet(
Get(HydratedStruct, Address, a) for a in inline_dependencies
]
)
dependencies = [d.value for d in hydrated_inline_dependencies]

def maybe_consume(outer_key, value):
Expand All @@ -155,7 +155,8 @@ def maybe_consume(outer_key, value):
return value

# NB: Some pythons throw an UnboundLocalError for `idx` if it is a simple local variable.
maybe_consume.idx = 0
# TODO(#8496): create a decorator for functions which declare a sentinel variable like this!
maybe_consume.idx = 0 # type: ignore

# 'zip' the previously-requested dependencies back together as struct fields.
def consume_dependencies(item, args=None):
Expand All @@ -177,7 +178,7 @@ def consume_dependencies(item, args=None):
hydrated_args[key] = maybe_consume(key, value)
return _hydrate(type(item), address.spec_path, **hydrated_args)

yield HydratedStruct(consume_dependencies(struct, args={"address": address}))
return HydratedStruct(consume_dependencies(struct, args={"address": address}))


def _hydrate(item_type, spec_path, **kwargs):
Expand All @@ -202,7 +203,7 @@ def _hydrate(item_type, spec_path, **kwargs):


@rule
def provenanced_addresses_from_address_families(
async def provenanced_addresses_from_address_families(
address_mapper: AddressMapper, specs: Specs
) -> ProvenancedBuildFileAddresses:
"""Given an AddressMapper and list of Specs, return matching ProvenancedBuildFileAddresses.
Expand All @@ -213,9 +214,9 @@ def provenanced_addresses_from_address_families(
:raises: :class:`AddressLookupError` if no targets are matched for non-SingleAddress specs.
"""
# Capture a Snapshot covering all paths for these Specs, then group by directory.
snapshot = yield Get(Snapshot, PathGlobs, _spec_to_globs(address_mapper, specs))
snapshot = await Get(Snapshot, PathGlobs, _spec_to_globs(address_mapper, specs))
dirnames = {dirname(f) for f in snapshot.files}
address_families = yield [Get(AddressFamily, Dir(d)) for d in dirnames]
address_families = await MultiGet(Get(AddressFamily, Dir(d)) for d in dirnames)
address_family_by_directory = {af.namespace: af for af in address_families}

matched_addresses = OrderedSet()
Expand Down Expand Up @@ -249,10 +250,10 @@ def provenanced_addresses_from_address_families(
)

# NB: This may be empty, as the result of filtering by tag and exclude patterns!
yield ProvenancedBuildFileAddresses(
return ProvenancedBuildFileAddresses(
tuple(
ProvenancedBuildFileAddress(
build_file_address=addr, provenance=addr_to_provenance.get(addr)
build_file_address=addr, provenance=addr_to_provenance[addr]
)
for addr in matched_addresses
)
Expand All @@ -261,7 +262,7 @@ def provenanced_addresses_from_address_families(

@rule
def remove_provenance(pbfas: ProvenancedBuildFileAddresses) -> BuildFileAddresses:
yield BuildFileAddresses(tuple(pbfa.build_file_address for pbfa in pbfas))
return BuildFileAddresses(tuple(pbfa.build_file_address for pbfa in pbfas))


@dataclass(frozen=True)
Expand All @@ -287,7 +288,7 @@ def _spec_to_globs(address_mapper, specs):
return PathGlobs(include=patterns, exclude=address_mapper.build_ignore_patterns)


def create_graph_rules(address_mapper):
def create_graph_rules(address_mapper: AddressMapper):
"""Creates tasks used to parse Structs from BUILD files.
:param address_mapper_key: The subject key for an AddressMapper instance.
Expand Down
6 changes: 3 additions & 3 deletions src/python/pants/engine/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple

from pants.build_graph.address import BuildFileAddress
from pants.engine.objects import Serializable
Expand Down Expand Up @@ -94,7 +94,7 @@ class AddressFamily:
objects_by_name: Any

@classmethod
def create(cls, spec_path, address_maps):
def create(cls, spec_path, address_maps) -> 'AddressFamily':
"""Creates an address family from the given set of address maps.
:param spec_path: The directory prefix shared by all address_maps.
Expand All @@ -113,7 +113,7 @@ def create(cls, spec_path, address_maps):
.format(spec_path, address_map.path))


objects_by_name = {}
objects_by_name: Dict[str, Tuple[str, Any]] = {}
for address_map in address_maps:
current_path = address_map.path
for name, obj in address_map.objects_by_name.items():
Expand Down
13 changes: 12 additions & 1 deletion src/python/pants/engine/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sysconfig
import traceback
from contextlib import closing
from types import GeneratorType
from types import CoroutineType, GeneratorType
from typing import Any, NamedTuple, Tuple, Type

import cffi
Expand Down Expand Up @@ -502,9 +502,18 @@ def extern_generator_send(self, context_handle, func, arg):
c.identities_buf([c.identify(g.subject) for g in res]),
)
else:
# TODO: this will soon become obsolete when all @rules are fully mypy-annotated and must
# `return` instead of `yield` at the end!
# Break.
response.tag = self._lib.Broke
response.broke = (c.to_value(res),)
except StopIteration as e:
if not e.args:
raise
# This was a `return` from a generator or coroutine, as opposed to a `StopIteration` raised
# by calling `next()` on an empty iterator.
response.tag = self._lib.Broke
response.broke = (c.to_value(e.value),)
except Exception as e:
# Throw.
response.tag = self._lib.Throw
Expand Down Expand Up @@ -568,6 +577,7 @@ class EngineTypes(NamedTuple):
multi_platform_process_request: TypeId
process_result: TypeId
generator: TypeId
coroutine: TypeId
url_to_fetch: TypeId
string: TypeId
bytes: TypeId
Expand Down Expand Up @@ -925,6 +935,7 @@ def ti(type_obj):
multi_platform_process_request=ti(MultiPlatformExecuteProcessRequest),
process_result=ti(FallibleExecuteProcessResult),
generator=ti(GeneratorType),
coroutine=ti(CoroutineType),
url_to_fetch=ti(UrlToFetch),
string=ti(str),
bytes=ti(bytes),
Expand Down
43 changes: 25 additions & 18 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from textwrap import dedent
from typing import Any, Callable, Dict, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

import asttokens
from twitter.common.collections import OrderedSet
Expand All @@ -28,11 +28,11 @@


class _RuleVisitor(ast.NodeVisitor):
"""Pull `Get` calls out of an @rule body and validate `yield` statements."""
"""Pull `Get` calls out of an @rule body and validate `yield` or `await` statements."""

def __init__(self, func, func_node, func_source, orig_indent, parents_table):
super().__init__()
self._gets = []
self._gets: List[Get] = []
self._func = func
self._func_node = func_node
self._func_source = func_source
Expand All @@ -41,10 +41,10 @@ def __init__(self, func, func_node, func_source, orig_indent, parents_table):
self._yields_in_assignments = set()

@property
def gets(self):
def gets(self) -> List[Get]:
return self._gets

def _generate_ast_error_message(self, node, msg):
def _generate_ast_error_message(self, node, msg) -> str:
# This is the location info of the start of the decorated @rule.
filename = inspect.getsourcefile(self._func)
source_lines, line_number = inspect.getsourcelines(self._func)
Expand Down Expand Up @@ -90,15 +90,15 @@ def _generate_ast_error_message(self, node, msg):
class YieldVisitError(Exception): pass

@staticmethod
def _maybe_end_of_stmt_list(attr_value):
def _maybe_end_of_stmt_list(attr_value: Optional[Any]) -> Optional[Any]:
"""If `attr_value` is a non-empty iterable, return its final element."""
if (attr_value is not None) and isinstance(attr_value, Iterable):
result = list(attr_value)
if len(result) > 0:
return result[-1]
return None

def _stmt_is_at_end_of_parent_list(self, stmt):
def _stmt_is_at_end_of_parent_list(self, stmt) -> bool:
"""Determine if `stmt` is at the end of a list of statements (i.e. can be an implicit `return`).
If there are any statements following `stmt` at the same level of nesting, this method returns
Expand Down Expand Up @@ -145,18 +145,19 @@ def _stmt_is_at_end_of_parent_list(self, stmt):
def _is_get(self, node):
return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == Get.__name__

def visit_Call(self, node):
def visit_Call(self, node) -> None:
self.generic_visit(node)
if self._is_get(node):
self._gets.append(Get.extract_constraints(node))

def visit_Assign(self, node):
if isinstance(node.value, ast.Yield):
def visit_Assign(self, node) -> None:
if isinstance(node.value, (ast.Yield, ast.Await)):
self._yields_in_assignments.add(node.value)
self.generic_visit(node)

def visit_Yield(self, node):
def _visit_await_or_yield_compat(self, node, *, is_yield: bool) -> None:
self.generic_visit(node)
if node not in self._yields_in_assignments:
if is_yield and (node not in self._yields_in_assignments):
# The current yield "expr" is the child of an "Expr" "stmt".
expr_for_yield = self._parents_table[node]

Expand Down Expand Up @@ -186,6 +187,12 @@ def visit_Yield(self, node):
supported. See https://github.com/pantsbuild/pants/pull/8227 for progress.
""")))

def visit_Await(self, node) -> None:
self._visit_await_or_yield_compat(node, is_yield=False)

def visit_Yield(self, node) -> None:
self._visit_await_or_yield_compat(node, is_yield=True)


@memoized
def optionable_rule(optionable_factory):
Expand Down Expand Up @@ -256,7 +263,7 @@ def resolve_type(name):
gets = OrderedSet()
rule_func_node = assert_single_element(
node for node in ast.iter_child_nodes(module_ast)
if isinstance(node, ast.FunctionDef) and node.name == func.__name__)
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func.__name__)

parents_table = {}
for parent in ast.walk(rule_func_node):
Expand Down Expand Up @@ -291,20 +298,20 @@ def resolve_type(name):
return wrapper


class MissingTypeAnnotation(TypeError):
"""Indicates a missing type annotation for an `@rule`."""
class InvalidTypeAnnotation(TypeError):
"""Indicates an incorrect type annotation for an `@rule`."""


class MissingReturnTypeAnnotation(MissingTypeAnnotation):
class MissingReturnTypeAnnotation(InvalidTypeAnnotation):
"""Indicates a missing return type annotation for an `@rule`."""


class MissingParameterTypeAnnotation(MissingTypeAnnotation):
class MissingParameterTypeAnnotation(InvalidTypeAnnotation):
"""Indicates a missing parameter type annotation for an `@rule`."""


def _ensure_type_annotation(
annotation: Any, name: str, empty_value: Any, raise_type: Type[MissingTypeAnnotation]
annotation: Any, name: str, empty_value: Any, raise_type: Type[InvalidTypeAnnotation],
) -> type:
if annotation == empty_value:
raise raise_type(f'{name} is missing a type annotation.')
Expand Down
Loading

0 comments on commit 9459eca

Please sign in to comment.