From 78d442717d56d628d84c572670b0f539185d9f15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 25 Feb 2019 18:40:46 -0800 Subject: [PATCH] [Relay] fix error in ANF (too agressively inline atomic expression and create free variable). (#2665) --- src/relay/pass/to_a_normal_form.cc | 30 ++++++++++++++------- tests/python/relay/test_to_a_normal_form.py | 10 +++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index e5da2dee2e03..46a4b92ac9b9 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -256,6 +256,10 @@ bool IsPrimitiveFunction(const Expr& e) { return e.as() && Downcast(e)->IsPrimitive(); } +/* Special care is needed to handle local recursion. + * Fill additionally take a (possibly null) Var argument, + * If it is not null, Fill is required to bind the transformed result to that var. + */ class Fill : ExprFunctor { public: static Expr ToANormalForm(const Expr& e, @@ -307,12 +311,18 @@ class Fill : ExprFunctor { } Expr VisitExpr(const Expr& e) { - Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); - return this->VisitExpr(e, v); + return this->VisitExpr(e, Var()); + } + + Expr Atomic(const Expr& orig, const Expr& now, const Var& v) { + return v.defined() ? GetScope(orig)->ll->Push(v, now) : now; } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - return GetScope(orig)->ll->Push(v, now); + Var var = v.defined() ? + v : + VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType)); + return GetScope(orig)->ll->Push(var, now); } Expr VisitExpr_(const CallNode* c, const Var& v) final { @@ -389,7 +399,8 @@ class Fill : ExprFunctor { } Expr VisitExpr_(const VarNode* vn, const Var& v) final { - return GetRef(vn); + Expr e = GetRef(vn); + return Atomic(e, e, v); } Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { @@ -398,15 +409,17 @@ class Fill : ExprFunctor { visited_->insert(gv); mod_->Update(gv, Downcast(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); } - return std::move(gv); + return Atomic(gv, gv, v); } Expr VisitExpr_(const OpNode* op, const Var& v) final { - return GetRef(op); + Expr e = GetRef(op); + return Atomic(e, e, v); } Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { - return GetRef(c); + Expr e = GetRef(c); + return Atomic(e, e, v); } Expr VisitExpr_(const MatchNode* m, const Var& v) final { @@ -418,8 +431,7 @@ class Fill : ExprFunctor { c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } - Expr r = Compound(e, MatchNode::make(data, clauses), v); - return r; + return Compound(e, MatchNode::make(data, clauses), v); } }; diff --git a/tests/python/relay/test_to_a_normal_form.py b/tests/python/relay/test_to_a_normal_form.py index c15dc8ffc269..392e1769e57d 100644 --- a/tests/python/relay/test_to_a_normal_form.py +++ b/tests/python/relay/test_to_a_normal_form.py @@ -138,6 +138,15 @@ def test_add(): assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext() +def test_let(): + x = relay.Var("x") + y = relay.Var("y") + d = relay.const(4.0, 'float32') + body = relay.Let(y, x, x + y) + body = relay.Let(x, d, body) + check_eval(body, 8) + check_eval(to_a_normal_form(body), 8) + if __name__ == '__main__': test_explicit_bound() test_order() @@ -145,3 +154,4 @@ def test_add(): test_recursion() test_ref() test_add() + test_let()