Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Generate kernel host call (#916)
Browse files Browse the repository at this point in the history
* generate kernel host call
  • Loading branch information
SunNy820828449 authored Aug 31, 2022
1 parent 3d688db commit af52b7a
Show file tree
Hide file tree
Showing 27 changed files with 279 additions and 246 deletions.
9 changes: 6 additions & 3 deletions cinn/auto_schedule/measure/simple_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,15 @@ TEST_F(TestSimpleRunner, MeasureWithSpecifiedArgs) {

TEST_F(TestSimpleRunner, TimeMeasured) {
// set up a BuildResult object with one instruction of the `sleep` function
auto sleep_fn = [](void*, int32_t) { std::this_thread::sleep_for(std::chrono::microseconds(100)); };
void (*sleep_fn)(void*, int32_t) = [](void*, int32_t) -> void {
std::this_thread::sleep_for(std::chrono::microseconds(100));
};
BuildResult build_result;
build_result.compiled_scope = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions;
instructions.emplace_back(new Instruction(target, nullptr, {}, {"empty_placeholder"}, "sleep_fn"));
instructions.back()->SetLoweredFunc(sleep_fn);
instructions.emplace_back(
new Instruction(common::DefaultHostTarget(), nullptr, {}, {"empty_placeholder"}, "sleep_fn"));
instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn));
instructions.back()->Finalize();
build_result.runtime_program.reset(new hlir::framework::Program(nullptr, std::move(instructions)));

Expand Down
7 changes: 3 additions & 4 deletions cinn/backends/codegen_cuda_dev_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1500,8 +1500,6 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
LOG(INFO) << "fn_kernel: " << fn_kernel;

GlobalSymbolRegistry::Global().RegisterFn("fn_kernel_ptr_", reinterpret_cast<void*>(&fn_kernel));
GlobalSymbolRegistry::Global().RegisterVar("fn_kernel_stream_ptr_", stream);

// compile host
{
auto jit = SimpleJIT::Create();
Expand All @@ -1527,11 +1525,12 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
CUDA_CALL(cudaDeviceSynchronize());

// call the kernel
auto comp = reinterpret_cast<void (*)(cinn_pod_value_t*, int)>(fn_ptr);
auto comp = reinterpret_cast<void (*)(void*, int, void*)>(fn_ptr);

auto args = common::ArgsBuilder().Add(M.as_int32()).Add(A_buf).Add(B_buf).Add(C_buf).Build();

comp(args.data(), args.size());
void* stream = nullptr;
comp(args.data(), args.size(), stream);

CUDA_CALL(cudaDeviceSynchronize());

Expand Down
151 changes: 71 additions & 80 deletions cinn/backends/codegen_cuda_host.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <algorithm>
#include <string>
#include <unordered_map>

#include "cinn/backends/codegen_cuda_util.h"
#include "cinn/backends/extern_func_emitter_builtin.h"
#include "cinn/backends/extern_func_jit_register.h"
#include "cinn/backends/llvm/llvm_util.h"
Expand All @@ -28,32 +30,9 @@ namespace backends {
const int kArgsArrayMaxLen = 20;

llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* func) {
CHECK(func->cuda_axis_info.valid());
/* The current function definiton is
* void fn(cinn_pod_value_t* args, int num_args) {
* Call(fn_kernel, args, num_args);
* }
* will lower to
* void fn(cinn_pod_value_t* args, int num_args) { // num_args is ignored here.
* // NOTE the num_args is unnecessary here, but it should follow the pattern of CINN function.
* cinn_call_cuda_kernel(fn_kernel_ptr, args, grid_dim, block_dim, fn_kernel_stream);
* }
*
* NOTE the global variables related to CUDA in LLVM module are
* 1. fn_kernel_ptr, the pointer to the compiled kernel function returned by CUDA driver
* 2. fn_kernel_stream, the CUDA stream this kernel should launch on.
*/

// hard-code here to verify it is a simple CUDA host function.
// @{
auto body = func->body;
auto* block = body.As<ir::Block>();
CHECK(block);

CHECK_EQ(block->stmts.size(), 1UL);
auto* call = block->stmts[0].As<ir::Call>();
CHECK(call);
// @}
auto body = func->body;
auto* call_ir = body.As<ir::Call>();
CHECK(call_ir);

// Create the function
// @{
Expand All @@ -66,70 +45,82 @@ llvm::Value* CodeGenCUDA_Host::LowerGPUKernelLauncher(const ir::_LoweredFunc_* f
std::transform(function->arg_begin(), function->arg_end(), std::back_inserter(ll_function_args), [](auto& arg) {
return std::addressof(arg);
});
// @}

llvm::BasicBlock* entry = llvm::BasicBlock::Create(
/*Context=*/b_->getContext(),
/*Name=*/"entry",
/*Parent=*/function,
/*InsertBefore=*/nullptr);
b_->SetInsertPoint(entry);
// @}

// Get the arguments of the function.
// @{
auto* ll_args = ll_function_args[0];
auto* ll_args_count = ll_function_args[1];
CHECK_EQ(ll_args->getType(), ll_cinn_pod_p_ty()); // cinn_pod_value_t* args
CHECK_EQ(ll_args_count->getType(), ll_int32_ty()); // int32

auto* ll_num_args_copied = b_->CreateAlloca(ll_int32_ty(), nullptr, "num_args_copied");
Store(ll_args_count, ll_num_args_copied);
SetVar(std::string(ll_num_args_copied->getName()), ll_num_args_copied);

const std::string& func_arg0_name = func->args[0].name();
CHECK(LLVM_WillVarLowerAsPointer(func_arg0_name))
<< "Variable [" << func_arg0_name << "] should have a name like someting will be lower to a pointer";
SetVar(func->args[0].var_arg()->name, ll_args);
// @}

const std::string kernel_ptr_global_var_name = GenKernelPtrVarName(func->name);
const std::string kernel_stream_global_var_name = GenKernelStreamVarName(func->name);
// set global variables to reference the [kernel_ptr] and [kernel_stream] for this kernel
SetVar(kernel_ptr_global_var_name, m_->getOrInsertGlobal(kernel_ptr_global_var_name, ll_void_p_ty()));
SetVar(kernel_stream_global_var_name, m_->getOrInsertGlobal(kernel_stream_global_var_name, ll_void_p_ty()));

{ // create a new Call node for the ExternFunctionEmitter
Var args_var(func->args[0].var_arg()->name, type_of<cinn_pod_value_t*>()); // pass *args directly to kernel
Var kernel_fn_ptr_var(kernel_ptr_global_var_name, type_of<void*>());
Var kernel_stream_var(kernel_stream_global_var_name, type_of<void*>());

auto new_call_node = ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
{
kernel_fn_ptr_var, // kernel_fn
args_var, // args
Var(std::string(ll_num_args_copied->getName()), type_of<int32_t>()),
Expr(func->cuda_axis_info.grid_dim(0)), // grid_x
Expr(func->cuda_axis_info.grid_dim(1)), // grid_y
Expr(func->cuda_axis_info.grid_dim(2)), // grid_z
Expr(func->cuda_axis_info.block_dim(0)), // block_x
Expr(func->cuda_axis_info.block_dim(1)), // block_y
Expr(func->cuda_axis_info.block_dim(2)), // block_z
kernel_stream_var // stream
},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);

auto emitter_id = ExternFuncID{backend_llvm_host, runtime::intrinsic::call_cuda_kernel};
const auto& fn_name = ExternFunctionEmitterRegistry::Global().Lookup(emitter_id);
CHECK(!fn_name.empty()) << "No extern function emitter called " << emitter_id;
ExternFunctionLLVMEmitter emitter(fn_name);
emitter.BindCodeGen(this);
emitter.Emit(new_call_node.As<ir::Call>());
auto* kernel_args = ll_function_args[0];
auto* kernel_args_count = ll_function_args[1];
llvm::Value* kernel_stream = nullptr;
if (ll_function_args.size() == 3) {
kernel_stream = ll_function_args[2];
CHECK_EQ(kernel_stream->getType(), ll_void_p_ty()); // void* stream
}
CHECK_EQ(kernel_args->getType(), ll_void_p_ty()); // void* args
CHECK_EQ(kernel_args_count->getType(), ll_int32_ty()); // int32

std::unordered_map<std::string, llvm::Value*> global_args = {
{KERNEL_ARGS, kernel_args}, {KERNEL_ARGS_NUM, kernel_args_count}, {KERNEL_STREAM, kernel_stream}};

auto ret_type = CinnTypeToLLVMType(Void(), m_);
std::vector<llvm::Type*> args_type;
for (auto r_arg : call_ir->read_args) {
if (r_arg.is_var()) {
if (r_arg.as_var()->type().is_string()) {
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
} else if (r_arg.as_var()->type().is_cpp_handle()) {
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
} else if (r_arg.as_var()->type().is_int(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));
} else {
CINN_NOT_IMPLEMENTED;
}
} else {
if (r_arg.type().is_bool()) {
args_type.push_back(CinnTypeToLLVMType(type_of<bool>(), m_));
} else if (r_arg.type().is_int(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));
} else if (r_arg.type().is_float(32)) {
args_type.push_back(CinnTypeToLLVMType(type_of<float>(), m_));
} else {
CINN_NOT_IMPLEMENTED;
}
}
}
auto func_type = llvm::FunctionType::get(ret_type, args_type, false);
auto call_func = m_->getOrInsertFunction(call_ir->name, func_type);

std::vector<llvm::Value*> call_args;
for (auto& r_arg : call_ir->read_args) {
if (r_arg.is_var()) {
if (r_arg.as_var()->type().is_string()) {
auto kvalue = m_->getOrInsertGlobal(r_arg.as_var()->name + "_ptr_", b_->getInt8PtrTy());
call_args.push_back(b_->CreateLoad(b_->getInt8PtrTy(), kvalue, r_arg.as_var()->name + "_ptr_load"));
} else if (r_arg.as_var()->type().is_cpp_handle() || r_arg.as_var()->type().is_int(32)) {
CHECK(global_args.count(r_arg.as_var()->name));
call_args.push_back(global_args[r_arg.as_var()->name]);
} else {
CINN_NOT_IMPLEMENTED;
}
} else {
if (r_arg.type().is_bool()) {
call_args.push_back(b_->getInt1(r_arg.as_bool()));
} else if (r_arg.type().is_int(32)) {
call_args.push_back(b_->getInt32(r_arg.as_int32()));
} else if (r_arg.type().is_float(32)) {
call_args.push_back(llvm::ConstantFP::get(b_->getFloatTy(), llvm::APFloat(r_arg.as_float())));
} else {
CINN_NOT_IMPLEMENTED;
}
}
}

b_->CreateCall(call_func, call_args);
RetVoid();

return function;
Expand Down
13 changes: 1 addition & 12 deletions cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,8 @@ class CodeGenCUDA_Host : public CodeGenLLVM {
explicit CodeGenCUDA_Host(llvm::Module *m, llvm::IRBuilder<> *b, const std::shared_ptr<SymbolTable> &vars = nullptr)
: CodeGenLLVM(m, b, vars) {}

static std::string GenKernelPtrVarName(const std::string &kernel_name) { return kernel_name + "_kernel_ptr_"; }
static std::string GenKernelStreamVarName(const std::string &kernel_name) {
return kernel_name + "_kernel_stream_ptr_";
}

using CodeGenLLVM::Visit;

llvm::Value *Visit(const ir::_LoweredFunc_ *func) override {
if (func->is_gpu_host()) {
return LowerGPUKernelLauncher(func);
}
return CodeGenLLVM::Visit(func);
}
llvm::Value *Visit(const ir::_LoweredFunc_ *func) override { return LowerGPUKernelLauncher(func); }

private:
/**
Expand Down
4 changes: 0 additions & 4 deletions cinn/backends/codegen_cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,5 @@ std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule(ir::Module module) {
return visitor(&expr);
}

bool detail::CollectHostFunctionVisitor::IsCudaFunction(const ir::_LoweredFunc_* func) {
return func->device_api == ir::DeviceAPI::GPU;
}

} // namespace backends
} // namespace cinn
49 changes: 33 additions & 16 deletions cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
namespace cinn {
namespace backends {

#define KERNEL_ARGS "kernel_args"
#define KERNEL_ARGS_NUM "kernel_args_num"
#define KERNEL_STREAM "kernel_stream"

/**
* Split a CINN Module into two separate modules, one cantains the host functions, the other contains the device
* kernels.
Expand Down Expand Up @@ -83,21 +87,36 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
* \endcode
*/
Expr CreateHostFunctionGivenDeviceKernel(const ir::_LoweredFunc_* func) {
std::vector<Expr> args;
// std::vector<Expr> args;
// NOTE the suffix `__ptr` makes this argument lower to a pointer in LLVM backend.
args.push_back(Var("args__ptr", type_of<cinn_pod_value_t*>()));
args.push_back(Var("num_args", type_of<int32_t>()));

auto call =
ir::Call::Make(Void(), GenDeviceKernelName(func->name), args, {}, ir::CallType::Extern, ir::FunctionRef(), 0);
Expr body = ir::Block::Make({call});

std::vector<ir::Argument> host_func_args;
host_func_args.emplace_back(args[0], ir::Argument::IO::kOutput);
host_func_args.emplace_back(args[1], ir::Argument::IO::kOutput);
auto host_func = ir::_LoweredFunc_::Make(func->name, host_func_args, body, {});
host_func->cuda_axis_info = func->cuda_axis_info;
return host_func;
// args.push_back(Var("args__ptr", type_of<cinn_pod_value_t*>()));
// args.push_back(Var("num_args", type_of<int32_t>()));
ir::Var kernel_ptr(GenDeviceKernelName(func->name), type_of<std::string>());
ir::Var kernel_args(KERNEL_ARGS, type_of<void*>());
ir::Var kernel_args_num(KERNEL_ARGS_NUM, type_of<int>());
ir::Var kernel_stream(KERNEL_STREAM, type_of<void*>());

auto call_extern_api = ir::Call::Make(Void(),
runtime::intrinsic::call_cuda_kernel,
{kernel_ptr,
kernel_args,
kernel_args_num,
Expr(func->cuda_axis_info.grid_dim(0)), // grid_x
Expr(func->cuda_axis_info.grid_dim(1)), // grid_y
Expr(func->cuda_axis_info.grid_dim(2)), // grid_z
Expr(func->cuda_axis_info.block_dim(0)), // block_x
Expr(func->cuda_axis_info.block_dim(1)), // block_y
Expr(func->cuda_axis_info.block_dim(2)), // block_z
kernel_stream},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
std::vector<ir::Argument> arguments = {ir::Argument(kernel_args, ir::Argument::IO::kOutput),
ir::Argument(kernel_args_num, ir::Argument::IO::kInput),
ir::Argument(kernel_stream, ir::Argument::IO::kOutput)};

return ir::_LoweredFunc_::Make(func->name, arguments, call_extern_api, {});
}

Expr CreateDeviceFunctionGivenDeviceKernel(Expr expr) {
Expand All @@ -109,8 +128,6 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {

inline std::string GenDeviceKernelName(const std::string& fn) { return fn + "_kernel"; }

bool IsCudaFunction(const ir::_LoweredFunc_* func);

private:
ir::Module::Builder host_module_builder;
ir::Module::Builder device_module_builder;
Expand Down
11 changes: 5 additions & 6 deletions cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ using ir::Module;

static constexpr int DebugLogMaxLen = 30000;

void Compiler::Build(const Module& module, const std::string& code, void* stream) {
void Compiler::Build(const Module& module, const std::string& code) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module, code, stream);
CompileCudaModule(module, code);
} else if (target_.arch == Target::Arch::X86) {
CompileX86Module(module);
} else {
Expand Down Expand Up @@ -72,7 +72,7 @@ void Compiler::BuildDefault(const Module& module) {
}
}

void Compiler::CompileCudaModule(const Module& module, const std::string& code, void* stream) {
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
#ifdef CINN_WITH_CUDA
auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
Expand Down Expand Up @@ -117,7 +117,6 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code,
CHECK(fn_kernel);

symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
symbols.RegisterVar(kernel_fn_name + "_stream_ptr_", static_cast<cudaStream_t>(stream));
}

engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
Expand All @@ -132,10 +131,10 @@ void Compiler::CompileX86Module(const Module& module) { engine_->Link<CodeGenX86

void Compiler::ExportObject(const std::string& path) { engine_->ExportObject(path); }

lower_func_ptr_t Compiler::Lookup(absl::string_view fn_name) {
void* Compiler::Lookup(absl::string_view fn_name) {
CHECK(engine_);
if (engine_->Lookup(fn_name) != nullptr) {
return reinterpret_cast<lower_func_ptr_t>(engine_->Lookup(fn_name));
return engine_->Lookup(fn_name);
}
return nullptr;
}
Expand Down
6 changes: 3 additions & 3 deletions cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Compiler final {
/**
* Compile and link to a CINN module.
*/
void Build(const ir::Module& module, const std::string& code = "", void* stream = nullptr);
void Build(const ir::Module& module, const std::string& code = "");

void ExportObject(const std::string& path);

Expand All @@ -51,10 +51,10 @@ class Compiler final {
* Retrieve a function by \p fn_name.
* @return function address or null if not exists.
*/
lower_func_ptr_t Lookup(absl::string_view fn_name);
void* Lookup(absl::string_view fn_name);

private:
void CompileCudaModule(const ir::Module& module, const std::string& code = "", void* stream = nullptr);
void CompileCudaModule(const ir::Module& module, const std::string& code = "");

void CompileX86Module(const ir::Module& module);

Expand Down
Loading

0 comments on commit af52b7a

Please sign in to comment.