Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] [refactor] (Decomp of #5251 2/n) Make modulegen a virtual function and let LLVMCompiledData replace ModuleGenValue #5353

Merged
merged 5 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Driver class for kernel code generators.

#pragma once

#include "taichi/ir/ir.h"
#include "taichi/program/program.h"

#ifdef TI_WITH_LLVM
#include "llvm/IR/Module.h"
#include "taichi/codegen/llvm/codegen_llvm.h"
#include "taichi/runtime/llvm/launch_arg_info.h"
#include "taichi/codegen/llvm/llvm_codegen_utils.h"
#endif
TLANG_NAMESPACE_BEGIN

class KernelCodeGen {
Expand All @@ -22,6 +27,13 @@ class KernelCodeGen {
Stmt *stmt = nullptr);

virtual FunctionType codegen() = 0;
#ifdef TI_WITH_LLVM
virtual LLVMCompiledData modulegen(
strongoier marked this conversation as resolved.
Show resolved Hide resolved
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) {
TI_NOT_IMPLEMENTED
}
#endif
};

TLANG_NAMESPACE_END
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
CUDAModuleToFunctionConverter converter{tlctx,
llvm_prog->get_runtime_executor()};

return converter.convert(this->kernel, std::move(compiled_res.llvm_module),
std::move(compiled_res.offloaded_tasks));
return converter.convert(this->kernel, std::move(compiled_res.module),
std::move(compiled_res.tasks));
}

llvm::Value *create_print(std::string tag,
Expand Down
20 changes: 10 additions & 10 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2327,14 +2327,14 @@ void CodeGenLLVM::emit_to_module() {
ir->accept(this);
}

CodeGenLLVM::CompiledData CodeGenLLVM::run_compilation() {
LLVMCompiledData CodeGenLLVM::run_compilation() {
const auto &config = prog->config;
std::string kernel_key =
get_hashed_offline_cache_key(&kernel->program->config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
if (config.offline_cache && !config.async_mode &&
this->supports_offline_cache() && !kernel->is_evaluator) {
CompiledData res;
LLVMCompiledData res;
const bool ok = maybe_read_compilation_from_cache(kernel_key, &res);
if (ok) {
return res;
Expand All @@ -2353,15 +2353,15 @@ CodeGenLLVM::CompiledData CodeGenLLVM::run_compilation() {
cache_module(kernel_key);
}

CompiledData res;
res.offloaded_tasks = std::move(this->offloaded_tasks);
res.llvm_module = std::move(this->module);
LLVMCompiledData res;
res.tasks = std::move(this->offloaded_tasks);
res.module = std::move(this->module);
return res;
}

bool CodeGenLLVM::maybe_read_compilation_from_cache(
const std::string &kernel_key,
CompiledData *data) {
LLVMCompiledData *data) {
const auto &config = prog->config;
auto reader =
LlvmOfflineCacheFileReader::make(config.offline_cache_file_path);
Expand All @@ -2384,8 +2384,8 @@ bool CodeGenLLVM::maybe_read_compilation_from_cache(
t.grid_dim = task.grid_dim;
}
kernel->set_from_offline_cache();
data->offloaded_tasks = std::move(this->offloaded_tasks);
data->llvm_module = std::move(this->module);
data->tasks = std::move(this->offloaded_tasks);
data->module = std::move(this->module);
return true;
}

Expand All @@ -2394,8 +2394,8 @@ FunctionType CodeGenLLVM::gen() {

ModuleToFunctionConverter converter{
tlctx, get_llvm_program(prog)->get_runtime_executor()};
return converter.convert(kernel, std::move(compiled_res.llvm_module),
std::move(compiled_res.offloaded_tasks));
return converter.convert(kernel, std::move(compiled_res.module),
std::move(compiled_res.tasks));
}

llvm::Value *CodeGenLLVM::create_xlogue(std::unique_ptr<Block> &block) {
Expand Down
17 changes: 9 additions & 8 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class FunctionCreationGuard {
~FunctionCreationGuard();
};

struct LLVMCompiledData {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<OffloadedTask> tasks;
std::unique_ptr<llvm::Module> module{nullptr};
};

class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
public:
Kernel *kernel;
Expand Down Expand Up @@ -121,18 +126,14 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void eliminate_unused_functions();

struct CompiledData {
std::vector<OffloadedTask> offloaded_tasks;
std::unique_ptr<llvm::Module> llvm_module{nullptr};
};
/**
* @brief Runs the codegen and produces the compiled result.
*
* After this call, `module` and `offloaded_tasks` will be moved.
* After this call, `module` and `tasks` will be moved.
*
* @return CompiledData
* @return LLVMCompiledData
*/
CompiledData run_compilation();
LLVMCompiledData run_compilation();

// TODO: This function relies largely on `run_compilation()`. Name it better.
virtual FunctionType gen();
Expand Down Expand Up @@ -406,7 +407,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

private:
bool maybe_read_compilation_from_cache(const std::string &kernel_key,
CompiledData *data);
LLVMCompiledData *data);

void cache_module(const std::string &kernel_key);
};
Expand Down
16 changes: 8 additions & 8 deletions taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,28 +246,28 @@ FunctionType CodeGenWASM::codegen() {
return CodeGenLLVMWASM(kernel, ir).gen();
}

std::unique_ptr<ModuleGenValue> CodeGenWASM::modulegen(
std::unique_ptr<llvm::Module> &&module) {
LLVMCompiledData CodeGenWASM::modulegen(std::unique_ptr<llvm::Module> &&module,
OffloadedStmt *stmt) {
bool init_flag = module == nullptr;
std::vector<std::string> name_list;

std::vector<OffloadedTask> name_list;
auto gen = std::make_unique<CodeGenLLVMWASM>(kernel, ir, std::move(module));

name_list.push_back(gen->init_taichi_kernel_function());
name_list.emplace_back(nullptr);
name_list[0].name = gen->init_taichi_kernel_function();
gen->emit_to_module();
gen->finalize_taichi_kernel_function();

// TODO: move the following functions to dump process in AOT.
if (init_flag) {
for (auto &name : kPreloadedFuncNames) {
name_list.emplace_back(name);
name_list.emplace_back(nullptr);
name_list.back().name = name;
}
}

gen->tlctx->jit->global_optimize_module(gen->module.get());

return std::make_unique<ModuleGenValue>(std::move(gen->module), name_list);
return {name_list, std::move(gen->module)};
}

} // namespace lang
} // namespace taichi
17 changes: 3 additions & 14 deletions taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@
namespace taichi {
namespace lang {

#ifdef TI_WITH_LLVM
class ModuleGenValue {
public:
ModuleGenValue(std::unique_ptr<llvm::Module> module,
const std::vector<std::string> &name_list)
: module(std::move(module)), name_list(name_list) {
}
std::unique_ptr<llvm::Module> module;
std::vector<std::string> name_list;
};
#endif

class CodeGenWASM : public KernelCodeGen {
public:
CodeGenWASM(Kernel *kernel, IRNode *ir = nullptr)
Expand All @@ -32,8 +20,9 @@ class CodeGenWASM : public KernelCodeGen {
FunctionType codegen() override;

#ifdef TI_WITH_LLVM
std::unique_ptr<ModuleGenValue> modulegen(
std::unique_ptr<llvm::Module> &&module); // AOT Module Gen
LLVMCompiledData modulegen(
std::unique_ptr<llvm::Module> &&module = nullptr,
OffloadedStmt *stmt = nullptr) override; // AOT Module Gen
#endif
};

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi {
namespace lang {
namespace cpu {

CodeGenLLVM::CompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = CodeGenCPU::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
LLVMCompiledData compile_kernel(Kernel *kernel) override;
};

} // namespace cpu
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi {
namespace lang {
namespace cuda {

CodeGenLLVM::CompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
LLVMCompiledData AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = CodeGenCUDA::make_codegen_llvm(kernel, /*ir=*/nullptr);
return cgen->run_compilation();
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
LLVMCompiledData compile_kernel(Kernel *kernel) override;
};

} // namespace cuda
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/llvm/llvm_aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier,
auto compiled = compile_kernel(kernel);
LlvmOfflineCache::KernelCacheData kcache;
kcache.kernel_key = identifier;
kcache.module = compiled.llvm_module.get();
kcache.owned_module = std::move(compiled.llvm_module);
const auto &tasks = compiled.offloaded_tasks;
kcache.module = compiled.module.get();
kcache.owned_module = std::move(compiled.module);
const auto &tasks = compiled.tasks;
kcache.args = infer_launch_args(kernel);
kcache.offloaded_task_list.resize(tasks.size());
std::transform(tasks.begin(), tasks.end(), kcache.offloaded_task_list.begin(),
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0;
virtual LLVMCompiledData compile_kernel(Kernel *kernel) = 0;

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,
void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto module_info = CodeGenWASM(kernel, nullptr).modulegen(std::move(module_));
module_ = std::move(module_info->module);
module_ = std::move(module_info.module);

for (auto &name : module_info->name_list)
name_list_.push_back(name);
for (auto &task : module_info.tasks)
name_list_.push_back(task.name);
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
Expand Down