Skip to content

Commit

Permalink
Pylint
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft committed May 13, 2021
1 parent c76c0b2 commit 37eb068
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 27 deletions.
13 changes: 8 additions & 5 deletions tests/backend_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def run_onnxruntime(self, model_path, inputs, output_names, use_custom_ops=False
results = m.run(output_names, inputs)
return results

def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
def run_backend(self, g, outputs, input_dict, large_model=False, postfix="", use_custom_ops=False):
tensor_storage = ExternalTensorStorage() if large_model else None
model_proto = g.make_model("test", external_tensor_storage=tensor_storage)
model_path = self.save_onnx_model(model_proto, input_dict, external_tensor_storage=tensor_storage,
postfix=postfix)

if self.config.backend == "onnxruntime":
y = self.run_onnxruntime(model_path, input_dict, outputs)
y = self.run_onnxruntime(model_path, input_dict, outputs, use_custom_ops)
elif self.config.backend == "caffe2":
y = self.run_onnxcaffe2(model_proto, input_dict)
else:
Expand Down Expand Up @@ -307,7 +307,8 @@ def get_dtype(info):
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False):
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False,
use_custom_ops=False):
test_tf = not self.config.skip_tf_tests
test_tflite = not self.config.skip_tflite_tests
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
Expand Down Expand Up @@ -347,7 +348,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
initialized_tables=initialized_tables,
**process_args)
g = optimizer.optimize_graph(g, catch_errors=False)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model,
use_custom_ops=use_custom_ops)

self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand Down Expand Up @@ -377,7 +379,8 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
**tfl_process_args)
g = optimizer.optimize_graph(g)
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port,
postfix="_from_tflite", use_custom_ops=use_custom_ops)

self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
Expand Down
25 changes: 14 additions & 11 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import numpy as np
import tensorflow as tf
from onnx import helper

from common import check_tf_min_version, unittest_main, requires_custom_ops
from common import check_tf_min_version, unittest_main, requires_custom_ops, check_opset_min_version
from backend_test_base import Tf2OnnxBackendTestBase
import tf2onnx

Expand Down Expand Up @@ -80,21 +81,24 @@ def test_keras_api_large(self):

@requires_custom_ops()
@check_tf_min_version("2.0")
@check_opset_min_version(11, "SparseToDense")
def test_keras_hashtable(self):

featCols = [tf.feature_column.numeric_column("f_inp", dtype=tf.float32),
feature_cols = [
tf.feature_column.numeric_column("f_inp", dtype=tf.float32),
tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list("s_inp", ["a", "b", "z"], num_oov_buckets=1)
)]
featureLayer = tf.keras.layers.DenseFeatures(featCols)
)
]
feature_layer = tf.keras.layers.DenseFeatures(feature_cols)

inputDict = {}
inputDict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32)
inputDict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string)
input_dict = {}
input_dict["f_inp"] = tf.keras.Input(name="f_inp", shape=(1,), dtype=tf.float32)
input_dict["s_inp"] = tf.keras.Input(name="s_inp", shape=(1,), dtype=tf.string)

inputs = [input for input in inputDict.values()]
standardFeatures = featureLayer(inputDict)
hidden1 = tf.keras.layers.Dense(512, activation='relu')(standardFeatures)
inputs = list(input_dict.values())
standard_features = feature_layer(input_dict)
hidden1 = tf.keras.layers.Dense(512, activation='relu')(standard_features)
output = tf.keras.layers.Dense(10, activation='softmax')(hidden1)
model = tf.keras.Model(inputs=inputs, outputs=output)
model.compile(optimizer='adam', loss=tf.keras.losses.mean_squared_error)
Expand All @@ -106,7 +110,6 @@ def test_keras_hashtable(self):
tf.TensorSpec((None, 1), tf.string, name="s_inp"))
output_path = os.path.join(self.test_data_directory, "model.onnx")

from onnx import helper
model_proto, _ = tf2onnx.convert.from_keras(
model, input_signature=spec, opset=self.config.opset, output_path=output_path,
extra_opset=[helper.make_opsetid("ai.onnx.contrib", 1)])
Expand Down
13 changes: 2 additions & 11 deletions tests/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,8 @@ def func(text):
def _run_test_case(self, func, output_names_with_port, feed_dict, **kwargs):
extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]
process_args = {"extra_opset": extra_opset}
return self.run_test_case(func, feed_dict, [], output_names_with_port, process_args=process_args, **kwargs)

def run_onnxruntime(self, model_path, inputs, output_names):
"""Run test against onnxruntime backend."""
from onnxruntime_customops import get_library_path
import onnxruntime as rt
opt = rt.SessionOptions()
opt.register_custom_ops_library(get_library_path())
m = rt.InferenceSession(model_path, opt)
results = m.run(output_names, inputs)
return results
return self.run_test_case(func, feed_dict, [], output_names_with_port,
use_custom_ops=True, process_args=process_args, **kwargs)

@requires_custom_ops("WordpieceTokenizer")
@check_tf_min_version("2.0", "tensorflow_text")
Expand Down

0 comments on commit 37eb068

Please sign in to comment.