Skip to content

Commit

Permalink
[VitisAI] update graph_save
Browse files Browse the repository at this point in the history
  • Loading branch information
Chunye Wang committed Jun 8, 2024
1 parent 981893c commit c563855
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,18 @@ void graph_remove_node(Graph& graph, const NodeInput& node_input) {
}

void graph_save(const Graph& graph, const std::string& filename, const std::string& filename_dat, size_t initializer_size_threshold) {
auto& model = const_cast<Model&>(graph.GetModel());
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto;

auto model_proto = const_cast<onnxruntime::Model&>(graph.GetModel()).ToProto();
auto graph_proto_subgraph = graph.ToGraphProto();
*model_proto->mutable_graph() = *graph_proto_subgraph;
auto& logger = logging::LoggingManager::DefaultLogger();
auto filename_data_relative_path = std::filesystem::path();
auto model = Model::Create(std::move(*model_proto), ToPathString(filename), nullptr, logger);
if (initializer_size_threshold == std::numeric_limits<size_t>::max()) {
model_proto = model.ToProto();
model_proto = model->ToProto();
} else {
model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, graph.ModelPath().ToPathString(), initializer_size_threshold);
model_proto = model->ToGraphProtoWithExternalInitializers(filename_dat, ToPathString(filename), initializer_size_threshold);
}
auto& metadata = model.MetaData();
auto& metadata = model->MetaData();
if (!metadata.empty()) {
auto metadata_props = model_proto->mutable_metadata_props();
metadata_props->Clear();
Expand All @@ -121,18 +124,6 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri
*prop->mutable_value() = m.second;
}
}
// use relative path as data storage.
auto graph_proto = model_proto->mutable_graph();
*graph_proto = *graph.ToGraphProto();
for (int i = 0; i < graph_proto->mutable_initializer()->size(); i++) {
auto mutable_external_data = graph_proto->mutable_initializer()->at(i).mutable_external_data();
for (int j = 0; j < mutable_external_data->size(); j++) {
auto& external_data = mutable_external_data->at(j);
if (*external_data.mutable_key() == "location")
*external_data.mutable_value() = std::filesystem::path(*external_data.mutable_value()).filename().u8string();
}
}

std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary);
bool result = model_proto->SerializeToOstream(output);
output << std::flush;
Expand Down

0 comments on commit c563855

Please sign in to comment.