Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Change some passes to mix mode #6695

Merged
merged 1 commit into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TypeVarTVisitor : public TypeVisitor {
InsertionSet<TypeVar>* bound_type_vars_;
};

class TypeVarEVisitor : private ExprVisitor {
class TypeVarEVisitor : private MixedModeVisitor {
public:
explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}

Expand Down Expand Up @@ -131,6 +131,8 @@ class TypeVarEVisitor : private ExprVisitor {
return CollectAll();
}

using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const FunctionNode* f) final {
for (const auto& tp : f->type_params) {
type_vars_.Insert(tp);
Expand Down Expand Up @@ -159,7 +161,7 @@ class TypeVarEVisitor : private ExprVisitor {
const IRModule& mod_;
};

class VarVisitor : protected ExprVisitor, protected PatternVisitor {
class VarVisitor : protected MixedModeVisitor, protected PatternVisitor {
public:
Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr);
Expand Down Expand Up @@ -204,6 +206,8 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor {
vars_.Insert(v);
}

using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }

void VisitExpr_(const FunctionNode* op) final {
Expand Down
16 changes: 7 additions & 9 deletions src/relay/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relay {

//! brief make sure each Var is bound at most once in a scope.
class WellFormedChecker : private ExprVisitor, PatternVisitor {
class WellFormedChecker : private MixedModeVisitor, PatternVisitor {
public:
Optional<DiagnosticContext> diag_ctx;
Span occurs_in;
Expand Down Expand Up @@ -79,6 +79,8 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor {
total_bound.insert(v);
}

using MixedModeVisitor::VisitExpr_;

void VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
if (current_bound.count(v) == 0) {
Expand Down Expand Up @@ -126,7 +128,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor {

// CHECK(call->attrs.defined());
CHECK(call->type_args.defined());
ExprVisitor::VisitExpr_(call);
MixedModeVisitor::VisitExpr_(call);
}

void VisitClause(const Clause& c) final {
Expand All @@ -139,18 +141,14 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor {

void VisitVar(const Var& v) final { Bound(v); }

void VisitExpr(const Expr& e) final {
public:
bool CheckWellFormed(const Expr& e) {
if (auto v = e.as<VarNode>()) {
VisitExpr_(v);
} else {
// this->occurs_in = e->span;
ExprVisitor::VisitExpr(e);
VisitExpr(e);
}
}

public:
bool CheckWellFormed(const Expr& e) {
this->VisitExpr(e);
return well_formed;
}
};
Expand Down
4 changes: 3 additions & 1 deletion src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,12 @@ TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr ex
});

// Implement bind.
class ExprBinder : public ExprMutator, PatternMutator {
class ExprBinder : public MixedModeMutator, PatternMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map) : args_map_(args_map) {}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
Expand Down
6 changes: 4 additions & 2 deletions src/relay/transforms/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace tvm {
namespace relay {

Expr DeDup(const Expr& e) {
class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator {
class DeDupMutator : public TypeMutator, public MixedModeMutator, public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVar(tv->name_hint, tv->kind);
Expand All @@ -47,12 +47,14 @@ Expr DeDup(const Expr& e) {
return ret;
}

Expr VisitExpr(const Expr& e) final {
Expr DispatchVisitExpr(const Expr& e) final {
auto ret = ExprMutator::VisitExpr(e);
ret->checked_type_ = e->checked_type_;
return ret;
}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const VarNode* op) final {
Var v = GetRef<Var>(op);
return rename_.count(v) != 0 ? rename_.at(v) : v;
Expand Down
32 changes: 16 additions & 16 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec

// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
class ConstantFolder : public MixedModeMutator {
public:
explicit ConstantFolder(IRModule module)
: module_(module),
Expand All @@ -89,6 +89,8 @@ class ConstantFolder : public ExprMutator {
cast_op_(Op::Get("cast")),
ndarray_size_op_(Op::Get("ndarray_size")) {}

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
Expand Down Expand Up @@ -118,7 +120,7 @@ class ConstantFolder : public ExprMutator {
}
}

Expr VisitExpr_(const CallNode* call) final {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nitpick) You could locally rename this argument from post to res, and then you wouldn't need most of the other changes in the function?

if (inside_primitive) {
return GetRef<Expr>(call);
}
Expand All @@ -127,26 +129,25 @@ class ConstantFolder : public ExprMutator {
std::unordered_set<std::string> skip_list{"zeros_like", "ones_like", "full_like", "full"};

auto origin_args = call->args;
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
call = post.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
if (call->args.size() == 0) return post;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
if (op == nullptr) return post;
if (skip_list.count(op->name)) {
return res;
return post;
}
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
if (op_stateful.get(GetRef<Op>(op), false)) return post;
// Try to evaluate shape_of op
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
return EvaluateShapeOf(post, origin_args, call->attrs);
}

if (call->op == ndarray_size_op_) {
return EvaluateNdarraySize(res, origin_args, call->attrs);
return EvaluateNdarraySize(post, origin_args, call->attrs);
}

// We should think about potentially constant evaluation over these ops too.
Expand All @@ -162,19 +163,18 @@ class ConstantFolder : public ExprMutator {
}
}
if (all_const_args) {
return ConstEvaluate(res);
return ConstEvaluate(post);
} else {
return res;
return post;
}
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
op = post.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
return post;
}
}

Expand Down