diff --git a/taichi/cache/gfx/cache_manager.cpp b/taichi/cache/gfx/cache_manager.cpp index bdb5fca6ab2b4..23dd4ec98beee 100644 --- a/taichi/cache/gfx/cache_manager.cpp +++ b/taichi/cache/gfx/cache_manager.cpp @@ -162,7 +162,7 @@ CacheManager::CacheManager(Params &&init_params) CompiledKernelData CacheManager::load_or_compile(CompileConfig *config, Kernel *kernel) { if (kernel->is_evaluator) { - spirv::lower(kernel); + spirv::lower(*config, kernel); return gfx::run_codegen(kernel, runtime_->get_ti_device()->arch(), runtime_->get_ti_device()->get_caps(), compiled_structs_, *config); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 6c3d0fe120b73..656e428af6241 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -2604,10 +2604,8 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, kernel_attribs.is_jit_evaluator = params_.kernel->is_evaluator; } -void lower(Kernel *kernel) { +void lower(const CompileConfig &config, Kernel *kernel) { if (!kernel->lowered()) { - auto &config = kernel->program->this_thread_config(); - config.demote_dense_struct_fors = true; irpass::compile_to_executable(kernel->ir.get(), config, kernel, kernel->autodiff_mode, /*ad_use_stack=*/false, config.print_ir, diff --git a/taichi/codegen/spirv/spirv_codegen.h b/taichi/codegen/spirv/spirv_codegen.h index 74e94a7cd91b8..bc93d61d9420b 100644 --- a/taichi/codegen/spirv/spirv_codegen.h +++ b/taichi/codegen/spirv/spirv_codegen.h @@ -14,7 +14,7 @@ class Kernel; namespace spirv { -void lower(Kernel *kernel); +void lower(const CompileConfig &config, Kernel *kernel); class KernelCodegen { public: diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index bb6c11f59c6b9..0cd5f88c16351 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -1,6 +1,7 @@ #include "compile_config.h" #include +#include "taichi/rhi/arch.h" #include "taichi/util/offline_cache.h" namespace taichi::lang { @@ -72,7 +73,7 @@ void CompileConfig::fit() { // TODO: allow users to run in debug mode without out-of-bound checks check_out_of_bound = true; } - if (arch == Arch::cc) { + if (arch == Arch::cc || arch_uses_spirv(arch)) { demote_dense_struct_fors = true; } offline_cache::disable_offline_cache_if_needed(this); diff --git a/taichi/rhi/arch.cpp b/taichi/rhi/arch.cpp index 06361d86cac3a..a441b9345f08f 100644 --- a/taichi/rhi/arch.cpp +++ b/taichi/rhi/arch.cpp @@ -56,6 +56,10 @@ bool arch_is_gpu(Arch arch) { return !arch_is_cpu(arch); } +bool arch_uses_spirv(Arch arch) { + return arch == Arch::opengl || arch == Arch::vulkan || arch == Arch::dx11; +} + Arch host_arch() { #if defined(TI_ARCH_x64) return Arch::x64; diff --git a/taichi/rhi/arch.h b/taichi/rhi/arch.h index 47e74ef3acbb0..58071f8abb846 100644 --- a/taichi/rhi/arch.h +++ b/taichi/rhi/arch.h @@ -24,6 +24,8 @@ bool arch_uses_llvm(Arch arch); bool arch_is_gpu(Arch arch); +bool arch_uses_spirv(Arch arch); + Arch host_arch(); bool arch_use_host_memory(Arch arch); diff --git a/taichi/runtime/gfx/aot_module_builder_impl.cpp b/taichi/runtime/gfx/aot_module_builder_impl.cpp index d907274567ef5..e572db6e504ae 100644 --- a/taichi/runtime/gfx/aot_module_builder_impl.cpp +++ b/taichi/runtime/gfx/aot_module_builder_impl.cpp @@ -201,8 +201,8 @@ AotModuleBuilderImpl::try_get_kernel_register_params( void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, Kernel *kernel) { - spirv::lower(kernel); - auto compiled = run_codegen(kernel, this->device_api_backend_, caps_, + spirv::lower(config_, kernel); + auto compiled = run_codegen(kernel, device_api_backend_, caps_, compiled_structs_, config_); compiled.kernel_attribs.name = identifier; ti_aot_data_.kernels.push_back(compiled.kernel_attribs); @@ -243,10 +243,9 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) { - spirv::lower(kernel); + spirv::lower(config_, kernel); auto compiled = run_codegen(kernel, device_api_backend_, caps_, compiled_structs_, config_); - compiled.kernel_attribs.name = identifier + "|" + key; ti_aot_data_.kernels.push_back(compiled.kernel_attribs); ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); diff --git a/taichi/runtime/program_impls/dx/dx_program.cpp b/taichi/runtime/program_impls/dx/dx_program.cpp index 1a63c5ef4ec63..fe45d897281ed 100644 --- a/taichi/runtime/program_impls/dx/dx_program.cpp +++ b/taichi/runtime/program_impls/dx/dx_program.cpp @@ -27,7 +27,7 @@ Dx11ProgramImpl::Dx11ProgramImpl(CompileConfig &config) : ProgramImpl(config) { } FunctionType Dx11ProgramImpl::compile(Kernel *kernel) { - spirv::lower(kernel); + spirv::lower(*config, kernel); return directx11::compile_to_executable(kernel, runtime_.get(), *config, snode_tree_mgr_.get()); }