From 391fcff7f9358099f0ca568afc75c3fd07110e7a Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 24 Dec 2020 22:11:12 +0800 Subject: [PATCH 01/10] [async] [IR] More precise same_value analysis --- taichi/analysis/same_statements.cpp | 123 ++++++++++++++++++++++++---- taichi/ir/analysis.h | 41 ++++++++-- taichi/program/async_engine.cpp | 8 +- taichi/program/async_utils.h | 7 ++ taichi/program/state_flow_graph.cpp | 12 ++- 5 files changed, 162 insertions(+), 29 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 7effc98d73de2..d6f0c89c7fd10 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -2,6 +2,8 @@ #include "taichi/ir/analysis.h" #include "taichi/ir/statements.h" #include "taichi/ir/visitors.h" +#include "taichi/program/async_utils.h" +#include "taichi/program/ir_bank.h" #include #include #include @@ -18,23 +20,44 @@ class IRNodeComparator : public IRVisitor { bool recursively_check_; bool check_same_value_; + std::unordered_set possibly_modified_states_; + bool all_states_can_be_modified_; + IRBank *ir_bank_; + public: bool same; - explicit IRNodeComparator(IRNode *other_node, - std::optional> id_map, - bool check_same_value) + explicit IRNodeComparator( + IRNode *other_node, + const std::optional> &id_map, + bool check_same_value, + const std::optional> + &possibly_modified_states, + IRBank *ir_bank) : other_node(other_node) { allow_undefined_visitor = true; invoke_default_visitor = true; same = true; if (id_map.has_value()) { recursively_check_ = true; - this->id_map = std::move(id_map.value()); + this->id_map = id_map.value(); } else { recursively_check_ = false; } + if (possibly_modified_states.has_value()) { + TI_ASSERT_INFO(check_same_value, + "The parameter possibly_modified_states " + "is only supported when check_same_value is true"); + TI_ASSERT_INFO(ir_bank, + "The parameter possibly_modified_states " + "requires ir_bank") + all_states_can_be_modified_ = false; + this->possibly_modified_states_ = possibly_modified_states.value(); + } else { + all_states_can_be_modified_ = true; + } check_same_value_ = check_same_value; + ir_bank_ = ir_bank; } void map_id(int this_id, int other_id) { @@ -97,22 +120,65 @@ class IRNodeComparator : public IRVisitor { same = false; return; } + auto other = other_node->as(); // If two identical statements can have different values, return false. - if (check_same_value_ && !stmt->is_container_statement() && - !stmt->common_statement_eliminable()) { - same = false; - return; + // TODO: two identical GlobalPtrStmts cannot have different values, + // but GlobalPtrStmt::common_statement_eliminable() is false. + if (check_same_value_ && stmt != other && !stmt->is_container_statement() && + !stmt->common_statement_eliminable() && !stmt->is()) { + if (all_states_can_be_modified_) { + same = false; + return; + } else { + // "break" all branches that do not result in "same = false" + do { + if (auto global_load = stmt->cast()) { + if (auto global_ptr = global_load->ptr->cast()) { + TI_ASSERT(global_ptr->width() == 1); + if (possibly_modified_states_.count(ir_bank_->get_async_state( + global_ptr->snodes[0], AsyncState::Type::value)) == 0) { + break; + } + } + // TODO: other cases? + } else if (auto global_store = stmt->cast()) { + if (auto global_ptr = global_store->ptr->cast()) { + TI_ASSERT(global_ptr->width() == 1); + if (possibly_modified_states_.count(ir_bank_->get_async_state( + global_ptr->snodes[0], AsyncState::Type::value)) == 0) { + break; + } + } + } + same = false; + return; + } while (false); + } } // Note that we do not need to test !stmt2->common_statement_eliminable() // because if this condition does not hold, // same_statements(stmt1, stmt2) returns false anyway. // field check - auto other = other_node->as(); - if (!stmt->field_manager.equal(other->field_manager)) { - same = false; - return; + if (check_same_value_ && stmt->is()) { + // Special case: we do not care the "activate" field when checking + // whether two global pointers share the same value. + // And we cannot use irpass::analysis::definitely_same_address() + // directly because that function does not support id_map. + + // TODO: Update this part if GlobalPtrStmt comes to have more fields + TI_ASSERT(stmt->width() == 1); + if (stmt->as()->snodes[0]->id != + other->as()->snodes[0]->id) { + same = false; + return; + } + } else { + if (!stmt->field_manager.equal(other->field_manager)) { + same = false; + return; + } } // operand check @@ -219,8 +285,12 @@ class IRNodeComparator : public IRVisitor { static bool run(IRNode *root1, IRNode *root2, const std::optional> &id_map, - bool check_same_value) { - IRNodeComparator comparator(root2, id_map, check_same_value); + bool check_same_value, + const std::optional> + &possibly_modified_states, + IRBank *ir_bank) { + IRNodeComparator comparator(root2, id_map, check_same_value, + possibly_modified_states, ir_bank); root1->accept(&comparator); return comparator.same; } @@ -270,17 +340,36 @@ bool same_statements( if (!root1 || !root2) return false; return IRNodeComparator::run(root1, root2, id_map, - /*check_same_value=*/false); + /*check_same_value=*/false, std::nullopt, + /*ir_bank=*/nullptr); +} +bool same_value(Stmt *stmt1, + Stmt *stmt2, + const AsyncStateSet &possibly_modified_states, + IRBank *ir_bank, + const std::optional> &id_map) { + // Test if two statements definitely have the same value. + if (stmt1 == stmt2) + return true; + if (!stmt1 || !stmt2) + return false; + return IRNodeComparator::run( + stmt1, stmt2, id_map, /*check_same_value=*/true, + std::make_optional>( + possibly_modified_states.s), + ir_bank); } bool same_value(Stmt *stmt1, Stmt *stmt2, const std::optional> &id_map) { - // Test if two statements must have the same value. + // Test if two statements definitely have the same value. if (stmt1 == stmt2) return true; if (!stmt1 || !stmt2) return false; - return IRNodeComparator::run(stmt1, stmt2, id_map, /*check_same_value=*/true); + return IRNodeComparator::run(stmt1, stmt2, id_map, + /*check_same_value=*/true, std::nullopt, + /*ir_bank=*/nullptr); } } // namespace irpass::analysis diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 15bc8e612fa65..a4e7724378917 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -54,6 +54,7 @@ class ControlFlowGraph; struct TaskMeta; class IRBank; +class AsyncStateSet; // IR Analysis namespace irpass::analysis { @@ -82,20 +83,44 @@ std::vector get_store_destination(Stmt *store_stmt); bool has_store_or_atomic(IRNode *root, const std::vector &vars); std::pair last_store_or_atomic(IRNode *root, Stmt *var); bool maybe_same_address(Stmt *var1, Stmt *var2); -/** Test if root1 and root2 are the same, i.e., have the same type, - * the same operands, the same fields, and the same containing statements. +/** + * Test if root1 and root2 are the same, i.e., have the same type, + * the same operands, the same fields, and the same containing statements. * - * @param id_map - * If id_map is std::nullopt by default, two operands are considered - * the same if they have the same id and do not belong to either root, - * or they belong to root1 and root2 at the same position in the roots. - * Otherwise, this function also recursively check the operands until - * ids in the id_map are reached. + * @param id_map + * If id_map is std::nullopt by default, two operands are considered + * the same if they have the same id and do not belong to either root, + * or they belong to root1 and root2 at the same position in the roots. + * Otherwise, this function also recursively check the operands until + * ids in the id_map are reached. */ bool same_statements( IRNode *root1, IRNode *root2, const std::optional> &id_map = std::nullopt); +/** + * Test if stmt1 and stmt2 definitely have the same value. + * + * @param possibly_modified_states + * Only states in possibly_modified_states can be modified + * between stmt1 and stmt2. + * + * @param id_map + * Same as in same_statements(root1, root2, id_map). + */ +bool same_value( + Stmt *stmt1, + Stmt *stmt2, + const AsyncStateSet &possibly_modified_states, + IRBank *ir_bank, + const std::optional> &id_map = std::nullopt); +/** + * Test if stmt1 and stmt2 definitely have the same value. + * Any global fields can be modified between stmt1 and stmt2. + * + * @param id_map + * Same as in same_statements(root1, root2, id_map). + */ bool same_value( Stmt *stmt1, Stmt *stmt2, diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 670bb6cf94cec..ba0f86fc14d80 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -211,8 +211,13 @@ void AsyncEngine::synchronize() { sfg->reid_nodes(); sfg->reid_pending_nodes(); sfg->sort_node_edges(); - TI_TRACE("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), + auto init_size = sfg->size(); + std::cout << std::flush; + TI_INFO("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), sfg->num_pending_tasks()); + std::cout << std::flush; + sfg->print(); + std::cout << std::flush; debug_sfg("initial"); if (program->config.debug) { sfg->verify(); @@ -260,6 +265,7 @@ void AsyncEngine::synchronize() { // Clear SFG debug stats cur_sync_sfg_debug_counter_ = 0; cur_sync_sfg_debug_per_stage_counts_.clear(); + TI_ASSERT(init_size <= 20); } void AsyncEngine::debug_sfg(const std::string &stage) { diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 10391dbb25b7f..4a80dea42f151 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -215,6 +216,12 @@ struct TaskMeta { void print() const; }; +// A wrapper class for the parameter in bool same_value() in analysis.h. +class AsyncStateSet { + public: + std::unordered_set s; +}; + class IRBank; TaskMeta *get_task_meta(IRBank *bank, const TaskLaunchRecord &t); diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 6a8d06ab12c36..85e8b754c0fce 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -675,6 +675,13 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { nodes[b]->meta->loop_unique.count(snode) == 0) { return false; } + std::unordered_map offload_map; + offload_map[0] = 0; + std::unordered_set modified_states = + get_task_meta(ir_bank_, nodes[a]->rec)->output_states; + modified_states.merge( + get_task_meta(ir_bank_, nodes[b]->rec)->output_states); + AsyncStateSet modified_states_set{modified_states}; auto same_loop_unique_address = [&](GlobalPtrStmt *ptr1, GlobalPtrStmt *ptr2) { if (!ptr1 || !ptr2) { @@ -686,11 +693,10 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { TI_ASSERT(nodes[b]->rec.stmt()->id == 0); // Only map the OffloadedStmt to see if both SNodes are loop-unique // on the same statement. - std::unordered_map offload_map; - offload_map[0] = 0; for (int i = 0; i < (int)ptr1->indices.size(); i++) { if (!irpass::analysis::same_value( - ptr1->indices[i], ptr2->indices[i], + ptr1->indices[i], ptr2->indices[i], modified_states_set, + ir_bank_, std::make_optional>( offload_map))) { return false; From 64c0adbb2526d0d9302a014e2d7c6e0eadf6065c Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 24 Dec 2020 22:12:27 +0800 Subject: [PATCH 02/10] Remove debug outputs --- taichi/program/async_engine.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index ba0f86fc14d80..670bb6cf94cec 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -211,13 +211,8 @@ void AsyncEngine::synchronize() { sfg->reid_nodes(); sfg->reid_pending_nodes(); sfg->sort_node_edges(); - auto init_size = sfg->size(); - std::cout << std::flush; - TI_INFO("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), + TI_TRACE("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), sfg->num_pending_tasks()); - std::cout << std::flush; - sfg->print(); - std::cout << std::flush; debug_sfg("initial"); if (program->config.debug) { sfg->verify(); @@ -265,7 +260,6 @@ void AsyncEngine::synchronize() { // Clear SFG debug stats cur_sync_sfg_debug_counter_ = 0; cur_sync_sfg_debug_per_stage_counts_.clear(); - TI_ASSERT(init_size <= 20); } void AsyncEngine::debug_sfg(const std::string &stage) { From d05ded9af76b694dfec97a8e955c295e8f322a27 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 26 Dec 2020 15:31:50 +0800 Subject: [PATCH 03/10] Fix std::unordered_set::merge()... (with debug outputs) --- taichi/analysis/same_statements.cpp | 17 ++++++++ taichi/program/async_engine.cpp | 8 +++- taichi/program/ir_bank.cpp | 11 ++++- taichi/program/state_flow_graph.cpp | 65 +++++++++++++++++++++++------ 4 files changed, 86 insertions(+), 15 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index d6f0c89c7fd10..1d3ad8df08f67 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -115,11 +115,15 @@ class IRNodeComparator : public IRVisitor { } void basic_check(Stmt *stmt) { + if (verbose) + std::cout << "checking " << stmt->id << std::endl; // type check if (typeid(*other_node) != typeid(*stmt)) { same = false; return; } + if (verbose) + std::cout << "qqq" << std::endl; auto other = other_node->as(); // If two identical statements can have different values, return false. @@ -156,12 +160,16 @@ class IRNodeComparator : public IRVisitor { } while (false); } } + if (verbose) + std::cout << "www" << std::endl; // Note that we do not need to test !stmt2->common_statement_eliminable() // because if this condition does not hold, // same_statements(stmt1, stmt2) returns false anyway. // field check if (check_same_value_ && stmt->is()) { + if (verbose) + std::cout << "eee" << std::endl; // Special case: we do not care the "activate" field when checking // whether two global pointers share the same value. // And we cannot use irpass::analysis::definitely_same_address() @@ -175,11 +183,15 @@ class IRNodeComparator : public IRVisitor { return; } } else { + if (verbose) + std::cout << "rrr" << std::endl; if (!stmt->field_manager.equal(other->field_manager)) { same = false; return; } } + if (verbose) + std::cout << "ttt" << std::endl; // operand check if (stmt->num_operands() != other->num_operands()) { @@ -195,6 +207,8 @@ class IRNodeComparator : public IRVisitor { continue; check_mapping(stmt->operand(i), other->operand(i)); } + if (verbose) + std::cout << "yyy" << std::endl; map_id(stmt->id, other->id); } @@ -281,6 +295,7 @@ class IRNodeComparator : public IRVisitor { other_node = other; } } + bool verbose{false}; static bool run(IRNode *root1, IRNode *root2, @@ -291,6 +306,8 @@ class IRNodeComparator : public IRVisitor { IRBank *ir_bank) { IRNodeComparator comparator(root2, id_map, check_same_value, possibly_modified_states, ir_bank); +// if (check_same_value && id_map.has_value()) +// comparator.verbose = true; root1->accept(&comparator); return comparator.same; } diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index e59100622f822..4df38bc8ef62a 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -214,8 +214,13 @@ void AsyncEngine::synchronize() { sfg->reid_nodes(); sfg->reid_pending_nodes(); sfg->sort_node_edges(); - TI_TRACE("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), + auto init_size = sfg->size(); + std::cout << std::flush; + TI_INFO("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), sfg->num_pending_tasks()); + std::cout << std::flush; + sfg->print(); + std::cout << std::flush; debug_sfg("initial"); if (program->config.debug) { sfg->verify(); @@ -263,6 +268,7 @@ void AsyncEngine::synchronize() { // Clear SFG debug stats cur_sync_sfg_debug_counter_ = 0; cur_sync_sfg_debug_per_stage_counts_.clear(); + //TI_ASSERT(init_size <= 20); } void AsyncEngine::debug_sfg(const std::string &stage) { diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index 92e7c99146359..c82f320465fb2 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -77,12 +77,18 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { return result; } - TI_TRACE("Begin uncached fusion"); + TI_INFO("Begin uncached fusion"); + std::cout << handle_a.ir()->get_kernel()->name << " " << (handle_a.ir()->as()->has_body() ? handle_a.ir()->as()->body->size() : -1) << std::endl; + std::cout << handle_b.ir()->get_kernel()->name << " " << (handle_b.ir()->as()->has_body() ? handle_b.ir()->as()->body->size() : -1) << std::endl; // We are about to change both |task_a| and |task_b|. Clone them first. auto cloned_task_a = handle_a.clone(); auto cloned_task_b = handle_b.clone(); auto task_a = cloned_task_a->as(); auto task_b = cloned_task_b->as(); +// std::cout << "before: " << std::endl; +// irpass::print((IRNode *)handle_a.ir()); +// irpass::print((IRNode *)handle_b.ir()); +// std::cout << std::flush; // TODO: in certain cases this optimization can be wrong! // Fuse task b into task_a for (int j = 0; j < (int)task_b->body->size(); j++) { @@ -96,6 +102,9 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel); // For now, re_id is necessary for the hash to be correct. irpass::re_id(task_a); +// std::cout << "after: " << std::endl; +// irpass::print(task_a); +// std::cout << std::flush; auto h = get_hash(task_a); result = IRHandle(task_a, h); diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 74235d21e683a..9bf42c1b6d13e 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -654,12 +654,41 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { auto edge_fusible = [&](int a, int b) { TI_PROFILER("edge_fusible"); + TI_ASSERT(a >= 0 && a < nodes.size() && b >= 0 && b < nodes.size()); + //std::cout << "fusible " << nodes[a]->rec.id << " " << nodes[b]->rec.id + // << std::endl; + bool verbose = (nodes[a]->rec.id == 117 && nodes[b]->rec.id == 124); + if (nodes[a]->rec.stmt()->has_body() && nodes[a]->rec.stmt()->body->size() > 100 && + nodes[b]->rec.stmt()->has_body() && nodes[b]->rec.stmt()->body->size() > 100) { + verbose = true; + } else { + } + verbose = false; + if (verbose) { + std::cout << "verbose" << std::endl; + std::cout << nodes[a]->rec.stmt()->get_kernel()->name << " " << nodes[a]->rec.id << " " << (nodes[a]->rec.stmt()->has_body() ? nodes[a]->rec.stmt()->body->size() : -1) << std::endl; + std::cout << nodes[b]->rec.stmt()->get_kernel()->name << " " << nodes[b]->rec.id << " " << (nodes[b]->rec.stmt()->has_body() ? nodes[b]->rec.stmt()->body->size() : -1) << std::endl; + //irpass::print(nodes[a]->rec.stmt()); + //irpass::print(nodes[b]->rec.stmt()); + } + if (verbose) + std::cout << "aaaaa" << std::endl; // Check if a and b are fusible if there is an edge (a, b). if (fused[a] || fused[b] || !fusion_meta[a].fusible || fusion_meta[a] != fusion_meta[b]) { return false; } + if (verbose) + std::cout << "bbbbb" << std::endl; if (nodes[a]->meta->type != OffloadedTaskType::serial) { + std::unordered_map offload_map; + offload_map[0] = 0; + std::unordered_set modified_states = + get_task_meta(ir_bank_, nodes[a]->rec)->output_states; + std::unordered_set modified_states_b = + get_task_meta(ir_bank_, nodes[a]->rec)->output_states; + modified_states.insert(modified_states_b.begin(), modified_states_b.end()); + AsyncStateSet modified_states_set{modified_states}; for (auto state_iter = nodes[a]->output_edges.get_state_iterator(); !state_iter.done(); ++state_iter) { auto state = state_iter.get_state(); @@ -676,15 +705,12 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { if (state_iter.has_edge(nodes[b])) { if (nodes[a]->meta->loop_unique.count(snode) == 0 || nodes[b]->meta->loop_unique.count(snode) == 0) { + if (verbose) { + std::cout << "not loop-unique " + << snode->get_node_type_name_hinted() << std::endl; + } return false; } - std::unordered_map offload_map; - offload_map[0] = 0; - std::unordered_set modified_states = - get_task_meta(ir_bank_, nodes[a]->rec)->output_states; - modified_states.merge( - get_task_meta(ir_bank_, nodes[b]->rec)->output_states); - AsyncStateSet modified_states_set{modified_states}; auto same_loop_unique_address = [&](GlobalPtrStmt *ptr1, GlobalPtrStmt *ptr2) { if (!ptr1 || !ptr2) { @@ -709,14 +735,25 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { }; if (!same_loop_unique_address(nodes[a]->meta->loop_unique[snode], nodes[b]->meta->loop_unique[snode])) { + if (verbose) { + std::cout << "not loop-unique address " + << snode->get_node_type_name_hinted() << std::endl; + } return false; } } } } + if (verbose) + std::cout << "ccccc" << std::endl; // check if a doesn't have a path to b of length >= 2 auto a_has_path_to_b = has_path[a] & has_path_reverse[b]; a_has_path_to_b[a] = a_has_path_to_b[b] = false; + if (verbose) { + if (a_has_path_to_b.none()) { + std::cout << "ddddd" << std::endl; + } + } return a_has_path_to_b.none(); }; @@ -787,12 +824,14 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { // Fuse no more than one task into task i bool i_updated = false; for (auto &edge : nodes[i]->output_edges.get_all_edges()) { - const int j = edge.second->pending_node_id - begin; - if (j != -1 && edge_fusible(i, j)) { - do_fuse(i, j); - // Iterators of nodes[i]->output_edges may be invalidated - i_updated = true; - break; + if (edge.second->pending()) { + const int j = edge.second->pending_node_id - begin; + if (j >= 0 && j < nodes.size() && edge_fusible(i, j)) { + do_fuse(i, j); + // Iterators of nodes[i]->output_edges may be invalidated + i_updated = true; + break; + } } if (i_updated) { From 112e393b020d49a2a11b78759ce8f365db19d882 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 26 Dec 2020 15:36:25 +0800 Subject: [PATCH 04/10] Remove debug outputs --- taichi/analysis/same_statements.cpp | 25 ------------------- taichi/program/async_engine.cpp | 7 +----- taichi/program/ir_bank.cpp | 19 +++++++-------- taichi/program/state_flow_graph.cpp | 38 ++--------------------------- 4 files changed, 12 insertions(+), 77 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 1d3ad8df08f67..c8f11bf8e9bfd 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -115,15 +115,11 @@ class IRNodeComparator : public IRVisitor { } void basic_check(Stmt *stmt) { - if (verbose) - std::cout << "checking " << stmt->id << std::endl; // type check if (typeid(*other_node) != typeid(*stmt)) { same = false; return; } - if (verbose) - std::cout << "qqq" << std::endl; auto other = other_node->as(); // If two identical statements can have different values, return false. @@ -146,30 +142,18 @@ class IRNodeComparator : public IRVisitor { } } // TODO: other cases? - } else if (auto global_store = stmt->cast()) { - if (auto global_ptr = global_store->ptr->cast()) { - TI_ASSERT(global_ptr->width() == 1); - if (possibly_modified_states_.count(ir_bank_->get_async_state( - global_ptr->snodes[0], AsyncState::Type::value)) == 0) { - break; - } - } } same = false; return; } while (false); } } - if (verbose) - std::cout << "www" << std::endl; // Note that we do not need to test !stmt2->common_statement_eliminable() // because if this condition does not hold, // same_statements(stmt1, stmt2) returns false anyway. // field check if (check_same_value_ && stmt->is()) { - if (verbose) - std::cout << "eee" << std::endl; // Special case: we do not care the "activate" field when checking // whether two global pointers share the same value. // And we cannot use irpass::analysis::definitely_same_address() @@ -183,15 +167,11 @@ class IRNodeComparator : public IRVisitor { return; } } else { - if (verbose) - std::cout << "rrr" << std::endl; if (!stmt->field_manager.equal(other->field_manager)) { same = false; return; } } - if (verbose) - std::cout << "ttt" << std::endl; // operand check if (stmt->num_operands() != other->num_operands()) { @@ -207,8 +187,6 @@ class IRNodeComparator : public IRVisitor { continue; check_mapping(stmt->operand(i), other->operand(i)); } - if (verbose) - std::cout << "yyy" << std::endl; map_id(stmt->id, other->id); } @@ -295,7 +273,6 @@ class IRNodeComparator : public IRVisitor { other_node = other; } } - bool verbose{false}; static bool run(IRNode *root1, IRNode *root2, @@ -306,8 +283,6 @@ class IRNodeComparator : public IRVisitor { IRBank *ir_bank) { IRNodeComparator comparator(root2, id_map, check_same_value, possibly_modified_states, ir_bank); -// if (check_same_value && id_map.has_value()) -// comparator.verbose = true; root1->accept(&comparator); return comparator.same; } diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 4df38bc8ef62a..6c0d13c6bb904 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -215,12 +215,8 @@ void AsyncEngine::synchronize() { sfg->reid_pending_nodes(); sfg->sort_node_edges(); auto init_size = sfg->size(); - std::cout << std::flush; - TI_INFO("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), + TI_TRACE("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), sfg->num_pending_tasks()); - std::cout << std::flush; - sfg->print(); - std::cout << std::flush; debug_sfg("initial"); if (program->config.debug) { sfg->verify(); @@ -268,7 +264,6 @@ void AsyncEngine::synchronize() { // Clear SFG debug stats cur_sync_sfg_debug_counter_ = 0; cur_sync_sfg_debug_per_stage_counts_.clear(); - //TI_ASSERT(init_size <= 20); } void AsyncEngine::debug_sfg(const std::string &stage) { diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index c82f320465fb2..6cac3aec1e5bc 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -77,18 +77,20 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { return result; } - TI_INFO("Begin uncached fusion"); - std::cout << handle_a.ir()->get_kernel()->name << " " << (handle_a.ir()->as()->has_body() ? handle_a.ir()->as()->body->size() : -1) << std::endl; - std::cout << handle_b.ir()->get_kernel()->name << " " << (handle_b.ir()->as()->has_body() ? handle_b.ir()->as()->body->size() : -1) << std::endl; + TI_TRACE("Begin uncached fusion: [{}(size={})] <- [{}(size={})]", + handle_a.ir()->get_kernel()->name, + (handle_a.ir()->as()->has_body() + ? handle_a.ir()->as()->body->size() + : -1), + handle_b.ir()->get_kernel()->name, + (handle_b.ir()->as()->has_body() + ? handle_a.ir()->as()->body->size() + : -1)); // We are about to change both |task_a| and |task_b|. Clone them first. auto cloned_task_a = handle_a.clone(); auto cloned_task_b = handle_b.clone(); auto task_a = cloned_task_a->as(); auto task_b = cloned_task_b->as(); -// std::cout << "before: " << std::endl; -// irpass::print((IRNode *)handle_a.ir()); -// irpass::print((IRNode *)handle_b.ir()); -// std::cout << std::flush; // TODO: in certain cases this optimization can be wrong! // Fuse task b into task_a for (int j = 0; j < (int)task_b->body->size(); j++) { @@ -102,9 +104,6 @@ IRHandle IRBank::fuse(IRHandle handle_a, IRHandle handle_b, Kernel *kernel) { irpass::full_simplify(task_a, /*after_lower_access=*/false, kernel); // For now, re_id is necessary for the hash to be correct. irpass::re_id(task_a); -// std::cout << "after: " << std::endl; -// irpass::print(task_a); -// std::cout << std::flush; auto h = get_hash(task_a); result = IRHandle(task_a, h); diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 9bf42c1b6d13e..0572a721b2584 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -655,31 +655,11 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { auto edge_fusible = [&](int a, int b) { TI_PROFILER("edge_fusible"); TI_ASSERT(a >= 0 && a < nodes.size() && b >= 0 && b < nodes.size()); - //std::cout << "fusible " << nodes[a]->rec.id << " " << nodes[b]->rec.id - // << std::endl; - bool verbose = (nodes[a]->rec.id == 117 && nodes[b]->rec.id == 124); - if (nodes[a]->rec.stmt()->has_body() && nodes[a]->rec.stmt()->body->size() > 100 && - nodes[b]->rec.stmt()->has_body() && nodes[b]->rec.stmt()->body->size() > 100) { - verbose = true; - } else { - } - verbose = false; - if (verbose) { - std::cout << "verbose" << std::endl; - std::cout << nodes[a]->rec.stmt()->get_kernel()->name << " " << nodes[a]->rec.id << " " << (nodes[a]->rec.stmt()->has_body() ? nodes[a]->rec.stmt()->body->size() : -1) << std::endl; - std::cout << nodes[b]->rec.stmt()->get_kernel()->name << " " << nodes[b]->rec.id << " " << (nodes[b]->rec.stmt()->has_body() ? nodes[b]->rec.stmt()->body->size() : -1) << std::endl; - //irpass::print(nodes[a]->rec.stmt()); - //irpass::print(nodes[b]->rec.stmt()); - } - if (verbose) - std::cout << "aaaaa" << std::endl; // Check if a and b are fusible if there is an edge (a, b). if (fused[a] || fused[b] || !fusion_meta[a].fusible || fusion_meta[a] != fusion_meta[b]) { return false; } - if (verbose) - std::cout << "bbbbb" << std::endl; if (nodes[a]->meta->type != OffloadedTaskType::serial) { std::unordered_map offload_map; offload_map[0] = 0; @@ -687,7 +667,8 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { get_task_meta(ir_bank_, nodes[a]->rec)->output_states; std::unordered_set modified_states_b = get_task_meta(ir_bank_, nodes[a]->rec)->output_states; - modified_states.insert(modified_states_b.begin(), modified_states_b.end()); + modified_states.insert(modified_states_b.begin(), + modified_states_b.end()); AsyncStateSet modified_states_set{modified_states}; for (auto state_iter = nodes[a]->output_edges.get_state_iterator(); !state_iter.done(); ++state_iter) { @@ -705,10 +686,6 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { if (state_iter.has_edge(nodes[b])) { if (nodes[a]->meta->loop_unique.count(snode) == 0 || nodes[b]->meta->loop_unique.count(snode) == 0) { - if (verbose) { - std::cout << "not loop-unique " - << snode->get_node_type_name_hinted() << std::endl; - } return false; } auto same_loop_unique_address = [&](GlobalPtrStmt *ptr1, @@ -735,25 +712,14 @@ std::unordered_set StateFlowGraph::fuse_range(int begin, int end) { }; if (!same_loop_unique_address(nodes[a]->meta->loop_unique[snode], nodes[b]->meta->loop_unique[snode])) { - if (verbose) { - std::cout << "not loop-unique address " - << snode->get_node_type_name_hinted() << std::endl; - } return false; } } } } - if (verbose) - std::cout << "ccccc" << std::endl; // check if a doesn't have a path to b of length >= 2 auto a_has_path_to_b = has_path[a] & has_path_reverse[b]; a_has_path_to_b[a] = a_has_path_to_b[b] = false; - if (verbose) { - if (a_has_path_to_b.none()) { - std::cout << "ddddd" << std::endl; - } - } return a_has_path_to_b.none(); }; From c465321e80b679ce90cf67da1208f8aa5137cbec Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 26 Dec 2020 16:08:35 +0800 Subject: [PATCH 05/10] Remove do ... while (false); --- taichi/analysis/same_statements.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index c8f11bf8e9bfd..1f75e2444a580 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -131,21 +131,21 @@ class IRNodeComparator : public IRVisitor { same = false; return; } else { - // "break" all branches that do not result in "same = false" - do { - if (auto global_load = stmt->cast()) { - if (auto global_ptr = global_load->ptr->cast()) { - TI_ASSERT(global_ptr->width() == 1); - if (possibly_modified_states_.count(ir_bank_->get_async_state( - global_ptr->snodes[0], AsyncState::Type::value)) == 0) { - break; - } + bool same_value = false; + if (auto global_load = stmt->cast()) { + if (auto global_ptr = global_load->ptr->cast()) { + TI_ASSERT(global_ptr->width() == 1); + if (possibly_modified_states_.count(ir_bank_->get_async_state( + global_ptr->snodes[0], AsyncState::Type::value)) == 0) { + same_value = true; } - // TODO: other cases? } + // TODO: other cases? + } + if (!same_value) { same = false; return; - } while (false); + } } } // Note that we do not need to test !stmt2->common_statement_eliminable() From ac0310ecdba817cc339e41285799338bb9d872a5 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 26 Dec 2020 16:10:06 +0800 Subject: [PATCH 06/10] [skip ci] minor change --- taichi/program/async_engine.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 6c0d13c6bb904..e59100622f822 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -214,7 +214,6 @@ void AsyncEngine::synchronize() { sfg->reid_nodes(); sfg->reid_pending_nodes(); sfg->sort_node_edges(); - auto init_size = sfg->size(); TI_TRACE("Synchronizing SFG of {} nodes ({} pending)", sfg->size(), sfg->num_pending_tasks()); debug_sfg("initial"); From 0884c5c5236e900c6fce704935c04ca706dbdc90 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 26 Dec 2020 16:11:09 +0800 Subject: [PATCH 07/10] [skip ci] minor change --- taichi/program/async_utils.h | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 4a80dea42f151..ff3b2d93cf6f1 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include From a1843e3c8bb111701f2c9337aadebb49270338ed Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Sat, 26 Dec 2020 03:12:22 -0500 Subject: [PATCH 08/10] [skip ci] enforce code format --- tests/python/test_bit_array_vectorization.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/python/test_bit_array_vectorization.py b/tests/python/test_bit_array_vectorization.py index a3457c25192df..9995575136457 100644 --- a/tests/python/test_bit_array_vectorization.py +++ b/tests/python/test_bit_array_vectorization.py @@ -23,24 +23,21 @@ def test_vectorized_struct_for(): @ti.kernel def init(): for i, j in ti.ndrange((boundary_offset, N - boundary_offset), - (boundary_offset, N - boundary_offset)): + (boundary_offset, N - boundary_offset)): x[i, j] = ti.random(dtype=ti.i32) % 2 - @ti.kernel def assign_vectorized(): ti.bit_vectorize(32) for i, j in x: y[i, j] = x[i, j] - @ti.kernel def verify(): for i, j in ti.ndrange((boundary_offset, N - boundary_offset), - (boundary_offset, N - boundary_offset)): + (boundary_offset, N - boundary_offset)): assert y[i, j] == x[i, j] - init() assign_vectorized() verify() From 8b1e7cf8754311bd1fb437cdbbdbdcbcb5ddd365 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sat, 26 Dec 2020 22:51:57 +0800 Subject: [PATCH 09/10] Apply review --- taichi/analysis/same_statements.cpp | 20 +++++++++++++++----- taichi/ir/analysis.h | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 1f75e2444a580..7aef084fe0521 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -18,6 +18,8 @@ class IRNodeComparator : public IRVisitor { std::unordered_map id_map; bool recursively_check_; + + // Compare if two IRNodes definitely have the same value instead. bool check_same_value_; std::unordered_set possibly_modified_states_; @@ -121,12 +123,23 @@ class IRNodeComparator : public IRVisitor { return; } auto other = other_node->as(); + if (stmt == other) { + return; + } // If two identical statements can have different values, return false. + // TODO: actually the condition should be "can stmt be an operand of + // another statement?" + const bool stmt_has_value = !stmt->is_container_statement(); // TODO: two identical GlobalPtrStmts cannot have different values, // but GlobalPtrStmt::common_statement_eliminable() is false. - if (check_same_value_ && stmt != other && !stmt->is_container_statement() && - !stmt->common_statement_eliminable() && !stmt->is()) { + const bool identical_stmts_can_have_different_value = + stmt_has_value && !stmt->common_statement_eliminable() && + !stmt->is(); + // Note that we do not need to test !stmt2->common_statement_eliminable() + // because if this condition does not hold, + // same_value(stmt1, stmt2) returns false anyway. + if (check_same_value_ && identical_stmts_can_have_different_value) { if (all_states_can_be_modified_) { same = false; return; @@ -148,9 +161,6 @@ class IRNodeComparator : public IRVisitor { } } } - // Note that we do not need to test !stmt2->common_statement_eliminable() - // because if this condition does not hold, - // same_statements(stmt1, stmt2) returns false anyway. // field check if (check_same_value_ && stmt->is()) { diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index a4e7724378917..c66db0f98102e 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -102,7 +102,7 @@ bool same_statements( * Test if stmt1 and stmt2 definitely have the same value. * * @param possibly_modified_states - * Only states in possibly_modified_states can be modified + * Assumes that only states in possibly_modified_states can be modified * between stmt1 and stmt2. * * @param id_map From abda40fe66269c3e6c3cb06724619cde29987ddc Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Sun, 27 Dec 2020 13:44:14 +0800 Subject: [PATCH 10/10] Apply review --- taichi/analysis/same_statements.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 7aef084fe0521..aff10e0e2897a 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -6,7 +6,6 @@ #include "taichi/program/ir_bank.h" #include #include -#include TLANG_NAMESPACE_BEGIN @@ -20,6 +19,10 @@ class IRNodeComparator : public IRVisitor { bool recursively_check_; // Compare if two IRNodes definitely have the same value instead. + // When this is true, it's weaker in the sense that we don't require the + // activate field in the GlobalPtrStmt to be the same, but stronger in the + // sense that we require the value to be the same (especially stronger in + // GlobalLoadStmt, RandStmt, etc.). bool check_same_value_; std::unordered_set possibly_modified_states_; @@ -131,8 +134,11 @@ class IRNodeComparator : public IRVisitor { // TODO: actually the condition should be "can stmt be an operand of // another statement?" const bool stmt_has_value = !stmt->is_container_statement(); - // TODO: two identical GlobalPtrStmts cannot have different values, - // but GlobalPtrStmt::common_statement_eliminable() is false. + // TODO: We want to know if two identical statements of the type same as + // stmt can have different values. In most cases, this property is the + // same as Stmt::common_statement_eliminable(). However, two identical + // GlobalPtrStmts cannot have different values, although + // GlobalPtrStmt::common_statement_eliminable() is false. const bool identical_stmts_can_have_different_value = stmt_has_value && !stmt->common_statement_eliminable() && !stmt->is(); @@ -162,7 +168,6 @@ class IRNodeComparator : public IRVisitor { } } - // field check if (check_same_value_ && stmt->is()) { // Special case: we do not care the "activate" field when checking // whether two global pointers share the same value. @@ -177,6 +182,7 @@ class IRNodeComparator : public IRVisitor { return; } } else { + // field check if (!stmt->field_manager.equal(other->field_manager)) { same = false; return; @@ -291,6 +297,9 @@ class IRNodeComparator : public IRVisitor { const std::optional> &possibly_modified_states, IRBank *ir_bank) { + // We need to distinguish the case of an empty + // std::unordered_set (assuming every SNodes are unchanged) + // and empty (assuming nothing), so we use std::optional<> here. IRNodeComparator comparator(root2, id_map, check_same_value, possibly_modified_states, ir_bank); root1->accept(&comparator);