diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 11f4515fbb1e..2add0739b901 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -25,6 +25,7 @@ from tvm.relay.testing import run_infer_type as infer_type from utils.assert_diagnostic import DiagnosticTesting +from utils import ref_funcs def int32(val): @@ -1703,5 +1704,29 @@ def verify_all_class_non_max_suppression( ) +@tvm.testing.uses_gpu +def test_gather_nd(): + def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0): + x = relay.var("x", relay.TensorType(data_shape, "float32")) + y = relay.var("y", relay.TensorType(indices_shape, "int32")) + z = relay.gather_nd(x, y, batch_dims, indices_shape[0]) + + mod = tvm.IRModule() + mod["main"] = relay.Function([x, y], z) + + data_np = np.random.uniform(size=data_shape_np).astype("float32") + indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32") + + ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims) + check_result([data_np, indices_np], mod, [ref_res]) + + verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3)) + verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3)) + verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1) + verify_gather_nd( + (relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2 + ) + + if __name__ == "__main__": pytest.main([__file__])