diff --git a/taichi/codegen/cc/cc_program.cpp b/taichi/codegen/cc/cc_program.cpp index c6f6c8a5fbfe5..79f8fd01eeb8d 100644 --- a/taichi/codegen/cc/cc_program.cpp +++ b/taichi/codegen/cc/cc_program.cpp @@ -16,7 +16,7 @@ CCProgramImpl::CCProgramImpl(CompileConfig &config) : ProgramImpl(config) { } FunctionType CCProgramImpl::compile(Kernel *kernel) { - CCKernelGen codegen(kernel, this); + CCKernelGen codegen(*config, kernel, this); auto ker = codegen.compile(); auto ker_ptr = ker.get(); this->add_kernel(std::move(ker)); diff --git a/taichi/codegen/cc/codegen_cc.cpp b/taichi/codegen/cc/codegen_cc.cpp index b86b3034473db..c38c9f14e571b 100644 --- a/taichi/codegen/cc/codegen_cc.cpp +++ b/taichi/codegen/cc/codegen_cc.cpp @@ -18,6 +18,15 @@ namespace { std::string get_node_ptr_name(SNode *snode) { return fmt::format("struct Ti_{} *", snode->get_node_type_name_hinted()); } + +static void lower_ast(const CompileConfig &config, Kernel *kernel) { + auto ir = kernel->ir.get(); + irpass::compile_to_executable(ir, config, kernel, + /*autodiff_mode=*/kernel->autodiff_mode, + /*ad_use_stack=*/true, config.print_ir, + /*lower_global_access*/ true); +} + } // namespace class CCTransformer : public IRVisitor { @@ -38,22 +47,11 @@ class CCTransformer : public IRVisitor { } void run() { - this->lower_ast(); emit_header("void Tk_{}(struct Ti_Context *ti_ctx) {{", kernel_->name); kernel_->ir->accept(this); emit("}}"); } - void lower_ast() { - auto ir = kernel_->ir.get(); - auto config = kernel_->program->this_thread_config(); - config.demote_dense_struct_fors = true; - irpass::compile_to_executable(ir, config, kernel_, - /*autodiff_mode=*/kernel_->autodiff_mode, - /*ad_use_stack=*/true, config.print_ir, - /*lower_global_access*/ true); - } - std::string get_source() { return line_appender_header_.lines() + line_appender_.lines(); } @@ -593,6 +591,7 @@ class CCTransformer : public IRVisitor { }; // namespace cccp std::unique_ptr CCKernelGen::compile() { + lower_ast(compile_config_, kernel_); auto layout = cc_program_impl_->get_layout(); CCTransformer tran(kernel_, layout); diff --git a/taichi/codegen/cc/codegen_cc.h b/taichi/codegen/cc/codegen_cc.h index d537e95e20215..bb444c9337f1e 100644 --- a/taichi/codegen/cc/codegen_cc.h +++ b/taichi/codegen/cc/codegen_cc.h @@ -14,13 +14,18 @@ class CCKernel; class CCKernelGen { // Generate corresponding C Source Code for a Taichi Kernel public: - CCKernelGen(Kernel *kernel, CCProgramImpl *cc_program_impl) - : cc_program_impl_(cc_program_impl), kernel_(kernel) { + CCKernelGen(const CompileConfig &compile_config, + Kernel *kernel, + CCProgramImpl *cc_program_impl) + : compile_config_(compile_config), + cc_program_impl_(cc_program_impl), + kernel_(kernel) { } std::unique_ptr compile(); private: + const CompileConfig &compile_config_; CCProgramImpl *cc_program_impl_{nullptr}; Kernel *kernel_; }; diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 19c3dca01db5e..bb6c11f59c6b9 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -1,6 +1,7 @@ #include "compile_config.h" #include +#include "taichi/util/offline_cache.h" namespace taichi::lang { @@ -66,4 +67,15 @@ CompileConfig::CompileConfig() { cc_link_cmd = "gcc -shared -fPIC -o '{}' '{}'"; } +void CompileConfig::fit() { + if (debug) { + // TODO: allow users to run in debug mode without out-of-bound checks + check_out_of_bound = true; + } + if (arch == Arch::cc) { + demote_dense_struct_fors = true; + } + offline_cache::disable_offline_cache_if_needed(this); +} + } // namespace taichi::lang diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index 947d7beab585c..3cff78ad5e5d6 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -106,6 +106,8 @@ struct CompileConfig { size_t cuda_stack_limit{8192}; CompileConfig(); + + void fit(); }; extern TI_DLL_EXPORT CompileConfig default_compile_config; diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index ca365ccc34d3d..31244a719f100 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -81,10 +81,7 @@ Program::Program(Arch desired_arch) : snode_rw_accessors_bank_(this) { configs[main_thread_id_] = default_compile_config; configs[main_thread_id_].arch = desired_arch; auto &config = this_thread_config(); - // TODO: allow users to run in debug mode without out-of-bound checks - if (config.debug) - config.check_out_of_bound = true; - offline_cache::disable_offline_cache_if_needed(&config); + config.fit(); profiler = make_profiler(config.arch, config.kernel_profiler); if (arch_uses_llvm(config.arch)) {