From 5c07413cd46692b2b23188feea690fd6b631e713 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang Date: Sat, 11 Feb 2017 21:55:57 -0800 Subject: [PATCH] [PASS] Change IRVisitor interfaces to function override (#42) * [PASS] Change IRVisitor interfaces to function override * [PASS] Change IRMutator interfaces to overloadable function --- include/tvm/ir_mutator.h | 84 ++++++- include/tvm/ir_visitor.h | 35 ++- src/pass/ir_mutator.cc | 488 ++++++++++++++++++++++++--------------- src/pass/ir_visitor.cc | 232 +++++++++++-------- 4 files changed, 549 insertions(+), 290 deletions(-) diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index eea6a3343f37..c428232698e8 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -16,7 +16,8 @@ namespace ir { /*! * \brief a base class for mutator to iterative mutate the IR * - * This IRMutator is implemented via IRFunctor instead of Visitor Pattern. + * This IRMutator is implemented via Visitor Pattern. + * Also you can implement via IRFunctor. * This enables easy extensions of possible new Node. * It also makes changing return types easier. * @@ -54,20 +55,91 @@ class IRMutator { static FMutateStmt& vtable_stmt(); // NOLINT(*) // Set of overloadable functions // The underscore allows Mutate not to be shadowed by inheritance + virtual Stmt Mutate_(const Variable* op, const Stmt& s); virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); + virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); virtual Stmt Mutate_(const For* op, const Stmt& s); - virtual Stmt Mutate_(const Provide* op, const Stmt& s); virtual Stmt Mutate_(const Allocate* op, const Stmt& s); - virtual Stmt Mutate_(const Realize* op, const Stmt& s); + virtual Stmt Mutate_(const Load* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s); + virtual Stmt Mutate_(const Let* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s); - virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s); + virtual Stmt Mutate_(const Call* op, const Stmt& s); + virtual Stmt Mutate_(const Add* op, const Stmt& e); + virtual Stmt Mutate_(const Sub* op, const Stmt& e); + virtual Stmt Mutate_(const Mul* op, const Stmt& e); + virtual Stmt Mutate_(const Div* op, const Stmt& e); + virtual Stmt Mutate_(const Mod* op, const Stmt& e); + virtual Stmt Mutate_(const Min* op, const Stmt& e); + virtual Stmt Mutate_(const Max* op, const Stmt& e); + virtual Stmt Mutate_(const EQ* op, const Stmt& e); + virtual Stmt Mutate_(const NE* op, const Stmt& e); + virtual Stmt Mutate_(const LT* op, const Stmt& e); + virtual Stmt Mutate_(const LE* op, const Stmt& e); + virtual Stmt Mutate_(const GT* op, const Stmt& e); + virtual Stmt Mutate_(const GE* op, const Stmt& e); + virtual Stmt Mutate_(const And* op, const Stmt& e); + virtual Stmt Mutate_(const Or* op, const Stmt& e); + virtual Stmt Mutate_(const Reduce* op, const Stmt& s); + virtual Stmt Mutate_(const Cast* op, const Stmt& s); + virtual Stmt Mutate_(const Not* op, const Stmt& s); + virtual Stmt Mutate_(const Select* op, const Stmt& s); + virtual Stmt Mutate_(const Ramp* op, const Stmt& s); + virtual Stmt Mutate_(const Broadcast* op, const Stmt& e); + virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e); + virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e); + virtual Stmt Mutate_(const Provide* op, const Stmt& e); + virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Block* op, const Stmt& s); - virtual Expr Mutate_(const Call* op, const Expr& e); - virtual Expr Mutate_(const Load* op, const Expr& s); + virtual Stmt Mutate_(const Evaluate* op, const Stmt& e); + virtual Stmt Mutate_(const IntImm* op, const Stmt& e); + virtual Stmt Mutate_(const UIntImm* op, const Stmt& e); + virtual Stmt Mutate_(const FloatImm* op, const Stmt& e); + virtual Stmt Mutate_(const StringImm* op, const Stmt& e); + virtual Expr Mutate_(const Variable* op, const Expr& e); + virtual Expr Mutate_(const LetStmt* op, const Expr& e); + virtual Expr Mutate_(const AttrStmt* op, const Expr& e); + virtual Expr Mutate_(const IfThenElse* op, const Expr& e); + virtual Expr Mutate_(const For* op, const Expr& e); + virtual Expr Mutate_(const Allocate* op, const Expr& e); + virtual Expr Mutate_(const Load* op, const Expr& e); + virtual Expr Mutate_(const Store* op, const Expr& e); virtual Expr Mutate_(const Let* op, const Expr& e); + virtual Expr Mutate_(const Free* op, const Expr& e); + virtual Expr Mutate_(const Call* op, const Expr& e); + virtual Expr Mutate_(const Add* op, const Expr& e); + virtual Expr Mutate_(const Sub* op, const Expr& e); + virtual Expr Mutate_(const Mul* op, const Expr& e); + virtual Expr Mutate_(const Div* op, const Expr& e); + virtual Expr Mutate_(const Mod* op, const Expr& e); + virtual Expr Mutate_(const Min* op, const Expr& e); + virtual Expr Mutate_(const Max* op, const Expr& e); + virtual Expr Mutate_(const EQ* op, const Expr& e); + virtual Expr Mutate_(const NE* op, const Expr& e); + virtual Expr Mutate_(const LT* op, const Expr& e); + virtual Expr Mutate_(const LE* op, const Expr& e); + virtual Expr Mutate_(const GT* op, const Expr& e); + virtual Expr Mutate_(const GE* op, const Expr& e); + virtual Expr Mutate_(const And* op, const Expr& e); + virtual Expr Mutate_(const Or* op, const Expr& e); + virtual Expr Mutate_(const Reduce* op, const Expr& e); + virtual Expr Mutate_(const Cast* op, const Expr& e); + virtual Expr Mutate_(const Not* op, const Expr& e); + virtual Expr Mutate_(const Select* op, const Expr& e); + virtual Expr Mutate_(const Ramp* op, const Expr& e); + virtual Expr Mutate_(const Broadcast* op, const Expr& e); + virtual Expr Mutate_(const AssertStmt* op, const Expr& e); + virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e); + virtual Expr Mutate_(const Provide* op, const Expr& e); + virtual Expr Mutate_(const Realize* op, const Expr& e); + virtual Expr Mutate_(const Block* op, const Expr& e); + virtual Expr Mutate_(const Evaluate* op, const Expr& e); + virtual Expr Mutate_(const IntImm* op, const Expr& e); + virtual Expr Mutate_(const UIntImm* op, const Expr& e); + virtual Expr Mutate_(const FloatImm* op, const Expr& e); + virtual Expr Mutate_(const StringImm* op, const Expr& e); }; /*! diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index e5711f65ff86..6bfbce25a0df 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -36,16 +36,47 @@ class IRVisitor { static FVisit& vtable(); // overloadable visit function. virtual void Visit_(const Variable* op); - virtual void Visit_(const AttrStmt* op); virtual void Visit_(const LetStmt* op); + virtual void Visit_(const AttrStmt* op); + virtual void Visit_(const IfThenElse* op); virtual void Visit_(const For* op); virtual void Visit_(const Allocate* op); - virtual void Visit_(const IfThenElse* op); virtual void Visit_(const Load* op); virtual void Visit_(const Store* op); virtual void Visit_(const Let* op); virtual void Visit_(const Free* op); virtual void Visit_(const Call* op); + virtual void Visit_(const Add* op); + virtual void Visit_(const Sub* op); + virtual void Visit_(const Mul* op); + virtual void Visit_(const Div* op); + virtual void Visit_(const Mod* op); + virtual void Visit_(const Min* op); + virtual void Visit_(const Max* op); + virtual void Visit_(const EQ* op); + virtual void Visit_(const NE* op); + virtual void Visit_(const LT* op); + virtual void Visit_(const LE* op); + virtual void Visit_(const GT* op); + virtual void Visit_(const GE* op); + virtual void Visit_(const And* op); + virtual void Visit_(const Or* op); + virtual void Visit_(const Reduce* op); + virtual void Visit_(const Cast* op); + virtual void Visit_(const Not* op); + virtual void Visit_(const Select* op); + virtual void Visit_(const Ramp* op); + virtual void Visit_(const Broadcast* op); + virtual void Visit_(const AssertStmt* op); + virtual void Visit_(const ProducerConsumer* op); + virtual void Visit_(const Provide* op); + virtual void Visit_(const Realize* op); + virtual void Visit_(const Block* op); + virtual void Visit_(const Evaluate* op); + virtual void Visit_(const IntImm* op); + virtual void Visit_(const UIntImm* op); + virtual void Visit_(const FloatImm* op); + virtual void Visit_(const StringImm* op); }; /*! diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index f10c9c089f1d..07f2b6d21b28 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -16,11 +16,6 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) static FMutateStmt inst; return inst; } -// const expr -inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) { - return e; -} - inline Array MutateArray(Array arr, IRMutator *m) { std::vector new_arr(arr.size()); bool changed = false; @@ -58,47 +53,33 @@ inline Array MutateIterVarArr(Array rdom, IRMutator *m) { } } + +// Mutate Stmt + #define DISPATCH_TO_MUTATE_STMT(OP) \ set_dispatch([](const OP* op, const Stmt& s, IRMutator* m) { \ return m->Mutate_(op, s); \ }) -#define DISPATCH_TO_MUTATE_EXPR(OP) \ - set_dispatch([](const OP* op, const Expr& e, IRMutator* m) { \ - return m->Mutate_(op, e); \ - }) - -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) -.DISPATCH_TO_MUTATE_STMT(LetStmt) -.DISPATCH_TO_MUTATE_STMT(AttrStmt) -.DISPATCH_TO_MUTATE_STMT(Provide) -.DISPATCH_TO_MUTATE_STMT(Realize) -.DISPATCH_TO_MUTATE_STMT(Store) -.DISPATCH_TO_MUTATE_STMT(IfThenElse) -.DISPATCH_TO_MUTATE_STMT(For) -.DISPATCH_TO_MUTATE_STMT(Allocate) -.DISPATCH_TO_MUTATE_STMT(Block) -.DISPATCH_TO_MUTATE_STMT(Free); - -Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { Expr value = this->Mutate(op->value); Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return s; } else { - return LetStmt::make(op->var, value, body); + return AttrStmt::make(op->node, op->type_key, value, body); } } -Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { +Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { Expr value = this->Mutate(op->value); Stmt body = this->Mutate(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return s; } else { - return AttrStmt::make(op->node, op->type_key, value, body); + return LetStmt::make(op->var, value, body); } } @@ -143,6 +124,36 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { } } +Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { + Expr condition = this->Mutate(op->condition); + Stmt then_case = this->Mutate(op->then_case); + Stmt else_case; + if (else_case.defined()) { + else_case = this->Mutate(op->else_case); + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return s; + } else { + return IfThenElse::make(condition, then_case, else_case); + } +} + +Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) { + return s; +} + +Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { + Expr value = this->Mutate(op->value); + Expr index = this->Mutate(op->index); + if (value.same_as(op->value) && index.same_as(op->index)) { + return s; + } else { + return Store::make(op->buffer_var, value, index); + } +} + Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { auto new_args = MutateArray(op->args, this); auto new_value = this->Mutate(op->value); @@ -183,63 +194,137 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { } } -Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { - Expr value = this->Mutate(op->value); - Expr index = this->Mutate(op->index); - if (value.same_as(op->value) && index.same_as(op->index)) { +Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { + Stmt first = this->Mutate(op->first); + Stmt rest = this->Mutate(op->rest); + if (first.same_as(op->first) && + rest.same_as(op->rest)) { return s; } else { - return Store::make(op->buffer_var, value, index); + return Block::make(first, rest); } } -Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { - return s; -} - -Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) { +Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) { Expr condition = this->Mutate(op->condition); - Stmt then_case = this->Mutate(op->then_case); - Stmt else_case; - if (else_case.defined()) { - else_case = this->Mutate(op->else_case); - } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + Expr message = this->Mutate(op->message); + + if (condition.same_as(op->condition) && message.same_as(op->message)) { return s; } else { - return IfThenElse::make(condition, then_case, else_case); + return AssertStmt::make(condition, message); } } -Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) { - Stmt first = this->Mutate(op->first); - Stmt rest = this->Mutate(op->rest); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { +Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) { + Stmt body = this->Mutate(op->body); + if (body.same_as(op->body)) { return s; } else { - return Block::make(first, rest); + return ProducerConsumer::make(op->func, op->is_producer, body); } } -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.DISPATCH_TO_MUTATE_EXPR(Call) -.DISPATCH_TO_MUTATE_EXPR(Let) -.DISPATCH_TO_MUTATE_EXPR(Load) -.DISPATCH_TO_MUTATE_EXPR(Variable); - -Expr IRMutator::Mutate_(const Call* op, const Expr& e) { - auto new_args = MutateArray(op->args, this); - if (op->args.same_as(new_args)) { - return e; +Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) { + Expr v = this->Mutate(op->value); + if (v.same_as(op->value)) { + return s; } else { - return Call::make(op->type, op->name, new_args, op->call_type, - op->func, op->value_index); + return Evaluate::make(v); } } +#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP) \ + Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) { \ + return s; \ + } + +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm) +DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm) + +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) +.DISPATCH_TO_MUTATE_STMT(Variable) +.DISPATCH_TO_MUTATE_STMT(LetStmt) +.DISPATCH_TO_MUTATE_STMT(AttrStmt) +.DISPATCH_TO_MUTATE_STMT(IfThenElse) +.DISPATCH_TO_MUTATE_STMT(For) +.DISPATCH_TO_MUTATE_STMT(Allocate) +.DISPATCH_TO_MUTATE_STMT(Load) +.DISPATCH_TO_MUTATE_STMT(Store) +.DISPATCH_TO_MUTATE_STMT(Let) +.DISPATCH_TO_MUTATE_STMT(Free) +.DISPATCH_TO_MUTATE_STMT(Call) +.DISPATCH_TO_MUTATE_STMT(Add) +.DISPATCH_TO_MUTATE_STMT(Sub) +.DISPATCH_TO_MUTATE_STMT(Mul) +.DISPATCH_TO_MUTATE_STMT(Div) +.DISPATCH_TO_MUTATE_STMT(Mod) +.DISPATCH_TO_MUTATE_STMT(Min) +.DISPATCH_TO_MUTATE_STMT(Max) +.DISPATCH_TO_MUTATE_STMT(EQ) +.DISPATCH_TO_MUTATE_STMT(NE) +.DISPATCH_TO_MUTATE_STMT(LT) +.DISPATCH_TO_MUTATE_STMT(LE) +.DISPATCH_TO_MUTATE_STMT(GT) +.DISPATCH_TO_MUTATE_STMT(GE) +.DISPATCH_TO_MUTATE_STMT(And) +.DISPATCH_TO_MUTATE_STMT(Or) +.DISPATCH_TO_MUTATE_STMT(Reduce) +.DISPATCH_TO_MUTATE_STMT(Cast) +.DISPATCH_TO_MUTATE_STMT(Not) +.DISPATCH_TO_MUTATE_STMT(Select) +.DISPATCH_TO_MUTATE_STMT(Ramp) +.DISPATCH_TO_MUTATE_STMT(Broadcast) +.DISPATCH_TO_MUTATE_STMT(AssertStmt) +.DISPATCH_TO_MUTATE_STMT(ProducerConsumer) +.DISPATCH_TO_MUTATE_STMT(Provide) +.DISPATCH_TO_MUTATE_STMT(Realize) +.DISPATCH_TO_MUTATE_STMT(Block) +.DISPATCH_TO_MUTATE_STMT(Evaluate) +.DISPATCH_TO_MUTATE_STMT(IntImm) +.DISPATCH_TO_MUTATE_STMT(UIntImm) +.DISPATCH_TO_MUTATE_STMT(FloatImm) +.DISPATCH_TO_MUTATE_STMT(StringImm); + + +// Mutate Expr + +#define DISPATCH_TO_MUTATE_EXPR(OP) \ + set_dispatch([](const OP* op, const Expr& e, IRMutator* m) { \ + return m->Mutate_(op, e); \ + }) + +Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { + return e; +} + Expr IRMutator::Mutate_(const Load *op, const Expr& e) { Expr index = this->Mutate(op->index); if (index.same_as(op->index)) { @@ -249,11 +334,6 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) { } } - -Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { - return e; -} - Expr IRMutator::Mutate_(const Let *op, const Expr& e) { Expr value = this->Mutate(op->value); Expr body = this->Mutate(op->body); @@ -265,130 +345,172 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) { } } -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.set_dispatch([](const Reduce* op, const Expr& e, IRMutator* m) { - Array new_axis = MutateIterVarArr(op->axis, m); - Expr new_source = m->Mutate(op->source); - if (op->axis.same_as(new_axis) && - op->source.same_as(new_source)) { - return e; - } else { - return Reduce::make(op->op, new_source, new_axis); - } - }); +Expr IRMutator::Mutate_(const Call* op, const Expr& e) { + auto new_args = MutateArray(op->args, this); + if (op->args.same_as(new_args)) { + return e; + } else { + return Call::make(op->type, op->name, new_args, op->call_type, + op->func, op->value_index); + } +} -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr); +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + Expr IRMutator::Mutate_(const OP* op, const Expr& e) { \ + Expr a = this->Mutate(op->a); \ + Expr b = this->Mutate(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return e; \ + } else { \ + return OP::make(a, b); \ + } \ + } -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.set_dispatch([](const Cast* op, const Expr& e, IRMutator* m) { - Expr value = m->Mutate(op->value); - if (value.same_as(op->value)) { - return e; - } else { - return Cast::make(op->type, value); - } - }); - -// binary operator -template -inline Expr Binary(const T* op, const Expr& e, IRMutator* m) { - Expr a = m->Mutate(op->a); - Expr b = m->Mutate(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { +DEFINE_BIOP_EXPR_MUTATE_(Add) +DEFINE_BIOP_EXPR_MUTATE_(Sub) +DEFINE_BIOP_EXPR_MUTATE_(Mul) +DEFINE_BIOP_EXPR_MUTATE_(Div) +DEFINE_BIOP_EXPR_MUTATE_(Mod) +DEFINE_BIOP_EXPR_MUTATE_(Min) +DEFINE_BIOP_EXPR_MUTATE_(Max) +DEFINE_BIOP_EXPR_MUTATE_(EQ) +DEFINE_BIOP_EXPR_MUTATE_(NE) +DEFINE_BIOP_EXPR_MUTATE_(LT) +DEFINE_BIOP_EXPR_MUTATE_(LE) +DEFINE_BIOP_EXPR_MUTATE_(GT) +DEFINE_BIOP_EXPR_MUTATE_(GE) +DEFINE_BIOP_EXPR_MUTATE_(And) +DEFINE_BIOP_EXPR_MUTATE_(Or) + +Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { + Array new_axis = MutateIterVarArr(op->axis, this); + Expr new_source = this->Mutate(op->source); + if (op->axis.same_as(new_axis) && + op->source.same_as(new_source)) { return e; } else { - return T::make(a, b); + return Reduce::make(op->op, new_source, new_axis); } } -TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch
(Binary
) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary) -.set_dispatch(Binary); +Expr IRMutator::Mutate_(const Cast *op, const Expr& e) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return e; + } else { + return Cast::make(op->type, value); + } +} + +Expr IRMutator::Mutate_(const Not *op, const Expr& e) { + Expr a = this->Mutate(op->a); + if (a.same_as(op->a)) { + return e; + } else { + return Not::make(a); + } +} + +Expr IRMutator::Mutate_(const Select *op, const Expr& e) { + Expr cond = this->Mutate(op->condition); + Expr t = this->Mutate(op->true_value); + Expr f = this->Mutate(op->false_value); + if (cond.same_as(op->condition) && + t.same_as(op->true_value) && + f.same_as(op->false_value)) { + return e; + } else { + return Select::make(cond, t, f); + } +} + +Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) { + Expr base = this->Mutate(op->base); + Expr stride = this->Mutate(op->stride); + if (base.same_as(op->base) && + stride.same_as(op->stride)) { + return e; + } else { + return Ramp::make(base, stride, op->lanes); + } +} + +Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) { + Expr value = this->Mutate(op->value); + if (value.same_as(op->value)) { + return e; + } else { + return Broadcast::make(value, op->lanes); + } +} + +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \ + return e; \ + } + +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.set_dispatch([](const Not* op, const Expr& e, IRMutator* m) { - Expr a = m->Mutate(op->a); - if (a.same_as(op->a)) { - return e; - } else { - return Not::make(a); - } - }) -.set_dispatch([](const Select *op, IRVisitor* v) { - v->Visit(op->condition); - v->Visit(op->true_value); - v->Visit(op->false_value); - }) -.set_dispatch([](const Ramp *op, IRVisitor* v) { - v->Visit(op->base); - v->Visit(op->stride); - }) -.set_dispatch([](const Broadcast *op, IRVisitor* v) { - v->Visit(op->value); - }); +void IRVisitor::Visit_(const Select* op) { + this->Visit(op->condition); + this->Visit(op->true_value); + this->Visit(op->false_value); +} -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const AssertStmt *op, IRVisitor* v) { - v->Visit(op->condition); - v->Visit(op->message); - }) -.set_dispatch([](const ProducerConsumer *op, IRVisitor* v) { - v->Visit(op->body); - }) -.set_dispatch([](const Provide *op, IRVisitor* v) { - VisitArray(op->args, v); - v->Visit(op->value); - }) -.set_dispatch([](const Realize *op, IRVisitor* v) { +void IRVisitor::Visit_(const Ramp *op) { + this->Visit(op->base); + this->Visit(op->stride); +} + +void IRVisitor::Visit_(const Broadcast *op) { + this->Visit(op->value); +} + +void IRVisitor::Visit_(const AssertStmt *op) { + this->Visit(op->condition); + this->Visit(op->message); +} + +void IRVisitor::Visit_(const ProducerConsumer *op) { + this->Visit(op->body); +} + +void IRVisitor::Visit_(const Provide *op) { + VisitArray(op->args, this); + this->Visit(op->value); +} + +void IRVisitor::Visit_(const Realize *op) { // Mutate the bounds - for (size_t i = 0; i < op->bounds.size(); i++) { - v->Visit(op->bounds[i]->min); - v->Visit(op->bounds[i]->extent); - } - - v->Visit(op->body); - v->Visit(op->condition); - }) -.set_dispatch([](const Block *op, IRVisitor* v) { - v->Visit(op->first); - v->Visit(op->rest); - }) -.set_dispatch([](const Evaluate *op, IRVisitor* v) { - v->Visit(op->value); - }); + for (size_t i = 0; i < op->bounds.size(); i++) { + this->Visit(op->bounds[i]->min); + this->Visit(op->bounds[i]->extent); + } + + this->Visit(op->body); + this->Visit(op->condition); +} + +void IRVisitor::Visit_(const Block *op) { + this->Visit(op->first); + this->Visit(op->rest); +} + +void IRVisitor::Visit_(const Evaluate *op) { + this->Visit(op->value); +} + +#define DEFINE_OP_NO_VISIT_(OP) \ + void IRVisitor::Visit_(const OP* op) {} + +DEFINE_OP_NO_VISIT_(IntImm) +DEFINE_OP_NO_VISIT_(UIntImm) +DEFINE_OP_NO_VISIT_(FloatImm) +DEFINE_OP_NO_VISIT_(StringImm) + +#define DISPATCH_TO_VISIT(OP) \ + set_dispatch([](const OP* op, IRVisitor* v) { \ + v->Visit_(op); \ + }) + +TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) +.DISPATCH_TO_VISIT(Variable) +.DISPATCH_TO_VISIT(LetStmt) +.DISPATCH_TO_VISIT(AttrStmt) +.DISPATCH_TO_VISIT(IfThenElse) +.DISPATCH_TO_VISIT(For) +.DISPATCH_TO_VISIT(Allocate) +.DISPATCH_TO_VISIT(Load) +.DISPATCH_TO_VISIT(Store) +.DISPATCH_TO_VISIT(Let) +.DISPATCH_TO_VISIT(Free) +.DISPATCH_TO_VISIT(Call) +.DISPATCH_TO_VISIT(Add) +.DISPATCH_TO_VISIT(Sub) +.DISPATCH_TO_VISIT(Mul) +.DISPATCH_TO_VISIT(Div) +.DISPATCH_TO_VISIT(Mod) +.DISPATCH_TO_VISIT(Min) +.DISPATCH_TO_VISIT(Max) +.DISPATCH_TO_VISIT(EQ) +.DISPATCH_TO_VISIT(NE) +.DISPATCH_TO_VISIT(LT) +.DISPATCH_TO_VISIT(LE) +.DISPATCH_TO_VISIT(GT) +.DISPATCH_TO_VISIT(GE) +.DISPATCH_TO_VISIT(And) +.DISPATCH_TO_VISIT(Or) +.DISPATCH_TO_VISIT(Reduce) +.DISPATCH_TO_VISIT(Cast) +.DISPATCH_TO_VISIT(Not) +.DISPATCH_TO_VISIT(Select) +.DISPATCH_TO_VISIT(Ramp) +.DISPATCH_TO_VISIT(Broadcast) +.DISPATCH_TO_VISIT(AssertStmt) +.DISPATCH_TO_VISIT(ProducerConsumer) +.DISPATCH_TO_VISIT(Provide) +.DISPATCH_TO_VISIT(Realize) +.DISPATCH_TO_VISIT(Block) +.DISPATCH_TO_VISIT(Evaluate) +.DISPATCH_TO_VISIT(IntImm) +.DISPATCH_TO_VISIT(UIntImm) +.DISPATCH_TO_VISIT(FloatImm) +.DISPATCH_TO_VISIT(StringImm); } // namespace ir } // namespace tvm