Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Passes] Iterative A-normal Traversals #7374

Merged
merged 5 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_l
}
}
}

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
std::function<void(const LetNode*)> post_visit);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
12 changes: 12 additions & 0 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ class TypeVarEVisitor : private MixedModeVisitor {
ExprVisitor::VisitExpr_(f);
}

void VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1;
};
ExpandANormalForm(op, pre_visit, post_visit);
}

void VisitExpr_(const ConstructorNode* cn) final {
// for constructors, type vars will be bound in the module
auto data = mod_->LookupTypeDef(cn->belong_to);
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,5 +532,27 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret)
*ret = Bind(Downcast<Type>(input), args[1]);
}
});

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
std::function<void(const LetNode*)> post_visit) {
std::stack<const LetNode*> stack;
stack.push(op);
bool is_anormal = true;
while (is_anormal) {
const LetNode* current_op = stack.top();
pre_visit(current_op);
if (const LetNode* new_op = current_op->body.as<LetNode>()) {
stack.push(new_op);
} else {
is_anormal = false;
}
}
while (stack.size()) {
const LetNode* current_op = stack.top();
stack.pop();
post_visit(current_op);
}
}

} // namespace relay
} // namespace tvm
19 changes: 16 additions & 3 deletions src/relay/transforms/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>

#include <stack>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -61,8 +63,19 @@ Expr DeDup(const Expr& e) {
}

Expr VisitExpr_(const LetNode* op) final {
Var v = Fresh(op->var);
return Let(v, VisitExpr(op->value), VisitExpr(op->body));
std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> new_vars;
auto pre_visit = [this, &new_vars](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
new_vars[expr] = Fresh(op->var);
// Rely on the Memoizer to cache pre-visit values
VisitExpr(op->value);
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
};
auto post_visit = [this, &new_vars](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
memo_[expr] = Let(new_vars[expr], VisitExpr(op->value), VisitExpr(op->body));
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}

Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; }
Expand Down Expand Up @@ -99,7 +112,7 @@ Expr DeDup(const Expr& e) {
ICHECK(WellFormed(ret));
ICHECK_EQ(FreeVars(e).size(), FreeVars(ret).size());
return ret;
}
} // namespace relay

TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup);

Expand Down
36 changes: 25 additions & 11 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,33 @@ class ConstantFolder : public MixedModeMutator {
using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Expr>(op);
auto pre_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
} else {
return Let(var, value, body);
this->Mutate(op->var);
}
}
};
auto post_visit = [this](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
// Rely on the Memoizer to cache pre-visit values
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[expr] = this->Mutate(op->body);
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
memo_[expr] = expr;
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
} else {
memo_[expr] = Let(var, value, body);
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
}
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}

bool inside_primitive = false;
Expand Down
52 changes: 43 additions & 9 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,20 @@ class IndexedForwardGraph::Creator : private ExprVisitor {

void VisitExpr_(const LetNode* op) final {
// do not fuse through let.
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
auto pre_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
this->Update(op->body, nullptr, kOpaque);
VisitExpr(op->var);
VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
VisitExpr(op->body);
visit_counter_[op] += 1;
this->AddNode(op);
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
};
ExpandANormalForm(op, pre_visit, post_visit);
}

void VisitExpr_(const IfNode* op) final {
Expand Down Expand Up @@ -797,7 +806,7 @@ std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
return std::move(groups_);
}

class FuseMutator : private ExprMutator {
class FuseMutator : private MixedModeMutator {
public:
// Run the transform
Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth) {
Expand All @@ -814,6 +823,8 @@ class FuseMutator : private ExprMutator {
}

private:
using MixedModeMutator::VisitExpr_;

/*! \brief Temporary information from each group. */
struct GroupInfo {
public:
Expand Down Expand Up @@ -853,7 +864,7 @@ class FuseMutator : private ExprMutator {
}

// Transform calls.
Expr VisitExpr_(const CallNode* call) {
Expr Rewrite_(const CallNode* call, const Expr& post) {
if (call->op.as<OpNode>()) {
static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");

Expand Down Expand Up @@ -886,7 +897,7 @@ class FuseMutator : private ExprMutator {
}
}

Expr VisitExpr_(const TupleNode* tuple) {
Expr Rewrite_(const TupleNode* tuple, const Expr& post) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple);
Expand All @@ -896,7 +907,7 @@ class FuseMutator : private ExprMutator {
return Tuple(new_fields);
}

Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItem(new_tuple, tuple_get->index);
Expand All @@ -913,6 +924,29 @@ class FuseMutator : private ExprMutator {
return std::move(new_node);
}

Expr VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Var var = Downcast<Var>(VisitExpr(op->var));
Expr value = VisitExpr(op->value);
// Visit body and cache the op
Expr body = VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
memo_[expr] = expr;
} else {
memo_[expr] = Let(var, value, body);
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
// If the function has no call, it is not a primitive function.
struct HasCallVisitor : ExprVisitor {
Expand Down
70 changes: 52 additions & 18 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,26 +341,34 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Type VisitExpr_(const OpNode* op) final { return op->op_type; }

Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
Type let_type = IncompleteType(Kind::kType);

if (is_functional_literal) {
let_type = GetType(let->var);
type_map_[let->var].checked_type = let_type;
}
auto pre_visit = [this](const LetNode* op) {
// if the definition is a function literal, permit recursion
bool is_functional_literal = op->value.as<FunctionNode>() != nullptr;
Type let_type = IncompleteType(Kind::kType);

if (is_functional_literal) {
let_type = GetType(op->var);
type_map_[op->var].checked_type = let_type;
}

if (let->var->type_annotation.defined()) {
let_type = Unify(let_type, let->var->type_annotation, let->span);
}
if (op->var->type_annotation.defined()) {
let_type = Unify(let_type, op->var->type_annotation, op->span);
}

Type vtype = GetType(let->value);
let_type = Unify(let_type, vtype, let->span);
Type vtype = GetType(op->value);
let_type = Unify(let_type, vtype, op->span);

ICHECK(is_functional_literal || !type_map_.count(let->var));
// NOTE: no scoping is necessary because var are unique in program
type_map_[let->var].checked_type = let_type;
return GetType(let->body);
ICHECK(is_functional_literal || !type_map_.count(op->var));
// NOTE: no scoping is necessary because var are unique in program
type_map_[op->var].checked_type = let_type;
};
auto post_visit = [this](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
memo_[expr] = GetType(op->body);
type_map_[expr].checked_type = memo_[expr];
};
ExpandANormalForm(let, pre_visit, post_visit);
return memo_[GetRef<Expr>(let)];
}

Type VisitExpr_(const IfNode* ite) final {
Expand Down Expand Up @@ -603,7 +611,21 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {

Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); }

Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); }
Expr VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
Expr expr = GetRef<Expr>(op);
Var var = Downcast<Var>(VisitExpr(op->var));
Expr value = VisitExpr(op->value);
Expr body = VisitExpr(op->body);
memo_[expr] = AttachCheckedType(op, Let(var, value, body));
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}

Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); }

Expand Down Expand Up @@ -738,6 +760,7 @@ Expr TypeInferencer::Infer(GlobalVar var, Function function) {
}

struct AllCheckTypePopulated : MixedModeVisitor {
using MixedModeVisitor::VisitExpr_;
void DispatchExprVisit(const Expr& e) {
if (e.as<OpNode>()) {
return;
Expand All @@ -751,6 +774,17 @@ struct AllCheckTypePopulated : MixedModeVisitor {
ICHECK(e->checked_type_.defined()) << "Expression: " << e;
return ExprVisitor::VisitExpr(e);
}
void VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
this->VisitExpr(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1;
};
ExpandANormalForm(op, pre_visit, post_visit);
}
};

void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); }
Expand Down