diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index 71b29fba2701f..b03a684fc8d45 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -195,10 +195,26 @@ void autograd() { ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n, {n}); - (*kernel_init)(config, ctx_init); - (*kernel_forward)(config, ctx_forward); - (*kernel_backward)(config, ctx_backward); - (*kernel_ext)(config, ctx_ext); + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_init); + program.launch_kernel(compiled_kernel_data, ctx_init); + } + { + const auto &compiled_kernel_data = program.compile_kernel( + config, program.get_device_caps(), *kernel_forward); + program.launch_kernel(compiled_kernel_data, ctx_forward); + } + { + const auto &compiled_kernel_data = program.compile_kernel( + config, program.get_device_caps(), *kernel_backward); + program.launch_kernel(compiled_kernel_data, ctx_backward); + } + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_ext); + program.launch_kernel(compiled_kernel_data, ctx_ext); + } for (int i = 0; i < n; i++) std::cout << ext_a[i] << " "; std::cout << std::endl; diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index cefced65d096d..12a249ac335c5 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -134,11 +134,23 @@ void run_snode() { ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()), n, {n}); - (*kernel_init)(config, ctx_init); - (*kernel_ret)(config, ctx_ret); - std::cout << program.fetch_result(0) << std::endl; - (*kernel_ext)(config, ctx_ext); - for (int i = 0; i < n; i++) - std::cout << ext_arr[i] << " "; - std::cout << std::endl; + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_init); + program.launch_kernel(compiled_kernel_data, ctx_init); + } + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_ret); + program.launch_kernel(compiled_kernel_data, ctx_ret); + std::cout << program.fetch_result(0) << std::endl; + } + { + const auto &compiled_kernel_data = + program.compile_kernel(config, program.get_device_caps(), *kernel_ext); + program.launch_kernel(compiled_kernel_data, ctx_ext); + for (int i = 0; i < n; i++) + std::cout << ext_arr[i] << " "; + std::cout << std::endl; + } } diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index f602f9d3353d9..7ee8e00336381 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -811,7 +811,11 @@ def call_back(): ) try: - t_kernel(launch_ctx) + prog = impl.get_runtime().prog + # Compile kernel (& Online Cache & Offline Cache) + compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel) + # Launch kernel + prog.launch_kernel(compiled_kernel_data, launch_ctx) except Exception as e: e = handle_exception_from_cpp(e) raise e from None diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp index dc3f59fa16fd3..2423f864e3cc9 100644 --- a/taichi/aot/graph_data.cpp +++ b/taichi/aot/graph_data.cpp @@ -30,7 +30,10 @@ void CompiledGraph::jit_run( // Compile & Run (JIT): The compilation result will be cached, so don't // worry that the kernels dispatched by this cgraph will be compiled // repeatedly. - (*dispatch.ti_kernel)(compile_config, launch_ctx); + auto *prog = dispatch.ti_kernel->program; + const auto &compiled_kernel_data = prog->compile_kernel( + compile_config, prog->get_device_caps(), *dispatch.ti_kernel); + prog->launch_kernel(compiled_kernel_data, launch_ctx); } } diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 51a36a812359c..7c1a39c246804 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -40,7 +40,6 @@ Kernel::Kernel(Program &program, this->ir = std::move(ir); this->program = &program; is_accessor = false; - compiled_ = nullptr; ir_is_ast_ = false; // CHI IR if (autodiff_mode == AutodiffMode::kNone) { @@ -56,25 +55,6 @@ Kernel::Kernel(Program &program, } } -void Kernel::compile(const CompileConfig &compile_config) { - compiled_ = program->compile(compile_config, *this); -} - -void Kernel::operator()(const CompileConfig &compile_config, - LaunchContextBuilder &ctx_builder) { - if (!compiled_) { - compile(compile_config); - } - - compiled_(ctx_builder); - - const auto arch = compile_config.arch; - if (compile_config.debug && - (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu)) { - program->check_runtime_error(); - } -} - LaunchContextBuilder Kernel::make_launch_context() { return LaunchContextBuilder(this); } @@ -166,7 +146,6 @@ void Kernel::init(Program &program, this->program = &program; is_accessor = false; - compiled_ = nullptr; context = std::make_unique(program.compile_config().arch); ir = context->get_root(); ir_is_ast_ = true; diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index b5876e633c3a9..b1277419aa97e 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -40,11 +40,6 @@ class TI_DLL_EXPORT Kernel : public Callable { return ir_is_ast_; } - void compile(const CompileConfig &compile_config); - - void operator()(const CompileConfig &compile_config, - LaunchContextBuilder &ctx_builder); - LaunchContextBuilder make_launch_context(); template @@ -75,8 +70,6 @@ class TI_DLL_EXPORT Kernel : public Callable { // True if |ir| is a frontend AST. False if it's already offloaded to CHI IR. bool ir_is_ast_{false}; - // The closure that, if invoked, launches the backend kernel (shader) - FunctionType compiled_{nullptr}; mutable std::string kernel_key_; }; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index fa6afddb65b1e..f21579baf20b7 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -175,14 +175,24 @@ Function *Program::create_function(const FunctionKey &func_key) { return functions_.back().get(); } -FunctionType Program::compile(const CompileConfig &compile_config, - Kernel &kernel) { +const CompiledKernelData &Program::compile_kernel( + const CompileConfig &compile_config, + const DeviceCapabilityConfig &caps, + const Kernel &kernel_def) { auto start_t = Time::get_time(); TI_AUTO_PROF; - auto ret = program_impl_->compile(compile_config, &kernel); - TI_ASSERT(ret); + auto &mgr = program_impl_->get_kernel_compilation_manager(); + const auto &ckd = mgr.load_or_compile(compile_config, caps, kernel_def); total_compilation_time_ += Time::get_time() - start_t; - return ret; + return ckd; +} + +void Program::launch_kernel(const CompiledKernelData &compiled_kernel_data, + LaunchContextBuilder &ctx) { + program_impl_->get_kernel_launcher().launch_kernel(compiled_kernel_data, ctx); + if (compile_config().debug && arch_uses_llvm(compiled_kernel_data.arch())) { + program_impl_->check_runtime_error(result_buffer); + } } void Program::materialize_runtime() { @@ -248,10 +258,6 @@ SNode *Program::get_snode_root(int tree_id) { return snode_trees_[tree_id]->root(); } -void Program::check_runtime_error() { - program_impl_->check_runtime_error(result_buffer); -} - void Program::synchronize() { program_impl_->synchronize(); } diff --git a/taichi/program/program.h b/taichi/program/program.h index a609d389f10eb..f0a53c40cf94b 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -125,11 +125,16 @@ class TI_DLL_EXPORT Program { Function *create_function(const FunctionKey &func_key); - // TODO: This function is doing two things: 1) compiling CHI IR, and 2) - // offloading them to each backend. We should probably separate the logic? - FunctionType compile(const CompileConfig &compile_config, Kernel &kernel); + const CompiledKernelData &compile_kernel(const CompileConfig &compile_config, + const DeviceCapabilityConfig &caps, + const Kernel &kernel_def); - void check_runtime_error(); + void launch_kernel(const CompiledKernelData &compiled_kernel_data, + LaunchContextBuilder &ctx); + + DeviceCapabilityConfig get_device_caps() { + return program_impl_->get_device_caps(); + } Kernel &get_snode_reader(SNode *snode); diff --git a/taichi/program/program_impl.cpp b/taichi/program/program_impl.cpp index e76554eb85d77..597c9abe648f1 100644 --- a/taichi/program/program_impl.cpp +++ b/taichi/program/program_impl.cpp @@ -10,19 +10,6 @@ void ProgramImpl::compile_snode_tree_types(SNodeTree *tree) { TI_NOT_IMPLEMENTED; } -FunctionType ProgramImpl::compile(const CompileConfig &compile_config, - Kernel *kernel) { - // NOTE: Temporary implementation (blocked by cc backend) - // TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel - auto &mgr = get_kernel_compilation_manager(); - const auto &compiled = - mgr.load_or_compile(compile_config, get_device_caps(), *kernel); - auto &launcher = get_kernel_launcher(); - return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) { - launcher.launch_kernel(compiled, ctx_builder); - }; -} - void ProgramImpl::dump_cache_data_to_disk() { auto &mgr = get_kernel_compilation_manager(); mgr.clean_offline_cache(offline_cache::string_to_clean_cache_policy( diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index b9244736eb2ca..5a8b8b29b8cc9 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -34,12 +34,6 @@ class ProgramImpl { public: explicit ProgramImpl(CompileConfig &config); - /** - * Codegen to specific backend - */ - virtual FunctionType compile(const CompileConfig &compile_config, - Kernel *kernel); - /** * Allocate runtime buffer, e.g result_buffer or backend specific runtime * buffer, e.g. preallocated_device_buffer on CUDA. @@ -173,6 +167,10 @@ class ProgramImpl { KernelLauncher &get_kernel_launcher(); + virtual DeviceCapabilityConfig get_device_caps() { + return {}; + } + protected: virtual std::unique_ptr make_kernel_compiler() = 0; @@ -180,10 +178,6 @@ class ProgramImpl { TI_NOT_IMPLEMENTED; } - virtual DeviceCapabilityConfig get_device_caps() { - return {}; - } - private: std::unique_ptr kernel_com_mgr_; std::unique_ptr kernel_launcher_; diff --git a/taichi/program/snode_rw_accessors_bank.cpp b/taichi/program/snode_rw_accessors_bank.cpp index 2a98454803536..15c5fa3296f67 100644 --- a/taichi/program/snode_rw_accessors_bank.cpp +++ b/taichi/program/snode_rw_accessors_bank.cpp @@ -41,14 +41,18 @@ void SNodeRwAccessorsBank::Accessors::write_float(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_float(snode_->num_active_indices, val); prog_->synchronize(); - (*writer_)(prog_->compile_config(), launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *writer_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); } float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - (*reader_)(prog_->compile_config(), launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *reader_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); prog_->synchronize(); if (arch_uses_llvm(prog_->compile_config().arch)) { return launch_ctx.get_struct_ret_float({0}); @@ -64,7 +68,9 @@ void SNodeRwAccessorsBank::Accessors::write_int(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_int(snode_->num_active_indices, val); prog_->synchronize(); - (*writer_)(prog_->compile_config(), launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *writer_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); } // for int32 and int64 @@ -74,14 +80,18 @@ void SNodeRwAccessorsBank::Accessors::write_uint(const std::vector &I, set_kernel_args(I, snode_->num_active_indices, &launch_ctx); launch_ctx.set_arg_uint(snode_->num_active_indices, val); prog_->synchronize(); - (*writer_)(prog_->compile_config(), launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *writer_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); } int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - (*reader_)(prog_->compile_config(), launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *reader_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); prog_->synchronize(); if (arch_uses_llvm(prog_->compile_config().arch)) { return launch_ctx.get_struct_ret_int({0}); @@ -94,7 +104,9 @@ uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector &I) { prog_->synchronize(); auto launch_ctx = reader_->make_launch_context(); set_kernel_args(I, snode_->num_active_indices, &launch_ctx); - (*reader_)(prog_->compile_config(), launch_ctx); + const auto &compiled_kernel_data = prog_->compile_kernel( + prog_->compile_config(), prog_->get_device_caps(), *reader_); + prog_->launch_kernel(compiled_kernel_data, launch_ctx); prog_->synchronize(); if (arch_uses_llvm(prog_->compile_config().arch)) { return launch_ctx.get_struct_ret_uint({0}); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index cd80905973b65..abd86b56e3912 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -327,6 +327,12 @@ void export_lang(py::module &m) { .def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag) .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag); + py::class_( + m, "DeviceCapabilityConfig"); // NOLINT(bugprone-unused-raii) + + py::class_( + m, "CompiledKernelData"); // NOLINT(bugprone-unused-raii) + py::class_(m, "Program") .def(py::init<>()) .def("config", &Program::compile_config, @@ -443,7 +449,11 @@ void export_lang(py::module &m) { program->fill_ndarray_fast_u32(ndarray, val); }) .def("get_graphics_device", - [](Program *program) { return program->get_graphics_device(); }); + [](Program *program) { return program->get_graphics_device(); }) + .def("compile_kernel", &Program::compile_kernel, + py::return_value_policy::reference) + .def("launch_kernel", &Program::launch_kernel) + .def("get_device_caps", &Program::get_device_caps); py::class_(m, "AotModuleBuilder") .def("add_field", &AotModuleBuilder::add_field) @@ -695,11 +705,7 @@ void export_lang(py::module &m) { [](Kernel *self) -> ASTBuilder * { return &self->context->builder(); }, - py::return_value_policy::reference) - .def("__call__", [](Kernel *kernel, LaunchContextBuilder &launch_ctx) { - py::gil_scoped_release release; - kernel->operator()(kernel->program->compile_config(), launch_ctx); - }); + py::return_value_policy::reference); py::class_(m, "KernelLaunchContext") .def("set_arg_int", &LaunchContextBuilder::set_arg_int) diff --git a/taichi/runtime/program_impls/dx/dx_program.h b/taichi/runtime/program_impls/dx/dx_program.h index a5970b313111e..74e1c87d596a9 100644 --- a/taichi/runtime/program_impls/dx/dx_program.h +++ b/taichi/runtime/program_impls/dx/dx_program.h @@ -78,10 +78,11 @@ class Dx11ProgramImpl : public ProgramImpl { return "1" + std::string(has_buffer_ptr ? "b" : "-"); }; + DeviceCapabilityConfig get_device_caps() override; + protected: std::unique_ptr make_kernel_compiler() override; std::unique_ptr make_kernel_launcher() override; - DeviceCapabilityConfig get_device_caps() override; private: std::shared_ptr device_{nullptr}; diff --git a/taichi/runtime/program_impls/metal/metal_program.h b/taichi/runtime/program_impls/metal/metal_program.h index 12a98dbc796f2..11d8dfb49560d 100644 --- a/taichi/runtime/program_impls/metal/metal_program.h +++ b/taichi/runtime/program_impls/metal/metal_program.h @@ -105,10 +105,11 @@ class MetalProgramImpl : public ProgramImpl { return "1" + std::string(has_buffer_ptr ? "b" : "-"); }; + DeviceCapabilityConfig get_device_caps() override; + protected: std::unique_ptr make_kernel_compiler() override; std::unique_ptr make_kernel_launcher() override; - DeviceCapabilityConfig get_device_caps() override; private: std::unique_ptr embedded_device_{nullptr}; diff --git a/taichi/runtime/program_impls/opengl/opengl_program.h b/taichi/runtime/program_impls/opengl/opengl_program.h index dcc2e0e3e0a40..5e6ade0d6a295 100644 --- a/taichi/runtime/program_impls/opengl/opengl_program.h +++ b/taichi/runtime/program_impls/opengl/opengl_program.h @@ -85,10 +85,11 @@ class OpenglProgramImpl : public ProgramImpl { return "1" + std::string(has_buffer_ptr ? "b" : "-"); }; + DeviceCapabilityConfig get_device_caps() override; + protected: std::unique_ptr make_kernel_compiler() override; std::unique_ptr make_kernel_launcher() override; - DeviceCapabilityConfig get_device_caps() override; private: std::shared_ptr device_{nullptr}; diff --git a/taichi/runtime/program_impls/vulkan/vulkan_program.h b/taichi/runtime/program_impls/vulkan/vulkan_program.h index f5ef89da95de1..1862ae1da2559 100644 --- a/taichi/runtime/program_impls/vulkan/vulkan_program.h +++ b/taichi/runtime/program_impls/vulkan/vulkan_program.h @@ -113,10 +113,11 @@ class VulkanProgramImpl : public ProgramImpl { ~VulkanProgramImpl() override; + DeviceCapabilityConfig get_device_caps() override; + protected: std::unique_ptr make_kernel_compiler() override; std::unique_ptr make_kernel_launcher() override; - DeviceCapabilityConfig get_device_caps() override; private: std::unique_ptr embedded_device_{nullptr}; diff --git a/tests/cpp/backends/dx11_device_test.cpp b/tests/cpp/backends/dx11_device_test.cpp index 0ae993bcb7d17..9f35ab455fefd 100644 --- a/tests/cpp/backends/dx11_device_test.cpp +++ b/tests/cpp/backends/dx11_device_test.cpp @@ -152,7 +152,8 @@ TEST(Dx11ProgramTest, MaterializeRuntimeTest) { auto ker = std::make_unique(*test_prog.prog(), std::move(block)); ker->finalize_rets(); ker->finalize_params(); - program->compile(*program->config, ker.get()); + program->get_kernel_compilation_manager().load_or_compile( + *program->config, program->get_device_caps(), *ker.get()); } } // namespace directx11 diff --git a/tests/cpp/ir/ir_builder_test.cpp b/tests/cpp/ir/ir_builder_test.cpp index a4c299f52078d..18ed18edb827e 100644 --- a/tests/cpp/ir/ir_builder_test.cpp +++ b/tests/cpp/ir/ir_builder_test.cpp @@ -116,7 +116,10 @@ TEST(IRBuilder, ExternalPtr) { auto launch_ctx = ker->make_launch_context(); launch_ctx.set_arg_external_array_with_shape( /*arg_id=*/0, (uint64)array.get(), size, {size}); - (*ker)(test_prog.prog()->compile_config(), launch_ctx); + auto *prog = test_prog.prog(); + const auto &compiled_kernel_data = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker); + prog->launch_kernel(compiled_kernel_data, launch_ctx); EXPECT_EQ(array[0], 2); EXPECT_EQ(array[1], 1); EXPECT_EQ(array[2], 42); @@ -131,6 +134,8 @@ TEST(IRBuilder, Ndarray) { Arch arch = Arch::x64; #endif test_prog.setup(arch); + auto *prog = test_prog.prog(); + IRBuilder builder1; int size = 10; @@ -140,7 +145,9 @@ TEST(IRBuilder, Ndarray) { auto ker1 = setup_kernel1(test_prog.prog()); auto launch_ctx1 = ker1->make_launch_context(); launch_ctx1.set_arg_ndarray(/*arg_id=*/0, array); - (*ker1)(test_prog.prog()->compile_config(), launch_ctx1); + const auto &compiled_kernel_data = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker1); + prog->launch_kernel(compiled_kernel_data, launch_ctx1); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 1); EXPECT_EQ(array.read_int({2}), 42); @@ -149,7 +156,9 @@ TEST(IRBuilder, Ndarray) { auto launch_ctx2 = ker2->make_launch_context(); launch_ctx2.set_arg_ndarray(/*arg_id=*/0, array); launch_ctx2.set_arg_int(/*arg_id=*/1, 3); - (*ker2)(test_prog.prog()->compile_config(), launch_ctx2); + const auto &compiled_kernel_data2 = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker2); + prog->launch_kernel(compiled_kernel_data2, launch_ctx2); EXPECT_EQ(array.read_int({0}), 2); EXPECT_EQ(array.read_int({1}), 3); EXPECT_EQ(array.read_int({2}), 42); @@ -177,8 +186,10 @@ TEST(IRBuilder, AtomicOp) { auto launch_ctx = ker->make_launch_context(); launch_ctx.set_arg_external_array_with_shape( /*arg_id=*/0, (uint64)array.get(), size, {size}); - (*ker)(test_prog.prog()->compile_config(), launch_ctx); - + auto *prog = test_prog.prog(); + const auto &compiled_kernel_data = prog->compile_kernel( + prog->compile_config(), prog->get_device_caps(), *ker); + prog->launch_kernel(compiled_kernel_data, launch_ctx); EXPECT_EQ(array[0], 3); } } // namespace taichi::lang