diff --git a/src/AsyncProducers.cpp b/src/AsyncProducers.cpp index 92012ccfe4c1..352219478923 100644 --- a/src/AsyncProducers.cpp +++ b/src/AsyncProducers.cpp @@ -569,11 +569,67 @@ class InitializeSemaphores : public IRMutator { } }; +// A class to support stmt_uses_vars queries that repeatedly hit the same +// sub-stmts. Used to support TightenProducerConsumerNodes below. +class CachingStmtUsesVars : public IRMutator { + const Scope<> &query; + bool found_use = false; + std::map cache; + + using IRMutator::visit; + Expr visit(const Variable *op) override { + found_use |= query.contains(op->name); + return op; + } + + Expr visit(const Call *op) override { + found_use |= query.contains(op->name); + IRMutator::visit(op); + return op; + } + + Stmt visit(const Provide *op) override { + found_use |= query.contains(op->name); + IRMutator::visit(op); + return op; + } + +public: + CachingStmtUsesVars(const Scope<> &q) + : query(q) { + } + + using IRMutator::mutate; + Stmt mutate(const Stmt &s) override { + auto it = cache.find(s); + if (it != cache.end()) { + found_use |= it->second; + } else { + bool old = found_use; + found_use = false; + Stmt stmt = IRMutator::mutate(s); + if (found_use) { + cache.emplace(s, true); + } else { + cache.emplace(s, false); + } + found_use |= old; + } + return s; + } + + bool check_stmt(const Stmt &s) { + found_use = false; + mutate(s); + return found_use; + } +}; + // Tighten the scope of consume nodes as much as possible to avoid needless synchronization. class TightenProducerConsumerNodes : public IRMutator { using IRMutator::visit; - Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope &scope) { + Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) { if (const LetStmt *let = body.as()) { Stmt orig = body; // 'orig' is only used to keep a reference to the let @@ -595,7 +651,7 @@ class TightenProducerConsumerNodes : public IRMutator { body = ProducerConsumer::make(name, is_producer, body); } else { // Recurse onto a non-let-node - body = make_producer_consumer(name, is_producer, body, scope); + body = make_producer_consumer(name, is_producer, body, scope, uses_vars); } for (auto it = containing_lets.rbegin(); it != containing_lets.rend(); it++) { @@ -611,7 +667,6 @@ class TightenProducerConsumerNodes : public IRMutator { vector sub_stmts; Stmt rest; do { - Stmt first = block->first; sub_stmts.push_back(block->first); rest = block->rest; block = rest.as(); @@ -619,18 +674,18 @@ class TightenProducerConsumerNodes : public IRMutator { sub_stmts.push_back(rest); for (Stmt &s : sub_stmts) { - if (stmt_uses_vars(s, scope)) { - s = make_producer_consumer(name, is_producer, s, scope); + if (uses_vars.check_stmt(s)) { + s = make_producer_consumer(name, is_producer, s, scope, uses_vars); } } return Block::make(sub_stmts); } else if (const ProducerConsumer *pc = body.as()) { - return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope)); + return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope, uses_vars)); } else if (const Realize *r = body.as()) { return Realize::make(r->name, r->types, r->memory_type, r->bounds, r->condition, - make_producer_consumer(name, is_producer, r->body, scope)); + make_producer_consumer(name, is_producer, r->body, scope, uses_vars)); } else { return ProducerConsumer::make(name, is_producer, body); } @@ -638,17 +693,18 @@ class TightenProducerConsumerNodes : public IRMutator { Stmt visit(const ProducerConsumer *op) override { Stmt body = mutate(op->body); - Scope scope; - scope.push(op->name, 0); + Scope<> scope; + scope.push(op->name); Function f = env.find(op->name)->second; if (f.outputs() == 1) { - scope.push(op->name + ".buffer", 0); + scope.push(op->name + ".buffer"); } else { for (int i = 0; i < f.outputs(); i++) { - scope.push(op->name + "." + std::to_string(i) + ".buffer", 0); + scope.push(op->name + "." + std::to_string(i) + ".buffer"); } } - return make_producer_consumer(op->name, op->is_producer, body, scope); + CachingStmtUsesVars uses_vars{scope}; + return make_producer_consumer(op->name, op->is_producer, body, scope, uses_vars); } const map &env; diff --git a/src/Lower.cpp b/src/Lower.cpp index 560e0353c7a4..ca9dc0950d87 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -307,7 +307,7 @@ void lower_impl(const vector &output_funcs, debug(1) << "Simplifying...\n"; s = simplify(s); s = unify_duplicate_lets(s); - log("Lowering after second simplifcation:", s); + log("Lowering after second simplification:", s); debug(1) << "Reduce prefetch dimension...\n"; s = reduce_prefetch_dimension(s, t);