Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a script to profile the conversion of a model #1077

Merged
merged 18 commits into from
Sep 30, 2020
Merged
37 changes: 37 additions & 0 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

"""Unit Tests for Benchmarks."""
import os
import subprocess
from backend_test_base import Tf2OnnxBackendTestBase
from common import (
check_opset_min_version, check_tf_min_version,
unittest_main, check_onnxruntime_min_version
)

# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
# pylint: disable=invalid-name
# pylint: enable=invalid-name

class ProfileTests(Tf2OnnxBackendTestBase):

folder = os.path.join(os.path.dirname(__file__), '..', 'tools')

@check_tf_min_version("2.0")
@check_opset_min_version(12)
@check_onnxruntime_min_version('1.4.0')
def test_profile_conversion_time(self):
filename = os.path.join(ProfileTests.folder, 'profile_conversion_time.py')
proc = subprocess.Popen(
["python", filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
outs = proc.communicate(timeout=15)[0]
except subprocess.TimeoutExpired:
proc.kill()
return
assert b"Profile complete." in outs or outs == b''


if __name__ == '__main__':
unittest_main()
2 changes: 1 addition & 1 deletion tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def get_tensor_value(self, as_list=True):
when as_list=True, return 1, type is <class 'int'>.
"""
if not self.is_const():
raise ValueError("get tensor value: {} must be Const".format(self.name))
raise ValueError("get tensor value: '{}' must be Const".format(self.name))

t = self.get_attr("value")
if t:
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/onnx_opset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def version_1(cls, ctx, node, **kwargs):
# const we make it an attribute.
seed = node.get_attr("seed")
node.set_attr("seed", float(seed.f))
if len(node.input) > 0:
if len(node.input) > 0 and node.inputs[0].is_const():
shape = node.inputs[0].get_tensor_value()
ctx.remove_input(node, node.input[0], 0)
node.set_attr("shape", shape)
Expand Down
36 changes: 26 additions & 10 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def tensorflow_onnx_mapping(g, ops_mapping):
func(g, node, **kwargs)
node.skip_conversion = True
except Exception as ex:
logger.error("Failed to convert node %s\n%s", node.name, node.summary, exc_info=1)
logger.error("Failed to convert node %r (fct=%r)\n%r",
node.name, func, node.summary, exc_info=1)
exceptions.append(ex)

return mapped_op, unmapped_op, exceptions
Expand Down Expand Up @@ -453,15 +454,30 @@ def compat_handler(ctx, node, **kwargs):

# pre-processing graph rewrites
# bi-directional re-writer should be placed after single directional re-writer
rewriters = [rewrite_constant_fold, rewrite_quantize_and_dequantize, rewrite_transpose, rewrite_flatten,
rewrite_random_uniform, rewrite_random_uniform_fold_const,
rewrite_random_normal, rewrite_dropout, rewrite_eye,
rewrite_leakyrelu, rewrite_thresholded_relu, rewrite_conv2d_with_pad,
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
rewrite_single_direction_gru, rewrite_bi_direction_gru,
rewrite_custom_rnn_cell, rewrite_generic_loop, rewrite_cond,
rewrite_biasadd_with_conv2d, rewrite_gemm
]
rewriters = [
# single directional
rewrite_constant_fold,
rewrite_quantize_and_dequantize,
rewrite_transpose,
rewrite_flatten,
rewrite_random_uniform,
rewrite_random_uniform_fold_const,
rewrite_random_normal,
rewrite_dropout,
rewrite_eye,
rewrite_leakyrelu,
rewrite_thresholded_relu,
rewrite_conv2d_with_pad,
rewrite_single_direction_lstm,
# bi-directional
rewrite_bi_direction_lstm,
rewrite_single_direction_gru,
rewrite_bi_direction_gru,
rewrite_custom_rnn_cell,
rewrite_generic_loop, rewrite_cond,
rewrite_biasadd_with_conv2d,
rewrite_gemm,
]

if custom_rewriter is not None:
rewriters.extend(custom_rewriter)
Expand Down
116 changes: 116 additions & 0 deletions tools/profile_conversion_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding: utf-8
"""
Profiles the conversion of a Keras model.
"""
import sys
import cProfile
from pstats import SortKey, Stats
import io
import argparse
import tensorflow as tf
from tensorflow.keras.applications import MobileNet, EfficientNetB2
from tf2onnx import tfonnx
try:
from pyinstrument import Profiler
except ImportError:
Profiler = None


def spy_model(name):
"Creates the model."
with tf.compat.v1.Session(graph=tf.Graph()) as session:
if name == "MobileNet":
model = MobileNet()
elif name == "EfficientNetB2":
model = EfficientNetB2()
else:
raise ValueError("Unknown model name %r." % name)

graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess=session,
input_graph_def=session.graph_def,
output_node_names=[model.output.op.name])

return graph_def, model


def spy_convert(graph_def, model):
"Converts the model."
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def=graph_def, name='')

def spy_convert_in():
return tfonnx.process_tf_graph(
tf_graph=graph, input_names=[model.input.name],
output_names=[model.output.name])

spy_convert_in()


def create(name):
"Creates the model."
graph_def, model = spy_model(name)
return graph_def, model


def convert(graph_def, model):
"Converts the model."
spy_convert(graph_def, model)


def profile(profiler="none", name="MobileNet", show_all=False):
"""
Profiles the conversion of a model.

:param profiler: one among none, spy, pyinstrument, cProfile
:param name: model to profile, MobileNet, EfficientNetB2
:param showall: used by pyinstrument to show all functions
"""
print("create(%r, %r)" % (profiler, name))
graph_def, model = create(name)
print("profile(%r, %r)" % (profiler, name))
if profiler == 'none':
convert(graph_def, model)
elif profiler == "spy":
# py-spy record -r 10 -o profile.svg -- python conversion_time.py spy
convert(graph_def, model)
elif profiler == "pyinstrument":
if Profiler is None:
raise ImportError("pyinstrument is not installed")
profiler = Profiler(interval=0.0001)
profiler.start()
convert(graph_def, model)
profiler.stop()
print(profiler.output_text(unicode=False, color=False, show_all=show_all))
elif profiler == "cProfile":
pr = cProfile.Profile()
pr.enable()
convert(graph_def, model)
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
else:
raise ValueError("Unknown profiler %r." % profiler)


def main(args):
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--profiler', default='none',
choices=['none', 'spy', 'pyinstrument', 'cProfile'],
help='a profiler')
parser.add_argument('--name', default="MobileNet",
choices=['MobileNet', 'EfficientNetB2'],
help="a model")
parser.add_argument('--showall', type=bool, default=False,
help="used by pyinstrument to show all functions")
res = parser.parse_args(args)
profile(res.profiler, res.name, res.showall)


if __name__ == '__main__':
print('Begin profiling with', sys.argv[1:])
main(sys.argv[1:])
print('Profile complete.')