Skip to content

Commit

Permalink
[Frontend|MXNet] SwapAxis operator support (#5246)
Browse files Browse the repository at this point in the history
* MXNet swap axis

* MXNet swap axis

* swap axis review comment

* swap axis review comment
  • Loading branch information
maheshambule authored Apr 14, 2020
1 parent d2e58ad commit b7545eb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def _mx_unravel_index(inputs, attrs):
return _op.unravel_index(inputs[0], shape_expr)


def _mx_swap_axis(inputs, attrs):
assert len(inputs) == 1
dim1 = attrs.get_int('dim1')
dim2 = attrs.get_int('dim2')
shape = _infer_type(inputs[0]).checked_type.shape
axes = list(range(len(shape)))
axes[dim1] = dim2
axes[dim2] = dim1
return _op.transpose(inputs[0], axes=axes)


def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
Expand Down Expand Up @@ -1813,6 +1824,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
"SwapAxis" : _mx_swap_axis,
"expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat,
"concat" : _mx_concat,
Expand Down
13 changes: 13 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,18 @@ def verify(x, shape, dtype):
# verify([0, 1, 2, 5], [2, 2], dtype)


def test_forward_swap_axis():
def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
data = mx.sym.var('data')
mx_sym = mx.sym.swapaxes(data, dim1, dim2)
verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape)

_verify_swap_axis((4, 5), (5, 4), 0, 1)
_verify_swap_axis((2, 4, 4, 5), (2, 5, 4, 4), 1, 3)
# MXNet errors out when dim1 == dim2
# _verify_swap_axis((4, 5), (5, 4), 0, 0)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -1040,3 +1052,4 @@ def verify(x, shape, dtype):
test_forward_cond()
test_forward_make_loss()
test_forward_unravel_index()
test_forward_swap_axis()

0 comments on commit b7545eb

Please sign in to comment.