diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 8a6f4fdd16e39..c629fe91d64a8 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -937,7 +937,7 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { + if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; @@ -973,15 +973,36 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat const FbsSessionStateViewer fbs_session_state_viewer{fbs_session_state}; ORT_RETURN_IF_ERROR(fbs_session_state_viewer.Validate()); - auto add_kernel_by_hash = - [&kernel_registry_manager, this](const Node& node, HashValue hash) { + // look up KernelCreateInfo with hash and + // - add KernelCreateInfo for node + // - set node's EP from KernelCreateInfo if unset + auto add_kernel_and_set_node_ep_by_hash = + [&kernel_registry_manager, this](Node& node, HashValue hash) { const KernelCreateInfo* kci = nullptr; utils::UpdateHashForBackwardsCompatibility(hash); ORT_RETURN_IF_NOT(kernel_registry_manager.SearchKernelRegistriesByHash(hash, &kci), "Failed to find kernel def hash (", hash, ") in kernel registries for ", node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'."); - kernel_create_info_map_.emplace(node.Index(), gsl::not_null(kci)); + + { + const auto [it, inserted] = kernel_create_info_map_.emplace(node.Index(), + gsl::not_null(kci)); + ORT_RETURN_IF_NOT(inserted, + "Cannot overwrite existing kernel for ", + node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), + "'. Existing kernel def hash: ", it->second->kernel_def->GetHash(), + ", new kernel def hash: ", hash, "."); + } + + if (node.GetExecutionProviderType().empty()) { + node.SetExecutionProviderType(kci->kernel_def->Provider()); + } else { + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kci->kernel_def->Provider(), + "Node execution provider type mismatch. Existing: ", node.GetExecutionProviderType(), + ", from KernelCreateInfo (via hash lookup): ", kci->kernel_def->Provider()); + } + return Status::OK(); }; @@ -1003,46 +1024,46 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat continue; } - ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, node_kernel_info.kernel_def_hash)); + ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(*node, node_kernel_info.kernel_def_hash)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // process the nodes that were added by replaying any loaded runtime optimizations for (const auto& [node_index, kernel_def_hash] : graph_.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) { - const auto* node = graph_.GetNode(node_index); + auto* node = graph_.GetNode(node_index); // NHWC optimizer may replace a node, so a missing node isn't necessarily an error // ORT_RETURN_IF(node == nullptr, "Can't find runtime optimization produced node with index ", node_index); if (node != nullptr) { - ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_def_hash)); + ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(*node, kernel_def_hash)); } } - // lookup the hashes for any nodes we compiled or added during graph partitioning. - // These node indexes for compiled nodes as well as newly added nodes are not in node_indices - // as they were created at runtime. - for (const auto& node : graph_.Nodes()) { - if (kernel_create_info_map_.count(node.Index()) == 0) { - if (node.Domain() == kOnnxDomain || node.Domain() == kMSDomain) { - // two possible places to get hash from - auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); - if (!kernel_hash.has_value()) { - kernel_hash = utils::GetInternalNhwcOpHash(node); - } + // Look up the hashes for any nodes we compiled or added during graph partitioning or other runtime optimizations. + // These nodes are not in the original model as they were created at runtime. + for (auto& node : graph_.Nodes()) { + if (kernel_create_info_map_.find(node.Index()) != kernel_create_info_map_.end()) { + continue; + } - if (kernel_hash.has_value()) { - ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, *kernel_hash)); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to find kernel hash for node:", node.Name(), " optype:", node.OpType()); - } - } else { - const auto hash_info = compiled_kernel_hashes.find(node.OpType()); - ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(), - "Unable to find compiled kernel hash for node '", node.Name(), "'."); - ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second)); + if (node.Domain() == kOnnxDomain || node.Domain() == kMSDomain) { + // two possible places to get hash from + auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); + if (!kernel_hash.has_value()) { + kernel_hash = utils::GetInternalNhwcOpHash(node); } + ORT_RETURN_IF_NOT(kernel_hash.has_value(), + "Unable to find kernel hash for node: '", node.Name(), "' optype: ", node.OpType()); + + ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(node, *kernel_hash)); + } else { + const auto hash_info = compiled_kernel_hashes.find(node.OpType()); + ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(), + "Unable to find compiled kernel hash for node '", node.Name(), "'."); + + ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(node, hash_info->second)); } } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1088,6 +1109,68 @@ static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordere } } +using NodePlacementMap = std::unordered_map>; + +static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_verbose, + NodePlacementMap& node_placements) { + for (const auto& node : graph.Nodes()) { + const auto& node_provider = node.GetExecutionProviderType(); + if (node_provider.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Could not find an implementation for ", + node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'"); + } + +#if !defined(ORT_MINIMAL_BUILD) + if (is_verbose) { // TODO: should we disable this if the number of nodes is above a certain threshold? + const std::string node_str = node.OpType() + " (" + node.Name() + ")"; + node_placements[node_provider].push_back(node_str); + } +#endif // !defined(ORT_MINIMAL_BUILD) + + // recurse into subgraphs + if (node.ContainsSubgraph()) { + const auto subgraphs = node.GetSubgraphs(); + for (const auto& subgraph : subgraphs) { + ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(*subgraph, is_verbose, node_placements)); + } + } + } + + return Status::OK(); +} + +static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::Logger& logger) { + NodePlacementMap node_placements{}; +#if !defined(ORT_MINIMAL_BUILD) + const bool is_verbose_mode = logger.GetSeverity() == logging::Severity::kVERBOSE; +#else + ORT_UNUSED_PARAMETER(logger); + const bool is_verbose_mode = false; +#endif // !defined(ORT_MINIMAL_BUILD) + + ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(graph, is_verbose_mode, node_placements)); + +#if !defined(ORT_MINIMAL_BUILD) + // print placement info + if (is_verbose_mode) { + LOGS(logger, VERBOSE) << "Node placements"; + if (node_placements.size() == 1) { + LOGS(logger, VERBOSE) << "All nodes have been placed on [" << node_placements.begin()->first << "]."; + } else { + for (const auto& [provider, node_strs] : node_placements) { + std::ostringstream all_nodes_str; + std::copy(node_strs.begin(), node_strs.end(), std::ostream_iterator(all_nodes_str, ", ")); + LOGS(logger, VERBOSE) << " Provider: [" << provider << "]" + << ": [" << all_nodes_str.str() << "]"; + } + } + } +#endif // !defined(ORT_MINIMAL_BUILD) + + return Status::OK(); +} + Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, const KernelRegistryManager& kernel_registry_manager, const SessionOptions& session_options, @@ -1101,8 +1184,11 @@ Status SessionState::FinalizeSessionState(const std::basic_string GetCurrentTimeString() { return std::basic_string(time_str); } -using NodePlacementMap = std::unordered_map>; - -Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_verbose, - NodePlacementMap& node_placements) { - for (const auto& node : graph.Nodes()) { - const auto& node_provider = node.GetExecutionProviderType(); - if (node_provider.empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "Could not find an implementation for ", - node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'"); - } - -#if !defined(ORT_MINIMAL_BUILD) - if (is_verbose) { // TODO: should we disable this if the number of nodes is above a certain threshold? - const std::string node_str = node.OpType() + " (" + node.Name() + ")"; - node_placements[node_provider].push_back(node_str); - } -#endif // !defined(ORT_MINIMAL_BUILD) - - // recurse into subgraphs - if (node.ContainsSubgraph()) { - const auto subgraphs = node.GetSubgraphs(); - for (const auto& subgraph : subgraphs) { - ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(*subgraph, is_verbose, node_placements)); - } - } - } - - return Status::OK(); -} - -Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::Logger& logger) { - NodePlacementMap node_placements{}; -#if !defined(ORT_MINIMAL_BUILD) - const bool is_verbose_mode = logger.GetSeverity() == logging::Severity::kVERBOSE; -#else - ORT_UNUSED_PARAMETER(logger); - const bool is_verbose_mode = false; -#endif // !defined(ORT_MINIMAL_BUILD) - - const auto status = VerifyEachNodeIsAssignedToAnEpImpl(graph, is_verbose_mode, node_placements); - -#if !defined(ORT_MINIMAL_BUILD) - // print placement info - if (is_verbose_mode) { - LOGS(logger, VERBOSE) << "Node placements"; - if (node_placements.size() == 1) { - LOGS(logger, VERBOSE) << "All nodes have been placed on [" << node_placements.begin()->first << "]."; - } else { - for (const auto& pr : node_placements) { - std::ostringstream all_nodes_str; - std::copy(pr.second.begin(), pr.second.end(), std::ostream_iterator(all_nodes_str, ", ")); - LOGS(logger, VERBOSE) << " Provider: [" << pr.first << "]" - << ": [" << all_nodes_str.str() << "]"; - } - } - } -#endif // !defined(ORT_MINIMAL_BUILD) - - return status; -} - #if !defined(ORT_MINIMAL_BUILD) bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { @@ -988,8 +926,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, // Insert cast node/s. ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified, *session_logger_)); - ORT_RETURN_IF_ERROR_SESSIONID_(VerifyEachNodeIsAssignedToAnEp(graph, *session_logger_)); - std::vector provider_types; for (auto& provider_ptr : providers) { provider_types.push_back(provider_ptr->Type()); @@ -1231,75 +1167,6 @@ Status ApplyOrtFormatModelRuntimeOptimizations( return Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - -Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs_session_state, - const KernelRegistryManager& kernel_registry_manager) { - using fbs::utils::FbsSessionStateViewer; - const FbsSessionStateViewer fbs_session_state_viewer{fbs_session_state}; - ORT_RETURN_IF_ERROR(fbs_session_state_viewer.Validate()); - - for (auto& node : graph.Nodes()) { - for (auto& [attribute, subgraph] : node.GetAttributeNameToMutableSubgraphMap()) { - const fbs::SessionState* fbs_subgraph_session_state; - ORT_RETURN_IF_ERROR(fbs_session_state_viewer.GetSubgraphSessionState(node.Index(), attribute, - fbs_subgraph_session_state)); - - ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashesImpl(*subgraph, *fbs_subgraph_session_state, - kernel_registry_manager)); - } - } - - const auto set_node_ep = [&](NodeIndex node_idx, HashValue kernel_def_hash) -> Status { - Node* node = graph.GetNode(node_idx); - if (!node || !node->GetExecutionProviderType().empty()) { - return Status::OK(); - } - - const KernelCreateInfo* kci = nullptr; - ORT_RETURN_IF_NOT(kernel_registry_manager.SearchKernelRegistriesByHash(kernel_def_hash, &kci), - "Failed to find kernel def hash (", kernel_def_hash, ") in kernel registries for ", - node->OpType(), "(", node->SinceVersion(), ") node with name '", node->Name(), "'."); - node->SetExecutionProviderType(kci->kernel_def->Provider()); - - return Status::OK(); - }; - - for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) { - const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i); - ORT_RETURN_IF_ERROR(set_node_ep(node_kernel_info.node_index, node_kernel_info.kernel_def_hash)); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - for (const auto& [node_index, kernel_def_hash] : - graph.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) { - ORT_RETURN_IF_ERROR(set_node_ep(node_index, kernel_def_hash)); - } - - // layout transformer which is enabled in extended minimal build can add new nodes. - // The following loop fetches the hash values for these nodes. - for (const auto& node : graph.Nodes()) { - if (node.GetExecutionProviderType().empty()) { - auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion()); - if (!kernel_hash.has_value()) { - kernel_hash = utils::GetInternalNhwcOpHash(node); - } - if (kernel_hash.has_value()) { - ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value())); - } - } - } -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - - return Status::OK(); -} - -Status AssignNodesToEpsFromHashes(Graph& graph, const fbs::SessionState& fbs_session_state, - const KernelRegistryManager& kernel_registry_manager, - const logging::Logger& logger) { - ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashesImpl(graph, fbs_session_state, kernel_registry_manager)); - ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph, logger)); - return Status::OK(); -} } // namespace static void ResolveMemoryPatternFlags(SessionState& session_state) { @@ -1518,9 +1385,6 @@ common::Status InferenceSession::Initialize() { ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - - ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashes(graph, *serialized_session_state, kernel_registry_manager_, - *session_logger_)); } ORT_RETURN_IF_ERROR_SESSIONID_( diff --git a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc index eb01ec687b620..2d8457721f39f 100644 --- a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc +++ b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc @@ -211,7 +211,7 @@ void SaveAndLoadRuntimeOptimizationsForModel( const GraphOpCountsCheckerFn& graph_op_counts_checker_for_replay) { auto run_test = [&](bool do_save) { // the two versions of the saved runtime optimizations file should be the same - // the one without the ".generated" suffix is checked in and the other is generated by the test + // the one with the ".generated" suffix is generated by the test and the other is checked in const PathString saved_runtime_optimizations_model_path = do_save ? ort_model_with_runtime_opt_path + ORT_TSTR(".generated") : ort_model_with_runtime_opt_path; @@ -255,52 +255,37 @@ void SaveAndLoadRuntimeOptimizationsForModel( } // if level 3 optimizations are enabled the NHWC transformer should convert the QLinearConv nodes to use channels_last -void CheckNhwcTransformerIsApplied() { - const auto saved_runtime_optimizations_model_path = - ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.ort"); +void CheckNhwcTransformerIsApplied(const PathString& ort_model_path, + const GraphOpCountsCheckerFn& graph_op_counts_checker) { + SCOPED_TRACE(MakeString("ORT format model: ", ToUTF8String(ort_model_path))); // load and replay runtime optimizations - { - SessionOptions so{}; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT")); - so.graph_optimization_level = TransformerLevel::Level3; - - GraphCheckerFn checker_fn = [](const Graph& graph) { - for (const auto& node : graph.Nodes()) { - if (node.OpType() == "QLinearConv") { - EXPECT_EQ(node.Domain(), kMSDomain); - bool has_channels_last_set = false; - for (const auto& attr : node.GetAttributes()) { - if (attr.first == "channels_last") { - EXPECT_EQ(attr.second.i(), 1); - has_channels_last_set = true; - break; - } + SessionOptions so{}; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsConfigLoadModelFormat, "ORT")); + so.graph_optimization_level = TransformerLevel::Level3; + + GraphCheckerFn graph_checker = [](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "QLinearConv") { + EXPECT_EQ(node.Domain(), kMSDomain); + bool has_channels_last_set = false; + for (const auto& attr : node.GetAttributes()) { + if (attr.first == "channels_last") { + EXPECT_EQ(attr.second.i(), 1); + has_channels_last_set = true; + break; } - EXPECT_TRUE(has_channels_last_set); } + EXPECT_TRUE(has_channels_last_set); } - }; + } + }; - ASSERT_NO_FATAL_FAILURE(LoadAndInitializeSession( - so, saved_runtime_optimizations_model_path, - [](const OpCountMap& loaded_ops, const OpCountMap& initialized_ops) { - constexpr int n = 3; // expected number of QDQ Convs to fuse - - EXPECT_EQ(loaded_ops, - (OpCountMap{{"DequantizeLinear", n * 3}, - {"QuantizeLinear", n}, - {"Conv", n}})); - - // should have internal version of QLinearConv that runs NHWC, and transposes around each of those nodes - // for the layout conversion. - EXPECT_EQ(initialized_ops, - (OpCountMap{{"Transpose", 6}, - {"com.microsoft.QLinearConv", n}})); - }, - checker_fn)); - } -} + ASSERT_NO_FATAL_FAILURE(LoadAndInitializeSession( + so, ort_model_path, + graph_op_counts_checker, + graph_checker)); +}; } // namespace TEST(GraphRuntimeOptimizationTest, QDQConv) { @@ -340,7 +325,46 @@ TEST(GraphRuntimeOptimizationTest, ConvActivation) { } TEST(GraphRuntimeOptimizationTest, TestNhwcTransformer) { - CheckNhwcTransformerIsApplied(); + CheckNhwcTransformerIsApplied( + ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.runtime_optimizations.ort"), + [](const OpCountMap& loaded_ops, const OpCountMap& initialized_ops) { + constexpr int n = 3; // expected number of QDQ Convs to fuse + + EXPECT_EQ(loaded_ops, + (OpCountMap{{"DequantizeLinear", n * 3}, + {"QuantizeLinear", n}, + {"Conv", n}})); + + // should have internal version of QLinearConv that runs NHWC, and transposes around each of those nodes + // for the layout conversion. + EXPECT_EQ(initialized_ops, + (OpCountMap{{"Transpose", n * 2}, + {"com.microsoft.QLinearConv", n}})); + }); +} + +TEST(GraphRuntimeOptimizationTest, TestNhwcTransformerDirectlyUpdatesQLinearConv) { + CheckNhwcTransformerIsApplied( + // ORT format model that contains QLinearConv nodes + // to generate: + // - set environment variable ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL=extended + // - run: + // python -m onnxruntime.tools.convert_onnx_models_to_ort + // --optimization_style Fixed + // testdata/transform/runtime_optimization/qdq_convs.onnx + ORT_TSTR("testdata/transform/runtime_optimization/qdq_convs.extended.ort"), + [](const OpCountMap& loaded_ops, const OpCountMap& initialized_ops) { + constexpr int n = 3; // expected number of QLinearConvs + + EXPECT_EQ(loaded_ops, + (OpCountMap{{"QLinearConv", n}})); + + // should have internal version of QLinearConv that runs NHWC, and transposes around each of those nodes + // for the layout conversion. + EXPECT_EQ(initialized_ops, + (OpCountMap{{"Transpose", n * 2}, + {"com.microsoft.QLinearConv", n}})); + }); } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort new file mode 100644 index 0000000000000..82fd4d708d4f0 Binary files /dev/null and b/onnxruntime/test/testdata/transform/runtime_optimization/qdq_convs.extended.ort differ diff --git a/tools/python/util/convert_onnx_models_to_ort.py b/tools/python/util/convert_onnx_models_to_ort.py index 0645bac1b48f5..e798cd3a57a0a 100644 --- a/tools/python/util/convert_onnx_models_to_ort.py +++ b/tools/python/util/convert_onnx_models_to_ort.py @@ -21,16 +21,18 @@ class OptimizationStyle(enum.Enum): Runtime = 1 -def _optimization_suffix(optimization_style: OptimizationStyle, suffix: str): - return "{}{}".format(".with_runtime_opt" if optimization_style == OptimizationStyle.Runtime else "", - suffix) +def _optimization_suffix(optimization_level_str: str, optimization_style: OptimizationStyle, suffix: str): + return "{}{}{}".format(f".{optimization_level_str}" if optimization_level_str != "all" else "", + ".with_runtime_opt" if optimization_style == OptimizationStyle.Runtime else "", + suffix) def _create_config_file_path(model_path_or_dir: pathlib.Path, + optimization_level_str: str, optimization_style: OptimizationStyle, enable_type_reduction: bool): config_name = "{}{}".format('required_operators_and_types' if enable_type_reduction else 'required_operators', - _optimization_suffix(optimization_style, ".config")) + _optimization_suffix(optimization_level_str, optimization_style, ".config")) if model_path_or_dir.is_dir(): return model_path_or_dir / config_name return model_path_or_dir.with_suffix(f".{config_name}") @@ -99,14 +101,15 @@ def is_model_file_to_convert(file_path: pathlib.Path): (output_dir / relative_model_path).parent.mkdir(parents=True, exist_ok=True) ort_target_path = (output_dir / relative_model_path).with_suffix( - _optimization_suffix(optimization_style, ".ort")) + _optimization_suffix(optimization_level_str, optimization_style, ".ort")) if create_optimized_onnx_model: # Create an ONNX file with the same optimization level that will be used for the ORT format file. # This allows the ONNX equivalent of the ORT format model to be easily viewed in Netron. # If runtime optimizations are saved in the ORT format model, there may be some difference in the # graphs at runtime between the ORT format model and this saved ONNX model. - optimized_target_path = (output_dir / relative_model_path).with_suffix(".optimized.onnx") + optimized_target_path = (output_dir / relative_model_path).with_suffix( + _optimization_suffix(optimization_level_str, optimization_style, ".optimized.onnx")) so = _create_session_options(optimization_level, optimized_target_path, custom_op_library, session_options_config_entries) if optimization_style == OptimizationStyle.Runtime: @@ -209,7 +212,9 @@ def convert_onnx_models_to_ort(): args = parse_args() optimization_styles = [OptimizationStyle[style_str] for style_str in args.optimization_style] - optimization_level_str = 'all' + # setting optimization level is not expected to be needed by typical users, but it can be set with this + # environment variable + optimization_level_str = os.getenv("ORT_CONVERT_ONNX_MODELS_TO_ORT_OPTIMIZATION_LEVEL", "all") model_path_or_dir = args.model_path_or_dir.resolve() custom_op_library = args.custom_op_library.resolve() if args.custom_op_library else None @@ -269,7 +274,8 @@ def convert_onnx_models_to_ort(): print("Generating config file from ORT format models with optimization style '{}' and level '{}'".format( optimization_style.name, optimization_level_str)) - config_file = _create_config_file_path(model_path_or_dir, optimization_style, args.enable_type_reduction) + config_file = _create_config_file_path(model_path_or_dir, optimization_level_str, optimization_style, + args.enable_type_reduction) create_config_from_models(converted_models, config_file, args.enable_type_reduction)