Skip to content

Commit

Permalink
Merge branch 'profile' of https://github.com/xadupre/tensorflow-onnx
Browse files Browse the repository at this point in the history
…into input2
  • Loading branch information
sdpython committed Aug 27, 2020
2 parents fd99a90 + febb438 commit ee8df44
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 36 deletions.
11 changes: 0 additions & 11 deletions ci_build/azure_pipelines/templates/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,3 @@ steps:
condition: succeededOrFailed()
env:
CI_ONNX_OPSET: '${{ onnx_opset }}'
- bash: |
export TF2ONNX_TEST_BACKEND=$CI_ONNX_BACKEND
export TF2ONNX_TEST_OPSET=$CI_ONNX_OPSET
pip install fire pyinstrument
python benchmarks/profile_conversion_time.py
timeoutInMinutes: 15
displayName: ${{ format('Run profile_conversion_time.py - Opset{0}', onnx_opset) }}
condition: succeededOrFailed()
env:
CI_ONNX_OPSET: '${{ onnx_opset }}'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self):
version=VersionInfo.version,
description='Tensorflow to ONNX converter',
setup_requires=['pytest-runner'],
tests_require=['graphviz', 'parameterized', 'pytest', 'pytest-cov', 'pyyaml'],
tests_require=['graphviz', 'parameterized', 'pytest', 'pytest-cov', 'pyyaml', 'fire'],
cmdclass=cmdclass,
packages=find_packages(),
author='onnx@microsoft.com',
Expand Down
32 changes: 32 additions & 0 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

"""Unit Tests for Benchmarks."""
import os
import subprocess
from backend_test_base import Tf2OnnxBackendTestBase
from common import check_opset_after_tf_version, unittest_main

# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
# pylint: disable=invalid-name
# pylint: enable=invalid-name

class ProfileTests(Tf2OnnxBackendTestBase):

folder = os.path.join(os.path.dirname(__file__), '..', 'tools')

@check_opset_after_tf_version("2.0", 12, "might need Scan")
def test_profile_conversion_time(self):
filename = os.path.join(ProfileTests.folder, 'profile_conversion_time.py')
proc = subprocess.Popen(
["python", filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
outs = proc.communicate(timeout=15)[0]
except subprocess.TimeoutExpired:
proc.kill()
return
assert b"Profile complete." in outs


if __name__ == '__main__':
unittest_main()
49 changes: 25 additions & 24 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def tensorflow_onnx_mapping(g, ops_mapping):
node.skip_conversion = True
except Exception as ex:
logger.error("Failed to convert node %r (fct=%r)\n%r",
node.name, func, node.summary, exc_info=1)
node.name, func, node.summary, exc_info=1)
exceptions.append(ex)

return mapped_op, unmapped_op, exceptions
Expand Down Expand Up @@ -451,29 +451,30 @@ def compat_handler(ctx, node, **kwargs):

# pre-processing graph rewrites
# bi-directional re-writer should be placed after single directional re-writer
rewriters = [# single directional
rewrite_constant_fold,
rewrite_quantize_and_dequantize,
rewrite_transpose,
rewrite_flatten,
rewrite_random_uniform,
rewrite_random_uniform_fold_const,
rewrite_random_normal,
rewrite_dropout,
rewrite_eye,
rewrite_leakyrelu,
rewrite_thresholded_relu,
rewrite_conv2d_with_pad,
rewrite_single_direction_lstm,
# bi-directional
rewrite_bi_direction_lstm,
rewrite_single_direction_gru,
rewrite_bi_direction_gru,
rewrite_custom_rnn_cell,
rewrite_generic_loop, rewrite_cond,
rewrite_biasadd_with_conv2d,
rewrite_gemm,
]
rewriters = [
# single directional
rewrite_constant_fold,
rewrite_quantize_and_dequantize,
rewrite_transpose,
rewrite_flatten,
rewrite_random_uniform,
rewrite_random_uniform_fold_const,
rewrite_random_normal,
rewrite_dropout,
rewrite_eye,
rewrite_leakyrelu,
rewrite_thresholded_relu,
rewrite_conv2d_with_pad,
rewrite_single_direction_lstm,
# bi-directional
rewrite_bi_direction_lstm,
rewrite_single_direction_gru,
rewrite_bi_direction_gru,
rewrite_custom_rnn_cell,
rewrite_generic_loop, rewrite_cond,
rewrite_biasadd_with_conv2d,
rewrite_gemm,
]

if custom_rewriter is not None:
rewriters.extend(custom_rewriter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,4 @@ def profile(profiler="none", name="MobileNet", show_all=False,

if __name__ == '__main__':
fire.Fire(profile)
print('Profile complete.')

0 comments on commit ee8df44

Please sign in to comment.