From 97ccde6e9194cbe05e57f2bc32a7076cae04ce9e Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 21 Jan 2017 21:12:48 -0800 Subject: [PATCH] Fix the visitor --- include/tvm/ir_mutator.h | 18 +++-- include/tvm/ir_visitor.h | 1 + src/pass/ir_mutator.cc | 160 +++++++++++++++++++++------------------ src/pass/ir_visitor.cc | 15 ++-- 4 files changed, 108 insertions(+), 86 deletions(-) diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 44345fbf7c773..b57bca25eb493 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -9,6 +9,7 @@ #include #include #include "./expr.h" +#include "./ir.h" namespace tvm { namespace ir { @@ -53,11 +54,18 @@ class IRMutator { static FMutateStmt& vtable_stmt(); // NOLINT(*) // Set of overloadable functions // The underscore allows Mutate not to be shadowed by inheritance - virtual Stmt Mutate_(const LetStmt* op, const Stmt& stmt); - virtual Stmt Mutate_(const AttrStmt* op, const Stmt& stmt); - virtual Stmt Mutate_(const Provide* op, const Stmt& stmt); - virtual Stmt Mutate_(const Realize* op, const Stmt& stmt); - virtual Expr Mutate_(const Call* op, const Expr& expr); + virtual Stmt Mutate_(const LetStmt* op, const Stmt& s); + virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s); + virtual Stmt Mutate_(const For* op, const Stmt& s); + virtual Stmt Mutate_(const Provide* op, const Stmt& s); + virtual Stmt Mutate_(const Allocate* op, const Stmt& s); + virtual Stmt Mutate_(const Realize* op, const Stmt& s); + virtual Stmt Mutate_(const Store* op, const Stmt& s); + virtual Stmt Mutate_(const Free* op, const Stmt& s); + virtual Expr Mutate_(const Call* op, const Expr& e); + virtual Expr Mutate_(const Load* op, const Expr& s); + virtual Expr Mutate_(const Variable* op, const Expr& e); + virtual Expr Mutate_(const Let* op, const Expr& e); }; /*! diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 351320841fc05..c942ae7b3700e 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -36,6 +36,7 @@ class IRVisitor { static FVisit& vtable(); // overloadable visit function. virtual void Visit_(const Variable* op); + virtual void Visit_(const AttrStmt* op); virtual void Visit_(const LetStmt* op); virtual void Visit_(const For* op); virtual void Visit_(const Allocate* op); diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index f55b01c676c2b..d2e055ce6bc42 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -72,7 +72,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) .DISPATCH_TO_MUTATE_STMT(LetStmt) .DISPATCH_TO_MUTATE_STMT(AttrStmt) .DISPATCH_TO_MUTATE_STMT(Provide) -.DISPATCH_TO_MUTATE_STMT(Realize); +.DISPATCH_TO_MUTATE_STMT(Realize) +.DISPATCH_TO_MUTATE_STMT(Store) +.DISPATCH_TO_MUTATE_STMT(Free); Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) { Expr value = this->Mutate(op->value); @@ -96,6 +98,47 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) { } } +Stmt IRMutator::Mutate_(const For *op, const Stmt& s) { + Expr min = this->Mutate(op->min); + Expr extent = this->Mutate(op->extent); + Stmt body = this->Mutate(op->body); + if (min.same_as(op->min) && + extent.same_as(op->extent) && + body.same_as(op->body)) { + return s; + } else { + return For::make( + op->loop_var, min, extent, op->for_type, op->device_api, body); + } +} + +Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) { + IRMutator* m = this; + std::vector new_extents; + bool all_extents_unmodified = true; + for (size_t i = 0; i < op->extents.size(); i++) { + new_extents.push_back(m->Mutate(op->extents[i])); + all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); + } + Stmt body = m->Mutate(op->body); + Expr condition = m->Mutate(op->condition); + Expr new_expr; + if (op->new_expr.defined()) { + new_expr = m->Mutate(op->new_expr); + } + if (all_extents_unmodified && + body.same_as(op->body) && + condition.same_as(op->condition) && + new_expr.same_as(op->new_expr)) { + return s; + } else { + return Allocate::make( + op->buffer_var, op->type, + new_extents, condition, body, + new_expr, op->free_function); + } +} + Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) { auto new_args = MutateArray(op->args, this); auto new_value = this->Mutate(op->value); @@ -136,8 +179,25 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) { } } +Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) { + Expr value = this->Mutate(op->value); + Expr index = this->Mutate(op->index); + if (value.same_as(op->value) && index.same_as(op->index)) { + return s; + } else { + return Store::make(op->buffer_var, value, index); + } +} + +Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { + return s; +} + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) -.DISPATCH_TO_MUTATE_EXPR(Call); +.DISPATCH_TO_MUTATE_EXPR(Call) +.DISPATCH_TO_MUTATE_EXPR(Let) +.DISPATCH_TO_MUTATE_EXPR(Load) +.DISPATCH_TO_MUTATE_EXPR(Variable); Expr IRMutator::Mutate_(const Call* op, const Expr& e) { auto new_args = MutateArray(op->args, this); @@ -149,6 +209,31 @@ Expr IRMutator::Mutate_(const Call* op, const Expr& e) { } } +Expr IRMutator::Mutate_(const Load *op, const Expr& e) { + Expr index = this->Mutate(op->index); + if (index.same_as(op->index)) { + return e; + } else { + return Load::make(op->type, op->buffer_var, index); + } +} + + +Expr IRMutator::Mutate_(const Variable *op, const Expr& e) { + return e; +} + +Expr IRMutator::Mutate_(const Let *op, const Expr& e) { + Expr value = this->Mutate(op->value); + Expr body = this->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return e; + } else { + return Let::make(op->var, value, body); + } +} + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch([](const Reduce* op, const Expr& e, IRMutator* m) { Array new_rdom = MutateRDom(op->rdom, m); @@ -165,8 +250,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch(ReturnSelfExpr) .set_dispatch(ReturnSelfExpr) .set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr) -.set_dispatch(ReturnSelfExpr); +.set_dispatch(ReturnSelfExpr); TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch([](const Cast* op, const Expr& e, IRMutator* m) { @@ -229,14 +313,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) return Select::make(cond, t, f); } }) -.set_dispatch([](const Load *op, const Expr& e, IRMutator* m) { - Expr index = m->Mutate(op->index); - if (index.same_as(op->index)) { - return e; - } else { - return Load::make(op->type, op->buffer_var, index); - } - }) .set_dispatch([](const Ramp *op, const Expr& e, IRMutator* m) { Expr base = m->Mutate(op->base); Expr stride = m->Mutate(op->stride); @@ -254,16 +330,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) } else { return Broadcast::make(value, op->lanes); } - }) -.set_dispatch([](const Let *op, const Expr& e, IRMutator* m) { - Expr value = m->Mutate(op->value); - Expr body = m->Mutate(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { - return e; - } else { - return Let::make(op->var, value, body); - } }); TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) @@ -285,56 +351,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) return ProducerConsumer::make(op->func, op->is_producer, body); } }) -.set_dispatch([](const For *op, const Stmt& s, IRMutator* m) { - Expr min = m->Mutate(op->min); - Expr extent = m->Mutate(op->extent); - Stmt body = m->Mutate(op->body); - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body)) { - return s; - } else { - return For::make( - op->loop_var, min, extent, op->for_type, op->device_api, body); - } - }) -.set_dispatch([](const Store *op, const Stmt& s, IRMutator* m) { - Expr value = m->Mutate(op->value); - Expr index = m->Mutate(op->index); - if (value.same_as(op->value) && index.same_as(op->index)) { - return s; - } else { - return Store::make(op->buffer_var, value, index); - } - }) -.set_dispatch([](const Allocate *op, const Stmt& s, IRMutator* m) { - std::vector new_extents; - bool all_extents_unmodified = true; - for (size_t i = 0; i < op->extents.size(); i++) { - new_extents.push_back(m->Mutate(op->extents[i])); - all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); - } - Stmt body = m->Mutate(op->body); - Expr condition = m->Mutate(op->condition); - Expr new_expr; - if (op->new_expr.defined()) { - new_expr = m->Mutate(op->new_expr); - } - if (all_extents_unmodified && - body.same_as(op->body) && - condition.same_as(op->condition) && - new_expr.same_as(op->new_expr)) { - return s; - } else { - return Allocate::make( - op->buffer_var, op->type, - new_extents, condition, body, - new_expr, op->free_function); - } - }) -.set_dispatch([](const Free *op, const Stmt& s, IRMutator* m) { - return s; - }) .set_dispatch([](const Block *op, const Stmt& s, IRMutator* m) { Stmt first = m->Mutate(op->first); Stmt rest = m->Mutate(op->rest); diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index b8fbeacbd8bbc..348dadc85aa4b 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -73,6 +73,11 @@ void IRVisitor::Visit_(const LetStmt *op) { this->Visit(op->body); } +void IRVisitor::Visit_(const AttrStmt* op) { + this->Visit(op->value); + this->Visit(op->body); +} + void IRVisitor::Visit_(const For *op) { IRVisitor* v = this; v->Visit(op->min); @@ -112,15 +117,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .set_dispatch([](const Reduce* op, IRVisitor* v) { VisitRDom(op->rdom, v); v->Visit(op->source); - }); - -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch([](const AttrStmt* op, IRVisitor* v) { - v->Visit(op->value); - v->Visit(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) + }) .set_dispatch(NoOp) .set_dispatch(NoOp) .set_dispatch(NoOp)