Skip to content

Commit

Permalink
[Relay][VM][Interpreter] Enable first-class constructors in VM and in…
Browse files Browse the repository at this point in the history
…terpreter via eta expansion (#4218)

* Fix constructor pretty printing

* Make Module::HasDef name consistent with API

* Add VM constructor compilation via eta expansion

* Lint

* Fix CI

* Fix failing test

* Address comment

* Retrigger CI

* Retrigger CI
  • Loading branch information
weberlo authored and icemelon committed Nov 15, 2019
1 parent 3f6b3db commit 2c5c4da
Show file tree
Hide file tree
Showing 14 changed files with 276 additions and 92 deletions.
14 changes: 7 additions & 7 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL bool ContainGlobalVar(const std::string& name) const;

/*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;

/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
Expand Down Expand Up @@ -198,13 +205,6 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;

/*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;

/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
Expand Down
9 changes: 6 additions & 3 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,17 +552,20 @@ TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize"
TVM_DLL Pass CanonicalizeCast();

/*!
* \brief Add abstraction over a function
* \brief Add abstraction over a constructor or global variable bound to a function.
*
* For example: `square` is transformed to
* `fun x -> square x`.
* `fn (%x: int32) -> int32 { square(x) }`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param expand_constructor Whether to expand constructors.
* \param expand_global_var Whether to expand global variables.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand();
TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var);

/*!
* \brief Print the IR for a module to help debugging.
Expand Down
15 changes: 3 additions & 12 deletions python/tvm/relay/std/prelude.rly
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,9 @@ def @sum(%xs: List[Tensor[(), int32]]) {
/*
* Concatenates two lists.
*/

def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
let %updater = fn(%x: A, %xss: List[A]) -> List[A] {
Cons(%x, %xss)
};
@foldr(%updater, %ys, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldr(Cons, %ys, %xs)
@foldr(Cons, %ys, %xs)
}

/*
Expand Down Expand Up @@ -199,12 +195,7 @@ def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
* Reverses a list.
*/
def @rev[A](%xs: List[A]) -> List[A] {
let %updater = fn(%xss: List[A], %x: A) -> List[A] {
Cons(%x, %xss)
};
@foldl(%updater, Nil, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldl(@flip(Cons), Nil, %xs)
@foldl(@flip(Cons), Nil, %xs)
}

/*
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,23 @@ def ToCPS(expr, mod=None):
return _transform.to_cps(expr, mod)


def EtaExpand():
"""Add abstraction over a function
def EtaExpand(expand_constructor=False, expand_global_var=False):
"""Add abstraction over a constructor or global variable bound to a function
Parameters
----------
expand_constructor: bool
Whether to expand constructors.
expand_global_var: bool
Whether to expand global variables.
Returns
-------
ret: tvm.relay.Pass
The registered pass that eta expands an expression.
"""
return _transform.EtaExpand()
return _transform.EtaExpand(expand_constructor, expand_global_var)


def ToGraphNormalForm():
Expand Down Expand Up @@ -959,6 +967,7 @@ def create_function_pass(pass_arg):
return create_function_pass(pass_func)
return create_function_pass


@function_pass(opt_level=1)
class ChangeBatch:
"""
Expand Down
11 changes: 11 additions & 0 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
Expand Down Expand Up @@ -789,6 +790,16 @@ CreateInterpreter(
Module mod,
DLContext context,
Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
transform::Sequential seq({
transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false)});
transform::PassContext pass_ctx = transform::PassContext::Current();
tvm::With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);
}

auto intrp = std::make_shared<Interpreter>(mod, context, target);
auto packed = [intrp](Expr expr) {
auto f = DetectFeature(expr);
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,10 @@ Module VMCompiler::OptimizeModule(const Module& mod, const TargetsMap& targets)
pass_seqs.push_back(transform::Legalize());
}

// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false));

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
Expand Down
22 changes: 16 additions & 6 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ Function MarkClosure(const Function& func) {
* We will lift a function out into a global which takes the set of the free
* vars and then return the new created function.
*/
struct LambdaLifter : ExprMutator {
Module module_;
class LambdaLifter : public ExprMutator {
public:
explicit LambdaLifter(const Module& module) : module_(module) {}

Expr VisitExpr_(const FunctionNode* func_node) final {
Expand Down Expand Up @@ -100,8 +100,8 @@ struct LambdaLifter : ExprMutator {
// The "inner" function should be used to generate the
// code for the closure.
Function lifted_func;
if (free_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, free_type_vars);
if (free_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
} else {
lifted_func =
FunctionNode::make(free_vars, body, func->func_type_annotation(), free_type_vars);
Expand All @@ -114,8 +114,15 @@ struct LambdaLifter : ExprMutator {
auto name = GenerateName(lifted_func);
auto global = GlobalVarNode::make(name);

// Add the lifted function to the module.
module_->Add(global, lifted_func);
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
// Add the lifted function to the module.
module_->Add(global, lifted_func);
}

if (free_vars.size() == 0) {
return std::move(global);
Expand Down Expand Up @@ -145,6 +152,9 @@ struct LambdaLifter : ExprMutator {
}
return module_;
}

private:
Module module_;
};

} // namespace vm
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->HasDef(p.first->var->name_hint) ||
if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
Expand Down
12 changes: 6 additions & 6 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ bool ModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}

bool ModuleNode::ContainGlobalTypeVar(const std::string& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}

GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
Expand Down Expand Up @@ -239,11 +243,6 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}

bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}

Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
Expand Down Expand Up @@ -336,7 +335,8 @@ TVM_REGISTER_API("relay._module.Module_Add")
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = Module(make_node<ModuleNode>(*mod.operator->()));
mod_copy = transform::EtaExpand()(mod_copy);
mod_copy = transform::EtaExpand(
/* expand_constructor */ false, /* expand_global_var */ true)(mod_copy);
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<Function>(func), update);
} else {
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ class PrettyPrinter :
Doc VisitExpr_(const ConstructorNode* n) final {
Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
if (in_adt_def_ && n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
Expand Down Expand Up @@ -775,6 +775,7 @@ class PrettyPrinter :
}

Doc VisitType_(const TypeDataNode* node) final {
in_adt_def_ = true;
Doc doc;
doc << "type " << Print(node->header);

Expand Down Expand Up @@ -802,6 +803,7 @@ class PrettyPrinter :
adt_body << ",";
}
doc << Brace(adt_body);
in_adt_def_ = false;
return doc;
}

Expand Down Expand Up @@ -876,6 +878,8 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief whether the printer is currently in an ADT definition */
bool in_adt_def_;
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
Expand Down
Loading

0 comments on commit 2c5c4da

Please sign in to comment.