diff --git a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h index 06188baaa9ee8..8352fa56981e1 100644 --- a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h +++ b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h @@ -24,7 +24,14 @@ #include #include -int32_t TVMEthosULaunch(struct ethosu_driver* resource_handle, void* cms_data, size_t cms_data_size, +typedef void tvm_device_ethos_u_t; + +int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size, uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors); +int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context); +int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context); +int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context); +int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context); + #endif // TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_ diff --git a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c index 6b7399b674069..8e506021521f4 100644 --- a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c +++ b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c @@ -21,8 +21,9 @@ #include -int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size, +int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size, uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) { + struct ethosu_driver* driver = (struct ethosu_driver*)context; int32_t result = ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors); @@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms } return 0; } + +int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {} +int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {} +int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {} +int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {} diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index db33c1b7844af..7b96dd87604e8 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -75,6 +75,8 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule): ---------- ir_mod : :py:class:`~tvm.IRModule` The IR module to build. + lowered_ir_mods : dict[Target, IRModule] + The IR modules lowered per Target. target : tvm.Target The Target used to build this module. libmod : tvm.Module @@ -89,8 +91,19 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule): List of devices used in the module """ - def __init__(self, ir_mod, target, libmod, libmod_name, params, function_metadata, devices): + def __init__( + self, + ir_mod, + lowered_ir_mods, + target, + libmod, + libmod_name, + params, + function_metadata, + devices, + ): self.ir_mod = ir_mod + self.lowered_ir_mods = lowered_ir_mods self.target = target self.lib = libmod self.libmod_name = libmod_name @@ -136,7 +149,14 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule): """ def __init__( - self, ir_mod, target, graph_json_str, libmod, libmod_name, params, function_metadata + self, + ir_mod, + target, + graph_json_str, + libmod, + libmod_name, + params, + function_metadata, ): assert isinstance(graph_json_str, string_types) fcreate = get_global_func("tvm.graph_executor_factory.create") diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 38c9a40b15bbb..b66d5fbec8c22 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -102,6 +102,7 @@ def __init__(self): self._get_params_func = self.mod["get_params"] self._get_function_metadata = self.mod["get_function_metadata"] self._get_devices = self.mod["get_devices"] + self._get_irmodule = self.mod["get_irmodule"] def build( self, mod, target=None, target_host=None, params=None, executor="graph", mod_name=None @@ -249,6 +250,10 @@ def get_params(self): ret[key] = value.data return ret + def get_irmodule(self): + """Returns the Target IRModule's post-lowering""" + return self._get_irmodule() + @register_func("tvm.relay.module_export_library") def _module_export(module, file_name): # fcompile, addons, kwargs? @@ -376,10 +381,18 @@ def build(ir_mod, target=None, target_host=None, params=None, mod_name="default" ) func_metadata = bld_mod.get_function_metadata() devices = bld_mod.get_devices() + lowered_ir_mods = bld_mod.get_irmodule() if executor == "aot": executor_factory = _executor_factory.AOTExecutorFactoryModule( - ir_mod, target, runtime_mod, mod_name, params, func_metadata, devices + ir_mod, + lowered_ir_mods, + target, + runtime_mod, + mod_name, + params, + func_metadata, + devices, ) elif executor == "graph": executor_factory = _executor_factory.GraphExecutorFactoryModule( diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index c240ec8b45f93..fde1de061bfa4 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -349,11 +349,19 @@ class AOTExecutorCodegen : public MixedModeVisitor { GlobalVar global_var = call_lowered_props.lowered_func; bool has_c_device_api_context = device_contexts_.count(global_var) != 0; if (has_c_device_api_context) { + tir::Var context = device_contexts_.Get(global_var).value(); args.push_back(device_contexts_[global_var]); - } - tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); - create_func_call_stmts.push_back(func_call); + tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); + create_func_call_stmts.push_back(tir::SeqStmt({ + GenerateDeviceHook(context, "Open"), + func_call, + GenerateDeviceHook(context, "Close"), + })); + } else { + tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args)); + create_func_call_stmts.push_back(func_call); + } tir::Stmt body = tir::SeqStmt(create_func_call_stmts); stmts_.push_back(body); @@ -417,7 +425,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { device_context_var = (*pair).second; } else { main_signature_.push_back(device_context_var); - devices_.push_back(context_name); + devices_.Set(context_name, device_context_var); target_contexts.Set(target_kind.value(), device_context_var); } @@ -426,6 +434,44 @@ class AOTExecutorCodegen : public MixedModeVisitor { } } + /** + * \brief Generates a call to a given hook for all Devices found for C Device API + * \param Name of hook to generate statements for + * \return Statement with function calls for each device + */ + tir::Stmt GenerateAllDeviceHook(const String& hook) { + std::vector device_hooks; + for (const auto& it : devices_) { + const String& device_name = it.first; + const tir::Var& context = it.second; + Array sections = {"Device", device_name, hook}; + String device_hook_name = ToCFunctionStyle(PrefixName(sections)); + + tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook_name), context})); + device_hooks.push_back(device_hook); + } + return tir::SeqStmt(device_hooks); + } + + /** + * \brief Generates a call to a given hook for a single Device function + * \param Var Device context to call hook on + * \param Name of hook to generate statements for + * \return Statement with function call to Device API + */ + tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) { + const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) { + return it.second->name_hint == context->name_hint; + }); + const String& device_name = (*it).first; + Array sections = {"Device", device_name, hook}; + String device_hook = ToCFunctionStyle(PrefixName(sections)); + + return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), + {tvm::tir::StringImm(device_hook), context})); + } + /*! * Utility function to string together different arguments */ @@ -587,8 +633,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { dict_attrs.Set("global_symbol", run_func_name); dict_attrs.Set("runner_function", Bool(true)); + tir::Stmt device_activations = GenerateAllDeviceHook("Activate"); + tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate"); + tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); + // Make the PrimFunc - return tir::PrimFunc(main_signature_, body, VoidType(), Map(), + return tir::PrimFunc(main_signature_, final_body, VoidType(), Map(), DictAttrs(dict_attrs)); } @@ -597,8 +647,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { runtime::Module* mod_; /*! \brief list of input expressions (i.e., variable passed by the user) */ std::vector input_vars_; - /*! \brief list of device contexts used */ - std::vector devices_; + /*! \brief map of device contexts variables */ + Map devices_; /*! \brief map of GlobalVars to C Device API contexts */ Map device_contexts_; /*! \brief input and output variables belonging to the main function signature */ @@ -779,7 +829,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(), [](Var input_var) -> String { return input_var->name_hint(); }); - ret.metadata = runtime::Metadata(input_var_names, devices_, return_sid_.size(), + ret.metadata = runtime::Metadata(input_var_names, ListDevices(), return_sid_.size(), runtime::kTvmExecutorAot, mod_name); return ret; } @@ -788,7 +838,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { * \brief Get list of devices found * \return List of devices */ - Array ListDevices() { return devices_; } + Array ListDevices() { + std::vector device_names(devices_.size()); + std::transform(devices_.begin(), devices_.end(), device_names.begin(), + [](const auto& it) -> String { return it.first; }); + return device_names; + } }; // namespace backend class AOTExecutorCodegenModule : public runtime::ModuleNode { diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c index 6b7399b674069..8e506021521f4 100644 --- a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c +++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c @@ -21,8 +21,9 @@ #include -int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size, +int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cms_data_size, uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors) { + struct ethosu_driver* driver = (struct ethosu_driver*)context; int32_t result = ethosu_invoke(driver, cms_data, cms_data_size, base_addrs, base_addrs_size, num_tensors); @@ -32,3 +33,8 @@ int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms } return 0; } + +int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {} +int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {} +int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {} +int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {} diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h index d62afc4c69efc..31d17557aa84b 100644 --- a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h +++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h @@ -24,7 +24,14 @@ #include #include -int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t cms_data_size, +typedef void tvm_device_ethos_u_t; + +int32_t TVMEthosULaunch(tvm_device_ethos_u_t* resource_handle, void* cms_data, size_t cms_data_size, uint64_t* base_addrs, size_t* base_addrs_size, int num_tensors); +int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context); +int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context); +int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context); +int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context); + #endif // TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_ diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 94ecaba280adc..e2bbb24c55d34 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -17,6 +17,7 @@ from collections import OrderedDict import sys +import re import numpy as np import pytest @@ -693,5 +694,108 @@ def @main(%data: Tensor[(1, 4, 4, 4), float32], %weight: Tensor[(4, 4, 3, 3), fl assert source.count("TVMBackendAllocWorkspace") == 3 +def test_device_api_hooks(): + """Check for Device API hooks""" + + # Ideally we should have a sample Target registered here + # but we're going to re-use this for now + pytest.importorskip("ethosu.vela") + import tensorflow as tf + import tflite.Model + + from tests.python.contrib.test_ethosu import infra + from tvm.relay.op.contrib.ethosu import partition_for_ethosu + + def create_tflite_graph(): + tf.config.run_functions_eagerly(True) + + class Model(tf.Module): + @tf.function + def tf_function(self, x): + return tf.nn.max_pool(x, [1, 2], [1, 2], "SAME") + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple([1, 3, 4, 3])) + yield [data.astype(np.float32)] + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec([1, 3, 4, 3], dtype=tf.float32) + ) + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": [1, 3, 4, 3]}, + dtype_dict={"x": "int8"}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + ) + main_ir_module = list(compiled_models[0].executor_factory.lowered_ir_mods.values())[0] + main_func = main_ir_module["run_model"] + + # Activate Device + assert ( + str(main_func.body[0][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUActivate",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Open Device + assert ( + str(main_func.body[1].body.body[0][0][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUOpen",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Device Call + assert ( + str(main_func.body[1].body.body[0][0][1].value) + == "@tir.call_extern(" + + '"tvmgen_default_ethos_u_main_0",' + + " input: handle, output: handle," + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Close Device + assert ( + str(main_func.body[1].body.body[0][0][2].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUClose",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + # Deactivate Device + assert ( + str(main_func.body[2][0].value) + == "@tir.call_extern(" + + '"TVMDeviceEthosUDeactivate",' + + " device_context_ethos_u: handle," + + " dtype=int32)" + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))