Skip to content

Commit

Permalink
Add test case for _npi_advanced_indexing_multiple
Browse files Browse the repository at this point in the history
- TODO: need to find a proper symbol for comparison
- currently test function is NOT valid
  • Loading branch information
insop committed Jan 19, 2021
1 parent 0476ea8 commit 6991f93
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,29 @@ def verify(data_shape, axis, use_length, length):
verify((2, 3, 4), 2, True, np.array([[3, 4, 2], [1, 2, 1]]).astype("int32"))


@pytest.mark.parametrize(
"data_shape, row_sel, col",
[
((5, 7), (0, 1, 2, 3, 4,), 2),
],
)
@pytest.mark.parametrize("dtype", ["float64", "float32"])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_advanced_indexing_multiple(data_shape, row_sel, col, dtype, target, ctx, kind):
data_np = np.random.uniform(size=data_shape).astype(dtype)
data = mx.sym.var("data")
ref_res = mx.np.array(data_np)[row_sel, col]

# TODO need to add the proper symbol operator
mx_sym = mx.sym.np.(data.as_np_ndarray()[row_sel, col])
mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
tvm.testing.assert_allclose(ref_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)


@pytest.mark.skipif(not hasattr(mx.sym.np, "pad"), reason="mx.sym.np.pad hasn't been publish yet")
@pytest.mark.parametrize(
"data_shape, pad_width",
Expand Down

0 comments on commit 6991f93

Please sign in to comment.