From f94942db8ef81f91ac82159f35a9e0565fde8ab9 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 5 Jan 2021 15:30:20 -0800 Subject: [PATCH] Allow condition in if op to be an array. (#7215) --- python/tvm/relay/frontend/onnx.py | 3 +++ tests/python/frontend/onnx/test_forward.py | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 62396d839dc9..4c9996bc855a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2266,6 +2266,9 @@ class If(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): cond = inputs[0] + # Convert array to bool if needed. + if len(infer_shape(cond)) > 0: + cond = _op.take(cond, _expr.const(0, dtype="int64")) then_branch = attr.get("then_branch", None) else_branch = attr.get("else_branch", None) assert then_branch is not None and else_branch is not None diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0f7fda7301cd..df35a7e9bb56 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3969,8 +3969,7 @@ def test_loop(): verify_count_loop() -@tvm.testing.uses_gpu -def test_if(): +def verify_if(cond_array): # Given a bool scalar input cond. # return constant tensor x if cond is True, otherwise return constant tensor y. then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) @@ -4007,7 +4006,10 @@ def test_if(): ) if_model = onnx.helper.make_model(if_graph) - cond = np.array(1).astype("bool") + if cond_array: + cond = np.array([1]).astype("bool") + else: + cond = np.array(1).astype("bool") correct_out = x if cond else y for target, ctx in tvm.testing.enabled_targets(): @@ -4016,6 +4018,13 @@ def test_if(): tvm.testing.assert_allclose(correct_out[i], tvm_out[i], rtol=1e-05, atol=1e-05) +@tvm.testing.uses_gpu +def test_if(): + # Confirm that if works with cond as an array or scalar. + verify_if(cond_array=False) + verify_if(cond_array=True) + + @tvm.testing.uses_gpu def test_size(): def verify_size(indata):