Skip to content

Commit

Permalink
Move SplitKernels to a common file
Browse files Browse the repository at this point in the history
  • Loading branch information
echuraev committed May 12, 2021
1 parent 4938d70 commit 8791c68
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 62 deletions.
25 changes: 4 additions & 21 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <mutex>
#include <string>
#include "../file_utils.h"
#include "../source_utils.h"
#include "../meta_data.h"
#include "../pack_args.h"
#include "../thread_storage_scope.h"
Expand Down Expand Up @@ -74,27 +75,6 @@ void SaveToBinary(dmlc::Stream* stream) final {
}
}

std::unordered_map<std::string, std::string> SplitKernels(std::string source) const {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
std::string del{"// Function: "};
size_t begin = source.find(del);
size_t end = begin;
while (end != std::string::npos) {
begin += del.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
end = source.find(del, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
}
}
return split_kernels;
}

// get a from primary context in device_id
id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
Expand All @@ -112,6 +92,9 @@ void SaveToBinary(dmlc::Stream* stream) final {
id<MTLLibrary> lib = nil;
std::string source;
auto kernel = parsed_kernels_.find(func_name);
// If we cannot find this kernel in parsed_kernels_, it means that all kernels going together
// without explicit separator. In this case we use data_ with all kernels. It done for backward
// compatibility.
if (kernel != parsed_kernels_.end())
source = kernel->second;
else
Expand Down
8 changes: 0 additions & 8 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,6 @@ class OpenCLModuleNode : public ModuleNode {
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e);

/*
* \brief Splits the provided serialized source file into separate
* source for each kernel primitive.
* \param source The serialized program source file (fmt: cl)
* \return Mapping from primitive name to kernel source
*/
std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;

private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
Expand Down
39 changes: 6 additions & 33 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <unordered_map>
#include <vector>

#include "../source_utils.h"
#include "opencl_common.h"

namespace tvm {
Expand Down Expand Up @@ -188,6 +189,11 @@ void OpenCLModuleNode::Init() {

// split into source artifacts for each kernel
parsed_kernels_ = SplitKernels(GetSource("cl"));
ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited "
<< "source from code generation, but no kernel "
<< "delimiter was found.";
ICHECK_EQ(workspace_->num_registered_kernels, parsed_kernels_.size())
<< "The number of registered kernels does not match number of parsed kernel sources";
// zero initialize cl_program pointers for each device kernel
for (auto& kv : parsed_kernels_) {
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
Expand Down Expand Up @@ -242,39 +248,6 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
return kernel;
}

std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
std::string source) const {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
std::string del{"// Function: "};
size_t end;
size_t begin = source.find(del);
ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
<< "source from code generation, but no kernel "
<< "delimiter was found.";
for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) {
begin += del.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
// std::string::substr returns either start of next kernel
// or std::string::npos, in the latter case substr returns
// all characters until the end of the source string.
end = source.find(del, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
if (end == std::string::npos) {
break;
}
}
}
ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size())
<< "The number of registered kernels does not match number of parsed kernel sources";
return split_kernels;
}

Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
Expand Down
50 changes: 50 additions & 0 deletions src/runtime/source_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file source_utils.cc
*/
#include "source_utils.h"

namespace tvm {
namespace runtime {

std::unordered_map<std::string, std::string> SplitKernels(std::string source,
std::string delimiter) {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
size_t begin = source.find(delimiter);
size_t end = begin;
while (end != std::string::npos) {
begin += delimiter.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
end = source.find(delimiter, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
}
}
return split_kernels;
}
} // namespace runtime
} // namespace tvm

44 changes: 44 additions & 0 deletions src/runtime/source_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file source_utils.h
* \brief Minimum source manipulation utils for runtime.
*/

#ifndef TVM_RUNTIME_SOURCE_UTILS_H_
#define TVM_RUNTIME_SOURCE_UTILS_H_

#include <string>
#include <unordered_map>

namespace tvm {
namespace runtime {
/*!
* \brief Split the source file on separate kernels by specified delimiter.
* \param source The source code of the kernels.
* \param delimiter The delimiter which is using for splitting kernels.
* \return Mapping from primitive name to kernel source
*/
std::unordered_map<std::string, std::string> SplitKernels(std::string source, std::string delimiter = "// Function: ");
}
} // namespace tvm

#endif // TVM_RUNTIME_SOURCE_UTILS_H_

0 comments on commit 8791c68

Please sign in to comment.