From 9a23ecd2719384e5215ec8d0c844d16a5bc61438 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Mon, 21 Sep 2020 15:55:56 +0800 Subject: [PATCH] [async] Add element-wise info into TaskMeta for fusion (#1884) * Refactor AsyncStateHash to std::hash * Add TaskMeta::element_wise * Make use of it in fuse() * Fix false positive when task type is serial * [skip ci] Apply suggestions from code review Co-authored-by: Yuanming Hu * [skip ci] enforce code format * retrigger CI Co-authored-by: Yuanming Hu Co-authored-by: Taichi Gardener --- taichi/ir/ir.cpp | 22 +++++++++++ taichi/ir/ir.h | 2 + taichi/program/async_engine.cpp | 9 +++++ taichi/program/async_utils.cpp | 22 +++++++++++ taichi/program/async_utils.h | 60 +++++++++++++++-------------- taichi/program/state_flow_graph.cpp | 39 ++++++++++++++----- taichi/program/state_flow_graph.h | 9 ++--- 7 files changed, 120 insertions(+), 43 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 02e376ce9a355..2fdecdc556e61 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -406,6 +406,28 @@ GlobalPtrStmt::GlobalPtrStmt(const LaneAttribute &snodes, TI_STMT_REG_FIELDS; } +bool GlobalPtrStmt::is_element_wise(SNode *snode) const { + if (snode == nullptr) { + // check every SNode when "snode" is nullptr + for (const auto &snode_i : snodes.data) { + if (!is_element_wise(snode_i)) { + return false; + } + } + return true; + } + // check if this statement is element-wise on a specific SNode, i.e., argument + // "snode" + for (int i = 0; i < (int)indices.size(); i++) { + if (auto loop_index_i = indices[i]->cast(); + !(loop_index_i && loop_index_i->loop->is() && + loop_index_i->index == snode->physical_index_position[i])) { + return false; + } + } + return true; +} + std::string GlobalPtrExpression::serialize() { std::string s = fmt::format("{}[", var.serialize()); for (int i = 0; i < (int)indices.size(); i++) { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 98ae2e470ff86..426b624ece4d1 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -851,6 +851,8 @@ class GlobalPtrStmt : public Stmt { const std::vector &indices, bool activate = true); + bool is_element_wise(SNode *snode) const; + bool has_global_side_effect() const override { return activate; } diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 7e0864f18bcd6..bc22a47116c22 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -362,6 +362,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { } } } + for (auto &snode : ptr->snodes.data) { + if (ptr->is_element_wise(snode)) { + if (meta.element_wise.find(snode) == meta.element_wise.end()) { + meta.element_wise[snode] = true; + } + } else { + meta.element_wise[snode] = false; + } + } } if (auto clear_list = stmt->cast()) { meta.output_states.emplace(clear_list->snode, AsyncState::Type::list); diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 9fcc6c1ce918a..3112246c2925f 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -55,6 +55,28 @@ void TaskMeta::print() const { } fmt::print("\n"); } + std::vector element_wise_snodes, non_element_wise_snodes; + for (auto s : element_wise) { + if (s.second) { + element_wise_snodes.push_back(s.first); + } else { + non_element_wise_snodes.push_back(s.first); + } + } + if (!element_wise_snodes.empty()) { + fmt::print(" element-wise snodes:\n "); + for (auto s : element_wise_snodes) { + fmt::print("{} ", s->get_node_type_name_hinted()); + } + fmt::print("\n"); + } + if (!non_element_wise_snodes.empty()) { + fmt::print(" non-element-wise snodes:\n "); + for (auto s : non_element_wise_snodes) { + fmt::print("{} ", s->get_node_type_name_hinted()); + } + fmt::print("\n"); + } } TLANG_NAMESPACE_END diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index cc95b33b3d8bc..984592ddb1b44 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -50,29 +50,6 @@ class IRHandle { uint64 hash_; }; -TLANG_NAMESPACE_END - -namespace std { -template <> -struct hash { - std::size_t operator()(const taichi::lang::IRHandle &ir_handle) const - noexcept { - return ir_handle.hash(); - } -}; - -template <> -struct hash> { - std::size_t operator()( - const std::pair - &ir_handles) const noexcept { - return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash(); - } -}; -} // namespace std - -TLANG_NAMESPACE_BEGIN - // Records the necessary data for launching an offloaded task. class TaskLaunchRecord { public: @@ -127,19 +104,44 @@ struct AsyncState { } }; -class AsyncStateHash { - public: - size_t operator()(const AsyncState &s) const { - return (uint64)s.snode ^ (uint64)s.type; +TLANG_NAMESPACE_END + +namespace std { +template <> +struct hash { + std::size_t operator()(const taichi::lang::IRHandle &ir_handle) const + noexcept { + return ir_handle.hash(); + } +}; + +template <> +struct hash> { + std::size_t operator()( + const std::pair + &ir_handles) const noexcept { + return ir_handles.first.hash() * 100000007UL + ir_handles.second.hash(); + } +}; + +template <> +struct hash { + std::size_t operator()(const taichi::lang::AsyncState &s) const noexcept { + return (std::size_t)s.snode ^ (std::size_t)s.type; } }; +} // namespace std + +TLANG_NAMESPACE_BEGIN + struct TaskMeta { std::string name; OffloadedStmt::TaskType type{OffloadedStmt::TaskType::serial}; SNode *snode{nullptr}; // struct-for and listgen only - std::unordered_set input_states; - std::unordered_set output_states; + std::unordered_set input_states; + std::unordered_set output_states; + std::unordered_map element_wise; void print() const; }; diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index c57dffe191c35..74c3e0f029aaa 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -235,11 +235,11 @@ bool StateFlowGraph::fuse() { std::vector has_path, has_path_reverse; std::tie(has_path, has_path_reverse) = compute_transitive_closure(); - // Cache the result that if each pair is fusable by task types. + // Cache the result that if each pair is fusible by task types. // TODO: improve this - auto task_type_fusable = std::make_unique(n); + auto task_type_fusible = std::make_unique(n); for (int i = 0; i < n; i++) { - task_type_fusable[i] = Bitset(n); + task_type_fusible[i] = Bitset(n); } // nodes_[0] is the initial node. for (int i = 1; i < n; i++) { @@ -292,10 +292,10 @@ bool StateFlowGraph::fuse() { // TODO: avoid snode accessors going into async engine const bool is_snode_accessor = (rec_i.kernel->is_accessor || rec_j.kernel->is_accessor); - bool fusable = + bool fusible = (is_same_range_for || is_same_struct_for || are_both_serial) && kernel_args_match && !is_snode_accessor; - task_type_fusable[i][j] = fusable; + task_type_fusible[i][j] = fusible; } } @@ -345,6 +345,29 @@ bool StateFlowGraph::fuse() { auto fused = std::make_unique(n); + auto edge_fusible = [&](int a, int b) { + // Check if a and b are fusible if there is an edge (a, b). + if (fused[a] || fused[b] || !task_type_fusible[a][b]) { + return false; + } + if (nodes_[a]->meta->type == OffloadedStmt::TaskType::serial) { + return true; + } + for (auto &state : nodes_[a]->output_edges) { + if (state.first.type != AsyncState::Type::value) { + // TODO: What checks do we need for edges of mask/list states? + continue; + } + if (state.second.find(nodes_[b].get()) != state.second.end()) { + if (!nodes_[a]->meta->element_wise[state.first.snode] || + !nodes_[b]->meta->element_wise[state.first.snode]) { + return false; + } + } + } + return true; + }; + bool modified = false; while (true) { bool updated = false; @@ -357,9 +380,7 @@ bool StateFlowGraph::fuse() { for (auto &edges : nodes_[i]->output_edges) { for (auto &edge : edges.second) { const int j = edge->node_id; - // TODO: for each pair of edge (i, j), we can only fuse if they - // are both serial or both element-wise. - if (!fused[j] && task_type_fusable[i][j]) { + if (edge_fusible(i, j)) { auto i_has_path_to_j = has_path[i] & has_path_reverse[j]; i_has_path_to_j[i] = i_has_path_to_j[j] = false; // check if i doesn't have a path to j of length >= 2 @@ -381,7 +402,7 @@ bool StateFlowGraph::fuse() { for (int i = 1; i < n; i++) { if (!fused[i]) { for (int j = i + 1; j < n; j++) { - if (!fused[j] && task_type_fusable[i][j] && !has_path[i][j] && + if (!fused[j] && task_type_fusible[i][j] && !has_path[i][j] && !has_path[j][i]) { do_fuse(i, j); fused[i] = fused[j] = true; diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index 895eb2fe3e017..f9c5287dccca3 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -18,8 +18,7 @@ class IRBank; class StateFlowGraph { public: struct Node; - using StateToNodeMapping = - std::unordered_map; + using StateToNodeMapping = std::unordered_map; // Each node is a task // Note: after SFG is done, each node here should hold a TaskLaunchRecord. @@ -39,8 +38,8 @@ class StateFlowGraph { // Profiling showed horrible performance using std::unordered_multimap (at // least on Mac with clang-1103.0.32.62)... - std::unordered_map, AsyncStateHash> - output_edges, input_edges; + std::unordered_map> output_edges, + input_edges; std::string string() const; @@ -140,7 +139,7 @@ class StateFlowGraph { Node *initial_node_; // The initial node holds all the initial states. TaskMeta initial_meta_; StateToNodeMapping latest_state_owner_; - std::unordered_map, AsyncStateHash> + std::unordered_map> latest_state_readers_; std::unordered_map task_name_to_launch_ids_; IRBank *ir_bank_;