Skip to content

Commit

Permalink
[opt] Add ImmediateIRModifier to provide amortized constant-time repl…
Browse files Browse the repository at this point in the history
…ace_usages_with() (taichi-dev#7001)

Issue: taichi-dev#6933

### Brief Summary

`ImmediateIRModifier` aims at replacing `Stmt::replace_usages_with`,
which visits the whole tree for a single replacement.
`ImmediateIRModifier` is currently associated with a pass, visits the
whole tree once at the beginning of that pass, and performs a single
replacement with amortized constant time. It is now used in two most
recent passes, `eliminate_immutable_local_vars` and `scalarize`. More
passes can be modified to leverage it in the future.

After this PR, the profiling result of the script in taichi-dev#6933 shows that
the time of `eliminate_immutable_local_vars` reduces from `0.956 s` to
`0.162 s`, the time of `scalarize` reduces from `3.510 s` to `0.478 s`,
and the total time of `compile` reduces from `8.550 s` to `4.696 s`.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Dec 30, 2022
1 parent c62414f commit 26b81e7
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 54 deletions.
55 changes: 55 additions & 0 deletions taichi/analysis/gather_statement_usages.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"

namespace taichi::lang {

class GatherStatementUsages : public BasicStmtVisitor {
private:
using BasicStmtVisitor::visit;

// maps a stmt to all its usages <stmt, operand>
std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> stmt_usages_;

public:
explicit GatherStatementUsages() {
invoke_default_visitor = true;
}

void default_visit(Stmt *stmt) {
auto ops = stmt->get_operands();
for (int i = 0; i < ops.size(); i++) {
auto &op = ops[i];
if (op != nullptr) {
stmt_usages_[op].push_back({stmt, i});
}
}
}

void visit(Stmt *stmt) override {
default_visit(stmt);
}

void preprocess_container_stmt(Stmt *stmt) override {
default_visit(stmt);
}

static std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> run(
IRNode *node) {
GatherStatementUsages pass;
node->accept(&pass);
return pass.stmt_usages_;
}
};

namespace irpass::analysis {

std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>>
gather_statement_usages(IRNode *root) {
return GatherStatementUsages::run(root);
}

} // namespace irpass::analysis

} // namespace taichi::lang
2 changes: 2 additions & 0 deletions taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ bool definitely_same_address(Stmt *var1, Stmt *var2);

std::unordered_set<Stmt *> detect_fors_with_break(IRNode *root);
std::unordered_set<Stmt *> detect_loops_with_continue(IRNode *root);
std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>>
gather_statement_usages(IRNode *root);
std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root);
std::unordered_set<SNode *> gather_deactivations(IRNode *root);
std::pair<std::unordered_set<SNode *>, std::unordered_set<SNode *>>
Expand Down
14 changes: 13 additions & 1 deletion taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <thread>
#include <unordered_map>

// #include "taichi/ir/analysis.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"

Expand Down Expand Up @@ -496,4 +496,16 @@ void DelayedIRModifier::mark_as_modified() {
modified_ = true;
}

ImmediateIRModifier::ImmediateIRModifier(IRNode *root) {
stmt_usages_ = irpass::analysis::gather_statement_usages(root);
}

void ImmediateIRModifier::replace_usages_with(Stmt *old_stmt, Stmt *new_stmt) {
if (stmt_usages_.find(old_stmt) == stmt_usages_.end())
return;
for (auto &[usage, i] : stmt_usages_.at(old_stmt)) {
usage->set_operand(i, new_stmt);
}
}

} // namespace taichi::lang
13 changes: 13 additions & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,19 @@ class DelayedIRModifier {
void mark_as_modified();
};

// ImmediateIRModifier aims at replacing Stmt::replace_usages_with, which visits
// the whole tree for a single replacement. ImmediateIRModifier is currently
// associated with a pass, visits the whole tree once at the beginning of that
// pass, and performs a single replacement with amortized constant time.
class ImmediateIRModifier {
private:
std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> stmt_usages_;

public:
explicit ImmediateIRModifier(IRNode *root);
void replace_usages_with(Stmt *old_stmt, Stmt *new_stmt);
};

template <typename T>
inline void StmtFieldManager::operator()(const char *key, T &&value) {
using decay_T = typename std::decay<T>::type;
Expand Down
21 changes: 12 additions & 9 deletions taichi/transforms/eliminate_immutable_local_vars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,29 @@ class EliminateImmutableLocalVars : public BasicStmtVisitor {
private:
using BasicStmtVisitor::visit;

DelayedIRModifier modifier_;
std::unordered_set<Stmt *> immutable_local_vars_;
std::unordered_map<Stmt *, Stmt *> immutable_local_var_to_value_;
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;

public:
explicit EliminateImmutableLocalVars(
const std::unordered_set<Stmt *> &immutable_local_vars)
: immutable_local_vars_(immutable_local_vars) {
const std::unordered_set<Stmt *> &immutable_local_vars,
IRNode *node)
: immutable_local_vars_(immutable_local_vars), immediate_modifier_(node) {
}

void visit(AllocaStmt *stmt) override {
if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) {
modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

void visit(LocalLoadStmt *stmt) override {
if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) {
stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]);
modifier_.erase(stmt);
immediate_modifier_.replace_usages_with(
stmt, immutable_local_var_to_value_[stmt->src]);
delayed_modifier_.erase(stmt);
}
}

Expand All @@ -43,15 +46,15 @@ class EliminateImmutableLocalVars : public BasicStmtVisitor {
TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) ==
immutable_local_var_to_value_.end());
immutable_local_var_to_value_[stmt->dest] = stmt->val;
modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

static void run(IRNode *node) {
EliminateImmutableLocalVars pass(
irpass::analysis::gather_immutable_local_vars(node));
irpass::analysis::gather_immutable_local_vars(node), node);
node->accept(&pass);
pass.modifier_.modify_ir();
pass.delayed_modifier_.modify_ir();
}
};

Expand Down
92 changes: 48 additions & 44 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ static bool is_alloca_scalarizable(AllocaStmt *stmt) {

class Scalarize : public BasicStmtVisitor {
public:
DelayedIRModifier modifier_;
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;

explicit Scalarize(IRNode *node) {
explicit Scalarize(IRNode *node) : immediate_modifier_(node) {
node->accept(this);

modifier_.modify_ir();
delayed_modifier_.modify_ir();
}

/*
Expand Down Expand Up @@ -75,12 +76,12 @@ class Scalarize : public BasicStmtVisitor {
auto scalarized_stmt = std::make_unique<T>(matrix_ptr_stmt.get(),
matrix_init_stmt->values[i]);

modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
modifier_.insert_before(stmt, std::move(scalarized_stmt));
delayed_modifier_.insert_before(stmt, std::move(const_stmt));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt));
}

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -127,19 +128,19 @@ class Scalarize : public BasicStmtVisitor {

matrix_init_values.push_back(scalarized_stmt.get());

modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
modifier_.insert_before(stmt, std::move(scalarized_stmt));
delayed_modifier_.insert_before(stmt, std::move(const_stmt));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = src_dtype;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -186,17 +187,17 @@ class Scalarize : public BasicStmtVisitor {
unary_stmt->ret_type = primitive_type;
matrix_init_values.push_back(unary_stmt.get());

modifier_.insert_before(stmt, std::move(unary_stmt));
delayed_modifier_.insert_before(stmt, std::move(unary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = operand_dtype;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -256,17 +257,17 @@ class Scalarize : public BasicStmtVisitor {
matrix_init_values.push_back(binary_stmt.get());
binary_stmt->ret_type = primitive_type;

modifier_.insert_before(stmt, std::move(binary_stmt));
delayed_modifier_.insert_before(stmt, std::move(binary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -334,8 +335,9 @@ class Scalarize : public BasicStmtVisitor {
if (!merged_string.empty())
merged_contents.push_back(merged_string);

modifier_.insert_before(stmt, Stmt::make<PrintStmt>(merged_contents));
modifier_.erase(stmt);
delayed_modifier_.insert_before(stmt,
Stmt::make<PrintStmt>(merged_contents));
delayed_modifier_.erase(stmt);
}

/*
Expand Down Expand Up @@ -403,19 +405,19 @@ class Scalarize : public BasicStmtVisitor {

matrix_init_values.push_back(atomic_stmt.get());

modifier_.insert_before(stmt, std::move(const_stmt));
modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
modifier_.insert_before(stmt, std::move(atomic_stmt));
delayed_modifier_.insert_before(stmt, std::move(const_stmt));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt));
delayed_modifier_.insert_before(stmt, std::move(atomic_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -487,17 +489,17 @@ class Scalarize : public BasicStmtVisitor {
matrix_init_values.push_back(ternary_stmt.get());
ternary_stmt->ret_type = primitive_type;

modifier_.insert_before(stmt, std::move(ternary_stmt));
delayed_modifier_.insert_before(stmt, std::move(ternary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand All @@ -522,10 +524,10 @@ class Scalarize : public BasicStmtVisitor {
auto arg_load =
std::make_unique<ArgLoadStmt>(stmt->arg_id, ret_type, stmt->is_ptr);

stmt->replace_usages_with(arg_load.get());
immediate_modifier_.replace_usages_with(stmt, arg_load.get());

modifier_.insert_before(stmt, std::move(arg_load));
modifier_.erase(stmt);
delayed_modifier_.insert_before(stmt, std::move(arg_load));
delayed_modifier_.erase(stmt);
}

private:
Expand All @@ -534,15 +536,16 @@ class Scalarize : public BasicStmtVisitor {

class ScalarizePointers : public BasicStmtVisitor {
public:
DelayedIRModifier modifier_;
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;

// { original_alloca_stmt : [scalarized_alloca_stmt0, ...] }
std::unordered_map<Stmt *, std::vector<Stmt *>> scalarized_local_tensor_map_;

explicit ScalarizePointers(IRNode *node) {
explicit ScalarizePointers(IRNode *node) : immediate_modifier_(node) {
node->accept(this);

modifier_.modify_ir();
delayed_modifier_.modify_ir();
}

/*
Expand Down Expand Up @@ -584,10 +587,11 @@ class ScalarizePointers : public BasicStmtVisitor {

scalarized_local_tensor_map_[stmt].push_back(
scalarized_alloca_stmt.get());
modifier_.insert_before(stmt, std::move(scalarized_alloca_stmt));
delayed_modifier_.insert_before(stmt,
std::move(scalarized_alloca_stmt));
}

modifier_.erase(stmt);
delayed_modifier_.erase(stmt);
}
}

Expand Down Expand Up @@ -617,7 +621,7 @@ class ScalarizePointers : public BasicStmtVisitor {
// handled
if (!stmt->offset->is<ConstStmt>()) {
// Removing this line will fail TI_ASSERT in ~DelayedIRModifier()
modifier_.modify_ir();
delayed_modifier_.modify_ir();
throw TaichiSyntaxError(fmt::format(
"{}The index of a Matrix/Vector must be a compile-time constant "
"integer.\n"
Expand All @@ -638,8 +642,8 @@ class ScalarizePointers : public BasicStmtVisitor {
TI_ASSERT(offset < scalarized_alloca_stmts.size());
auto alloca_stmt = scalarized_alloca_stmts[offset];

stmt->replace_usages_with(alloca_stmt);
modifier_.erase(stmt);
immediate_modifier_.replace_usages_with(stmt, alloca_stmt);
delayed_modifier_.erase(stmt);
}
}
}
Expand Down

0 comments on commit 26b81e7

Please sign in to comment.