diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 97b308e51e182..9886957e60eb4 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -606,25 +606,38 @@ std::pair, std::vector> EthosnCompiler::GetInput return std::make_pair(input_order, output_order); } -auto ctx = transform::PassContext::Current(); -auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() +std::unique_ptr EthosnCompiler::m_Queries; + +EthosnError EthosnCompiler::SupportedSetup() { + if (m_Queries == nullptr) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx -> GetConfig("relay.ext.ethos-n.options").defined() ? ctx -> GetConfig("relay.ext.ethos-n.options") : AttrsWithDefaultValues(); -auto m_Queries = sl::SupportQueries(sl::GetFwAndHwCapabilities( - sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + m_Queries = std::make_unique(sl::GetFwAndHwCapabilities( + sl::EthosNVariantFromString(cfg.value()->variant.c_str()), cfg.value()->sram_size_bytes)); + if (m_Queries == nullptr) { + return EthosnError("Could not initialise Ethos-N compiler isSupported"); + } + } + return EthosnError(); +} TVM_REGISTER_GLOBAL("relay.ethos-n.support.conv2d") .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { Call call = args[0]; ConvolutionParams params; auto err = EthosnAPI::QnnConv2d(call, ¶ms); + err += EthosnCompiler::SupportedSetup(); if (params.is_depthwise) { *rv = !err && - m_Queries.IsDepthwiseConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + EthosnCompiler::GetSupported()->IsDepthwiseConvolutionSupported( + params.bias_info, params.weights_info, + params.conv_info, params.activation_info); } else { - *rv = !err && m_Queries.IsConvolutionSupported(params.bias_info, params.weights_info, - params.conv_info, params.activation_info); + *rv = !err && EthosnCompiler::GetSupported()->IsConvolutionSupported( + params.bias_info, params.weights_info, + params.conv_info, params.activation_info); } }); @@ -633,8 +646,10 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.fc") Call call = args[0]; FullyConnectedParams params; auto err = EthosnAPI::QnnFullyConnected(call, ¶ms); - *rv = !err && m_Queries.IsFullyConnectedSupported(params.bias_info, params.weights_info, - params.fc_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsFullyConnectedSupported( + params.bias_info, params.weights_info, + params.fc_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") @@ -642,7 +657,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.max_pool2d") Call call = args[0]; MaxPool2DParams params; auto err = EthosnAPI::MaxPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsPoolingSupported( + params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") @@ -650,7 +667,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.avg_pool2d") Call call = args[0]; AvgPool2DParams params; auto err = EthosnAPI::AvgPool2D(call, ¶ms); - *rv = !err && m_Queries.IsPoolingSupported(params.pool_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsPoolingSupported( + params.pool_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") @@ -658,7 +677,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.reshape") Call call = args[0]; ReshapeParams params; auto err = EthosnAPI::Reshape(call, ¶ms); - *rv = !err && m_Queries.IsReshapeSupported(params.new_shape, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsReshapeSupported( + params.new_shape, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") @@ -666,8 +687,10 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.addition") Call call = args[0]; AdditionParams params; auto err = EthosnAPI::Addition(call, ¶ms); - *rv = !err && m_Queries.IsAdditionSupported(params.lhs_info, params.rhs_info, - params.output_quantization_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsAdditionSupported( + params.lhs_info, params.rhs_info, + params.output_quantization_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") @@ -675,7 +698,8 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.sigmoid") Call call = args[0]; SigmoidParams params; auto err = EthosnAPI::Sigmoid(call, ¶ms); - *rv = !err && m_Queries.IsSigmoidSupported(params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsSigmoidSupported(params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") @@ -683,7 +707,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.concatenate") Call call = args[0]; ConcatenateParams params; auto err = EthosnAPI::Concatenate(call, ¶ms); - *rv = !err && m_Queries.IsConcatenationSupported(params.input_infos, params.concat_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsConcatenationSupported( + params.input_infos, params.concat_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") @@ -691,7 +717,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.split") Call call = args[0]; SplitParams params; auto err = EthosnAPI::Split(call, ¶ms); - *rv = !err && m_Queries.IsSplitSupported(params.input_info, params.split_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsSplitSupported( + params.input_info, params.split_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") @@ -699,7 +727,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.depth_to_space") Call call = args[0]; DepthToSpaceParams params; auto err = EthosnAPI::DepthToSpace(call, ¶ms); - *rv = !err && m_Queries.IsDepthToSpaceSupported(params.input_info, params.depth_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsDepthToSpaceSupported( + params.input_info, params.depth_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") @@ -707,7 +737,9 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.relu") Call call = args[0]; ReluParams params; auto err = EthosnAPI::Relu(call, ¶ms); - *rv = !err && m_Queries.IsReluSupported(params.relu_info, params.input_info); + err += EthosnCompiler::SupportedSetup(); + *rv = !err && EthosnCompiler::GetSupported()->IsReluSupported( + params.relu_info, params.input_info); }); TVM_REGISTER_GLOBAL("relay.ethos-n.query").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { diff --git a/src/relay/backend/contrib/ethosn/codegen_ethosn.h b/src/relay/backend/contrib/ethosn/codegen_ethosn.h index 63ae7a3e47049..1d6ebfcd2acc6 100644 --- a/src/relay/backend/contrib/ethosn/codegen_ethosn.h +++ b/src/relay/backend/contrib/ethosn/codegen_ethosn.h @@ -287,6 +287,22 @@ class EthosnCompiler { */ static runtime::Module CreateRuntimeModule(const ObjectRef& ref); + /*! + * \brief Initialise the is-supported functionality of the Ethos-N support library + * with the target variant. + * \return Error object + */ + static EthosnError SupportedSetup(); + + /*! + * \brief Return the is-supported API of the Support Library + * \return A reference to the API. + */ + static std::unique_ptr& GetSupported() { + ICHECK(m_Queries != nullptr); + return m_Queries; + } + private: /*! * \brief Compile a single Relay Ethos-N function into an ordered compiled network. @@ -322,6 +338,8 @@ class EthosnCompiler { */ static std::pair, std::vector> GetInputOutputOrder( NetworkWithIDs network, const std::unique_ptr& compiled_network); + + static std::unique_ptr m_Queries; }; runtime::Module CompileEthosn(const ObjectRef& ref) {