Skip to content

Commit

Permalink
Use a caching version of stmt_uses_vars in TightenProducerConsumer no…
Browse files Browse the repository at this point in the history
…des (#8102)

We were making a very large number stmt_uses_vars queries that covered
the same sub-stmts. I solved it by adding a cache.

Speeds up local laplacian lowering by 10% by basically removing this
pass from the profile.

Also a drive-by typo fix in Lower.cpp
  • Loading branch information
abadams authored Feb 26, 2024
1 parent 4399ed8 commit aae84f6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
80 changes: 68 additions & 12 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,67 @@ class InitializeSemaphores : public IRMutator {
}
};

// A class to support stmt_uses_vars queries that repeatedly hit the same
// sub-stmts. Used to support TightenProducerConsumerNodes below.
class CachingStmtUsesVars : public IRMutator {
const Scope<> &query;
bool found_use = false;
std::map<Stmt, bool> cache;

using IRMutator::visit;
Expr visit(const Variable *op) override {
found_use |= query.contains(op->name);
return op;
}

Expr visit(const Call *op) override {
found_use |= query.contains(op->name);
IRMutator::visit(op);
return op;
}

Stmt visit(const Provide *op) override {
found_use |= query.contains(op->name);
IRMutator::visit(op);
return op;
}

public:
CachingStmtUsesVars(const Scope<> &q)
: query(q) {
}

using IRMutator::mutate;
Stmt mutate(const Stmt &s) override {
auto it = cache.find(s);
if (it != cache.end()) {
found_use |= it->second;
} else {
bool old = found_use;
found_use = false;
Stmt stmt = IRMutator::mutate(s);
if (found_use) {
cache.emplace(s, true);
} else {
cache.emplace(s, false);
}
found_use |= old;
}
return s;
}

bool check_stmt(const Stmt &s) {
found_use = false;
mutate(s);
return found_use;
}
};

// Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
class TightenProducerConsumerNodes : public IRMutator {
using IRMutator::visit;

Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<int> &scope) {
Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) {
if (const LetStmt *let = body.as<LetStmt>()) {
Stmt orig = body;
// 'orig' is only used to keep a reference to the let
Expand All @@ -595,7 +651,7 @@ class TightenProducerConsumerNodes : public IRMutator {
body = ProducerConsumer::make(name, is_producer, body);
} else {
// Recurse onto a non-let-node
body = make_producer_consumer(name, is_producer, body, scope);
body = make_producer_consumer(name, is_producer, body, scope, uses_vars);
}

for (auto it = containing_lets.rbegin(); it != containing_lets.rend(); it++) {
Expand All @@ -611,44 +667,44 @@ class TightenProducerConsumerNodes : public IRMutator {
vector<Stmt> sub_stmts;
Stmt rest;
do {
Stmt first = block->first;
sub_stmts.push_back(block->first);
rest = block->rest;
block = rest.as<Block>();
} while (block);
sub_stmts.push_back(rest);

for (Stmt &s : sub_stmts) {
if (stmt_uses_vars(s, scope)) {
s = make_producer_consumer(name, is_producer, s, scope);
if (uses_vars.check_stmt(s)) {
s = make_producer_consumer(name, is_producer, s, scope, uses_vars);
}
}

return Block::make(sub_stmts);
} else if (const ProducerConsumer *pc = body.as<ProducerConsumer>()) {
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope));
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope, uses_vars));
} else if (const Realize *r = body.as<Realize>()) {
return Realize::make(r->name, r->types, r->memory_type,
r->bounds, r->condition,
make_producer_consumer(name, is_producer, r->body, scope));
make_producer_consumer(name, is_producer, r->body, scope, uses_vars));
} else {
return ProducerConsumer::make(name, is_producer, body);
}
}

Stmt visit(const ProducerConsumer *op) override {
Stmt body = mutate(op->body);
Scope<int> scope;
scope.push(op->name, 0);
Scope<> scope;
scope.push(op->name);
Function f = env.find(op->name)->second;
if (f.outputs() == 1) {
scope.push(op->name + ".buffer", 0);
scope.push(op->name + ".buffer");
} else {
for (int i = 0; i < f.outputs(); i++) {
scope.push(op->name + "." + std::to_string(i) + ".buffer", 0);
scope.push(op->name + "." + std::to_string(i) + ".buffer");
}
}
return make_producer_consumer(op->name, op->is_producer, body, scope);
CachingStmtUsesVars uses_vars{scope};
return make_producer_consumer(op->name, op->is_producer, body, scope, uses_vars);
}

const map<string, Function> &env;
Expand Down
2 changes: 1 addition & 1 deletion src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ void lower_impl(const vector<Function> &output_funcs,
debug(1) << "Simplifying...\n";
s = simplify(s);
s = unify_duplicate_lets(s);
log("Lowering after second simplifcation:", s);
log("Lowering after second simplification:", s);

debug(1) << "Reduce prefetch dimension...\n";
s = reduce_prefetch_dimension(s, t);
Expand Down

0 comments on commit aae84f6

Please sign in to comment.