Skip to content

Commit

Permalink
updated based on the PR discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
insop committed Jan 19, 2021
1 parent 6991f93 commit 4c763cd
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,20 +1938,34 @@ def verify(data_shape, axis, use_length, length):
@pytest.mark.parametrize(
"data_shape, row_sel, col",
[
((5, 7), (0, 1, 2, 3, 4,), 2),
(
(5, 7),
(
0,
1,
2,
3,
4,
),
2,
),
],
)
@pytest.mark.parametrize("dtype", ["float64", "float32"])
@tvm.testing.parametrize_targets
@pytest.mark.parametrize("kind", ["graph", "vm", "debug"])
def test_forward_npi_advanced_indexing_multiple(data_shape, row_sel, col, dtype, target, ctx, kind):
data_np = np.random.uniform(size=data_shape).astype(dtype)
data = mx.sym.var("data")
ref_res = mx.np.array(data_np)[row_sel, col]

# TODO need to add the proper symbol operator
mx_sym = mx.sym.np.(data.as_np_ndarray()[row_sel, col])
mod, _ = relay.frontend.from_mxnet(mx_sym, shape={"data": data_shape}, dtype=dtype)
row_sel_sym = mx.sym.var("row_sel").as_np_ndarray()
data_sym = mx.sym.var("data").as_np_ndarray()
col_sym = mx.sym.var("col").as_np_ndarray()
mx_sym = data_sym[row_sel_sym, col_sym]

mod, _ = relay.frontend.from_mxnet(
mx_sym, shape={"data": data_shape, "row_sel": row_sel, "col": col}, dtype=dtype
)
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5)
Expand Down

0 comments on commit 4c763cd

Please sign in to comment.