From 6991f937e8384f3e605a4ba70308fcacf29b85eb Mon Sep 17 00:00:00 2001 From: Insop Song Date: Thu, 7 Jan 2021 22:29:15 -0800 Subject: [PATCH] Add test case for _npi_advanced_indexing_multiple - TODO: need to find a proper symbol for comparison - currently test function is NOT valid --- tests/python/frontend/mxnet/test_forward.py | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 537349e073e1..8a1763266a7d 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1935,6 +1935,29 @@ def verify(data_shape, axis, use_length, length): verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype("int32")) +@pytest.mark.parametrize( + "data_shape, row_sel, col", + [ + ((5, 7), (0, 1, 2, 3, 4,), 2), + ], +) +@pytest.mark.parametrize("dtype", ["float64", "float32"]) +@tvm.testing.parametrize_targets +@pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) +def test_forward_npi_advanced_indexing_multiple(data_shape, row_sel, col, dtype, target, ctx, kind): + data_np = np.random.uniform(size=data_shape).astype(dtype) + data = mx.sym.var("data") + ref_res = mx.np.array(data_np)[row_sel, col] + + # TODO need to add the proper symbol operator + mx_sym = mx.sym.np.(data.as_np_ndarray()[row_sel, col]) + mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + tvm.testing.assert_allclose(ref_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5) + + @pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet") @pytest.mark.parametrize( "data_shape, pad_width",