Skip to content

Commit

Permalink
[Relay] fix error in ANF (too agressively inline atomic expression an…
Browse files Browse the repository at this point in the history
…d create free variable). (apache#2665)
  • Loading branch information
MarisaKirisame authored and wweic committed Mar 9, 2019
1 parent 741b222 commit 78d4427
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/relay/pass/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ bool IsPrimitiveFunction(const Expr& e) {
return e.as<FunctionNode>() && Downcast<Function>(e)->IsPrimitive();
}

/* Special care is needed to handle local recursion.
* Fill additionally take a (possibly null) Var argument,
* If it is not null, Fill is required to bind the transformed result to that var.
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANormalForm(const Expr& e,
Expand Down Expand Up @@ -307,12 +311,18 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}

Expr VisitExpr(const Expr& e) {
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return this->VisitExpr(e, v);
return this->VisitExpr(e, Var());
}

Expr Atomic(const Expr& orig, const Expr& now, const Var& v) {
return v.defined() ? GetScope(orig)->ll->Push(v, now) : now;
}

Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
return GetScope(orig)->ll->Push(v, now);
Var var = v.defined() ?
v :
VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return GetScope(orig)->ll->Push(var, now);
}

Expr VisitExpr_(const CallNode* c, const Var& v) final {
Expand Down Expand Up @@ -389,7 +399,8 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}

Expr VisitExpr_(const VarNode* vn, const Var& v) final {
return GetRef<Expr>(vn);
Expr e = GetRef<Expr>(vn);
return Atomic(e, e, v);
}

Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
Expand All @@ -398,15 +409,17 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
visited_->insert(gv);
mod_->Update(gv, Downcast<Function>(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_)));
}
return std::move(gv);
return Atomic(gv, gv, v);
}

Expr VisitExpr_(const OpNode* op, const Var& v) final {
return GetRef<Expr>(op);
Expr e = GetRef<Expr>(op);
return Atomic(e, e, v);
}

Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
return GetRef<Expr>(c);
Expr e = GetRef<Expr>(c);
return Atomic(e, e, v);
}

Expr VisitExpr_(const MatchNode* m, const Var& v) final {
Expand All @@ -418,8 +431,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
}
Expr r = Compound(e, MatchNode::make(data, clauses), v);
return r;
return Compound(e, MatchNode::make(data, clauses), v);
}
};

Expand Down
10 changes: 10 additions & 0 deletions tests/python/relay/test_to_a_normal_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,20 @@ def test_add():
assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2
assert "let" in mod[add].astext()

def test_let():
x = relay.Var("x")
y = relay.Var("y")
d = relay.const(4.0, 'float32')
body = relay.Let(y, x, x + y)
body = relay.Let(x, d, body)
check_eval(body, 8)
check_eval(to_a_normal_form(body), 8)

if __name__ == '__main__':
test_explicit_bound()
test_order()
test_if()
test_recursion()
test_ref()
test_add()
test_let()

0 comments on commit 78d4427

Please sign in to comment.