Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenCL] Refactor cl_program generation #7834

Merged
merged 6 commits into from
May 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,13 @@ class OpenCLModuleNode : public ModuleNode {
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e);

/*
* \brief Splits the provided serialized source file into separate
* source for each kernel primitive.
* \param source The serialized program source file (fmt: cl)
csullivan marked this conversation as resolved.
Show resolved Hide resolved
*/
std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;

private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
Expand All @@ -329,14 +336,14 @@ class OpenCLModuleNode : public ModuleNode {
std::mutex build_lock_;
// The OpenCL source.
std::string source_;
// the binary data
cl_program program_{nullptr};
// build info
std::vector<bool> device_built_flag_;
// Mapping from primitive name to cl program for each device.
std::unordered_map<std::string, std::vector<cl_program>> programs_;
// kernel id cache
std::unordered_map<std::string, KTRefEntry> kid_map_;
// kernels build so far.
std::vector<cl_kernel> kernels_;
// parsed kernel data
std::unordered_map<std::string, std::string> parsed_kernels_;
};

} // namespace runtime
Expand Down
78 changes: 60 additions & 18 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,13 @@ OpenCLModuleNode::~OpenCLModuleNode() {
for (cl_kernel k : kernels_) {
OPENCL_CALL(clReleaseKernel(k));
}
if (program_) {
OPENCL_CALL(clReleaseProgram(program_));
// free the programs
for (auto& kv : programs_) {
for (auto& program : kv.second) {
if (program) {
OPENCL_CALL(clReleaseProgram(program));
}
}
}
}

Expand Down Expand Up @@ -166,7 +171,6 @@ std::string OpenCLModuleNode::GetSource(const std::string& format) {
void OpenCLModuleNode::Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
Expand All @@ -181,56 +185,94 @@ void OpenCLModuleNode::Init() {
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}

// split into source artifacts for each kernel
parsed_kernels_ = SplitKernels(GetSource("cl"));
// zero initialize cl_program pointers for each device kernel
for (auto& kv : parsed_kernels_) {
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
}
}

cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->device.device_id;
if (!device_built_flag_[device_id]) {
if (programs_[func_name][device_id] == nullptr) {
// create program
if (fmt_ == "cl") {
if (program_ == nullptr) {
const char* s = data_.c_str();
size_t len = data_.length();
cl_int err;
program_ = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
}
const char* s = parsed_kernels_[func_name].c_str();
size_t len = parsed_kernels_[func_name].length();
cl_int err;
programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
const unsigned char* s = (const unsigned char*)data_.c_str();
size_t len = data_.length();
cl_int err;
cl_device_id dev = w->devices[device_id];
program_ = clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
programs_[func_name][device_id] =
clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
OPENCL_CHECK_ERROR(err);
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
}
// build program
cl_int err;
cl_device_id dev = w->devices[device_id];
err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) {
size_t len;
std::string log;
clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr,
&len);
log.resize(len);
clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log;
clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len,
&log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log;
}
device_built_flag_[device_id] = true;
}
// build kernel
cl_int err;
cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err);
OPENCL_CHECK_ERROR(err);
t->kernel_table[e.kernel_id].kernel = kernel;
t->kernel_table[e.kernel_id].version = e.version;
kernels_.push_back(kernel);
return kernel;
}

std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
std::string source) const {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
std::string del{"// Function: "};
size_t end;
size_t begin = source.find(del);
ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
<< "source from code generation, but no kernel "
<< "delimiter was found.";
while (true) {
csullivan marked this conversation as resolved.
Show resolved Hide resolved
begin += del.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
// std::string::substr returns either start of next kernel
// or std::string::npos, in the latter case substr returns
// all characters until the end of the source string.
end = source.find(del, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
if (end == std::string::npos) {
break;
}
}
}
return split_kernels;
}

Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
Expand Down
18 changes: 11 additions & 7 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,23 +283,27 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N
runtime::Module BuildOpenCL(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);

std::stringstream code;
const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
code << "// Function: " << kv.first->name_hint << std::endl;
CodeGenOpenCL cg;
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
std::string fsource = cg.Finish();
if (fpostproc) {
fsource = (*fpostproc)(fsource).operator std::string();
}
code << fsource;
}

std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
}
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str());
}

TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL);
Expand Down