Skip to content

Commit

Permalink
Add and use support for statically declared Gets.
Browse files Browse the repository at this point in the history
  • Loading branch information
stuhood committed Mar 29, 2018
1 parent ae6a218 commit c4b40a3
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/python/pants/engine/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
Tasks* tasks_create(void);
void tasks_task_begin(Tasks*, Function, TypeConstraint);
void tasks_add_get(Tasks*, TypeConstraint, TypeId);
void tasks_add_select(Tasks*, TypeConstraint);
void tasks_add_select_variant(Tasks*, TypeConstraint, Buffer);
void tasks_add_select_dependencies(Tasks*, TypeConstraint, TypeConstraint, Buffer, TypeIdBuffer);
Expand Down
60 changes: 54 additions & 6 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,47 @@
from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
unicode_literals, with_statement)

import ast
import inspect
import logging
from abc import abstractproperty
from collections import OrderedDict

from twitter.common.collections import OrderedSet

from pants.engine.addressable import Exactly
from pants.engine.selectors import type_or_constraint_repr
from pants.engine.selectors import Get, type_or_constraint_repr
from pants.util.meta import AbstractClass
from pants.util.objects import datatype


logger = logging.getLogger(__name__)


class _RuleVisitor(ast.NodeVisitor):
def __init__(self):
super(_RuleVisitor, self).__init__()
self.gets = []

def visit_Call(self, node):
if not isinstance(node.func, ast.Name) or node.func.id != Get.__name__:
return

# TODO: Validation.
if len(node.args) == 2:
product_type, subject_constructor = node.args
if not isinstance(product_type, ast.Name) or not isinstance(subject_constructor, ast.Call):
raise Exception('TODO: Implement validation of Get shapes.')
self.gets.append((product_type.id, subject_constructor.func.id))
elif len(node.args) == 3:
product_type, subject_type, _ = node.args
if not isinstance(product_type, ast.Name) or not isinstance(subject_type, ast.Name):
raise Exception('TODO: Implement validation of Get shapes.')
self.gets.append((product_type.id, subject_type.id))
else:
raise Exception('Invalid {}: {}'.format(Get.__name__, node.args))


def rule(output_type, input_selectors):
"""A @decorator that declares that a particular static function may be used as a TaskRule.
Expand All @@ -29,7 +55,20 @@ def rule(output_type, input_selectors):
to the @decorated function.
"""
def wrapper(func):
func._rule = TaskRule(output_type, input_selectors, func)
caller_frame = inspect.stack()[1][0]
module_ast = ast.parse(inspect.getsource(func))

def resolve(name):
return caller_frame.f_globals.get(name) or caller_frame.f_builtins.get(name)

gets = []
for node in ast.iter_child_nodes(module_ast):
if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
rule_visitor = _RuleVisitor()
rule_visitor.visit(node)
gets.extend(Get(resolve(p), resolve(s)) for p, s in rule_visitor.gets)

func._rule = TaskRule(output_type, input_selectors, func, input_gets=gets)
return func
return wrapper

Expand All @@ -50,10 +89,13 @@ def input_selectors(self):
"""Collection of input selectors."""


class TaskRule(datatype('TaskRule', ['output_constraint', 'input_selectors', 'func']), Rule):
"""A Rule that runs a task function when all of its input selectors are satisfied."""
class TaskRule(datatype('TaskRule', ['output_constraint', 'input_selectors', 'input_gets', 'func']), Rule):
"""A Rule that runs a task function when all of its input selectors are satisfied.
TODO: Make input_gets non-optional when more/all rules are using them.
"""

def __new__(cls, output_type, input_selectors, func):
def __new__(cls, output_type, input_selectors, func, input_gets=None):
# Validate result type.
if isinstance(output_type, Exactly):
constraint = output_type
Expand All @@ -68,8 +110,14 @@ def __new__(cls, output_type, input_selectors, func):
raise TypeError("Expected a list of Selectors for rule `{}`, got: {}".format(
func.__name__, type(input_selectors)))

# Validate gets.
input_gets = [] if input_gets is None else input_gets
if not isinstance(input_gets, list):
raise TypeError("Expected a list of Gets for rule `{}`, got: {}".format(
func.__name__, type(input_gets)))

# Create.
return super(TaskRule, cls).__new__(cls, constraint, tuple(input_selectors), func)
return super(TaskRule, cls).__new__(cls, constraint, tuple(input_selectors), tuple(input_gets), func)

def __str__(self):
return '({}, {!r}, {})'.format(type_or_constraint_repr(self.output_constraint),
Expand Down
7 changes: 5 additions & 2 deletions src/python/pants/engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ def _register_singleton(self, output_constraint, rule):

def _register_task(self, output_constraint, rule):
"""Register the given TaskRule with the native scheduler."""
input_selects = rule.input_selectors
func = rule.func
self._native.lib.tasks_task_begin(self._tasks, Function(self._to_key(func)), output_constraint)
for selector in input_selects:
for selector in rule.input_selectors:
selector_type = type(selector)
product_constraint = self._to_constraint(selector.product)
if selector_type is Select:
Expand All @@ -218,6 +217,10 @@ def _register_task(self, output_constraint, rule):
self._to_constraint(selector.input_product))
else:
raise ValueError('Unrecognized Selector type: {}'.format(selector))
for get in rule.input_gets:
self._native.lib.tasks_add_get(self._tasks,
self._to_constraint(get.product),
TypeId(self._to_id(get.subject)))
self._native.lib.tasks_task_end(self._tasks)

def visualize_graph_to_file(self, execution_request, filename):
Expand Down
7 changes: 7 additions & 0 deletions src/rust/engine/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,13 @@ pub extern "C" fn tasks_task_begin(
})
}

#[no_mangle]
pub extern "C" fn tasks_add_get(tasks_ptr: *mut Tasks, product: TypeConstraint, subject: TypeId) {
with_tasks(tasks_ptr, |tasks| {
tasks.add_get(product, subject);
})
}

#[no_mangle]
pub extern "C" fn tasks_add_select(tasks_ptr: *mut Tasks, product: TypeConstraint) {
with_tasks(tasks_ptr, |tasks| {
Expand Down
67 changes: 42 additions & 25 deletions src/rust/engine/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::BTreeMap;
use std::fmt;
use std::os::unix::ffi::OsStrExt;
use std::path::{Path, PathBuf};
use std::sync::Arc;

use futures::future::{self, Future};
use tempdir::TempDir;
Expand Down Expand Up @@ -124,6 +125,21 @@ impl Select {
}
}

pub fn new_with_entries(
product: TypeConstraint,
subject: Key,
variants: Variants,
entries: rule_graph::Entries,
) -> Select {
let selector = selectors::Select::without_variant(product);
Select {
selector: selector,
subject: subject,
variants: variants,
entries: entries,
}
}

pub fn new_with_selector(
selector: selectors::Select,
subject: Key,
Expand Down Expand Up @@ -356,7 +372,7 @@ impl Select {
product: self.product().clone(),
variants: self.variants.clone(),
task: task,
entry: entry.clone(),
entry: Arc::new(entry.clone()),
})
})
.collect::<Vec<NodeFuture<Value>>>()
Expand Down Expand Up @@ -883,7 +899,7 @@ pub struct Task {
product: TypeConstraint,
variants: Variants,
task: tasks::Task,
entry: rule_graph::Entry,
entry: Arc<rule_graph::Entry>,
}

impl Task {
Expand All @@ -909,30 +925,25 @@ impl Task {
}
}

///
/// TODO: Merge with `get` once all edges are statically declared.
///
fn gen_get(context: &Context, gets: Vec<externs::Get>) -> NodeFuture<Vec<Value>> {
fn gen_get(
context: &Context,
entry: Arc<rule_graph::Entry>,
gets: Vec<externs::Get>,
) -> NodeFuture<Vec<Value>> {
let get_futures = gets
.into_iter()
.map(|get| {
let externs::Get(constraint, subject) = get;
let selector = selectors::Select::without_variant(constraint.clone());
let edges_res = context
let externs::Get(product, subject) = get;
let entries = context
.core
.rule_graph
.find_root_edges(*subject.type_id(), selectors::Selector::Select(selector))
.ok_or_else(|| {
throw(&format!(
"No rules were available to compute {} for {}",
externs::key_to_str(&constraint.0),
externs::key_to_str(&subject)
))
});
let context = context.clone();
future::result(edges_res).and_then(move |edges| {
Select::new(constraint, subject, Default::default(), &edges).run(context.clone())
})
.edges_for_inner(&entry)
.expect("edges for task exist.")
.entries_for(&rule_graph::SelectKey::JustGet(selectors::Get {
product: product,
subject: subject.type_id().clone(),
}));
Select::new_with_entries(product, subject, Default::default(), entries).run(context.clone())
})
.collect::<Vec<_>>();
future::join_all(get_futures).to_boxed()
Expand All @@ -942,15 +953,20 @@ impl Task {
/// Given a python generator Value, loop to request the generator's dependencies until
/// it completes with a result Value.
///
fn generate(context: Context, generator: Value) -> NodeFuture<Value> {
fn generate(
context: Context,
entry: Arc<rule_graph::Entry>,
generator: Value,
) -> NodeFuture<Value> {
future::loop_fn(externs::eval("None").unwrap(), move |input| {
let context = context.clone();
let entry = entry.clone();
future::result(externs::generator_send(&generator, &input)).and_then(move |response| {
match response {
externs::GeneratorResponse::Get(get) => Self::gen_get(&context, vec![get])
externs::GeneratorResponse::Get(get) => Self::gen_get(&context, entry, vec![get])
.map(|vs| future::Loop::Continue(vs.into_iter().next().unwrap()))
.to_boxed() as BoxFuture<_, _>,
externs::GeneratorResponse::GetMulti(gets) => Self::gen_get(&context, gets)
externs::GeneratorResponse::GetMulti(gets) => Self::gen_get(&context, entry, gets)
.map(|vs| future::Loop::Continue(externs::store_list(vs.iter().collect(), false)))
.to_boxed() as BoxFuture<_, _>,
externs::GeneratorResponse::Break(val) => {
Expand All @@ -976,6 +992,7 @@ impl Node for Task {
);

let func = self.task.func.clone();
let entry = self.entry.clone();
deps
.then(move |deps_result| match deps_result {
Ok(deps) => externs::call(&externs::val_for(&func.0), &deps),
Expand All @@ -984,7 +1001,7 @@ impl Node for Task {
.then(move |task_result| match task_result {
Ok(val) => {
if externs::satisfied_by(&context.core.types.generator, &val) {
Self::generate(context, val)
Self::generate(context, entry, val)
} else {
ok(val)
}
Expand Down
Loading

0 comments on commit c4b40a3

Please sign in to comment.