diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 8a1763266a7d..eee0e95d2505 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -1938,7 +1938,17 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.parametrize( "data_shape, row_sel, col", [ - ((5, 7), (0, 1, 2, 3, 4,), 2), + ( + (5, 7), + ( + 0, + 1, + 2, + 3, + 4, + ), + 2, + ), ], ) @pytest.mark.parametrize("dtype", ["float64", "float32"]) @@ -1946,12 +1956,16 @@ def verify(data_shape, axis, use_length, length): @pytest.mark.parametrize("kind", ["graph", "vm", "debug"]) def test_forward_npi_advanced_indexing_multiple(data_shape, row_sel, col, dtype, target, ctx, kind): data_np = np.random.uniform(size=data_shape).astype(dtype) - data = mx.sym.var("data") ref_res = mx.np.array(data_np)[row_sel, col] - # TODO need to add the proper symbol operator - mx_sym = mx.sym.np.(data.as_np_ndarray()[row_sel, col]) - mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype) + row_sel_sym = mx.sym.var("row_sel").as_np_ndarray() + data_sym = mx.sym.var("data").as_np_ndarray() + col_sym = mx.sym.var("col").as_np_ndarray() + mx_sym = data_sym[row_sel_sym, col_sym] + + mod, _ = relay.frontend.from_mxnet( + mx_sym, shape={"data": data_shape, "row_sel": row_sel, "col": col}, dtype=dtype + ) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(data_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)