From 672e14f8db3d5a9d4ea63835b8558e4e85d01f16 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 2 Apr 2020 21:09:23 +0530 Subject: [PATCH 1/4] MXNet swap axis --- python/tvm/relay/frontend/mxnet.py | 3 ++- python/tvm/relay/frontend/nnvm_common.py | 11 +++++++++++ tests/python/frontend/mxnet/test_forward.py | 12 ++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index b918f9b1adc5..2c0e1ec01410 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -37,7 +37,7 @@ from .common import get_name as _get_name from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast -from .nnvm_common import _clip, _transpose, _upsampling +from .nnvm_common import _clip, _transpose, _upsampling, _swap_axis from .nnvm_common import _elemwise_sum, _reshape from .nnvm_common import _warn_not_used from .mxnet_qnn_op_utils import quantize_mxnet_min_max, \ @@ -1790,6 +1790,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "Cast" : _cast, "clip" : _clip, "transpose" : _transpose, + "SwapAxis" : _swap_axis, "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index 072c7ad3be39..08a9853391ae 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -108,6 +108,17 @@ def _transpose(inputs, attrs): return _op.transpose(inputs[0], axes=axes) +def _swap_axis(inputs, attrs): + assert len(inputs) == 1 + dim1 = attrs.get_int('dim1') + dim2 = attrs.get_int('dim2') + shape = _infer_type(inputs[0]).checked_type.shape + axes = list(range(len(shape))) + axes[dim1] = dim2 + axes[dim2] = dim1 + return _op.transpose(inputs[0], axes=axes) + + def _upsampling(inputs, attrs): scale = attrs.get_int("scale") return _op.nn.upsampling(inputs[0], scale_h=scale, scale_w=scale) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 102905a408db..0d68f240c6a6 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -975,6 +975,18 @@ def verify(x, shape, dtype): # verify([0, 1, 2, 5], [2, 2], dtype) +def test_forward_swap_axis(): + def _verify_swap_axis(in_shape, out_shape, dim1, dim2): + data = mx.sym.var('data') + mx_sym = mx.sym.swapaxes(data, dim1, dim2) + verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape) + + _verify_swap_axis((4, 5), (5, 4), 0, 1) + _verify_swap_axis((2, 4, 4, 5), (2, 5, 4, 4), 1, 3) + # MXNet errors out when dim1 == dim2 + # _verify_swap_axis((4, 5), (5, 4), 0, 0) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() From 7dd26271121f7453a0855c1a7cc28b83e01ccd74 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 2 Apr 2020 21:12:23 +0530 Subject: [PATCH 2/4] MXNet swap axis --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 2c0e1ec01410..2b469771e29e 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1790,7 +1790,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "Cast" : _cast, "clip" : _clip, "transpose" : _transpose, - "SwapAxis" : _swap_axis, + "SwapAxis" : _swap_axis, "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations From c96531b29c6b8451bf42c794593504465e948117 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Mon, 13 Apr 2020 18:45:44 +0530 Subject: [PATCH 3/4] swap axis review comment --- python/tvm/relay/frontend/mxnet.py | 15 +++++++++++++-- python/tvm/relay/frontend/nnvm_common.py | 11 ----------- tests/python/frontend/mxnet/test_forward.py | 3 ++- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 2b469771e29e..cdc8e7e2d478 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -37,7 +37,7 @@ from .common import get_name as _get_name from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast -from .nnvm_common import _clip, _transpose, _upsampling, _swap_axis +from .nnvm_common import _clip, _transpose, _upsampling from .nnvm_common import _elemwise_sum, _reshape from .nnvm_common import _warn_not_used from .mxnet_qnn_op_utils import quantize_mxnet_min_max, \ @@ -127,6 +127,17 @@ def _mx_unravel_index(inputs, attrs): return _op.unravel_index(inputs[0], shape_expr) +def _mx_swap_axis(inputs, attrs): + assert len(inputs) == 1 + dim1 = attrs.get_int('dim1') + dim2 = attrs.get_int('dim2') + shape = _infer_type(inputs[0]).checked_type.shape + axes = list(range(len(shape))) + axes[dim1] = dim2 + axes[dim2] = dim1 + return _op.transpose(inputs[0], axes=axes) + + def _mx_zeros(inputs, attrs): assert len(inputs) == 0 shape = attrs.get_int_tuple("shape") @@ -1790,7 +1801,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "Cast" : _cast, "clip" : _clip, "transpose" : _transpose, - "SwapAxis" : _swap_axis, + "SwapAxis" : _mx_swap_axis, "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index 08a9853391ae..072c7ad3be39 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -108,17 +108,6 @@ def _transpose(inputs, attrs): return _op.transpose(inputs[0], axes=axes) -def _swap_axis(inputs, attrs): - assert len(inputs) == 1 - dim1 = attrs.get_int('dim1') - dim2 = attrs.get_int('dim2') - shape = _infer_type(inputs[0]).checked_type.shape - axes = list(range(len(shape))) - axes[dim1] = dim2 - axes[dim2] = dim1 - return _op.transpose(inputs[0], axes=axes) - - def _upsampling(inputs, attrs): scale = attrs.get_int("scale") return _op.nn.upsampling(inputs[0], scale_h=scale, scale_w=scale) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 0d68f240c6a6..dda6be60bafd 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1042,4 +1042,5 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): test_forward_deconvolution() test_forward_cond() test_forward_make_loss() - test_forward_unravel_index() \ No newline at end of file + test_forward_unravel_index() + test_forward_swap_axis() \ No newline at end of file From 77f7de2db97cc5537d30546b0f59f0b86386e4d3 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Mon, 13 Apr 2020 18:47:39 +0530 Subject: [PATCH 4/4] swap axis review comment --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index cdc8e7e2d478..e5c505933d94 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1801,7 +1801,6 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "Cast" : _cast, "clip" : _clip, "transpose" : _transpose, - "SwapAxis" : _mx_swap_axis, "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations @@ -1825,6 +1824,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "slice_axis" : _mx_slice_axis, "SliceChannel" : _mx_split, "split" : _mx_split, + "SwapAxis" : _mx_swap_axis, "expand_dims" : _mx_expand_dims, "Concat" : _mx_concat, "concat" : _mx_concat,