From 2e7f91a56f412c852474fc7a9169f7c401f59a40 Mon Sep 17 00:00:00 2001 From: Rohan Date: Sat, 12 Jun 2021 00:40:07 +0000 Subject: [PATCH 1/4] Support for broadcasting in batch_matmul when shapes differ --- python/tvm/relay/frontend/tensorflow_ops.py | 16 ++++++++++------ .../python/frontend/tensorflow/test_forward.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index c7385565857d..a4bb977ea835 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1135,10 +1135,6 @@ def _impl(inputs, attr, params, mod): is_static = not check_symbolic_shape(orig_shape_x) - if ndim > 3 and not is_static: - shape_of_x = list_shape_of(inputs[0], ndim) - shape_of_y = list_shape_of(inputs[1], ndim) - # reshape n-dimensional batch matmul into 3d if ndim > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] @@ -1147,7 +1143,8 @@ def _impl(inputs, attr, params, mod): new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) else: # handle dynamic shape (dyn.reshape op) - # new shape = [prod(shape[:-2]), -2, -1] + shape_of_x = list_shape_of(inputs[0], ndim) + shape_of_y = list_shape_of(inputs[1], ndim) new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] for i in range(ndim - 2): @@ -1157,11 +1154,18 @@ def _impl(inputs, attr, params, mod): new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0) input_x = _op.reshape(input_x, newshape=new_shape_x) - input_y = _op.reshape(input_y, newshape=new_shape_y) + + if np.prod(orig_shape_y) < np.prod(new_shape_y): + input_y = _op.broadcast_to(input_y, new_shape_y) + else: + input_y = _op.reshape(input_y, newshape=new_shape_y) adj_x = attr["adj_x"] adj_y = attr["adj_y"] input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x + shape_y = _infer_shape(input_y, mod) + if len(shape_y) < 3: + input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y ret = get_relay_op("batch_matmul")(input_x, input_y) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 331553388b48..57497d04706a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1843,6 +1843,9 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True) _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) + _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False) + _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False) + _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) @tvm.testing.requires_cuda @@ -1870,6 +1873,20 @@ def test_forward_batch_matmul_dynamic(): (2, 3, 4, 6, 5), "float32", ) + _test_batch_matmul_dynamic( + (None, None, None, 5, 6), + (6, None), + (2, 3, 4, 5, 6), + (6, 1), + "float32", + ) + _test_batch_matmul_dynamic( + (None, 5, 6), + (6, None), + (24, 5, 6), + (6, 1), + "float32", + ) ####################################################################### From d1e97fa019d59ae444ce1b6989ed0b34b858cf19 Mon Sep 17 00:00:00 2001 From: Rohan Date: Tue, 15 Jun 2021 22:34:36 +0000 Subject: [PATCH 2/4] refactor --- python/tvm/relay/frontend/tensorflow_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index a4bb977ea835..13015dc1fecf 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1141,7 +1141,11 @@ def _impl(inputs, attr, params, mod): if is_static: num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + ndim_y = len(orig_shape_y) + if ndim_y > 2: + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + elif ndim_y == 2: + new_shape_y = (1, orig_shape_y[-2], orig_shape_y[-1]) else: # handle dynamic shape (dyn.reshape op) shape_of_x = list_shape_of(inputs[0], ndim) shape_of_y = list_shape_of(inputs[1], ndim) @@ -1154,11 +1158,7 @@ def _impl(inputs, attr, params, mod): new_shape_y = _op.concatenate(_op.Tuple(new_shape_y), axis=0) input_x = _op.reshape(input_x, newshape=new_shape_x) - - if np.prod(orig_shape_y) < np.prod(new_shape_y): - input_y = _op.broadcast_to(input_y, new_shape_y) - else: - input_y = _op.reshape(input_y, newshape=new_shape_y) + input_y = _op.reshape(input_y, newshape=new_shape_y) adj_x = attr["adj_x"] adj_y = attr["adj_y"] From 08399461ccaef71a4d82296643411e8aea763bb6 Mon Sep 17 00:00:00 2001 From: Rohan Date: Tue, 15 Jun 2021 23:03:14 +0000 Subject: [PATCH 3/4] refactor logic for reshape in conditional --- python/tvm/relay/frontend/tensorflow_ops.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 13015dc1fecf..3cbb961548aa 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1159,13 +1159,11 @@ def _impl(inputs, attr, params, mod): input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) - + elif len(orig_shape_y) < 3: + input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"] input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x - shape_y = _infer_shape(input_y, mod) - if len(shape_y) < 3: - input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) input_y = _op.transpose(input_y, axes=[0, 2, 1]) if not adj_y else input_y ret = get_relay_op("batch_matmul")(input_x, input_y) From 244f3bfccde4c781075d153cb72c688cb5d60849 Mon Sep 17 00:00:00 2001 From: Rohan Date: Tue, 15 Jun 2021 23:08:32 +0000 Subject: [PATCH 4/4] refactor --- python/tvm/relay/frontend/tensorflow_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 3cbb961548aa..3c4a9b69ea6e 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1132,6 +1132,7 @@ def _impl(inputs, attr, params, mod): orig_shape_x = _infer_shape(input_x, mod) orig_shape_y = _infer_shape(input_y, mod) ndim = len(orig_shape_x) + ndim_y = len(orig_shape_y) is_static = not check_symbolic_shape(orig_shape_x) @@ -1141,7 +1142,6 @@ def _impl(inputs, attr, params, mod): if is_static: num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - ndim_y = len(orig_shape_y) if ndim_y > 2: new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) elif ndim_y == 2: @@ -1159,7 +1159,7 @@ def _impl(inputs, attr, params, mod): input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) - elif len(orig_shape_y) < 3: + elif ndim_y == 2: input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"]