Skip to content

Commit

Permalink
Fix nightly tests (#1712)
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft authored Sep 13, 2021
1 parent 1c60588 commit 496b65d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
4 changes: 3 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
"group_nodes_by_type",
"test_ms_domain",
"check_node_domain",
"check_op_count"
"check_op_count",
"check_gru_count",
"check_lstm_count",
]


Expand Down
5 changes: 3 additions & 2 deletions tests/run_pretrained_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ googlenet_v4_slim:
rtol: 0.1

mobilenet_v3_large_float:
tf_min_version: 1.14 # explicit_paddings for Conv2D
url: https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_float.tgz
model: v3-large_224_1.0_float/v3-large_224_1.0_float.pb
input_get: get_beach
Expand Down Expand Up @@ -428,7 +429,7 @@ faster_rcnn_inception_v2_coco:
- num_detections:0

keras_resnet50:
tf_min_version: 2.1
tf_min_version: 2.2
disabled: false
url: module://tensorflow.keras.applications.resnet50/ResNet50
model: ResNet50
Expand All @@ -440,7 +441,7 @@ keras_resnet50:
- Identity:0

keras_mobilenet_v2:
tf_min_version: 2.1
tf_min_version: 2.2
disabled: false
url: module://tensorflow.keras.applications.mobilenet_v2/MobileNetV2
model: MobileNetV2
Expand Down
6 changes: 4 additions & 2 deletions tests/test_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main, check_gru_count, check_opset_after_tf_version, check_op_count, check_tf_min_version
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
from tf2onnx.tf_loader import is_tf2

# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
Expand Down Expand Up @@ -607,6 +607,7 @@ def func(x):
graph_validator=lambda g: check_gru_count(g, 1))

@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
@skip_tf_versions(["2.1"], "TF fails to correctly add output_2 node.")
def test_dynamic_multi_bigru_with_same_input_hidden_size(self):
batch_size = 10
x_val = np.array([[1., 1.], [2., 2.], [3., 3.]], dtype=np.float32)
Expand Down Expand Up @@ -660,6 +661,7 @@ def func(x):
# graph_validator=lambda g: check_gru_count(g, 2))

@check_opset_after_tf_version("1.15", 10, "might need ReverseV2")
@skip_tf_versions(["2.1"], "TF fails to correctly add output_2 node.")
def test_dynamic_multi_bigru_with_same_input_seq_len(self):
units = 5
batch_size = 10
Expand Down Expand Up @@ -714,7 +716,7 @@ def func(x, y1, y2):
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
# graph_validator=lambda g: check_gru_count(g, 2))

@check_tf_min_version("2.0")
@check_tf_min_version("2.2")
def test_keras_gru(self):
in_shape = [10, 3]
x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)
Expand Down

0 comments on commit 496b65d

Please sign in to comment.