diff --git a/taichi/analysis/gather_func_store_dests.cpp b/taichi/analysis/gather_func_store_dests.cpp new file mode 100644 index 00000000000000..bdf7211fa0c9b7 --- /dev/null +++ b/taichi/analysis/gather_func_store_dests.cpp @@ -0,0 +1,103 @@ +#include +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/control_flow_graph.h" +#include "taichi/program/function.h" + +namespace taichi::lang { + +class GatherFuncStoreDests : public BasicStmtVisitor { + private: + std::unordered_set results_; + Function *current_func_; + struct TarjanData { + std::unordered_map func_dfn; + std::unordered_map func_low; + std::unordered_set func_in_stack; + std::stack func_stack; + }; + TarjanData &tarjan_data_; + + static std::unordered_set run(Function *func, + TarjanData &tarjan_data) { + TI_ASSERT(tarjan_data.func_dfn.count(func) == 0); + tarjan_data.func_dfn[func] = tarjan_data.func_low[func] = + tarjan_data.func_dfn.size(); + tarjan_data.func_in_stack.insert(func); + tarjan_data.func_stack.push(func); + GatherFuncStoreDests searcher(func, tarjan_data); + func->ir->accept(&searcher); + if (tarjan_data.func_low[func] == tarjan_data.func_dfn[func]) { + while (true) { + auto top = tarjan_data.func_stack.top(); + tarjan_data.func_stack.pop(); + tarjan_data.func_in_stack.erase(top); + top->store_dests.insert(searcher.results_.begin(), + searcher.results_.end()); + if (top == func) { + break; + } + } + } + return searcher.results_; + } + + static void run(IRNode *ir, TarjanData &tarjan_data) { + GatherFuncStoreDests searcher(nullptr, tarjan_data); + ir->accept(&searcher); + } + + public: + using BasicStmtVisitor::visit; + + GatherFuncStoreDests(Function *func, TarjanData &tarjan_data) + : current_func_(func), tarjan_data_(tarjan_data) { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + + void visit(Stmt *stmt) override { + if (!current_func_) { + return; + } + auto result = irpass::analysis::get_store_destination(stmt); + results_.insert(result.begin(), result.end()); + } + + void visit(FuncCallStmt *stmt) override { + auto func = stmt->func; + if (!current_func_) { + if (!tarjan_data_.func_dfn.count(func)) { + run(func, tarjan_data_); + } + return; + } + if (!tarjan_data_.func_dfn.count(func)) { + auto result = run(func, tarjan_data_); + results_.merge(result); + tarjan_data_.func_low[current_func_] = std::min( + tarjan_data_.func_low[current_func_], tarjan_data_.func_low[func]); + } else if (tarjan_data_.func_in_stack.count(func)) { + tarjan_data_.func_low[current_func_] = std::min( + tarjan_data_.func_low[current_func_], tarjan_data_.func_dfn[func]); + } else { + const auto &dests = func->store_dests; + results_.insert(dests.begin(), dests.end()); + } + } + + static void run(IRNode *ir) { + TarjanData tarjan_data; + run(ir, tarjan_data); + } +}; + +namespace irpass::analysis { +void gather_func_store_dests(IRNode *ir) { + GatherFuncStoreDests::run(ir); +} + +} // namespace irpass::analysis + +} // namespace taichi::lang diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 48eb0b4cae1281..996f63c14825c0 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -213,7 +213,7 @@ std::unique_ptr initialize_mesh_local_attribute( OffloadedStmt *offload, bool auto_mesh_local, const CompileConfig &config); - +void gather_func_store_dests(IRNode *ir); } // namespace analysis } // namespace irpass } // namespace taichi::lang diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 7f777f547a3015..5f8bd03fa22e99 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -6,6 +6,7 @@ #include "taichi/ir/analysis.h" #include "taichi/ir/statements.h" #include "taichi/system/profiler.h" +#include "taichi/program/function.h" namespace taichi::lang { @@ -161,10 +162,6 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // [Intra-block Search] int last_def_position = -1; for (int i = position - 1; i >= begin_location; i--) { - if (block->statements[i]->is()) { - return nullptr; - } - // Find previous store stmt to the same dest_addr, stop at the closest one. // store_ptr: prev-store dest_addr for (auto store_ptr : @@ -216,10 +213,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { } // Check if store_stmt will ever influence the value of var - auto may_contain_address = [](Stmt *store_stmt, Stmt *var) { - if (store_stmt->is()) { - return true; - } + auto may_contain_address = [&](Stmt *store_stmt, Stmt *var) { for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { if (var->is() && !store_ptr->is()) { // check for aliased address with var @@ -698,6 +692,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { if (stmt->is()) { killed_in_this_node.clear(); live_load_in_this_node.clear(); + continue; } auto store_ptrs = irpass::analysis::get_store_destination(stmt); @@ -979,8 +974,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { for (int i = 0; i < num_nodes; i++) { for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { auto stmt = nodes[i]->block->statements[j].get(); - if (stmt->is() || - (stmt->is() && + if ((stmt->is() && stmt->as()->origin->is()) || (!after_lower_access && (stmt->is() || stmt->is() || @@ -991,6 +985,9 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); + } else if (auto func_call = stmt->cast()) { + const auto &dests = func_call->func->store_dests; + nodes[start_node]->reach_gen.insert(dests.begin(), dests.end()); } } } diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index f51ec732e2b176..324e9a93313936 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -7,6 +7,7 @@ namespace taichi::lang { +class Function; /** * A basic block in control-flow graph. * A CFGNode contains a reference to a part of the CHI IR, or more precisely, @@ -113,6 +114,8 @@ class ControlFlowGraph { const int start_node = 0; int final_node{0}; + std::unordered_map> func_store_dests; + template CFGNode *push_back(Args &&...args) { nodes.emplace_back(std::make_unique(std::forward(args)...)); diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 3bbe33893dab58..bec9808a7a7f15 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -2,6 +2,7 @@ #include "taichi/ir/statements.h" #include "taichi/util/bit.h" #include "taichi/program/kernel.h" +#include "taichi/program/function.h" namespace taichi::lang { @@ -278,6 +279,19 @@ FuncCallStmt::FuncCallStmt(Function *func, const std::vector &args) TI_STMT_REG_FIELDS; } +stmt_refs FuncCallStmt::get_store_destination() const { + std::vector ret; + for (auto &arg : args) { + if (auto ref = arg->cast()) { + ret.push_back(ref->var); + } else if (arg->ret_type.is_pointer()) { + ret.push_back(arg); + } + } + ret.insert(ret.end(), func->store_dests.begin(), func->store_dests.end()); + return ret; +} + WhileStmt::WhileStmt(std::unique_ptr &&body) : mask(nullptr), body(std::move(body)) { this->body->set_parent_stmt(this); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 57280c28a6c750..2de685936bf182 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1062,7 +1062,7 @@ class MeshForStmt : public Stmt { /** * Call an inline Taichi function. */ -class FuncCallStmt : public Stmt { +class FuncCallStmt : public Stmt, public ir_traits::Store { public: Function *func; std::vector args; @@ -1074,6 +1074,13 @@ class FuncCallStmt : public Stmt { return global_side_effect; } + // IR Trait: Store + stmt_refs get_store_destination() const override; + + Stmt *get_store_data() const override { + return nullptr; + } + TI_STMT_DEF_FIELDS(ret_type, func, args); TI_DEFINE_ACCEPT_AND_CLONE }; diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index fede01a04b80ec..51030dc1e7aa56 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -17,6 +17,7 @@ #include "taichi/transforms/demote_mesh_statements.h" #include "taichi/transforms/simplify.h" #include "taichi/common/trait.h" +#include "taichi/program/function.h" namespace taichi::lang { @@ -158,10 +159,7 @@ std::unordered_map detect_external_ptr_access_in_task( OffloadedStmt *offload); // compile_to_offloads does the basic compilation to create all the offloaded -// tasks of a Taichi kernel. It's worth pointing out that this doesn't demote -// dense struct fors. This is a necessary workaround to prevent the async -// engine from fusing incompatible offloaded tasks. TODO(Lin): check this -// comment +// tasks of a Taichi kernel. void compile_to_offloads(IRNode *ir, const CompileConfig &config, const Kernel *kernel, @@ -190,16 +188,17 @@ void compile_to_executable(IRNode *ir, bool make_thread_local = false, bool make_block_local = false, bool start_from_ast = true); -// Compile a function with some basic optimizations, so that the number of -// statements is reduced before inlining. +// Compile a function with some basic optimizations void compile_function(IRNode *ir, const CompileConfig &config, Function *func, AutodiffMode autodiff_mode, bool verbose, - bool start_from_ast); + Function::IRStage target_stage); -void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config); +void compile_taichi_functions(IRNode *ir, + const CompileConfig &compile_config, + Function::IRStage target_stage); } // namespace irpass } // namespace taichi::lang diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index c4de4c4056aa14..b06370089e0ca1 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -14,7 +14,7 @@ Function::Function(Program *program, const FunctionKey &func_key) void Function::set_function_body(const std::function &func) { context = std::make_unique(program->compile_config().arch); ir = context->get_root(); - ir_type_ = IRType::AST; + ir_stage_ = IRStage::AST; func(); finalize_params(); @@ -29,7 +29,7 @@ void Function::set_function_body(const std::function &func) { void Function::set_function_body(std::unique_ptr func_body) { ir = std::move(func_body); - ir_type_ = IRType::InitialIR; + ir_stage_ = IRStage::InitialIR; } std::string Function::get_name() const { diff --git a/taichi/program/function.h b/taichi/program/function.h index 712f79f57dba32..15d3ba57cdd948 100644 --- a/taichi/program/function.h +++ b/taichi/program/function.h @@ -1,15 +1,23 @@ #pragma once +#include #include "taichi/program/callable.h" #include "taichi/program/function_key.h" namespace taichi::lang { class Program; +class Stmt; class Function : public Callable { public: - enum class IRType { None, AST, InitialIR, OptimizedIR }; + enum class IRStage : int { + None = 0, + AST = 1, + InitialIR = 2, + BeforeLowerAccess = 3, + OptimizedIR = 4 + }; FunctionKey func_key; @@ -28,16 +36,18 @@ class Function : public Callable { return ast_serialization_data_; } - void set_ir_type(IRType type) { - ir_type_ = type; + void set_ir_stage(IRStage type) { + ir_stage_ = type; } - IRType ir_type() const { - return ir_type_; + IRStage ir_stage() const { + return ir_stage_; } + std::unordered_set store_dests; + private: - IRType ir_type_{IRType::None}; + IRStage ir_stage_{IRStage::None}; std::optional ast_serialization_data_; // For generating AST-Key }; diff --git a/taichi/transforms/compile_taichi_functions.cpp b/taichi/transforms/compile_taichi_functions.cpp index 25d9e1d28df6c7..180d191f49dd02 100644 --- a/taichi/transforms/compile_taichi_functions.cpp +++ b/taichi/transforms/compile_taichi_functions.cpp @@ -10,39 +10,42 @@ class CompileTaichiFunctions : public BasicStmtVisitor { public: using BasicStmtVisitor::visit; - explicit CompileTaichiFunctions(const CompileConfig &compile_config) - : compile_config_(compile_config) { + CompileTaichiFunctions(const CompileConfig &compile_config, + Function::IRStage target_stage) + : compile_config_(compile_config), target_stage_(target_stage) { } void visit(FuncCallStmt *stmt) override { - using IRType = Function::IRType; auto *func = stmt->func; - const auto ir_type = func->ir_type(); - if (ir_type != IRType::OptimizedIR) { - TI_ASSERT(ir_type == IRType::AST || ir_type == IRType::InitialIR); - func->set_ir_type(IRType::OptimizedIR); + const auto ir_type = func->ir_stage(); + if (ir_type < target_stage_) { irpass::compile_function(func->ir.get(), compile_config_, func, /*autodiff_mode=*/AutodiffMode::kNone, /*verbose=*/compile_config_.print_ir, - /*start_from_ast=*/ir_type == IRType::AST); + target_stage_); func->ir->accept(this); } } - static void run(IRNode *ir, const CompileConfig &compile_config) { - CompileTaichiFunctions ctf{compile_config}; + static void run(IRNode *ir, + const CompileConfig &compile_config, + Function::IRStage target_stage) { + CompileTaichiFunctions ctf{compile_config, target_stage}; ir->accept(&ctf); } private: const CompileConfig &compile_config_; + Function::IRStage target_stage_; }; namespace irpass { -void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config) { +void compile_taichi_functions(IRNode *ir, + const CompileConfig &compile_config, + Function::IRStage target_stage) { TI_AUTO_PROF; - CompileTaichiFunctions::run(ir, compile_config); + CompileTaichiFunctions::run(ir, compile_config, target_stage); } } // namespace irpass diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 64d03ac686cdcb..4c770881706ed1 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -44,7 +44,11 @@ void compile_to_offloads(IRNode *ir, print("Lowered"); } - irpass::compile_taichi_functions(ir, config); + irpass::compile_taichi_functions(ir, config, + Function::IRStage::BeforeLowerAccess); + irpass::analysis::gather_func_store_dests(ir); + irpass::compile_taichi_functions(ir, config, Function::IRStage::OptimizedIR); + irpass::analysis::gather_func_store_dests(ir); irpass::eliminate_immutable_local_vars(ir); print("Immutable local vars eliminated"); @@ -330,54 +334,63 @@ void compile_function(IRNode *ir, Function *func, AutodiffMode autodiff_mode, bool verbose, - bool start_from_ast) { + Function::IRStage target_stage) { TI_AUTO_PROF; + auto current_stage = func->ir_stage(); auto print = make_pass_printer(verbose, func->get_name(), ir); print("Initial IR"); - if (autodiff_mode == AutodiffMode::kReverse) { - irpass::reverse_segments(ir); - print("Segment reversed (for autodiff)"); - } + if (target_stage >= Function::IRStage::BeforeLowerAccess && + current_stage < Function::IRStage::BeforeLowerAccess) { + if (autodiff_mode == AutodiffMode::kReverse) { + irpass::reverse_segments(ir); + print("Segment reversed (for autodiff)"); + } - if (start_from_ast) { - irpass::frontend_type_check(ir); - irpass::lower_ast(ir); - print("Lowered"); - } + if (current_stage < Function::IRStage::InitialIR) { + irpass::frontend_type_check(ir); + irpass::lower_ast(ir); + print("Lowered"); + } - if (config.real_matrix_scalarize) { - if (irpass::scalarize(ir)) { - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::die(ir); - print("Scalarized"); + if (config.real_matrix_scalarize) { + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::die(ir); + print("Scalarized"); + } } + func->set_ir_stage(Function::IRStage::BeforeLowerAccess); } - irpass::lower_access(ir, config, {{}, true}); - print("Access lowered"); - irpass::analysis::verify(ir); + if (target_stage >= Function::IRStage::OptimizedIR && + current_stage < Function::IRStage::OptimizedIR) { + irpass::lower_access(ir, config, {{}, true}); + print("Access lowered"); + irpass::analysis::verify(ir); - irpass::die(ir); - print("DIE"); - irpass::analysis::verify(ir); + irpass::die(ir); + print("DIE"); + irpass::analysis::verify(ir); - irpass::flag_access(ir); - print("Access flagged III"); - irpass::analysis::verify(ir); + irpass::flag_access(ir); + print("Access flagged III"); + irpass::analysis::verify(ir); - irpass::type_check(ir, config); - print("Typechecked"); + irpass::type_check(ir, config); + print("Typechecked"); - irpass::demote_operations(ir, config); - print("Operations demoted"); + irpass::demote_operations(ir, config); + print("Operations demoted"); - irpass::full_simplify( - ir, config, - {false, autodiff_mode != AutodiffMode::kNone, func->get_name(), verbose}); - print("Simplified"); - irpass::analysis::verify(ir); + irpass::full_simplify(ir, config, + {true, autodiff_mode != AutodiffMode::kNone, + func->get_name(), verbose}); + print("Simplified"); + irpass::analysis::verify(ir); + func->set_ir_stage(Function::IRStage::OptimizedIR); + } } } // namespace irpass