From bba3b6db29bc717e453b61cccad6e669ef47dfd1 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 22 Feb 2019 21:49:52 -0800 Subject: [PATCH 1/4] save --- src/relay/pass/to_a_normal_form.cc | 82 +++++++++++---------- tests/python/relay/test_to_a_normal_form.py | 10 +++ 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 53e2c1c594f8..2cd4caa97831 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -256,7 +256,8 @@ bool IsPrimitiveFunction(const Expr& e) { return e.as() && Downcast(e)->IsPrimitive(); } -class Fill : ExprFunctor { +using FlushTo = std::shared_ptr; // if it is defined, always flush current expr into it. +class Fill : ExprFunctor { public: static Expr ToANormalForm(const Expr& e, const Module& m, @@ -299,69 +300,73 @@ class Fill : ExprFunctor { return node_scope_->at(h->value); } - Expr VisitExpr(const Expr& e, const Var& v) final { + Expr VisitExpr(const Expr& e, const FlushTo& ft) final { if (memo.count(e) == 0) { - memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + memo.insert({e, ExprFunctor::VisitExpr(e, ft)}); } return memo.at(e); } Expr VisitExpr(const Expr& e) { - Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); - return this->VisitExpr(e, v); + return this->VisitExpr(e, FlushTo()); } - Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + Expr Atomic(const Expr& orig, const Expr& now, const FlushTo& ft) { + return ft ? GetScope(orig)->ll->Push(*ft, now) : now; + } + + Expr Compound(const Expr& orig, const Expr& now, const FlushTo& ft) { + Var v = ft ? *ft : VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); return GetScope(orig)->ll->Push(v, now); } - Expr VisitExpr_(const CallNode* c, const Var& v) final { + Expr VisitExpr_(const CallNode* c, const FlushTo& ft) final { Expr e = GetRef(c); std::vector args; for (const auto& a : c->args) { args.push_back(VisitExpr(a)); } - return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v); + return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), ft); } - Expr VisitExpr_(const TupleNode* t, const Var& v) final { + Expr VisitExpr_(const TupleNode* t, const FlushTo& ft) final { Expr e = GetRef(t); std::vector fields; for (const auto& a : t->fields) { fields.push_back(VisitExpr(a)); } - return Compound(e, TupleNode::make(fields), v); + return Compound(e, TupleNode::make(fields), ft); } - Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { + Expr VisitExpr_(const TupleGetItemNode* t, const FlushTo& ft) final { Expr e = GetRef(t); - return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); + return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), ft); } - Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { + Expr VisitExpr_(const RefCreateNode* r, const FlushTo& ft) final { Expr e = GetRef(r); - return Compound(e, RefCreateNode::make(VisitExpr(r->value)), v); + return Compound(e, RefCreateNode::make(VisitExpr(r->value)), ft); } - Expr VisitExpr_(const RefReadNode* r, const Var& v) final { + Expr VisitExpr_(const RefReadNode* r, const FlushTo& ft) final { Expr e = GetRef(r); - return Compound(e, RefReadNode::make(VisitExpr(r->ref)), v); + return Compound(e, RefReadNode::make(VisitExpr(r->ref)), ft); } - Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { + Expr VisitExpr_(const RefWriteNode* r, const FlushTo& ft) final { Expr e = GetRef(r); - return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), v); + return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), ft); } - Expr VisitExpr_(const IfNode* i, const Var& v) final { + Expr VisitExpr_(const IfNode* i, const FlushTo& ft) final { Expr e = GetRef(i); Expr ret = IfNode::make(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); - return Compound(e, ret, v); + return Compound(e, ret, ft); } - Expr VisitExpr_(const FunctionNode* f, const Var& v) final { + Expr VisitExpr_(const FunctionNode* f, const FlushTo& ft) final { Expr e = GetRef(f); Expr ret; if (IsPrimitiveFunction(e)) { @@ -373,43 +378,46 @@ class Fill : ExprFunctor { f->type_params, f->attrs); } - return Compound(e, ret, v); + return Compound(e, ret, ft); } - Expr VisitExpr_(const LetNode* l, const Var& v) final { + Expr VisitExpr_(const LetNode* l, const FlushTo& ft) final { Expr e = GetRef(l); - VisitExpr(l->value, l->var); + VisitExpr(l->value, std::make_shared(l->var)); Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body)); - return Compound(e, ret, v); + return Compound(e, ret, ft); } - Expr VisitExpr_(const ConstantNode* c, const Var& v) final { + Expr VisitExpr_(const ConstantNode* c, const FlushTo& ft) final { Expr e = GetRef(c); - return Compound(e, e, v); + return Compound(e, e, ft); } - Expr VisitExpr_(const VarNode* vn, const Var& v) final { - return GetRef(vn); + Expr VisitExpr_(const VarNode* vn, const FlushTo& ft) final { + Expr e = GetRef(vn); + return Atomic(e, e, ft); } - Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { + Expr VisitExpr_(const GlobalVarNode* gvn, const FlushTo& ft) final { GlobalVar gv = GetRef(gvn); if (visited_->count(gv) == 0) { visited_->insert(gv); mod_->Update(gv, Downcast(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); } - return gv; + return Atomic(gv, gv, ft); } - Expr VisitExpr_(const OpNode* op, const Var& v) final { - return GetRef(op); + Expr VisitExpr_(const OpNode* op, const FlushTo& ft) final { + Expr e = GetRef(op); + return Atomic(e, e, ft); } - Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { - return GetRef(c); + Expr VisitExpr_(const ConstructorNode* c, const FlushTo& ft) final { + Expr e = GetRef(c); + return Atomic(e, e, ft); } - Expr VisitExpr_(const MatchNode* m, const Var& v) final { + Expr VisitExpr_(const MatchNode* m, const FlushTo& ft) final { Expr e = GetRef(m); Expr data = VisitExpr(m->data); std::vector clauses; @@ -418,7 +426,7 @@ class Fill : ExprFunctor { c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } - Expr r = Compound(e, MatchNode::make(data, clauses), v); + Expr r = Compound(e, MatchNode::make(data, clauses), ft); return r; } }; diff --git a/tests/python/relay/test_to_a_normal_form.py b/tests/python/relay/test_to_a_normal_form.py index c15dc8ffc269..392e1769e57d 100644 --- a/tests/python/relay/test_to_a_normal_form.py +++ b/tests/python/relay/test_to_a_normal_form.py @@ -138,6 +138,15 @@ def test_add(): assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext() +def test_let(): + x = relay.Var("x") + y = relay.Var("y") + d = relay.const(4.0, 'float32') + body = relay.Let(y, x, x + y) + body = relay.Let(x, d, body) + check_eval(body, 8) + check_eval(to_a_normal_form(body), 8) + if __name__ == '__main__': test_explicit_bound() test_order() @@ -145,3 +154,4 @@ def test_add(): test_recursion() test_ref() test_add() + test_let() From c2e5371d7644eff1027f6a2c56fa672927d1f226 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Fri, 22 Feb 2019 21:53:39 -0800 Subject: [PATCH 2/4] lint --- src/relay/pass/to_a_normal_form.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 2cd4caa97831..1814da95ebe3 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -256,7 +256,7 @@ bool IsPrimitiveFunction(const Expr& e) { return e.as() && Downcast(e)->IsPrimitive(); } -using FlushTo = std::shared_ptr; // if it is defined, always flush current expr into it. +using FlushTo = std::shared_ptr; // if it is defined, always flush current expr into it. class Fill : ExprFunctor { public: static Expr ToANormalForm(const Expr& e, From cf4fd9592a8a2c183371241949747e36ff040207 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 25 Feb 2019 12:44:01 -0800 Subject: [PATCH 3/4] address comment --- src/relay/pass/to_a_normal_form.cc | 86 +++++++++++++++--------------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 1814da95ebe3..285456436c5d 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -256,8 +256,11 @@ bool IsPrimitiveFunction(const Expr& e) { return e.as() && Downcast(e)->IsPrimitive(); } -using FlushTo = std::shared_ptr; // if it is defined, always flush current expr into it. -class Fill : ExprFunctor { +/* Special care is needed to handle local recursion. + * Fill additionally take a (possibly null) Var argument, + * If it is not null, Fill is required to bind the transformed result to that var. + */ +class Fill : ExprFunctor { public: static Expr ToANormalForm(const Expr& e, const Module& m, @@ -300,73 +303,73 @@ class Fill : ExprFunctor { return node_scope_->at(h->value); } - Expr VisitExpr(const Expr& e, const FlushTo& ft) final { + Expr VisitExpr(const Expr& e, const Var& v) final { if (memo.count(e) == 0) { - memo.insert({e, ExprFunctor::VisitExpr(e, ft)}); + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); } return memo.at(e); } Expr VisitExpr(const Expr& e) { - return this->VisitExpr(e, FlushTo()); + return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& orig, const Expr& now, const FlushTo& ft) { - return ft ? GetScope(orig)->ll->Push(*ft, now) : now; + Expr Atomic(const Expr& orig, const Expr& now, const Var& v) { + return v.defined() ? GetScope(orig)->ll->Push(v, now) : now; } - Expr Compound(const Expr& orig, const Expr& now, const FlushTo& ft) { - Var v = ft ? *ft : VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); - return GetScope(orig)->ll->Push(v, now); + Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + Var var = v.defined() ? v : VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); + return GetScope(orig)->ll->Push(var, now); } - Expr VisitExpr_(const CallNode* c, const FlushTo& ft) final { + Expr VisitExpr_(const CallNode* c, const Var& v) final { Expr e = GetRef(c); std::vector args; for (const auto& a : c->args) { args.push_back(VisitExpr(a)); } - return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), ft); + return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v); } - Expr VisitExpr_(const TupleNode* t, const FlushTo& ft) final { + Expr VisitExpr_(const TupleNode* t, const Var& v) final { Expr e = GetRef(t); std::vector fields; for (const auto& a : t->fields) { fields.push_back(VisitExpr(a)); } - return Compound(e, TupleNode::make(fields), ft); + return Compound(e, TupleNode::make(fields), v); } - Expr VisitExpr_(const TupleGetItemNode* t, const FlushTo& ft) final { + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { Expr e = GetRef(t); - return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), ft); + return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); } - Expr VisitExpr_(const RefCreateNode* r, const FlushTo& ft) final { + Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { Expr e = GetRef(r); - return Compound(e, RefCreateNode::make(VisitExpr(r->value)), ft); + return Compound(e, RefCreateNode::make(VisitExpr(r->value)), v); } - Expr VisitExpr_(const RefReadNode* r, const FlushTo& ft) final { + Expr VisitExpr_(const RefReadNode* r, const Var& v) final { Expr e = GetRef(r); - return Compound(e, RefReadNode::make(VisitExpr(r->ref)), ft); + return Compound(e, RefReadNode::make(VisitExpr(r->ref)), v); } - Expr VisitExpr_(const RefWriteNode* r, const FlushTo& ft) final { + Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { Expr e = GetRef(r); - return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), ft); + return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), v); } - Expr VisitExpr_(const IfNode* i, const FlushTo& ft) final { + Expr VisitExpr_(const IfNode* i, const Var& v) final { Expr e = GetRef(i); Expr ret = IfNode::make(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); - return Compound(e, ret, ft); + return Compound(e, ret, v); } - Expr VisitExpr_(const FunctionNode* f, const FlushTo& ft) final { + Expr VisitExpr_(const FunctionNode* f, const Var& v) final { Expr e = GetRef(f); Expr ret; if (IsPrimitiveFunction(e)) { @@ -378,46 +381,46 @@ class Fill : ExprFunctor { f->type_params, f->attrs); } - return Compound(e, ret, ft); + return Compound(e, ret, v); } - Expr VisitExpr_(const LetNode* l, const FlushTo& ft) final { + Expr VisitExpr_(const LetNode* l, const Var& v) final { Expr e = GetRef(l); - VisitExpr(l->value, std::make_shared(l->var)); + VisitExpr(l->value, l->var); Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body)); - return Compound(e, ret, ft); + return Compound(e, ret, v); } - Expr VisitExpr_(const ConstantNode* c, const FlushTo& ft) final { + Expr VisitExpr_(const ConstantNode* c, const Var& v) final { Expr e = GetRef(c); - return Compound(e, e, ft); + return Compound(e, e, v); } - Expr VisitExpr_(const VarNode* vn, const FlushTo& ft) final { + Expr VisitExpr_(const VarNode* vn, const Var& v) final { Expr e = GetRef(vn); - return Atomic(e, e, ft); + return Atomic(e, e, v); } - Expr VisitExpr_(const GlobalVarNode* gvn, const FlushTo& ft) final { + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { GlobalVar gv = GetRef(gvn); if (visited_->count(gv) == 0) { visited_->insert(gv); mod_->Update(gv, Downcast(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); } - return Atomic(gv, gv, ft); + return Atomic(gv, gv, v); } - Expr VisitExpr_(const OpNode* op, const FlushTo& ft) final { + Expr VisitExpr_(const OpNode* op, const Var& v) final { Expr e = GetRef(op); - return Atomic(e, e, ft); + return Atomic(e, e, v); } - Expr VisitExpr_(const ConstructorNode* c, const FlushTo& ft) final { + Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { Expr e = GetRef(c); - return Atomic(e, e, ft); + return Atomic(e, e, v); } - Expr VisitExpr_(const MatchNode* m, const FlushTo& ft) final { + Expr VisitExpr_(const MatchNode* m, const Var& v) final { Expr e = GetRef(m); Expr data = VisitExpr(m->data); std::vector clauses; @@ -426,8 +429,7 @@ class Fill : ExprFunctor { c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } - Expr r = Compound(e, MatchNode::make(data, clauses), ft); - return r; + return Compound(e, MatchNode::make(data, clauses), v); } }; From bd95f519a981d24aaf1bedf50bb00128b1bfba3b Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 25 Feb 2019 13:53:54 -0800 Subject: [PATCH 4/4] lint --- src/relay/pass/to_a_normal_form.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 285456436c5d..46a4b92ac9b9 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -319,7 +319,9 @@ class Fill : ExprFunctor { } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? v : VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); + Var var = v.defined() ? + v : + VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); return GetScope(orig)->ll->Push(var, now); }