From b0b51fdee1a04e74d6f944d5647ba739ee527695 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 5 May 2021 12:45:06 +0300 Subject: [PATCH 1/5] [METAL] Split kernels and compile them separately Refactor Metal module to build each kernel separately. This should help to avoid potential problem then generated blockSize is more than maxTotalThreadsPerThreadgroup. --- src/runtime/metal/metal_module.mm | 95 +++++++++++++++++++----------- src/target/source/codegen_metal.cc | 25 ++++---- 2 files changed, 75 insertions(+), 45 deletions(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index e22caa21a81e..7df1da6ce03d 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -43,7 +43,9 @@ public: explicit MetalModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { + parsed_kernels_ = SplitKernels(GetSource(fmt_)); + } const char* type_key() const final { return "metal"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; @@ -71,6 +73,31 @@ void SaveToBinary(dmlc::Stream* stream) final { return ""; } } + + std::unordered_map SplitKernels(std::string source) const { + std::unordered_map split_kernels; + if (source.size()) { + std::string del{"// Function: "}; + size_t end = 0; + size_t begin = source.find(del); + ICHECK(begin != std::string::npos) << "The Metal module expects a kernel delimited " + << "source from code generation, but no kernel " + << "delimiter was found."; + while (end != std::string::npos) { + begin += del.size(); + end = source.find('\n', begin); + std::string func_name = source.substr(begin, end - begin); + begin = ++end; + end = source.find(del, begin); + std::string func_source = + source.substr(begin, (end == std::string::npos) ? end : end - begin); + split_kernels.insert({func_name, func_source}); + begin = end; + } + } + return split_kernels; + } + // get a from primary context in device_id id GetPipelineState(size_t device_id, const std::string& func_name) { metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); @@ -85,37 +112,37 @@ void SaveToBinary(dmlc::Stream* stream) final { if (it != e.smap.end()) return it->second; // compile NSError* err_msg = nil; - if (e.lib == nil) { - if (fmt_ == "metal") { - MTLCompileOptions* opts = [MTLCompileOptions alloc]; - opts.languageVersion = MTLLanguageVersion2_3; - opts.fastMathEnabled = YES; - // opts = nil; - e.lib = [w->devices[device_id] - newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] - options:opts - error:&err_msg]; - [opts dealloc]; - if (e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; - } - if (err_msg != nil) { - LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; - } - } else { - // Build from library. - auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); - auto data = dispatch_data_create(data_.c_str(), data_.length(), q, - ^{ - }); - e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; - if (err_msg != nil || e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; - } + id lib = nil; + if (fmt_ == "metal") { + MTLCompileOptions* opts = [MTLCompileOptions alloc]; + opts.languageVersion = MTLLanguageVersion2_3; + opts.fastMathEnabled = YES; + // opts = nil; + std::string source = parsed_kernels_[func_name]; + lib = + [w->devices[device_id] newLibraryWithSource:[NSString stringWithUTF8String:source.c_str()] + options:opts + error:&err_msg]; + [opts dealloc]; + if (lib == nil) { + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; + } + if (err_msg != nil) { + LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; + } + } else { + // Build from library. + auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); + std::string source = parsed_kernels_[func_name]; + auto data = dispatch_data_create(source.c_str(), source.length(), q, + ^{ + }); + lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; + if (err_msg != nil || lib == nil) { + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } } - id f = - [e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; + id f = [lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; ICHECK(f != nil) << "cannot find function " << func_name; id state = [w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg]; @@ -123,6 +150,7 @@ void SaveToBinary(dmlc::Stream* stream) final { << " for function " << func_name << [[err_msg localizedDescription] UTF8String]; [f release]; + [lib release]; // The state.threadExecutionWidth can change dynamically according // to the resource constraint in kernel, so it is not strictly hold // Turn of warp aware optimziation for now. @@ -135,13 +163,10 @@ void SaveToBinary(dmlc::Stream* stream) final { private: // device specific entry struct DeviceEntry { - // library - id lib = nil; // state cache; - std::unordered_map > smap; + std::unordered_map> smap; ~DeviceEntry() { - if (lib != nil) [lib release]; for (auto&& kv : smap) { [kv.second release]; } @@ -159,6 +184,8 @@ void SaveToBinary(dmlc::Stream* stream) final { std::vector finfo_; // internal mutex when updating the module std::mutex mutex_; + // parsed kernel data + std::unordered_map parsed_kernels_; }; // a wrapped function class to get packed func. diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 270d81f218ea..9a496cd7db4c 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -305,27 +305,30 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT runtime::Module BuildMetal(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; - CodeGenMetal cg; - cg.Init(output_ssa); + std::stringstream code; + std::stringstream source; + std::string fmt = "metal"; for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; + code << "// Function: " << kv.first->name_hint << std::endl; + CodeGenMetal cg; + cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); + std::string fsource = cg.Finish(); + if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { + source << fsource; + fsource = (*f)(fsource).operator std::string(); + fmt = "metallib"; + } + code << fsource; } - std::string code = cg.Finish(); - std::string fmt = "metal"; - std::string source = ""; - if (const auto* f = Registry::Get("tvm_callback_metal_compile")) { - source = code; - code = (*f)(code).operator std::string(); - fmt = "metallib"; - } - return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); + return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str()); } TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); From 4938d70ba36e1127b411fdd961938c475a07b936 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 6 May 2021 08:14:52 +0300 Subject: [PATCH 2/5] Add backward compatible behaviour --- src/runtime/metal/metal_module.mm | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 7df1da6ce03d..e8f14d5d357c 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -78,11 +78,8 @@ void SaveToBinary(dmlc::Stream* stream) final { std::unordered_map split_kernels; if (source.size()) { std::string del{"// Function: "}; - size_t end = 0; size_t begin = source.find(del); - ICHECK(begin != std::string::npos) << "The Metal module expects a kernel delimited " - << "source from code generation, but no kernel " - << "delimiter was found."; + size_t end = begin; while (end != std::string::npos) { begin += del.size(); end = source.find('\n', begin); @@ -113,12 +110,17 @@ void SaveToBinary(dmlc::Stream* stream) final { // compile NSError* err_msg = nil; id lib = nil; + std::string source; + auto kernel = parsed_kernels_.find(func_name); + if (kernel != parsed_kernels_.end()) + source = kernel->second; + else + source = data_; if (fmt_ == "metal") { MTLCompileOptions* opts = [MTLCompileOptions alloc]; opts.languageVersion = MTLLanguageVersion2_3; opts.fastMathEnabled = YES; // opts = nil; - std::string source = parsed_kernels_[func_name]; lib = [w->devices[device_id] newLibraryWithSource:[NSString stringWithUTF8String:source.c_str()] options:opts @@ -133,7 +135,6 @@ void SaveToBinary(dmlc::Stream* stream) final { } else { // Build from library. auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); - std::string source = parsed_kernels_[func_name]; auto data = dispatch_data_create(source.c_str(), source.length(), q, ^{ }); From 070bba4d40adb347b85bf696be01c15f6c0aac63 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 12 May 2021 10:20:29 +0300 Subject: [PATCH 3/5] Move SplitKernels to a common file --- src/runtime/metal/metal_module.mm | 25 +++------------ src/runtime/opencl/opencl_common.h | 8 ----- src/runtime/opencl/opencl_module.cc | 39 ++++------------------- src/runtime/source_utils.cc | 49 +++++++++++++++++++++++++++++ src/runtime/source_utils.h | 44 ++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 62 deletions(-) create mode 100644 src/runtime/source_utils.cc create mode 100644 src/runtime/source_utils.h diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index e8f14d5d357c..9ec26ec35f21 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -30,6 +30,7 @@ #include "../file_utils.h" #include "../meta_data.h" #include "../pack_args.h" +#include "../source_utils.h" #include "../thread_storage_scope.h" #include "metal_common.h" @@ -74,27 +75,6 @@ void SaveToBinary(dmlc::Stream* stream) final { } } - std::unordered_map SplitKernels(std::string source) const { - std::unordered_map split_kernels; - if (source.size()) { - std::string del{"// Function: "}; - size_t begin = source.find(del); - size_t end = begin; - while (end != std::string::npos) { - begin += del.size(); - end = source.find('\n', begin); - std::string func_name = source.substr(begin, end - begin); - begin = ++end; - end = source.find(del, begin); - std::string func_source = - source.substr(begin, (end == std::string::npos) ? end : end - begin); - split_kernels.insert({func_name, func_source}); - begin = end; - } - } - return split_kernels; - } - // get a from primary context in device_id id GetPipelineState(size_t device_id, const std::string& func_name) { metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); @@ -112,6 +92,9 @@ void SaveToBinary(dmlc::Stream* stream) final { id lib = nil; std::string source; auto kernel = parsed_kernels_.find(func_name); + // If we cannot find this kernel in parsed_kernels_, it means that all kernels going together + // without explicit separator. In this case we use data_ with all kernels. It done for backward + // compatibility. if (kernel != parsed_kernels_.end()) source = kernel->second; else diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 93420feec805..ad2040af8cd5 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -326,14 +326,6 @@ class OpenCLModuleNode : public ModuleNode { cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, const std::string& func_name, const KTRefEntry& e); - /* - * \brief Splits the provided serialized source file into separate - * source for each kernel primitive. - * \param source The serialized program source file (fmt: cl) - * \return Mapping from primitive name to kernel source - */ - std::unordered_map SplitKernels(std::string source) const; - private: // The workspace, need to keep reference to use it in destructor. // In case of static destruction order problem. diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 6543b1de460c..40aa66651f7b 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -29,6 +29,7 @@ #include #include +#include "../source_utils.h" #include "opencl_common.h" namespace tvm { @@ -188,6 +189,11 @@ void OpenCLModuleNode::Init() { // split into source artifacts for each kernel parsed_kernels_ = SplitKernels(GetSource("cl")); + ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited " + << "source from code generation, but no kernel " + << "delimiter was found."; + ICHECK_EQ(workspace_->num_registered_kernels, parsed_kernels_.size()) + << "The number of registered kernels does not match number of parsed kernel sources"; // zero initialize cl_program pointers for each device kernel for (auto& kv : parsed_kernels_) { programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); @@ -242,39 +248,6 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre return kernel; } -std::unordered_map OpenCLModuleNode::SplitKernels( - std::string source) const { - std::unordered_map split_kernels; - if (source.size()) { - std::string del{"// Function: "}; - size_t end; - size_t begin = source.find(del); - ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited " - << "source from code generation, but no kernel " - << "delimiter was found."; - for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) { - begin += del.size(); - end = source.find('\n', begin); - std::string func_name = source.substr(begin, end - begin); - begin = ++end; - // std::string::substr returns either start of next kernel - // or std::string::npos, in the latter case substr returns - // all characters until the end of the source string. - end = source.find(del, begin); - std::string func_source = - source.substr(begin, (end == std::string::npos) ? end : end - begin); - split_kernels.insert({func_name, func_source}); - begin = end; - if (end == std::string::npos) { - break; - } - } - } - ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size()) - << "The number of registered kernels does not match number of parsed kernel sources"; - return split_kernels; -} - Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); diff --git a/src/runtime/source_utils.cc b/src/runtime/source_utils.cc new file mode 100644 index 000000000000..e1cf94e52e18 --- /dev/null +++ b/src/runtime/source_utils.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file source_utils.cc + */ +#include "source_utils.h" + +namespace tvm { +namespace runtime { + +std::unordered_map SplitKernels(std::string source, + std::string delimiter) { + std::unordered_map split_kernels; + if (source.size()) { + size_t begin = source.find(delimiter); + size_t end = begin; + while (end != std::string::npos) { + begin += delimiter.size(); + end = source.find('\n', begin); + std::string func_name = source.substr(begin, end - begin); + begin = ++end; + end = source.find(delimiter, begin); + std::string func_source = + source.substr(begin, (end == std::string::npos) ? end : end - begin); + split_kernels.insert({func_name, func_source}); + begin = end; + } + } + return split_kernels; +} +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/source_utils.h b/src/runtime/source_utils.h new file mode 100644 index 000000000000..5476585b945c --- /dev/null +++ b/src/runtime/source_utils.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file source_utils.h + * \brief Minimum source manipulation utils for runtime. + */ + +#ifndef TVM_RUNTIME_SOURCE_UTILS_H_ +#define TVM_RUNTIME_SOURCE_UTILS_H_ + +#include +#include + +namespace tvm { +namespace runtime { +/*! + * \brief Split the source file on separate kernels by specified delimiter. + * \param source The source code of the kernels. + * \param delimiter The delimiter which is using for splitting kernels. + * \return Mapping from primitive name to kernel source + */ +std::unordered_map SplitKernels(std::string source, + std::string delimiter = "// Function: "); +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_SOURCE_UTILS_H_ From 8041a29174d76b587e89ba457409d4193f69cc6b Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 12 May 2021 12:44:46 +0300 Subject: [PATCH 4/5] Remove using GetSource --- src/runtime/metal/metal_module.mm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 9ec26ec35f21..2920c60449d1 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -45,7 +45,7 @@ explicit MetalModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string source) : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { - parsed_kernels_ = SplitKernels(GetSource(fmt_)); + parsed_kernels_ = SplitKernels(data); } const char* type_key() const final { return "metal"; } From df55c33380d2794ae5a9db142baf5727aec2ea61 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 13 May 2021 09:06:19 +0300 Subject: [PATCH 5/5] Update the bundles --- apps/android_camera/app/src/main/jni/tvm_runtime.h | 1 + apps/android_rpc/app/src/main/jni/tvm_runtime.h | 1 + golang/src/tvm_runtime_pack.cc | 1 + 3 files changed, 3 insertions(+) diff --git a/apps/android_camera/app/src/main/jni/tvm_runtime.h b/apps/android_camera/app/src/main/jni/tvm_runtime.h index f3c7efd08b5c..07a812c4b840 100644 --- a/apps/android_camera/app/src/main/jni/tvm_runtime.h +++ b/apps/android_camera/app/src/main/jni/tvm_runtime.h @@ -57,6 +57,7 @@ #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" #include "../src/runtime/opencl/opencl_module.cc" +#include "../src/runtime/source_utils.cc" #endif #ifdef TVM_VULKAN_RUNTIME diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5dcd823929ca..c0bd7070412a 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -62,6 +62,7 @@ #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" #include "../src/runtime/opencl/opencl_module.cc" +#include "../src/runtime/source_utils.cc" #endif #ifdef TVM_VULKAN_RUNTIME diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index e6d8e74ae0f9..c2add6a36734 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -68,3 +68,4 @@ // Uncomment the following lines to enable OpenCL // #include "../../src/runtime/opencl/opencl_device_api.cc" // #include "../../src/runtime/opencl/opencl_module.cc" +// #include "../src/runtime/source_utils.cc"