Skip to content

Commit

Permalink
Refactor2023: Remove dependencies on Program::this_thread_config() in…
Browse files Browse the repository at this point in the history
… codgen on llvm backends
  • Loading branch information
PGZXB committed Dec 6, 2022
1 parent 187f206 commit 423a29c
Show file tree
Hide file tree
Showing 17 changed files with 95 additions and 104 deletions.
3 changes: 2 additions & 1 deletion taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
auto offload =
irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel());
irpass::re_id(offload.get());
auto new_data = this->compile_task(nullptr, offload->as<OffloadedStmt>());
auto new_data =
this->compile_task(&config, nullptr, offload->as<OffloadedStmt>());
data[i] = std::make_unique<LLVMCompiledTask>(std::move(new_data));
};
if (kernel->is_evaluator) {
Expand Down
10 changes: 9 additions & 1 deletion taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ class KernelCodeGen {

virtual ~KernelCodeGen() = default;

static std::unique_ptr<KernelCodeGen> create(const CompileConfig *compile_config, Kernel *kernel);
static std::unique_ptr<KernelCodeGen> create(
const CompileConfig *compile_config,
Kernel *kernel);

virtual FunctionType compile_to_function() = 0;
virtual bool supports_offline_cache() const {
Expand All @@ -57,6 +59,7 @@ class KernelCodeGen {
virtual LLVMCompiledKernel compile_kernel_to_module();

virtual LLVMCompiledTask compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr){TI_NOT_IMPLEMENTED}

Expand All @@ -65,6 +68,11 @@ class KernelCodeGen {
void cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data);
#endif
protected:
const CompileConfig *get_compile_config() const {
return compile_config_;
}

private:
const CompileConfig *compile_config_{nullptr};
};
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace taichi::lang {

inline bool codegen_vector_type(CompileConfig *config) {
inline bool codegen_vector_type(const CompileConfig *config) {
if (config->real_matrix && !config->real_matrix_scalarize) {
return true;
}
Expand Down
24 changes: 8 additions & 16 deletions taichi/codegen/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
public:
using IRVisitor::visit;

TaskCodeGenCPU(Kernel *kernel, IRNode *ir)
: TaskCodeGenLLVM(kernel, ir, nullptr) {
TaskCodeGenCPU(const CompileConfig *config, Kernel *kernel, IRNode *ir)
: TaskCodeGenLLVM(config, kernel, ir, nullptr) {
TI_AUTO_PROF
}

Expand Down Expand Up @@ -56,7 +56,7 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
auto [begin, end] = get_range_for_bounds(stmt);

// adaptive block_dim
if (prog->this_thread_config().cpu_block_dim_adaptive) {
if (compile_config->cpu_block_dim_adaptive) {
int num_items = (stmt->end_value - stmt->begin_value) / std::abs(step);
int num_threads = stmt->num_cpu_threads;
int items_per_thread = std::max(1, num_items / (num_threads * 32));
Expand Down Expand Up @@ -166,8 +166,7 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
create_bls_buffer(stmt);
using Type = OffloadedStmt::TaskType;
auto offloaded_task_name = init_offloaded_task_function(stmt);
if (prog->this_thread_config().kernel_profiler &&
arch_is_cpu(prog->this_thread_config().arch)) {
if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) {
call("LLVMRuntime_profiler_start", get_runtime(),
builder->CreateGlobalStringPtr(offloaded_task_name));
}
Expand All @@ -190,8 +189,7 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
} else {
TI_NOT_IMPLEMENTED
}
if (prog->this_thread_config().kernel_profiler &&
arch_is_cpu(prog->this_thread_config().arch)) {
if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) {
llvm::IRBuilderBase::InsertPointGuard guard(*builder);
builder->SetInsertPoint(final_block);
call("LLVMRuntime_profiler_stop", get_runtime());
Expand All @@ -216,13 +214,6 @@ class TaskCodeGenCPU : public TaskCodeGenLLVM {
} // namespace

#ifdef TI_WITH_LLVM
// static
std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCPU::make_codegen_llvm(
Kernel *kernel,
IRNode *ir) {
return std::make_unique<TaskCodeGenCPU>(kernel, ir);
}

FunctionType CPUModuleToFunctionConverter::convert(
const std::string &kernel_name,
const std::vector<LlvmLaunchArgInfo> &args,
Expand Down Expand Up @@ -264,17 +255,18 @@ FunctionType CPUModuleToFunctionConverter::convert(
}

LLVMCompiledTask KernelCodeGenCPU::compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
TaskCodeGenCPU gen(kernel, stmt);
TaskCodeGenCPU gen(config, kernel, stmt);
return gen.run_compilation();
}
#endif // TI_WITH_LLVM

FunctionType KernelCodeGenCPU::compile_to_function() {
TI_AUTO_PROF;
auto *llvm_prog = get_llvm_program(prog);
const auto &config = prog->this_thread_config();
const auto &config = *get_compile_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);

CPUModuleToFunctionConverter converter(
Expand Down
4 changes: 1 addition & 3 deletions taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ class KernelCodeGenCPU : public KernelCodeGen {

// TODO: Stop defining this macro guards in the headers
#ifdef TI_WITH_LLVM
static std::unique_ptr<TaskCodeGenLLVM> make_codegen_llvm(Kernel *kernel,
IRNode *ir);

bool supports_offline_cache() const override {
return true;
}
LLVMCompiledTask compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;

Expand Down
30 changes: 12 additions & 18 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
public:
using IRVisitor::visit;

explicit TaskCodeGenCUDA(Kernel *kernel, IRNode *ir = nullptr)
: TaskCodeGenLLVM(kernel, ir) {
explicit TaskCodeGenCUDA(const CompileConfig *config,
Kernel *kernel,
IRNode *ir = nullptr)
: TaskCodeGenLLVM(config, kernel, ir) {
}

llvm::Value *create_print(std::string tag,
Expand Down Expand Up @@ -106,7 +108,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
auto elem_type = dtype->get_element_type();
for (int i = 0; i < dtype->get_num_elements(); ++i) {
llvm::Value *elem_value;
if (codegen_vector_type(&prog->this_thread_config())) {
if (codegen_vector_type(compile_config)) {
TI_ASSERT(llvm::dyn_cast<llvm::VectorType>(value_type));
elem_value = builder->CreateExtractElement(value, i);
} else {
Expand Down Expand Up @@ -364,7 +366,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
init_offloaded_task_function(stmt, "gather_list");
call("gc_parallel_0", get_context(), snode_id);
finalize_offloaded_task_function();
current_task->grid_dim = prog->this_thread_config().saturating_grid_dim;
current_task->grid_dim = compile_config->saturating_grid_dim;
current_task->block_dim = 64;
offloaded_tasks.push_back(*current_task);
current_task = nullptr;
Expand All @@ -382,7 +384,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
init_offloaded_task_function(stmt, "zero_fill");
call("gc_parallel_2", get_context(), snode_id);
finalize_offloaded_task_function();
current_task->grid_dim = prog->this_thread_config().saturating_grid_dim;
current_task->grid_dim = compile_config->saturating_grid_dim;
current_task->block_dim = 64;
offloaded_tasks.push_back(*current_task);
current_task = nullptr;
Expand All @@ -394,7 +396,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
init_offloaded_task_function(stmt, "gather_list");
call("gc_rc_parallel_0", get_context());
finalize_offloaded_task_function();
current_task->grid_dim = prog->this_thread_config().saturating_grid_dim;
current_task->grid_dim = compile_config->saturating_grid_dim;
current_task->block_dim = 64;
offloaded_tasks.push_back(*current_task);
current_task = nullptr;
Expand All @@ -412,7 +414,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
init_offloaded_task_function(stmt, "zero_fill");
call("gc_rc_parallel_2", get_context());
finalize_offloaded_task_function();
current_task->grid_dim = prog->this_thread_config().saturating_grid_dim;
current_task->grid_dim = compile_config->saturating_grid_dim;
current_task->block_dim = 64;
offloaded_tasks.push_back(*current_task);
current_task = nullptr;
Expand Down Expand Up @@ -584,26 +586,18 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
}
};

#ifdef TI_WITH_LLVM
// static
std::unique_ptr<TaskCodeGenLLVM> KernelCodeGenCUDA::make_codegen_llvm(
Kernel *kernel,
IRNode *ir) {
return std::make_unique<TaskCodeGenCUDA>(kernel, ir);
}
#endif // TI_WITH_LLVM

LLVMCompiledTask KernelCodeGenCUDA::compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
TaskCodeGenCUDA gen(kernel, stmt);
TaskCodeGenCUDA gen(config, kernel, stmt);
return gen.run_compilation();
}

FunctionType KernelCodeGenCUDA::compile_to_function() {
TI_AUTO_PROF
auto *llvm_prog = get_llvm_program(prog);
const auto &config = prog->this_thread_config();
const auto &config = *get_compile_config();
auto *tlctx = llvm_prog->get_llvm_context(config.arch);

CUDAModuleToFunctionConverter converter{tlctx,
Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ class KernelCodeGenCUDA : public KernelCodeGen {

// TODO: Stop defining this macro guards in the headers
#ifdef TI_WITH_LLVM
static std::unique_ptr<TaskCodeGenLLVM> make_codegen_llvm(Kernel *kernel,
IRNode *ir);
LLVMCompiledTask compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;
#endif // TI_WITH_LLVM
Expand Down
27 changes: 12 additions & 15 deletions taichi/codegen/dx12/codegen_dx12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM {
public:
using IRVisitor::visit;

TaskCodeGenLLVMDX12(Kernel *kernel, IRNode *ir)
: TaskCodeGenLLVM(kernel, ir, nullptr) {
TaskCodeGenLLVMDX12(const CompileConfig *config, Kernel *kernel, IRNode *ir)
: TaskCodeGenLLVM(config, kernel, ir, nullptr) {
TI_AUTO_PROF
}

Expand Down Expand Up @@ -149,8 +149,7 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM {
create_bls_buffer(stmt);
using Type = OffloadedStmt::TaskType;
auto offloaded_task_name = init_offloaded_task_function(stmt);
if (prog->this_thread_config().kernel_profiler &&
arch_is_cpu(prog->this_thread_config().arch)) {
if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) {
call(
builder.get(), "LLVMRuntime_profiler_start",
{get_runtime(), builder->CreateGlobalStringPtr(offloaded_task_name)});
Expand All @@ -172,8 +171,7 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM {
} else {
TI_NOT_IMPLEMENTED
}
if (prog->this_thread_config().kernel_profiler &&
arch_is_cpu(prog->this_thread_config().arch)) {
if (compile_config->kernel_profiler && arch_is_cpu(compile_config->arch)) {
llvm::IRBuilderBase::InsertPointGuard guard(*builder);
builder->SetInsertPoint(final_block);
call(builder.get(), "LLVMRuntime_profiler_stop", {get_runtime()});
Expand Down Expand Up @@ -201,6 +199,7 @@ class TaskCodeGenLLVMDX12 : public TaskCodeGenLLVM {

static std::vector<uint8_t> generate_dxil_from_llvm(
LLVMCompiledTask &compiled_data,
const CompileConfig *config,
taichi::lang::Kernel *kernel) {
// generate dxil from llvm ir.
auto offloaded_local = compiled_data.tasks;
Expand All @@ -209,21 +208,18 @@ static std::vector<uint8_t> generate_dxil_from_llvm(
llvm::Function *func = module->getFunction(task.name);
TI_ASSERT(func);
directx12::mark_function_as_cs_entry(func);
directx12::set_num_threads(
func, kernel->program->this_thread_config().default_gpu_block_dim, 1,
1);
directx12::set_num_threads(func, config->default_gpu_block_dim, 1, 1);
// FIXME: save task.block_dim like
// tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
}
auto dx_container = directx12::global_optimize_module(
module, kernel->program->this_thread_config());
auto dx_container = directx12::global_optimize_module(module, *config);
// validate and sign dx container.
return directx12::validate_and_sign(dx_container);
}

KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() {
TI_AUTO_PROF;
auto &config = prog->this_thread_config();
auto &config = *get_compile_config();
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);

Expand All @@ -240,10 +236,10 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() {
irpass::analysis::clone(offloads[i].get(), offloads[i]->get_kernel());
irpass::re_id(offload.get());
auto *offload_stmt = offload->as<OffloadedStmt>();
auto new_data = compile_task(nullptr, offload_stmt);
auto new_data = compile_task(&config, nullptr, offload_stmt);

Result.task_dxil_source_codes.emplace_back(
generate_dxil_from_llvm(new_data, kernel));
generate_dxil_from_llvm(new_data, &config, kernel));
aot::CompiledOffloadedTask task;
// FIXME: build all fields for task.
task.name = fmt::format("{}_{}_{}", kernel->get_name(),
Expand All @@ -257,9 +253,10 @@ KernelCodeGenDX12::CompileResult KernelCodeGenDX12::compile() {
}

LLVMCompiledTask KernelCodeGenDX12::compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
TaskCodeGenLLVMDX12 gen(kernel, stmt);
TaskCodeGenLLVMDX12 gen(config, kernel, stmt);
return gen.run_compilation();
}
#endif // TI_WITH_LLVM
Expand Down
1 change: 1 addition & 0 deletions taichi/codegen/dx12/codegen_dx12.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class KernelCodeGenDX12 : public KernelCodeGen {
CompileResult compile();
#ifdef TI_WITH_LLVM
LLVMCompiledTask compile_task(
const CompileConfig *config,
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override;
#endif
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/dx12/dx12_global_optimize_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ GlobalVariable *createGlobalVariableForResource(Module &M,
}

std::vector<uint8_t> global_optimize_module(llvm::Module *module,
CompileConfig &config) {
const CompileConfig &config) {
TI_AUTO_PROF
if (llvm::verifyModule(*module, &llvm::errs())) {
module->print(llvm::errs(), nullptr);
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/dx12/dx12_llvm_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ llvm::GlobalVariable *createGlobalVariableForResource(llvm::Module &M,
llvm::Type *Ty);

std::vector<uint8_t> global_optimize_module(llvm::Module *module,
CompileConfig &config);
const CompileConfig &config);

extern const char *NumWorkGroupsCBName;

Expand All @@ -46,6 +46,6 @@ ModulePass *createTaichiRuntimeContextLowerPass();
void initializeTaichiIntrinsicLowerPass(PassRegistry &);

/// Pass to lower taichi intrinsic into DXIL intrinsic.
ModulePass *createTaichiIntrinsicLowerPass(taichi::lang::CompileConfig *config);
ModulePass *createTaichiIntrinsicLowerPass(const taichi::lang::CompileConfig *config);

} // namespace llvm
6 changes: 3 additions & 3 deletions taichi/codegen/dx12/dx12_lower_intrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ class TaichiIntrinsicLower : public ModulePass {
return true;
}

TaichiIntrinsicLower(taichi::lang::CompileConfig *config = nullptr)
TaichiIntrinsicLower(const taichi::lang::CompileConfig *config = nullptr)
: ModulePass(ID), config(config) {
initializeTaichiIntrinsicLowerPass(*PassRegistry::getPassRegistry());
}

static char ID; // Pass identification.
private:
taichi::lang::CompileConfig *config;
const taichi::lang::CompileConfig *config;
};
char TaichiIntrinsicLower::ID = 0;

Expand All @@ -116,6 +116,6 @@ INITIALIZE_PASS(TaichiIntrinsicLower,
false)

llvm::ModulePass *llvm::createTaichiIntrinsicLowerPass(
taichi::lang::CompileConfig *config) {
const taichi::lang::CompileConfig *config) {
return new TaichiIntrinsicLower(config);
}
Loading

0 comments on commit 423a29c

Please sign in to comment.