Skip to content

Commit

Permalink
add script to run optimizer on onnx files (#1538)
Browse files Browse the repository at this point in the history
* add script to run optimizer on onnx files

Signed-off-by: Guenther Schmuelling <guschmue@microsoft.com>

* pylint

Signed-off-by: Guenther Schmuelling <guschmue@microsoft.com>
  • Loading branch information
guschmue authored May 27, 2021
1 parent cba8f8d commit 2f9078e
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tools/onnx-optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0


"""
A simple tool to try optimizations on onnx graphs.
This makes use of the fact that tensorflow-onnx internal graph representation is onnx
so all graph, rewrite, matching and utility libaries do work which makes things easy.
"""

# pylint: disable=invalid-name,missing-docstring, unused-argument

import argparse
import logging

import onnx
from onnx import helper

from tf2onnx.graph import GraphUtil
from tf2onnx import logging, optimizer


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("onnx-optimize")


def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True, help="onnx input model file")
parser.add_argument("--output", help="output model file")
args = parser.parse_args()
return args


def load_graph(fname):
model_proto = onnx.ModelProto()
with open(fname, "rb") as f:
data = f.read()
model_proto.ParseFromString(data)
g = GraphUtil.create_graph_from_onnx_model(model_proto)
return g, model_proto


def main():
args = get_args()

g, org_model_proto = load_graph(args.input)

g = optimizer.optimize_graph(g)

onnx_graph = g.make_graph(org_model_proto.graph.doc_string + " (+tf2onnx/onnx-optimize)")

kwargs = {"producer_name": org_model_proto.producer_name,
"producer_version": org_model_proto.producer_version,
"opset_imports": org_model_proto.opset_import,
"ir_version": org_model_proto.ir_version}

model_proto = helper.make_model(onnx_graph, **kwargs)

# write onnx graph
if args.output:
with open(args.output, "wb") as f:
f.write(model_proto.SerializeToString())


if __name__ == "__main__":
main()

0 comments on commit 2f9078e

Please sign in to comment.