Skip to content

Commit

Permalink
[Frontend][MXNet] Fix default value for is_ascend in topk (apache#7568)
Browse files Browse the repository at this point in the history
* Use correct default value of False for is_ascend

* Add unit test for default topk is_ascend value
  • Loading branch information
Trevor Morris authored and trevor-m committed May 11, 2021
1 parent 918611f commit 0d6f36d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def _mx_topk(inputs, attrs):
new_attrs = {}
new_attrs["k"] = attrs.get_int("k", 1)
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", False)
ret_type = attrs.get_str("ret_typ", "indices")
if ret_type == "mask":
raise tvm.error.OpAttributeUnimplemented(
Expand Down
25 changes: 17 additions & 8 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,14 +1064,23 @@ def verify(shape, axis, is_ascend, dtype="float32"):

@tvm.testing.uses_gpu
def test_forward_topk():
def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):
def verify(shape, k, axis, ret_type, is_ascend=None, dtype="float32"):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.topk(
mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype
)
mx_sym = mx.sym.topk(
mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype
)
if is_ascend is None:
ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, dtype=dtype)
mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, dtype=dtype)
else:
ref_res = mx.nd.topk(
mx.nd.array(x_np),
k=k,
axis=axis,
ret_typ=ret_type,
is_ascend=is_ascend,
dtype=dtype,
)
mx_sym = mx.sym.topk(
mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, is_ascend=is_ascend, dtype=dtype
)
mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand All @@ -1086,7 +1095,7 @@ def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"):

verify((3, 4), k=1, axis=0, ret_type="both")
verify((3, 4), k=1, axis=-1, ret_type="indices")
verify((3, 5, 6), k=2, axis=2, ret_type="value")
verify((3, 5, 6), k=2, axis=2, ret_type="value", is_ascend=False)
verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True)
verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32")

Expand Down

0 comments on commit 0d6f36d

Please sign in to comment.