diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1085e904c386..b272ead9737d 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -2335,6 +2335,14 @@ def _mx_npi_concatenate(inputs, attrs): return _op.concatenate(tuple(inputs), axis=int(axis)) +def _mx_npi_stack(inputs, attrs): + axis = attrs.get_str("axis", "0") + if axis == "None": + return _op.reshape(_op.stack(tuple(inputs), axis=0), (-1,)) + else: + return _op.stack(tuple(inputs), axis=int(axis)) + + def _mx_npx_reshape(inputs, attrs): shape = attrs.get_int_tuple("newshape") reverse = attrs.get_bool("reverse", False) @@ -2700,6 +2708,7 @@ def _mx_npi_where_rscalar(inputs, attrs): "_npi_less_equal": _mx_compare(_op.less_equal, _rename), "_npi_tanh": _rename(_op.tanh), "_npi_true_divide_scalar": _binop_scalar(_op.divide), + "_npi_stack": _mx_npi_stack, } # set identity list diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index d3be8c0506ba..537349e073e1 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -2012,6 +2012,34 @@ def test_forward_npi_concatenate(data_shape1, data_shape2, axis, dtype, target, tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) +@pytest.mark.parametrize( + "data_shape1, data_shape2, axis", + [ + ((3,), (3,), 0), + ((3,), (3,), -1), + ((1, 3, 2), (1, 3, 2), 2), + ((1, 3, 3), (1, 3, 3), 1), + ((1, 3), (1, 3), 0), + ], +) +@pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32"]) +@tvm.testing.parametrize_targets +@pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) +def test_forward_npi_stack(data_shape1, data_shape2, axis, dtype, target, ctx, kind): + data_np1 = np.random.uniform(size=data_shape1).astype(dtype) + data_np2 = np.random.uniform(size=data_shape2).astype(dtype) + data1 = mx.sym.var("data1") + data2 = mx.sym.var("data2") + ref_res = mx.np.stack([mx.np.array(data_np1), mx.np.array(data_np2)], axis=axis) + mx_sym = mx.sym.np.stack([data1.as_np_ndarray(), data2.as_np_ndarray()], axis=axis) + mod, _ = relay.frontend.from_mxnet( + mx_sym, shape={"data1": data_shape1, "data2": data_shape2}, dtype=dtype + ) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data_np1, data_np2) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + + @pytest.mark.parametrize("data_shape", [(2, 2, 2), (2, 7, 2), (2, 2, 2, 1, 2, 3, 1), (1, 8)]) @pytest.mark.parametrize("dtype", ["float64", "float32", "int64", "int32", "bool"]) @tvm.testing.parametrize_targets