Skip to content

Commit

Permalink
Update convert_onnx_models_to_ort.py to support runtime optimizations. (
Browse files Browse the repository at this point in the history
microsoft#10765)

Add runtime optimization support to ONNX -> ORT format conversion script.
Replace `--optimization_level`, `--use_nnapi`, and `--use_coreml` with a new `--optimization_style` option.
  • Loading branch information
edgchen1 authored and lavanyax committed Mar 29, 2022
1 parent df86dbf commit c63f453
Show file tree
Hide file tree
Showing 16 changed files with 403 additions and 255 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ file(GLOB onnxruntime_python_datasets_data CONFIGURE_DEPENDS
set(onnxruntime_mobile_util_srcs
${REPO_ROOT}/tools/python/util/check_onnx_model_mobile_usability.py
${REPO_ROOT}/tools/python/util/convert_onnx_models_to_ort.py
${REPO_ROOT}/tools/python/util/file_utils.py
${REPO_ROOT}/tools/python/util/logger.py
${REPO_ROOT}/tools/python/util/make_dynamic_shape_fixed.py
${REPO_ROOT}/tools/python/util/onnx_model_utils.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ static const char* const kOrtSessionOptionsDisableQuantQDQ = "session.disable_qu
// other factors like whether the model was created using Quantization Aware Training or Post Training Quantization.
// As such, it's best to test to determine if enabling this works well for your scenario.
// The default value is "0"
// Available since version 1.11.
static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enable_quant_qdq_cleanup";

// Enable or disable gelu approximation in graph optimization. "0": disable; "1": enable. The default is "0".
Expand All @@ -80,25 +81,18 @@ static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "ses

// This should only be specified when exporting an ORT format model for use on a different platform.
// If the ORT format model will be used on ARM platforms set to "1". For other platforms set to "0"
// Available since version 1.11.
static const char* const kOrtSessionOptionsQDQIsInt8Allowed = "session.qdqisint8allowed";

// Save information for replaying graph optimizations later instead of applying them directly.
//
// When an ONNX model is loaded, ORT can perform various optimizations on the graph.
// However, when an ORT format model is loaded, the logic to perform these optimizations may not be available because
// this scenario must be supported by minimal builds.
// When loading an ONNX model, ORT can optionally save the effects of some optimizations for later replay in an ORT
// format model. These are known as "runtime optimizations" - in an ORT format model, they happen at runtime.
//
// Note: This option is only applicable when loading an ONNX model and saving an ORT format model.
//
// Note: Runtime optimizations are only supported for certain optimizations at the extended level or higher.
// Unsupported optimizations at those levels are not applied at all, while optimizations at other levels are applied
// directly.
//
// "0": disabled, "1": enabled
// The default is "0".
static const char* const kOrtSessionOptionsConfigSaveRuntimeOptimizations = "optimization.save_runtime_optimizations";
// Specifies how minimal build graph optimizations are handled in a full build.
// These optimizations are at the extended level or higher.
// Possible values and their effects are:
// "save": Save runtime optimizations when saving an ORT format model.
// "apply": Only apply optimizations available in a minimal build.
// ""/<unspecified>: Apply optimizations available in a full build.
// Available since version 1.11.
static const char* const kOrtSessionOptionsConfigMinimalBuildOptimizations =
"optimization.minimal_build_optimizations";

// Note: The options specific to an EP should be specified prior to appending that EP to the session options object in
// order for them to take effect.
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/config_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ bool ConfigOptions::TryGetConfigEntry(const std::string& config_key, std::string
return found;
}

const std::string ConfigOptions::GetConfigOrDefault(const std::string& config_key,
const std::string& default_value) const noexcept {
std::string ConfigOptions::GetConfigOrDefault(const std::string& config_key,
const std::string& default_value) const noexcept {
return GetConfigEntry(config_key).value_or(default_value);
}

Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/framework/config_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
namespace onnxruntime {

/**
* Configuration options that can be used by any struct by inheriting this class.
* Provides infrastructure to add/get config entries
*/
* Configuration options that can be used by any struct by inheriting this class.
* Provides infrastructure to add/get config entries
*/
struct ConfigOptions {
std::unordered_map<std::string, std::string> configurations;

Expand All @@ -29,7 +29,7 @@ struct ConfigOptions {

// Get the config string in this instance of ConfigOptions using the given config_key
// If there is no such config, the given default string will be returned
const std::string GetConfigOrDefault(const std::string& config_key, const std::string& default_value) const noexcept;
std::string GetConfigOrDefault(const std::string& config_key, const std::string& default_value) const noexcept;

// Add a config pair (config_key, config_value) to this instance of ConfigOptions
Status AddConfigEntry(const char* config_key, const char* config_value) noexcept;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
bool saving = std::holds_alternative<SatRuntimeOptimizationSaveContext>(apply_context);
const bool saving = std::holds_alternative<SatRuntimeOptimizationSaveContext>(apply_context);

switch (level) {
case TransformerLevel::Level1:
Expand Down
105 changes: 75 additions & 30 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::Logger&

return status;
}
} // namespace

#if !defined(ORT_MINIMAL_BUILD)
static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) {

bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) {
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();

Expand All @@ -178,7 +178,7 @@ static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderTy
return true;
}

static bool HasControlflowNodes(const Graph& graph) {
bool HasControlflowNodes(const Graph& graph) {
for (const auto& node : graph.Nodes()) {
if (node.ContainsSubgraph()) {
return true;
Expand All @@ -187,7 +187,40 @@ static bool HasControlflowNodes(const Graph& graph) {

return false;
}
#endif

Status GetMinimalBuildOptimizationHandling(
std::string_view config_value, bool saving_ort_format,
InferenceSession::MinimalBuildOptimizationHandling& minimal_build_optimization_handling) {
if (config_value == "save") {
if (saving_ort_format) {
minimal_build_optimization_handling =
InferenceSession::MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations;
return Status::OK();
}
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kOrtSessionOptionsConfigMinimalBuildOptimizations,
" value of 'save' is only valid when saving an ORT format model.");
}

if (config_value == "apply") {
minimal_build_optimization_handling =
InferenceSession::MinimalBuildOptimizationHandling::OnlyApplyMinimalBuildOptimizations;
return Status::OK();
}

if (config_value.empty()) {
minimal_build_optimization_handling =
InferenceSession::MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;
return Status::OK();
}

return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid value for ", kOrtSessionOptionsConfigMinimalBuildOptimizations, ": ", config_value);
};

#endif // !defined(ORT_MINIMAL_BUILD)

} // namespace

std::atomic<uint32_t> InferenceSession::global_session_id_{1};

Expand Down Expand Up @@ -1402,14 +1435,17 @@ common::Status InferenceSession::Initialize() {

#if !defined(ORT_MINIMAL_BUILD)
if (!loading_ort_format) {
const bool saving_runtime_optimizations =
saving_ort_format &&
session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigSaveRuntimeOptimizations,
"0") == "1";
const auto minimal_build_opt_config_value = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsConfigMinimalBuildOptimizations, "");
MinimalBuildOptimizationHandling minimal_build_optimization_handling{};
ORT_RETURN_IF_ERROR_SESSIONID_(GetMinimalBuildOptimizationHandling(minimal_build_opt_config_value,
saving_ort_format,
minimal_build_optimization_handling));

// add predefined transformers
ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformation_mgr_,
session_options_.graph_optimization_level,
saving_runtime_optimizations));
minimal_build_optimization_handling));

// apply any transformations to the main graph and any subgraphs
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
Expand All @@ -1436,19 +1472,19 @@ common::Status InferenceSession::Initialize() {
// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all the graph nodes have not been partitioned to the CUDA EP.";

// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));

} else {
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
Expand Down Expand Up @@ -1875,11 +1911,11 @@ Status InferenceSession::Run(const RunOptions& run_options,

// Check if this Run() is simply going to be a CUDA Graph replay.
if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Replaying the captured "
<< cached_execution_provider_for_graph_replay_.Type()
<< " CUDA Graph for this model with tag: " << run_options.run_tag;
++current_num_runs_;
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph());
LOGS(*session_logger_, INFO) << "Replaying the captured "
<< cached_execution_provider_for_graph_replay_.Type()
<< " CUDA Graph for this model with tag: " << run_options.run_tag;
++current_num_runs_;
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph());
} else {
std::vector<IExecutionProvider*> exec_providers_to_stop;
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
Expand Down Expand Up @@ -1951,13 +1987,13 @@ Status InferenceSession::Run(const RunOptions& run_options,
}
#endif

// execute the graph
// execute the graph
#ifdef DEBUG_NODE_INPUTS_OUTPUTS
session_state_->IncrementGraphExecutionCounter();
#endif
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode, run_options.terminate, run_logger,
run_options.only_execute_path_to_fetches));
session_options_.execution_mode, run_options.terminate, run_logger,
run_options.only_execute_path_to_fetches));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Expand Down Expand Up @@ -2010,7 +2046,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
// are needed before replaying the captured graph, here run the inference again
// to capture the graph, so that users just need one session run to capture
// the graph.
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. "
"The first one is for necessary memory allocation;"
Expand Down Expand Up @@ -2361,21 +2397,30 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) {
#if !defined(ORT_MINIMAL_BUILD)

// Registers all the predefined transformers with transformer manager
common::Status InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager,
TransformerLevel graph_optimization_level,
bool saving_runtime_optimizations) const {
common::Status InferenceSession::AddPredefinedTransformers(
GraphTransformerManager& transformer_manager,
TransformerLevel graph_optimization_level,
MinimalBuildOptimizationHandling minimal_build_optimization_handling) const {
const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
TransformerLevel level = static_cast<TransformerLevel>(i);
if (graph_optimization_level >= level) {
// Generate and register transformers for level
auto transformers_to_register = [&]() {
if (!saving_runtime_optimizations || level == TransformerLevel::Level1) {
const bool use_full_build_optimizations =
level == TransformerLevel::Level1 ||
minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations;

if (use_full_build_optimizations) {
return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep,
optimizers_to_disable_);
} else {
SatRuntimeOptimizationSaveContext save_context{kernel_registry_manager_};
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, save_context, cpu_ep,
const auto sat_context =
minimal_build_optimization_handling ==
MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations
? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{kernel_registry_manager_}}
: SatApplyContextVariant{SatDirectApplicationContext{}};
return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep,
optimizers_to_disable_);
}
}();
Expand Down
25 changes: 22 additions & 3 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ struct ModelMetadata {

class InferenceSession {
public:
#if !defined(ORT_MINIMAL_BUILD)

/**
* How minimal build graph optimizations should be handled in a full build.
* Note: These only apply to optimizations at the extended level or higher.
*/
enum class MinimalBuildOptimizationHandling {
/** Run full build optimizations. The default behavior. */
ApplyFullBuildOptimizations,
/** Save minimal build optimizations as runtime optimizations in an ORT format model. */
SaveMinimalBuildRuntimeOptimizations,
/** Only run minimal build optimizations. */
OnlyApplyMinimalBuildOptimizations,
};

#endif

/**
Create a new InferenceSession
@param session_options Session options.
Expand Down Expand Up @@ -444,6 +461,7 @@ class InferenceSession {

protected:
#if !defined(ORT_MINIMAL_BUILD)

/**
* Load an ONNX model.
* @param protobuf object corresponding to the model file. model_proto will be copied by the API.
Expand Down Expand Up @@ -583,9 +601,10 @@ class InferenceSession {
void ShrinkMemoryArenas(const std::vector<AllocatorPtr>& arenas_to_shrink);

#if !defined(ORT_MINIMAL_BUILD)
virtual common::Status AddPredefinedTransformers(GraphTransformerManager& transformer_manager,
TransformerLevel graph_optimization_level,
bool saving_runtime_optimizations) const;
virtual common::Status AddPredefinedTransformers(
GraphTransformerManager& transformer_manager,
TransformerLevel graph_optimization_level,
MinimalBuildOptimizationHandling minimal_build_optimization_handling) const;

common::Status TransformGraph(onnxruntime::Graph& graph,
const onnxruntime::GraphTransformerManager& graph_transformer_mgr,
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/framework/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ using OpCountMap = std::map<std::string, int>;
// Helper function to check that the graph transformations have been successfully applied.
OpCountMap CountOpsInGraph(const Graph& graph, bool recurse_into_subgraphs = true);

// Gets the op count from the OpCountMap.
// Can be called with a const OpCountMap, unlike OpCountMap::operator[].
inline int OpCount(const OpCountMap& op_count_map, const std::string& op_type) {
if (auto it = op_count_map.find(op_type); it != op_count_map.end()) {
return it->second;
}
return 0;
}

#if !defined(DISABLE_SPARSE_TENSORS)
void SparseIndicesChecker(const ONNX_NAMESPACE::TensorProto& indices_proto, gsl::span<const int64_t> expected_indicies);
#endif // DISABLE_SPARSE_TENSORS
Expand Down
Loading

0 comments on commit c63f453

Please sign in to comment.