diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index adf4db4dcf1aa..dd0f3743fabaf 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -400,7 +400,11 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { replace_with(i, std::move(local_load), true); modified = true; continue; - } else if (!is_parallel_executed) { + } else if (!is_parallel_executed || + (atomic->dest->is() && + atomic->dest->as() + ->snodes[0] + ->is_scalar())) { // If this node is parallel executed, we can't weaken a global // atomic operation to a global load. // TODO: we can weaken it if it's element-wise (i.e. never @@ -619,7 +623,9 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { } } -void ControlFlowGraph::live_variable_analysis(bool after_lower_access) { +void ControlFlowGraph::live_variable_analysis( + bool after_lower_access, + const std::optional &config_opt) { TI_AUTO_PROF; const int num_nodes = size(); std::queue to_visit; @@ -627,14 +633,27 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access) { TI_ASSERT(nodes[final_node]->empty()); nodes[final_node]->live_gen.clear(); nodes[final_node]->live_kill.clear(); + + auto in_final_node_live_gen = [&config_opt](const Stmt *stmt) -> bool { + if (stmt->is() || stmt->is()) { + return false; + } + if (auto *gptr = stmt->cast(); + gptr && config_opt.has_value()) { + TI_ASSERT(gptr->snodes.size() == 1); + const bool res = + (config_opt->eliminable_snodes.count(gptr->snodes[0]) == 0); + return res; + } + // A global pointer that may be loaded after this kernel. + return true; + }; if (!after_lower_access) { for (int i = 0; i < num_nodes; i++) { for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { auto stmt = nodes[i]->block->statements[j].get(); for (auto store_ptr : irpass::analysis::get_store_destination(stmt)) { - if (!store_ptr->is() && - !store_ptr->is()) { - // A global pointer that may be loaded after this kernel. + if (in_final_node_live_gen(store_ptr)) { nodes[final_node]->live_gen.insert(store_ptr); } } @@ -753,9 +772,11 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access) { return modified; } -bool ControlFlowGraph::dead_store_elimination(bool after_lower_access) { +bool ControlFlowGraph::dead_store_elimination( + bool after_lower_access, + const std::optional &lva_config_opt) { TI_AUTO_PROF; - live_variable_analysis(after_lower_access); + live_variable_analysis(after_lower_access, lva_config_opt); const int num_nodes = size(); bool modified = false; for (int i = 0; i < num_nodes; i++) { diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 932c38e735215..d82c00959aa4e 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -1,5 +1,8 @@ #pragma once +#include +#include + #include "taichi/ir/ir.h" TLANG_NAMESPACE_BEGIN @@ -70,6 +73,12 @@ class ControlFlowGraph { void erase(int node_id); public: + struct LiveVarAnalysisConfig { + // This is mostly useful for SFG task-level dead store elimination. SFG may + // detect certain cases where writes to one or more SNodes in a task are + // eliminable. + std::unordered_set eliminable_snodes; + }; std::vector> nodes; const int start_node = 0; int final_node{0}; @@ -85,7 +94,9 @@ class ControlFlowGraph { void print_graph_structure() const; void reaching_definition_analysis(bool after_lower_access); - void live_variable_analysis(bool after_lower_access); + void live_variable_analysis( + bool after_lower_access, + const std::optional &config_opt); void simplify_graph(); @@ -96,7 +107,9 @@ class ControlFlowGraph { bool store_to_load_forwarding(bool after_lower_access); // Also performs identical load elimination. - bool dead_store_elimination(bool after_lower_access); + bool dead_store_elimination( + bool after_lower_access, + const std::optional &lva_config_opt); }; TLANG_NAMESPACE_END diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index bedb8cbac8ad1..3490f5da3c89c 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -120,6 +120,10 @@ bool SNode::is_place() const { return type == SNodeType::place; } +bool SNode::is_scalar() const { + return is_place() && (num_active_indices == 0); +} + bool SNode::has_grad() const { auto adjoint = expr.cast()->adjoint; return is_primal() && adjoint.expr != nullptr && diff --git a/taichi/ir/snode.h b/taichi/ir/snode.h index 6c5c21da081ec..0ad0c695ff652 100644 --- a/taichi/ir/snode.h +++ b/taichi/ir/snode.h @@ -217,6 +217,8 @@ class SNode { bool is_place() const; + bool is_scalar() const; + const Expr &get_expr() const { return expr; } diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 6a52a701082f6..b483d4b37e71c 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -1,9 +1,12 @@ #pragma once -#include "taichi/ir/ir.h" #include -#include +#include #include +#include + +#include "taichi/ir/control_flow_graph.h" +#include "taichi/ir/ir.h" TLANG_NAMESPACE_BEGIN @@ -19,7 +22,11 @@ void re_id(IRNode *root); void flag_access(IRNode *root); bool die(IRNode *root); bool simplify(IRNode *root, Kernel *kernel = nullptr); -bool cfg_optimization(IRNode *root, bool after_lower_access); +bool cfg_optimization( + IRNode *root, + bool after_lower_access, + const std::optional + &lva_config_opt = std::nullopt); bool alg_simp(IRNode *root); bool demote_operations(IRNode *root); bool binary_op_simplify(IRNode *root); diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index c1d72451baee6..e32d3b8cde2a3 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -127,7 +127,10 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { if (auto global_store = stmt->cast()) { if (auto ptr = global_store->ptr->cast()) { for (auto &snode : ptr->snodes.data) { - meta.input_states.emplace(snode, AsyncState::Type::value); + if (!snode->is_scalar()) { + // TODO: This is ad-hoc, use value killing analysis + meta.input_states.emplace(snode, AsyncState::Type::value); + } meta.output_states.emplace(snode, AsyncState::Type::value); } } diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index 0059e8438a9b8..16877249aa68b 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -189,4 +189,41 @@ IRHandle IRBank::demote_activation(IRHandle handle) { return result; } +std::pair IRBank::optimize_dse( + IRHandle handle, + const std::set &snodes, + bool verbose) { + const OptimizeDseKey key(handle, snodes); + auto &ret_handle = optimize_dse_bank_[key]; + if (!ret_handle.empty()) { + // Already cached + return std::make_pair(ret_handle, true); + } + + std::unique_ptr new_ir = handle.clone(); + + if (verbose) { + TI_INFO(" DSE: before CFG"); + irpass::print(new_ir.get()); + } + ControlFlowGraph::LiveVarAnalysisConfig lva_config; + lva_config.eliminable_snodes = {snodes.begin(), snodes.end()}; + const bool modified = irpass::cfg_optimization( + new_ir.get(), /*after_lower_access=*/false, lva_config); + if (verbose) { + TI_INFO(" DSE: after CFG, modified={}", modified); + irpass::print(new_ir.get()); + } + + if (!modified) { + // Nothing demoted. Simply delete new_ir when this function returns. + ret_handle = handle; + return std::make_pair(ret_handle, false); + } + + ret_handle = IRHandle(new_ir.get(), get_hash(new_ir.get())); + insert(std::move(new_ir), ret_handle.hash()); + return std::make_pair(ret_handle, false); +} + TLANG_NAMESPACE_END diff --git a/taichi/program/ir_bank.h b/taichi/program/ir_bank.h index 22ebb25fb3693..6e2079aa0cbec 100644 --- a/taichi/program/ir_bank.h +++ b/taichi/program/ir_bank.h @@ -1,3 +1,7 @@ +#include +#include +#include + #include "taichi/program/async_utils.h" TLANG_NAMESPACE_BEGIN @@ -16,6 +20,16 @@ class IRBank { IRHandle demote_activation(IRHandle handle); + // Try running DSE optimization on the IR identified by |handle|. |snodes| + // denotes the set of SNodes whose stores are safe to eliminate. + // + // Returns: + // * IRHandle: the (possibly) DSE-optimized IRHandle + // * bool: whether the result is already cached. + std::pair optimize_dse(IRHandle handle, + const std::set &snodes, + bool verbose); + std::unordered_map meta_bank_; std::unordered_map fusion_meta_bank_; @@ -25,6 +39,41 @@ class IRBank { std::vector> trash_bin; // prevent IR from deleted std::unordered_map, IRHandle> fuse_bank_; std::unordered_map demote_activation_bank_; + + // For DSE optimization, the input key is (IRHandle, [SNode*]). This is + // because it is possible that the same IRHandle may have different sets of + // SNode stores that are eliminable. + struct OptimizeDseKey { + IRHandle task_ir; + // Intentionally use (ordered) set so that hash is deterministic. + std::set eliminable_snodes; + + OptimizeDseKey(const IRHandle task_ir, + const std::set &snodes) + : task_ir(task_ir), eliminable_snodes(snodes) { + } + + bool operator==(const OptimizeDseKey &other) const { + return (task_ir == other.task_ir) && + (eliminable_snodes == other.eliminable_snodes); + } + + bool operator!=(const OptimizeDseKey &other) const { + return !(*this == other); + } + + struct Hash { + std::size_t operator()(const OptimizeDseKey &k) const { + std::size_t ret = k.task_ir.hash(); + for (const auto *s : k.eliminable_snodes) { + ret = ret * 100000007UL + reinterpret_cast(s); + } + return ret; + } + }; + }; + std::unordered_map + optimize_dse_bank_; }; TLANG_NAMESPACE_END diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 31da118b4992a..9b2f7946abd76 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -1,13 +1,14 @@ #include "taichi/program/state_flow_graph.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/analysis.h" -#include "taichi/program/async_engine.h" - #include #include #include +#include "taichi/ir/analysis.h" +#include "taichi/ir/transforms.h" +#include "taichi/program/async_engine.h" +#include "taichi/util/statistics.h" + TLANG_NAMESPACE_BEGIN // TODO: rename state to edge since we have not only state flow edges but also @@ -793,55 +794,75 @@ bool StateFlowGraph::optimize_dead_store() { // Dive into this task and erase dead stores auto &task = nodes_[i]; + std::set store_eliminable_snodes; // Try to find unnecessary output state for (auto &s : task->meta->output_states) { + if (s.type != AsyncState::Type::value) { + // Listgen elimination has been handled in optimize_listgen, so we will + // only focus on "value" states. + continue; + } + if (latest_state_owner_[s] == task.get()) { + // Cannot eliminate the latest write, because it may form a state-flow + // with the later kernel launches. + // + // TODO: Add some sort of hints so that the compiler knows that some + // value will never be used? + continue; + } + auto *snode = s.snode; + if (!snode->is_scalar()) { + // TODO: handle non-scalar SNodes, i.e. num_active_indices > 0. + continue; + } bool used = false; for (auto other : task->output_edges[s]) { - if (task->has_state_flow(s, other)) { + if (task->has_state_flow(s, other) && + (other->meta->input_states.count(s) > 0)) { + // Check if this is a RAW dependency. For scalar SNodes, a WAW flow + // edge decades to a dependency edge. + // + // TODO: This is a hack that only works for scalar SNodes. The proper + // handling would require value killing analysis. used = true; } else { // Note that a dependency edge does not count as an data usage } } // This state is used by some other node, so it cannot be erased - if (used) - continue; - - if (s.type != AsyncState::Type::list && - latest_state_owner_[s] == task.get()) - // Note that list state is special. Since a future list generation - // always comes with ClearList, we can erase the list state even if it - // is latest. + if (used) { continue; + } - // ***************************** - // Erase the state s output. - if (s.type == AsyncState::Type::list && - task->meta->type == OffloadedStmt::TaskType::serial) { - // Try to erase list gen - DelayedIRModifier mod; - - auto new_ir = task->rec.ir_handle.clone(); - irpass::analysis::gather_statements(new_ir.get(), [&](Stmt *stmt) { - // TODO: invoke mod.erase(stmt) when necessary; - return false; - }); - if (mod.modify_ir()) { - // IR modified. Node should be updated. - auto handle = - IRHandle(new_ir.get(), ir_bank_->get_hash(new_ir.get())); - ir_bank_->insert(std::move(new_ir), handle.hash()); - task->rec.ir_handle = handle; - // task->meta->print(); - task->meta = get_task_meta(ir_bank_, task->rec); - // task->meta->print(); + store_eliminable_snodes.insert(snode); + } - for (auto other : task->output_edges[s]) - other->input_edges[s].erase(task.get()); + // ***************************** + // Erase the state s output. + if (!store_eliminable_snodes.empty()) { + const bool verbose = task->rec.kernel->program.config.verbose; - task->output_edges.erase(s); - modified = true; + const auto dse_result = ir_bank_->optimize_dse( + task->rec.ir_handle, store_eliminable_snodes, verbose); + auto new_handle = dse_result.first; + if (new_handle != task->rec.ir_handle) { + modified = true; + task->rec.ir_handle = new_handle; + task->meta = get_task_meta(ir_bank_, task->rec); + } + bool first_compute = !dse_result.second; + if (first_compute && modified) { + stat.add("sfg_dse_tasks", 1.0); + } + if (first_compute && verbose) { + // Log only for the first time, otherwise we will be overwhelmed very + // quickly... + std::vector snodes_strs; + for (const auto *sn : store_eliminable_snodes) { + snodes_strs.push_back(sn->get_node_type_name_hinted()); } + TI_INFO("SFG DSE: task={} snodes={} optimized?={}", task->string(), + fmt::join(snodes_strs, ", "), modified); } } } @@ -851,19 +872,19 @@ bool StateFlowGraph::optimize_dead_store() { for (int i = 1; i < (int)nodes_.size(); i++) { auto &meta = *nodes_[i]->meta; auto ir = nodes_[i]->rec.ir_handle.ir()->cast(); - if (meta.type == OffloadedStmt::serial && ir->body->statements.empty()) { - to_delete.insert(i); - } else if (meta.type == OffloadedStmt::struct_for && - ir->body->statements.empty()) { - to_delete.insert(i); - } else if (meta.type == OffloadedStmt::range_for && - ir->body->statements.empty()) { + const auto mt = meta.type; + // Do NOT check ir->body->statements first! |ir->body| could be done when + // |mt| is not the desired type. + if ((mt == OffloadedStmt::serial || mt == OffloadedStmt::struct_for || + mt == OffloadedStmt::range_for) && + ir->body->statements.empty()) { to_delete.insert(i); } } - if (!to_delete.empty()) + if (!to_delete.empty()) { modified = true; + } delete_nodes(to_delete); diff --git a/taichi/transforms/cfg_optimization.cpp b/taichi/transforms/cfg_optimization.cpp index 6b3758761e6d7..093053cddd325 100644 --- a/taichi/transforms/cfg_optimization.cpp +++ b/taichi/transforms/cfg_optimization.cpp @@ -6,7 +6,11 @@ TLANG_NAMESPACE_BEGIN namespace irpass { -bool cfg_optimization(IRNode *root, bool after_lower_access) { +bool cfg_optimization( + IRNode *root, + bool after_lower_access, + const std::optional + &lva_config_opt) { TI_AUTO_PROF; auto cfg = analysis::build_cfg(root); bool result_modified = false; @@ -15,7 +19,7 @@ bool cfg_optimization(IRNode *root, bool after_lower_access) { cfg->simplify_graph(); if (cfg->store_to_load_forwarding(after_lower_access)) modified = true; - if (cfg->dead_store_elimination(after_lower_access)) + if (cfg->dead_store_elimination(after_lower_access, lva_config_opt)) modified = true; if (modified) result_modified = true; diff --git a/tests/python/test_sfg.py b/tests/python/test_sfg.py index 6334e348cd2e4..1d7a3cc32f9a3 100644 --- a/tests/python/test_sfg.py +++ b/tests/python/test_sfg.py @@ -1,4 +1,5 @@ import taichi as ti +import numpy as np import pytest @@ -58,3 +59,46 @@ def serial_z(): else: assert ys[i] == i + 2 assert xs[i] == 0 + + +@ti.test(require=ti.extension.async_mode, async_mode=True) +def test_sfg_dead_store_elimination(): + ti.init(arch=ti.cpu, async_mode=True) + n = 32 + + x = ti.field(dtype=float, shape=n, needs_grad=True) + total_energy = ti.field(dtype=float, shape=(), needs_grad=True) + unused = ti.field(dtype=float, shape=()) + + @ti.kernel + def gather(): + for i in x: + e = x[i]**2 + total_energy[None] += e + + @ti.kernel + def scatter(): + for i in x: + unused[None] += x[i] + + xnp = np.arange(n, dtype=np.float32) + x.from_numpy(xnp) + ti.sync() + + stats = ti.get_kernel_stats() + stats.clear() + + for _ in range(5): + with ti.Tape(total_energy): + gather() + scatter() + + ti.sync() + counters = stats.get_counters() + + # gather() should be DSE'ed + assert counters['sfg_dse_tasks'] > 0 + + x_grad = x.grad.to_numpy() + for i in range(n): + assert ti.approx(x_grad[i]) == 2.0 * i