Skip to content

Commit

Permalink
[SCHEDULE] Refactor bound inference logic
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 12, 2017
1 parent b8f0ec5 commit a162619
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 115 deletions.
2 changes: 1 addition & 1 deletion include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace schedule {
* \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable
*/
Map<IterVar, Range> InferBound(Schedule sch);
Map<IterVar, Range> InferBound(const Schedule& sch);

/*!
* \brief Schedule s' dependent operations.
Expand Down
22 changes: 12 additions & 10 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable)
.set_dispatch<And>(Binary<And>)
.set_dispatch<Or>(Binary<Or>);


IntSet EvalSet(Expr e,
const std::unordered_map<const Variable*, IntSet>& dom_map) {
return IntSetEvaluator(dom_map).Eval(e);
Expand All @@ -444,17 +443,12 @@ IntSet EvalSet(Expr e,
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
return m.Eval(e);
return EvalSet(e, dmap);
}

IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
IntSetEvaluator m(dmap);
const std::unordered_map<const Variable*, IntSet>& dom_map) {
IntSetEvaluator m(dom_map);
IntSet min_set = m.Eval(r->min);
IntSet ext_set = m.Eval(r->extent).cover_interval();
const Interval& ei = ext_set.as<IntervalSet>()->i;
Expand All @@ -463,13 +457,21 @@ IntSet EvalSet(Range r,
return Combine<Add>(min_set, ext_set);
}

IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first->var.as<Variable>()] = kv.second;
}
return EvalSet(r, dmap);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IntervalSet>([](const IntervalSet *op, IRPrinter *p) {
p->stream << "interval-set["
<< "[" << op->i.min << ", "
<< op->i.max << ']';
});


} // namespace arith
} // namespace tvm
3 changes: 3 additions & 0 deletions src/arithmetic/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ IntSet EvalSet(Expr e,
*/
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(Range r,
const std::unordered_map<const Variable*, IntSet>& dom_map);


/*!
* \brief Create an union set of all sets
Expand Down
Loading

0 comments on commit a162619

Please sign in to comment.