Skip to content

Commit

Permalink
[llvm] [refactor] (Decomp of #5251 2/n) Make modulegen a virtual func…
Browse files Browse the repository at this point in the history
…tion and let LLVMCompiledData replace ModuleGenValue (#5353)

* [llvm] [refactor] (Decomp of #5251 2/n) Make modulegen a virtual function and let LLVMCompiledData replace ModuleGenValue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add ifdef

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Jul 7, 2022
1 parent 04ad121 commit 74399c4
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 55 deletions.
16 changes: 14 additions & 2 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Driver class for kernel code generators.

#pragma once

#include "taichi/ir/ir.h"
#include "taichi/program/program.h"

#ifdef TI_WITH_LLVM
#include "llvm/IR/Module.h"
#include "taichi/codegen/llvm/codegen_llvm.h"
#include "taichi/runtime/llvm/launch_arg_info.h"
#include "taichi/codegen/llvm/llvm_codegen_utils.h"
#endif
TLANG_NAMESPACE_BEGIN

class KernelCodeGen {
Expand All @@ -22,6 +27,13 @@ class KernelCodeGen {
Stmt *stmt = nullptr);

virtual FunctionType codegen() = 0;
#ifdef TI_WITH_LLVM
virtual LLVMCompiledData modulegen(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) {
TI_NOT_IMPLEMENTED
}
#endif
};

TLANG_NAMESPACE_END
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};

return converter.convert(this->kernel, std::move(compiled_res.llvm_module),
std::move(compiled_res.offloaded_tasks));
return converter.convert(this->kernel, std::move(compiled_res.module),
std::move(compiled_res.tasks));
}

llvm::Value *create_print(std::string tag,
Expand Down
20 changes: 10 additions & 10 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2313,14 +2313,14 @@ void CodeGenLLVM::emit_to_module() {
ir->accept(this);
}

CodeGenLLVM::CompiledData CodeGenLLVM::run_compilation() {
LLVMCompiledData CodeGenLLVM::run_compilation() {
const auto &config = prog->config;
std::string kernel_key =
get_hashed_offline_cache_key(&kernel->program->config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
CompiledData res;
LLVMCompiledData res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, &res);
if (ok) {
return res;
Expand All @@ -2339,15 +2339,15 @@ CodeGenLLVM::CompiledData CodeGenLLVM::run_compilation() {
cache_module(kernel_key);
}

CompiledData res;
res.offloaded_tasks = std::move(this->offloaded_tasks);
res.llvm_module = std::move(this->module);
LLVMCompiledData res;
res.tasks = std::move(this->offloaded_tasks);
res.module = std::move(this->module);
return res;
}

bool CodeGenLLVM::maybe_read_compilation_from_cache(
const std::string &kernel_key,
CompiledData *data) {
LLVMCompiledData *data) {
const auto &config = prog->config;
auto reader =
LlvmOfflineCacheFileReader::make(config.offline_cache_file_path);
Expand All @@ -2369,8 +2369,8 @@ bool CodeGenLLVM::maybe_read_compilation_from_cache(
t.grid_dim = task.grid_dim;
}
kernel->set_from_offline_cache();
data->offloaded_tasks = std::move(this->offloaded_tasks);
data->llvm_module = std::move(this->module);
data->tasks = std::move(this->offloaded_tasks);
data->module = std::move(this->module);
return true;
}

Expand All @@ -2379,8 +2379,8 @@ FunctionType CodeGenLLVM::gen() {

ModuleToFunctionConverter converter{
tlctx, get_llvm_program(prog)->get_runtime_executor()};
return converter.convert(kernel, std::move(compiled_res.llvm_module),
std::move(compiled_res.offloaded_tasks));
return converter.convert(kernel, std::move(compiled_res.module),
std::move(compiled_res.tasks));
}

llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
Expand Down
17 changes: 9 additions & 8 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class FunctionCreationGuard {
~FunctionCreationGuard();
};

struct LLVMCompiledData {
std::vector<OffloadedTask> tasks;
std::unique_ptr<llvm::Module> module{nullptr};
};

class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
public:
Kernel *kernel;
Expand Down Expand Up @@ -119,18 +124,14 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void eliminate_unused_functions();

struct CompiledData {
std::vector<OffloadedTask> offloaded_tasks;
std::unique_ptr<llvm::Module> llvm_module{nullptr};
};
/**
* @brief Runs the codegen and produces the compiled result.
*
* After this call, `module` and `offloaded_tasks` will be moved.
* After this call, `module` and `tasks` will be moved.
*
* @return CompiledData
* @return LLVMCompiledData
*/
CompiledData run_compilation();
LLVMCompiledData run_compilation();

// TODO: This function relies largely on `run_compilation()`. Name it better.
virtual FunctionType gen();
Expand Down Expand Up @@ -404,7 +405,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

private:
bool maybe_read_compilation_from_cache(const std::string &kernel_key,
CompiledData *data);
LLVMCompiledData *data);

void cache_module(const std::string &kernel_key);
};
Expand Down
16 changes: 8 additions & 8 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,28 +246,28 @@ FunctionType CodeGenWASM::codegen() {
return CodeGenLLVMWASM(kernel, ir).gen();
}

std::unique_ptr<ModuleGenValue> CodeGenWASM::modulegen(
std::unique_ptr<llvm::Module> &&module) {
LLVMCompiledData CodeGenWASM::modulegen(std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
bool init_flag = module == nullptr;
std::vector<std::string> name_list;

std::vector<OffloadedTask> name_list;
auto gen = std::make_unique<CodeGenLLVMWASM>(kernel, ir, std::move(module));

name_list.push_back(gen->init_taichi_kernel_function());
name_list.emplace_back(nullptr);
name_list[0].name = gen->init_taichi_kernel_function();
gen->emit_to_module();
gen->finalize_taichi_kernel_function();

// TODO: move the following functions to dump process in AOT.
if (init_flag) {
for (auto &name : kPreloadedFuncNames) {
name_list.emplace_back(name);
name_list.emplace_back(nullptr);
name_list.back().name = name;
}
}

gen->tlctx->jit->global_optimize_module(gen->module.get());

return std::make_unique<ModuleGenValue>(std::move(gen->module), name_list);
return {name_list, std::move(gen->module)};
}

} // namespace lang
} // namespace taichi
17 changes: 3 additions & 14 deletions taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@
namespace taichi {
namespace lang {

#ifdef TI_WITH_LLVM
class ModuleGenValue {
public:
ModuleGenValue(std::unique_ptr<llvm::Module> module,
const std::vector<std::string> &name_list)
: module(std::move(module)), name_list(name_list) {
}
std::unique_ptr<llvm::Module> module;
std::vector<std::string> name_list;
};
#endif

class CodeGenWASM : public KernelCodeGen {
public:
CodeGenWASM(Kernel *kernel, IRNode *ir = nullptr)
Expand All @@ -32,8 +20,9 @@ class CodeGenWASM : public KernelCodeGen {
FunctionType codegen() override;

#ifdef TI_WITH_LLVM
std::unique_ptr<ModuleGenValue> modulegen(
std::unique_ptr<llvm::Module> &&module); // AOT Module Gen
LLVMCompiledData modulegen(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override; // AOT Module Gen
#endif
};

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi {
namespace lang {
namespace cpu {

CodeGenLLVM::CompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = CodeGenCPU::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
LLVMCompiledData compile_kernel(Kernel *kernel) override;
};

} // namespace cpu
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi {
namespace lang {
namespace cuda {

CodeGenLLVM::CompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = CodeGenCUDA::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
LLVMCompiledData compile_kernel(Kernel *kernel) override;
};

} // namespace cuda
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/llvm/llvm_aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier,
auto compiled = compile_kernel(kernel);
LlvmOfflineCache::KernelCacheData kcache;
kcache.kernel_key = identifier;
kcache.module = compiled.llvm_module.get();
kcache.owned_module = std::move(compiled.llvm_module);
kcache.module = compiled.module.get();
kcache.owned_module = std::move(compiled.module);
kcache.args = infer_launch_args(kernel);
kcache.offloaded_task_list = std::move(compiled.offloaded_tasks);
kcache.offloaded_task_list = std::move(compiled.tasks);
cache_.kernels[identifier] = std::move(kcache);
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0;
virtual LLVMCompiledData compile_kernel(Kernel *kernel) = 0;

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto module_info = CodeGenWASM(kernel, nullptr).modulegen(std::move(module_));
module_ = std::move(module_info->module);
module_ = std::move(module_info.module);

for (auto &name : module_info->name_list)
name_list_.push_back(name);
for (auto &task : module_info.tasks)
name_list_.push_back(task.name);
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
Expand Down

0 comments on commit 74399c4

Please sign in to comment.