From 9c185bdb4c627ebae4ed136fbe74892d25c8c9bf Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Mon, 20 Sep 2021 00:14:17 +0100 Subject: [PATCH] [3/10] Moved TIR generation from Python to C++ for CMSIS-NN (#8951) * [CMSIS-NN] Moved TIR Generation to C++ * Deleted self import for cmsisnn Change-Id: I2cdcd7a90aa4749877c48bc6c7c4d27328856860 * Reusing CodeGenC VistiExpr for softmax Change-Id: Ie41b695fa06468cd3b0bfe428c360e98438a9180 --- python/tvm/relay/backend/__init__.py | 1 - python/tvm/relay/backend/contrib/__init__.py | 18 -- .../relay/backend/contrib/cmsisnn/__init__.py | 18 -- .../relay/backend/contrib/cmsisnn/codegen.py | 134 ----------- python/tvm/relay/op/contrib/cmsisnn.py | 2 +- .../contrib/cmsisnn/codegen_cmsisnn.cc | 208 +++--------------- .../backend/contrib/cmsisnn/relay_to_tir.cc | 140 ++++++++++++ .../backend/contrib/cmsisnn/tir_to_runtime.cc | 138 ++++++++++++ 8 files changed, 306 insertions(+), 353 deletions(-) delete mode 100644 python/tvm/relay/backend/contrib/__init__.py delete mode 100644 python/tvm/relay/backend/contrib/cmsisnn/__init__.py delete mode 100644 python/tvm/relay/backend/contrib/cmsisnn/codegen.py create mode 100644 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc create mode 100644 src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc diff --git a/python/tvm/relay/backend/__init__.py b/python/tvm/relay/backend/__init__.py index b84e215fa581..4fc2b63748db 100644 --- a/python/tvm/relay/backend/__init__.py +++ b/python/tvm/relay/backend/__init__.py @@ -16,4 +16,3 @@ # under the License. """Backend codegen modules for relay.""" from . import compile_engine -from .contrib import cmsisnn diff --git a/python/tvm/relay/backend/contrib/__init__.py b/python/tvm/relay/backend/contrib/__init__.py deleted file mode 100644 index bfc5b79bb2ee..000000000000 --- a/python/tvm/relay/backend/contrib/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""External backend codegen modules for Relay.""" -from . import cmsisnn diff --git a/python/tvm/relay/backend/contrib/cmsisnn/__init__.py b/python/tvm/relay/backend/contrib/cmsisnn/__init__.py deleted file mode 100644 index cc6873f9fda6..000000000000 --- a/python/tvm/relay/backend/contrib/cmsisnn/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""CMSIS-NN codegen modules for relay.""" -from . import codegen diff --git a/python/tvm/relay/backend/contrib/cmsisnn/codegen.py b/python/tvm/relay/backend/contrib/cmsisnn/codegen.py deleted file mode 100644 index ef08f5eb317d..000000000000 --- a/python/tvm/relay/backend/contrib/cmsisnn/codegen.py +++ /dev/null @@ -1,134 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Codegen for CMSIS-NN""" -import tvm -from tvm import relay -from tvm.relay.expr_functor import ExprVisitor - - -class GenerateTIR(ExprVisitor): - """Generates TIR module containing TIR primfuncs corresponding to the Relay operators. - Note: Relay operator to primfunc mapping may not be 1:1. - """ - - def __init__(self, name): - super().__init__() - self.name = name - self.tir_mod = None - self.scale = 1.0 / 256 - - def call_contains_op(self, call, op_name): - if not isinstance(call.op, tvm.ir.op.Op): - return False - if call.op.name != op_name: - return False - return True - - def is_quantized_softmax(self, call): - """Checks for the following relay sequence - a = qnn.dequantize(in, scale, zero_point) - b = nn.softmax(a) - c = qnn.quantize(c, scale, zero_point) - """ - if not self.call_contains_op(call, "qnn.quantize"): - return False - softmax_call = call.args[0] - if not self.call_contains_op(softmax_call, "nn.softmax"): - return False - dequantize_call = softmax_call.args[0] - if not self.call_contains_op(dequantize_call, "qnn.dequantize"): - return False - self.scale = dequantize_call.args[1].data.numpy().item(0) - return True - - def emit_softmax_tir(self, call): - """Generates TIR extern_call for softmax""" - shape = call.checked_type.shape # NHWC - dtype = call.checked_type.dtype - ir_builder = tvm.tir.ir_builder.create() - in_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype) - out_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype) - - trailing_dim = len(shape) - 1 - num_rows = 1 - for dim in range(trailing_dim): - num_rows *= shape[dim] - row_size = shape[trailing_dim] - ir_builder.emit( - tvm.tir.call_extern( - dtype, - "arm_softmax_s8", - in_buf.data, - num_rows, - row_size, - self.scale, - out_buf.data, - ) - ) - prim_func = tvm.tir.PrimFunc([in_buf, out_buf], ir_builder.get()) - prim_func = prim_func.with_attr("global_symbol", self.name) - prim_func = prim_func.with_attr("tir.noalias", True) - self.tir_mod = tvm.IRModule({self.name: prim_func}) - - def visit_call(self, call): - """Iterates over the relay operators within relay external function""" - super().visit_call(call) - if self.is_quantized_softmax(call): - self.emit_softmax_tir(call) - - def generate_tir(self, func): - self.visit(func) - return self.tir_mod - - -def relay_to_tir(name, func): - """Lower a Relay function to TIR for the CMSIS-NN target. - - The Relay function should only contain operations supported - by the CMSIS-NN target. This is enforced by the graph partitioner - for CMSIS-NN. - - Parameters - ---------- - name: str - Name of the external relay function - func : tvm.relay.Function - The Relay function to lower. - - Returns - ------- - mod : tvm.IRModule - The lowered TIR module. - - """ - return GenerateTIR(name).generate_tir(func) - - -@tvm.register_func("relay.ext.cmsisnn") -def cmsisnn_compiler(relay_func): - """It compiles Relay's external function into equivalent TIR - and subsequently converts that into 'c' code. During the 'c' - code generation, it embeds CMSIS-NN APIs for the corresponding - operators. - """ - mod = tvm.IRModule() - mod["main"] = relay_func - mod = relay.transform.InferType()(mod) - func_name = relay_func.attrs["global_symbol"] - tir_mod = relay_to_tir(func_name, mod["main"]) - cmsisnn_runtime = tvm._ffi.get_global_func("runtime.module.cmsisnn.create") - return cmsisnn_runtime(tir_mod) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index f1153c6a8575..b74a09c4fef2 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -80,5 +80,5 @@ def check_quantized_softmax(extract): ) return [ - ("cmsisnn.qnn_softmax", softmax_pattern(), check_quantized_softmax), + ("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax), ] diff --git a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc b/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc index d2e498a52ce4..c8094109771b 100644 --- a/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc +++ b/src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc @@ -16,190 +16,36 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include -#include -#include -#include -#include - -#include "../../../../runtime/file_utils.h" -#include "../../../../target/source/codegen_c.h" -#include "../../../qnn/utils.h" +#include +#include +#include namespace tvm { -namespace runtime { - -using namespace tir; - -class CodeGenCMSISNN : public tvm::codegen::CodeGenC { - public: - void Init(bool output_ssa) { - decl_stream << "#include \n"; - decl_stream << "#include \n"; - decl_stream << "#include \n"; - decl_stream << "#include \n"; - decl_stream << "#include \n"; - CodeGenC::Init(output_ssa); - } - - /*! - * \brief Emit code that offloads a subgraph to the Cortex-M - * - * \return string of code that offloads a subgraph to the Cortex-M - */ - void AddFunction(const PrimFunc& prim_func) { - PrintExternCPrefix(stream); - CodeGenC::AddFunction(prim_func); - PrintExternCPostfix(stream); - } - - private: - void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (!op->op.same_as(builtin::call_extern())) { - return; - } - std::string cmsis_func_name = op->args[0].as()->value; - if (cmsis_func_name == "arm_softmax_s8") { - EmitSoftmax(op); - } - return; - } - - /*! * \brief Creates a cplusplus guard prefix for extern "C" printing */ - void PrintExternCPrefix(std::ostringstream& ss) { - PrintIndent(); - ss << "#ifdef __cplusplus\n"; - ss << "extern \"C\" {\n"; - ss << "#endif\n"; - } - - /*! * \brief Creates a cplusplus guard postfix for extern "C" printing */ - void PrintExternCPostfix(std::ostringstream& ss) { - PrintIndent(); - ss << "#ifdef __cplusplus\n"; - ss << "}\n"; - ss << "#endif\n"; - } - - /*! * \brief Emits CMSIS-NN code block for softmax */ - void EmitSoftmax(const CallNode* op) { - // @tir.call_extern("arm_softmax_s8", buffer_0, num_rows, row_size, scale, buffer_1, dtype=int8) - std::string cmsis_func_name = op->args[0].as()->value; - int32_t num_rows = op->args[2].as()->value; - int32_t row_size = op->args[3].as()->value; - float quant_scale = op->args[4].as()->value; - - // calculate multiplier and shift for CMSIS-NN softmax API - // Note: tfl micro assumptions - // TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); - // TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); - // TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); - double beta = 1.0; - int32_t input_bits = 5; - double beta_multiplier = (beta * quant_scale * (1 << (31 - input_bits))); - beta_multiplier = std::min(beta_multiplier, (1ll << 31) - 1.0); - auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier); - int32_t mult = std::get<0>(mult_shift_pair); - int32_t shift = std::get<1>(mult_shift_pair); - int32_t diff_min = (1 << 5) - 1; - diff_min <<= (31 - 5); - diff_min >>= shift; - diff_min *= -1; - - PrintIndent(); - stream << "int32_t num_rows = " << num_rows << ";\n"; - PrintIndent(); - stream << "int32_t row_size = " << row_size << ";\n"; - PrintIndent(); - stream << "int32_t mult = " << mult << ";\n"; - PrintIndent(); - stream << "int32_t shift = " << shift << ";\n"; - PrintIndent(); - stream << "int32_t diff_min = " << diff_min << ";\n"; - PrintIndent(); - stream << cmsis_func_name << "(buffer,"; - PrintIndent(); - stream << " num_rows, row_size, mult, shift, diff_min, buffer1);\n"; - PrintIndent(); - stream << "return;\n"; - } -}; - -class CMSISNNModuleNode : public runtime::ModuleNode { - public: - CMSISNNModuleNode(const std::string& code, const std::string& fmt, - const Array& func_names) - : code_(code), fmt_(fmt), func_names_(func_names) {} - - std::string GetSource(const std::string& format) final { return code_; } - - const char* type_key() const { return "c"; } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - if (name == "get_symbol") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); - } else if (name == "get_func_names") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); - } else { - return PackedFunc(nullptr); - } - } - - void SaveToFile(const std::string& file_name, const std::string& format) final { - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - if (fmt == "c" || fmt == "cu") { - ICHECK_NE(code_.length(), 0); - SaveBinaryToFile(file_name, code_); - } else { - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; - } - } - - protected: - std::string code_; - std::string fmt_; - Array func_names_; -}; - -class CMSISNNModule : public Module { - public: - CMSISNNModule() {} - explicit CMSISNNModule(ObjectPtr n) : Module(n) {} - inline CMSISNNModuleNode* operator->(); - inline const CMSISNNModuleNode* operator->() const; -}; - -inline CMSISNNModuleNode* CMSISNNModule::operator->() { - return static_cast(get_mutable()); -} - -static Module CMSISNNModuleNodeCreate(IRModule mod) { - bool output_ssa = false; - CodeGenCMSISNN cg; - Array function_names; - cg.Init(output_ssa); - ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc."; - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; - function_names.push_back(global_symbol.value()); - cg.AddFunction(f); - } - std::string code = cg.Finish(); - auto n = make_object(code, "c", function_names); - return Module(n); +namespace relay { +namespace contrib { +namespace cmsisnn { + +transform::Pass RelayToTIR(); + +runtime::Module CompileCMSISNN(const ObjectRef& ref) { + IRModule relay_mod; + Function relay_func = Downcast(ref); + auto func_name = relay_func->GetAttr(tvm::attr::kGlobalSymbol); + GlobalVar var = GlobalVar(func_name.value()); + relay_mod->Add(var, relay_func); + relay_mod = transform::InferType()(relay_mod); + + Array pass_seqs{transform::InferType(), RelayToTIR()}; + transform::Sequential seq(pass_seqs); + IRModule tir_mod = seq(relay_mod); + + const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate"); + return (*pf)(tir_mod); } -TVM_REGISTER_GLOBAL("runtime.module.cmsisnn.create").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CMSISNNModuleNodeCreate(args[0]); -}); +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN); -} // namespace runtime +} // namespace cmsisnn +} // namespace contrib +} // namespace relay } // namespace tvm diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc new file mode 100644 index 000000000000..7c1728ce0ed5 --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -0,0 +1,140 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../../../qnn/utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +class RelayToTIR : public MixedModeVisitor { + public: + explicit RelayToTIR(String func_name) : func_name_(func_name) {} + + private: + void emit_softmax_tir(const Expr& expr) { + auto* quantize_call = expr.as(); + auto* softmax_call = quantize_call->args[0].as(); + auto* dequant_call = softmax_call->args[0].as(); + auto* scale_const = dequant_call->args[1].as(); + const float quant_scale = static_cast(scale_const->data->data)[0]; + + // assuming layout as NHWC + auto shape = quantize_call->type_as()->shape; + int trailing_dim = shape.size() - 1; + int row_size = shape[trailing_dim].as()->value; + int num_rows = 1; + for (int i = 0; i < trailing_dim; ++i) { + num_rows *= shape[i].as()->value; + } + + // calculate multiplier and shift for CMSIS-NN softmax API + // Note: TensorFlow Lite Micro assumptions + // Output zero point and scale are fixed to -128 and 1 / 256 + // https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47 + double beta = 1.0; + int32_t input_bits = 5; + double beta_multiplier = (beta * quant_scale * (1 << (31 - input_bits))); + beta_multiplier = std::min(beta_multiplier, (1ll << 31) - 1.0); + auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier); + int32_t mult = std::get<0>(mult_shift_pair); + int32_t shift = std::get<1>(mult_shift_pair); + int32_t diff_min = (1 << 5) - 1; + diff_min <<= (31 - 5); + diff_min >>= shift; + diff_min *= -1; + + auto in_var = tir::Var("input", DataType::Handle(8)); + auto out_var = tir::Var("output", DataType::Handle(8)); + + Array func_signature{in_var, out_var}; + + tvm::Array args = { + tir::StringImm("arm_softmax_s8"), in_var, + IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size), + IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift), + IntImm(DataType::Int(32), diff_min), out_var}; + tir::Stmt body = + tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args)); + + Map dict_attrs; + dict_attrs.Set("global_symbol", func_name_); + dict_attrs.Set("tir.noalias", Bool(true)); + + primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map(), + DictAttrs(dict_attrs)); + } + + void VisitExpr_(const CallNode* call) final { + auto* func = call->op.as(); + if (func == nullptr) { + return; + } + + auto comp_name = func->GetAttr(attr::kComposite); + if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") { + emit_softmax_tir(func->body); + } + } + + public: + String func_name_; + tir::PrimFunc primfunc_; +}; + +IRModule GenerateTIR(IRModule mod) { + String func_name; + Function func; + + // Obtain external Relay Function that needs to be translated into TIR + ICHECK(mod->functions.size() == 1) << "Supports modules with single external Relay function."; + for (auto kv : mod->functions) { + func = Downcast(kv.second); + func_name = func->GetAttr(tvm::attr::kGlobalSymbol).value(); + } + + // Prepare PrimFunc from Relay Function + auto relay_to_tir = RelayToTIR(func_name); + relay_to_tir.VisitExpr(func->body); + + // Build the TIR IRModule from the generated PrimFunc + Map var_func_map; + var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_); + return IRModule(var_func_map); +} + +transform::Pass RelayToTIR() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, transform::PassContext pc) { return GenerateTIR(m); }; + return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIR", {}); +} + +} // namespace cmsisnn +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc new file mode 100644 index 000000000000..fb612e70311b --- /dev/null +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include + +#include "../../../../runtime/file_utils.h" +#include "../../../../target/source/codegen_c.h" + +namespace tvm { +namespace codegen { + +using namespace tir; + +class CodeGenCMSISNN : public CodeGenC { + public: + void Init(bool output_ssa) { + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + CodeGenC::Init(output_ssa); + } + + /*! + * \brief Emit code that offloads a subgraph to the Cortex-M + * + * \return string of code that offloads a subgraph to the Cortex-M + */ + void AddFunction(const PrimFunc& prim_func) { + PrintExternCPrefix(stream); + CodeGenC::AddFunction(prim_func); + PrintExternCPostfix(stream); + } + + private: + /*! * \brief Creates a cplusplus guard prefix for extern "C" printing */ + void PrintExternCPrefix(std::ostringstream& ss) { + PrintIndent(); + ss << "#ifdef __cplusplus\n"; + ss << "extern \"C\" {\n"; + ss << "#endif\n"; + } + + /*! * \brief Creates a cplusplus guard postfix for extern "C" printing */ + void PrintExternCPostfix(std::ostringstream& ss) { + PrintIndent(); + ss << "#ifdef __cplusplus\n"; + ss << "}\n"; + ss << "#endif\n"; + } +}; + +class CMSISNNModuleNode : public runtime::ModuleNode { + public: + CMSISNNModuleNode(const std::string& code, const std::string& fmt, + const Array& func_names) + : code_(code), fmt_(fmt), func_names_(func_names) {} + + std::string GetSource(const std::string& format) final { return code_; } + + const char* type_key() const { return "c"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); + } else { + return PackedFunc(nullptr); + } + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = runtime::GetFileFormat(file_name, format); + std::string meta_file = runtime::GetMetaFilePath(file_name); + if (fmt == "c") { + ICHECK_NE(code_.length(), 0); + runtime::SaveBinaryToFile(file_name, code_); + } else { + ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + } + } + + protected: + std::string code_; + std::string fmt_; + Array func_names_; +}; + +static runtime::Module CMSISNNModuleNodeCreate(IRModule mod) { + bool output_ssa = false; + CodeGenCMSISNN cg; + Array function_names; + cg.Init(output_ssa); + ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc."; + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodegenCMSISNN: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; + function_names.push_back(global_symbol.value()); + cg.AddFunction(f); + } + std::string code = cg.Finish(); + auto n = make_object(code, "c", function_names); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.CMSISNNModuleNodeCreate").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = CMSISNNModuleNodeCreate(args[0]); +}); + +} // namespace codegen +} // namespace tvm