diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index d83fa1158e38c..0b748abfc221b 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -3,13 +3,14 @@ #include #include +#include "taichi/backends/metal/api.h" #include "taichi/backends/metal/constants.h" +#include "taichi/backends/metal/env_config.h" #include "taichi/backends/metal/features.h" #include "taichi/ir/ir.h" #include "taichi/ir/transforms.h" -#include "taichi/util/line_appender.h" #include "taichi/math/arithmetic.h" -#include "taichi/backends/metal/api.h" +#include "taichi/util/line_appender.h" TLANG_NAMESPACE_BEGIN namespace metal { @@ -83,13 +84,17 @@ class KernelCodegen : public IRVisitor { }; public: + struct Config { + bool allow_simdgroup = true; + }; // TODO(k-ye): Create a Params to hold these ctor params. KernelCodegen(const std::string &taichi_kernel_name, const std::string &root_snode_type_name, Kernel *kernel, const CompiledStructs *compiled_structs, PrintStringTable *print_strtab, - const CodeGen::Config &config) + const Config &config, + OffloadedStmt *offloaded) : mtl_kernel_prefix_(taichi_kernel_name), root_snode_type_name_(root_snode_type_name), kernel_(kernel), @@ -97,7 +102,8 @@ class KernelCodegen : public IRVisitor { needs_root_buffer_(compiled_structs_->root_size > 0), ctx_attribs_(*kernel_), print_strtab_(print_strtab), - cgen_config_(config) { + cgen_config_(config), + offloaded_(offloaded) { ti_kernel_attribus_.name = taichi_kernel_name; ti_kernel_attribus_.is_jit_evaluator = kernel->is_evaluator; // allow_undefined_visitor = true; @@ -769,7 +775,8 @@ class KernelCodegen : public IRVisitor { void generate_kernels() { SectionGuard sg(this, Section::Kernels); - kernel_->ir->accept(this); + IRNode *ast = offloaded_ ? offloaded_ : kernel_->ir.get(); + ast->accept(this); if (used_features()->sparse) { emit(""); @@ -1212,7 +1219,8 @@ class KernelCodegen : public IRVisitor { const bool needs_root_buffer_; const KernelContextAttributes ctx_attribs_; PrintStringTable *const print_strtab_; - const CodeGen::Config &cgen_config_; + const Config &cgen_config_; + OffloadedStmt *const offloaded_; bool is_top_level_{true}; int mtl_kernel_count_{0}; @@ -1226,36 +1234,27 @@ class KernelCodegen : public IRVisitor { } // namespace -CodeGen::CodeGen(Kernel *kernel, - KernelManager *kernel_mgr, - const CompiledStructs *compiled_structs, - const Config &config) - : kernel_(kernel), - kernel_mgr_(kernel_mgr), - compiled_structs_(compiled_structs), - id_(Program::get_kernel_id()), - taichi_kernel_name_(fmt::format("mtl_k{:04d}_{}", id_, kernel_->name)), - config_(config) { -} +FunctionType compile_to_metal_executable( + Kernel *kernel, + KernelManager *kernel_mgr, + const CompiledStructs *compiled_structs, + OffloadedStmt *offloaded) { + const auto id = Program::get_kernel_id(); + const auto taichi_kernel_name( + fmt::format("mtl_k{:04d}_{}", id, kernel->name)); -FunctionType CodeGen::compile() { - auto &config = kernel_->program.config; - config.demote_dense_struct_fors = true; - irpass::compile_to_executable(kernel_->ir.get(), config, - /*vectorize=*/false, kernel_->grad, - /*ad_use_stack=*/true, config.print_ir, - /*lower_global_access=*/true, - /*make_thread_local=*/config.make_thread_local); + KernelCodegen::Config cgen_config; + cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled(); KernelCodegen codegen( - taichi_kernel_name_, kernel_->program.snode_root->node_type_name, kernel_, - compiled_structs_, kernel_mgr_->print_strtable(), config_); + taichi_kernel_name, kernel->program.snode_root->node_type_name, kernel, + compiled_structs, kernel_mgr->print_strtable(), cgen_config, offloaded); + const auto source_code = codegen.run(); - kernel_mgr_->register_taichi_kernel(taichi_kernel_name_, source_code, - codegen.ti_kernels_attribs(), - codegen.kernel_ctx_attribs()); - return [kernel_mgr = kernel_mgr_, - kernel_name = taichi_kernel_name_](Context &ctx) { + kernel_mgr->register_taichi_kernel(taichi_kernel_name, source_code, + codegen.ti_kernels_attribs(), + codegen.kernel_ctx_attribs()); + return [kernel_mgr, kernel_name = taichi_kernel_name](Context &ctx) { kernel_mgr->launch_taichi_kernel(kernel_name, &ctx); }; } diff --git a/taichi/backends/metal/codegen_metal.h b/taichi/backends/metal/codegen_metal.h index 468180306d973..0783894c4fdd3 100644 --- a/taichi/backends/metal/codegen_metal.h +++ b/taichi/backends/metal/codegen_metal.h @@ -16,30 +16,14 @@ TLANG_NAMESPACE_BEGIN namespace metal { -class CodeGen { - public: - struct Config { - bool allow_simdgroup = true; - }; - - CodeGen(Kernel *kernel, - KernelManager *kernel_mgr, - const CompiledStructs *compiled_structs, - const Config &config); - - FunctionType compile(); - - private: - void lower(); - FunctionType gen(const SNode &root_snode, KernelManager *runtime); - - Kernel *const kernel_; - KernelManager *const kernel_mgr_; - const CompiledStructs *const compiled_structs_; - const int id_; - const std::string taichi_kernel_name_; - const Config config_; -}; +// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it +// compiles just |offloaded|. These ASTs must have already been lowered at the +// CHI level. +FunctionType compile_to_metal_executable( + Kernel *kernel, + KernelManager *kernel_mgr, + const CompiledStructs *compiled_structs, + OffloadedStmt *offloaded = nullptr); } // namespace metal diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 859b49e40b100..571ee1c05ea89 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -128,7 +128,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) { auto cloned_stmt = ker.ir_handle.clone(); stmt = cloned_stmt->as(); - compilation_workers.enqueue([async_func, stmt, kernel]() { + compilation_workers.enqueue([async_func, stmt, kernel, this]() { { // Final lowering using namespace irpass; @@ -143,8 +143,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) { is_extension_supported(config.arch, Extension::bls) && config.make_block_local); } - auto codegen = KernelCodeGen::create(kernel->arch, kernel, stmt); - auto func = codegen->codegen(); + auto func = this->compile_to_backend_(*kernel, stmt); async_func->set(func); }); ir_bank_->insert_to_trash_bin(std::move(cloned_stmt)); @@ -161,14 +160,18 @@ void ExecutionQueue::synchronize() { launch_worker.flush(); } -ExecutionQueue::ExecutionQueue(IRBank *ir_bank) +ExecutionQueue::ExecutionQueue( + IRBank *ir_bank, + const BackendExecCompilationFunc &compile_to_backend) : compilation_workers(4), // TODO: remove 4 launch_worker(1), - ir_bank_(ir_bank) { + ir_bank_(ir_bank), + compile_to_backend_(compile_to_backend) { } -AsyncEngine::AsyncEngine(Program *program) - : queue(&ir_bank_), +AsyncEngine::AsyncEngine(Program *program, + const BackendExecCompilationFunc &compile_to_backend) + : queue(&ir_bank_, compile_to_backend), program(program), sfg(std::make_unique(&ir_bank_)) { } diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index 56d6c949541c3..63f336e575dca 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -70,6 +71,10 @@ class ParallelExecutor { std::condition_variable flush_cv_; }; +// Compiles the offloaded and optimized IR to the target backend's executable. +using BackendExecCompilationFunc = + std::function; + // In charge of (parallel) compilation to binary and (serial) kernel launching class ExecutionQueue { public: @@ -78,7 +83,8 @@ class ExecutionQueue { ParallelExecutor compilation_workers; // parallel compilation ParallelExecutor launch_worker; // serial launching - explicit ExecutionQueue(IRBank *ir_bank); + explicit ExecutionQueue(IRBank *ir_bank, + const BackendExecCompilationFunc &compile_to_backend); void enqueue(const TaskLaunchRecord &ker); @@ -117,6 +123,7 @@ class ExecutionQueue { std::unordered_map compiled_funcs_; IRBank *ir_bank_; // not owned + BackendExecCompilationFunc compile_to_backend_; }; // An engine for asynchronous execution and optimization @@ -130,7 +137,8 @@ class AsyncEngine { std::unique_ptr sfg; std::deque task_queue; - explicit AsyncEngine(Program *program); + explicit AsyncEngine(Program *program, + const BackendExecCompilationFunc &compile_to_backend); bool fuse(); // return true when modified diff --git a/taichi/program/extension.cpp b/taichi/program/extension.cpp index 82662c60c30ff..0d3cbbb82ca37 100644 --- a/taichi/program/extension.cpp +++ b/taichi/program/extension.cpp @@ -17,7 +17,7 @@ bool is_extension_supported(Arch arch, Extension ext) { {Arch::cuda, {Extension::sparse, Extension::async_mode, Extension::data64, Extension::adstack, Extension::bls, Extension::assertion}}, - {Arch::metal, {Extension::adstack}}, + {Arch::metal, {Extension::adstack, Extension::async_mode}}, {Arch::opengl, {Extension::extfunc}}, {Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}}, }; diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 6595f9082dfdc..7dde274576b1f 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -69,9 +69,8 @@ void Kernel::compile() { void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class // necessary for each backend? TI_ASSERT(!lowered); - if (arch_is_cpu(arch) || arch == Arch::cuda) { + if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::metal) { CurrentKernelGuard _(program, this); - auto codegen = KernelCodeGen::create(arch, this); auto config = program.config; bool verbose = config.print_ir; if ((is_accessor && !config.print_accessor_ir) || @@ -81,9 +80,9 @@ void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class if (to_executable) { irpass::compile_to_executable( ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad, - /*ad_use_stack*/ true, verbose, /*lower_global_access*/ to_executable, - /*make_thread_local*/ config.make_thread_local, - /*make_block_local*/ + /*ad_use_stack=*/true, verbose, /*lower_global_access=*/to_executable, + /*make_thread_local=*/config.make_thread_local, + /*make_block_local=*/ is_extension_supported(config.arch, Extension::bls) && config.make_block_local); } else { diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 2c32bd373681b..244543832ef49 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -11,7 +11,6 @@ #include "taichi/backends/cuda/cuda_context.h" #endif #include "taichi/backends/metal/codegen_metal.h" -#include "taichi/backends/metal/env_config.h" #include "taichi/backends/opengl/codegen_opengl.h" #include "taichi/backends/cpu/codegen_cpu.h" #include "taichi/struct/struct.h" @@ -146,8 +145,11 @@ Program::Program(Arch desired_arch) { if (config.async_mode) { TI_WARN("Running in async mode. This is experimental."); - TI_ASSERT(arch_is_cpu(config.arch) || config.arch == Arch::cuda); - async_engine = std::make_unique(this); + TI_ASSERT(is_extension_supported(config.arch, Extension::async_mode)); + async_engine = std::make_unique( + this, [this](Kernel &kernel, OffloadedStmt *offloaded) { + return this->compile_to_backend_executable(kernel, offloaded); + }); } // TODO: allow users to run in debug mode without out-of-bound checks @@ -196,17 +198,10 @@ FunctionType Program::compile(Kernel &kernel) { auto start_t = Time::get_time(); TI_AUTO_PROF; FunctionType ret = nullptr; - if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda) { + if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda || + kernel.arch == Arch::metal) { kernel.lower(); - auto codegen = KernelCodeGen::create(kernel.arch, &kernel); - ret = codegen->compile(); - } else if (kernel.arch == Arch::metal) { - metal::CodeGen::Config cgen_config; - cgen_config.allow_simdgroup = - metal::EnvConfig::instance().is_simdgroup_enabled(); - metal::CodeGen codegen(&kernel, metal_kernel_mgr_.get(), - &metal_compiled_structs_.value(), cgen_config); - ret = codegen.compile(); + ret = compile_to_backend_executable(kernel, /*offloaded=*/nullptr); } else if (kernel.arch == Arch::opengl) { opengl::OpenglCodeGen codegen(kernel.name, &opengl_struct_compiled_.value(), opengl_kernel_launcher_.get()); @@ -223,6 +218,20 @@ FunctionType Program::compile(Kernel &kernel) { return ret; } +FunctionType Program::compile_to_backend_executable(Kernel &kernel, + OffloadedStmt *offloaded) { + if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda) { + auto codegen = KernelCodeGen::create(kernel.arch, &kernel, offloaded); + return codegen->compile(); + } else if (kernel.arch == Arch::metal) { + return metal::compile_to_metal_executable(&kernel, metal_kernel_mgr_.get(), + &metal_compiled_structs_.value(), + offloaded); + } + TI_NOT_IMPLEMENTED; + return nullptr; +} + // For CPU and CUDA archs only void Program::initialize_runtime_system(StructCompiler *scomp) { // auto tlctx = llvm_context_host.get(); diff --git a/taichi/program/program.h b/taichi/program/program.h index 7db40e8123480..59ea056985861 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -180,8 +180,14 @@ class Program { void end_function_definition() { } + // TODO: This function is doing two things: 1) compiling CHI IR, and 2) + // offloading them to each backend. We should probably separate the logic? FunctionType compile(Kernel &kernel); + // Just does the per-backend executable compilation without kernel lowering. + FunctionType compile_to_backend_executable(Kernel &kernel, + OffloadedStmt *stmt); + void initialize_runtime_system(StructCompiler *scomp); void materialize_layout(); diff --git a/tests/python/test_fuse_dense.py b/tests/python/test_fuse_dense.py index f5c90338a3906..59fdad84c6771 100644 --- a/tests/python/test_fuse_dense.py +++ b/tests/python/test_fuse_dense.py @@ -3,17 +3,17 @@ template_fuse_reduction -@ti.archs_with([ti.cpu], async_mode=True) +@ti.test(require=ti.extension.async_mode, async_mode=True) def test_fuse_dense_x2y2z(): template_fuse_dense_x2y2z(size=100 * 1024**2) -@ti.archs_with([ti.cpu], async_mode=True) +@ti.test(require=ti.extension.async_mode, async_mode=True) def test_fuse_reduction(): template_fuse_reduction(size=10 * 1024**2) -@ti.archs_with([ti.cpu], async_mode=True) +@ti.test(require=ti.extension.async_mode, async_mode=True) def test_no_fuse_sigs_mismatch(): n = 4096 x = ti.field(ti.i32, shape=(n, )) diff --git a/tests/python/test_fuse_dynamic.py b/tests/python/test_fuse_dynamic.py index 7b78168ad79be..216cec4f6f8f6 100644 --- a/tests/python/test_fuse_dynamic.py +++ b/tests/python/test_fuse_dynamic.py @@ -53,6 +53,7 @@ def y_to_z(): assert z[i] == x[i] + 5 -@ti.archs_with([ti.cpu], async_mode=True) +@ti.test(require=[ti.extension.async_mode, ti.extension.sparse], + async_mode=True) def test_fuse_dynamic_x2y2z(): benchmark_fuse_dynamic_x2y2z() diff --git a/tests/python/test_sfg.py b/tests/python/test_sfg.py index 1d7a3cc32f9a3..2b93ad881c8a3 100644 --- a/tests/python/test_sfg.py +++ b/tests/python/test_sfg.py @@ -3,7 +3,8 @@ import pytest -@ti.test(require=ti.extension.async_mode, async_mode=True) +@ti.test(require=[ti.extension.async_mode, ti.extension.sparse], + async_mode=True) def test_remove_clear_list_from_fused_serial(): x = ti.field(ti.i32) y = ti.field(ti.i32)