Skip to content

Commit

Permalink
[refactor] Remove dependencies on Program::this_thread_config() in sp…
Browse files Browse the repository at this point in the history
…irv::lower (taichi-dev#7134)

Issue: taichi-dev#7002
  • Loading branch information
PGZXB authored and quadpixels committed May 13, 2023
1 parent 49adf3e commit 7e62aa3
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 11 deletions.
2 changes: 1 addition & 1 deletion taichi/cache/gfx/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ CacheManager::CacheManager(Params &&init_params)
CompiledKernelData CacheManager::load_or_compile(CompileConfig *config,
Kernel *kernel) {
if (kernel->is_evaluator) {
spirv::lower(kernel);
spirv::lower(*config, kernel);
return gfx::run_codegen(kernel, runtime_->get_ti_device()->arch(),
runtime_->get_ti_device()->get_caps(),
compiled_structs_, *config);
Expand Down
4 changes: 1 addition & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2604,10 +2604,8 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs,
kernel_attribs.is_jit_evaluator = params_.kernel->is_evaluator;
}

void lower(Kernel *kernel) {
void lower(const CompileConfig &config, Kernel *kernel) {
if (!kernel->lowered()) {
auto &config = kernel->program->this_thread_config();
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(kernel->ir.get(), config, kernel,
kernel->autodiff_mode,
/*ad_use_stack=*/false, config.print_ir,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/spirv/spirv_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Kernel;

namespace spirv {

void lower(Kernel *kernel);
void lower(const CompileConfig &config, Kernel *kernel);

class KernelCodegen {
public:
Expand Down
3 changes: 2 additions & 1 deletion taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "compile_config.h"

#include <thread>
#include "taichi/rhi/arch.h"
#include "taichi/util/offline_cache.h"

namespace taichi::lang {
Expand Down Expand Up @@ -72,7 +73,7 @@ void CompileConfig::fit() {
// TODO: allow users to run in debug mode without out-of-bound checks
check_out_of_bound = true;
}
if (arch == Arch::cc) {
if (arch == Arch::cc || arch_uses_spirv(arch)) {
demote_dense_struct_fors = true;
}
offline_cache::disable_offline_cache_if_needed(this);
Expand Down
4 changes: 4 additions & 0 deletions taichi/rhi/arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ bool arch_is_gpu(Arch arch) {
return !arch_is_cpu(arch);
}

bool arch_uses_spirv(Arch arch) {
return arch == Arch::opengl || arch == Arch::vulkan || arch == Arch::dx11;
}

Arch host_arch() {
#if defined(TI_ARCH_x64)
return Arch::x64;
Expand Down
2 changes: 2 additions & 0 deletions taichi/rhi/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ bool arch_uses_llvm(Arch arch);

bool arch_is_gpu(Arch arch);

bool arch_uses_spirv(Arch arch);

Arch host_arch();

bool arch_use_host_memory(Arch arch);
Expand Down
7 changes: 3 additions & 4 deletions taichi/runtime/gfx/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ AotModuleBuilderImpl::try_get_kernel_register_params(

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
spirv::lower(kernel);
auto compiled = run_codegen(kernel, this->device_api_backend_, caps_,
spirv::lower(config_, kernel);
auto compiled = run_codegen(kernel, device_api_backend_, caps_,
compiled_structs_, config_);
compiled.kernel_attribs.name = identifier;
ti_aot_data_.kernels.push_back(compiled.kernel_attribs);
Expand Down Expand Up @@ -243,10 +243,9 @@ void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
spirv::lower(kernel);
spirv::lower(config_, kernel);
auto compiled = run_codegen(kernel, device_api_backend_, caps_,
compiled_structs_, config_);

compiled.kernel_attribs.name = identifier + "|" + key;
ti_aot_data_.kernels.push_back(compiled.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Dx11ProgramImpl::Dx11ProgramImpl(CompileConfig &config) : ProgramImpl(config) {
}

FunctionType Dx11ProgramImpl::compile(Kernel *kernel) {
spirv::lower(kernel);
spirv::lower(*config, kernel);
return directx11::compile_to_executable(kernel, runtime_.get(), *config,
snode_tree_mgr_.get());
}
Expand Down

0 comments on commit 7e62aa3

Please sign in to comment.