diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index b4377119e4c7..d74a529595a2 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -315,6 +315,14 @@ class OpenCLModuleNode : public ModuleNode { cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e); + /* + * \brief Splits the provided serialized source file into separate + * source for each kernel primitive. + * \param source The serialized program source file (fmt: cl) + * \return Mapping from primitive name to kernel source + */ + std::unordered_map SplitKernels(std::string source) const; + private: // The workspace, need to keep reference to use it in destructor. // In case of static destruction order problem. @@ -329,14 +337,14 @@ class OpenCLModuleNode : public ModuleNode { std::mutex build_lock_; // The OpenCL source. std::string source_; - // the binary data - cl_program program_{nullptr}; - // build info - std::vector device_built_flag_; + // Mapping from primitive name to cl program for each device. + std::unordered_map> programs_; // kernel id cache std::unordered_map kid_map_; // kernels build so far. std::vector kernels_; + // parsed kernel data + std::unordered_map parsed_kernels_; }; } // namespace runtime diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 8c22c3c8cb23..6543b1de460c 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -105,8 +105,13 @@ OpenCLModuleNode::~OpenCLModuleNode() { for (cl_kernel k : kernels_) { OPENCL_CALL(clReleaseKernel(k)); } - if (program_) { - OPENCL_CALL(clReleaseProgram(program_)); + // free the programs + for (auto& kv : programs_) { + for (auto& program : kv.second) { + if (program) { + OPENCL_CALL(clReleaseProgram(program)); + } + } } } @@ -166,7 +171,6 @@ std::string OpenCLModuleNode::GetSource(const std::string& format) { void OpenCLModuleNode::Init() { workspace_ = GetGlobalWorkspace(); workspace_->Init(); - device_built_flag_.resize(workspace_->devices.size(), false); // initialize the kernel id, need to lock global table. std::lock_guard lock(workspace_->mu); for (const auto& kv : fmap_) { @@ -181,28 +185,34 @@ void OpenCLModuleNode::Init() { e.version = workspace_->timestamp++; kid_map_[key] = e; } + + // split into source artifacts for each kernel + parsed_kernels_ = SplitKernels(GetSource("cl")); + // zero initialize cl_program pointers for each device kernel + for (auto& kv : parsed_kernels_) { + programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); + } } cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e) { std::lock_guard lock(build_lock_); int device_id = t->device.device_id; - if (!device_built_flag_[device_id]) { + if (programs_[func_name][device_id] == nullptr) { // create program if (fmt_ == "cl") { - if (program_ == nullptr) { - const char* s = data_.c_str(); - size_t len = data_.length(); - cl_int err; - program_ = clCreateProgramWithSource(w->context, 1, &s, &len, &err); - OPENCL_CHECK_ERROR(err); - } + const char* s = parsed_kernels_[func_name].c_str(); + size_t len = parsed_kernels_[func_name].length(); + cl_int err; + programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err); + OPENCL_CHECK_ERROR(err); } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { const unsigned char* s = (const unsigned char*)data_.c_str(); size_t len = data_.length(); cl_int err; cl_device_id dev = w->devices[device_id]; - program_ = clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err); + programs_[func_name][device_id] = + clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err); OPENCL_CHECK_ERROR(err); } else { LOG(FATAL) << "Unknown OpenCL format " << fmt_; @@ -210,20 +220,21 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre // build program cl_int err; cl_device_id dev = w->devices[device_id]; - err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr); + err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr); if (err != CL_SUCCESS) { size_t len; std::string log; - clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, + &len); log.resize(len); - clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); - LOG(FATAL) << "OpenCL build error for device=" << dev << log; + clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len, + &log[0], nullptr); + LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log; } - device_built_flag_[device_id] = true; } // build kernel cl_int err; - cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err); + cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err); OPENCL_CHECK_ERROR(err); t->kernel_table[e.kernel_id].kernel = kernel; t->kernel_table[e.kernel_id].version = e.version; @@ -231,6 +242,39 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre return kernel; } +std::unordered_map OpenCLModuleNode::SplitKernels( + std::string source) const { + std::unordered_map split_kernels; + if (source.size()) { + std::string del{"// Function: "}; + size_t end; + size_t begin = source.find(del); + ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited " + << "source from code generation, but no kernel " + << "delimiter was found."; + for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) { + begin += del.size(); + end = source.find('\n', begin); + std::string func_name = source.substr(begin, end - begin); + begin = ++end; + // std::string::substr returns either start of next kernel + // or std::string::npos, in the latter case substr returns + // all characters until the end of the source string. + end = source.find(del, begin); + std::string func_source = + source.substr(begin, (end == std::string::npos) ? end : end - begin); + split_kernels.insert({func_name, func_source}); + begin = end; + if (end == std::string::npos) { + break; + } + } + } + ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size()) + << "The number of registered kernels does not match number of parsed kernel sources"; + return split_kernels; +} + Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index f72f3f265511..edb614d9c122 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -283,23 +283,27 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; - CodeGenOpenCL cg; - cg.Init(output_ssa); + std::stringstream code; + const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc"); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; + code << "// Function: " << kv.first->name_hint << std::endl; + CodeGenOpenCL cg; + cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); + std::string fsource = cg.Finish(); + if (fpostproc) { + fsource = (*fpostproc)(fsource).operator std::string(); + } + code << fsource; } - std::string code = cg.Finish(); - if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) { - code = (*f)(code).operator std::string(); - } - return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code); + return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str()); } TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL);