From 25b9a10a44bcbb78d86bf9b881f7d6892f70ea25 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Fri, 25 Sep 2020 12:08:35 +0800 Subject: [PATCH] [async] Cache demote_activation (#1889) * [async] Cache demote_activation * [skip ci] Apply suggestions from code review Co-authored-by: Yuanming Hu * [skip ci] enforce code format * [skip ci] Edit comment Co-authored-by: Yuanming Hu Co-authored-by: Taichi Gardener --- taichi/program/async_engine.cpp | 90 +++++++++++++++++++++++++++ taichi/program/async_engine.h | 3 + taichi/program/async_utils.h | 4 ++ taichi/program/state_flow_graph.cpp | 94 ++++------------------------- 4 files changed, 108 insertions(+), 83 deletions(-) diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 4f7093453073e..c111de9f02e87 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -106,6 +106,96 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { return result; } +// TODO: make this an IR pass +class ConstExprPropagation { + public: + static std::unordered_set run( + Block *block, + const std::function &is_const_seed) { + std::unordered_set const_stmts; + + auto is_const = [&](Stmt *stmt) { + if (is_const_seed(stmt)) { + return true; + } else { + return const_stmts.find(stmt) != const_stmts.end(); + } + }; + + for (auto &s : block->statements) { + if (is_const(s.get())) { + const_stmts.insert(s.get()); + } else if (auto binary = s->cast()) { + if (is_const(binary->lhs) && is_const(binary->rhs)) { + const_stmts.insert(s.get()); + } + } else if (auto unary = s->cast()) { + if (is_const(unary->operand)) { + const_stmts.insert(s.get()); + } + } else { + // TODO: ... + } + } + + return const_stmts; + } +}; + +IRHandle IRBank::demote_activation(IRHandle handle) { + auto &result = demote_activation_bank_[handle]; + if (!result.empty()) { + return result; + } + + std::unique_ptr new_ir = handle.clone(); + + OffloadedStmt *offload = new_ir->as(); + Block *body = offload->body.get(); + + auto snode = offload->snode; + TI_ASSERT(snode != nullptr); + + // TODO: for now we only deal with the top level. Is there an easy way to + // extend this part? + auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) { + if (stmt->is()) { + return true; + } else if (stmt->is()) + return true; + return false; + }); + + bool demoted = false; + for (int k = 0; k < (int)body->statements.size(); k++) { + Stmt *stmt = body->statements[k].get(); + if (auto ptr = stmt->cast(); ptr && ptr->activate) { + bool can_demote = true; + // TODO: test input mask? + for (auto ind : ptr->indices) { + if (consts.find(ind) == consts.end()) { + // non-constant index + can_demote = false; + } + } + if (can_demote) { + ptr->activate = false; + demoted = true; + } + } + } + + if (!demoted) { + // Nothing demoted. Simply delete new_ir when this function returns. + result = handle; + return result; + } + + result = IRHandle(new_ir.get(), get_hash(new_ir.get())); + insert(std::move(new_ir), result.hash()); + return result; +} + ParallelExecutor::ParallelExecutor(int num_threads) : num_threads(num_threads), status(ExecutorStatus::uninitialized), diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index 921720a73d2e3..e7760aa84f6d2 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -31,6 +31,8 @@ class IRBank { // Fuse handle_b into handle_a IRHandle fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel); + IRHandle demote_activation(IRHandle handle); + std::unordered_map meta_bank_; std::unordered_map fusion_meta_bank_; @@ -39,6 +41,7 @@ class IRBank { std::unordered_map> ir_bank_; std::vector> trash_bin; // prevent IR from deleted std::unordered_map, IRHandle> fuse_bank_; + std::unordered_map demote_activation_bank_; }; class ParallelExecutor { diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 2fcd4ec8b6f0a..96ec7517e7eda 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -41,6 +41,10 @@ class IRHandle { return hash_ == other_ir_handle.hash_; } + bool operator!=(const IRHandle &other_ir_handle) const { + return !(*this == other_ir_handle); + } + bool operator<(const IRHandle &other_ir_handle) const { return hash_ < other_ir_handle.hash_; } diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 6ed6ddd0cf235..43875c7ee8fd2 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -935,42 +935,6 @@ void StateFlowGraph::verify() { topo_sort_nodes(); } -// TODO: make this an IR pass -class ConstExprPropagation { - public: - static std::unordered_set run( - Block *block, - const std::function &is_const_seed) { - std::unordered_set const_stmts; - - auto is_const = [&](Stmt *stmt) { - if (is_const_seed(stmt)) { - return true; - } else { - return const_stmts.find(stmt) != const_stmts.end(); - } - }; - - for (auto &s : block->statements) { - if (is_const(s.get())) { - const_stmts.insert(s.get()); - } else if (auto binary = s->cast()) { - if (is_const(binary->lhs) && is_const(binary->rhs)) { - const_stmts.insert(s.get()); - } - } else if (auto unary = s->cast()) { - if (is_const(unary->operand)) { - const_stmts.insert(s.get()); - } - } else { - // TODO: ... - } - } - - return const_stmts; - } -}; - bool StateFlowGraph::demote_activation() { bool modified = false; @@ -1001,55 +965,19 @@ bool StateFlowGraph::demote_activation() { if (nodes.size() <= 1) continue; - auto snode = nodes[0]->meta->snode; - - auto list_state = AsyncState(snode, AsyncState::Type::list); - - TI_ASSERT(snode != nullptr); - - std::unique_ptr new_ir = nodes[0]->rec.ir_handle.clone(); - - OffloadedStmt *offload = new_ir->as(); - Block *body = offload->body.get(); - - // TODO: for now we only deal with the top level. Is there an easy way to - // extend this part? - auto consts = ConstExprPropagation::run(body, [](Stmt *stmt) { - if (stmt->is()) { - return true; - } else if (stmt->is()) - return true; - return false; - }); - - bool demoted = false; - for (int k = 0; k < (int)body->statements.size(); k++) { - Stmt *stmt = body->statements[k].get(); - if (auto ptr = stmt->cast(); ptr && ptr->activate) { - bool can_demote = true; - // TODO: test input mask? - for (auto ind : ptr->indices) { - if (consts.find(ind) == consts.end()) { - // non-constant index - can_demote = false; - } - } - if (can_demote) { - modified = true; - ptr->activate = false; - demoted = true; - } - } - } - // TODO: cache this part - auto new_handle = IRHandle(new_ir.get(), ir_bank_->get_hash(new_ir.get())); - ir_bank_->insert(std::move(new_ir), new_handle.hash()); - auto new_meta = get_task_meta(ir_bank_, nodes[0]->rec); - if (demoted) { - for (int j = 1; j < (int)nodes.size(); j++) { + auto new_handle = ir_bank_->demote_activation(nodes[0]->rec.ir_handle); + if (new_handle != nodes[0]->rec.ir_handle) { + modified = true; + nodes[1]->rec.ir_handle = new_handle; + nodes[1]->meta = get_task_meta(ir_bank_, nodes[1]->rec); + for (int j = 2; j < (int)nodes.size(); j++) { nodes[j]->rec.ir_handle = new_handle; - nodes[j]->meta = new_meta; + nodes[j]->meta = nodes[1]->meta; } + // For every "demote_activation" call, we only optimize for a single key + // in std::map, std::vector> tasks + // since the graph probably needs to be rebuild after demoting + // part of the tasks. break; } }