From 71e70f22a4d00f7028cac01c6ec347698c5c3939 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Tue, 14 Jan 2020 13:11:21 -0800 Subject: [PATCH] adding asnumpy() to output of gather(implicitly called) to fix gather test in large vector and tensor tests (#17290) --- tests/nightly/test_large_array.py | 4 ++-- tests/nightly/test_large_vector.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 2c0841bf1ff8..a0d893d601b3 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1672,10 +1672,10 @@ def test_gather(): idx = mx.nd.random.randint(0, LARGE_X, SMALL_X) # Calls gather_nd internally tmp = arr[idx] - assert np.sum(tmp[0] == 1) == SMALL_Y + assert np.sum(tmp[0].asnumpy() == 1) == SMALL_Y # Calls gather_nd internally arr[idx] += 1 - assert np.sum(arr[idx[0]] == 2) == SMALL_Y + assert np.sum(arr[idx[0]].asnumpy() == 2) == SMALL_Y def test_binary_broadcast(): diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index 8f01372fcf19..bc87fec33e79 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -1049,10 +1049,10 @@ def test_gather(): idx = mx.nd.random.randint(0, LARGE_X, 10, dtype=np.int64) # Calls gather_nd internally tmp = arr[idx] - assert np.sum(tmp == 1) == 10 + assert np.sum(tmp.asnumpy() == 1) == 10 # Calls gather_nd internally arr[idx] += 1 - assert np.sum(arr[idx] == 2) == 10 + assert np.sum(arr[idx].asnumpy() == 2) == 10 def test_infer_shape():