diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp index acc074c74f8a0..1ebefde10cbde 100644 --- a/taichi/aot/graph_data.cpp +++ b/taichi/aot/graph_data.cpp @@ -1,6 +1,7 @@ #include "taichi/aot/graph_data.h" #include "taichi/program/ndarray.h" #include "taichi/program/texture.h" +#include "taichi/program/kernel.h" #include @@ -12,7 +13,7 @@ void CompiledGraph::run( for (const auto &dispatch : dispatches) { RuntimeContext ctx = ctx_; - TI_ASSERT(dispatch.compiled_kernel); + TI_ASSERT(dispatch.ti_kernel || dispatch.compiled_kernel); // Populate args metadata into RuntimeContext const auto &symbolic_args_ = dispatch.symbolic_args; @@ -77,7 +78,16 @@ void CompiledGraph::run( TI_ERROR("Error in compiled graph: unknown tag {}", ival.tag); } } - dispatch.compiled_kernel->launch(&ctx); + + if (dispatch.compiled_kernel) { + // Run cgraph loaded from AOT module + dispatch.compiled_kernel->launch(&ctx); + } else { + // JIT & Run + TI_ASSERT(dispatch.ti_kernel); + lang::Kernel::LaunchContextBuilder launch_ctx(dispatch.ti_kernel, &ctx); + dispatch.ti_kernel->operator()(launch_ctx); + } } } } // namespace aot diff --git a/taichi/aot/graph_data.h b/taichi/aot/graph_data.h index 0ee1d18cd7028..ed100e27862bf 100644 --- a/taichi/aot/graph_data.h +++ b/taichi/aot/graph_data.h @@ -15,6 +15,7 @@ namespace taichi::lang { class AotModuleBuilder; class Ndarray; class Texture; +class Kernel; namespace aot { // Currently only scalar, matrix and ndarray are supported. @@ -149,6 +150,7 @@ struct CompiledDispatch { std::string kernel_name; std::vector symbolic_args; Kernel *compiled_kernel{nullptr}; + taichi::lang::Kernel *ti_kernel{nullptr}; TI_IO_DEF(kernel_name, symbolic_args); }; diff --git a/taichi/aot/module_builder.cpp b/taichi/aot/module_builder.cpp index df430a3571851..194dd1762dca1 100644 --- a/taichi/aot/module_builder.cpp +++ b/taichi/aot/module_builder.cpp @@ -56,8 +56,12 @@ void AotModuleBuilder::add_graph(const std::string &name, TI_ERROR("Graph {} already exists", name); } // Handle adding kernels separately. + std::unordered_map kernels; for (const auto &dispatch : graph.dispatches) { - add_compiled_kernel(dispatch.kernel_name, dispatch.compiled_kernel); + kernels[dispatch.kernel_name] = dispatch.ti_kernel; + } + for (auto &e : kernels) { + add(e.first, e.second); } graphs_[name] = graph; } diff --git a/taichi/aot/module_builder.h b/taichi/aot/module_builder.h index 53a47ac506d25..0edadfaf7aeb1 100644 --- a/taichi/aot/module_builder.h +++ b/taichi/aot/module_builder.h @@ -64,11 +64,6 @@ class AotModuleBuilder { TI_NOT_IMPLEMENTED; } - virtual void add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) { - TI_NOT_IMPLEMENTED; - } - virtual void add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) { diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 376284a97b0a3..62ea16f185959 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -2566,13 +2566,16 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, } void lower(Kernel *kernel) { - auto &config = kernel->program->this_thread_config(); - config.demote_dense_struct_fors = true; - irpass::compile_to_executable(kernel->ir.get(), config, kernel, - kernel->autodiff_mode, - /*ad_use_stack=*/false, config.print_ir, - /*lower_global_access=*/true, - /*make_thread_local=*/false); + if (!kernel->lowered()) { + auto &config = kernel->program->this_thread_config(); + config.demote_dense_struct_fors = true; + irpass::compile_to_executable(kernel->ir.get(), config, kernel, + kernel->autodiff_mode, + /*ad_use_stack=*/false, config.print_ir, + /*lower_global_access=*/true, + /*make_thread_local=*/false); + kernel->set_lowered(true); + } } } // namespace spirv diff --git a/taichi/program/graph_builder.cpp b/taichi/program/graph_builder.cpp index 7cde479a831d9..8f9fe9db29e6e 100644 --- a/taichi/program/graph_builder.cpp +++ b/taichi/program/graph_builder.cpp @@ -5,11 +5,11 @@ namespace taichi::lang { void Dispatch::compile( std::vector &compiled_dispatches) { - if (kernel_->compiled_aot_kernel() == nullptr) { - kernel_->compile_to_aot_kernel(); - } - aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, - kernel_->compiled_aot_kernel()}; + aot::CompiledDispatch dispatch; + dispatch.kernel_name = kernel_->get_name(); + dispatch.symbolic_args = symbolic_args_; + dispatch.ti_kernel = kernel_; + dispatch.compiled_kernel = nullptr; compiled_dispatches.push_back(std::move(dispatch)); } diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 3c6ee66c769c0..ade148d9a5ddd 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -67,10 +67,6 @@ void Kernel::compile() { compiled_ = program->compile(*this); } -void Kernel::compile_to_aot_kernel() { - compiled_aot_kernel_ = program->make_aot_kernel(*this); -} - void Kernel::lower(bool to_executable) { TI_ASSERT(!lowered_); TI_ASSERT(supports_lowering(arch)); diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index 89d7b5ea9e76e..726283dfda991 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -86,14 +86,12 @@ class TI_DLL_EXPORT Kernel : public Callable { return lowered_; } - void compile(); - - void compile_to_aot_kernel(); - - aot::Kernel *compiled_aot_kernel() { - return compiled_aot_kernel_.get(); + void set_lowered(bool lowered) { + lowered_ = lowered; } + void compile(); + /** * Lowers |ir| to CHI IR level * @@ -159,10 +157,6 @@ class TI_DLL_EXPORT Kernel : public Callable { bool ir_is_ast_{false}; // The closure that, if invoked, launches the backend kernel (shader) FunctionType compiled_{nullptr}; - // TODO[#5114]: It's kinda redundant to keep both compiled_ (used for JIT - // execution) as well as compiled_aot_kernel_. In fact we'd better unify - // everything around compiled_aot_kernel and rename it. - std::unique_ptr compiled_aot_kernel_{nullptr}; // A flag to record whether |ir| has been fully lowered. // lower initial AST all the way down to a bunch of // OffloadedStmt for async execution TODO(Lin): Check this comment diff --git a/taichi/program/program.h b/taichi/program/program.h index 48f6cd4b35125..6eab3efd21230 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -201,10 +201,6 @@ class TI_DLL_EXPORT Program { // TODO(Lin): remove the offloaded parameter FunctionType compile(Kernel &kernel, OffloadedStmt *offloaded = nullptr); - std::unique_ptr make_aot_kernel(Kernel &kernel) { - return program_impl_->make_aot_kernel(kernel); - } - void check_runtime_error(); Kernel &get_snode_reader(SNode *snode); diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index 90d8e40d0fbf4..24bbadd10ed23 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -78,13 +78,6 @@ class ProgramImpl { virtual std::unique_ptr make_aot_module_builder( const DeviceCapabilityConfig &caps) = 0; - /** - * Compile a taichi::lang::Kernel to taichi::lang::aot::Kernel. - */ - virtual std::unique_ptr make_aot_kernel(Kernel &kernel) { - TI_NOT_IMPLEMENTED; - } - /** * Dump Offline-cache data to disk */ diff --git a/taichi/runtime/dx12/aot_module_builder_impl.cpp b/taichi/runtime/dx12/aot_module_builder_impl.cpp index 1a07f4519f850..f4c8eb7a925fa 100644 --- a/taichi/runtime/dx12/aot_module_builder_impl.cpp +++ b/taichi/runtime/dx12/aot_module_builder_impl.cpp @@ -28,12 +28,6 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, compiled_kernel.tasks = compiled_data.tasks; } -void AotModuleBuilderImpl::add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) { - // FIXME: implement add_compiled_kernel. - TI_NOT_IMPLEMENTED; -} - void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, const SNode *rep_snode, bool is_scalar, diff --git a/taichi/runtime/dx12/aot_module_builder_impl.h b/taichi/runtime/dx12/aot_module_builder_impl.h index 763a15eea5611..e063e0b7a366a 100644 --- a/taichi/runtime/dx12/aot_module_builder_impl.h +++ b/taichi/runtime/dx12/aot_module_builder_impl.h @@ -34,9 +34,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder { const std::string &key, Kernel *kernel) override; - void add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) override; - LlvmProgramImpl *prog; ModuleDataDX12 module_data; }; diff --git a/taichi/runtime/gfx/aot_module_builder_impl.cpp b/taichi/runtime/gfx/aot_module_builder_impl.cpp index ae36b6756df70..26499456a7527 100644 --- a/taichi/runtime/gfx/aot_module_builder_impl.cpp +++ b/taichi/runtime/gfx/aot_module_builder_impl.cpp @@ -205,16 +205,6 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); } -void AotModuleBuilderImpl::add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) { - GfxRuntime::RegisterParams register_params = - static_cast(kernel)->params(); - register_params.kernel_attribs.name = identifier; - ti_aot_data_.kernels.push_back(std::move(register_params.kernel_attribs)); - ti_aot_data_.spirv_codes.push_back( - std::move(register_params.task_spirv_source_codes)); -} - void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, const SNode *rep_snode, bool is_scalar, diff --git a/taichi/runtime/gfx/aot_module_builder_impl.h b/taichi/runtime/gfx/aot_module_builder_impl.h index 49d82cab55392..487844a7cedaf 100644 --- a/taichi/runtime/gfx/aot_module_builder_impl.h +++ b/taichi/runtime/gfx/aot_module_builder_impl.h @@ -42,9 +42,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder { const std::string &key, Kernel *kernel) override; - void add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) override; - std::string write_spv_file(const std::string &output_dir, const TaskAttributes &k, const std::vector &source_code) const; diff --git a/taichi/runtime/llvm/llvm_aot_module_builder.cpp b/taichi/runtime/llvm/llvm_aot_module_builder.cpp index 6a694646ee481..4eb14ffbf5775 100644 --- a/taichi/runtime/llvm/llvm_aot_module_builder.cpp +++ b/taichi/runtime/llvm/llvm_aot_module_builder.cpp @@ -60,18 +60,4 @@ void LlvmAotModuleBuilder::add_field_per_backend(const std::string &identifier, cache_.fields[snode_tree_id] = std::move(field_cache); } -void LlvmAotModuleBuilder::add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) { - auto *kernel_impl = dynamic_cast(kernel); - TI_ASSERT(kernel_impl); - if (!kernel_impl->kernel_data_.created_at) { - kernel_impl->kernel_data_.last_used_at = std::time(nullptr); - kernel_impl->kernel_data_.created_at = std::time(nullptr); - } - const std::string &kernel_name = identifier; - if (cache_.kernels.find(kernel_name) == cache_.kernels.end()) { - cache_.kernels[kernel_name] = std::move(kernel_impl->kernel_data_); - } -} - } // namespace taichi::lang diff --git a/taichi/runtime/llvm/llvm_aot_module_builder.h b/taichi/runtime/llvm/llvm_aot_module_builder.h index f00dddc67b9f0..4b731944b8b18 100644 --- a/taichi/runtime/llvm/llvm_aot_module_builder.h +++ b/taichi/runtime/llvm/llvm_aot_module_builder.h @@ -26,9 +26,6 @@ class LlvmAotModuleBuilder : public AotModuleBuilder { int row_num, int column_num) override; - void add_compiled_kernel(const std::string &identifier, - aot::Kernel *kernel) override; - const LlvmOfflineCache &get_cache() { return cache_; } diff --git a/taichi/runtime/program_impls/dx/dx_program.cpp b/taichi/runtime/program_impls/dx/dx_program.cpp index 09ac4daf01cab..9b379bd2ba604 100644 --- a/taichi/runtime/program_impls/dx/dx_program.cpp +++ b/taichi/runtime/program_impls/dx/dx_program.cpp @@ -82,15 +82,6 @@ DeviceAllocation Dx11ProgramImpl::allocate_memory_ndarray( /*export_sharing=*/false}); } -std::unique_ptr Dx11ProgramImpl::make_aot_kernel(Kernel &kernel) { - spirv::lower(&kernel); - std::vector compiled_structs; - gfx::GfxRuntime::RegisterParams kparams = gfx::run_codegen( - &kernel, Arch::dx11, get_compute_device()->get_current_caps(), - compiled_structs); - return std::make_unique(runtime_.get(), std::move(kparams)); -} - } // namespace taichi::lang #endif diff --git a/taichi/runtime/program_impls/dx/dx_program.h b/taichi/runtime/program_impls/dx/dx_program.h index e7b4143764400..0c815e16ed66c 100644 --- a/taichi/runtime/program_impls/dx/dx_program.h +++ b/taichi/runtime/program_impls/dx/dx_program.h @@ -62,8 +62,6 @@ class Dx11ProgramImpl : public ProgramImpl { return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); } - std::unique_ptr make_aot_kernel(Kernel &kernel) override; - private: std::shared_ptr device_{nullptr}; std::unique_ptr runtime_{nullptr}; diff --git a/taichi/runtime/program_impls/llvm/llvm_program.cpp b/taichi/runtime/program_impls/llvm/llvm_program.cpp index ab2ddc497e6f5..342615297860c 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.cpp +++ b/taichi/runtime/program_impls/llvm/llvm_program.cpp @@ -109,20 +109,6 @@ std::unique_ptr LlvmProgramImpl::make_aot_module_builder( return nullptr; } -std::unique_ptr LlvmProgramImpl::make_aot_kernel(Kernel &kernel) { - auto compiled_fn = - this->compile(&kernel, nullptr); // Offloaded used in async mode only - - const std::string &kernel_key = kernel.get_cached_kernel_key(); - TI_ASSERT(cache_data_->kernels.count(kernel_key)); - const LlvmOfflineCache::KernelCacheData &kernel_data = - cache_data_->kernels[kernel_key]; - LlvmOfflineCache::KernelCacheData compiled_kernel = kernel_data.clone(); - compiled_kernel.kernel_key = kernel.get_name(); - return std::make_unique(compiled_fn, - std::move(compiled_kernel)); -} - void LlvmProgramImpl::cache_kernel(const std::string &kernel_key, const LLVMCompiledKernel &data, std::vector &&args) { diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index 5e16dc3ea57f0..d5ab1595e9fce 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -68,8 +68,6 @@ class LlvmProgramImpl : public ProgramImpl { std::unique_ptr compile_snode_tree_types_impl( SNodeTree *tree); - std::unique_ptr make_aot_kernel(Kernel &kernel) override; - std::unique_ptr make_aot_module_builder( const DeviceCapabilityConfig &caps) override; diff --git a/taichi/runtime/program_impls/opengl/opengl_program.cpp b/taichi/runtime/program_impls/opengl/opengl_program.cpp index 67f9774de8af1..b1e9853fcdb4d 100644 --- a/taichi/runtime/program_impls/opengl/opengl_program.cpp +++ b/taichi/runtime/program_impls/opengl/opengl_program.cpp @@ -83,12 +83,6 @@ DeviceAllocation OpenglProgramImpl::allocate_texture( return runtime_->create_image(params); } -std::unique_ptr OpenglProgramImpl::make_aot_kernel( - Kernel &kernel) { - auto params = get_cache_manager()->load_or_compile(config, &kernel); - return std::make_unique(runtime_.get(), std::move(params)); -} - void OpenglProgramImpl::dump_cache_data_to_disk() { const auto &mgr = get_cache_manager(); mgr->clean_offline_cache(offline_cache::string_to_clean_cache_policy( diff --git a/taichi/runtime/program_impls/opengl/opengl_program.h b/taichi/runtime/program_impls/opengl/opengl_program.h index 73033b559d721..176e9f2840e40 100644 --- a/taichi/runtime/program_impls/opengl/opengl_program.h +++ b/taichi/runtime/program_impls/opengl/opengl_program.h @@ -62,8 +62,6 @@ class OpenglProgramImpl : public ProgramImpl { return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); } - std::unique_ptr make_aot_kernel(Kernel &kernel) override; - void dump_cache_data_to_disk() override; const std::unique_ptr &get_cache_manager(); diff --git a/taichi/runtime/program_impls/vulkan/vulkan_program.cpp b/taichi/runtime/program_impls/vulkan/vulkan_program.cpp index 566feb3ca481e..c5e1b4a60bd23 100644 --- a/taichi/runtime/program_impls/vulkan/vulkan_program.cpp +++ b/taichi/runtime/program_impls/vulkan/vulkan_program.cpp @@ -205,13 +205,6 @@ DeviceAllocation VulkanProgramImpl::allocate_texture( return vulkan_runtime_->create_image(params); } -std::unique_ptr VulkanProgramImpl::make_aot_kernel( - Kernel &kernel) { - auto params = get_cache_manager()->load_or_compile(config, &kernel); - return std::make_unique(vulkan_runtime_.get(), - std::move(params)); -} - void VulkanProgramImpl::enqueue_compute_op_lambda( std::function op, const std::vector &image_refs) { diff --git a/taichi/runtime/program_impls/vulkan/vulkan_program.h b/taichi/runtime/program_impls/vulkan/vulkan_program.h index c2a33fc25ce1a..c7f877a3a65d8 100644 --- a/taichi/runtime/program_impls/vulkan/vulkan_program.h +++ b/taichi/runtime/program_impls/vulkan/vulkan_program.h @@ -91,8 +91,6 @@ class VulkanProgramImpl : public ProgramImpl { return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); } - std::unique_ptr make_aot_kernel(Kernel &kernel) override; - void enqueue_compute_op_lambda( std::function op, const std::vector &image_refs) override; diff --git a/tests/python/test_aot.py b/tests/python/test_aot.py index fd1052f3e1e74..19f38651f918f 100644 --- a/tests/python/test_aot.py +++ b/tests/python/test_aot.py @@ -593,7 +593,7 @@ def test(a: ti.types.ndarray(), c: ti.u8): g.run({'a': a, 'c': c}) - m = ti.aot.Module(ti.lang.impl.current_cfg().arch) + m = ti.aot.Module(ti.lang.impl.current_cfg().arch, caps=['spirv_has_int8']) m.add_graph('g_init', g) with tempfile.TemporaryDirectory() as tmpdir: m.save(tmpdir, '')