From 2a71e2ab9a7fdcb49851ecbb43d1b68e345c9240 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Tue, 13 Oct 2020 10:50:57 +0800 Subject: [PATCH] [refactor] Move OffloadedStmt::TaskType to a separate file (#1946) --- taichi/analysis/clone.cpp | 1 + taichi/backends/cc/codegen_cc.cpp | 1 + taichi/backends/cpu/codegen_cpu.cpp | 1 + taichi/backends/cuda/codegen_cuda.cpp | 1 + taichi/backends/metal/codegen_metal.cpp | 1 + taichi/backends/metal/kernel_manager.cpp | 2 +- taichi/backends/metal/kernel_manager.h | 3 ++- taichi/backends/metal/kernel_util.cpp | 4 ++-- taichi/backends/metal/kernel_util.h | 4 ++-- taichi/backends/opengl/codegen_opengl.cpp | 1 + taichi/backends/opengl/opengl_api.h | 7 +++++-- taichi/backends/opengl/opengl_kernel_util.h | 2 +- taichi/backends/opengl/struct_opengl.cpp | 2 ++ taichi/codegen/codegen_llvm.cpp | 1 + taichi/inc/offloaded_task_type.inc.h | 5 +++++ taichi/ir/frontend.cpp | 3 ++- taichi/ir/offloaded_task_type.cpp | 15 +++++++++++++++ taichi/ir/offloaded_task_type.h | 17 +++++++++++++++++ taichi/ir/snode.cpp | 3 ++- taichi/ir/statements.cpp | 10 +--------- taichi/ir/statements.h | 11 +++-------- taichi/program/async_engine.cpp | 3 ++- taichi/program/async_engine.h | 1 - taichi/program/async_utils.cpp | 16 ++++++++-------- taichi/program/async_utils.h | 12 ++++++++---- taichi/program/ir_bank.cpp | 3 ++- taichi/program/kernel.cpp | 1 + taichi/program/program.cpp | 1 + taichi/program/state_flow_graph.cpp | 8 +++++--- taichi/program/state_flow_graph.h | 1 - taichi/python/export_lang.cpp | 1 + taichi/transforms/alg_simp.cpp | 1 + taichi/transforms/auto_diff.cpp | 5 +++-- taichi/transforms/binary_op_simplify.cpp | 1 + taichi/transforms/constant_fold.cpp | 1 + taichi/transforms/demote_operations.cpp | 1 + taichi/transforms/insert_scratch_pad.cpp | 2 +- taichi/transforms/ir_printer.cpp | 8 ++++---- taichi/transforms/lower_access.cpp | 1 + taichi/transforms/make_block_local.cpp | 1 + taichi/transforms/make_thread_local.cpp | 4 ++-- taichi/transforms/offload.cpp | 1 + taichi/transforms/simplify.cpp | 1 + taichi/transforms/type_check.cpp | 1 + taichi/transforms/variable_optimization.cpp | 2 +- tests/cpp/test_alg_simp.cpp | 1 + tests/cpp/test_same_statements.cpp | 1 + tests/cpp/test_simplify.cpp | 1 + 48 files changed, 118 insertions(+), 57 deletions(-) create mode 100644 taichi/inc/offloaded_task_type.inc.h create mode 100644 taichi/ir/offloaded_task_type.cpp create mode 100644 taichi/ir/offloaded_task_type.h diff --git a/taichi/analysis/clone.cpp b/taichi/analysis/clone.cpp index 0037b86d8ad24..540993a56f27d 100644 --- a/taichi/analysis/clone.cpp +++ b/taichi/analysis/clone.cpp @@ -1,5 +1,6 @@ #include "taichi/ir/ir.h" #include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" #include "taichi/program/program.h" diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index d7ca78881411d..c94648a69d192 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -2,6 +2,7 @@ #include "cc_kernel.h" #include "cc_layout.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/util/line_appender.h" #include "taichi/util/str.h" diff --git a/taichi/backends/cpu/codegen_cpu.cpp b/taichi/backends/cpu/codegen_cpu.cpp index a3d4266bd18fe..1d2686fa26ef8 100644 --- a/taichi/backends/cpu/codegen_cpu.cpp +++ b/taichi/backends/cpu/codegen_cpu.cpp @@ -6,6 +6,7 @@ #include "taichi/lang_util.h" #include "taichi/program/program.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/util/statistics.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 1c19ccb581348..0356c3d7ebfad 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -7,6 +7,7 @@ #include "taichi/util/io.h" #include "taichi/util/statistics.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/program/program.h" #include "taichi/lang_util.h" #include "taichi/backends/cuda/cuda_driver.h" diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index cf9f5570c9480..0c37bac93f43c 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -8,6 +8,7 @@ #include "taichi/backends/metal/env_config.h" #include "taichi/backends/metal/features.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/math/arithmetic.h" #include "taichi/util/line_appender.h" diff --git a/taichi/backends/metal/kernel_manager.cpp b/taichi/backends/metal/kernel_manager.cpp index 2469f0c2d42be..bc2468407ea9d 100644 --- a/taichi/backends/metal/kernel_manager.cpp +++ b/taichi/backends/metal/kernel_manager.cpp @@ -34,7 +34,7 @@ namespace shaders { #include "taichi/backends/metal/shaders/runtime_utils.metal.h" } // namespace shaders -using KernelTaskType = OffloadedStmt::TaskType; +using KernelTaskType = OffloadedTaskType; using BufferEnum = KernelAttributes::Buffers; inline int infer_msl_version(const TaichiKernelAttributes::UsedFeatures &f) { diff --git a/taichi/backends/metal/kernel_manager.h b/taichi/backends/metal/kernel_manager.h index f8570b339850a..7bff7170c55d0 100644 --- a/taichi/backends/metal/kernel_manager.h +++ b/taichi/backends/metal/kernel_manager.h @@ -6,9 +6,10 @@ #include #include "taichi/backends/metal/kernel_util.h" +#include "taichi/backends/metal/struct_metal.h" #include "taichi/lang_util.h" +#include "taichi/program/compile_config.h" #include "taichi/program/kernel_profiler.h" -#include "taichi/backends/metal/struct_metal.h" #include "taichi/system/memory_pool.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/backends/metal/kernel_util.cpp b/taichi/backends/metal/kernel_util.cpp index 20cef30eb0273..994061fbe139c 100644 --- a/taichi/backends/metal/kernel_util.cpp +++ b/taichi/backends/metal/kernel_util.cpp @@ -42,13 +42,13 @@ std::string KernelAttributes::debug_string() const { std::string result; result += fmt::format( "snode->id); } result += ">"; diff --git a/taichi/backends/metal/kernel_util.h b/taichi/backends/metal/kernel_util.h index 7473865e8deb7..3a6c6ad2edde9 100644 --- a/taichi/backends/metal/kernel_util.h +++ b/taichi/backends/metal/kernel_util.h @@ -4,7 +4,7 @@ #include #include -#include "taichi/ir/statements.h" +#include "taichi/ir/offloaded_task_type.h" #include "taichi/backends/metal/data_types.h" // Data structures defined in this file may overlap with some of the Taichi data @@ -43,7 +43,7 @@ struct KernelAttributes { }; std::string name; int num_threads; - OffloadedStmt::TaskType task_type; + OffloadedTaskType task_type; struct RangeForAttributes { // |begin| has differen meanings depending on |const_begin|: diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 02620a8129479..7cd6f85f50b72 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -7,6 +7,7 @@ #include "taichi/backends/opengl/opengl_data_types.h" #include "taichi/backends/opengl/opengl_kernel_util.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/util/line_appender.h" #include "taichi/util/macros.h" diff --git a/taichi/backends/opengl/opengl_api.h b/taichi/backends/opengl/opengl_api.h index 939390ab0fb21..4e46766992b98 100644 --- a/taichi/backends/opengl/opengl_api.h +++ b/taichi/backends/opengl/opengl_api.h @@ -6,8 +6,11 @@ #include #include -#include "opengl_kernel_util.h" -#include "opengl_kernel_launcher.h" +#include "taichi/backends/opengl/opengl_kernel_util.h" +#include "taichi/backends/opengl/opengl_kernel_launcher.h" +#define TI_RUNTIME_HOST +#include "taichi/program/context.h" +#undef TI_RUNTIME_HOST TLANG_NAMESPACE_BEGIN diff --git a/taichi/backends/opengl/opengl_kernel_util.h b/taichi/backends/opengl/opengl_kernel_util.h index 5dd62233b70a4..ba14f1e237f5d 100644 --- a/taichi/backends/opengl/opengl_kernel_util.h +++ b/taichi/backends/opengl/opengl_kernel_util.h @@ -4,7 +4,7 @@ #include #include -#include "taichi/ir/statements.h" +#include "taichi/ir/snode.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/backends/opengl/struct_opengl.cpp b/taichi/backends/opengl/struct_opengl.cpp index 1470123e771ac..6799d35c1f0cc 100644 --- a/taichi/backends/opengl/struct_opengl.cpp +++ b/taichi/backends/opengl/struct_opengl.cpp @@ -1,5 +1,7 @@ #include "struct_opengl.h" +#include "taichi/ir/snode.h" + TLANG_NAMESPACE_BEGIN namespace opengl { diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 9d378349b0b09..7bd4e484729cd 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1,5 +1,6 @@ #include "taichi/codegen/codegen_llvm.h" +#include "taichi/ir/statements.h" #include "taichi/struct/struct_llvm.h" #include "taichi/util/file_sequence_writer.h" diff --git a/taichi/inc/offloaded_task_type.inc.h b/taichi/inc/offloaded_task_type.inc.h new file mode 100644 index 0000000000000..36ae476708f6e --- /dev/null +++ b/taichi/inc/offloaded_task_type.inc.h @@ -0,0 +1,5 @@ +PER_TASK_TYPE(serial) +PER_TASK_TYPE(range_for) +PER_TASK_TYPE(struct_for) +PER_TASK_TYPE(listgen) +PER_TASK_TYPE(gc) diff --git a/taichi/ir/frontend.cpp b/taichi/ir/frontend.cpp index f3a61e47db293..ab0f5eff1af50 100644 --- a/taichi/ir/frontend.cpp +++ b/taichi/ir/frontend.cpp @@ -1,6 +1,7 @@ // Frontend constructs -#include "frontend.h" +#include "taichi/ir/frontend.h" +#include "taichi/ir/statements.h" #include "taichi/program/program.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/ir/offloaded_task_type.cpp b/taichi/ir/offloaded_task_type.cpp new file mode 100644 index 0000000000000..3203c53796aba --- /dev/null +++ b/taichi/ir/offloaded_task_type.cpp @@ -0,0 +1,15 @@ +#include "taichi/ir/offloaded_task_type.h" + +TLANG_NAMESPACE_BEGIN + +std::string offloaded_task_type_name(OffloadedTaskType tt) { + if (false) { + } +#define PER_TASK_TYPE(x) else if (tt == OffloadedTaskType::x) return #x; +#include "taichi/inc/offloaded_task_type.inc.h" +#undef PER_TASK_TYPE + else + TI_NOT_IMPLEMENTED +} + +TLANG_NAMESPACE_END diff --git a/taichi/ir/offloaded_task_type.h b/taichi/ir/offloaded_task_type.h new file mode 100644 index 0000000000000..e32a338259a99 --- /dev/null +++ b/taichi/ir/offloaded_task_type.h @@ -0,0 +1,17 @@ +#pragma once + +#include "taichi/common/core.h" + +#include + +TLANG_NAMESPACE_BEGIN + +enum class OffloadedTaskType : int { +#define PER_TASK_TYPE(x) x, +#include "taichi/inc/offloaded_task_type.inc.h" +#undef PER_TASK_TYPE +}; + +std::string offloaded_task_type_name(OffloadedTaskType tt); + +TLANG_NAMESPACE_END diff --git a/taichi/ir/snode.cpp b/taichi/ir/snode.cpp index 91ac84608e139..8b17f7c0fd407 100644 --- a/taichi/ir/snode.cpp +++ b/taichi/ir/snode.cpp @@ -2,6 +2,7 @@ #include "taichi/ir/ir.h" #include "taichi/ir/frontend.h" +#include "taichi/ir/statements.h" #include "taichi/backends/cuda/cuda_driver.h" TLANG_NAMESPACE_BEGIN @@ -18,7 +19,7 @@ void set_kernel_args(const std::vector &I, } // namespace -std::atomic SNode::counter = 0; +std::atomic SNode::counter{0}; SNode &SNode::insert_children(SNodeType t) { TI_ASSERT(t != SNodeType::root); diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 18b9460b765d5..29b859359456b 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -317,15 +317,7 @@ std::string OffloadedStmt::task_name() const { // static std::string OffloadedStmt::task_type_name(TaskType tt) { -#define REGISTER_NAME(x) \ - { TaskType::x, #x } - const static std::unordered_map m = { - REGISTER_NAME(serial), REGISTER_NAME(range_for), - REGISTER_NAME(struct_for), REGISTER_NAME(listgen), - REGISTER_NAME(gc), - }; -#undef REGISTER_NAME - return m.find(tt)->second; + return offloaded_task_type_name(tt); } std::unique_ptr OffloadedStmt::clone() const { diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 0c6924bb77049..219e6ff1dbac4 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1,6 +1,7 @@ #pragma once #include "taichi/ir/ir.h" +#include "taichi/ir/offloaded_task_type.h" #include "taichi/ir/scratch_pad.h" TLANG_NAMESPACE_BEGIN @@ -789,13 +790,7 @@ class GetChStmt : public Stmt { class OffloadedStmt : public Stmt { public: - enum TaskType : int { - serial, - range_for, - struct_for, - listgen, - gc, - }; + using TaskType = OffloadedTaskType; TaskType task_type; SNode *snode; @@ -830,7 +825,7 @@ class OffloadedStmt : public Stmt { static std::string task_type_name(TaskType tt); bool has_body() const { - return task_type != listgen && task_type != gc; + return task_type != TaskType::listgen && task_type != TaskType::gc; } bool is_container_statement() const override { diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 2a58e42e89bd6..d0f06f4bf4b85 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -7,8 +7,9 @@ #include "taichi/backends/cpu/codegen_cpu.h" #include "taichi/util/testing.h" #include "taichi/util/statistics.h" -#include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" #include "taichi/program/extension.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index d2e205b62fc6d..8a64c0e17f904 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -8,7 +8,6 @@ #include #include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" #include "taichi/lang_util.h" #define TI_RUNTIME_HOST #include "taichi/program/context.h" diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index 2b59dcbf8c217..5bd85a52a9031 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -38,7 +38,7 @@ bool TaskLaunchRecord::empty() const { void TaskMeta::print() const { fmt::print("TaskMeta\n name {}\n", name); - fmt::print(" type {}\n", OffloadedStmt::task_type_name(type)); + fmt::print(" type {}\n", offloaded_task_type_name(type)); if (snode != nullptr) { fmt::print(" snode {}\n", snode->get_node_type_name_hinted()); } else { @@ -99,8 +99,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { TaskMeta meta; // TODO: this is an abuse since it gathers nothing... auto *root_stmt = t.stmt(); - meta.name = t.kernel->name + "_" + - OffloadedStmt::task_type_name(root_stmt->task_type); + meta.name = + t.kernel->name + "_" + offloaded_task_type_name(root_stmt->task_type); meta.type = root_stmt->task_type; get_meta_input_value_states(root_stmt, &meta); gather_statements(root_stmt, [&](Stmt *stmt) { @@ -160,14 +160,14 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { } } - if (root_stmt->task_type == OffloadedStmt::listgen) { + if (root_stmt->task_type == OffloadedTaskType::listgen) { TI_ASSERT(root_stmt->snode->parent); meta.snode = root_stmt->snode; meta.input_states.emplace(root_stmt->snode->parent, AsyncState::Type::list); meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list); meta.input_states.emplace(root_stmt->snode, AsyncState::Type::mask); meta.output_states.emplace(root_stmt->snode, AsyncState::Type::list); - } else if (root_stmt->task_type == OffloadedStmt::struct_for) { + } else if (root_stmt->task_type == OffloadedTaskType::struct_for) { meta.snode = root_stmt->snode; meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list); } @@ -197,10 +197,10 @@ TaskFusionMeta get_task_fusion_meta(IRBank *bank, const TaskLaunchRecord &t) { auto *task = t.stmt(); meta.type = task->task_type; - if (task->task_type == OffloadedStmt::struct_for) { + if (task->task_type == OffloadedTaskType::struct_for) { meta.snode = task->snode; meta.block_dim = task->block_dim; - } else if (task->task_type == OffloadedStmt::range_for) { + } else if (task->task_type == OffloadedTaskType::range_for) { // TODO: a few problems with the range-for test condition: // 1. This could incorrectly fuse two range-for kernels that have // different sizes, but then the loop ranges get padded to the same @@ -215,7 +215,7 @@ TaskFusionMeta get_task_fusion_meta(IRBank *bank, const TaskLaunchRecord &t) { } meta.begin_value = task->begin_value; meta.end_value = task->end_value; - } else if (task->task_type != OffloadedStmt::serial) { + } else if (task->task_type != OffloadedTaskType::serial) { // Do not fuse gc/listgen tasks. return fusion_meta_bank[t.ir_handle] = TaskFusionMeta(); } diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 96ec7517e7eda..42ba24617df35 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -5,7 +5,7 @@ #include #include "taichi/ir/snode.h" -#include "taichi/ir/statements.h" +#include "taichi/ir/offloaded_task_type.h" #define TI_RUNTIME_HOST #include "taichi/program/context.h" #undef TI_RUNTIME_HOST @@ -14,6 +14,9 @@ TLANG_NAMESPACE_BEGIN struct TaskMeta; +class IRNode; +class OffloadedStmt; + class IRHandle { public: IRHandle() : ir_(nullptr), hash_(0) { @@ -110,7 +113,7 @@ struct AsyncState { struct TaskFusionMeta { // meta for task fusion - OffloadedStmt::TaskType type{OffloadedStmt::TaskType::serial}; + OffloadedTaskType type{OffloadedTaskType::serial}; SNode *snode{nullptr}; // struct-for only int block_dim{0}; // struct-for only int32 begin_value{0}; // range-for only @@ -172,7 +175,8 @@ struct hash { template <> struct hash { std::size_t operator()(const taichi::lang::TaskFusionMeta &t) const noexcept { - std::size_t result = (t.type << 1) ^ t.fusible ^ (std::size_t)t.kernel; + std::size_t result = + ((std::size_t)t.type << 1) ^ t.fusible ^ (std::size_t)t.kernel; result ^= (std::size_t)t.block_dim * 100000007UL + (std::size_t)t.snode; result ^= ((std::size_t)t.begin_value << 32) ^ t.end_value; return result; @@ -185,7 +189,7 @@ TLANG_NAMESPACE_BEGIN struct TaskMeta { std::string name; - OffloadedStmt::TaskType type{OffloadedStmt::TaskType::serial}; + OffloadedTaskType type{OffloadedTaskType::serial}; SNode *snode{nullptr}; // struct-for and listgen only std::unordered_set input_states; std::unordered_set output_states; diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index 16877249aa68b..ce500e1a1e6bf 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -1,7 +1,8 @@ #include "taichi/program/ir_bank.h" -#include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" #include "taichi/program/kernel.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index ebaeb9a072b71..82850f3cdda61 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -6,6 +6,7 @@ #include "taichi/program/async_engine.h" #include "taichi/codegen/codegen.h" #include "taichi/backends/cuda/cuda_driver.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/util/action_recorder.h" #include "taichi/program/extension.h" diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 313e929b41c8b..c6bc59bfb3131 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -2,6 +2,7 @@ #include "program.h" +#include "taichi/ir/statements.h" #include "taichi/program/extension.h" #include "taichi/backends/metal/api.h" #include "taichi/backends/opengl/opengl_api.h" diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index d1b7b2ec04a36..f90399c176b19 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -6,6 +6,7 @@ #include #include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/program/async_engine.h" #include "taichi/util/statistics.h" @@ -1054,8 +1055,9 @@ bool StateFlowGraph::optimize_dead_store() { const auto mt = meta.type; // Do NOT check ir->body->statements first! |ir->body| could be done when // |mt| is not the desired type. - if ((mt == OffloadedStmt::serial || mt == OffloadedStmt::struct_for || - mt == OffloadedStmt::range_for) && + if ((mt == OffloadedTaskType::serial || + mt == OffloadedTaskType::struct_for || + mt == OffloadedTaskType::range_for) && ir->body->statements.empty()) { to_delete.insert(i + first_pending_task_index_); } @@ -1180,7 +1182,7 @@ bool StateFlowGraph::demote_activation() { auto list_state = AsyncState(snode, AsyncState::Type::list); // TODO: handle serial and range for - if (node->meta->type != OffloadedStmt::struct_for) + if (node->meta->type != OffloadedTaskType::struct_for) continue; if (get_or_insert(node->input_edges, list_state).size() != 1) diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index 883b0b5cbf99f..1bae87a0ff719 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -8,7 +8,6 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" #include "taichi/lang_util.h" #include "taichi/program/async_utils.h" #include "taichi/program/program.h" diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 270f06fb86986..a377a07bccaf1 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -8,6 +8,7 @@ #include "taichi/ir/frontend.h" #include "taichi/ir/frontend_ir.h" +#include "taichi/ir/statements.h" #include "taichi/program/extension.h" #include "taichi/program/async_engine.h" #include "taichi/common/interface.h" diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 16b43808f6bca..bcb13c3ed6830 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" #include "taichi/program/program.h" diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index b541700b6bf72..feff8e8ea89eb 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -1,8 +1,9 @@ +#include "taichi/ir/analysis.h" +#include "taichi/ir/frontend.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" -#include "taichi/ir/analysis.h" #include "taichi/ir/visitors.h" -#include "taichi/ir/frontend.h" #include diff --git a/taichi/transforms/binary_op_simplify.cpp b/taichi/transforms/binary_op_simplify.cpp index a07a47a97d55b..a3a478469cb34 100644 --- a/taichi/transforms/binary_op_simplify.cpp +++ b/taichi/transforms/binary_op_simplify.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" #include "taichi/program/program.h" diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 8c34c860b3b5a..ebaa5c02613fd 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -5,6 +5,7 @@ #include "taichi/ir/ir.h" #include "taichi/ir/snode.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" #include "taichi/program/program.h" diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index e63e4ad5f8f3b..cf0aac343ef5e 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/visitors.h" #include "taichi/program/program.h" diff --git a/taichi/transforms/insert_scratch_pad.cpp b/taichi/transforms/insert_scratch_pad.cpp index 40803ffe647ea..99d8070756cea 100644 --- a/taichi/transforms/insert_scratch_pad.cpp +++ b/taichi/transforms/insert_scratch_pad.cpp @@ -122,7 +122,7 @@ class AccessAnalysis : public BasicStmtVisitor { namespace irpass { std::unique_ptr initialize_scratch_pad(OffloadedStmt *offload) { - TI_ASSERT(offload->task_type == OffloadedStmt::struct_for); + TI_ASSERT(offload->task_type == OffloadedTaskType::struct_for); std::unique_ptr pads; pads = std::make_unique(); if (!offload->scratch_opt.empty()) { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 51cce6db5847b..a7428f4f59ea4 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -486,7 +486,7 @@ class IRPrinter : public IRVisitor { void visit(OffloadedStmt *stmt) override { std::string details; - if (stmt->task_type == stmt->range_for) { + if (stmt->task_type == OffloadedTaskType::range_for) { std::string begin_str, end_str; if (stmt->const_begin) { begin_str = std::to_string(stmt->begin_value); @@ -501,17 +501,17 @@ class IRPrinter : public IRVisitor { details = fmt::format("range_for({}, {}) grid_dim={} block_dim={}", begin_str, end_str, stmt->grid_dim, stmt->block_dim); - } else if (stmt->task_type == stmt->struct_for) { + } else if (stmt->task_type == OffloadedTaskType::struct_for) { details = fmt::format("struct_for({}) grid_dim={} block_dim={} bls={}", stmt->snode->get_node_type_name_hinted(), stmt->grid_dim, stmt->block_dim, scratch_pad_info(stmt->scratch_opt)); } - if (stmt->task_type == OffloadedStmt::TaskType::listgen) { + if (stmt->task_type == OffloadedTaskType::listgen) { print("{} = offloaded listgen {}->{}", stmt->name(), stmt->snode->parent->get_node_type_name_hinted(), stmt->snode->get_node_type_name_hinted()); - } else if (stmt->task_type == OffloadedStmt::TaskType::gc) { + } else if (stmt->task_type == OffloadedTaskType::gc) { print("{} = offloaded garbage collect {}", stmt->name(), stmt->snode->get_node_type_name_hinted()); } else { diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index fb6ddcdd0cb2e..79285c5519b76 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" #include "taichi/ir/visitors.h" diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 759de772b027c..49c4446f9e9bd 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" #include "taichi/ir/visitors.h" diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index 7796051d04527..a0db654c7d1ef 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -86,8 +86,8 @@ std::vector find_global_reduction_destinations( } void make_thread_local_offload(OffloadedStmt *offload) { - if (offload->task_type != OffloadedStmt::range_for && - offload->task_type != OffloadedStmt::struct_for) + if (offload->task_type != OffloadedTaskType::range_for && + offload->task_type != OffloadedTaskType::struct_for) return; std::vector valid_reduction_values; diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index dfd6326f757c8..b3f761bd49764 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" #include "taichi/ir/visitors.h" diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 61028ba4c7ddc..9ec76b1f99922 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" #include "taichi/ir/visitors.h" diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index f1690d2aed59a..37b58de7d20e1 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -1,6 +1,7 @@ // Type checking #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" #include "taichi/ir/visitors.h" diff --git a/taichi/transforms/variable_optimization.cpp b/taichi/transforms/variable_optimization.cpp index 40b4fe0c0e76d..4937ddf2e0841 100644 --- a/taichi/transforms/variable_optimization.cpp +++ b/taichi/transforms/variable_optimization.cpp @@ -318,7 +318,7 @@ class GlobalTempOptimize : public VariableOptimize { } void visit(OffloadedStmt *stmt) override { - if (stmt->task_type == stmt->range_for) { + if (stmt->task_type == OffloadedTaskType::range_for) { TI_ASSERT(!maybe_run); if (!stmt->const_begin) { TI_ASSERT(state_machines.find(stmt->begin_offset) != diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index 38e6e3fb41f16..1dfc9301d3017 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/frontend.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/util/testing.h" diff --git a/tests/cpp/test_same_statements.cpp b/tests/cpp/test_same_statements.cpp index a0be1c68ebdbd..16a0700de6dee 100644 --- a/tests/cpp/test_same_statements.cpp +++ b/tests/cpp/test_same_statements.cpp @@ -1,6 +1,7 @@ #include "taichi/ir/frontend.h" #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" #include "taichi/util/testing.h" TLANG_NAMESPACE_BEGIN diff --git a/tests/cpp/test_simplify.cpp b/tests/cpp/test_simplify.cpp index 4dad03a816e99..be84ddbaf091b 100644 --- a/tests/cpp/test_simplify.cpp +++ b/tests/cpp/test_simplify.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/frontend.h" +#include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" #include "taichi/util/testing.h"