Skip to content

Commit

Permalink
Migrate C Interface API Generation to C++
Browse files Browse the repository at this point in the history
Using the new name transformations added in #9088, the C interface API is now generated in C++ rather than in Python.

Follow up PRs will clean up any remaining name transformation inconsistencies.

Fixes #8792
  • Loading branch information
Mousius committed Sep 24, 2021
1 parent cca7176 commit 7c1bead
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 116 deletions.
101 changes: 0 additions & 101 deletions python/tvm/micro/interface_api.py

This file was deleted.

15 changes: 14 additions & 1 deletion python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import tarfile
import typing

import tvm
from tvm.ir.type import TupleType
from .._ffi import get_global_func
from .interface_api import generate_c_interface_header
from ..contrib import utils
from ..driver import build_module
from ..runtime import ndarray as _nd
from ..relay.backend import executor_factory
from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name
from ..relay import param_dict
from ..tir import expr

Expand All @@ -43,6 +44,18 @@ class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def generate_c_interface_header(module_name, inputs, outputs, include_path):
"""Generate C Interface header to be included in MLF"""
mangled_name = to_c_variable_style(prefix_generated_name(module_name))
metadata_header = os.path.join(include_path, f"{mangled_name}.h")

interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
interface_c_module = interface_c_create(module_name, inputs, outputs)

with open(metadata_header, "w") as header_file:
header_file.write(interface_c_module.get_source())


def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
"""Populate the codegen sub-directory as part of a Model Library Format export.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/backend/name_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def to_c_variable_style(original_name: str):
return _backend.ToCVariableStyle(original_name)


def to_c_constant_style(original_name: str):
"""Transform a name to the C constant style assuming it is
appropriately constructed using the prefixing functions
Parameters
----------
original_name : str
Original name to transform
"""
return _backend.ToCConstantStyle(original_name)


def prefix_name(names: Union[List[str], str]):
"""Apply TVM-specific prefix to a function name
Expand Down
15 changes: 8 additions & 7 deletions src/relay/backend/name_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ std::string ToCVariableStyle(const std::string& original_name) {
return variable_name;
}

std::string ToCConstantStyle(const std::string& original_name) {
std::string constant_name = ToCVariableStyle(original_name);

std::transform(constant_name.begin(), constant_name.end(), constant_name.begin(), ::toupper);
return constant_name;
}

std::string CombineNames(const Array<String>& names) {
std::stringstream combine_stream;
ICHECK(!names.empty());
Expand All @@ -77,22 +84,16 @@ std::string CombineNames(const Array<String>& names) {
std::string SanitiseName(const std::string& name) {
ICHECK(!name.empty());

auto multipleSeparators = [](char before, char after) {
return before == '_' && before == after;
};
auto isNotAlnum = [](char c) { return !std::isalnum(c); };
std::string sanitised_input = name;
std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_');

sanitised_input.erase(
std::unique(sanitised_input.begin(), sanitised_input.end(), multipleSeparators),
sanitised_input.end());

return sanitised_input;
}

TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle);
TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle);
TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName);
TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName);
TVM_REGISTER_GLOBAL("relay.backend.SanitiseName").set_body_typed(SanitiseName);
Expand Down
11 changes: 11 additions & 0 deletions src/relay/backend/name_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
* ToCVariableStyle(PrefixGeneratedName(CombineNames({"model", "Devices"})))
* // tvmgen_model_devices
*
* ToCConstantStyle(PrefixGeneratedName(CombineNames({"model", "Devices"})))
* // TVMGEN_MODEL_DEVICES
*
*/

#include <tvm/runtime/container/array.h>
Expand Down Expand Up @@ -68,6 +71,14 @@ std::string ToCFunctionStyle(const std::string& original_name);
*/
std::string ToCVariableStyle(const std::string& original_name);

/*!
* \brief Transform a name to the C constant style assuming it is
* appropriately constructed using the prefixing functions
* \param name Original name
* \return Transformed function in the C constant style
*/
std::string ToCConstantStyle(const std::string& original_name);

/*!
* \brief Combine names together for use as a generated name
* \param names Vector of strings to combine
Expand Down
132 changes: 132 additions & 0 deletions src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file interface_c.cc
* \brief Generates a C interface header for a given modules inputs and outputs
*/

#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include <string>

#include "../../relay/backend/name_transforms.h"

namespace tvm {
namespace codegen {

using runtime::PackedFunc;
using namespace tvm::relay::backend;

class InterfaceCNode : public runtime::ModuleNode {
public:
InterfaceCNode(std::string module_name, Array<String> inputs, Array<String> outputs)
: module_name_(module_name), inputs_(inputs), outputs_(outputs) {}
const char* type_key() const { return "h"; }

std::string GetSource(const std::string& format) final {
std::stringstream code;
std::string mangled_module_name = ToCVariableStyle(PrefixGeneratedName({module_name_}));
std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_}));

EmitUpperHeaderGuard(code, header_guard_name);
EmitBrief(code, "Input tensor pointers");
EmitStruct(code, mangled_module_name, "inputs", inputs_);
EmitBrief(code, "Output tensor pointers");
EmitStruct(code, mangled_module_name, "outputs", outputs_);
EmitRunFunction(code, mangled_module_name);
EmitLowerHeaderGuard(code, header_guard_name);

return code.str();
}

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
return PackedFunc(nullptr);
}

private:
void EmitUpperHeaderGuard(std::stringstream& code_stream, const std::string& header_guard_name) {
code_stream << "#ifndef " << header_guard_name << "_H_\n"
<< "#define " << header_guard_name << "_H_\n"
<< "#include <stdint.h>\n\n"
<< "#ifdef __cplusplus\n"
<< "extern \"C\" {\n"
<< "#endif\n\n";
}

void EmitLowerHeaderGuard(std::stringstream& code_stream, const std::string& header_guard_name) {
code_stream << "\n#ifdef __cplusplus\n"
<< "}\n"
<< "#endif\n\n"
<< "#endif // " << header_guard_name << "_H_\n";
}

void EmitBrief(std::stringstream& code_stream, const std::string& description) {
code_stream << "/*!\n"
<< " * \\brief " << description << " for TVM module \"" << module_name_ << "\" \n"
<< " */\n";
}

void EmitStruct(std::stringstream& code_stream, const std::string& mangled_module_name,
const std::string& suffix, Array<String> properties) {
code_stream << "struct " << mangled_module_name << "_" << suffix << " {\n";

std::vector<std::string> sanitised_properties;
for (const String& property : properties) {
std::string sanitised_property = SanitiseName(property);
ICHECK(std::find(sanitised_properties.begin(), sanitised_properties.end(),
sanitised_property) == sanitised_properties.end())
<< "Sanitized input tensor name clash" << sanitised_property;
code_stream << " void* " << sanitised_property << ";\n";
sanitised_properties.push_back(sanitised_property);
}
code_stream << "};\n\n";
}

void EmitRunFunction(std::stringstream& code_stream, const std::string& mangled_module_name) {
code_stream << "/*!\n"
<< " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n"
<< " * \\param inputs Input tensors for the module \n"
<< " * \\param outputs Output tensors for the module \n"
<< " */\n"
<< "int32_t " << mangled_module_name << "_run(\n"
<< " struct " << mangled_module_name << "_inputs* inputs,\n"
<< " struct " << mangled_module_name << "_outputs* outputs\n"
<< ");\n";
}

std::string module_name_;
Array<String> inputs_;
Array<String> outputs_;
};

runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
Array<String> outputs) {
auto n = make_object<InterfaceCNode>(module_name, inputs, outputs);
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("runtime.InterfaceCCreate").set_body_typed(InterfaceCCreate);

} // namespace codegen
} // namespace tvm
11 changes: 9 additions & 2 deletions tests/cpp/name_transforms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ TEST(NameTransforms, ToCVariableStyle) {
EXPECT_THROW(ToCVariableStyle(""), InternalError);
}

TEST(NameTransforms, ToCConstantStyle) {
ASSERT_EQ(ToCConstantStyle("TVM_Woof"), "TVM_WOOF");
ASSERT_EQ(ToCConstantStyle("TVM_woof"), "TVM_WOOF");
ASSERT_EQ(ToCConstantStyle("TVM_woof_Woof"), "TVM_WOOF_WOOF");
EXPECT_THROW(ToCConstantStyle(""), InternalError);
}

TEST(NameTransforms, PrefixName) {
ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof");
ASSERT_EQ(PrefixName({"woof"}), "TVM_woof");
Expand Down Expand Up @@ -69,10 +76,10 @@ TEST(NameTransforms, CombineNames) {
}

TEST(NameTransforms, SanitiseName) {
ASSERT_EQ(SanitiseName("+_+ "), "_");
ASSERT_EQ(SanitiseName("+_+ "), "____");
ASSERT_EQ(SanitiseName("input+"), "input_");
ASSERT_EQ(SanitiseName("input-"), "input_");
ASSERT_EQ(SanitiseName("input++"), "input_");
ASSERT_EQ(SanitiseName("input++"), "input__");
ASSERT_EQ(SanitiseName("woof:1"), "woof_1");
EXPECT_THROW(SanitiseName(""), InternalError);
}
Expand Down
Loading

0 comments on commit 7c1bead

Please sign in to comment.