Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM] Remove the "ret_void" argument of AddFunction #15127

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {

void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
CodeGenLLVM::AddFunctionInternal(gvar, f);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
std::ostringstream attr;
attr << "1," << DetectROCMmaxThreadsPerBlock();
Expand Down
22 changes: 11 additions & 11 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ void CodeGenLLVM::InitTarget() {
}

llvm::Function* CodeGenLLVM::DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) {
return this->DeclareFunctionInternal(gvar, f, false);
return this->DeclareFunctionInternal(gvar, f);
}

void CodeGenLLVM::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
this->AddFunctionInternal(gvar, f, false);
this->AddFunctionInternal(gvar, f);
}

void CodeGenLLVM::InitFuncState() {
Expand All @@ -258,8 +258,7 @@ std::tuple<std::string, llvm::Function::LinkageTypes> CodeGenLLVM::GetLinkage(
return {symbol_name, llvm::Function::PrivateLinkage};
}

llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func,
bool ret_void) {
llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& func) {
if (auto it = functions_.find(gvar.get()); it != functions_.end()) {
return it->second;
}
Expand All @@ -275,11 +274,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons
alias_var_set_.insert(param.get());
}
}
// TODO(tvm-team):
// Update the function type to respect the ret_type field of f.
// Once we allow more flexibility in the PrimFunc.

llvm::FunctionType* ftype =
llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false);
llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false);

auto [symbol_name, linkage_type] = GetLinkage(gvar, func);

Expand All @@ -297,10 +294,10 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons
return function;
}

void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void) {
void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f) {
this->InitFuncState();

function_ = DeclareFunctionInternal(gvar, f, ret_void);
function_ = DeclareFunctionInternal(gvar, f);

// set var map and align information
auto arg_it = function_->arg_begin();
Expand Down Expand Up @@ -341,7 +338,10 @@ void CodeGenLLVM::AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f,
#endif

EmitDebugLocation(f->span);
if (ret_void) {

if (IsVoidType(f->ret_type)) {
// All other return types are handled when encountering
// builtin::ret().
builder_->CreateRetVoid();
} else {
builder_->CreateRet(ConstInt32(0));
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,9 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
std::tuple<std::string, llvm::Function::LinkageTypes> GetLinkage(const GlobalVar& gvar,
const PrimFunc& func);

llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
llvm::Function* DeclareFunctionInternal(const GlobalVar& gvar, const PrimFunc& f);

void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f, bool ret_void);
void AddFunctionInternal(const GlobalVar& gvar, const PrimFunc& f);

// Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name,
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ class CodeGenNVPTX : public CodeGenLLVM {
public:
llvm::Function* DeclareFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
return CodeGenLLVM::DeclareFunctionInternal(gvar, f, true);
return CodeGenLLVM::DeclareFunctionInternal(gvar, f);
}
void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(gvar, f, true);
CodeGenLLVM::AddFunctionInternal(gvar, f);
// annotate as kernel function
llvm::LLVMContext* ctx = llvm_target_->GetContext();
module_->getOrInsertNamedMetadata("nvvm.annotations")
Expand Down