From 423a29c7156e829e62739e047fa383bb4b36518d Mon Sep 17 00:00:00 2001 From: PGZXB Date: Tue, 6 Dec 2022 18:52:10 +0800 Subject: [PATCH] Refactor2023: Remove dependencies on Program::this_thread_config() in codgen on llvm backends --- taichi/codegen/codegen.cpp | 3 +- taichi/codegen/codegen.h | 10 +++- taichi/codegen/codegen_utils.h | 2 +- taichi/codegen/cpu/codegen_cpu.cpp | 24 +++------ taichi/codegen/cpu/codegen_cpu.h | 4 +- taichi/codegen/cuda/codegen_cuda.cpp | 30 +++++------- taichi/codegen/cuda/codegen_cuda.h | 3 +- taichi/codegen/dx12/codegen_dx12.cpp | 27 +++++----- taichi/codegen/dx12/codegen_dx12.h | 1 + .../dx12/dx12_global_optimize_module.cpp | 2 +- taichi/codegen/dx12/dx12_llvm_passes.h | 4 +- taichi/codegen/dx12/dx12_lower_intrinsic.cpp | 6 +-- taichi/codegen/llvm/codegen_llvm.cpp | 49 +++++++++---------- taichi/codegen/llvm/codegen_llvm.h | 6 ++- taichi/codegen/llvm/codegen_llvm_quant.cpp | 4 +- taichi/codegen/wasm/codegen_wasm.cpp | 23 +++++---- taichi/codegen/wasm/codegen_wasm.h | 1 + 17 files changed, 95 insertions(+), 104 deletions(-) diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index 7be5c69904f2c..a16cfce470388 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -118,7 +118,8 @@ LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() { auto offload = irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel()); irpass::re_id(offload.get()); - auto new_data = this->compile_task(nullptr, offload->as()); + auto new_data = + this->compile_task(&config, nullptr, offload->as()); data[i] = std::make_unique(std::move(new_data)); }; if (kernel->is_evaluator) { diff --git a/taichi/codegen/codegen.h b/taichi/codegen/codegen.h index b4a1d740083fb..073abada79356 100644 --- a/taichi/codegen/codegen.h +++ b/taichi/codegen/codegen.h @@ -46,7 +46,9 @@ class KernelCodeGen { virtual ~KernelCodeGen() = default; - static std::unique_ptr create(const CompileConfig *compile_config, Kernel *kernel); + static std::unique_ptr create( + const CompileConfig *compile_config, + Kernel *kernel); virtual FunctionType compile_to_function() = 0; virtual bool supports_offline_cache() const { @@ -57,6 +59,7 @@ class KernelCodeGen { virtual LLVMCompiledKernel compile_kernel_to_module(); virtual LLVMCompiledTask compile_task( + const CompileConfig *config, std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr){TI_NOT_IMPLEMENTED} @@ -65,6 +68,11 @@ class KernelCodeGen { void cache_kernel(const std::string &kernel_key, const LLVMCompiledKernel &data); #endif + protected: + const CompileConfig *get_compile_config() const { + return compile_config_; + } + private: const CompileConfig *compile_config_{nullptr}; }; diff --git a/taichi/codegen/codegen_utils.h b/taichi/codegen/codegen_utils.h index 11dc80d5c41c2..f94be39a054e4 100644 --- a/taichi/codegen/codegen_utils.h +++ b/taichi/codegen/codegen_utils.h @@ -3,7 +3,7 @@ namespace taichi::lang { -inline bool codegen_vector_type(CompileConfig *config) { +inline bool codegen_vector_type(const CompileConfig *config) { if (config->real_matrix && !config->real_matrix_scalarize) { return true; } diff --git a/taichi/codegen/cpu/codegen_cpu.cpp b/taichi/codegen/cpu/codegen_cpu.cpp index 8813ef7793e74..9812555f4e5d6 100644 --- a/taichi/codegen/cpu/codegen_cpu.cpp +++ b/taichi/codegen/cpu/codegen_cpu.cpp @@ -18,8 +18,8 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM { public: using IRVisitor::visit; - TaskCodeGenCPU(Kernel *kernel, IRNode *ir) - : TaskCodeGenLLVM(kernel, ir, nullptr) { + TaskCodeGenCPU(const CompileConfig *config, Kernel *kernel, IRNode *ir) + : TaskCodeGenLLVM(config, kernel, ir, nullptr) { TI_AUTO_PROF } @@ -56,7 +56,7 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM { auto [begin, end] = get_range_for_bounds(stmt); // adaptive block_dim - if (prog->this_thread_config().cpu_block_dim_adaptive) { + if (compile_config->cpu_block_dim_adaptive) { int num_items = (stmt->end_value - stmt->begin_value) / std::abs(step); int num_threads = stmt->num_cpu_threads; int items_per_thread = std::max(1, num_items / (num_threads * 32)); @@ -166,8 +166,7 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM { create_bls_buffer(stmt); using Type = OffloadedStmt::TaskType; auto offloaded_task_name = init_offloaded_task_function(stmt); - if (prog->this_thread_config().kernel_profiler && - arch_is_cpu(prog->this_thread_config().arch)) { + if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) { call("LLVMRuntime_profiler_start", get_runtime(), builder->CreateGlobalStringPtr(offloaded_task_name)); } @@ -190,8 +189,7 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM { } else { TI_NOT_IMPLEMENTED } - if (prog->this_thread_config().kernel_profiler && - arch_is_cpu(prog->this_thread_config().arch)) { + if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) { llvm::IRBuilderBase::InsertPointGuard guard(*builder); builder->SetInsertPoint(final_block); call("LLVMRuntime_profiler_stop", get_runtime()); @@ -216,13 +214,6 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM { } // namespace #ifdef TI_WITH_LLVM -// static -std::unique_ptr KernelCodeGenCPU::make_codegen_llvm( - Kernel *kernel, - IRNode *ir) { - return std::make_unique(kernel, ir); -} - FunctionType CPUModuleToFunctionConverter::convert( const std::string &kernel_name, const std::vector &args, @@ -264,9 +255,10 @@ FunctionType CPUModuleToFunctionConverter::convert( } LLVMCompiledTask KernelCodeGenCPU::compile_task( + const CompileConfig *config, std::unique_ptr &&module, OffloadedStmt *stmt) { - TaskCodeGenCPU gen(kernel, stmt); + TaskCodeGenCPU gen(config, kernel, stmt); return gen.run_compilation(); } #endif // TI_WITH_LLVM @@ -274,7 +266,7 @@ LLVMCompiledTask KernelCodeGenCPU::compile_task( FunctionType KernelCodeGenCPU::compile_to_function() { TI_AUTO_PROF; auto *llvm_prog = get_llvm_program(prog); - const auto &config = prog->this_thread_config(); + const auto &config = *get_compile_config(); auto *tlctx = llvm_prog->get_llvm_context(config.arch); CPUModuleToFunctionConverter converter( diff --git a/taichi/codegen/cpu/codegen_cpu.h b/taichi/codegen/cpu/codegen_cpu.h index 060f80a88909b..02cb0c9e1efa0 100644 --- a/taichi/codegen/cpu/codegen_cpu.h +++ b/taichi/codegen/cpu/codegen_cpu.h @@ -17,13 +17,11 @@ class KernelCodeGenCPU : public KernelCodeGen { // TODO: Stop defining this macro guards in the headers #ifdef TI_WITH_LLVM - static std::unique_ptr make_codegen_llvm(Kernel *kernel, - IRNode *ir); - bool supports_offline_cache() const override { return true; } LLVMCompiledTask compile_task( + const CompileConfig *config, std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) override; diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 80576c2e17ce5..0b60929087bb6 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -30,8 +30,10 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { public: using IRVisitor::visit; - explicit TaskCodeGenCUDA(Kernel *kernel, IRNode *ir = nullptr) - : TaskCodeGenLLVM(kernel, ir) { + explicit TaskCodeGenCUDA(const CompileConfig *config, + Kernel *kernel, + IRNode *ir = nullptr) + : TaskCodeGenLLVM(config, kernel, ir) { } llvm::Value *create_print(std::string tag, @@ -106,7 +108,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { llvm::Value *elem_value; - if (codegen_vector_type(&prog->this_thread_config())) { + if (codegen_vector_type(compile_config)) { TI_ASSERT(llvm::dyn_cast(value_type)); elem_value = builder->CreateExtractElement(value, i); } else { @@ -364,7 +366,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { init_offloaded_task_function(stmt, "gather_list"); call("gc_parallel_0", get_context(), snode_id); finalize_offloaded_task_function(); - current_task->grid_dim = prog->this_thread_config().saturating_grid_dim; + current_task->grid_dim = compile_config->saturating_grid_dim; current_task->block_dim = 64; offloaded_tasks.push_back(*current_task); current_task = nullptr; @@ -382,7 +384,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { init_offloaded_task_function(stmt, "zero_fill"); call("gc_parallel_2", get_context(), snode_id); finalize_offloaded_task_function(); - current_task->grid_dim = prog->this_thread_config().saturating_grid_dim; + current_task->grid_dim = compile_config->saturating_grid_dim; current_task->block_dim = 64; offloaded_tasks.push_back(*current_task); current_task = nullptr; @@ -394,7 +396,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { init_offloaded_task_function(stmt, "gather_list"); call("gc_rc_parallel_0", get_context()); finalize_offloaded_task_function(); - current_task->grid_dim = prog->this_thread_config().saturating_grid_dim; + current_task->grid_dim = compile_config->saturating_grid_dim; current_task->block_dim = 64; offloaded_tasks.push_back(*current_task); current_task = nullptr; @@ -412,7 +414,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { init_offloaded_task_function(stmt, "zero_fill"); call("gc_rc_parallel_2", get_context()); finalize_offloaded_task_function(); - current_task->grid_dim = prog->this_thread_config().saturating_grid_dim; + current_task->grid_dim = compile_config->saturating_grid_dim; current_task->block_dim = 64; offloaded_tasks.push_back(*current_task); current_task = nullptr; @@ -584,26 +586,18 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } }; -#ifdef TI_WITH_LLVM -// static -std::unique_ptr KernelCodeGenCUDA::make_codegen_llvm( - Kernel *kernel, - IRNode *ir) { - return std::make_unique(kernel, ir); -} -#endif // TI_WITH_LLVM - LLVMCompiledTask KernelCodeGenCUDA::compile_task( + const CompileConfig *config, std::unique_ptr &&module, OffloadedStmt *stmt) { - TaskCodeGenCUDA gen(kernel, stmt); + TaskCodeGenCUDA gen(config, kernel, stmt); return gen.run_compilation(); } FunctionType KernelCodeGenCUDA::compile_to_function() { TI_AUTO_PROF auto *llvm_prog = get_llvm_program(prog); - const auto &config = prog->this_thread_config(); + const auto &config = *get_compile_config(); auto *tlctx = llvm_prog->get_llvm_context(config.arch); CUDAModuleToFunctionConverter converter{tlctx, diff --git a/taichi/codegen/cuda/codegen_cuda.h b/taichi/codegen/cuda/codegen_cuda.h index c636156e7866c..d8e4192915b86 100644 --- a/taichi/codegen/cuda/codegen_cuda.h +++ b/taichi/codegen/cuda/codegen_cuda.h @@ -16,9 +16,8 @@ class KernelCodeGenCUDA : public KernelCodeGen { // TODO: Stop defining this macro guards in the headers #ifdef TI_WITH_LLVM - static std::unique_ptr make_codegen_llvm(Kernel *kernel, - IRNode *ir); LLVMCompiledTask compile_task( + const CompileConfig *config, std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) override; #endif // TI_WITH_LLVM diff --git a/taichi/codegen/dx12/codegen_dx12.cpp b/taichi/codegen/dx12/codegen_dx12.cpp index 10419243008ac..99a419ae22a36 100644 --- a/taichi/codegen/dx12/codegen_dx12.cpp +++ b/taichi/codegen/dx12/codegen_dx12.cpp @@ -21,8 +21,8 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { public: using IRVisitor::visit; - TaskCodeGenLLVMDX12(Kernel *kernel, IRNode *ir) - : TaskCodeGenLLVM(kernel, ir, nullptr) { + TaskCodeGenLLVMDX12(const CompileConfig *config, Kernel *kernel, IRNode *ir) + : TaskCodeGenLLVM(config, kernel, ir, nullptr) { TI_AUTO_PROF } @@ -149,8 +149,7 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { create_bls_buffer(stmt); using Type = OffloadedStmt::TaskType; auto offloaded_task_name = init_offloaded_task_function(stmt); - if (prog->this_thread_config().kernel_profiler && - arch_is_cpu(prog->this_thread_config().arch)) { + if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) { call( builder.get(), "LLVMRuntime_profiler_start", {get_runtime(), builder->CreateGlobalStringPtr(offloaded_task_name)}); @@ -172,8 +171,7 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { } else { TI_NOT_IMPLEMENTED } - if (prog->this_thread_config().kernel_profiler && - arch_is_cpu(prog->this_thread_config().arch)) { + if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) { llvm::IRBuilderBase::InsertPointGuard guard(*builder); builder->SetInsertPoint(final_block); call(builder.get(), "LLVMRuntime_profiler_stop", {get_runtime()}); @@ -201,6 +199,7 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM { static std::vector generate_dxil_from_llvm( LLVMCompiledTask &compiled_data, + const CompileConfig *config, taichi::lang::Kernel *kernel) { // generate dxil from llvm ir. auto offloaded_local = compiled_data.tasks; @@ -209,21 +208,18 @@ static std::vector generate_dxil_from_llvm( llvm::Function *func = module->getFunction(task.name); TI_ASSERT(func); directx12::mark_function_as_cs_entry(func); - directx12::set_num_threads( - func, kernel->program->this_thread_config().default_gpu_block_dim, 1, - 1); + directx12::set_num_threads(func, config->default_gpu_block_dim, 1, 1); // FIXME: save task.block_dim like // tlctx->mark_function_as_cuda_kernel(func, task.block_dim); } - auto dx_container = directx12::global_optimize_module( - module, kernel->program->this_thread_config()); + auto dx_container = directx12::global_optimize_module(module, *config); // validate and sign dx container. return directx12::validate_and_sign(dx_container); } KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { TI_AUTO_PROF; - auto &config = prog->this_thread_config(); + auto &config = *get_compile_config(); std::string kernel_key = get_hashed_offline_cache_key(&config, kernel); kernel->set_kernel_key_for_cache(kernel_key); @@ -240,10 +236,10 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel()); irpass::re_id(offload.get()); auto *offload_stmt = offload->as(); - auto new_data = compile_task(nullptr, offload_stmt); + auto new_data = compile_task(&config, nullptr, offload_stmt); Result.task_dxil_source_codes.emplace_back( - generate_dxil_from_llvm(new_data, kernel)); + generate_dxil_from_llvm(new_data, &config, kernel)); aot::CompiledOffloadedTask task; // FIXME: build all fields for task. task.name = fmt::format("{}_{}_{}", kernel->get_name(), @@ -257,9 +253,10 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() { } LLVMCompiledTask KernelCodeGenDX12::compile_task( + const CompileConfig *config, std::unique_ptr &&module, OffloadedStmt *stmt) { - TaskCodeGenLLVMDX12 gen(kernel, stmt); + TaskCodeGenLLVMDX12 gen(config, kernel, stmt); return gen.run_compilation(); } #endif // TI_WITH_LLVM diff --git a/taichi/codegen/dx12/codegen_dx12.h b/taichi/codegen/dx12/codegen_dx12.h index 7aa07c0216606..356be6b4bc4c1 100644 --- a/taichi/codegen/dx12/codegen_dx12.h +++ b/taichi/codegen/dx12/codegen_dx12.h @@ -24,6 +24,7 @@ class KernelCodeGenDX12 : public KernelCodeGen { CompileResult compile(); #ifdef TI_WITH_LLVM LLVMCompiledTask compile_task( + const CompileConfig *config, std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) override; #endif diff --git a/taichi/codegen/dx12/dx12_global_optimize_module.cpp b/taichi/codegen/dx12/dx12_global_optimize_module.cpp index 86edc34af7231..2e44c51a965a5 100644 --- a/taichi/codegen/dx12/dx12_global_optimize_module.cpp +++ b/taichi/codegen/dx12/dx12_global_optimize_module.cpp @@ -65,7 +65,7 @@ GlobalVariable *createGlobalVariableForResource(Module &M, } std::vector global_optimize_module(llvm::Module *module, - CompileConfig &config) { + const CompileConfig &config) { TI_AUTO_PROF if (llvm::verifyModule(*module, &llvm::errs())) { module->print(llvm::errs(), nullptr); diff --git a/taichi/codegen/dx12/dx12_llvm_passes.h b/taichi/codegen/dx12/dx12_llvm_passes.h index 0da5b3e36a497..30d22ddb3bb3b 100644 --- a/taichi/codegen/dx12/dx12_llvm_passes.h +++ b/taichi/codegen/dx12/dx12_llvm_passes.h @@ -24,7 +24,7 @@ llvm::GlobalVariable *createGlobalVariableForResource(llvm::Module &M, llvm::Type *Ty); std::vector global_optimize_module(llvm::Module *module, - CompileConfig &config); + const CompileConfig &config); extern const char *NumWorkGroupsCBName; @@ -46,6 +46,6 @@ ModulePass *createTaichiRuntimeContextLowerPass(); void initializeTaichiIntrinsicLowerPass(PassRegistry &); /// Pass to lower taichi intrinsic into DXIL intrinsic. -ModulePass *createTaichiIntrinsicLowerPass(taichi::lang::CompileConfig *config); +ModulePass *createTaichiIntrinsicLowerPass(const taichi::lang::CompileConfig *config); } // namespace llvm diff --git a/taichi/codegen/dx12/dx12_lower_intrinsic.cpp b/taichi/codegen/dx12/dx12_lower_intrinsic.cpp index 2a694ca04af49..8467bde528bfb 100644 --- a/taichi/codegen/dx12/dx12_lower_intrinsic.cpp +++ b/taichi/codegen/dx12/dx12_lower_intrinsic.cpp @@ -96,14 +96,14 @@ class TaichiIntrinsicLower : public ModulePass { return true; } - TaichiIntrinsicLower(taichi::lang::CompileConfig *config = nullptr) + TaichiIntrinsicLower(const taichi::lang::CompileConfig *config = nullptr) : ModulePass(ID), config(config) { initializeTaichiIntrinsicLowerPass(*PassRegistry::getPassRegistry()); } static char ID; // Pass identification. private: - taichi::lang::CompileConfig *config; + const taichi::lang::CompileConfig *config; }; char TaichiIntrinsicLower::ID = 0; @@ -116,6 +116,6 @@ INITIALIZE_PASS(TaichiIntrinsicLower, false) llvm::ModulePass *llvm::createTaichiIntrinsicLowerPass( - taichi::lang::CompileConfig *config) { + const taichi::lang::CompileConfig *config) { return new TaichiIntrinsicLower(config); } diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index c94b168997336..f3a2388763103 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -301,19 +301,19 @@ void TaskCodeGenLLVM::emit_struct_meta_base(const std::string &name, snode->get_snode_tree_id())); } -TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel, +TaskCodeGenLLVM::TaskCodeGenLLVM(const CompileConfig *compile_config, + Kernel *kernel, IRNode *ir, std::unique_ptr &&module) // TODO: simplify LLVMModuleBuilder ctor input - : LLVMModuleBuilder( - module == nullptr - ? get_llvm_program(kernel->program) - ->get_llvm_context( - kernel->program->this_thread_config().arch) - ->new_module("kernel") - : std::move(module), - get_llvm_program(kernel->program) - ->get_llvm_context(kernel->program->this_thread_config().arch)), + : LLVMModuleBuilder(module == nullptr + ? get_llvm_program(kernel->program) + ->get_llvm_context(compile_config->arch) + ->new_module("kernel") + : std::move(module), + get_llvm_program(kernel->program) + ->get_llvm_context(compile_config->arch)), + compile_config(compile_config), kernel(kernel), ir(ir), prog(kernel->program) { @@ -556,8 +556,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); #if defined(__clang__) || defined(__GNUC__) - } else if (prog->this_thread_config().debug && - is_integral(stmt->ret_type)) { + } else if (compile_config->debug && is_integral(stmt->ret_type)) { llvm_val[stmt] = call("debug_add_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], @@ -572,8 +571,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); #if defined(__clang__) || defined(__GNUC__) - } else if (prog->this_thread_config().debug && - is_integral(stmt->ret_type)) { + } else if (compile_config->debug && is_integral(stmt->ret_type)) { llvm_val[stmt] = call("debug_sub_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], @@ -588,8 +586,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); #if defined(__clang__) || defined(__GNUC__) - } else if (prog->this_thread_config().debug && - is_integral(stmt->ret_type)) { + } else if (compile_config->debug && is_integral(stmt->ret_type)) { llvm_val[stmt] = call("debug_mul_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], @@ -624,7 +621,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateXor(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_shl) { #if defined(__clang__) || defined(__GNUC__) - if (prog->this_thread_config().debug && is_integral(stmt->ret_type)) { + if (compile_config->debug && is_integral(stmt->ret_type)) { llvm_val[stmt] = call("debug_shl_" + stmt->ret_type->to_string(), get_arg(0), llvm_val[stmt->lhs], llvm_val[stmt->rhs], @@ -887,9 +884,8 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) { llvm::Value *TaskCodeGenLLVM::create_print(std::string tag, DataType dt, llvm::Value *value) { - if (!arch_is_cpu(prog->this_thread_config().arch)) { - TI_WARN("print not supported on arch {}", - arch_name(prog->this_thread_config().arch)); + if (!arch_is_cpu(compile_config->arch)) { + TI_WARN("print not supported on arch {}", arch_name(compile_config->arch)); return nullptr; } std::vector args; @@ -966,7 +962,7 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { auto dtype = arg_stmt->ret_type->cast(); auto elem_type = dtype->get_element_type(); for (int i = 0; i < dtype->get_num_elements(); ++i) { - if (codegen_vector_type(&prog->this_thread_config())) { + if (codegen_vector_type(compile_config)) { TI_ASSERT(llvm::dyn_cast(value_type)); auto elem = builder->CreateExtractElement(value, i); args.push_back(value_for_printf(elem, elem_type)); @@ -2003,7 +1999,7 @@ void TaskCodeGenLLVM::finalize_offloaded_task_function() { builder->SetInsertPoint(entry_block); builder->CreateBr(func_body_bb); - if (prog->this_thread_config().print_kernel_llvm_ir) { + if (compile_config->print_kernel_llvm_ir) { static FileSequenceWriter writer("taichi_kernel_generic_llvm_ir_{:04d}.ll", "unoptimized LLVM IR (generic)"); writer.write(module.get()); @@ -2205,7 +2201,7 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt, auto exec_cond = tlctx->get_constant(true); auto coord_object = RuntimeObject(kLLVMPhysicalCoordinatesName, this, builder.get(), new_coordinates); - if (!prog->this_thread_config().packed) { + if (!compile_config->packed) { for (int i = 0; i < leaf_block->num_active_indices; i++) { auto j = leaf_block->physical_index_position[i]; if (!bit::is_power_of_two( @@ -2549,7 +2545,7 @@ void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) { llvm::Value *vec = llvm::UndefValue::get(type); for (int i = 0; i < stmt->values.size(); ++i) { auto *elem = llvm_val[stmt->values[i]]; - if (codegen_vector_type(&prog->this_thread_config())) { + if (codegen_vector_type(compile_config)) { TI_ASSERT(llvm::dyn_cast(type)); vec = builder->CreateInsertElement(vec, elem, i); } else { @@ -2578,8 +2574,7 @@ FunctionCreationGuard TaskCodeGenLLVM::get_function_creation_guard( } void TaskCodeGenLLVM::initialize_context() { - tlctx = - get_llvm_program(prog)->get_llvm_context(prog->this_thread_config().arch); + tlctx = get_llvm_program(prog)->get_llvm_context(compile_config->arch); llvm_context = tlctx->get_this_thread_context(); builder = std::make_unique>(*llvm_context); } @@ -2649,7 +2644,7 @@ void TaskCodeGenLLVM::emit_to_module() { LLVMCompiledTask TaskCodeGenLLVM::run_compilation() { // Final lowering - auto config = kernel->program->this_thread_config(); + const auto &config = *compile_config; kernel->offload_to_executable(ir); emit_to_module(); diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index 478364a275ce4..26deb3c1c969d 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -33,6 +33,7 @@ class FunctionCreationGuard { class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { public: + const CompileConfig *compile_config{nullptr}; Kernel *kernel; IRNode *ir; Program *prog; @@ -69,12 +70,13 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { using IRVisitor::visit; using LLVMModuleBuilder::call; - explicit TaskCodeGenLLVM(Kernel *kernel, + explicit TaskCodeGenLLVM(const CompileConfig *config, + Kernel *kernel, IRNode *ir = nullptr, std::unique_ptr &&module = nullptr); Arch current_arch() { - return prog->this_thread_config().arch; + return compile_config->arch; } void initialize_context(); diff --git a/taichi/codegen/llvm/codegen_llvm_quant.cpp b/taichi/codegen/llvm/codegen_llvm_quant.cpp index 909d1f67a7ca8..06f32c822b44d 100644 --- a/taichi/codegen/llvm/codegen_llvm_quant.cpp +++ b/taichi/codegen/llvm/codegen_llvm_quant.cpp @@ -100,7 +100,7 @@ void TaskCodeGenLLVM::store_masked(llvm::Value *ptr, return; } uint64 full_mask = (~(uint64)0) >> (64 - ty->getIntegerBitWidth()); - if ((!atomic || prog->this_thread_config().quant_opt_atomic_demotion) && + if ((!atomic || compile_config->quant_opt_atomic_demotion) && ((mask & full_mask) == full_mask)) { builder->CreateStore(value, ptr); return; @@ -154,7 +154,7 @@ void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) { } } bool store_all_components = false; - if (prog->this_thread_config().quant_opt_atomic_demotion && + if (compile_config->quant_opt_atomic_demotion && stmt->ch_ids.size() == num_non_exponent_children) { stmt->is_atomic = false; store_all_components = true; diff --git a/taichi/codegen/wasm/codegen_wasm.cpp b/taichi/codegen/wasm/codegen_wasm.cpp index e0171bcfaba56..07d356e07a9b1 100644 --- a/taichi/codegen/wasm/codegen_wasm.cpp +++ b/taichi/codegen/wasm/codegen_wasm.cpp @@ -23,10 +23,11 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM { public: using IRVisitor::visit; - TaskCodeGenWASM(Kernel *kernel, + TaskCodeGenWASM(const CompileConfig *config, + Kernel *kernel, IRNode *ir, std::unique_ptr &&M = nullptr) - : TaskCodeGenLLVM(kernel, ir, std::move(M)) { + : TaskCodeGenLLVM(config, kernel, ir, std::move(M)) { TI_AUTO_PROF } @@ -201,7 +202,7 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM { builder->SetInsertPoint(entry_block); builder->CreateBr(func_body_bb); - if (prog->this_thread_config().print_kernel_llvm_ir) { + if (compile_config->print_kernel_llvm_ir) { static FileSequenceWriter writer( "taichi_kernel_generic_llvm_ir_{:04d}.ll", "unoptimized LLVM IR (generic)"); @@ -212,7 +213,7 @@ class TaskCodeGenWASM : public TaskCodeGenLLVM { LLVMCompiledTask run_compilation() override { // lower kernel - irpass::ast_to_ir(kernel->program->this_thread_config(), *kernel); + irpass::ast_to_ir(*compile_config, *kernel); // emit_to_module auto offloaded_task_name = init_taichi_kernel_function(); @@ -239,7 +240,7 @@ FunctionType KernelCodeGenWASM::compile_to_function() { TI_AUTO_PROF auto linked = compile_kernel_to_module(); auto *tlctx = - get_llvm_program(prog)->get_llvm_context(prog->this_thread_config().arch); + get_llvm_program(prog)->get_llvm_context(get_compile_config()->arch); tlctx->create_jit_module(std::move(linked.module)); auto kernel_symbol = tlctx->lookup_function_pointer(linked.tasks[0].name); return [=](RuntimeContext &context) { @@ -250,12 +251,14 @@ FunctionType KernelCodeGenWASM::compile_to_function() { } LLVMCompiledTask KernelCodeGenWASM::compile_task( + const CompileConfig *config, std::unique_ptr &&module, OffloadedStmt *stmt) { kernel->offload_to_executable(ir); bool init_flag = module == nullptr; std::vector name_list; - auto gen = std::make_unique(kernel, ir, std::move(module)); + auto gen = + std::make_unique(config, kernel, ir, std::move(module)); name_list.emplace_back(nullptr); name_list[0].name = gen->init_taichi_kernel_function(); @@ -276,11 +279,11 @@ LLVMCompiledTask KernelCodeGenWASM::compile_task( } LLVMCompiledKernel KernelCodeGenWASM::compile_kernel_to_module() { - auto *tlctx = - get_llvm_program(prog)->get_llvm_context(prog->this_thread_config().arch); - irpass::ast_to_ir(kernel->program->this_thread_config(), *kernel, false); + const auto &config = *get_compile_config(); + auto *tlctx = get_llvm_program(prog)->get_llvm_context(config.arch); + irpass::ast_to_ir(config, *kernel, false); - auto res = compile_task(); + auto res = compile_task(&config); std::vector> data; data.push_back(std::make_unique(std::move(res))); return tlctx->link_compiled_tasks(std::move(data)); diff --git a/taichi/codegen/wasm/codegen_wasm.h b/taichi/codegen/wasm/codegen_wasm.h index bb7846a85ed94..3cd550b4017e0 100644 --- a/taichi/codegen/wasm/codegen_wasm.h +++ b/taichi/codegen/wasm/codegen_wasm.h @@ -21,6 +21,7 @@ class KernelCodeGenWASM : public KernelCodeGen { #ifdef TI_WITH_LLVM LLVMCompiledTask compile_task( + const CompileConfig *config, std::unique_ptr &&module = nullptr, OffloadedStmt *stmt = nullptr) override; // AOT Module Gen