From 1a78643a534b51b8b9327b354069c9b33585f9f8 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Fri, 24 Sep 2021 15:41:19 +0000 Subject: [PATCH] Migrate C Interface API Generation to C++ Using the new name transformations added in #9088, the C interface API is now generated in C++ rather than in Python. Follow up PRs will clean up any remaining name transformation inconsistencies. Fixes #8792 --- python/tvm/micro/interface_api.py | 101 ----------- python/tvm/micro/model_library_format.py | 15 +- python/tvm/relay/backend/name_transforms.py | 12 ++ src/relay/backend/name_transforms.cc | 15 +- src/relay/backend/name_transforms.h | 11 ++ src/target/source/interface_c.cc | 132 ++++++++++++++ tests/cpp/name_transforms_test.cc | 11 +- tests/cpp/target/source/interface_c_test.cc | 184 ++++++++++++++++++++ tests/micro/zephyr/test_zephyr_aot.py | 2 +- tests/python/relay/aot/test_crt_aot.py | 4 +- tests/python/relay/test_name_transforms.py | 14 +- 11 files changed, 385 insertions(+), 116 deletions(-) delete mode 100644 python/tvm/micro/interface_api.py create mode 100644 src/target/source/interface_c.cc create mode 100644 tests/cpp/target/source/interface_c_test.cc diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py deleted file mode 100644 index 5a4841f39f7cb..0000000000000 --- a/python/tvm/micro/interface_api.py +++ /dev/null @@ -1,101 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Defines functions for generating a C interface header""" - -# TODO: Currently the Interface API header is generated in Python but the source it references -# is generated in C++. These should be consolidated to generate both header and source in C++ -# and avoid re-implementing logic, such as name sanitising, in the two different languages. -# See https://github.com/apache/tvm/issues/8792 . - -import os -import re - -from tvm.relay.backend.utils import mangle_module_name - - -def _emit_brief(header_file, module_name, description): - header_file.write("/*!\n") - header_file.write(f' * \\brief {description} for TVM module "{module_name}" \n') - header_file.write(" */\n") - - -def generate_c_interface_header(module_name, inputs, outputs, output_path): - """Generates a C interface header for a given modules inputs and outputs - - Parameters - ---------- - module_name : str - Name of the module to be used in defining structs and naming the header - inputs : list[str] - List of module input names to be placed in generated structs - outputs : list[str] - List of module output names to be placed in generated structs - output_path : str - Path to the output folder to generate the header into - - Returns - ------- - str : - Name of the generated file. - """ - mangled_name = mangle_module_name(module_name) - metadata_header = os.path.join(output_path, f"{mangled_name}.h") - with open(metadata_header, "w") as header_file: - header_file.write( - f"#ifndef {mangled_name.upper()}_H_\n" - f"#define {mangled_name.upper()}_H_\n\n" - "#include \n\n" - "#ifdef __cplusplus\n" - 'extern "C" {\n' - "#endif\n\n" - ) - - _emit_brief(header_file, module_name, "Input tensor pointers") - header_file.write(f"struct {mangled_name}_inputs {{\n") - sanitized_names = [] - for input_name in inputs: - sanitized_input_name = re.sub(r"\W", "_", input_name) - if sanitized_input_name in sanitized_names: - raise ValueError(f"Sanitized input tensor name clash: {sanitized_input_name}") - sanitized_names.append(sanitized_input_name) - header_file.write(f" void* {sanitized_input_name};\n") - header_file.write("};\n\n") - - _emit_brief(header_file, module_name, "Output tensor pointers") - header_file.write(f"struct {mangled_name}_outputs {{\n") - for output_name in outputs: - header_file.write(f" void* {output_name};\n") - header_file.write("};\n\n") - - header_file.write( - "/*!\n" - f' * \\brief entrypoint function for TVM module "{module_name}"\n' - " * \\param inputs Input tensors for the module \n" - " * \\param outputs Output tensors for the module \n" - " */\n" - f"int32_t {mangled_name}_run(\n" - f" struct {mangled_name}_inputs* inputs,\n" - f" struct {mangled_name}_outputs* outputs\n" - ");\n" - ) - - header_file.write( - "\n#ifdef __cplusplus\n}\n#endif\n\n" f"#endif // {mangled_name.upper()}_H_\n" - ) - - return metadata_header diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index ed44a3336a521..012e4d78819cc 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -25,13 +25,14 @@ import tarfile import typing +import tvm from tvm.ir.type import TupleType from .._ffi import get_global_func -from .interface_api import generate_c_interface_header from ..contrib import utils from ..driver import build_module from ..runtime import ndarray as _nd from ..relay.backend import executor_factory +from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name from ..relay import param_dict from ..tir import expr @@ -43,6 +44,18 @@ class UnsupportedInModelLibraryFormatError(Exception): """Raised when export_model_library_format does not support the given Module tree.""" +def generate_c_interface_header(module_name, inputs, outputs, include_path): + """Generate C Interface header to be included in MLF""" + mangled_name = to_c_variable_style(prefix_generated_name(module_name)) + metadata_header = os.path.join(include_path, f"{mangled_name}.h") + + interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate") + interface_c_module = interface_c_create(module_name, inputs, outputs) + + with open(metadata_header, "w") as header_file: + header_file.write(interface_c_module.get_source()) + + def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None): """Populate the codegen sub-directory as part of a Model Library Format export. diff --git a/python/tvm/relay/backend/name_transforms.py b/python/tvm/relay/backend/name_transforms.py index a7bfa9847edd3..d4fadc9cd3c6b 100644 --- a/python/tvm/relay/backend/name_transforms.py +++ b/python/tvm/relay/backend/name_transforms.py @@ -47,6 +47,18 @@ def to_c_variable_style(original_name: str): return _backend.ToCVariableStyle(original_name) +def to_c_constant_style(original_name: str): + """Transform a name to the C constant style assuming it is + appropriately constructed using the prefixing functions + + Parameters + ---------- + original_name : str + Original name to transform + """ + return _backend.ToCConstantStyle(original_name) + + def prefix_name(names: Union[List[str], str]): """Apply TVM-specific prefix to a function name diff --git a/src/relay/backend/name_transforms.cc b/src/relay/backend/name_transforms.cc index f09e6a6346a58..c0b24126e6fc7 100644 --- a/src/relay/backend/name_transforms.cc +++ b/src/relay/backend/name_transforms.cc @@ -60,6 +60,13 @@ std::string ToCVariableStyle(const std::string& original_name) { return variable_name; } +std::string ToCConstantStyle(const std::string& original_name) { + std::string constant_name = ToCVariableStyle(original_name); + + std::transform(constant_name.begin(), constant_name.end(), constant_name.begin(), ::toupper); + return constant_name; +} + std::string CombineNames(const Array& names) { std::stringstream combine_stream; ICHECK(!names.empty()); @@ -77,22 +84,16 @@ std::string CombineNames(const Array& names) { std::string SanitiseName(const std::string& name) { ICHECK(!name.empty()); - auto multipleSeparators = [](char before, char after) { - return before == '_' && before == after; - }; auto isNotAlnum = [](char c) { return !std::isalnum(c); }; std::string sanitised_input = name; std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_'); - sanitised_input.erase( - std::unique(sanitised_input.begin(), sanitised_input.end(), multipleSeparators), - sanitised_input.end()); - return sanitised_input; } TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle); TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle); +TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle); TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName); TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName); TVM_REGISTER_GLOBAL("relay.backend.SanitiseName").set_body_typed(SanitiseName); diff --git a/src/relay/backend/name_transforms.h b/src/relay/backend/name_transforms.h index f94b472f88c88..2bf2616d65e52 100644 --- a/src/relay/backend/name_transforms.h +++ b/src/relay/backend/name_transforms.h @@ -35,6 +35,9 @@ * ToCVariableStyle(PrefixGeneratedName(CombineNames({"model", "Devices"}))) * // tvmgen_model_devices * + * ToCConstantStyle(PrefixGeneratedName(CombineNames({"model", "Devices"}))) + * // TVMGEN_MODEL_DEVICES + * */ #include @@ -68,6 +71,14 @@ std::string ToCFunctionStyle(const std::string& original_name); */ std::string ToCVariableStyle(const std::string& original_name); +/*! + * \brief Transform a name to the C constant style assuming it is + * appropriately constructed using the prefixing functions + * \param name Original name + * \return Transformed function in the C constant style + */ +std::string ToCConstantStyle(const std::string& original_name); + /*! * \brief Combine names together for use as a generated name * \param names Vector of strings to combine diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc new file mode 100644 index 0000000000000..cec840c4bf2cd --- /dev/null +++ b/src/target/source/interface_c.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file interface_c.cc + * \brief Generates a C interface header for a given modules inputs and outputs + */ + +#include +#include +#include +#include +#include + +#include + +#include "../../relay/backend/name_transforms.h" + +namespace tvm { +namespace codegen { + +using runtime::PackedFunc; +using namespace tvm::relay::backend; + +class InterfaceCNode : public runtime::ModuleNode { + public: + InterfaceCNode(std::string module_name, Array inputs, Array outputs) + : module_name_(module_name), inputs_(inputs), outputs_(outputs) {} + const char* type_key() const { return "h"; } + + std::string GetSource(const std::string& format) final { + std::stringstream code; + std::string mangled_module_name = ToCVariableStyle(PrefixGeneratedName({module_name_})); + std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_})); + + EmitUpperHeaderGuard(code, header_guard_name); + EmitBrief(code, "Input tensor pointers"); + EmitStruct(code, mangled_module_name, "inputs", inputs_); + EmitBrief(code, "Output tensor pointers"); + EmitStruct(code, mangled_module_name, "outputs", outputs_); + EmitRunFunction(code, mangled_module_name); + EmitLowerHeaderGuard(code, header_guard_name); + + return code.str(); + } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + return PackedFunc(nullptr); + } + + private: + void EmitUpperHeaderGuard(std::stringstream& code_stream, const std::string& header_guard_name) { + code_stream << "#ifndef " << header_guard_name << "_H_\n" + << "#define " << header_guard_name << "_H_\n" + << "#include \n\n" + << "#ifdef __cplusplus\n" + << "extern \"C\" {\n" + << "#endif\n\n"; + } + + void EmitLowerHeaderGuard(std::stringstream& code_stream, const std::string& header_guard_name) { + code_stream << "\n#ifdef __cplusplus\n" + << "}\n" + << "#endif\n\n" + << "#endif // " << header_guard_name << "_H_\n"; + } + + void EmitBrief(std::stringstream& code_stream, const std::string& description) { + code_stream << "/*!\n" + << " * \\brief " << description << " for TVM module \"" << module_name_ << "\" \n" + << " */\n"; + } + + void EmitStruct(std::stringstream& code_stream, const std::string& mangled_module_name, + const std::string& suffix, Array properties) { + code_stream << "struct " << mangled_module_name << "_" << suffix << " {\n"; + + std::vector sanitised_properties; + for (const String& property : properties) { + std::string sanitised_property = SanitiseName(property); + ICHECK(std::find(sanitised_properties.begin(), sanitised_properties.end(), + sanitised_property) == sanitised_properties.end()) + << "Sanitized input tensor name clash" << sanitised_property; + code_stream << " void* " << sanitised_property << ";\n"; + sanitised_properties.push_back(sanitised_property); + } + code_stream << "};\n\n"; + } + + void EmitRunFunction(std::stringstream& code_stream, const std::string& mangled_module_name) { + code_stream << "/*!\n" + << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n" + << " * \\param inputs Input tensors for the module \n" + << " * \\param outputs Output tensors for the module \n" + << " */\n" + << "int32_t " << mangled_module_name << "_run(\n" + << " struct " << mangled_module_name << "_inputs* inputs,\n" + << " struct " << mangled_module_name << "_outputs* outputs\n" + << ");\n"; + } + + std::string module_name_; + Array inputs_; + Array outputs_; +}; + +runtime::Module InterfaceCCreate(std::string module_name, Array inputs, + Array outputs) { + auto n = make_object(module_name, inputs, outputs); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.InterfaceCCreate").set_body_typed(InterfaceCCreate); + +} // namespace codegen +} // namespace tvm diff --git a/tests/cpp/name_transforms_test.cc b/tests/cpp/name_transforms_test.cc index bcad3f9db0916..a60da86fb8ff2 100644 --- a/tests/cpp/name_transforms_test.cc +++ b/tests/cpp/name_transforms_test.cc @@ -40,6 +40,13 @@ TEST(NameTransforms, ToCVariableStyle) { EXPECT_THROW(ToCVariableStyle(""), InternalError); } +TEST(NameTransforms, ToCConstantStyle) { + ASSERT_EQ(ToCConstantStyle("TVM_Woof"), "TVM_WOOF"); + ASSERT_EQ(ToCConstantStyle("TVM_woof"), "TVM_WOOF"); + ASSERT_EQ(ToCConstantStyle("TVM_woof_Woof"), "TVM_WOOF_WOOF"); + EXPECT_THROW(ToCConstantStyle(""), InternalError); +} + TEST(NameTransforms, PrefixName) { ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof"); ASSERT_EQ(PrefixName({"woof"}), "TVM_woof"); @@ -69,10 +76,10 @@ TEST(NameTransforms, CombineNames) { } TEST(NameTransforms, SanitiseName) { - ASSERT_EQ(SanitiseName("+_+ "), "_"); + ASSERT_EQ(SanitiseName("+_+ "), "____"); ASSERT_EQ(SanitiseName("input+"), "input_"); ASSERT_EQ(SanitiseName("input-"), "input_"); - ASSERT_EQ(SanitiseName("input++"), "input_"); + ASSERT_EQ(SanitiseName("input++"), "input__"); ASSERT_EQ(SanitiseName("woof:1"), "woof_1"); EXPECT_THROW(SanitiseName(""), InternalError); } diff --git a/tests/cpp/target/source/interface_c_test.cc b/tests/cpp/target/source/interface_c_test.cc new file mode 100644 index 0000000000000..1f7870a28c7c2 --- /dev/null +++ b/tests/cpp/target/source/interface_c_test.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +using ::testing::HasSubstr; + +namespace tvm { +namespace codegen { + +runtime::Module InterfaceCCreate(std::string module_name, Array inputs, + Array outputs); + +namespace { + +TEST(InterfaceAPI, ContainsHeaderGuards) { + std::stringstream upper_header_guard; + std::stringstream lower_header_guard; + + upper_header_guard << "#ifndef TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n" + << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n" + << "#include \n\n" + << "#ifdef __cplusplus\n" + << "extern \"C\" {\n" + << "#endif\n\n"; + + lower_header_guard << "\n#ifdef __cplusplus\n" + << "}\n" + << "#endif\n\n" + << "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n"; + + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str())); + ASSERT_THAT(header_source, HasSubstr(lower_header_guard.str())); +} + +TEST(InterfaceAPI, ContainsRunFunction) { + std::stringstream run_function; + + run_function << "/*!\n" + << " * \\brief entrypoint function for TVM module \"ultimate_cat_spotter\"\n" + << " * \\param inputs Input tensors for the module \n" + << " * \\param outputs Output tensors for the module \n" + << " */\n" + << "int32_t tvmgen_ultimate_cat_spotter_run(\n" + << " struct tvmgen_ultimate_cat_spotter_inputs* inputs,\n" + << " struct tvmgen_ultimate_cat_spotter_outputs* outputs\n" + << ");\n"; + + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(run_function.str())); +} + +TEST(InterfaceAPI, ContainsInputStructSingle) { + std::stringstream input_struct; + + input_struct << "/*!\n" + << " * \\brief Input tensor pointers for TVM module \"ultimate_cat_spotter\" \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_inputs {\n" + << " void* input;\n" + << "};\n\n"; + + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(input_struct.str())); +} + +TEST(InterfaceAPI, ContainsInputStructMany) { + std::stringstream input_struct; + + input_struct << "struct tvmgen_ultimate_cat_spotter_inputs {\n" + << " void* input1;\n" + << " void* input2;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(input_struct.str())); +} + +TEST(InterfaceAPI, ContainsInputStructSanitised) { + std::stringstream input_struct; + + input_struct << "struct tvmgen_ultimate_cat_spotter_inputs {\n" + << " void* input_1;\n" + << " void* input_2;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(input_struct.str())); +} + +TEST(InterfaceAPI, ContainsInputStructClash) { + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}); + ASSERT_THROW(test_module->GetSource(), InternalError); +} + +TEST(InterfaceAPI, ContainsOutputStructSingle) { + std::stringstream output_struct; + + output_struct << "/*!\n" + << " * \\brief Output tensor pointers for TVM module \"ultimate_cat_spotter\" \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_outputs {\n" + << " void* output;\n" + << "};\n\n"; + + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(output_struct.str())); +} + +TEST(InterfaceAPI, ContainsOutputStructMany) { + std::stringstream output_struct; + + output_struct << "struct tvmgen_ultimate_cat_spotter_outputs {\n" + << " void* output1;\n" + << " void* output2;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(output_struct.str())); +} + +TEST(InterfaceAPI, ContainsOutputStructSanitised) { + std::stringstream output_struct; + + output_struct << "struct tvmgen_ultimate_cat_spotter_outputs {\n" + << " void* output_1;\n" + << " void* output_2;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(output_struct.str())); +} + +TEST(InterfaceAPI, ContainsOutputStructClash) { + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}); + ASSERT_THROW(test_module->GetSource(), InternalError); +} + +} // namespace +} // namespace codegen +} // namespace tvm \ No newline at end of file diff --git a/tests/micro/zephyr/test_zephyr_aot.py b/tests/micro/zephyr/test_zephyr_aot.py index f03b8ecce6d04..543b15c61fcf1 100644 --- a/tests/micro/zephyr/test_zephyr_aot.py +++ b/tests/micro/zephyr/test_zephyr_aot.py @@ -31,7 +31,7 @@ import tvm.relay as relay from tvm.contrib.download import download_testdata -from tvm.micro.interface_api import generate_c_interface_header +from tvm.micro.model_library_format import generate_c_interface_header import test_utils diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 9961cd567fbe6..4408202e47c43 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -22,7 +22,7 @@ import pytest import tvm -from tvm import relay +from tvm import relay, TVMError from tvm.ir.module import IRModule from tvm.relay import testing, transform from tvm.relay.testing import byoc @@ -604,7 +604,7 @@ def test_name_sanitiser_name_clash(): inputs = {"input::-1": x_data, "input::-2": y_data, "input:--2": t_data} output_list = generate_ref_data(func, inputs) - with pytest.raises(ValueError, match="Sanitized input tensor name clash"): + with pytest.raises(TVMError, match="Sanitized input tensor name clash"): compile_and_run( AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), test_runner, diff --git a/tests/python/relay/test_name_transforms.py b/tests/python/relay/test_name_transforms.py index 34ea100392f5a..1909b912bc171 100644 --- a/tests/python/relay/test_name_transforms.py +++ b/tests/python/relay/test_name_transforms.py @@ -19,6 +19,7 @@ from tvm.relay.backend.name_transforms import ( to_c_function_style, to_c_variable_style, + to_c_constant_style, prefix_name, prefix_generated_name, sanitise_name, @@ -45,6 +46,15 @@ def test_to_c_variable_style(): to_c_variable_style("") +def test_to_c_constant_style(): + assert to_c_constant_style("TVM_Woof") == "TVM_WOOF" + assert to_c_constant_style("TVM_woof") == "TVM_WOOF" + assert to_c_constant_style("TVM_woof_Woof") == "TVM_WOOF_WOOF" + + with pytest.raises(TVMError): + to_c_constant_style("") + + def test_prefix_name(): assert prefix_name("Woof") == "TVM_Woof" assert prefix_name(["Woof"]) == "TVM_Woof" @@ -75,10 +85,10 @@ def test_prefix_generated_name(): def test_sanitise_name(): - assert sanitise_name("+_+ ") == "_" + assert sanitise_name("+_+ ") == "____" assert sanitise_name("input+") == "input_" assert sanitise_name("input-") == "input_" - assert sanitise_name("input++") == "input_" + assert sanitise_name("input++") == "input__" assert sanitise_name("woof:1") == "woof_1" with pytest.raises(TVMError):