Skip to content

Commit

Permalink
Fix lookup symbol from cuda module failed
Browse files Browse the repository at this point in the history
  • Loading branch information
PGZXB committed Nov 22, 2022
1 parent 5191a2c commit e8eb7f0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
6 changes: 6 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,12 @@ FunctionType CUDAModuleToFunctionConverter::convert(
auto &mod = data.module;
auto &tasks = data.tasks;
#ifdef TI_WITH_CUDA
for (const auto &task : tasks) {
llvm::Function *func = mod->getFunction(task.name);
TI_ASSERT(func);
tlctx_->mark_function_as_cuda_kernel(func, task.block_dim);
}

auto jit = tlctx_->jit.get();
auto cuda_module =
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg);
Expand Down
9 changes: 0 additions & 9 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2638,15 +2638,6 @@ LLVMCompiledTask TaskCodeGenLLVM::run_compilation() {
emit_to_module();
eliminate_unused_functions();

if (config.arch == Arch::cuda) {
// CUDA specific metadata
for (const auto &task : offloaded_tasks) {
llvm::Function *func = module->getFunction(task.name);
TI_ASSERT(func);
tlctx->mark_function_as_cuda_kernel(func, task.block_dim);
}
}

return {std::move(offloaded_tasks), std::move(module),
std::move(used_tree_ids), std::move(struct_for_tls_sizes)};
}
Expand Down
15 changes: 8 additions & 7 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,16 +663,17 @@ void TaichiLLVMContext::insert_nvvm_annotation(llvm::Function *func,
float addrspace(1)*,
float addrspace(1)*)* @kernel, !"kernel", i32 1}
*******************************************************************/
auto ctx = get_this_thread_context();
llvm::Metadata *md_args[] = {llvm::ValueAsMetadata::get(func),
MDString::get(*ctx, key),
llvm::ValueAsMetadata::get(get_constant(val))};
auto *llvm_mod = func->getParent();
auto *ctx = &llvm_mod->getContext();
llvm::Metadata *md_args[] = {
llvm::ValueAsMetadata::get(func), MDString::get(*ctx, key),
llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
*ctx, llvm::APInt(sizeof(val) * 8, (uint64)val,
std::is_signed_v<decltype(val)>)))};

MDNode *md_node = MDNode::get(*ctx, md_args);

func->getParent()
->getOrInsertNamedMetadata("nvvm.annotations")
->addOperand(md_node);
llvm_mod->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(md_node);
}

void TaichiLLVMContext::mark_function_as_cuda_kernel(llvm::Function *func,
Expand Down
19 changes: 11 additions & 8 deletions taichi/runtime/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ class TaichiLLVMContext {
llvm::Module *module,
std::function<bool(const std::string &)> export_indicator);

void mark_function_as_cuda_kernel(llvm::Function *func, int block_dim = 0);

void fetch_this_thread_struct_module();
llvm::Module *get_this_thread_runtime_module();
llvm::Function *get_runtime_function(const std::string &name);
Expand All @@ -139,29 +137,34 @@ class TaichiLLVMContext {

void add_struct_for_func(llvm::Module *module, int tls_size);

static std::string get_struct_for_func_name(int tls_size);

LLVMCompiledKernel link_compiled_tasks(
std::vector<std::unique_ptr<LLVMCompiledTask>> data_list);

static std::string get_struct_for_func_name(int tls_size);

static void mark_function_as_cuda_kernel(llvm::Function *func,
int block_dim = 0);

private:
std::unique_ptr<llvm::Module> clone_module_to_context(
llvm::Module *module,
llvm::LLVMContext *target_context);

void link_module_with_cuda_libdevice(std::unique_ptr<llvm::Module> &module);

static int num_instructions(llvm::Function *func);

void insert_nvvm_annotation(llvm::Function *func, std::string key, int val);

std::unique_ptr<llvm::Module> clone_module_to_this_thread_context(
llvm::Module *module);

ThreadLocalData *get_this_thread_data();

void update_runtime_jit_module(std::unique_ptr<llvm::Module> module);

static int num_instructions(llvm::Function *func);

static void insert_nvvm_annotation(llvm::Function *func,
std::string key,
int val);

std::unordered_map<std::thread::id, std::unique_ptr<ThreadLocalData>>
per_thread_data_;

Expand Down

0 comments on commit e8eb7f0

Please sign in to comment.