Skip to content

Commit

Permalink
Fix the visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 22, 2017
1 parent 4551021 commit 97ccde6
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 86 deletions.
18 changes: 13 additions & 5 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
#include "./ir.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -53,11 +54,18 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const LetStmt* op, const Stmt& stmt);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& stmt);
virtual Stmt Mutate_(const Provide* op, const Stmt& stmt);
virtual Stmt Mutate_(const Realize* op, const Stmt& stmt);
virtual Expr Mutate_(const Call* op, const Expr& expr);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
};

/*!
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class IRVisitor {
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
Expand Down
160 changes: 88 additions & 72 deletions src/pass/ir_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(LetStmt)
.DISPATCH_TO_MUTATE_STMT(AttrStmt)
.DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize);
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(Free);

Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expand All @@ -96,6 +98,47 @@ Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const For *op, const Stmt& s) {
Expr min = this->Mutate(op->min);
Expr extent = this->Mutate(op->extent);
Stmt body = this->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
}

Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
IRMutator* m = this;
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
}

Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
auto new_args = MutateArray(op->args, this);
auto new_value = this->Mutate(op->value);
Expand Down Expand Up @@ -136,8 +179,25 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
}
}

Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
}

Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s;
}

TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Call);
.DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Let)
.DISPATCH_TO_MUTATE_EXPR(Load)
.DISPATCH_TO_MUTATE_EXPR(Variable);

Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
auto new_args = MutateArray(op->args, this);
Expand All @@ -149,6 +209,31 @@ Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
}
}

Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
}
}


Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
return e;
}

Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
Expr value = this->Mutate(op->value);
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}

TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
Array<IterVar> new_rdom = MutateRDom(op->rdom, m);
Expand All @@ -165,8 +250,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<IntImm>(ReturnSelfExpr)
.set_dispatch<UIntImm>(ReturnSelfExpr)
.set_dispatch<FloatImm>(ReturnSelfExpr)
.set_dispatch<StringImm>(ReturnSelfExpr)
.set_dispatch<Variable>(ReturnSelfExpr);
.set_dispatch<StringImm>(ReturnSelfExpr);

TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
Expand Down Expand Up @@ -229,14 +313,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
return Select::make(cond, t, f);
}
})
.set_dispatch<Load>([](const Load *op, const Expr& e, IRMutator* m) {
Expr index = m->Mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type, op->buffer_var, index);
}
})
.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
Expr base = m->Mutate(op->base);
Expr stride = m->Mutate(op->stride);
Expand All @@ -254,16 +330,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
} else {
return Broadcast::make(value, op->lanes);
}
})
.set_dispatch<Let>([](const Let *op, const Expr& e, IRMutator* m) {
Expr value = m->Mutate(op->value);
Expr body = m->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
});

TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
Expand All @@ -285,56 +351,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return ProducerConsumer::make(op->func, op->is_producer, body);
}
})
.set_dispatch<For>([](const For *op, const Stmt& s, IRMutator* m) {
Expr min = m->Mutate(op->min);
Expr extent = m->Mutate(op->extent);
Stmt body = m->Mutate(op->body);
if (min.same_as(op->min) &&
extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, min, extent, op->for_type, op->device_api, body);
}
})
.set_dispatch<Store>([](const Store *op, const Stmt& s, IRMutator* m) {
Expr value = m->Mutate(op->value);
Expr index = m->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
return Store::make(op->buffer_var, value, index);
}
})
.set_dispatch<Allocate>([](const Allocate *op, const Stmt& s, IRMutator* m) {
std::vector<Expr> new_extents;
bool all_extents_unmodified = true;
for (size_t i = 0; i < op->extents.size(); i++) {
new_extents.push_back(m->Mutate(op->extents[i]));
all_extents_unmodified &= new_extents[i].same_as(op->extents[i]);
}
Stmt body = m->Mutate(op->body);
Expr condition = m->Mutate(op->condition);
Expr new_expr;
if (op->new_expr.defined()) {
new_expr = m->Mutate(op->new_expr);
}
if (all_extents_unmodified &&
body.same_as(op->body) &&
condition.same_as(op->condition) &&
new_expr.same_as(op->new_expr)) {
return s;
} else {
return Allocate::make(
op->buffer_var, op->type,
new_extents, condition, body,
new_expr, op->free_function);
}
})
.set_dispatch<Free>([](const Free *op, const Stmt& s, IRMutator* m) {
return s;
})
.set_dispatch<Block>([](const Block *op, const Stmt& s, IRMutator* m) {
Stmt first = m->Mutate(op->first);
Stmt rest = m->Mutate(op->rest);
Expand Down
15 changes: 6 additions & 9 deletions src/pass/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ void IRVisitor::Visit_(const LetStmt *op) {
this->Visit(op->body);
}

void IRVisitor::Visit_(const AttrStmt* op) {
this->Visit(op->value);
this->Visit(op->body);
}

void IRVisitor::Visit_(const For *op) {
IRVisitor* v = this;
v->Visit(op->min);
Expand Down Expand Up @@ -112,15 +117,7 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->Visit(op->source);
});

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
});

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
})
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
Expand Down

0 comments on commit 97ccde6

Please sign in to comment.