diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index c4e82cde139b6..b3f64db1eee96 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -22,7 +22,7 @@ namespace schedule { * \param sch The root schedule to infer all the bounds. * \return the result bound of the iteration Variable */ -Map InferBound(Schedule sch); +Map InferBound(const Schedule& sch); /*! * \brief Schedule s' dependent operations. diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index d60504f2c51ea..a805943358955 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -432,7 +432,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch(Binary) .set_dispatch(Binary); - IntSet EvalSet(Expr e, const std::unordered_map& dom_map) { return IntSetEvaluator(dom_map).Eval(e); @@ -444,17 +443,12 @@ IntSet EvalSet(Expr e, for (auto kv : dom_map) { dmap[kv.first->var.as()] = kv.second; } - IntSetEvaluator m(dmap); - return m.Eval(e); + return EvalSet(e, dmap); } IntSet EvalSet(Range r, - const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - IntSetEvaluator m(dmap); + const std::unordered_map& dom_map) { + IntSetEvaluator m(dom_map); IntSet min_set = m.Eval(r->min); IntSet ext_set = m.Eval(r->extent).cover_interval(); const Interval& ei = ext_set.as()->i; @@ -463,6 +457,15 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } +IntSet EvalSet(Range r, + const Map& dom_map) { + std::unordered_map dmap; + for (auto kv : dom_map) { + dmap[kv.first->var.as()] = kv.second; + } + return EvalSet(r, dmap); +} + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const IntervalSet *op, IRPrinter *p) { p->stream << "interval-set[" @@ -470,6 +473,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) << op->i.max << ']'; }); - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 979d138af9e2c..f5de7450194bc 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -103,6 +103,9 @@ IntSet EvalSet(Expr e, */ IntSet EvalSet(Range r, const Map& dom_map); +IntSet EvalSet(Range r, + const std::unordered_map& dom_map); + /*! * \brief Create an union set of all sets diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 9fe530b6767a1..1d452047af702 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "./graph.h" #include "../arithmetic/int_set.h" #include "../runtime/thread_storage_scope.h" @@ -131,7 +133,6 @@ void PassUp(const FuseNode* s, } } - void PassUp(const RebaseNode* s, const std::unordered_map& dom_map, const IntSet& rebased, @@ -180,82 +181,69 @@ void PassUp(const Stage& s, } } -/*! - * \brief Pass the bound of tensor read - * to the corresponding bound of the IterVar of operation - * \param tensor The tensor to be passed. - * \param dim_bounds The read index set on each dimension. - * \param The result IterVar bound . - */ -void PassToOperation( - const Tensor& tensor, - const std::vector& dim_bounds, - std::unordered_map >* result) { - // This is a push style operation, given output bound, push to the op IterVar bound. - // It cannot handle complicated cases where op bound is coupled with bounds of - // all of its outputs, without having a simple communicative union relation. - // - // Eventually, we need to change the inference to be a Pull style inference - if (tensor->op.as()) { - auto root_iter_vars = tensor->op->root_iter_vars(); - const ComputeOpNode* op = tensor->op.as(); - CHECK_EQ(op->axis.size() + op->reduce_axis.size(), root_iter_vars.size()); - for (size_t i = 0; i < op->axis.size(); ++i) { - (*result)[op->axis[i]].push_back(dim_bounds[i]); - } - // reduction. - for (size_t i = 0; i < op->reduce_axis.size(); ++i) { - (*result)[op->reduce_axis[i]].push_back( - IntSet::range(op->reduce_axis[i]->dom)); - } - } else { - LOG(FATAL) << "unknown operation mode " << tensor->op->type_key(); - } -} + +/*! \brief temporary data structure to store Tensor domain */ +struct TensorDom { + // constructor + explicit TensorDom(int ndim) + : data(ndim) {} + /*! \brief The domain data*/ + std::vector > data; +}; /*! - * \brief Recursively propagate bound - * \param post_order The propagation order. + * \brief Propagate bound to target * \param dom_map The domain map to be propagated + * \param out The tensor set to be passed * \return The result bound */ -std::unordered_map -BoundProp(const Array& post_order, - std::unordered_map > *p_state) { - std::unordered_map result; - - for (size_t i = post_order.size(); i != 0; --i) { - Operation op = post_order[i - 1]; - if (op.as()) { - for (auto iv : op->root_iter_vars()) { - CHECK(p_state->count(iv)) - << "Bound of root operator must exists"; - CHECK(!result.count(iv)); - result[iv] = Union(p_state->at(iv)); - } - auto fvisit = [p_state, &result](const NodeRef& n) { - auto *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Operation(call->func.node_).output(call->value_index); - if (t->op.defined() && !t->op.as()) { - std::vector arg_bounds; - for (size_t i = 0; i < t.ndim(); ++i) { - arg_bounds.push_back(EvalSet(call->args[i], result)); - } - PassToOperation(t, arg_bounds, p_state); +void BoundProp(const Operation& op, + const std::unordered_map& dom_map, + std::unordered_map *out) { + if (op.as()) { + auto fvisit = [&dom_map, out](const NodeRef& n) { + auto *call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Operation(call->func.node_).output(call->value_index); + if (t->op.defined() && out->count(t)) { + TensorDom& dom = out->at(t); + for (size_t i = 0; i < t.ndim(); ++i) { + dom.data[i].push_back(EvalSet(call->args[i], dom_map)); } } - }; - ir::PostOrderVisit(op.as()->body, fvisit); - } else if (op.as()) { - // do nothing - } else { - LOG(FATAL) << "unknown operation mode " << op->type_key(); - } + } + }; + ir::PostOrderVisit(op.as()->body, fvisit); + } else if (op.as()) { + // do nothing + } else { + LOG(FATAL) << "unknown operation mode " << op->type_key(); } - return result; } +void InferOpBound(const Operation& op, + const std::unordered_map& tmap, + std::unordered_map* rmap) { + if (op.as()) { + auto root_iter_vars = op->root_iter_vars(); + const ComputeOpNode* compute = op.as(); + const TensorDom& tdom = tmap.at(op.output(0)); + + for (size_t i = 0; i < compute->axis.size(); ++i) { + Range r = arith::Union(tdom.data[i]).cover_range(compute->axis[i]->dom); + CHECK(!rmap->count(compute->axis[i])); + (*rmap)[compute->axis[i]] = r; + } + for (size_t i = 0; i < compute->reduce_axis.size(); ++i) { + CHECK(!rmap->count(compute->reduce_axis[i])); + (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom; + } + } else if (op.as()) { + // dp nothing + } else { + LOG(FATAL) << "unknown operation mode " << op->type_key(); + } +} // check if scope inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { @@ -267,8 +255,18 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank; } -void InferBound(const Stage& stage, - std::unordered_map* rmap) { +// The map beteen tensor and operation it feeds ti +using FeedGraph = std::unordered_map >; + +// AttachPath maps op-> a list of IterVar +// That represents the loop nest op sits in from inner most to outermost +using AttachPath = Map >; + + +void InferRootBound(const Stage& stage, + const FeedGraph& feed_graph, + const AttachPath& attach_path, + std::unordered_map* rmap) { if (stage->attach_type == kInline) return; if (stage->attach_type == kRoot || stage->attach_type == kNone) { auto root_iter_vars = stage->op->root_iter_vars(); @@ -277,15 +275,46 @@ void InferBound(const Stage& stage, CHECK(!rmap->count(iv)); (*rmap)[iv] = iv->dom; } + return; } + // Infer root bounds for the attached node. + CHECK_EQ(stage->attach_type, kScope); + Stage parent = stage->attach_stage; + CHECK(parent.defined()); - if (stage->attach_type == kScope) { - Stage parent = stage->attach_stage; - CHECK(parent.defined()); - auto g = CreateReadGraph({parent->op}); - auto post_order = PostDFSOrder({parent->op}, g); - std::unordered_map up_state; + // The tensor domain. + std::unordered_map tmap; + // consumers other than parent + std::unordered_set consumers; + // initialize the result + bool direct_consume_by_parent = false; + for (int i = 0; i < stage->op->num_outputs(); ++i) { + Tensor t = stage->op.output(i); + tmap.emplace(t, TensorDom(t.ndim())); + auto it = feed_graph.find(t); + if (it != feed_graph.end()) { + for (const Operation& op : it->second) { + if (op != parent->op) { + consumers.insert(op); + } else { + direct_consume_by_parent = true; + } + } + } + } + // The relax set + // Thie specifieds the iteration variables that need to be relaxed + // from the already inferred bounds. + std::unordered_map relax_set; + for (IterVar iv : attach_path.at(stage->op)) { + if (ScopeRelax(iv, stage->scope)) { + relax_set[iv->var.get()] = IntSet::range(rmap->at(iv)); + } + } + if (direct_consume_by_parent) { + // Bound inference logics in parent. + std::unordered_map up_state; bool fix_value = true; for (auto iv : parent->leaf_iter_vars) { Range vrange = rmap->at(iv); @@ -305,48 +334,104 @@ void InferBound(const Stage& stage, fix_value = false; } } - // get the bound of the root IterVars given the current condition + // get the bound of the root IterVars given current location. PassUp(parent, *rmap, &up_state); - std::unordered_map > bp_state; + + std::unordered_map dom_map; for (auto iv : parent->op->root_iter_vars()) { - CHECK(up_state.count(iv)); - bp_state[iv] = {up_state.at(iv)}; + Range r = up_state.at(iv).cover_range(iv->dom); + if (relax_set.size() != 0) { + dom_map[iv->var.get()] = EvalSet(r, relax_set); + } else { + dom_map[iv->var.get()] = IntSet::range(r); + } } - auto result = BoundProp(post_order, &bp_state); - - // Set relaxation for the threads in parent. - Map relax_set; - Stage s = stage; - while (s->attach_type == kScope) { - s = s->attach_stage; - for (auto iv : s->leaf_iter_vars) { - if (ScopeRelax(iv, stage->scope)) { - relax_set.Set(iv, IntSet::range(rmap->at(iv))); - } + // prop from parent. + BoundProp(parent->op, dom_map, &tmap); + } + // Bound prop by other consumers. + // To explain the the general logic, consider the example: + // + // for (i_outer, 0, 10) { + // producer + // + // for (i_inner, 0, 4) { + // consumer op + // } + // } + // - Get the root domain of each of consumer op, say it is [i_inner + i_outer*8, extent=4) + // - We need to relax the bound, since the current producer is attached at i_outer + // - The attach_path of consumer is [i_inner, i_outer], then we know [i_inner] need to be relaxed + // - Do a traversal of attach_path until reaching the producer's attachment point, set these as relaxed. + for (const Operation& op : consumers) { + std::unordered_map dom_map; + bool found = false; + for (IterVar iv : attach_path.at(op)) { + if (iv == stage->attach_ivar) { + found = true; break; } + Range vrange = rmap->at(iv); + CHECK(is_zero(vrange->min)) + << "InferBound requires every leaf iter var's min equals 0, " + << "call schedule.normalize to achieve this."; + relax_set[iv->var.get()] = IntSet::range(vrange); } + CHECK(found) + << "Invalid Schedule, cannot find the producer " << stage->op + << " along the loop nest specified by compute_at of consumer " << op; + for (auto iv : op->root_iter_vars()) { + Range r = rmap->at(iv); + dom_map[iv->var.get()] = EvalSet(r, relax_set); + } + BoundProp(op, dom_map, &tmap); + } + InferOpBound(stage->op, tmap, rmap); +} - for (auto iv : stage->op->root_iter_vars()) { - CHECK(result.count(iv)); - CHECK(!rmap->count(iv)); - Range r = result.at(iv).cover_range(iv->dom); - if (relax_set.size() != 0) { - r = EvalSet(r, relax_set).cover_range(iv->dom); - } - (*rmap)[iv] = r; +FeedGraph CreateFeedGraph(const Schedule& sch) { + auto g = CreateReadGraph(sch->roots); + FeedGraph fg; + for (auto kv : g) { + for (Tensor t : kv.second) { + fg[t].push_back(kv.first); } } - // get range of all child iter vars. - PassDown(stage, rmap); + return fg; } +// Create AttachPath that maps op-> a list of IterVar +// That represents the loop nest op sits in from inner most to outermost +AttachPath CreateAttachPath(const Schedule& sch) { + AttachPath ret; + for (Stage stage : sch->stages) { + Array path; + for (Stage s = stage; s->attach_type == kScope;) { + IterVar attach_ivar = s->attach_ivar; + s = s->attach_stage; + bool start_attach = false; + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (iv == attach_ivar) start_attach = true; + if (start_attach) path.push_back(iv); + } + CHECK(start_attach) + << "Invalid Schedule: cannot find attach point " << attach_ivar + << " in the schedule of " << s->op; + } + ret.Set(stage->op, path); + } + return ret; +} -Map InferBound(Schedule sch) { +Map InferBound(const Schedule& sch) { + FeedGraph feed_graph = CreateFeedGraph(sch); + AttachPath attach_path = CreateAttachPath(sch); std::unordered_map ret; - // reverse post DFS order, from out most stage to the innermost for (size_t i = sch->stages.size(); i != 0; --i) { - Stage stage = sch->stages[i - 1]; - InferBound(stage, &ret); + const Stage& stage = sch->stages[i - 1]; + InferRootBound(stage, feed_graph, attach_path, &ret); + // pass down to get bound of all iter vars. + PassDown(stage, &ret); } return Map(ret.begin(), ret.end()); }