Skip to content

Commit

Permalink
Merge branch 'master' into tom/ragged_varient
Browse files Browse the repository at this point in the history
  • Loading branch information
TomWildenhain-Microsoft authored May 7, 2021
2 parents 46b9722 + 229985e commit f26fe0d
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 19 deletions.
64 changes: 64 additions & 0 deletions examples/getting_started.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0

"""
This example shows how to convert tf functions and keras models using the Python API.
It also demonstrates converting saved_models from the command line.
"""

import tensorflow as tf
import tf2onnx
import numpy as np
import onnxruntime as ort
import os

##################### tf function #####################

@tf.function
def f(a, b):
return a + b

input_signature = [tf.TensorSpec([2, 3], tf.float32), tf.TensorSpec([2, 3], tf.float32)]
onnx_model, _ = tf2onnx.convert.from_function(f, input_signature, opset=13)

a_val = np.ones([2, 3], np.float32)
b_val = np.zeros([2, 3], np.float32)

print("Tensorflow result")
print(f(a_val, b_val).numpy())

print("ORT result")
sess = ort.InferenceSession(onnx_model.SerializeToString())
res = sess.run(None, {'a': a_val, 'b': b_val})
print(res[0])


##################### Keras Model #####################

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, activation="relu"))

input_signature = [tf.TensorSpec([3, 3], tf.float32, name='x')]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13)

x_val = np.ones((3, 3), np.float32)

print("Keras result")
print(model(x_val).numpy())

print("ORT result")
sess = ort.InferenceSession(onnx_model.SerializeToString())
res = sess.run(None, {'x': x_val})
print(res[0])


##################### Saved Model #####################

model.save("savedmodel")
os.system("python -m tf2onnx.convert --saved-model savedmodel --output model.onnx --opset 13")

print("ORT result")
sess = ort.InferenceSession("model.onnx")
res = sess.run(None, {'dense_input:0': x_val})
print(res[0])

print("Conversion succeeded")
30 changes: 18 additions & 12 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import unicode_literals

# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,import-outside-toplevel
# pylint: disable=wrong-import-position
# pylint: disable=wrong-import-position,invalid-unary-operand-type

import logging
import os
Expand Down Expand Up @@ -106,7 +106,8 @@ def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
raise ValueError("unknown backend")
return y

def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, check_shape=True, check_dtype=True):
def assert_results_equal(self, expected, actual, rtol, atol, mtol=None,
check_value=True, check_shape=True, check_dtype=True):
for expected_val, actual_val in zip(expected, actual):
if check_value:
if expected_val.dtype == np.object:
Expand All @@ -115,6 +116,11 @@ def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, c
expected_val_str = decode(expected_val)
self.assertAllEqual(expected_val_str, actual_val)
else:
if mtol is not None:
expected_val = np.minimum(expected_val, mtol)
expected_val = np.maximum(expected_val, -mtol)
actual_val = np.minimum(actual_val, mtol)
actual_val = np.maximum(actual_val, -mtol)
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol)
if check_dtype:
self.assertEqual(expected_val.dtype, actual_val.dtype)
Expand Down Expand Up @@ -295,10 +301,10 @@ def get_dtype(info):
else:
self.assertEqual(onnx_shape, tf2onnx_shape)

def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False,
large_model=False, premade_placeholders=False):
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False):
test_tf = not self.config.skip_tf_tests
test_tflite = not self.config.skip_tflite_tests
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
Expand Down Expand Up @@ -340,19 +346,19 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
g = optimizer.optimize_graph(g, catch_errors=False)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)

self.assert_results_equal(expected, actual, rtol, atol, check_value, check_shape, check_dtype)
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)

if graph_validator:
self.assertTrue(graph_validator(g))

if test_tflite:
tfl_results, tfl_outputs = self.run_tflite(tflite_path, feed_dict)
test_tflite = tfl_results is not None
tfl_res, tfl_outputs = self.run_tflite(tflite_path, feed_dict)
test_tflite = tfl_res is not None

if test_tflite:
if run_tfl_consistency_test:
self.assert_results_equal(expected, tfl_results, rtol, atol, check_value, check_shape, check_dtype)
self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)

tfl_process_args = process_args.copy()
if 'inputs_as_nchw' in tfl_process_args:
Expand All @@ -368,9 +374,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
**tfl_process_args)
g = optimizer.optimize_graph(g)
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
onnx_from_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")

self.assert_results_equal(tfl_results, onnx_from_tfl_res, rtol, atol, check_value, check_shape, check_dtype)
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)

if graph_validator:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,17 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("1.14")
@check_opset_min_version(11, "float equality")
def test_div_no_nan(self):
x_val = np.array([1.0, 2.0, -3.0, -4.0, 5.0, 0.0, float("nan"), float("-inf"), float("inf")], dtype=np.float32)
y_val = np.array([1.0, 0.5, 0.0, -4.0, 0.0, 0.0, 0.0, 2.0, 0.0], dtype=np.float32)
def func(x, y):
x_ = tf.math.divide_no_nan(x, y)
return tf.identity(x_, name=_TFOUTPUT)
# TFLite expresses infinity as a value > 1e38
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}, mtol=1e38)

@check_onnxruntime_incompatibility("Exp")
def test_exp(self):
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))
Expand Down
10 changes: 9 additions & 1 deletion tests/test_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def run_example(self, name, expected=None):
"..", "examples", name)
if not os.path.exists(full):
raise FileNotFoundError(full)
proc = subprocess.run(('python %s' % full).split(),
proc = subprocess.run(['python', full],
capture_output=True, check=True)
self.assertEqual(0, proc.returncode)
out = proc.stdout.decode('ascii')
Expand Down Expand Up @@ -51,6 +51,14 @@ def test_end2end_tfhub(self):
"Optimizing ONNX model",
"Using opset <onnx, 12>"])

@check_tf_min_version("2.3", "use tf.keras")
@check_opset_min_version(13)
@check_opset_max_version(13)
def test_getting_started(self):
self.run_example(
"getting_started.py",
expected=["Conversion succeeded"])


if __name__ == '__main__':
unittest.main()
12 changes: 12 additions & 0 deletions tf2onnx/onnx_opset/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,18 @@ def version_7(cls, ctx, node, **kwargs):
pass


@tf_op("DivNoNan")
class DivNoNan:
@classmethod
def version_9(cls, ctx, node, **kwargs):
node.type = "Div"
np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))
zero_const = ctx.make_const(utils.make_name("const_zero"), np.array(0, np_dtype)).output[0]
is_zero = ctx.make_node("Equal", [node.input[1], zero_const]).output[0]
where_node = ctx.make_node("Where", [is_zero, zero_const, node.output[0]])
ctx.insert_node_on_output(where_node, node.output[0])


@tf_op("LRN")
class LRN:
@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/tflite_rewriters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from tf2onnx.tflite_rewriters.tfl_scan_output_rewriter import rewrite_tfl_scan_outputs
from tf2onnx.tflite_rewriters.tfl_qdq_rewriter import rewrite_tfl_qdq
from tf2onnx.tflite_rewriters.tfl_select_zero_mul_rewriter import rewrite_tfl_select_zero_mul
from tf2onnx.tflite_rewriters.tfl_select_zero_rewriter import rewrite_tfl_select_zero

__all__ = [
"rewrite_tfl_scan_outputs",
"rewrite_tfl_qdq",
"rewrite_tfl_select_zero_mul",
"rewrite_tfl_select_zero",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@


"""
tf2onnx.tflite_rewriters.tfl_select_zero_mul_rewriter - TFLite has a pattern to remove NaN when multiplying by 0
tf2onnx.tflite_rewriters.tfl_select_zero_rewriter - TFLite has a pattern to remove NaN when multiplying/dividing by 0
"""
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher


# pylint: disable=missing-docstring,unused-argument

def rewrite_tfl_select_zero_mul(g, ops):
def rewrite_tfl_select_zero(g, ops):
pattern0 = \
OpTypePattern('TFL_SELECT_V2', name='select', inputs=[
OpTypePattern('TFL_EQUAL', name='equal', inputs=[
OpTypePattern('Const|ConstV2', name='const_eq'),
OpTypePattern('*', name='term_eq'),
], allow_reorder=True),
OpTypePattern('Const|ConstV2', name='const_select'),
OpTypePattern('TFL_MUL', name='mul', inputs=[
OpTypePattern('TFL_MUL|TFL_DIV', name='mul', inputs=[
OpTypePattern('*', name='term_mul1'),
OpTypePattern('*', name='term_mul2'),
]),
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_erro
if dequantize:
tfl_rewriters.append(rewrite_tfl_qdq)
tfl_rewriters.append(rewrite_tfl_scan_outputs)
tfl_rewriters.append(rewrite_tfl_select_zero_mul)
tfl_rewriters.append(rewrite_tfl_select_zero)
run_rewriters(g, tfl_rewriters, continue_on_error)
tfl_ops_mapping = handler.tfl_op.create_tfl_to_tf_mapping()
_, _, exceptions = tensorflow_onnx_mapping(g, tfl_ops_mapping, is_tflite=True, dequantize=False)
Expand Down

0 comments on commit f26fe0d

Please sign in to comment.