From 3a0eb884c1b52b272520d61b569705590cbd0522 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 7 Jul 2020 16:35:47 -0400 Subject: [PATCH 1/2] [Opt] [bug] Better aliasing analysis for dead store elimination --- taichi/analysis/data_source_analysis.cpp | 6 +++ taichi/ir/control_flow_graph.cpp | 41 ++++++++++++---- taichi/ir/control_flow_graph.h | 2 + taichi/ir/ir.cpp | 62 ++++++++++++++++++++++++ taichi/ir/ir.h | 1 + tests/python/test_ad_for.py | 12 ++--- 6 files changed, 109 insertions(+), 15 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 8ef81d47d0caa..e96f7d8640252 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -26,6 +26,12 @@ std::vector get_load_pointers(Stmt *load_stmt) { } else if (auto stack_acc_adj = load_stmt->cast()) { // This statement loads and stores the adjoint data. return std::vector(1, stack_acc_adj->stack); + } else if (auto stack_push = load_stmt->cast()) { + // This is to make dead store elimination not eliminate consequent pushes. + return std::vector(1, stack_push->stack); + } else if (auto stack_pop = load_stmt->cast()) { + // This is to make dead store elimination not eliminate consequent pops. + return std::vector(1, stack_pop->stack); } else { return std::vector(); } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index cb92bb04030ef..96ffa41d87382 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -82,8 +82,27 @@ bool CFGNode::contain_variable(const std::unordered_set &var_set, return var_set.find(var) != var_set.end(); } else { // TODO: How to optimize this? + if (var_set.find(var) != var_set.end()) + return true; for (auto set_var : var_set) { - if (irpass::analysis::same_statements(var, set_var)) { + if (definitely_same_address(var, set_var)) { + return true; + } + } + return false; + } +} + +bool CFGNode::may_contain_variable(const std::unordered_set &var_set, + Stmt *var) { + if (var->is() || var->is()) { + return var_set.find(var) != var_set.end(); + } else { + // TODO: How to optimize this? + if (var_set.find(var) != var_set.end()) + return true; + for (auto set_var : var_set) { + if (maybe_same_address(var, set_var)) { return true; } } @@ -290,16 +309,18 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { store_ptr = stack_push->stack; } else if (auto stack_acc_adj = stmt->cast()) { store_ptr = stack_acc_adj->stack; + } else if (stmt->is()) { + store_ptr = stmt; } if (store_ptr) { if (!after_lower_access || (store_ptr->is() || store_ptr->is())) { // After lower_access, we only analyze local variables and stacks. - // Do not eliminate AllocaStmt here. - if (!stmt->is() && + // Do not eliminate AllocaStmt and StackAllocaStmt here. + if (!stmt->is() && !stmt->is() && + !may_contain_variable(live_in_this_node, store_ptr) && (contain_variable(killed_in_this_node, store_ptr) || - (!contain_variable(live_out, store_ptr) && - !contain_variable(live_in_this_node, store_ptr)))) { + !may_contain_variable(live_out, store_ptr))) { // Neither used in other nodes nor used in this node. if (auto atomic = stmt->cast()) { // Weaken the atomic operation to a load. @@ -309,7 +330,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { local_load->ret_type = atomic->ret_type; replace_with(i, std::move(local_load), true); // Notice that we have a load here. - killed_in_this_node.erase(atomic->dest); live_in_this_node.insert(atomic->dest); modified = true; continue; @@ -322,7 +342,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { global_load->ret_type = atomic->ret_type; replace_with(i, std::move(global_load), true); // Notice that we have a load here. - killed_in_this_node.erase(atomic->dest); live_in_this_node.insert(atomic->dest); modified = true; continue; @@ -335,7 +354,12 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { } else { // A non-eliminated store. killed_in_this_node.insert(store_ptr); - live_in_this_node.erase(store_ptr); + auto old_live_in_this_node = std::move(live_in_this_node); + live_in_this_node.clear(); + for (auto &var : old_live_in_this_node) { + if (!definitely_same_address(store_ptr, var)) + live_in_this_node.insert(var); + } } } } @@ -344,7 +368,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { if (!after_lower_access || (load_ptr->is() || load_ptr->is())) { // After lower_access, we only analyze local variables and stacks. - killed_in_this_node.erase(load_ptr); live_in_this_node.insert(load_ptr); } } diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 0465d336c388b..f391b1a95f6f8 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -53,6 +53,8 @@ class CFGNode { static bool contain_variable(const std::unordered_set &var_set, Stmt *var); + static bool may_contain_variable(const std::unordered_set &var_set, + Stmt *var); void reaching_definition_analysis(bool after_lower_access); bool reach_kill_variable(Stmt *var) const; Stmt *get_store_forwarding_data(Stmt *var, int position) const; diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 1dc5fd4bb0821..0da591995a323 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -24,6 +24,51 @@ CompileConfig &IRNode::get_config() const { return get_kernel()->program.config; } +bool definitely_same_address(Stmt *var1, Stmt *var2) { + // Return true when two statements must be the same address; + // false when two statements can be different addresses. + + // If both stmts are allocas, they have the same address iff var1 == var2. + // If only one of them is an alloca, they can never share the same address. + if (var1 == var2) + return true; + if (!var1 || !var2) + return false; + if (var1->is() || var2->is()) + return false; + if (var1->is() || var2->is()) + return false; + + // TODO(xumingkuan): Put GlobalTemporaryStmt, ThreadLocalPtrStmt and + // BlockLocalPtrStmt into GlobalPtrStmt. + // If both statements are global temps, they have the same address iff they + // have the same offset. If only one of them is a global temp, they can never + // share the same address. + if (var1->is() || var2->is()) { + if (!var1->is() || !var2->is()) + return false; + return var1->as()->offset == + var2->as()->offset; + } + + if (var1->is() || var2->is()) { + if (!var1->is() || !var2->is()) + return false; + return var1->as()->offset == + var2->as()->offset; + } + + if (var1->is() || var2->is()) { + if (!var1->is() || !var2->is()) + return false; + return irpass::analysis::same_statements( + var1->as()->offset, + var2->as()->offset); + } + + return irpass::analysis::same_statements(var1, var2); +} + bool maybe_same_address(Stmt *var1, Stmt *var2) { // Return true when two statements might be the same address; // false when two statements cannot be the same address. @@ -36,6 +81,8 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) { return false; if (var1->is() || var2->is()) return false; + if (var1->is() || var2->is()) + return false; // If both statements are global temps, they have the same address iff they // have the same offset. If only one of them is a global temp, they can never @@ -47,6 +94,21 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) { var2->as()->offset; } + if (var1->is() || var2->is()) { + if (!var1->is() || !var2->is()) + return false; + return var1->as()->offset == + var2->as()->offset; + } + + if (var1->is() || var2->is()) { + if (!var1->is() || !var2->is()) + return false; + return irpass::analysis::same_statements( + var1->as()->offset, + var2->as()->offset); + } + // If both statements are GlobalPtrStmts or GetChStmts, we can check by // SNode::id. TI_ASSERT(var1->width() == 1); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index a7eab171b6ad1..3f752e6188b4f 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -33,6 +33,7 @@ using ScratchPadOptions = std::vector>; IRBuilder ¤t_ast_builder(); +bool definitely_same_address(Stmt *var1, Stmt *var2); bool maybe_same_address(Stmt *var1, Stmt *var2); struct VectorType { diff --git a/tests/python/test_ad_for.py b/tests/python/test_ad_for.py index 50d7911045dc1..9b23c21ad6db8 100644 --- a/tests/python/test_ad_for.py +++ b/tests/python/test_ad_for.py @@ -10,7 +10,7 @@ def test_ad_sum(): p = ti.var(ti.f32, shape=N, needs_grad=True) @ti.kernel - def comptue_sum(): + def compute_sum(): for i in range(N): ret = 1.0 for j in range(b[i]): @@ -21,13 +21,13 @@ def comptue_sum(): a[i] = 3 b[i] = i - comptue_sum() + compute_sum() for i in range(N): assert p[i] == 3 * b[i] + 1 p.grad[i] = 1 - comptue_sum.grad() + compute_sum.grad() for i in range(N): assert a.grad[i] == b[i] @@ -43,7 +43,7 @@ def test_ad_sum_local_atomic(): p = ti.var(ti.f32, shape=N, needs_grad=True) @ti.kernel - def comptue_sum(): + def compute_sum(): for i in range(N): ret = 1.0 for j in range(b[i]): @@ -54,13 +54,13 @@ def comptue_sum(): a[i] = 3 b[i] = i - comptue_sum() + compute_sum() for i in range(N): assert p[i] == 3 * b[i] + 1 p.grad[i] = 1 - comptue_sum.grad() + compute_sum.grad() for i in range(N): assert a.grad[i] == b[i] From 3f131806758b58a514155711fa45a26536c52b9a Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 7 Jul 2020 17:01:58 -0400 Subject: [PATCH 2/2] [skip ci] enforce code format --- taichi/ir/ir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 0da591995a323..2ad3db4fcc7e5 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -98,7 +98,7 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) { if (!var1->is() || !var2->is()) return false; return var1->as()->offset == - var2->as()->offset; + var2->as()->offset; } if (var1->is() || var2->is()) {