diff --git a/taichi/analysis/gather_statement_usages.cpp b/taichi/analysis/gather_statement_usages.cpp new file mode 100644 index 0000000000000..ba12d58e4a313 --- /dev/null +++ b/taichi/analysis/gather_statement_usages.cpp @@ -0,0 +1,55 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" + +namespace taichi::lang { + +class GatherStatementUsages : public BasicStmtVisitor { + private: + using BasicStmtVisitor::visit; + + // maps a stmt to all its usages + std::unordered_map>> stmt_usages_; + + public: + explicit GatherStatementUsages() { + invoke_default_visitor = true; + } + + void default_visit(Stmt *stmt) { + auto ops = stmt->get_operands(); + for (int i = 0; i < ops.size(); i++) { + auto &op = ops[i]; + if (op != nullptr) { + stmt_usages_[op].push_back({stmt, i}); + } + } + } + + void visit(Stmt *stmt) override { + default_visit(stmt); + } + + void preprocess_container_stmt(Stmt *stmt) override { + default_visit(stmt); + } + + static std::unordered_map>> run( + IRNode *node) { + GatherStatementUsages pass; + node->accept(&pass); + return pass.stmt_usages_; + } +}; + +namespace irpass::analysis { + +std::unordered_map>> +gather_statement_usages(IRNode *root) { + return GatherStatementUsages::run(root); +} + +} // namespace irpass::analysis + +} // namespace taichi::lang diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 69d68ee4a32d1..b477a9649cbd1 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -95,6 +95,8 @@ bool definitely_same_address(Stmt *var1, Stmt *var2); std::unordered_set detect_fors_with_break(IRNode *root); std::unordered_set detect_loops_with_continue(IRNode *root); +std::unordered_map>> +gather_statement_usages(IRNode *root); std::unordered_set gather_immutable_local_vars(IRNode *root); std::unordered_set gather_deactivations(IRNode *root); std::pair, std::unordered_set> diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 6e2f9d3695794..cffdf09443c77 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -4,7 +4,7 @@ #include #include -// #include "taichi/ir/analysis.h" +#include "taichi/ir/analysis.h" #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" @@ -496,4 +496,16 @@ void DelayedIRModifier::mark_as_modified() { modified_ = true; } +ImmediateIRModifier::ImmediateIRModifier(IRNode *root) { + stmt_usages_ = irpass::analysis::gather_statement_usages(root); +} + +void ImmediateIRModifier::replace_usages_with(Stmt *old_stmt, Stmt *new_stmt) { + if (stmt_usages_.find(old_stmt) == stmt_usages_.end()) + return; + for (auto &[usage, i] : stmt_usages_.at(old_stmt)) { + usage->set_operand(i, new_stmt); + } +} + } // namespace taichi::lang diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 929dceccc9d02..12f22794e832e 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -609,6 +609,19 @@ class DelayedIRModifier { void mark_as_modified(); }; +// ImmediateIRModifier aims at replacing Stmt::replace_usages_with, which visits +// the whole tree for a single replacement. ImmediateIRModifier is currently +// associated with a pass, visits the whole tree once at the beginning of that +// pass, and performs a single replacement with amortized constant time. +class ImmediateIRModifier { + private: + std::unordered_map>> stmt_usages_; + + public: + explicit ImmediateIRModifier(IRNode *root); + void replace_usages_with(Stmt *old_stmt, Stmt *new_stmt); +}; + template inline void StmtFieldManager::operator()(const char *key, T &&value) { using decay_T = typename std::decay::type; diff --git a/taichi/transforms/eliminate_immutable_local_vars.cpp b/taichi/transforms/eliminate_immutable_local_vars.cpp index 036e96459f574..6fb1533823e5a 100644 --- a/taichi/transforms/eliminate_immutable_local_vars.cpp +++ b/taichi/transforms/eliminate_immutable_local_vars.cpp @@ -15,26 +15,29 @@ class EliminateImmutableLocalVars : public BasicStmtVisitor { private: using BasicStmtVisitor::visit; - DelayedIRModifier modifier_; std::unordered_set immutable_local_vars_; std::unordered_map immutable_local_var_to_value_; + ImmediateIRModifier immediate_modifier_; + DelayedIRModifier delayed_modifier_; public: explicit EliminateImmutableLocalVars( - const std::unordered_set &immutable_local_vars) - : immutable_local_vars_(immutable_local_vars) { + const std::unordered_set &immutable_local_vars, + IRNode *node) + : immutable_local_vars_(immutable_local_vars), immediate_modifier_(node) { } void visit(AllocaStmt *stmt) override { if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) { - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } void visit(LocalLoadStmt *stmt) override { if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) { - stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]); - modifier_.erase(stmt); + immediate_modifier_.replace_usages_with( + stmt, immutable_local_var_to_value_[stmt->src]); + delayed_modifier_.erase(stmt); } } @@ -43,15 +46,15 @@ class EliminateImmutableLocalVars : public BasicStmtVisitor { TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) == immutable_local_var_to_value_.end()); immutable_local_var_to_value_[stmt->dest] = stmt->val; - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } static void run(IRNode *node) { EliminateImmutableLocalVars pass( - irpass::analysis::gather_immutable_local_vars(node)); + irpass::analysis::gather_immutable_local_vars(node), node); node->accept(&pass); - pass.modifier_.modify_ir(); + pass.delayed_modifier_.modify_ir(); } }; diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index bf55ed1538d5c..2c2ecfddee2d9 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -20,12 +20,13 @@ static bool is_alloca_scalarizable(AllocaStmt *stmt) { class Scalarize : public BasicStmtVisitor { public: - DelayedIRModifier modifier_; + ImmediateIRModifier immediate_modifier_; + DelayedIRModifier delayed_modifier_; - explicit Scalarize(IRNode *node) { + explicit Scalarize(IRNode *node) : immediate_modifier_(node) { node->accept(this); - modifier_.modify_ir(); + delayed_modifier_.modify_ir(); } /* @@ -75,12 +76,12 @@ class Scalarize : public BasicStmtVisitor { auto scalarized_stmt = std::make_unique(matrix_ptr_stmt.get(), matrix_init_stmt->values[i]); - modifier_.insert_before(stmt, std::move(const_stmt)); - modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); - modifier_.insert_before(stmt, std::move(scalarized_stmt)); + delayed_modifier_.insert_before(stmt, std::move(const_stmt)); + delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); + delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt)); } - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -127,19 +128,19 @@ class Scalarize : public BasicStmtVisitor { matrix_init_values.push_back(scalarized_stmt.get()); - modifier_.insert_before(stmt, std::move(const_stmt)); - modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); - modifier_.insert_before(stmt, std::move(scalarized_stmt)); + delayed_modifier_.insert_before(stmt, std::move(const_stmt)); + delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); + delayed_modifier_.insert_before(stmt, std::move(scalarized_stmt)); } auto matrix_init_stmt = std::make_unique(matrix_init_values); matrix_init_stmt->ret_type = src_dtype; - stmt->replace_usages_with(matrix_init_stmt.get()); - modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -186,17 +187,17 @@ class Scalarize : public BasicStmtVisitor { unary_stmt->ret_type = primitive_type; matrix_init_values.push_back(unary_stmt.get()); - modifier_.insert_before(stmt, std::move(unary_stmt)); + delayed_modifier_.insert_before(stmt, std::move(unary_stmt)); } auto matrix_init_stmt = std::make_unique(matrix_init_values); matrix_init_stmt->ret_type = operand_dtype; - stmt->replace_usages_with(matrix_init_stmt.get()); - modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -256,17 +257,17 @@ class Scalarize : public BasicStmtVisitor { matrix_init_values.push_back(binary_stmt.get()); binary_stmt->ret_type = primitive_type; - modifier_.insert_before(stmt, std::move(binary_stmt)); + delayed_modifier_.insert_before(stmt, std::move(binary_stmt)); } auto matrix_init_stmt = std::make_unique(matrix_init_values); matrix_init_stmt->ret_type = stmt->ret_type; - stmt->replace_usages_with(matrix_init_stmt.get()); - modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -334,8 +335,9 @@ class Scalarize : public BasicStmtVisitor { if (!merged_string.empty()) merged_contents.push_back(merged_string); - modifier_.insert_before(stmt, Stmt::make(merged_contents)); - modifier_.erase(stmt); + delayed_modifier_.insert_before(stmt, + Stmt::make(merged_contents)); + delayed_modifier_.erase(stmt); } /* @@ -403,19 +405,19 @@ class Scalarize : public BasicStmtVisitor { matrix_init_values.push_back(atomic_stmt.get()); - modifier_.insert_before(stmt, std::move(const_stmt)); - modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); - modifier_.insert_before(stmt, std::move(atomic_stmt)); + delayed_modifier_.insert_before(stmt, std::move(const_stmt)); + delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_stmt)); + delayed_modifier_.insert_before(stmt, std::move(atomic_stmt)); } auto matrix_init_stmt = std::make_unique(matrix_init_values); matrix_init_stmt->ret_type = stmt->ret_type; - stmt->replace_usages_with(matrix_init_stmt.get()); - modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -487,17 +489,17 @@ class Scalarize : public BasicStmtVisitor { matrix_init_values.push_back(ternary_stmt.get()); ternary_stmt->ret_type = primitive_type; - modifier_.insert_before(stmt, std::move(ternary_stmt)); + delayed_modifier_.insert_before(stmt, std::move(ternary_stmt)); } auto matrix_init_stmt = std::make_unique(matrix_init_values); matrix_init_stmt->ret_type = stmt->ret_type; - stmt->replace_usages_with(matrix_init_stmt.get()); - modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -522,10 +524,10 @@ class Scalarize : public BasicStmtVisitor { auto arg_load = std::make_unique(stmt->arg_id, ret_type, stmt->is_ptr); - stmt->replace_usages_with(arg_load.get()); + immediate_modifier_.replace_usages_with(stmt, arg_load.get()); - modifier_.insert_before(stmt, std::move(arg_load)); - modifier_.erase(stmt); + delayed_modifier_.insert_before(stmt, std::move(arg_load)); + delayed_modifier_.erase(stmt); } private: @@ -534,15 +536,16 @@ class Scalarize : public BasicStmtVisitor { class ScalarizePointers : public BasicStmtVisitor { public: - DelayedIRModifier modifier_; + ImmediateIRModifier immediate_modifier_; + DelayedIRModifier delayed_modifier_; // { original_alloca_stmt : [scalarized_alloca_stmt0, ...] } std::unordered_map> scalarized_local_tensor_map_; - explicit ScalarizePointers(IRNode *node) { + explicit ScalarizePointers(IRNode *node) : immediate_modifier_(node) { node->accept(this); - modifier_.modify_ir(); + delayed_modifier_.modify_ir(); } /* @@ -584,10 +587,11 @@ class ScalarizePointers : public BasicStmtVisitor { scalarized_local_tensor_map_[stmt].push_back( scalarized_alloca_stmt.get()); - modifier_.insert_before(stmt, std::move(scalarized_alloca_stmt)); + delayed_modifier_.insert_before(stmt, + std::move(scalarized_alloca_stmt)); } - modifier_.erase(stmt); + delayed_modifier_.erase(stmt); } } @@ -617,7 +621,7 @@ class ScalarizePointers : public BasicStmtVisitor { // handled if (!stmt->offset->is()) { // Removing this line will fail TI_ASSERT in ~DelayedIRModifier() - modifier_.modify_ir(); + delayed_modifier_.modify_ir(); throw TaichiSyntaxError(fmt::format( "{}The index of a Matrix/Vector must be a compile-time constant " "integer.\n" @@ -638,8 +642,8 @@ class ScalarizePointers : public BasicStmtVisitor { TI_ASSERT(offset < scalarized_alloca_stmts.size()); auto alloca_stmt = scalarized_alloca_stmts[offset]; - stmt->replace_usages_with(alloca_stmt); - modifier_.erase(stmt); + immediate_modifier_.replace_usages_with(stmt, alloca_stmt); + delayed_modifier_.erase(stmt); } } }