Skip to content

Commit

Permalink
Move node EP assignment for ORT format into SessionState::FinalizeSes…
Browse files Browse the repository at this point in the history
…sionState() (#10944)

Follow up to #10904.
- Move node EP assignment for ORT format into SessionState::FinalizeSessionState().
- Add unit test for #10904.
- Make convert_onnx_models_to_ort.py optimization level configurable via environment variable.
  • Loading branch information
edgchen1 authored Mar 28, 2022
1 parent 9c6cc01 commit 9371401
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 214 deletions.
142 changes: 114 additions & 28 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ Status SessionState::CreateSubgraphSessionState() {
for (auto& node : graph_.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
const auto& ep = node.GetExecutionProviderType();
if (ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) {
if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) {
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
// node containing the subgraph it will create whatever state it needs internally.
continue;
Expand Down Expand Up @@ -973,15 +973,36 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
const FbsSessionStateViewer fbs_session_state_viewer{fbs_session_state};
ORT_RETURN_IF_ERROR(fbs_session_state_viewer.Validate());

auto add_kernel_by_hash =
[&kernel_registry_manager, this](const Node& node, HashValue hash) {
// look up KernelCreateInfo with hash and
// - add KernelCreateInfo for node
// - set node's EP from KernelCreateInfo if unset
auto add_kernel_and_set_node_ep_by_hash =
[&kernel_registry_manager, this](Node& node, HashValue hash) {
const KernelCreateInfo* kci = nullptr;
utils::UpdateHashForBackwardsCompatibility(hash);

ORT_RETURN_IF_NOT(kernel_registry_manager.SearchKernelRegistriesByHash(hash, &kci),
"Failed to find kernel def hash (", hash, ") in kernel registries for ",
node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'.");
kernel_create_info_map_.emplace(node.Index(), gsl::not_null<const KernelCreateInfo*>(kci));

{
const auto [it, inserted] = kernel_create_info_map_.emplace(node.Index(),
gsl::not_null<const KernelCreateInfo*>(kci));
ORT_RETURN_IF_NOT(inserted,
"Cannot overwrite existing kernel for ",
node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(),
"'. Existing kernel def hash: ", it->second->kernel_def->GetHash(),
", new kernel def hash: ", hash, ".");
}

if (node.GetExecutionProviderType().empty()) {
node.SetExecutionProviderType(kci->kernel_def->Provider());
} else {
ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kci->kernel_def->Provider(),
"Node execution provider type mismatch. Existing: ", node.GetExecutionProviderType(),
", from KernelCreateInfo (via hash lookup): ", kci->kernel_def->Provider());
}

return Status::OK();
};

Expand All @@ -1003,46 +1024,46 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat
continue;
}

ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, node_kernel_info.kernel_def_hash));
ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(*node, node_kernel_info.kernel_def_hash));
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// process the nodes that were added by replaying any loaded runtime optimizations
for (const auto& [node_index, kernel_def_hash] :
graph_.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) {
const auto* node = graph_.GetNode(node_index);
auto* node = graph_.GetNode(node_index);

// NHWC optimizer may replace a node, so a missing node isn't necessarily an error
// ORT_RETURN_IF(node == nullptr, "Can't find runtime optimization produced node with index ", node_index);

if (node != nullptr) {
ORT_RETURN_IF_ERROR(add_kernel_by_hash(*node, kernel_def_hash));
ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(*node, kernel_def_hash));
}
}

// lookup the hashes for any nodes we compiled or added during graph partitioning.
// These node indexes for compiled nodes as well as newly added nodes are not in node_indices
// as they were created at runtime.
for (const auto& node : graph_.Nodes()) {
if (kernel_create_info_map_.count(node.Index()) == 0) {
if (node.Domain() == kOnnxDomain || node.Domain() == kMSDomain) {
// two possible places to get hash from
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (!kernel_hash.has_value()) {
kernel_hash = utils::GetInternalNhwcOpHash(node);
}
// Look up the hashes for any nodes we compiled or added during graph partitioning or other runtime optimizations.
// These nodes are not in the original model as they were created at runtime.
for (auto& node : graph_.Nodes()) {
if (kernel_create_info_map_.find(node.Index()) != kernel_create_info_map_.end()) {
continue;
}

if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, *kernel_hash));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to find kernel hash for node:", node.Name(), " optype:", node.OpType());
}
} else {
const auto hash_info = compiled_kernel_hashes.find(node.OpType());
ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(),
"Unable to find compiled kernel hash for node '", node.Name(), "'.");
ORT_RETURN_IF_ERROR(add_kernel_by_hash(node, hash_info->second));
if (node.Domain() == kOnnxDomain || node.Domain() == kMSDomain) {
// two possible places to get hash from
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (!kernel_hash.has_value()) {
kernel_hash = utils::GetInternalNhwcOpHash(node);
}
ORT_RETURN_IF_NOT(kernel_hash.has_value(),
"Unable to find kernel hash for node: '", node.Name(), "' optype: ", node.OpType());

ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(node, *kernel_hash));
} else {
const auto hash_info = compiled_kernel_hashes.find(node.OpType());
ORT_RETURN_IF(hash_info == compiled_kernel_hashes.cend(),
"Unable to find compiled kernel hash for node '", node.Name(), "'.");

ORT_RETURN_IF_ERROR(add_kernel_and_set_node_ep_by_hash(node, hash_info->second));
}
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down Expand Up @@ -1088,6 +1109,68 @@ static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordere
}
}

using NodePlacementMap = std::unordered_map<std::string, std::vector<std::string>>;

static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_verbose,
NodePlacementMap& node_placements) {
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();
if (node_provider.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Could not find an implementation for ",
node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'");
}

#if !defined(ORT_MINIMAL_BUILD)
if (is_verbose) { // TODO: should we disable this if the number of nodes is above a certain threshold?
const std::string node_str = node.OpType() + " (" + node.Name() + ")";
node_placements[node_provider].push_back(node_str);
}
#endif // !defined(ORT_MINIMAL_BUILD)

// recurse into subgraphs
if (node.ContainsSubgraph()) {
const auto subgraphs = node.GetSubgraphs();
for (const auto& subgraph : subgraphs) {
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(*subgraph, is_verbose, node_placements));
}
}
}

return Status::OK();
}

static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::Logger& logger) {
NodePlacementMap node_placements{};
#if !defined(ORT_MINIMAL_BUILD)
const bool is_verbose_mode = logger.GetSeverity() == logging::Severity::kVERBOSE;
#else
ORT_UNUSED_PARAMETER(logger);
const bool is_verbose_mode = false;
#endif // !defined(ORT_MINIMAL_BUILD)

ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(graph, is_verbose_mode, node_placements));

#if !defined(ORT_MINIMAL_BUILD)
// print placement info
if (is_verbose_mode) {
LOGS(logger, VERBOSE) << "Node placements";
if (node_placements.size() == 1) {
LOGS(logger, VERBOSE) << "All nodes have been placed on [" << node_placements.begin()->first << "].";
} else {
for (const auto& [provider, node_strs] : node_placements) {
std::ostringstream all_nodes_str;
std::copy(node_strs.begin(), node_strs.end(), std::ostream_iterator<std::string>(all_nodes_str, ", "));
LOGS(logger, VERBOSE) << " Provider: [" << provider << "]"
<< ": [" << all_nodes_str.str() << "]";
}
}
}
#endif // !defined(ORT_MINIMAL_BUILD)

return Status::OK();
}

Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE>& graph_location,
const KernelRegistryManager& kernel_registry_manager,
const SessionOptions& session_options,
Expand All @@ -1101,8 +1184,11 @@ Status SessionState::FinalizeSessionState(const std::basic_string<PATH_CHAR_TYPE

if (serialized_session_state) {
ORT_RETURN_IF_ERROR(LoadFromOrtFormat(*serialized_session_state, kernel_registry_manager));
// LoadFromOrtFormat() may assign node EPs so check afterwards
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph_, logger_));
} else {
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph_, logger_));
ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));
#else
ORT_UNUSED_PARAMETER(graph_location);
Expand Down
136 changes: 0 additions & 136 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,68 +102,6 @@ inline std::basic_string<T> GetCurrentTimeString() {
return std::basic_string<T>(time_str);
}

using NodePlacementMap = std::unordered_map<std::string, std::vector<std::string>>;

Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_verbose,
NodePlacementMap& node_placements) {
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();
if (node_provider.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Could not find an implementation for ",
node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'");
}

#if !defined(ORT_MINIMAL_BUILD)
if (is_verbose) { // TODO: should we disable this if the number of nodes is above a certain threshold?
const std::string node_str = node.OpType() + " (" + node.Name() + ")";
node_placements[node_provider].push_back(node_str);
}
#endif // !defined(ORT_MINIMAL_BUILD)

// recurse into subgraphs
if (node.ContainsSubgraph()) {
const auto subgraphs = node.GetSubgraphs();
for (const auto& subgraph : subgraphs) {
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(*subgraph, is_verbose, node_placements));
}
}
}

return Status::OK();
}

Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::Logger& logger) {
NodePlacementMap node_placements{};
#if !defined(ORT_MINIMAL_BUILD)
const bool is_verbose_mode = logger.GetSeverity() == logging::Severity::kVERBOSE;
#else
ORT_UNUSED_PARAMETER(logger);
const bool is_verbose_mode = false;
#endif // !defined(ORT_MINIMAL_BUILD)

const auto status = VerifyEachNodeIsAssignedToAnEpImpl(graph, is_verbose_mode, node_placements);

#if !defined(ORT_MINIMAL_BUILD)
// print placement info
if (is_verbose_mode) {
LOGS(logger, VERBOSE) << "Node placements";
if (node_placements.size() == 1) {
LOGS(logger, VERBOSE) << "All nodes have been placed on [" << node_placements.begin()->first << "].";
} else {
for (const auto& pr : node_placements) {
std::ostringstream all_nodes_str;
std::copy(pr.second.begin(), pr.second.end(), std::ostream_iterator<std::string>(all_nodes_str, ", "));
LOGS(logger, VERBOSE) << " Provider: [" << pr.first << "]"
<< ": [" << all_nodes_str.str() << "]";
}
}
}
#endif // !defined(ORT_MINIMAL_BUILD)

return status;
}

#if !defined(ORT_MINIMAL_BUILD)

bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) {
Expand Down Expand Up @@ -988,8 +926,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// Insert cast node/s.
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified, *session_logger_));

ORT_RETURN_IF_ERROR_SESSIONID_(VerifyEachNodeIsAssignedToAnEp(graph, *session_logger_));

std::vector<std::string> provider_types;
for (auto& provider_ptr : providers) {
provider_types.push_back(provider_ptr->Type());
Expand Down Expand Up @@ -1231,75 +1167,6 @@ Status ApplyOrtFormatModelRuntimeOptimizations(
return Status::OK();
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Status AssignNodesToEpsFromHashesImpl(Graph& graph, const fbs::SessionState& fbs_session_state,
const KernelRegistryManager& kernel_registry_manager) {
using fbs::utils::FbsSessionStateViewer;
const FbsSessionStateViewer fbs_session_state_viewer{fbs_session_state};
ORT_RETURN_IF_ERROR(fbs_session_state_viewer.Validate());

for (auto& node : graph.Nodes()) {
for (auto& [attribute, subgraph] : node.GetAttributeNameToMutableSubgraphMap()) {
const fbs::SessionState* fbs_subgraph_session_state;
ORT_RETURN_IF_ERROR(fbs_session_state_viewer.GetSubgraphSessionState(node.Index(), attribute,
fbs_subgraph_session_state));

ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashesImpl(*subgraph, *fbs_subgraph_session_state,
kernel_registry_manager));
}
}

const auto set_node_ep = [&](NodeIndex node_idx, HashValue kernel_def_hash) -> Status {
Node* node = graph.GetNode(node_idx);
if (!node || !node->GetExecutionProviderType().empty()) {
return Status::OK();
}

const KernelCreateInfo* kci = nullptr;
ORT_RETURN_IF_NOT(kernel_registry_manager.SearchKernelRegistriesByHash(kernel_def_hash, &kci),
"Failed to find kernel def hash (", kernel_def_hash, ") in kernel registries for ",
node->OpType(), "(", node->SinceVersion(), ") node with name '", node->Name(), "'.");
node->SetExecutionProviderType(kci->kernel_def->Provider());

return Status::OK();
};

for (FbsSessionStateViewer::Index i = 0, end = fbs_session_state_viewer.GetNumNodeKernelInfos(); i < end; ++i) {
const auto node_kernel_info = fbs_session_state_viewer.GetNodeKernelInfo(i);
ORT_RETURN_IF_ERROR(set_node_ep(node_kernel_info.node_index, node_kernel_info.kernel_def_hash));
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
for (const auto& [node_index, kernel_def_hash] :
graph.RuntimeOptimizationReplayCtx().produced_node_index_to_kernel_def_hash) {
ORT_RETURN_IF_ERROR(set_node_ep(node_index, kernel_def_hash));
}

// layout transformer which is enabled in extended minimal build can add new nodes.
// The following loop fetches the hash values for these nodes.
for (const auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType().empty()) {
auto kernel_hash = utils::GetHashValueFromStaticKernelHashMap(node.OpType(), node.SinceVersion());
if (!kernel_hash.has_value()) {
kernel_hash = utils::GetInternalNhwcOpHash(node);
}
if (kernel_hash.has_value()) {
ORT_RETURN_IF_ERROR(set_node_ep(node.Index(), kernel_hash.value()));
}
}
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

return Status::OK();
}

Status AssignNodesToEpsFromHashes(Graph& graph, const fbs::SessionState& fbs_session_state,
const KernelRegistryManager& kernel_registry_manager,
const logging::Logger& logger) {
ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashesImpl(graph, fbs_session_state, kernel_registry_manager));
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph, logger));
return Status::OK();
}
} // namespace

static void ResolveMemoryPatternFlags(SessionState& session_state) {
Expand Down Expand Up @@ -1518,9 +1385,6 @@ common::Status InferenceSession::Initialize() {
ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_,
cpu_ep));
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

ORT_RETURN_IF_ERROR(AssignNodesToEpsFromHashes(graph, *serialized_session_state, kernel_registry_manager_,
*session_logger_));
}

ORT_RETURN_IF_ERROR_SESSIONID_(
Expand Down
Loading

0 comments on commit 9371401

Please sign in to comment.