From c63f453b54340af943b3f4ea1a71ba2077ff5161 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Mon, 14 Mar 2022 16:50:41 -0700 Subject: [PATCH] Update convert_onnx_models_to_ort.py to support runtime optimizations. (#10765) Add runtime optimization support to ONNX -> ORT format conversion script. Replace `--optimization_level`, `--use_nnapi`, and `--use_coreml` with a new `--optimization_style` option. --- cmake/onnxruntime_python.cmake | 1 + .../onnxruntime_session_options_config_keys.h | 28 +-- onnxruntime/core/framework/config_options.cc | 4 +- onnxruntime/core/framework/config_options.h | 8 +- .../core/optimizer/graph_transformer_utils.cc | 2 +- onnxruntime/core/session/inference_session.cc | 105 +++++--- onnxruntime/core/session/inference_session.h | 25 +- onnxruntime/test/framework/test_utils.h | 9 + .../graph_runtime_optimization_test.cc | 48 +++- .../core/session/training_session.cc | 11 +- .../core/session/training_session.h | 7 +- tools/python/create_reduced_build_config.py | 51 ++-- .../util/check_onnx_model_mobile_usability.py | 17 +- .../python/util/convert_onnx_models_to_ort.py | 230 ++++++++++-------- tools/python/util/file_utils.py | 46 ++++ tools/python/util/ort_format_model/utils.py | 66 ++--- 16 files changed, 403 insertions(+), 255 deletions(-) create mode 100644 tools/python/util/file_utils.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 416d50a0b26ea..fa9c6410013d8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -380,6 +380,7 @@ file(GLOB onnxruntime_python_datasets_data CONFIGURE_DEPENDS set(onnxruntime_mobile_util_srcs ${REPO_ROOT}/tools/python/util/check_onnx_model_mobile_usability.py ${REPO_ROOT}/tools/python/util/convert_onnx_models_to_ort.py + ${REPO_ROOT}/tools/python/util/file_utils.py ${REPO_ROOT}/tools/python/util/logger.py ${REPO_ROOT}/tools/python/util/make_dynamic_shape_fixed.py ${REPO_ROOT}/tools/python/util/onnx_model_utils.py diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 457d20f9a152f..70fd33b2c4838 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -54,6 +54,7 @@ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_qu // other factors like whether the model was created using Quantization Aware Training or Post Training Quantization. // As such, it's best to test to determine if enabling this works well for your scenario. // The default value is "0" +// Available since version 1.11. static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup"; // Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0". @@ -80,25 +81,18 @@ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "ses // This should only be specified when exporting an ORT format model for use on a different platform. // If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0" +// Available since version 1.11. static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed"; -// Save information for replaying graph optimizations later instead of applying them directly. -// -// When an ONNX model is loaded, ORT can perform various optimizations on the graph. -// However, when an ORT format model is loaded, the logic to perform these optimizations may not be available because -// this scenario must be supported by minimal builds. -// When loading an ONNX model, ORT can optionally save the effects of some optimizations for later replay in an ORT -// format model. These are known as "runtime optimizations" - in an ORT format model, they happen at runtime. -// -// Note: This option is only applicable when loading an ONNX model and saving an ORT format model. -// -// Note: Runtime optimizations are only supported for certain optimizations at the extended level or higher. -// Unsupported optimizations at those levels are not applied at all, while optimizations at other levels are applied -// directly. -// -// "0": disabled, "1": enabled -// The default is "0". -static const char* const kOrtSessionOptionsConfigSaveRuntimeOptimizations = "optimization.save_runtime_optimizations"; +// Specifies how minimal build graph optimizations are handled in a full build. +// These optimizations are at the extended level or higher. +// Possible values and their effects are: +// "save": Save runtime optimizations when saving an ORT format model. +// "apply": Only apply optimizations available in a minimal build. +// ""/: Apply optimizations available in a full build. +// Available since version 1.11. +static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations = + "optimization.minimal_build_optimizations"; // Note: The options specific to an EP should be specified prior to appending that EP to the session options object in // order for them to take effect. diff --git a/onnxruntime/core/framework/config_options.cc b/onnxruntime/core/framework/config_options.cc index 05ab8627dd3df..d74989bb478a5 100644 --- a/onnxruntime/core/framework/config_options.cc +++ b/onnxruntime/core/framework/config_options.cc @@ -22,8 +22,8 @@ bool ConfigOptions::TryGetConfigEntry(const std::string& config_key, std::string return found; } -const std::string ConfigOptions::GetConfigOrDefault(const std::string& config_key, - const std::string& default_value) const noexcept { +std::string ConfigOptions::GetConfigOrDefault(const std::string& config_key, + const std::string& default_value) const noexcept { return GetConfigEntry(config_key).value_or(default_value); } diff --git a/onnxruntime/core/framework/config_options.h b/onnxruntime/core/framework/config_options.h index 24f682bdad429..e70261797fd43 100644 --- a/onnxruntime/core/framework/config_options.h +++ b/onnxruntime/core/framework/config_options.h @@ -12,9 +12,9 @@ namespace onnxruntime { /** - * Configuration options that can be used by any struct by inheriting this class. - * Provides infrastructure to add/get config entries - */ + * Configuration options that can be used by any struct by inheriting this class. + * Provides infrastructure to add/get config entries + */ struct ConfigOptions { std::unordered_map configurations; @@ -29,7 +29,7 @@ struct ConfigOptions { // Get the config string in this instance of ConfigOptions using the given config_key // If there is no such config, the given default string will be returned - const std::string GetConfigOrDefault(const std::string& config_key, const std::string& default_value) const noexcept; + std::string GetConfigOrDefault(const std::string& config_key, const std::string& default_value) const noexcept; // Add a config pair (config_key, config_value) to this instance of ConfigOptions Status AddConfigEntry(const char* config_key, const char* config_value) noexcept; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 21216dedc1c55..67d81a21cbec1 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -286,7 +286,7 @@ InlinedVector> GenerateTransformersForMinimalB const IExecutionProvider& cpu_execution_provider, const InlinedHashSet& rules_and_transformers_to_disable) { InlinedVector> transformers; - bool saving = std::holds_alternative(apply_context); + const bool saving = std::holds_alternative(apply_context); switch (level) { case TransformerLevel::Level1: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 0fe1972ab81f1..f665a449471c1 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -163,10 +163,10 @@ Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::Logger& return status; } -} // namespace #if !defined(ORT_MINIMAL_BUILD) -static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { + +bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); @@ -178,7 +178,7 @@ static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderTy return true; } -static bool HasControlflowNodes(const Graph& graph) { +bool HasControlflowNodes(const Graph& graph) { for (const auto& node : graph.Nodes()) { if (node.ContainsSubgraph()) { return true; @@ -187,7 +187,40 @@ static bool HasControlflowNodes(const Graph& graph) { return false; } -#endif + +Status GetMinimalBuildOptimizationHandling( + std::string_view config_value, bool saving_ort_format, + InferenceSession::MinimalBuildOptimizationHandling& minimal_build_optimization_handling) { + if (config_value == "save") { + if (saving_ort_format) { + minimal_build_optimization_handling = + InferenceSession::MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations; + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + kOrtSessionOptionsConfigMinimalBuildOptimizations, + " value of 'save' is only valid when saving an ORT format model."); + } + + if (config_value == "apply") { + minimal_build_optimization_handling = + InferenceSession::MinimalBuildOptimizationHandling::OnlyApplyMinimalBuildOptimizations; + return Status::OK(); + } + + if (config_value.empty()) { + minimal_build_optimization_handling = + InferenceSession::MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; + return Status::OK(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid value for ", kOrtSessionOptionsConfigMinimalBuildOptimizations, ": ", config_value); +}; + +#endif // !defined(ORT_MINIMAL_BUILD) + +} // namespace std::atomic InferenceSession::global_session_id_{1}; @@ -1402,14 +1435,17 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) if (!loading_ort_format) { - const bool saving_runtime_optimizations = - saving_ort_format && - session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigSaveRuntimeOptimizations, - "0") == "1"; + const auto minimal_build_opt_config_value = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsConfigMinimalBuildOptimizations, ""); + MinimalBuildOptimizationHandling minimal_build_optimization_handling{}; + ORT_RETURN_IF_ERROR_SESSIONID_(GetMinimalBuildOptimizationHandling(minimal_build_opt_config_value, + saving_ort_format, + minimal_build_optimization_handling)); + // add predefined transformers ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, - saving_runtime_optimizations)); + minimal_build_optimization_handling)); // apply any transformations to the main graph and any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_, @@ -1436,9 +1472,9 @@ common::Status InferenceSession::Initialize() { // Return error status as we don't want the session initialization to complete successfully // if the user has requested usage of CUDA Graph feature and we cannot honor that. ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as the model has control flow nodes which can't be supported by CUDA Graphs.")); + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as the model has control flow nodes which can't be supported by CUDA Graphs.")); } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) { LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " << " as all the graph nodes have not been partitioned to the CUDA EP."; @@ -1446,9 +1482,9 @@ common::Status InferenceSession::Initialize() { // Return error status as we don't want the session initialization to complete successfully // if the user has requested usage of CUDA Graph feature and we cannot honor that. ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all the graph nodes have not been partitioned to the CUDA EP.")); + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as all the graph nodes have not been partitioned to the CUDA EP.")); } else { LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; @@ -1875,11 +1911,11 @@ Status InferenceSession::Run(const RunOptions& run_options, // Check if this Run() is simply going to be a CUDA Graph replay. if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { - LOGS(*session_logger_, INFO) << "Replaying the captured " - << cached_execution_provider_for_graph_replay_.Type() - << " CUDA Graph for this model with tag: " << run_options.run_tag; - ++current_num_runs_; - ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph()); + LOGS(*session_logger_, INFO) << "Replaying the captured " + << cached_execution_provider_for_graph_replay_.Type() + << " CUDA Graph for this model with tag: " << run_options.run_tag; + ++current_num_runs_; + ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph()); } else { std::vector exec_providers_to_stop; exec_providers_to_stop.reserve(execution_providers_.NumProviders()); @@ -1951,13 +1987,13 @@ Status InferenceSession::Run(const RunOptions& run_options, } #endif - // execute the graph + // execute the graph #ifdef DEBUG_NODE_INPUTS_OUTPUTS session_state_->IncrementGraphExecutionCounter(); #endif ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, - session_options_.execution_mode, run_options.terminate, run_logger, - run_options.only_execute_path_to_fetches)); + session_options_.execution_mode, run_options.terminate, run_logger, + run_options.only_execute_path_to_fetches)); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { @@ -2010,7 +2046,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // are needed before replaying the captured graph, here run the inference again // to capture the graph, so that users just need one session run to capture // the graph. - if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && + if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. " "The first one is for necessary memory allocation;" @@ -2361,21 +2397,30 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { #if !defined(ORT_MINIMAL_BUILD) // Registers all the predefined transformers with transformer manager -common::Status InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - bool saving_runtime_optimizations) const { +common::Status InferenceSession::AddPredefinedTransformers( + GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + MinimalBuildOptimizationHandling minimal_build_optimization_handling) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); if (graph_optimization_level >= level) { // Generate and register transformers for level auto transformers_to_register = [&]() { - if (!saving_runtime_optimizations || level == TransformerLevel::Level1) { + const bool use_full_build_optimizations = + level == TransformerLevel::Level1 || + minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; + + if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, optimizers_to_disable_); } else { - SatRuntimeOptimizationSaveContext save_context{kernel_registry_manager_}; - return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, save_context, cpu_ep, + const auto sat_context = + minimal_build_optimization_handling == + MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations + ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{kernel_registry_manager_}} + : SatApplyContextVariant{SatDirectApplicationContext{}}; + return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, optimizers_to_disable_); } }(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 1f8b26cbe5061..4c4cc884bdb90 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -107,6 +107,23 @@ struct ModelMetadata { class InferenceSession { public: +#if !defined(ORT_MINIMAL_BUILD) + + /** + * How minimal build graph optimizations should be handled in a full build. + * Note: These only apply to optimizations at the extended level or higher. + */ + enum class MinimalBuildOptimizationHandling { + /** Run full build optimizations. The default behavior. */ + ApplyFullBuildOptimizations, + /** Save minimal build optimizations as runtime optimizations in an ORT format model. */ + SaveMinimalBuildRuntimeOptimizations, + /** Only run minimal build optimizations. */ + OnlyApplyMinimalBuildOptimizations, + }; + +#endif + /** Create a new InferenceSession @param session_options Session options. @@ -444,6 +461,7 @@ class InferenceSession { protected: #if !defined(ORT_MINIMAL_BUILD) + /** * Load an ONNX model. * @param protobuf object corresponding to the model file. model_proto will be copied by the API. @@ -583,9 +601,10 @@ class InferenceSession { void ShrinkMemoryArenas(const std::vector& arenas_to_shrink); #if !defined(ORT_MINIMAL_BUILD) - virtual common::Status AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - bool saving_runtime_optimizations) const; + virtual common::Status AddPredefinedTransformers( + GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + MinimalBuildOptimizationHandling minimal_build_optimization_handling) const; common::Status TransformGraph(onnxruntime::Graph& graph, const onnxruntime::GraphTransformerManager& graph_transformer_mgr, diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/framework/test_utils.h index 12ae831b3b833..1e97f44629857 100644 --- a/onnxruntime/test/framework/test_utils.h +++ b/onnxruntime/test/framework/test_utils.h @@ -93,6 +93,15 @@ using OpCountMap = std::map; // Helper function to check that the graph transformations have been successfully applied. OpCountMap CountOpsInGraph(const Graph& graph, bool recurse_into_subgraphs = true); +// Gets the op count from the OpCountMap. +// Can be called with a const OpCountMap, unlike OpCountMap::operator[]. +inline int OpCount(const OpCountMap& op_count_map, const std::string& op_type) { + if (auto it = op_count_map.find(op_type); it != op_count_map.end()) { + return it->second; + } + return 0; +} + #if !defined(DISABLE_SPARSE_TENSORS) void SparseIndicesChecker(const ONNX_NAMESPACE::TensorProto& indices_proto, gsl::span expected_indicies); #endif // DISABLE_SPARSE_TENSORS diff --git a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc index 6b179b49c78b6..eb01ec687b620 100644 --- a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc +++ b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc @@ -185,7 +185,7 @@ using GraphCheckerFn = std::function; void LoadAndInitializeSession(const SessionOptions& so, const PathString& input_model_path, const GraphOpCountsCheckerFn& graph_op_count_checker_fn, - const GraphCheckerFn* graph_checker_fn = nullptr) { + const GraphCheckerFn& graph_checker_fn = {}) { InferenceSessionWrapper session{so, GetEnvironment()}; ASSERT_STATUS_OK(session.Load(input_model_path)); @@ -196,10 +196,12 @@ void LoadAndInitializeSession(const SessionOptions& so, const PathString& input_ const auto initialized_ops = CountOpsInGraph(session.GetGraph()); - graph_op_count_checker_fn(loaded_ops, initialized_ops); + if (graph_op_count_checker_fn) { + graph_op_count_checker_fn(loaded_ops, initialized_ops); + } if (graph_checker_fn) { - (*graph_checker_fn)(session.GetGraph()); + graph_checker_fn(session.GetGraph()); } } @@ -223,7 +225,7 @@ void SaveAndLoadRuntimeOptimizationsForModel( if (do_save) { SessionOptions so{}; ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigSaveModelFormat, "ORT")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigSaveRuntimeOptimizations, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigMinimalBuildOptimizations, "save")); so.graph_optimization_level = TransformerLevel::Level2; so.optimized_model_filepath = saved_runtime_optimizations_model_path; @@ -296,7 +298,7 @@ void CheckNhwcTransformerIsApplied() { (OpCountMap{{"Transpose", 6}, {"com.microsoft.QLinearConv", n}})); }, - &checker_fn)); + checker_fn)); } } } // namespace @@ -341,6 +343,42 @@ TEST(GraphRuntimeOptimizationTest, TestNhwcTransformer) { CheckNhwcTransformerIsApplied(); } +#if !defined(ORT_MINIMAL_BUILD) +TEST(GraphRuntimeOptimizationTest, TestOnlyApplyMinimalBuildOptimizations) { + // This test assumes that AttentionFusion is not included in the minimal build optimizations. + // Update it if that changes. + + // When setting the option to only apply minimal build optimizations, verify that AttentionFusion does not run. + { + SessionOptions so{}; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigMinimalBuildOptimizations, "apply")); + so.graph_optimization_level = TransformerLevel::Level2; + + LoadAndInitializeSession( + so, + ORT_TSTR("testdata/transform/fusion/attention_int32_mask.onnx"), + [](const OpCountMap& /*initialized_ops*/, const OpCountMap& loaded_ops) { + // expect no fused node + EXPECT_EQ(OpCount(loaded_ops, "com.microsoft.Attention"), 0); + }); + } + + // Otherwise, it should run. + { + SessionOptions so{}; + so.graph_optimization_level = TransformerLevel::Level2; + + LoadAndInitializeSession( + so, + ORT_TSTR("testdata/transform/fusion/attention_int32_mask.onnx"), + [](const OpCountMap& /*initialized_ops*/, const OpCountMap& loaded_ops) { + // expect fused node + EXPECT_EQ(OpCount(loaded_ops, "com.microsoft.Attention"), 1); + }); + } +} +#endif // !defined(ORT_MINIMAL_BUILD) + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace onnxruntime::test diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 30dd2a23f2902..60b278ac7cddf 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -754,10 +754,13 @@ void TrainingSession::AddPreTrainingTransformers(const IExecutionProvider& execu } // Registers all the predefined transformers with transformer manager -Status TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - bool saving_runtime_optimizations) const { - ORT_RETURN_IF(saving_runtime_optimizations, "Saving runtime optimizations is not supported by TrainingSession."); +Status TrainingSession::AddPredefinedTransformers( + GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + MinimalBuildOptimizationHandling minimal_build_optimization_handling) const { + ORT_RETURN_IF_NOT( + minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations, + "Only applying full build optimizations is supported by TrainingSession."); ORT_RETURN_IF_NOT(graph_optimization_level <= TransformerLevel::MaxLevel, "Exceeded max transformer level. Current level is set to " + diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 1150241dae1c7..f6bf8985835ab 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -485,9 +485,10 @@ class TrainingSession : public InferenceSession { TransformerLevel graph_optimization_level = TransformerLevel::MaxLevel); /** override the parent method in inference session for training specific transformers */ - common::Status AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - bool saving_runtime_optimizations) const override; + common::Status AddPredefinedTransformers( + GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + MinimalBuildOptimizationHandling minimal_build_optimization_handling) const override; /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. diff --git a/tools/python/create_reduced_build_config.py b/tools/python/create_reduced_build_config.py index 26caf694ef657..9062c977658f9 100644 --- a/tools/python/create_reduced_build_config.py +++ b/tools/python/create_reduced_build_config.py @@ -3,10 +3,18 @@ # Licensed under the MIT License. import argparse -import os import onnx import pathlib import sys +import typing + +from util.file_utils import files_from_file_or_dir, path_match_suffix_ignore_case + + +def _get_suffix_match_predicate(suffix: str): + def predicate(file_path: pathlib.Path): + return path_match_suffix_ignore_case(file_path, suffix) + return predicate def _extract_ops_from_onnx_graph(graph, operators, domain_opset_map): @@ -51,39 +59,29 @@ def _process_onnx_model(model_path, required_ops): _extract_ops_from_onnx_graph(model.graph, required_ops, domain_opset_map) -def _extract_ops_from_onnx_model(model_path_or_dir): - '''Extract ops from a single ONNX model, or all ONNX models found by recursing model_path_or_dir''' - - if not os.path.exists(model_path_or_dir): - raise ValueError('Path to model/s does not exist: {}'.format(model_path_or_dir)) +def _extract_ops_from_onnx_model(model_files: typing.Iterable[pathlib.Path]): + '''Extract ops from ONNX models''' required_ops = {} - if os.path.isfile(model_path_or_dir): - _process_onnx_model(model_path_or_dir, required_ops) - else: - for root, _, files in os.walk(model_path_or_dir): - for file in files: - if file.lower().endswith('.onnx'): - model_path = os.path.join(root, file) - _process_onnx_model(model_path, required_ops) + for model_file in model_files: + if not model_file.is_file(): + raise ValueError(f"Path is not a file: '{model_file}'") + _process_onnx_model(model_file, required_ops) return required_ops -def create_config_from_onnx_models(model_path_or_dir: str, output_file: str): - - required_ops = _extract_ops_from_onnx_model(model_path_or_dir) +def create_config_from_onnx_models(model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path): - directory, filename = os.path.split(output_file) - if not filename: - raise RuntimeError("Invalid output path for configuation: {}".format(output_file)) + required_ops = _extract_ops_from_onnx_model(model_files) - if not os.path.exists(directory): - os.makedirs(directory) + output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, 'w') as out: - out.write("# Generated from ONNX models path of {}\n".format(model_path_or_dir)) + out.write("# Generated from ONNX model/s:\n") + for model_file in sorted(model_files): + out.write(f"# - {model_file}\n") for domain in sorted(required_ops.keys()): for opset in sorted(required_ops[domain].keys()): @@ -129,10 +127,13 @@ def main(): config_path = config_path.joinpath(filename) if args.format == 'ONNX': - create_config_from_onnx_models(model_path_or_dir, config_path) + model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".onnx")) + create_config_from_onnx_models(model_files, config_path) else: from util.ort_format_model import create_config_from_models as create_config_from_ort_models - create_config_from_ort_models(model_path_or_dir, config_path, args.enable_type_reduction) + + model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".ort")) + create_config_from_ort_models(model_files, config_path, args.enable_type_reduction) # Debug code to validate that the config parsing matches # from util import parse_config diff --git a/tools/python/util/check_onnx_model_mobile_usability.py b/tools/python/util/check_onnx_model_mobile_usability.py index 3f62c53b9f9d2..7042d4cd2d018 100644 --- a/tools/python/util/check_onnx_model_mobile_usability.py +++ b/tools/python/util/check_onnx_model_mobile_usability.py @@ -42,17 +42,18 @@ def check_usability(): try_eps = usability_checker.analyze_model(args.model_path, skip_optimize=False, logger=logger) check_model_can_use_ort_mobile_pkg.run_check(args.model_path, args.config_path, logger) - logger.info("Run `python -m onnxruntime.tools.convert_onnx_models_to_ort ...` to convert the ONNX model to " - "ORT format. By default, the conversion tool will create an ORT format model optimized to " - "'basic' level (with a .basic.ort file extension) for use with NNAPI or CoreML, " - "and an ORT format model optimized to 'all' level (with a .all.ort file extension) for use with " - "the CPU EP.") + logger.info("Run `python -m onnxruntime.tools.convert_onnx_models_to_ort ...` to convert the ONNX model to ORT " + "format. " + "By default, the conversion tool will create an ORT format model with saved optimizations which can " + "potentially be applied at runtime (with a .with_runtime_opt.ort file extension) for use with NNAPI " + "or CoreML, and a fully optimized ORT format model (with a .ort file extension) for use with the CPU " + "EP.") if try_eps: logger.info("As NNAPI or CoreML may provide benefits with this model it is recommended to compare the " - "performance of the .basic.ort model using the NNAPI EP on Android, and the " - "CoreML EP on iOS, against the performance of the .all.ort model using the CPU EP.") + "performance of the .with_runtime_opt.ort model using the NNAPI EP on Android, and the " + "CoreML EP on iOS, against the performance of the .ort model using the CPU EP.") else: - logger.info("For optimal performance the .all.ort model should be used with the CPU EP. ") + logger.info("For optimal performance the .ort model should be used with the CPU EP. ") if __name__ == '__main__': diff --git a/tools/python/util/convert_onnx_models_to_ort.py b/tools/python/util/convert_onnx_models_to_ort.py index 48c7d6552b9d2..0645bac1b48f5 100644 --- a/tools/python/util/convert_onnx_models_to_ort.py +++ b/tools/python/util/convert_onnx_models_to_ort.py @@ -3,44 +3,37 @@ # Licensed under the MIT License. import argparse +import contextlib +import enum import os import pathlib +import tempfile import typing import onnxruntime as ort -from .ort_format_model import create_config_from_models +from .file_utils import files_from_file_or_dir, path_match_suffix_ignore_case from .onnx_model_utils import get_optimization_level +from .ort_format_model import create_config_from_models -def _path_match_suffix_ignore_case(path: typing.Union[pathlib.Path, str], suffix: str): - if not isinstance(path, str): - path = str(path) - return path.casefold().endswith(suffix.casefold()) - +class OptimizationStyle(enum.Enum): + Fixed = 0 + Runtime = 1 -def _onnx_model_path_to_ort_model_path(onnx_model_path: pathlib.Path, optimization_level_str: str): - assert onnx_model_path.is_file() and _path_match_suffix_ignore_case(onnx_model_path, ".onnx") - return onnx_model_path.with_suffix(".{}.ort".format(optimization_level_str)) +def _optimization_suffix(optimization_style: OptimizationStyle, suffix: str): + return "{}{}".format(".with_runtime_opt" if optimization_style == OptimizationStyle.Runtime else "", + suffix) -def _create_config_file_from_ort_models(onnx_model_path_or_dir: pathlib.Path, optimization_level: str, - enable_type_reduction: bool): - if onnx_model_path_or_dir.is_dir(): - # model directory - model_path_or_dir = onnx_model_path_or_dir - config_path = None # default path in model directory - else: - # single model - model_path_or_dir = _onnx_model_path_to_ort_model_path(onnx_model_path_or_dir, optimization_level) - suffix = f'.{optimization_level}.config' - config_suffix = ".{}{}".format( - 'required_operators_and_types' if enable_type_reduction else 'required_operators', suffix) - config_path = model_path_or_dir.with_suffix(config_suffix) - create_config_from_models(model_path_or_dir=str(model_path_or_dir), - output_file=str(config_path) if config_path is not None else None, - enable_type_reduction=enable_type_reduction, - optimization_level=optimization_level) +def _create_config_file_path(model_path_or_dir: pathlib.Path, + optimization_style: OptimizationStyle, + enable_type_reduction: bool): + config_name = "{}{}".format('required_operators_and_types' if enable_type_reduction else 'required_operators', + _optimization_suffix(optimization_style, ".config")) + if model_path_or_dir.is_dir(): + return model_path_or_dir / config_name + return model_path_or_dir.with_suffix(f".{config_name}") def _create_session_options(optimization_level: ort.GraphOptimizationLevel, @@ -60,31 +53,33 @@ def _create_session_options(optimization_level: ort.GraphOptimizationLevel, return so -def _convert(model_path_or_dir: pathlib.Path, optimization_level_str: str, use_nnapi: bool, use_coreml: bool, +def _convert(model_path_or_dir: pathlib.Path, output_dir: typing.Optional[pathlib.Path], + optimization_level_str: str, optimization_style: OptimizationStyle, custom_op_library: pathlib.Path, create_optimized_onnx_model: bool, allow_conversion_failures: bool, - target_platform: str, session_options_config_entries: typing.Dict[str, str]): + target_platform: str, session_options_config_entries: typing.Dict[str, str]) \ + -> typing.List[pathlib.Path]: + + model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent + output_dir = output_dir or model_dir optimization_level = get_optimization_level(optimization_level_str) - models = [] - if model_path_or_dir.is_file() and _path_match_suffix_ignore_case(model_path_or_dir, ".onnx"): - models.append(model_path_or_dir) - elif model_path_or_dir.is_dir(): - for root, _, files in os.walk(model_path_or_dir): - for file in files: - if _path_match_suffix_ignore_case(file, ".onnx"): - models.append(pathlib.Path(root, file)) + def is_model_file_to_convert(file_path: pathlib.Path): + if not path_match_suffix_ignore_case(file_path, ".onnx"): + return False + # ignore any files with an extension of .optimized.onnx which are presumably from previous executions + # of this script + if path_match_suffix_ignore_case(file_path, ".optimized.onnx"): + print(f"Ignoring '{file_path}'") + return False + return True + + models = files_from_file_or_dir(model_path_or_dir, is_model_file_to_convert) if len(models) == 0: - raise ValueError("No .onnx files were found in '{}'".format(model_path_or_dir)) + raise ValueError("No model files were found in '{}'".format(model_path_or_dir)) providers = ['CPUExecutionProvider'] - if use_nnapi: - # providers are priority based, so register NNAPI first - providers.insert(0, 'NnapiExecutionProvider') - if use_coreml: - # providers are priority based, so register CoreML first - providers.insert(0, 'CoreMLExecutionProvider') # if the optimization level is 'all' we manually exclude the NCHWc transformer. It's not applicable to ARM # devices, and creates a device specific model which won't run on all hardware. @@ -94,26 +89,29 @@ def _convert(model_path_or_dir: pathlib.Path, optimization_level_str: str, use_n if optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_ALL and target_platform != 'amd64': optimizer_filter = ['NchwcTransformer'] - num_failures = 0 + converted_models = [] for model in models: try: - # ignore any files with an extension of .optimized.onnx which are presumably from previous executions - # of this script - if _path_match_suffix_ignore_case(model, ".optimized.onnx"): - print("Ignoring '{}'".format(model)) - continue + relative_model_path = model.relative_to(model_dir) - # create .ort file in same dir as original onnx model - ort_target_path = _onnx_model_path_to_ort_model_path(model, optimization_level_str) + (output_dir / relative_model_path).parent.mkdir(parents=True, exist_ok=True) + + ort_target_path = (output_dir / relative_model_path).with_suffix( + _optimization_suffix(optimization_style, ".ort")) if create_optimized_onnx_model: - # Create an ONNX file with the same optimizations that will be used for the ORT format file. + # Create an ONNX file with the same optimization level that will be used for the ORT format file. # This allows the ONNX equivalent of the ORT format model to be easily viewed in Netron. - optimized_target_path = model.with_suffix(".{}.optimized.onnx".format(optimization_level_str)) + # If runtime optimizations are saved in the ORT format model, there may be some difference in the + # graphs at runtime between the ORT format model and this saved ONNX model. + optimized_target_path = (output_dir / relative_model_path).with_suffix(".optimized.onnx") so = _create_session_options(optimization_level, optimized_target_path, custom_op_library, session_options_config_entries) + if optimization_style == OptimizationStyle.Runtime: + # Limit the optimizations to those that can run in a model with runtime optimizations. + so.add_session_config_entry('optimization.minimal_build_optimizations', 'apply') print("Saving optimized ONNX model {} to {}".format(model, optimized_target_path)) _ = ort.InferenceSession(str(model), sess_options=so, providers=providers, @@ -123,11 +121,15 @@ def _convert(model_path_or_dir: pathlib.Path, optimization_level_str: str, use_n so = _create_session_options(optimization_level, ort_target_path, custom_op_library, session_options_config_entries) so.add_session_config_entry('session.save_model_format', 'ORT') + if optimization_style == OptimizationStyle.Runtime: + so.add_session_config_entry('optimization.minimal_build_optimizations', 'save') print("Converting optimized ONNX model {} to ORT format model {}".format(model, ort_target_path)) _ = ort.InferenceSession(str(model), sess_options=so, providers=providers, disabled_optimizers=optimizer_filter) + converted_models.append(ort_target_path) + # orig_size = os.path.getsize(onnx_target_path) # new_size = os.path.getsize(ort_target_path) # print("Serialized {} to {}. Sizes: orig={} new={} diff={} new:old={:.4f}:1.0".format( @@ -136,9 +138,10 @@ def _convert(model_path_or_dir: pathlib.Path, optimization_level_str: str, use_n print("Error converting {}: {}".format(model, e)) if not allow_conversion_failures: raise - num_failures += 1 - print("Converted {} models. {} failures.".format(len(models), num_failures)) + print("Converted {}/{} models successfully.".format(len(converted_models), len(models))) + + return converted_models def parse_args(): @@ -146,38 +149,28 @@ def parse_args(): os.path.basename(__file__), description='''Convert the ONNX format model/s in the provided directory to ORT format models. All files with a `.onnx` extension will be processed. For each one, an ORT format model will be created in the - same directory. A configuration file will also be created called `required_operators.config`, and will contain - the list of required operators for all converted models. - This configuration file should be used as input to the minimal build via the `--include_ops_by_config` - parameter. + same directory. A configuration file will also be created containing the list of required operators for all + converted models. This configuration file should be used as input to the minimal build via the + `--include_ops_by_config` parameter. ''' ) - parser.add_argument('--use_nnapi', action='store_true', - help='Enable the NNAPI Execution Provider when creating models and determining required ' - 'operators. Note that this will limit the optimizations possible on nodes that the ' - 'NNAPI execution provider takes, in order to preserve those nodes in the ORT format ' - 'model.') - - parser.add_argument('--use_coreml', action='store_true', - help='Enable the CoreML Execution Provider when creating models and determining required ' - 'operators. Note that this will limit the optimizations possible on nodes that the ' - 'CoreML execution provider takes, in order to preserve those nodes in the ORT format ' - 'model.') - - parser.add_argument('--optimization_level', default=['basic', 'all'], nargs='+', - choices=['disable', 'basic', 'extended', 'all'], - help="Level to optimize ONNX model with, prior to converting to ORT format model. " - "These map to the onnxruntime.GraphOptimizationLevel values. " - "If the level is 'all' the NCHWc transformer is manually disabled as it contains device " - "specific logic, so the ORT format model must be generated on the device it will run on. " - "Additionally, the NCHWc optimizations are not applicable to ARM devices. " - "Multiple values can be provided. A model produced with 'all' is optimal for usage with " - "just the CPU Execution Provider. A model produced with 'basic' is required for usage " - "with the NNAPI or CoreML Execution Providers. " - "The filename for the ORT format model will contain the optimization level that was used " - "to create it." - ) + parser.add_argument('--optimization_style', + nargs='+', + default=[OptimizationStyle.Fixed.name, OptimizationStyle.Runtime.name], + choices=[e.name for e in OptimizationStyle], + help="Style of optimization to perform on the ORT format model. " + "Multiple values may be provided. The conversion will run once for each value. " + "The general guidance is to use models optimized with " + f"'{OptimizationStyle.Runtime.name}' style when using NNAPI or CoreML and " + f"'{OptimizationStyle.Fixed.name}' style otherwise. " + f"'{OptimizationStyle.Fixed.name}': Run optimizations directly before saving the ORT " + "format model. This bakes in any platform-specific optimizations. " + f"'{OptimizationStyle.Runtime.name}': Run basic optimizations directly and save certain " + "other optimizations to be applied at runtime if possible. This is useful when using a " + "compiling EP like NNAPI or CoreML that may run an unknown (at model conversion time) " + "number of nodes. The saved optimizations can further optimize nodes not assigned to the " + "compiling EP at runtime.") parser.add_argument('--enable_type_reduction', action='store_true', help='Add operator specific type information to the configuration file to potentially reduce ' @@ -188,7 +181,7 @@ def parse_args(): parser.add_argument('--save_optimized_onnx_model', action='store_true', help='Save the optimized version of each ONNX model. ' - 'This will have the same optimizations applied as the ORT format model.') + 'This will have the same level of optimizations applied as the ORT format model.') parser.add_argument('--allow_conversion_failures', action='store_true', help='Whether to proceed after encountering model conversion failures.') @@ -200,13 +193,14 @@ def parse_args(): parser.add_argument('--target_platform', type=str, default=None, choices=['arm', 'amd64'], help='Specify the target platform where the exported model will be used. ' - 'This parameter can be used to choose between platform specific options, ' - 'such as QDQIsInt8Allowed(arm), NCHWc (amd64) and NHWC (arm/amd64) format different ' - 'optimizer level options,etc.') + 'This parameter can be used to choose between platform-specific options, ' + 'such as QDQIsInt8Allowed(arm), NCHWc (amd64) and NHWC (arm/amd64) format, different ' + 'optimizer level options, etc.') parser.add_argument('model_path_or_dir', type=pathlib.Path, help='Provide path to ONNX model or directory containing ONNX model/s to convert. ' - 'All files with a .onnx extension, including in subdirectories, will be processed.') + 'All files with a .onnx extension, including those in subdirectories, will be ' + 'processed.') return parser.parse_args() @@ -214,6 +208,8 @@ def parse_args(): def convert_onnx_models_to_ort(): args = parse_args() + optimization_styles = [OptimizationStyle[style_str] for style_str in args.optimization_style] + optimization_level_str = 'all' model_path_or_dir = args.model_path_or_dir.resolve() custom_op_library = args.custom_op_library.resolve() if args.custom_op_library else None @@ -223,12 +219,6 @@ def convert_onnx_models_to_ort(): if custom_op_library and not custom_op_library.is_file(): raise FileNotFoundError("Unable to find custom operator library '{}'".format(custom_op_library)) - if args.use_nnapi and 'NnapiExecutionProvider' not in ort.get_available_providers(): - raise ValueError('The NNAPI Execution Provider was not included in this build of ONNX Runtime.') - - if args.use_coreml and 'CoreMLExecutionProvider' not in ort.get_available_providers(): - raise ValueError('The CoreML Execution Provider was not included in this build of ONNX Runtime.') - session_options_config_entries = {} if args.nnapi_partitioning_stop_ops is not None: @@ -239,13 +229,49 @@ def convert_onnx_models_to_ort(): else: session_options_config_entries["session.qdqisint8allowed"] = "0" - for optimization_level in args.optimization_level: - print(f"Converting models and creating configuration file for optimization level '{optimization_level}'") - _convert(model_path_or_dir, optimization_level, args.use_nnapi, args.use_coreml, custom_op_library, - args.save_optimized_onnx_model, args.allow_conversion_failures, args.target_platform, - session_options_config_entries) - - _create_config_file_from_ort_models(model_path_or_dir, optimization_level, args.enable_type_reduction) + for optimization_style in optimization_styles: + print("Converting models with optimization style '{}' and level '{}'".format( + optimization_style.name, optimization_level_str)) + + converted_models = _convert( + model_path_or_dir=model_path_or_dir, output_dir=None, + optimization_level_str=optimization_level_str, optimization_style=optimization_style, + custom_op_library=custom_op_library, + create_optimized_onnx_model=args.save_optimized_onnx_model, + allow_conversion_failures=args.allow_conversion_failures, + target_platform=args.target_platform, + session_options_config_entries=session_options_config_entries) + + with contextlib.ExitStack() as context_stack: + if optimization_style == OptimizationStyle.Runtime: + # Convert models again without runtime optimizations. + # Runtime optimizations may not end up being applied, so we need to use both converted models with and + # without runtime optimizations to get a complete set of ops that may be needed for the config file. + model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent + temp_output_dir = context_stack.enter_context( + tempfile.TemporaryDirectory(dir=model_dir, suffix=".without_runtime_opt")) + session_options_config_entries_for_second_conversion = session_options_config_entries.copy() + # Limit the optimizations to those that can run in a model with runtime optimizations. + session_options_config_entries_for_second_conversion[ + "optimization.minimal_build_optimizations"] = "apply" + + print("Converting models again without runtime optimizations to generate a complete config file. " + "These converted models are temporary and will be deleted.") + converted_models += _convert( + model_path_or_dir=model_path_or_dir, output_dir=temp_output_dir, + optimization_level_str=optimization_level_str, optimization_style=OptimizationStyle.Fixed, + custom_op_library=custom_op_library, + create_optimized_onnx_model=False, # not useful as they would be created in a temp directory + allow_conversion_failures=args.allow_conversion_failures, + target_platform=args.target_platform, + session_options_config_entries=session_options_config_entries_for_second_conversion) + + print("Generating config file from ORT format models with optimization style '{}' and level '{}'".format( + optimization_style.name, optimization_level_str)) + + config_file = _create_config_file_path(model_path_or_dir, optimization_style, args.enable_type_reduction) + + create_config_from_models(converted_models, config_file, args.enable_type_reduction) if __name__ == '__main__': diff --git a/tools/python/util/file_utils.py b/tools/python/util/file_utils.py new file mode 100644 index 0000000000000..73505b73369bb --- /dev/null +++ b/tools/python/util/file_utils.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import pathlib +import typing +import os + + +def path_match_suffix_ignore_case(path: typing.Union[pathlib.Path, str], suffix: str) -> bool: + ''' + Returns whether `path` ends in `suffix`, ignoring case. + ''' + if not isinstance(path, str): + path = str(path) + return path.casefold().endswith(suffix.casefold()) + + +def files_from_file_or_dir(file_or_dir_path: typing.Union[pathlib.Path, str], + predicate: typing.Callable[[pathlib.Path], bool] = lambda _: True) \ + -> typing.List[pathlib.Path]: + ''' + Gets the files in `file_or_dir_path` satisfying `predicate`. + If `file_or_dir_path` is a file, the single file is considered. Otherwise, all files in the directory are + considered. + :param file_or_dir_path: Path to a file or directory. + :param predicate: Predicate to determine if a file is included. + :return: A list of files. + ''' + if not isinstance(file_or_dir_path, pathlib.Path): + file_or_dir_path = pathlib.Path(file_or_dir_path) + + selected_files = [] + + def process_file(file_path: pathlib.Path): + if predicate(file_path): + selected_files.append(file_path) + + if file_or_dir_path.is_dir(): + for root, _, files in os.walk(file_or_dir_path): + for file in files: + file_path = pathlib.Path(root, file) + process_file(file_path) + else: + process_file(file_or_dir_path) + + return selected_files diff --git a/tools/python/util/ort_format_model/utils.py b/tools/python/util/ort_format_model/utils.py index a6d3c2c8682bc..2be004dc9cfaf 100644 --- a/tools/python/util/ort_format_model/utils.py +++ b/tools/python/util/ort_format_model/utils.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os +import pathlib import typing from .operator_type_usage_processors import OperatorTypeUsageManager @@ -11,72 +11,36 @@ log = get_logger("ort_format_model.utils") -def _extract_ops_and_types_from_ort_models(model_path_or_dir: str, enable_type_reduction: bool, - optimization_level: str = None): - if not os.path.exists(model_path_or_dir): - raise ValueError('Path to model/s does not exist: {}'.format(model_path_or_dir)) - +def _extract_ops_and_types_from_ort_models(model_files: typing.Iterable[pathlib.Path], enable_type_reduction: bool): required_ops = {} op_type_usage_manager = OperatorTypeUsageManager() if enable_type_reduction else None - suffix = f'.{optimization_level}.ort' if optimization_level else '.ort' - if os.path.isfile(model_path_or_dir): - if model_path_or_dir.lower().endswith(suffix): - model_processor = OrtFormatModelProcessor(model_path_or_dir, required_ops, op_type_usage_manager) - model_processor.process() # this updates required_ops and op_type_processors - log.info('Processed {}'.format(model_path_or_dir)) - else: - log.debug('Skipped {}'.format(model_path_or_dir)) - else: - for root, _, files in os.walk(model_path_or_dir): - for file in files: - model_path = os.path.join(root, file) - if file.lower().endswith(suffix): - model_processor = OrtFormatModelProcessor(model_path, required_ops, op_type_usage_manager) - model_processor.process() # this updates required_ops and op_type_processors - log.info('Processed {}'.format(model_path)) - else: - log.debug('Skipped {}'.format(model_path)) + for model_file in model_files: + if not model_file.is_file(): + raise ValueError(f"Path is not a file: '{model_file}'") + model_processor = OrtFormatModelProcessor(str(model_file), required_ops, op_type_usage_manager) + model_processor.process() # this updates required_ops and op_type_processors return required_ops, op_type_usage_manager -def create_config_from_models(model_path_or_dir: str, output_file: str = None, enable_type_reduction: bool = True, - optimization_level: typing.Optional[str] = None): +def create_config_from_models(model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, + enable_type_reduction: bool): ''' Create a configuration file with required operators and optionally required types. - :param model_path_or_dir: Path to recursively search for ORT format models, or to a single ORT format model. + :param model_files: Model files to use to generate the configuration file. :param output_file: File to write configuration to. - Defaults to creating required_operators[_and_types].config in the model_path_or_dir directory. :param enable_type_reduction: Include required type information for individual operators in the configuration. - :param optimization_level: Filter files and adjust default output_file based on the optimization level. If set, - looks for '..ort' as the file suffix. Uses '..config' as the config - file suffix. - When we convert models we include the optimization level in the filename. When creating the configuration - we only want to create it for the specific optimization level so that we don't include irrelevant operators. ''' - required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_path_or_dir, enable_type_reduction, - optimization_level) - - if output_file: - directory, filename = os.path.split(output_file) - if not filename: - raise RuntimeError("Invalid output path for configuration: {}".format(output_file)) - - if directory and not os.path.exists(directory): - os.makedirs(directory) - else: - dir = model_path_or_dir - if os.path.isfile(model_path_or_dir): - dir = os.path.dirname(model_path_or_dir) + required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_files, enable_type_reduction) - suffix = f'.{optimization_level}.config' if optimization_level else '.config' - output_file = os.path.join( - dir, ('required_operators_and_types' if enable_type_reduction else 'required_operators') + suffix) + output_file.parent.mkdir(parents=True, exist_ok=True) with open(output_file, 'w') as out: - out.write("# Generated from model/s in {}\n".format(model_path_or_dir)) + out.write("# Generated from model/s:\n") + for model_file in sorted(model_files): + out.write(f"# - {model_file}\n") for domain in sorted(required_ops.keys()): for opset in sorted(required_ops[domain].keys()):